Lightning 2.1: Train Bigger, Better, Faster
Lightning AI is excited to announce the release of Lightning 2.1 ⚡. The theme this time around is “bigger, better, faster”.
Bigger because training large multi-billion parameter models has gotten even more efficient thanks to FSDP, efficient initialization and sharded checkpointing improvements.
Better because it’s easier than ever to scale models without making substantial code changes or installing third-party packages.
Faster because it leverages the latest hardware features to speed up training in low-bit precision thanks to new precision plugins like bitsandbytes and transformer engine.
All of these goodies are available both in PyTorch Lightning and Lightning Fabric. Don’t know what Fabric is? It’s the latest addition to Lightning’s family of tools – a fast and lightweight way to scale PyTorch models without boilerplate. You can convert PyTorch code to Fabric in just 5 lines and get access to SOTA distributed training features (DDP, FSDP, DeepSpeed, mixed precision and more) while maintaining full control over your training loop.
Upgrade to 2.1
Here is how you upgrade:
pip install -U lightning
Or if you’re using the older pytorch-lightning package:
pip install -U pytorch-lightning
Upgrading from 2.0 to 2.1 won’t require any code changes. If you’re upgrading from a version prior to 2.0, follow our migration guide.
Here are the big highlights we want you to try out in the new release.
Improvements To Large-Scale Training With FSDP
With FSDP, you can train large billion-parameter models that aren’t able to fit in a single GPU or even a single machine. In this release, the FSDP strategy gets substantial improvements and new features: It is now more user-friendly to configure, has memory management and speed improvements, and we have a brand new end-to-end user guide with best practices (Trainer, Fabric).
Efficient Saving and Loading of Large Checkpoints
When training large billion-parameter models with FSDP, saving and resuming training, or even just loading model parameters for finetuning can be challenging, as users are are often plagued by out-of-memory errors and speed bottlenecks. Starting with saving checkpoints, we added support for distributed/sharded checkpoints, enabled through the setting state_dict_type
in the strategy.
Trainer:
import lightning as L from lightning.pytorch.strategies import FSDPStrategy # Default used by the strategy strategy = FSDPStrategy(state_dict_type="full") # Enable saving distributed checkpoints strategy = FSDPStrategy(state_dict_type="sharded") trainer = L.Trainer(strategy=strategy, ...)
Fabric:
import lightning as L from lightning.fabric.strategies import FSDPStrategy # Saving distributed checkpoints is the default strategy = FSDPStrategy(state_dict_type="sharded") # Save consolidated (single file) checkpoints strategy = FSDPStrategy(state_dict_type="full") fabric = L.Fabric(strategy=strategy, ...)
Distributed checkpoints are the fastest and most memory efficient way to save the state of very large models. The distributed checkpoint format also makes it efficient to load these checkpoints back for resuming training in parallel, and it reduces the impact on CPU memory usage significantly. Furthermore, we’ve also introduced lazy-loading for non-distributed checkpoints which greatly reduces the impact on CPU memory usage when loading a consolidated (single-file) checkpoint (e.g. for finetuning). Learn more about these features in our FSDP guides (Trainer, Fabric).
Fast and Memory-Optimized Initialization
A major challenge that users face when working with large models such as LLMs is dealing with the extreme memory requirements. Even something as simple as instantiating a model becomes non-trivial if the model is so large it won’t fit in a single GPU or even a single machine. In Lightning 2.1, we are introducing empty-weights initialization through the Fabric.init_module()
and Trainer.init_module()
/LightningModule.configure_model()
methods.
Trainer:
import lightning as L
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
# Delay initialization of model to `configure_model()`
def configure_model(self):
# Model initialized in correct precision and weights on meta-device
self.model = ...
...
trainer = L.Trainer(strategy="fsdp", ...)
trainer.fit(model)
Fabric:
import lightning as L
fabric = L.Fabric(strategy="fsdp", ...)
# Model initialized in correct precision and weights on meta-device
with fabric.init_module(empty_init=True):
model = ...
# You can also initialize buffers and tensors directly on device and dtype
with fabric.init_tensor():
model.mask.create()
model.kv_cache.create()
x = torch.randn(4, 128)
# Materialization and sharding of model happens inside here
model = fabric.setup(model)
Read more about this new feature and its other benefits in our docs (Trainer, Fabric).
User-Friendly Configuration
We made it super easy to configure the sharding- and activation-checkpointing policy when you want to auto-wrap particular layers of your model for advanced control.
Before:
import lightning as L from lightning.pytorch.strategies import FSDPStrategy from torch.distributed.fsdp.wrap import ModuleWrapPolicy strategy = FSDPStrategy(auto_wrap_policy=ModuleWrapPolicy({MyTransformerBlock})) trainer = L.Trainer(strategy=strategy, ...)
After:
import lightning as L from lightning.pytorch.strategies import FSDPStrategy strategy = FSDPStrategy(auto_wrap_policy={MyTransformerBlock}) trainer = L.Trainer(strategy=strategy, ...)
True Half-Precision
Lightning now supports true half-precision for training and inference with all built-in strategies. With this setting, the memory required to store the model weights is only half of what is normally needed when running with float32. In addition, you get the same speed benefits as mixed precision training (precision="16-mixed"
) has:
import lightning as L
# default
trainer = L.Trainer(precision="32-true")
# train with model weights in `torch.float16`
trainer = L.Trainer(precision="16-true")
# train with model weights in `torch.bfloat16`
# (if hardware supports it)
trainer = L.Trainer(precision="bf16-true")
The same settings are also available in Fabric! We recommend to try bfloat16 training (
precision="bf16-true"
) as it is often more numerically stable than regular 16-bit precision (precision="16-true"
).Bitsandbytes Quantization
With the new Bitsandbytes precision plugin, you can now quantize your model for significant memory savings during training, finetuning, or inference with a selection of several state-of-the-art quantization algorithms (int8, fp4, nf4 and more). For the first time, Trainer and Fabric make bitsandbytes easy to use for general models.
Trainer:
import lightning as L from lightning.pytorch.plugins import BitsandbytesPrecisionPlugin # this will pick out the compute dtype automatically, by default `bfloat16` precision = BitsandbytesPrecisionPlugin("nf4-dq") trainer = L.Trainer(plugins=precision)
Fabric:
import lightning as L from lightning.fabric.plugins import BitsandbytesPrecision # this will pick out the compute dtype automatically, by default `bfloat16` precision = BitsandbytesPrecision("nf4-dq") trainer = L.Fabric(plugins=precision)
Transformer Engine
The Transformer Engine by NVIDIA is a library for accelerating transformer layers on the new Hopper (H100) generation of GPUs. With the integration in Lightning Trainer and Fabric, you have easy access to the 8-bit mixed precision for significant speed ups:
Trainer:
import lightning as L # Select 8-bit mixed precision via TransformerEngine, with model weights in float16 trainer = L.Trainer(precision="transformer-engine-float16")
Fabric:
import lightning as L # Select 8-bit mixed precision via TransformerEngine, with model weights in float16 fabric = L.Fabric(precision="transformer-engine-float16")
More configuration options are available through the respective plugins in Trainer and Fabric.
Lightning on TPU Goes Brrr
Lightning 2.1 runs on the latest generation of TPU hardware on Google Cloud! TPU-v4 and TPU-v5 are now fully supported both in Fabric and Trainer and run using the new PjRT runtime by default. PjRT is the runtime used by Jax and has shown an average improvement of 35% on benchmarks.
Trainer:
import lightning as L trainer = L.Trainer(accelerator="tpu", devices=8) model = MyModel() trainer.fit(model) # uses PjRT if available
Fabric:
import lightning as L def train(fabric): ... fabric = L.Fabric(accelerator="tpu") fabric.launch(train) # uses PjRT if available
And what’s even more exciting, you can now scale massive multi-billion parameter models on TPUs using FSDP.
import lightning as L
from lightning.fabric.strategies import XLAFSDPStrategy
strategy = XLAFSDPStrategy(
# Most arguments from the PyTorch native FSDP strategy are also available here!
auto_wrap_policy={Block},
activation_checkpointing_policy={Block},
state_dict_type="full",
sequential_save=True,
)
fabric = L.Fabric(devices=8, strategy=strategy)
fabric.launch(finetune)
You can find a full end-to-end finetuning example script in our Lit-GPT repository. The new XLA-FSDP strategy is experimental and currently only available in Fabric. Support in the Trainer will follow in the future.
Granular Control Over Checkpoints in Fabric
Several improvements for checkpoint saving and loading have landed in Fabric, enabling more fine-grained control over what is saved/loaded while reducing boilerplate code:
There is a new Fabric.load_raw()
method with which you can load model- or optimizer state-dicts saved externally by a non-Fabric application (e.g., raw PyTorch):
import lightning as L fabric = L.Fabric() model = MyModel() # A model weights file saved by your friend who doesn't use Fabric fabric.load_raw("path/to/model.pt", model) # Equivalent to this: # model.load_state_dict(torch.load("path/to/model.pt"))
Then there is new parameter Fabric.load(..., strict=True|False)
to disable strict loading:
import lightning as L fabric = L.Fabric() model = MyModel() state = {"model": model} # strict loading is the default fabric.load("path/to/checkpoint.ckpt", state, strict=True) # disable strict loading fabric.load("path/to/checkpoint.ckpt", state, strict=False)
Finally, a new parameter Fabric.save(..., filter=...)
that enables you to exclude certain parameters of your model without writing boilerplate code for it
import lightning as L fabric = L.Fabric() model, optimizer = ... state = {"model": model, "optimizer": optimizer, "foo": 123} # save only the weights that match a pattern filter = {"model": lambda k, v: "weight" in k} fabric.save("path/to/checkpoint.ckpt", state, filter=filter)
You can read more about the new options in our checkpoint guide.
Conclusion
Our vision here at Lightning is to make deep learning accessible to everyone, enabling both beginners and experts to contribute to the advancement of the state of the art in AI research, or to just build cool stuff. With 2.1, we’re putting new tools for making models large and efficient to train in the hands of our users, so that they can invest the time they used to spend in debugging boilerplate code elsewhere.
Lightning is built by the community, for the community. We want to thank the 75+ developers who have contributed code to 2.1, and hundreds of users who gave us feedback.
Make sure to join our Discord if you have any questions or want to chat about Lightning.