How to Organize PyTorch Into Lightning¶
To enable your code to work with Lightning, perform the following to organize PyTorch into Lightning.
1. Keep Your Computational Code¶
Keep your regular nn.Module architecture
import lightning.pytorch as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
class LitModel(nn.Module):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(28 * 28, 128)
self.layer_2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
return x
2. Configure Training Logic¶
In the training_step of the LightningModule configure how your training routine behaves with a batch of training data:
class LitModel(pl.LightningModule):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.encoder(x)
loss = F.cross_entropy(y_hat, y)
return loss
Note
If you need to fully own the training loop for complicated legacy projects, check out Own your loop.
3. Move Optimizer(s) and LR Scheduler(s)¶
Move your optimizers to the configure_optimizers()
hook.
class LitModel(pl.LightningModule):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.encoder.parameters(), lr=1e-3)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
4. Organize Validation Logic (optional)¶
If you need a validation loop, configure how your validation routine behaves with a batch of validation data:
class LitModel(pl.LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.encoder(x)
val_loss = F.cross_entropy(y_hat, y)
self.log("val_loss", val_loss)
Tip
trainer.validate()
loads the best checkpoint automatically by default if checkpointing was enabled during fitting.
5. Organize Testing Logic (optional)¶
If you need a test loop, configure how your testing routine behaves with a batch of test data:
class LitModel(pl.LightningModule):
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self.encoder(x)
test_loss = F.cross_entropy(y_hat, y)
self.log("test_loss", test_loss)
6. Configure Prediction Logic (optional)¶
If you need a prediction loop, configure how your prediction routine behaves with a batch of test data:
class LitModel(LightningModule):
def predict_step(self, batch, batch_idx):
x, y = batch
pred = self.encoder(x)
return pred
7. Remove any .cuda() or .to(device) Calls¶
Your LightningModule can automatically run on any hardware!
If you have any explicit calls to .cuda()
or .to(device)
, you can remove them since Lightning makes sure that the data coming from DataLoader
and all the Module
instances initialized inside LightningModule.__init__
are moved to the respective devices automatically.
If you still need to access the current device, you can use self.device
anywhere in your LightningModule
except in the __init__
and setup
methods.
class LitModel(LightningModule):
def training_step(self, batch, batch_idx):
z = torch.randn(4, 5, device=self.device)
...
Hint: If you are initializing a Tensor
within the LightningModule.__init__
method and want it to be moved to the device automatically you should call
register_buffer()
to register it as a parameter.
class LitModel(LightningModule):
def __init__(self):
super().__init__()
self.register_buffer("running_mean", torch.zeros(num_features))
8. Use your own data¶
Regular PyTorch DataLoaders work with Lightning. For more modular and scalable datasets, check out LightningDataModule.
Good to know¶
Additionally, you can run only the validation loop using validate()
method.
model = LitModel()
trainer.validate(model)
Note
model.eval()
and torch.no_grad()
are called automatically for validation.
The test loop isn’t used within fit()
, therefore, you would need to explicitly call test()
.
model = LitModel()
trainer.test(model)
Note
model.eval()
and torch.no_grad()
are called automatically for testing.
Tip
trainer.test()
loads the best checkpoint automatically by default if checkpointing is enabled.
The predict loop will not be used until you call predict()
.
model = LitModel()
trainer.predict(model)
Note
model.eval()
and torch.no_grad()
are called automatically for testing.
Tip
trainer.predict()
loads the best checkpoint automatically by default if checkpointing is enabled.