Lightning AI Studios: Never set up a local environment again →

← Back to blog

Introduction to Lightning Fabric

Lightning Fabric is a new, open-source library that allows you to quickly and easily scale models while maintaining full control over your training loop.

In the past, getting PyTorch code to run efficiently on GPUs and scaling it up to many machines and large datasets was possible with PyTorch Lightning. As time went on, however, we became aware of the need to provide a scaling option that landed somewhere between a raw deep learning framework like PyTorch on the one hand, and a high-level, feature-rich framework like PyTorch Lightning. Lightning Fabric is just that.

While PyTorch Lightning provides many features to save time and improve readability and collaboration, there are complex use cases where full control over the training loop is needed. That’s why we built Fabric.

 

What’s inside?

Fabric is part of the Lightning 2.0 package. You can install or upgrade with:



pip install -U lightning

The newest addition is the new module lightning.fabric. Here’s what’s inside:

  • Accelerators (CPU, GPU, TPU, …)
  • Distributed Training (DDP, DeepSpeed, FSDP, …)
  • Mixed Precision Training (FP32, FP16, Bfloat16, …)
  • Loggers (TensorBoard, CSV, …)
  • Callback system
  • Checkpointing primitives (supports distributed checkpoints)
  • Distributed Collectives
  • Gradient Accumulation
  • Lots more!

All of these features are already available in PyTorch Lightning, but the key difference with Fabric is how they’re applied to your code:

How Fabric works can best be demonstrated with a short example:

 

Accelerate your code without the boilerplate

Let’s start with a simple PyTorch training script:


import torch
import torch.nn.functional as F
from lightning.pytorch.demos import Transformer, WikiText2
   
dataset = WikiText2()
dataloader = torch.utils.data.DataLoader(dataset)
model = Transformer(vocab_size=dataset.vocab_size)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
   
model.train()
for epoch in range(20):
for batch in dataloader:
        input, target = batch
        optimizer.zero_grad()
        output = model(input, target)
        loss = F.nll_loss(output, target.view(-1))
        loss.backward()
        optimizer.step()

Unfortunately, this code only runs on the CPU. That’s not super ideal.

To run this on a GPU, it only takes a couple of .to("cuda") calls on the model and data. But with larger models and bigger data samples (e.g., high-res images) or bigger batch sizes, we’re limited by the available GPU memory. To get around this, we could implement mixed precision training and distributed training across multiple GPUs, but we would have to once again change our code. On top of that, doing this correctly and efficiently can be difficult, time-consuming, and will produce a lot of boilerplate code.

With Lightning Fabric, you can add a few calls to your code, once, and then you have the flexibility to run it anywhere, like so:


import torch
import torch.nn.functional as F import lightning as L
from lightning.pytorch.demos import Transformer, WikiText2 fabric = L.Fabric() dataset = WikiText2()
dataloader = torch.utils.data.DataLoader(dataset)
model = Transformer(vocab_size=dataset.vocab_size)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) model, optimizer = fabric.setup(model, optimizer)
dataloader = fabric.setup_dataloaders(dataloader) model.train()
for epoch in range(20):
for batch in dataloader:
input, target = batch
optimizer.zero_grad()
output = model(input, target)
loss = F.nll_loss(output, target.view(-1))
fabric.backward(loss)
optimizer.step()

Now, unleash the full power of Fabric on your Python script.

Run on the M1/M2 GPU of your MacBook:



lightning run model train.py --accelerator=mps

Or on a beefy GPU server in float-16 precision:



lightning run model train.py \
--accelerator=cuda \
--devices=8 \
--precision=bf16

Or across multiple machines in your cluster:



lightning run model train.py \
--accelerator=cuda \
--devices=8 \
--num-nodes=4 \
--main-address=10.10.10.24 \
--node-rank=1

You can find plenty of other configurations in the official Fabric documentation.

Most importantly: None of these configurations require any code changes on your side.

 

More use cases for Fabric

While the examples above demonstrate some potential use cases for Fabric, we’ve also got plenty of more robust examples that you can find in our official documentation. These include image classification, reinforcement learning, large language model pre-training with nanoGPT, and much more.

 

Make your own trainer

Accelerating your training code is just the tip of the iceberg. Fabric brings a toolset that helps you build a fully-fledged trainer from the ground up.

Here’s a selection of key building blocks you can add to your own custom Trainer (all optional, of course!):

 

Loggers/Experiment Trackers

Fabric doesn’t track metrics by default. It’s all up to the user to decide what and when they want to log a metric value. To add logging capabilities to your Trainer, you can choose one or several loggers from the lightning.fabric.loggers module:



import lightning as L
from lightning.fabric.loggers import TensorBoardLogger # Pick a logger and add it to Fabric
logger = TensorBoardLogger(root_dir="logs")
fabric = L.Fabric(loggers=logger) # Python scalar or tensor scalar
fabric.log("some_value", value) # Works with TorchMetrics too
fabric.log("my_metric", metric.compute())

You can control many more things with logging, such as the logging frequency, log media (images, audio, etc.), multiple loggers at once, and more.

Read about logging in our docs.

 

Checkpoints

Saving and resuming training state is important for development and long running experiments. Fabric supports this through the .save and .load primitives:



# Define the state of your program/loop
state = {"model1": model1, "model2": model2, "optimizer": optimizer, "iteration": iteration, "hparams": ...} # Save it to a file
fabric.save("path/to/checkpoint.ckpt", state) # Load the state (in-place)
fabric.load("path/to/checkpoint.ckpt", state)

As you can see, the syntax is almost identical to torch.save and torch.load, and therefore also very easy to convert your existing PyTorch code. However, Fabric’s saving and loading methods take care of correctly saving sharded models under distributed settings. Without Fabric, this would require the user to write a lot of boilerplate code.

Read about checkpoints in our docs.

 

Callbacks

When you build a Trainer/framework for your team, a community, or even just for yourself, it can be useful to have a callback system to hook into the machinery and extend its functionality without needing to change the actual source code. Fabric brings the building blocks so you don’t have to reinvent the wheel:


import torch
import torch.nn.functional as F import lightning as L
from lightning.pytorch.demos import Transformer, WikiText2 # The code of a callback can live anywhere, away from your Trainer
class MyCallback:
def on_train_batch_end(self, loss, output):
# Do something with the loss and output
print("current loss:", loss) # Add one or several callbacks:
fabric = L.Fabric(callbacks=[MyCallback()]) dataset = WikiText2()
dataloader = torch.utils.data.DataLoader(dataset)
model = Transformer(vocab_size=dataset.vocab_size)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) model, optimizer = fabric.setup(model, optimizer)
dataloader = fabric.setup_dataloaders(dataloader) # Anywhere in your Trainer, call the appropriate callback methods
model.train()
for epoch in range(20):
for batch in dataloader:
input, target = batch
optimizer.zero_grad()
output = model(input, target)
loss = F.nll_loss(output, target.view(-1))
fabric.backward(loss)
optimizer.step() # Let a callback add some arbitrary processing at the appropriate place
# Give the callback access to some varibles
fabric.call("on_train_batch_end", loss=loss, output=output)
Read about callbacks in our docs.

 

LightningModule

As mentioned before, Fabric can wrap around your PyTorch code no matter how it is organized, giving you maximum flexibility! However, maybe you prefer to standardize and separate the research code (model, loss, optimization, etc.) from the “trainer” code (training loop, checkpointing, logging, etc.). This is exactly what the the LightningModule was made for!


import torch
import torch.nn.functional as F import lightning as L
from lightning.pytorch.demos import Transformer, WikiText2 # Organize code in LightningModule hooks
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
self.dataset = dataset = WikiText2()
self.model = Transformer(vocab_size=dataset.vocab_size) def training_step(self, batch):
input, target = batch
output = self.model(input, target)
loss = F.nll_loss(output, target.view(-1))
return loss def train_dataloader(self):
return torch.utils.data.DataLoader(self.dataset) def configure_optimizers(self):
return torch.optim.SGD(model.parameters(), lr=0.1) # Instantiate the LightningModule
model = LitModel()

Now, you can make your Trainer accept an instance of LightningModule. The Trainer needs to call the LightningModule hooks, e.g., training_step, at the right time:


fabric = L.Fabric() # Call `configure_optimizers` and dataloader hooks
model, optimizer = fabric.setup(model, model.configure_optimizers())
dataloader = fabric.setup_dataloaders(model.train_dataloader()) # Call the hooks at the right time
model.on_train_start() model.train()
for epoch in range(20):
for batch in dataloader:
optimizer.zero_grad()
# Call training_step at the right place
loss = model.training_step(batch)
fabric.backward(loss)
optimizer.step()
Learn more about how the LightningModule works with Fabric in our docs.

 

These are the essentials, but there is more cool stuff that didn’t fit into this blog post: You can read up about more advanced topics like gradient accumulation, distributed communication, multiple models and optimizers, and more in our docs.

An important takeaway here is that all the tools we just saw are opt-in: You pick what you find useful, and leave what you don’t!

 

The future of Fabric and PyTorch Lightning

Fabric is an important step towards modularizing and “unbundling” our beloved Lightning Trainer as we know it. In the near future we will have the internals of Trainer refactored to rely heavily on Fabric as it becomes the core framework to build *any* Trainer. This means from now on, Lightning will offer two distinct experiences that target different user groups:

 

PyTorch Lightning: This is our general-purpose, battle tested, fast and reliable trainer with the most popular features built in. The best solution for ML researchers who want to get started quickly and iterate fast.

Lightning Fabric: For the ML researcher and framework builders out there who want to hack together their own trainers. It is the most flexible way to work with PyTorch while still getting the most valuable benefits of Lightning.

Starting with 2.0, Lightning becomes more modular, easier to customize and hack, easier to understand in terms of code readability, and overall a more lightweight package with fewer dependencies.

 

Wrap-up!

With Lightning 2.0 and the introduction of Fabric, we are resolving several existing challenges and address feedback from the PyTorch community. The users who faced difficulties building domain-specific frameworks on top of PyTorch Lightning can now use Fabric, a toolset that is much better-suited to these tasks. At the same time, users who love the pre-built, feature-rich Trainer but were struggling to understand its internals now get a cleaner, more readable and debuggable experience. Plus, this also significantly lowers the bar for contributing to the Lightning source code and discussions around it. Moreover, and we believe this is the best part, Fabric will reach new users that were previously hesitant to adopt Lightning due to the abstractions in the Trainer. Fabric doesn’t force abstractions, follows an opt-in philosophy, and can immediately add value for any PyTorch code base.

We hope you give Fabric a try. Please reach out to us for feedback on Discord, ask questions on the Forums, and report bugs and feature requests in our GitHub!