Lightning AI Studios: Never set up a local environment again →

← Back to blog

Scaling Large (Language) Models with PyTorch Lightning

In this blog, you will learn about techniques to train large models like Llama (or any LLM) and Stable Diffusion using distributed training strategy FSDP with PyTorch Lightning.

In recent times, there has been a notable shift in the scale of models, particularly in the realm of language models such as GPT 4, Llama, and Falcon. Various research teams have embarked on the training of extremely large language models (LLMs) equipped with billions of parameters, all on massive datasets. These meticulously trained models serve a multitude of purposes, including classification, generation, retrieval, and chat-based applications in various downstream tasks. The open-source community is actively involved in fine-tuning these large models to make them adaptable to diverse scenarios and to optimize their performance for quicker inference.

These large models are not confined solely to the language models. They also find application in image generation and even in the field of genomic research, where they play a crucial role. An example of this is HyenaDNA, a genomic foundation model capable of handling context lengths of up to 1 million tokens, offering single nucleotide resolution, all implemented using PyTorch Lightning.

Training Billion Parameter Large Models

Training a large model is often a complex endeavor, mainly because of various challenges such as the notorious CUDA Out of Memory (OOM) issue, the need for distributed training on a large scale, and cost management. PyTorch Lightning not only eliminates the need for writing repetitive code for distributed training, handling GPU/TPU device management, and configuring precision settings but also offers scalability for training trillion-parameter models. It has been extensively tested in the training of foundational models like Stable Diffusion, Nvidia’s NeMo, and Stanford’s HyenaDNA, and is trusted by hundreds, if not thousands, of companies worldwide.

There is often a perception that when using a high-level library such as PyTorch Lightning, developers forfeit control over the program. However, PyTorch Lightning not only grants you control over most, if not all, aspects of the training process through its comprehensive hooks and callbacks at each stage of training, testing, and validation, but it also allows you to override and customize the training behavior to your liking. With its battle-tested Trainer and LightningModule, PyTorch Lightning assists in preventing human errors by optimizing and managing the engineering of the model training process.

đź’ˇ If you still need more flexibility for managing your training process then you can have a look at Lightning Fabric which in simple words is internals of PyTorch Lightning directly available to you for managing distributed training and precision settings without the need to use LightningModule and Trainer.

 

CUDA Out of Memory with Large (Language) Models

Every machine learning engineer has encountered this error at least once and more often when training a large model. Let’s understand when and why do we face this issue with an example.

Suppose we aim to train a Llama 7B variant. This model itself comprises 7 billion parameters. When we opt to train it using float16 or bfloat16 precision, often referred to as half precision, the mere act of loading the model into GPU memory necessitates approximately 14 gigabytes (GB). This calculation, however, does not account for the additional memory needed for the optimizer and gradients.

The optimizer’s state stores the model parameters that consumes GPU memory. Furthermore, the model retains gradients for the purpose of backpropagation, further contributing to memory usage. When we factor in the model parameters, optimizer state, and gradients, the total memory requirement surges to over 40 GB. This demand aligns with the typical GPU capacity found in models like the A100.

Using PyTorch Lightning, we can reduce the memory requirement by sharding and offloading the parameters to multiple devices. PyTorch Lightning Trainer supports FSDP (Fully Sharded Data Parallel) strategy, offering a straightforward API to address this memory constraint.

import pytorch_lightning as pl
trainer = pl.Trainer(accelerator="gpu", devices=4, strategy="fsdp")

Optimizing Large Model Training

FSDP is a data-parallel training technique, it distributes the model’s parameters, gradients, and optimizer states among data-parallel workers and allows the option to offload the sharded model parameters to CPUs thus saving GPU memory and enable training of large models at massive scale.

We’ll explore how to perform optimized training for the Llama 7B model. But first, we’ll delve into the FSDP (Fully Sharded Data Parallel) strategy in PyTorch Lightning using a 2B transformer model. Afterward, we’ll apply our newfound knowledge to train the Llama model.

Training on a single GPU

We set up a Transformer training script on the WikiText2 dataset. We create a LightningModule named LanguageModel which defines the training step and configure optimizer. We create a dataloader and configure the Trainer. All the device and training strategy-related arguments are provided to the trainer.


import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning.demos import Transformer, WikiText2
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.callbacks import DeviceStatsMonitor import pytorch_lightning as pl # intialize model, optimizer and defines training step
class LanguageModel(pl.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.model = Transformer(
vocab_size=vocab_size,
nlayers=64,
nhid=4096,
ninp=1024,
nhead=64,
) def training_step(self, batch):
input, target = batch
output = self.model(input, target)
loss = F.nll_loss(output, target.view(-1))
self.log("train_loss", loss, prog_bar=True)
return loss def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.1) pl.seed_everything(42) # Data
dataset = WikiText2()
train_dataloader = DataLoader(dataset, batch_size=64) # Model
model = LanguageModel(vocab_size=dataset.vocab_size) # Trainer
trainer = pl.Trainer(accelerator="cuda", devices=1)
trainer.fit(model, train_dataloader)
trainer.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

After running the training script, we observe that it consumed 39.02 GB GPU memory for a 2B parameter model. If you are running this on a A10G 24GB GPU then you are already out of luck. In the following section we will see how to overcome this limitation.

Sharding the training on multiple GPUs

We’ll distribute the model across multiple devices using the FSDP strategy. To employ 4 GPUs with FSDP, all that’s required is to update our Trainer arguments by setting devices=4 and strategy="fsdp". It will use the default FSDP strategy setting in PyTorch Lightning and setup our model and optimizer for distributed training.


trainer = pl.Trainer(accelerator="cuda", devices=4, strategy="fsdp")

After the execution of the revised script, it utilizes approximately 29.30 GB of memory per GPU. We’ve managed to decrease memory usage by 24.8%, but it remains insufficient for training on an A10G with 24GB of memory.

FSDP allows wrapping the layers in nested way so that only layers in a single FSDP instance need to gather the full parameters to a single device during forward or backward computations. To optimize our memory usage further, we will import the FSDPStrategy class and configure auto_wrap_policy .


from pytorch_lightning.strategies import FSDPStrategy
import pytorch_lightning as pl strategy = FSDPStrategy(
auto_wrap_policy={
nn.TransformerEncoderLayer,
nn.TransformerDecoderLayer,
},
activation_checkpointing_policy=None
)
trainer = pl.Trainer(accelerator="cuda", devices=4, strategy=strategy)
trainer.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

This reduces the memory usage to ~13 GB per GPU, so now we can train on A10G 24 GB GPUs.

Activation Checkpointing

We can save more memory using activation checkpointing where we trade the memory for compute. Model stores the intermediate activations of the entire computation graph for backward propagation. With checkpointing, we don’t save the intermediate activations and recompute them during the backward pass. With PyTorch Lightning we can enable activation checkpointing similar to how we set the auto_wrap_policy.


layers = {
nn.TransformerEncoderLayer,
nn.TransformerDecoderLayer,
}
strategy = FSDPStrategy(
auto_wrap_policy=layers,
activation_checkpointing_policy=layers # enables activation checkpointing for the given layers
)
trainer = pl.Trainer(accelerator="cuda", devices=4, strategy=strategy)

After enabling activation checkpointing we were able to reduce the GPU memory to 12.5 GB. Note that this comes with increase in training time.

You can find the full training script here.

CPU Offloading

You can reduce the GPU memory requirement drastically by using CPU offloading where the parameters are offloaded to CPU. This method will make the training slower and should only be used if you have enough CPU memory and you are not able to scale the model after applying all other strategies including checkpointing, parallelization and mixed precision. To enable CPU offloading you can pass cpu_offload=True to the FSDPStrategy class.


strategy = FSDPStrategy(
auto_wrap_policy=layers,
activation_checkpointing_policy=layers,
cpu_offload=True
)
trainer = pl.Trainer(accelerator="cuda", devices=4, strategy=strategy)

Training a Llama 7B

In the above section, we saw that we can reduce the GPU memory from 39 GB to 12.5 GB by using distributed training strategy and activation checkpointing with FSDP. We will use the all the techniques we used for training our 2B parameter Transformer model to train the Llama 7B model. In addition, we will also use half-precision which reduces the memory by using lower precision. We won’t go deep in precision settings but you can read more about it in the Accelerating Large Language Models with Mixed-Precision Techniques blog.

We will use Lit-GPT for initializing our Llama 7B model. Similar to the LightningModule for Transformer, we create configure_model, and configure_optimizers for initializing model and optimizer and implement the training and validation step functions.


from lit_gpt.model import GPT class LightningGPTModule(L.LightningModule):
def __init__(self, config: Config) -> None:
super().__init__()
self.config = config
self.module: Optional[torch.nn.Module] = None def configure_model(self) -> None:
self.module = GPT(self.config)
self.module.apply(self.module._init_weights) def configure_optimizers(self) -> torch.optim.Optimizer:
return torch.optim.AdamW(
self.module.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
betas=(beta1, beta2),
foreach=False,
) def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
if not decay_lr:
return
# determine and set the learning rate for this iteration
lr = get_lr(self.trainer.fit_loop.total_batch_idx)
for optimizer in self.trainer.strategy.optimizers:
for param_group in optimizer.param_groups:
param_group["lr"] = lr def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor:
input_ids, targets = batch
logits = self.module(input_ids)
loss = chunked_cross_entropy(logits, targets, chunk_size=0)
self.log("train_loss", loss, on_step=True, on_epoch=False, prog_bar=True)
return loss def validation_step(self, batch: Any, batch_idx: int) -> None:
input_ids, targets = batch
logits = self.module(input_ids)
loss = chunked_cross_entropy(logits, targets, chunk_size=0)
self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)

We create a Trainer with all the configurations and finally run the training using trainer.fit(model, train_dataloader, val_dataloader) .


from lit_gpt.model import Block strategy = FSDPStrategy(
auto_wrap_policy={Block},
activation_checkpointing_policy={Block},
# the argument is not available in the Trainer strategy, but it's the default anyways
# state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
) trainer = L.Trainer(
devices=devices,
strategy=strategy,
precision="bf16-true",
max_steps=max_iters,
max_epochs=1,
limit_val_batches=eval_iters,
accumulate_grad_batches=gradient_accumulation_steps,
log_every_n_steps=log_interval,
val_check_interval=eval_interval,
) trainer.fit(model, train_dataloader, val_dataloader, ckpt_path="last")

You can find the full training script here. Feel free to download the script and give it a spin. You can experiment with different FSDP configurations like cpu_offload, and precision settings to scale on your hardware.

Conclusion

We learned how to train large models with billions of parameters and scale the training across multiple GPUs and sharding with FSDP. We also explored Llama 7B training with bfloat16 precision using PyTorch Lightning. You can learn more about LLMs and model training in our blogs here.

Resources

Join our Discord community to chat and ask your questions!