The RWKV Language Model

Overview

RWKV-LM

We propose the RWKV language model, with alternating time-mix and channel-mix layers:

\begin{align*}
\text{Time-mix :} && \text{TM}_{t,c} &&=&&\text{sigmoid}(\text{R}_{t,c}) &&\cdot&& &&\textstyle\sum_{u} &&\textbf{W}_{t,u,c} &&\cdot&& \text{softmax}_t(\text{K}_{u,c}) &&\cdot&& \text{V}_{u,c}\\
\text{Channel-mix :} && \text{CM}_{t,c} &&=&&\text{sigmoid}(\text{R}_{t,c}) &&\cdot&& &&\textstyle\sum_d &&\textbf{W}_{c,d} &&\cdot&& \text{gelu}(\text{K}_{t,d}) &&\cdot&& \text{V}_{t,d}
\end{align*}

  • The R, K, V are generated by linear transforms of input, and W is parameter. The idea of RWKV is to decompose attention into R(target) * W(src, target) * K(src). So we can call R "receptance", and sigmoid means it's in 0~1 range.

  • The Time-mix is similar to AFT (https://arxiv.org/abs/2105.14103). There are two differences.

(1) We changed the normalization (denominator). For masked language models, we define:

\text{softmax}_t(\text{K}_{u,c}) = \frac{\exp(\text{K}_{u,c})}{\sum_{v \leq t}\exp(\text{K}_{v,c})}

(2) We decompose W_{t,u,c} and introduce multi-head W (here h is the corresponding head of c):

W_{t,u,c}=f_h(t-u)\cdot \alpha_h(u) \cdot \beta_h(t)

(3) You don't need LayerNorm for Time-mix. In fact, the model converges faster when LayerNorm is removed.

Moreover we multiply the final output of Time-mix layer by γ(t). The reason for the α β γ factors, is because the context size is smaller when t is small, and this can be compensated using the α β γ factors.


We also propose a new sampling method (as in src/utils.py):

(1) Find the max probability p_max after softmax.

(2) Remove all entries whose probability is lower than 0.02 * pow(p_max, 2)

(3) Feel free to tune the 0.02 and 2 factor.


Training loss, RWKV vs MHA+Rotary+GeGLU:

RWKV-vs-MHA

(this is character-level loss with simplebooks-92 dataset https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip)

Comments
  • Sequence to Sequence?

    Sequence to Sequence?

    Hey @BlinkDL! Awesome project!

    I was wondering if you have performed any Seq-2-Seq experiments with it? Any reason for going with GPT model in the first place as opposed to something like T5 (standard Transformer)? Any direction on what changes will be required to make a standard encoder-decoder architecture with RWKV?

    Also, is there any report on in-context-learning/FSL capability of the latest trained model?

    opened by SushantDaga 2
  • v4 model.py vs model_run.py

    v4 model.py vs model_run.py

    Hi, Thanks for this awesome repo! I'm trying to understand the code and found that in the v4 folder, there's this model.py and model_run.py, which contains GPT and RWKV_GPT respectively which all uses different initialization methods. Could you elaborate on when should which one be used? Thanks in advnace!

    opened by jingweiz 3
  • RWKV-4 169m/430m in browser with ORT Web / TF.js / tfjs-tflite?

    RWKV-4 169m/430m in browser with ORT Web / TF.js / tfjs-tflite?

    Hi, really exciting project! I'm wondering if you've published the model conversion script that you used to create the js_models files from the .pth model file? It would be awesome to see how the larger and newer models like RWKV-4 169m/430m perform in the browser! I think the inference speed of RWKV opens up many new possibilities for language models on the web.

    opened by josephrocca 32
  • CUDA compilation error with Ctx Length>2000

    CUDA compilation error with Ctx Length>2000

    Hello, I am trying out RWKV with audio modality and when I set T_MAX>>1000, it throws this error:

    Emitting ninja build file /root/.cache/torch_extensions/py39_cu116/timex/build.ninja...
    Building extension module timex...
    Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
    [1/2] /usr/local/cuda/bin/nvcc  -DTORCH_EXTENSION_NAME=timex -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1013\" -isystem /root/anaconda3/envs/surya-env/lib/python3.9/site-packages/torch/include -isystem /root/anaconda3/envs/surya-env/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -isystem /root/anaconda3/envs/surya-env/lib/python3.9/site-packages/torch/include/TH -isystem /root/anaconda3/envs/surya-env/lib/python3.9/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /root/anaconda3/envs/surya-env/include/python3.9 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' --use_fast_math --extra-device-vectorization -DTmax=10000 -DBF=8 -DBB=2 -std=c++14 -c cuda/timex_cuda.cu -o timex_cuda.cuda.o 
    FAILED: timex_cuda.cuda.o 
    /usr/local/cuda/bin/nvcc  -DTORCH_EXTENSION_NAME=timex -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1013\" -isystem /root/anaconda3/envs/surya-env/lib/python3.9/site-packages/torch/include -isystem /root/anaconda3/envs/surya-env/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -isystem /root/anaconda3/envs/surya-env/lib/python3.9/site-packages/torch/include/TH -isystem /root/anaconda3/envs/surya-env/lib/python3.9/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /root/anaconda3/envs/surya-env/include/python3.9 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' --use_fast_math --extra-device-vectorization -DTmax=10000 -DBF=8 -DBB=2 -std=c++14 -c cuda/timex_cuda.cu -o timex_cuda.cuda.o 
    ptxas error   : Entry function '_Z15kernel_backwardIfEvPKT_S2_S2_PS0_S3_iii' uses too much shared data (0x30d40 bytes, 0xc000 max)
    ptxas error   : Entry function '_Z14kernel_forwardIfEvPKT_S2_PS0_S0_iii' uses too much shared data (0x57e40 bytes, 0xc000 max)
    ninja: build stopped: subcommand failed.
    

    GPU: A100, VRAM: 42GB, CUDA 11.6

    I am okay if the training takes a bit long. But I need this to work. Don't know any CUDA. Can you suggest some workarounds?

    Thanks for the incredible work btw!

    opened by ojus1 8
  • 关于调用模型做分类任务

    关于调用模型做分类任务

    你好作者!我对此工作很感兴趣,因为我现在在用基于transformer的模型做分类任务,transformer或者RNN在分类任务里通常采用最后一个模块的每个通道的最后一个元素作为输出,并通过全连接层映射到几个类别。 请问你觉得RWKV原理类似吗?依旧提取最后一个元素作为输出是否稳妥呢?希望您能给出一些建议,我将很感激!

    opened by louisinhit 2
Releases(4.00)
Owner
PENG Bo
http://zhihu.com/people/bopengbopeng
PENG Bo
Applying "Load What You Need: Smaller Versions of Multilingual BERT" to LaBSE

smaller-LaBSE LaBSE(Language-agnostic BERT Sentence Embedding) is a very good method to get sentence embeddings across languages. But it is hard to fi

Jeong Ukjae 13 Sep 02, 2022
Mirco Ravanelli 2.3k Dec 27, 2022
Super Tickets in Pre-Trained Language Models: From Model Compression to Improving Generalization (ACL 2021)

Structured Super Lottery Tickets in BERT This repo contains our codes for the paper "Super Tickets in Pre-Trained Language Models: From Model Compress

Chen Liang 16 Dec 11, 2022
Automatically search Stack Overflow for the command you want to run

stackshell Automatically search Stack Overflow (and other Stack Exchange sites) for the command you want to ru Use the up and down arrows to change be

circuit10 22 Oct 27, 2021
Unofficial Python library for using the Polish Wordnet (plWordNet / Słowosieć)

Polish Wordnet Python library Simple, easy-to-use and reasonably fast library for using the Słowosieć (also known as PlWordNet) - a lexico-semantic da

Max Adamski 12 Dec 23, 2022
Perform sentiment analysis and keyword extraction on Craigslist listings

craiglist-helper synopsis Perform sentiment analysis and keyword extraction on Craigslist listings Background I love Craigslist. I've found most of my

Mark Musil 1 Nov 08, 2021
A complete NLP guideline for enthusiasts

NLP-NINJA A complete guide for Natural Language Processing in Python Table of Contents S.No. Topic Level Meaning 1 Tokenization 🤍 Beginner 2 Stemming

MAINAK CHAUDHURI 22 Dec 27, 2022
An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition

CRNN paper:An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition 1. create your ow

Tsukinousag1 3 Apr 02, 2022
The (extremely) naive sentiment classification function based on NBSVM trained on wisesight_sentiment

thai_sentiment The naive sentiment classification function based on NBSVM trained on wisesight_sentiment วิธีติดตั้ง pip install thai_sentiment==0.1.3

Charin 7 Dec 08, 2022
NLP-Project - Used an API to scrape 2000 reddit posts, then used NLP analysis and created a classification model to mixed succcess

Project 3: Web APIs & NLP Problem Statement How do r/Libertarian and r/Neoliberal differ on Biden post-inaguration? The goal of the project is to see

Adam Muhammad Klesc 2 Mar 29, 2022
The Easy-to-use Dialogue Response Selection Toolkit for Researchers

The Easy-to-use Dialogue Response Selection Toolkit for Researchers

GMFTBY 32 Nov 13, 2022
Main repository for the chatbot Bobotinho.

Bobotinho Bot Main repository for the chatbot Bobotinho. ℹ️ Introduction Twitch chatbot with entertainment commands. ‎ 💻 Technologies Concurrent code

Bobotinho 14 Nov 29, 2022
Codes for processing meeting summarization datasets AMI and ICSI.

Meeting Summarization Dataset Meeting plays an essential part in our daily life, which allows us to share information and collaborate with others. Wit

xcfeng 39 Dec 14, 2022
An end to end ASR Transformer model training repo

END TO END ASR TRANSFORMER 本项目基于transformer 6*encoder+6*decoder的基本结构构造的端到端的语音识别系统 Model Instructions 1.数据准备: 自行下载数据,遵循文件结构如下: ├── data │ ├── train │

旷视天元 MegEngine 10 Jul 19, 2022
NLP - Machine learning

Flipkart-product-reviews NLP - Machine learning About Product reviews is an essential part of an online store like Flipkart’s branding and marketing.

Harshith VH 1 Oct 29, 2021
Speech Recognition for Uyghur using Speech transformer

Speech Recognition for Uyghur using Speech transformer Training: this model using CTC loss and Cross Entropy loss for training. Download pretrained mo

Uyghur 11 Nov 17, 2022
Experiments in converting wikidata to ftm

FollowTheMoney / Wikidata mappings This repo will contain tools for converting Wikidata entities into FtM schema. Prefixes: https://www.mediawiki.org/

Friedrich Lindenberg 2 Nov 12, 2021
Malware-Related Sentence Classification

Malware-Related Sentence Classification This repo contains the code for the ICTAI 2021 paper "Enrichment of Features for Malware-Related Sentence Clas

Chau Nguyen 1 Mar 26, 2022
Model parallel transformers in JAX and Haiku

Table of contents Mesh Transformer JAX Updates Pretrained Models GPT-J-6B Links Acknowledgments License Model Details Zero-Shot Evaluations Architectu

Ben Wang 4.9k Jan 04, 2023
A design of MIDI language for music generation task, specifically for Natural Language Processing (NLP) models.

MIDI Language Introduction Reference Paper: Pop Music Transformer: Beat-based Modeling and Generation of Expressive Pop Piano Compositions: code This

Robert Bogan Kang 3 May 25, 2022