RWKV model overview
RWKV (pronounced Rwa-Kuv) is the next breakthrough in large language models (LLM's). It was first proposed in this repository. It performs incredibly well in conversational tasks, efficiently uses resources in training and inference, and learns quickly. And of course it's free and open source.
Most LLM's of today, such as GPT-4, are based on the Transformer architecture, but RWKV is a recurrent neural network (RNN). Because it is an RNN, it has relatively low resource usage requirements, fast inference times, and can handle long input sequences. Training and fine-tuning can be done on a GPU with less than 8GB VRAM.
It even has it's own ChatGPT-like chatbot hosted on HuggingFace. Check out ChatRWKV to see for yourself just how powerful it is.
Training RWKV with HuggingFace
If you are interested in training this new state-of-the-art LLM, you may have been deterred by the lack of documentation. Moreover, the code in the official repository is hard to follow (a common reality in data science).
Fortunately, HuggingFace provides a high-level API for causal language modeling with RWKV. It is called RwkvForCausalLm, and is documented here. The code presented here uses RwkvForCausalLM as a drop-in replacement for AutoModelForCausalLM presented in the HuggingFace causal LM tutorial. I highly recommend following through the tutorial.
This article presents the first clear-cut way to train a RWKV model for causal language modeling. We will be using the Transformers library.
As always, the code can be found on our GitHub. With no further ado, let's get into it.
Installing dependencies
Install the necessary libraries if you haven't already.
pip install transformers torch accelerate datasets numpy
The code for training RWKV
This code was developed based on HuggingFace's guide for
causal language modeling.
I refactored their methods a bit to produce a more modular experience that is easier to work with.
If you need to change anything, just edit the constants under the
if __name__ == "__main__"
block.
Start by importing the necessary libraries.
from transformers import (
RwkvForCausalLM,
RwkvConfig,
AutoTokenizer,
DataCollatorForLanguageModeling,
TrainingArguments,
Trainer
)
from datasets import load_dataset
import torch
import numpy as np
import collections
from typing import Any, Dict
import math
Now, define some preprocessing functions that do the following:
- Remove square brackets and (_URL_X_) tags from the data>
- Tokenize the data>
- Concatenate then chunk the data>
- Make a copy of our data to use as our labels>
def remove_url_from_text(text: str):
"""Remove square brackets around linked text and (_URL_0_) after"""
return re.sub(r"\[|\]|\(_URL_\d+_\)", "", text)
def tokenize_function(examples: Dict[str, Any]) -> Dict[str, Any]:
"""Concatenate and tokenize the answers in flattened ELI5 data"""
concatenated = [remove_url_from_text(" ".join(x)) for x in examples["answers.text"]]
return tokenizer(concatenated)
def chunk(examples: Dict[str, Any], chunk_size: int = 256) -> Dict[str, Any]:
"""Concatenate and chunk batches of data"""
concatenated = {k: sum(examples[k], []) for k in examples.keys()}
total_length = len(concatenated[list(examples.keys())[0]])
total_length = (total_length // chunk_size) * chunk_size
return {
k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
for k, t in concatenated.items()
}
def set_labels(examples: Dict[str, Any]) -> Dict[str, Any]:
"""Add a labels column to the dataset which is a copy of input_ids"""
examples["labels"] = examples["input_ids"].copy()
return examples
Now that we have our preprocessing functions set up, we can set up the model for training. We will define our constants up top, instantiate the model and tokenizer, and load our dataset. Finally, we will be creating a DataCollatorForLanguageModeling. This will mask tokens at the end of the sequence for us, which the model will attempt to predict.
if __name__ == "__main__":
MODEL_NAME = "sgugger/rwkv-430M-pile"
DATASET = "eli5"
CHUNK_SIZE = 128
TEST_SPLIT_SIZE = 0.2
BATCH_SIZE = 32
DATASET_SPLIT = "train_asks[:500]"
model = RwkvForCausalLM.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
dataset = load_dataset(DATASET, split=DATASET_SPLIT)
dataset = dataset.train_test_split(test_size=TEST_SPLIT_SIZE)
dataset = dataset.flatten()
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
Now that our model is set up, let's map our preprocessing functions over our dataset.
# Encode
encoded_dataset = dataset.map(
tokenize_function,
batched=True,
num_proc=4,
remove_columns=dataset["train"].column_names
)
# Chunk
chunked_dataset = encoded_dataset.map(
chunk,
fn_kwargs={"chunk_size": CHUNK_SIZE},
batched=True,
)
# Label
lm_dataset = chunked_dataset.map(
set_labels,
batched=True
)
With the dataset preprocessed, we can set up our trainer and start training. We will be using a perplexity score to assess our model's performance. The perplexity is simply the exponential of the loss, and it tells us how "surprised" the model was at the result. Lower values indicate a better model.
training_args = TrainingArguments(
output_dir = MODEL_NAME + "-" + DATASET,
overwrite_output_dir=True,
evaluation_strategy="epoch",
learning_rate=2e-5,
weight_decay=0.01,
push_to_hub=False,
logging_steps=len(lm_dataset["train"]) // BATCH_SIZE
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=lm_dataset["train"],
eval_dataset=lm_dataset["test"],
data_collator=data_collator,
)
# Evaluate before train
eval_0 = trainer.evaluate()
perplexity_0 = math.exp(eval_0["eval_loss"])
# Train
trainer.train()
# Evaluate after train
eval_f = trainer.evaluate()
perplexity_f = math.exp(eval_f["eval_loss"])
You can run this script as-is or make adjustments as necessary.
Dealing with CUDA out of memory errors
If you are running into "CUDA out of memory" errors, there are a few things you can do to reduce your model's memory footprint. You can reduce your batch size, try gradient accumulation, and use mixed precision training. Check out this wonderful article from HuggingFace that shows you how to efficiently train on a single GPU to learn more.