Introducing Lit-LLaMA: a minimal, optimized rewrite of LLaMA licensed under Apache 2.0 →

Train at scale and retain control, with Fabric

Fabric is the easiest way to scale your models while maintaining full control over your training loop and inference logic.GitHub Iconpip install lightning

Quick implementation

No need to refactor your code. Just swap a few lines and you can start unlocking scale.

Control your training loop

Fabric’s features are opt-in, making it easier to develop and debug your PyTorch code as you gradually add only the features you need.

Great for complex tasks

Full control over training logic and distributed mechanisms, for tasks like reinforcement learning and training large-scale transformers.

Scale PyTorch models

Fabric was designed with multi-billion parameter models in mind, allowing you to maximize performance while leveraging tools like callbacks and checkpoints when you need them. Write your own training logic without the boilerplate, right down to the individual optimizer calls.


import lightning as L
from torch import nn, optim, utils
import torchvision def main():
# create models
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28)) params = list(encoder.parameters()) + list(decoder.parameters())
optimizer = optim.Adam(params, lr=1e-3) # setup data
dataset = torchvision.datasets.MNIST(".", download=True, transform=torchvision.transforms.ToTensor())
train_loader = utils.data.DataLoader(dataset, batch_size=64) # setup Fabric
fabric = L.Fabric()
encoder, optimizer = fabric.setup(encoder, optimizer)
decoder = fabric.setup(decoder)
train_loader = fabric.setup_dataloaders(train_loader) # train the model
for epoch in range(2):
fabric.print("Epoch:", epoch)
for i, batch in enumerate(train_loader):
# get the inputs; data is a list of [inputs, labels]
x, y = batch
x = x.view(x.size(0), -1) optimizer.zero_grad() # forward + loss
z = encoder(x)
x_hat = decoder(z)
loss = nn.functional.mse_loss(x_hat, x) # backward + optimize
fabric.backward(loss)
optimizer.step() if i % 100 == 0:
fabric.print("train_loss", float(loss)) fabric.log("train_loss", loss) if __name__ == "__main__":
main()

Convert Pytorch code to Fabric screenshot

Supercharge PyTorch code

Unlock speed and scale in existing PyTorch projects without large refactors. In just a few lines, add accelerators, strategies, distributed checkpointing, and mixed precision to your PyTorch code.

Build your own Trainer

Use LightningModule hooks and Callbacks with Fabric, or build your own specialized Trainer out of the Fabric building blocks.

Callbacks screenshot

Jumpstart with Lightning

Docs

Glossary, tutorials, and code samples to guide you as you build with Lightning.

Discord

Ask questions, give feedback, and help define the direction of Lightning.

AI Education

Learn deep learning with a modern open source stack.

Forums

Learn deep learning with a modern open source stack

Sign up for the Lightning newsletter and release notes