Return to blog

How to train RWKV on a new dataset using HuggingFace Transformers

Greg Bizup
Jul 17, 2023

large language model concept image

RWKV model overview

RWKV (pronounced Rwa-Kuv) is the next breakthrough in large language models (LLM's). It was first proposed in this repository. It performs incredibly well in conversational tasks, efficiently uses resources in training and inference, and learns quickly. And of course it's free and open source.

Most LLM's of today, such as GPT-4, are based on the Transformer architecture, but RWKV is a recurrent neural network (RNN). Because it is an RNN, it has relatively low resource usage requirements, fast inference times, and can handle long input sequences. Training and fine-tuning can be done on a GPU with less than 8GB VRAM.

It even has it's own ChatGPT-like chatbot hosted on HuggingFace. Check out ChatRWKV to see for yourself just how powerful it is.

Training RWKV with HuggingFace

If you are interested in training this new state-of-the-art LLM, you may have been deterred by the lack of documentation. Moreover, the code in the official repository is hard to follow (a common reality in data science).

Fortunately, HuggingFace provides a high-level API for causal language modeling with RWKV. It is called RwkvForCausalLm, and is documented here. The code presented here uses RwkvForCausalLM as a drop-in replacement for AutoModelForCausalLM presented in the HuggingFace causal LM tutorial. I highly recommend following through the tutorial.

This article presents the first clear-cut way to train a RWKV model for causal language modeling. We will be using the Transformers library.

As always, the code can be found on our GitHub. With no further ado, let's get into it.

Installing dependencies

Install the necessary libraries if you haven't already.

pip install transformers torch accelerate datasets numpy 

The code for training RWKV

This code was developed based on HuggingFace's guide for causal language modeling. I refactored their methods a bit to produce a more modular experience that is easier to work with. If you need to change anything, just edit the constants under the if __name__ == "__main__" block.

Start by importing the necessary libraries.

from transformers import (

from datasets import load_dataset
import torch
import numpy as np

import collections
from typing import Any, Dict
import math

Now, define some preprocessing functions that do the following:

  1. Remove square brackets and (_URL_X_) tags from the data
  2. Tokenize the data
  3. Concatenate then chunk the data
  4. Make a copy of our data to use as our labels

def remove_url_from_text(text: str):
    """Remove square brackets around linked text and (_URL_0_) after"""
    return re.sub(r"\[|\]|\(_URL_\d+_\)", "", text)

def tokenize_function(examples: Dict[str, Any]) -> Dict[str, Any]:
    """Concatenate and tokenize the answers in flattened ELI5 data"""
    concatenated = [remove_url_from_text(" ".join(x)) for x in examples["answers.text"]]
    return tokenizer(concatenated)

def chunk(examples: Dict[str, Any], chunk_size: int = 256) -> Dict[str, Any]:
    """Concatenate and chunk batches of data"""
    concatenated = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated[list(examples.keys())[0]])
    total_length = (total_length // chunk_size) * chunk_size
    return {
        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concatenated.items()

def set_labels(examples: Dict[str, Any]) -> Dict[str, Any]:
    """Add a labels column to the dataset which is a copy of input_ids"""
    examples["labels"] = examples["input_ids"].copy()
    return examples

Now that we have our preprocessing functions set up, we can set up the model for training. We will define our constants up top, instantiate the model and tokenizer, and load our dataset. Finally, we will be creating a DataCollatorForLanguageModeling. This will mask tokens at the end of the sequence for us, which the model will attempt to predict.

if __name__ == "__main__":
    MODEL_NAME = "sgugger/rwkv-430M-pile"
    DATASET = "eli5"
    CHUNK_SIZE = 128
    BATCH_SIZE = 32
    DATASET_SPLIT = "train_asks[:500]"

    model = RwkvForCausalLM.from_pretrained(MODEL_NAME)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.pad_token = tokenizer.eos_token

    dataset = load_dataset(DATASET, split=DATASET_SPLIT)
    dataset = dataset.train_test_split(test_size=TEST_SPLIT_SIZE)
    dataset = dataset.flatten()
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

Now that our model is set up, let's map our preprocessing functions over our dataset.

    # Encode
    encoded_dataset =

    # Chunk
    chunked_dataset =
        fn_kwargs={"chunk_size": CHUNK_SIZE},

    # Label
    lm_dataset =

With the dataset preprocessed, we can set up our trainer and start training. We will be using a perplexity score to assess our model's performance. The perplexity is simply the exponential of the loss, and it tells us how "surprised" the model was at the result. Lower values indicate a better model.

    training_args = TrainingArguments(
        output_dir = MODEL_NAME + "-" + DATASET,
        logging_steps=len(lm_dataset["train"]) // BATCH_SIZE
    trainer = Trainer(

    # Evaluate before train
    eval_0 = trainer.evaluate()
    perplexity_0 = math.exp(eval_0["eval_loss"])

    # Train

    # Evaluate after train
    eval_f = trainer.evaluate()
    perplexity_f = math.exp(eval_f["eval_loss"])

You can run this script as-is or make adjustments as necessary.

Dealing with CUDA out of memory errors

If you are running into "CUDA out of memory" errors, there are a few things you can do to reduce your model's memory footprint. You can reduce your batch size, try gradient accumulation, and use mixed precision training. Check out this wonderful article from HuggingFace that shows you how to efficiently train on a single GPU to learn more.