Fine-tuning a language model#

This Tango example showcases how you could train or fine-tune a causal language model like GPT-2 or GPT-J from transformers on WikiText2 or a similar dataset. It’s best that you run this experiment on a machine with a GPU and PyTorch properly installed, otherwise Tango will fall back to CPU-only and it will be extremely slow.

This example also depends on FairScale, which allows you to leverage FullyShardedDataParallel (FSDP) and activation checkpointing to fine-tune GPT-J 6B or a similar-sized model. Just set the constants fsdp and activation_checkpointing in the config to true. Without using CPU offloading you’ll need at least 4 x 40GiB A100 GPUs, or a different configuration with a comparable amount of total GPU memory.

Tip

You can find the full code for this example on GitHub.

Components#

We’ll need to write a step for tokenizing the data and preparing it for language model training. All of the other steps we need are provided by Tango integrations.

So, create a file called tokenize_step.py with following contents:

import datasets

from tango import Step
from tango.integrations.datasets import DatasetsFormat
from tango.integrations.transformers import Tokenizer


# We need a step to tokenize the raw data. The result of this step will be passed
# directly into the "torch::train" step.
@Step.register("tokenize_data")
class TokenizeData(Step):
    DETERMINISTIC = True
    CACHEABLE = True
    FORMAT = DatasetsFormat()

    def run(  # type: ignore[override]
        self,
        dataset: datasets.DatasetDict,
        tokenizer: Tokenizer,
        block_size: int = 1024,
        num_workers: int = 1,
        field_to_tokenize: str = "text",
    ) -> datasets.DatasetDict:
        def tokenize_function(example):
            return tokenizer(example[field_to_tokenize])

        dataset = dataset.map(
            tokenize_function,
            batched=True,
            num_proc=num_workers,
            remove_columns=list(dataset.column_names.values())[0],  # remove all old columns
            desc="Tokenizing dataset",
        )

        def group_texts(examples):
            # Concatenate all texts.
            concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}  # type: ignore
            total_length = len(concatenated_examples[list(examples.keys())[0]])
            # We drop the small remainder, we could add padding if the model supported
            # it instead of this drop, you can customize this part to your needs.
            if total_length >= block_size:
                total_length = (total_length // block_size) * block_size
            # Split by chunks of max_len.
            result = {
                k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
                for k, t in concatenated_examples.items()
            }
            result["labels"] = result["input_ids"].copy()
            return result

        dataset = dataset.map(
            group_texts,
            batched=True,
            num_proc=num_workers,
            desc=f"Grouping texts into chunks of {block_size}",
        )

        return dataset

Configuration file#

Next you’ll need to create a configuration file that defines the experiment. Just copy over these contents into a file called config.jsonnet:

##################
# Model settings #
##################

local pretrained_model = "gpt2";
# With 'fsdp' and 'activation_checkpointing' (see constants below), you should be able to train
# a 6B model on 4x ~40GB GPUs:
# local pretrained_model = "EleutherAI/gpt-j-6B";

# This doesn't seem to work with gpt2, but works fine with gpt-j.
local load_with_low_cpu_mem_usage = std.startsWith(pretrained_model, "EleutherAI/gpt-j");

####################
# Trainer settings #
####################

# Trainer settings, adjust to your use-case.
local training_steps = 200;  # total number of optimization steps to train for
local validate_every = 20;  # how often to validate and save checkpoints

local devices = 1;  # number of devices to train on (will use GPUs if enough are available, otherwise CPU)
local grad_accum = 1;  # number of gradient accumulation steps (changes the effective batch size)
# This is the batch size per GPU, ignoring gradient accumulation:
local batch_size = 8;
# So the effective batch size is `batch_size * grad_accum * devices`

local activation_checkpointing = false;  # use activation/gradient checkpointing (probably need this GPT-J 6B, but not gpt2)
local amp = false;  # use PyTorch's native automatic mixed precision
local fsdp = false;  # Use FairScale's FullyShardedDataParallel (probably need this GPT-J 6B, but not gpt2)
local cpu_offloading = false;  # Can only be used with 'fsdp' - saves a lot of GPU memory by offloading params+gradients to CPU, but is very slow.

######################
# Optimizer settings #
######################

local warmup_steps = 20;
local learning_rate = 0.00005;  # you can probably use a higher LR for a small model like "gpt2"


# <----- you probably don't need to edit below this line ----> #


assert fsdp == true || cpu_offloading == false : "cpu_offloading only available with fsdp";

# FullyShardedDataParallel config:
local fsdp_config = if fsdp then {
    reshard_after_forward: true,
    move_params_to_cpu: cpu_offloading,
    move_grads_to_cpu: cpu_offloading,
    mixed_precision: amp,
} else null;

local training_engine = {
    type: if fsdp then "fairscale" else "torch",
    optimizer: {
        type: "torch::AdamW",
        lr: learning_rate,
        betas: [0.9, 0.95],
        eps: 1e-6,
    },
    lr_scheduler: {
        type: "transformers::linear",
        num_warmup_steps: warmup_steps,
        num_training_steps: training_steps,
    },
    amp: amp,
    [if fsdp then "fsdp_config" else null]: fsdp_config,
};

local distributed_dataloader = {
    batch_size: batch_size,
    collate_fn: { type: "transformers::DefaultDataCollator" },
    sampler: {
        type: "torch::DistributedSampler",
        shuffle: true,
        drop_last: true,
    },
};

local single_device_dataloader = {
    shuffle: true,
    batch_size: batch_size,
    collate_fn: { type: "transformers::DefaultDataCollator" },
};

local dataloader = if devices > 1 then distributed_dataloader else single_device_dataloader;

{
    steps: {
        raw_data: {
            type: "datasets::load",
            path: "wikitext",
            name: "wikitext-2-raw-v1",
        },
        tokenized_data: {
            type: "tokenize_data",
            dataset: { type: "ref", ref: "raw_data" },
            tokenizer: { pretrained_model_name_or_path: pretrained_model }
        },
        trained_model: {
            type: "torch::train",
            model: {
                type: "fairscale::with_wrapped_modules",
                model: {
                    type: "transformers::AutoModelForCausalLM::from_pretrained",
                    pretrained_model_name_or_path: pretrained_model,
                    low_cpu_mem_usage: load_with_low_cpu_mem_usage,
                },
                modules_to_wrap: ["transformer\\.h\\.[0-9]+"],  # tell FairScale to wrap the transformer's blocks individually
                fsdp_config: fsdp_config,
                activation_checkpointing: activation_checkpointing,
            },
            dataset_dict: { type: "ref", ref: "tokenized_data" },
            train_dataloader: dataloader,
            validation_split: "validation",
            grad_accum: grad_accum,
            train_steps: training_steps,
            validate_every: validate_every,
            checkpoint_every: validate_every,
            log_every: 1,
            device_count: devices,
            training_engine: training_engine,
        },
        final_metrics: {
            type: "torch::eval",
            model: { type: "ref", ref: "trained_model" },
            dataset_dict: { type: "ref", ref: "tokenized_data" },
            dataloader: single_device_dataloader,
            test_split: "test",
        },
    }
}

Run it#

Now we can run the experiment with:

tango run config.jsonnet -i tokenize_step.py -d /tmp/results