Convert PyTorch code to Fabric¶
Here are five easy steps to let Fabric
scale your PyTorch models.
Step 1: Create the Fabric
object at the beginning of your training code.
from lightning.fabric import Fabric
fabric = Fabric()
Step 2: Call launch()
if you intend to use multiple devices (e.g., multi-GPU).
fabric.launch()
Step 3: Call setup()
on each model and optimizer pair and setup_dataloaders()
on all your data loaders.
model, optimizer = fabric.setup(model, optimizer)
dataloader = fabric.setup_dataloaders(dataloader)
Step 4: Remove all .to
and .cuda
calls since Fabric
will take care of it.
- model.to(device)
- batch.to(device)
Step 5: Replace loss.backward()
by fabric.backward(loss)
.
- loss.backward()
+ fabric.backward(loss)
These are all code changes required to prepare your script for Fabric. You can now simply run from the terminal:
python path/to/your/script.py
All steps combined, this is how your code will change:
import torch
from lightning.pytorch.demos import WikiText2, Transformer
+ import lightning as L
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ fabric = L.Fabric(accelerator="cuda", devices=8, strategy="ddp")
+ fabric.launch()
dataset = WikiText2()
dataloader = torch.utils.data.DataLoader(dataset)
model = Transformer(vocab_size=dataset.vocab_size)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
- model = model.to(device)
+ model, optimizer = fabric.setup(model, optimizer)
+ dataloader = fabric.setup_dataloaders(dataloader)
model.train()
for epoch in range(20):
for batch in dataloader:
input, target = batch
- input, target = input.to(device), target.to(device)
optimizer.zero_grad()
output = model(input, target)
loss = torch.nn.functional.nll_loss(output, target.view(-1))
- loss.backward()
+ fabric.backward(loss)
optimizer.step()
That’s it! You can now train on any device at any scale with a switch of a flag. Check out our before-and-after example for image classification and many more examples that use Fabric.
Optional changes¶
Here are a few optional upgrades you can make to your code, if applicable:
Replace
torch.save()
andtorch.load()
with Fabric’s save and load methods.Replace collective operations from
torch.distributed
(barrier, broadcast, etc.) with Fabric’s collective methods.Use Fabric’s no_backward_sync() context manager if you implemented gradient accumulation.
Initialize your model under the init_module() context manager.