Train at scale and retain control, with Fabric
Fabric is the easiest way to scale your models while maintaining full control over your training loop and inference logic.pip install lightning
Quick implementation
No need to refactor your code. Just swap a few lines and you can start unlocking scale.
Control your training loop
Fabric’s features are opt-in, making it easier to develop and debug your PyTorch code as you gradually add only the features you need.
Great for complex tasks
Full control over training logic and distributed mechanisms, for tasks like reinforcement learning and training large-scale transformers.
Scale PyTorch models
Fabric was designed with multi-billion parameter models in mind, allowing you to maximize performance while leveraging tools like callbacks and checkpoints when you need them. Write your own training logic without the boilerplate, right down to the individual optimizer calls.
import lightning as L
from torch import nn, optim, utils
import torchvision
def main():
# create models
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
params = list(encoder.parameters()) + list(decoder.parameters())
optimizer = optim.Adam(params, lr=1e-3)
# setup data
dataset = torchvision.datasets.MNIST(".", download=True, transform=torchvision.transforms.ToTensor())
train_loader = utils.data.DataLoader(dataset, batch_size=64)
# setup Fabric
fabric = L.Fabric()
encoder, optimizer = fabric.setup(encoder, optimizer)
decoder = fabric.setup(decoder)
train_loader = fabric.setup_dataloaders(train_loader)
# train the model
for epoch in range(2):
fabric.print("Epoch:", epoch)
for i, batch in enumerate(train_loader):
# get the inputs; data is a list of [inputs, labels]
x, y = batch
x = x.view(x.size(0), -1)
optimizer.zero_grad()
# forward + loss
z = encoder(x)
x_hat = decoder(z)
loss = nn.functional.mse_loss(x_hat, x)
# backward + optimize
fabric.backward(loss)
optimizer.step()
if i % 100 == 0:
fabric.print("train_loss", float(loss))
fabric.log("train_loss", loss)
if __name__ == "__main__":
main()
Supercharge PyTorch code
Unlock speed and scale in existing PyTorch projects without large refactors. In just a few lines, add accelerators, strategies, distributed checkpointing, and mixed precision to your PyTorch code.
Build your own Trainer
Use LightningModule hooks and Callbacks with Fabric, or build your own specialized Trainer out of the Fabric building blocks.