• Docs >
  • Convert PyTorch code to Fabric
Shortcuts

Convert PyTorch code to Fabric

Here are five easy steps to let Fabric scale your PyTorch models.

Step 1: Create the Fabric object at the beginning of your training code.

from lightning.fabric import Fabric

fabric = Fabric()

Step 2: Call setup() on each model and optimizer pair and setup_dataloaders() on all your data loaders.

model, optimizer = fabric.setup(model, optimizer)
dataloader = fabric.setup_dataloaders(dataloader)

Step 3: Remove all .to and .cuda calls since Fabric will take care of it.

- model.to(device)
- batch.to(device)

Step 4: Replace loss.backward() by fabric.backward(loss).

- loss.backward()
+ fabric.backward(loss)

Step 5: Run the script from the terminal with

lightning run model path/to/train.py

or use the launch() method in a notebook. Learn more about launching distributed training.


All steps combined, this is how your code will change:

  import torch
  import torch.nn as nn
  from torch.utils.data import DataLoader, Dataset

+ from lightning.fabric import Fabric

  class PyTorchModel(nn.Module):
      ...

  class PyTorchDataset(Dataset):
      ...

+ fabric = Fabric(accelerator="cuda", devices=8, strategy="ddp")
+ fabric.launch()

- device = "cuda" if torch.cuda.is_available() else "cpu
  model = PyTorchModel(...)
  optimizer = torch.optim.SGD(model.parameters())
+ model, optimizer = fabric.setup(model, optimizer)
  dataloader = DataLoader(PyTorchDataset(...), ...)
+ dataloader = fabric.setup_dataloaders(dataloader)
  model.train()

  for epoch in range(num_epochs):
      for batch in dataloader:
          input, target = batch
-         input, target = input.to(device), target.to(device)
          optimizer.zero_grad()
          output = model(input)
          loss = loss_fn(output, target)
-         loss.backward()
+         fabric.backward(loss)
          optimizer.step()
          lr_scheduler.step()

That’s it! You can now train on any device at any scale with a switch of a flag. Check out our before-and-after example for image classification and many more examples that use Fabric.


© Copyright Copyright (c) 2018-2023, Lightning AI et al...

Built with Sphinx using a theme provided by Read the Docs.