Audience: Users looking to train models in interactive notebooks (Jupyter, Colab, Kaggle, etc.).

Lightning in notebooks

You can use the Lightning Trainer in interactive notebooks just like in a regular Python script, including multi-GPU training!

import lightning as L

# Works in Jupyter, Colab and Kaggle!
trainer = L.Trainer(accelerator="auto", devices="auto")

You can find many notebook examples on our tutorials page too!

Full example

Paste the following code block into a notebook cell:

import lightning as L
from torch import nn, optim, utils
import torchvision

encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-3)

    def prepare_data(self):
        torchvision.datasets.MNIST(".", download=True)

    def train_dataloader(self):
        dataset = torchvision.datasets.MNIST(".", transform=torchvision.transforms.ToTensor())
        return utils.data.DataLoader(dataset, batch_size=64)

autoencoder = LitAutoEncoder(encoder, decoder)
trainer = L.Trainer(max_epochs=2, devices="auto")

Multi-GPU Limitations

The multi-GPU capabilities in Jupyter are enabled by launching processes using the ‘fork’ start method. It is the only supported way of multi-processing in notebooks, but also brings some limitations that you should be aware of.

Avoid initializing CUDA before .fit()

Don’t run torch CUDA functions before calling trainer.fit() in any of the notebook cells beforehand, otherwise your code may hang or crash.

# BAD: Don't run CUDA-related code before `.fit()`
x = torch.tensor(1).cuda()

trainer = L.Trainer(accelerator="cuda", devices=2)

Move data loading code inside the hooks

If you define/load your data in the main process before calling trainer.fit(), you may see a slowdown or crashes (segmentation fault, SIGSEV, etc.).

# BAD: Don't load data in the main process
dataset = MyDataset("data/")
train_dataloader = torch.utils.data.DataLoader(dataset)

trainer = L.Trainer(accelerator="cuda", devices=2)
trainer.fit(model, train_dataloader)

The best practice is to move your data loading code inside the *_dataloader() hooks in the LightningModule or LightningDataModule as shown in the example above.