Lightning AI Studios: Never set up a local environment again →

← Back to blog

Lightning 2.1: Train Bigger, Better, Faster

Lightning 2.1: Train Bigger, Better, Faster

Lightning AI is excited to announce the release of Lightning 2.1 ⚡. The theme this time around is “bigger, better, faster”.

Bigger because training large multi-billion parameter models has gotten even more efficient thanks to FSDP, efficient initialization and sharded checkpointing improvements.

Better because it’s easier than ever to scale models without making substantial code changes or installing third-party packages.

Faster because it leverages the latest hardware features to speed up training in low-bit precision thanks to new precision plugins like bitsandbytes and transformer engine.

All of these goodies are available both in PyTorch Lightning and Lightning Fabric. Don’t know what Fabric is? It’s the latest addition to Lightning’s family of tools – a fast and lightweight way to scale PyTorch models without boilerplate. You can convert PyTorch code to Fabric in just 5 lines and get access to SOTA distributed training features (DDP, FSDP, DeepSpeed, mixed precision and more) while maintaining full control over your training loop.

Upgrade to 2.1

Here is how you upgrade:

pip install -U lightning

Or if you’re using the older pytorch-lightning package:
pip install -U pytorch-lightning

Upgrading from 2.0 to 2.1 won’t require any code changes. If you’re upgrading from a version prior to 2.0, follow our migration guide.

Here are the big highlights we want you to try out in the new release.

Improvements To Large-Scale Training With FSDP

With FSDP, you can train large billion-parameter models that aren’t able to fit in a single GPU or even a single machine. In this release, the FSDP strategy gets substantial improvements and new features: It is now more user-friendly to configure, has memory management and speed improvements, and we have a brand new end-to-end user guide with best practices (TrainerFabric).

Efficient Saving and Loading of Large Checkpoints

When training large billion-parameter models with FSDP, saving and resuming training, or even just loading model parameters for finetuning can be challenging, as users are are often plagued by out-of-memory errors and speed bottlenecks. Starting with saving checkpoints, we added support for distributed/sharded checkpoints, enabled through the setting state_dict_type in the strategy.

Trainer:

import lightning as L
from lightning.pytorch.strategies import FSDPStrategy

# Default used by the strategy
strategy = FSDPStrategy(state_dict_type="full")

# Enable saving distributed checkpoints
strategy = FSDPStrategy(state_dict_type="sharded")

trainer = L.Trainer(strategy=strategy, ...)

Fabric:

import lightning as L
from lightning.fabric.strategies import FSDPStrategy

# Saving distributed checkpoints is the default
strategy = FSDPStrategy(state_dict_type="sharded")

# Save consolidated (single file) checkpoints
strategy = FSDPStrategy(state_dict_type="full")

fabric = L.Fabric(strategy=strategy, ...)

Distributed checkpoints are the fastest and most memory efficient way to save the state of very large models. The distributed checkpoint format also makes it efficient to load these checkpoints back for resuming training in parallel, and it reduces the impact on CPU memory usage significantly. Furthermore, we’ve also introduced lazy-loading for non-distributed checkpoints which greatly reduces the impact on CPU memory usage when loading a consolidated (single-file) checkpoint (e.g. for finetuning). Learn more about these features in our FSDP guides (Trainer, Fabric).

Fast and Memory-Optimized Initialization

A major challenge that users face when working with large models such as LLMs is dealing with the extreme memory requirements. Even something as simple as instantiating a model becomes non-trivial if the model is so large it won’t fit in a single GPU or even a single machine. In Lightning 2.1, we are introducing empty-weights initialization through the Fabric.init_module() and Trainer.init_module()/LightningModule.configure_model() methods.

Trainer:



import lightning as L

class MyModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        # Delay initialization of model to `configure_model()`

    def configure_model(self):
        # Model initialized in correct precision and weights on meta-device
        self.model = ...

    ...

trainer = L.Trainer(strategy="fsdp", ...)
trainer.fit(model)

Fabric:


import lightning as L

fabric = L.Fabric(strategy="fsdp", ...)

# Model initialized in correct precision and weights on meta-device
with fabric.init_module(empty_init=True):
    model = ...

# You can also initialize buffers and tensors directly on device and dtype
with fabric.init_tensor():
    model.mask.create()
    model.kv_cache.create()
    x = torch.randn(4, 128)

# Materialization and sharding of model happens inside here
model = fabric.setup(model)

Read more about this new feature and its other benefits in our docs (TrainerFabric).

User-Friendly Configuration

We made it super easy to configure the sharding- and activation-checkpointing policy when you want to auto-wrap particular layers of your model for advanced control.

Before:

import lightning as L
from lightning.pytorch.strategies import FSDPStrategy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy

strategy = FSDPStrategy(auto_wrap_policy=ModuleWrapPolicy({MyTransformerBlock}))
trainer = L.Trainer(strategy=strategy, ...)

After:

import lightning as L
from lightning.pytorch.strategies import FSDPStrategy

strategy = FSDPStrategy(auto_wrap_policy={MyTransformerBlock})
trainer = L.Trainer(strategy=strategy, ...)

True Half-Precision

Lightning now supports true half-precision for training and inference with all built-in strategies. With this setting, the memory required to store the model weights is only half of what is normally needed when running with float32. In addition, you get the same speed benefits as mixed precision training (precision="16-mixed") has:



import lightning as L

# default
trainer = L.Trainer(precision="32-true")

# train with model weights in `torch.float16`
trainer = L.Trainer(precision="16-true")

# train with model weights in `torch.bfloat16`
# (if hardware supports it)
trainer = L.Trainer(precision="bf16-true")

The same settings are also available in Fabric! We recommend to try bfloat16 training (precision="bf16-true") as it is often more numerically stable than regular 16-bit precision (precision="16-true").

Bitsandbytes Quantization

With the new Bitsandbytes precision plugin, you can now quantize your model for significant memory savings during training, finetuning, or inference with a selection of several state-of-the-art quantization algorithms (int8, fp4, nf4 and more). For the first time, Trainer and Fabric make bitsandbytes easy to use for general models.

Trainer:

import lightning as L
from lightning.pytorch.plugins import BitsandbytesPrecisionPlugin

# this will pick out the compute dtype automatically, by default `bfloat16`
precision = BitsandbytesPrecisionPlugin("nf4-dq")
trainer = L.Trainer(plugins=precision)

Fabric:

import lightning as L
from lightning.fabric.plugins import BitsandbytesPrecision

# this will pick out the compute dtype automatically, by default `bfloat16`
precision = BitsandbytesPrecision("nf4-dq")
trainer = L.Fabric(plugins=precision)

Transformer Engine

The Transformer Engine by NVIDIA is a library for accelerating transformer layers on the new Hopper (H100) generation of GPUs. With the integration in Lightning Trainer and Fabric, you have easy access to the 8-bit mixed precision for significant speed ups:

Trainer:

import lightning as L

# Select 8-bit mixed precision via TransformerEngine, with model weights in float16
trainer = L.Trainer(precision="transformer-engine-float16")

Fabric:

import lightning as L

# Select 8-bit mixed precision via TransformerEngine, with model weights in float16
fabric = L.Fabric(precision="transformer-engine-float16")

More configuration options are available through the respective plugins in Trainer and Fabric.

Lightning on TPU Goes Brrr

Lightning 2.1 runs on the latest generation of TPU hardware on Google Cloud! TPU-v4 and TPU-v5 are now fully supported both in Fabric and Trainer and run using the new PjRT runtime by default. PjRT is the runtime used by Jax and has shown an average improvement of 35% on benchmarks.

Trainer:

import lightning as L

trainer = L.Trainer(accelerator="tpu", devices=8)
model = MyModel()
trainer.fit(model)  # uses PjRT if available

Fabric:

import lightning as L

def train(fabric):
    ...

fabric = L.Fabric(accelerator="tpu")
fabric.launch(train)  # uses PjRT if available

And what’s even more exciting, you can now scale massive multi-billion parameter models on TPUs using FSDP.



import lightning as L
from lightning.fabric.strategies import XLAFSDPStrategy

strategy = XLAFSDPStrategy(
    # Most arguments from the PyTorch native FSDP strategy are also available here!
    auto_wrap_policy={Block},
    activation_checkpointing_policy={Block},
    state_dict_type="full",
    sequential_save=True,
)

fabric = L.Fabric(devices=8, strategy=strategy)
fabric.launch(finetune)

You can find a full end-to-end finetuning example script in our Lit-GPT repository. The new XLA-FSDP strategy is experimental and currently only available in Fabric. Support in the Trainer will follow in the future.

Granular Control Over Checkpoints in Fabric

Several improvements for checkpoint saving and loading have landed in Fabric, enabling more fine-grained control over what is saved/loaded while reducing boilerplate code:

There is a new Fabric.load_raw() method with which you can load model- or optimizer state-dicts saved externally by a non-Fabric application (e.g., raw PyTorch):

import lightning as L

fabric = L.Fabric()
model = MyModel()

# A model weights file saved by your friend who doesn't use Fabric
fabric.load_raw("path/to/model.pt", model)

# Equivalent to this:
# model.load_state_dict(torch.load("path/to/model.pt"))

Then there is new parameter Fabric.load(..., strict=True|False) to disable strict loading:

import lightning as L

fabric = L.Fabric()
model = MyModel()
state = {"model": model}

# strict loading is the default
fabric.load("path/to/checkpoint.ckpt", state, strict=True)

# disable strict loading
fabric.load("path/to/checkpoint.ckpt", state, strict=False)

Finally, a new parameter Fabric.save(..., filter=...) that enables you to exclude certain parameters of your model without writing boilerplate code for it

import lightning as L

fabric = L.Fabric()
model, optimizer = ...

state = {"model": model, "optimizer": optimizer, "foo": 123}

# save only the weights that match a pattern
filter = {"model": lambda k, v: "weight" in k}
fabric.save("path/to/checkpoint.ckpt", state, filter=filter)

You can read more about the new options in our checkpoint guide.

Conclusion

Our vision here at Lightning is to make deep learning accessible to everyone, enabling both beginners and experts to contribute to the advancement of the state of the art in AI research, or to just build cool stuff. With 2.1, we’re putting new tools for making models large and efficient to train in the hands of our users, so that they can invest the time they used to spend in debugging boilerplate code elsewhere.

Lightning is built by the community, for the community. We want to thank the 75+ developers who have contributed code to 2.1, and hundreds of users who gave us feedback.

Make sure to join our Discord if you have any questions or want to chat about Lightning.