How to Organize PyTorch Into Lightning¶
To enable your code to work with Lightning, perform the following to organize PyTorch into Lightning.
1. Move your Computational Code¶
Move the model architecture and forward pass to your LightningModule
.
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(28 * 28, 128)
self.layer_2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
return x
2. Move the Optimizer(s) and LR Scheduler(s)¶
Move your optimizers to the configure_optimizers()
hook.
class LitModel(pl.LightningModule):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
3. Configure the Training Logic¶
Lightning automates the training loop for you and manages all of the associated components such as: epoch and batch tracking, optimizers and schedulers,
and metric reduction. As a user, you just need to define how your model behaves with a batch of training data within the
training_step()
method. When using Lightning, simply override the
training_step()
method which takes the current batch
and the batch_idx
as arguments. Optionally, it can take optimizer_idx
if your LightningModule defines multiple optimizers within its
configure_optimizers()
hook.
class LitModel(pl.LightningModule):
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
4. Configure the Validation Logic¶
Lightning also automates the validation loop for you and manages all of the associated components such as: epoch and batch tracking, and metrics reduction. As a user,
you just need to define how your model behaves with a batch of validation data within the validation_step()
method. When using Lightning, simply override the validation_step()
method which takes the current
batch
and the batch_idx
as arguments. Optionally, it can take dataloader_idx
if you configure multiple dataloaders.
To add an (optional) validation loop add logic to the
validation_step()
hook (make sure to use the hook parameters, batch
and batch_idx
in this case).
class LitModel(pl.LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
val_loss = F.cross_entropy(y_hat, y)
self.log("val_loss", val_loss)
Additionally, you can run only the validation loop using validate()
method.
model = LitModel()
trainer.validate(model)
Note
model.eval()
and torch.no_grad()
are called automatically for validation.
Tip
trainer.validate()
loads the best checkpoint automatically by default if checkpointing was enabled during fitting.
5. Configure Testing Logic¶
Lightning automates the testing loop for you and manages all the associated components, such as epoch and batch tracking, metrics reduction. As a user,
you just need to define how your model behaves with a batch of testing data within the test_step()
method. When using Lightning, simply override the test_step()
method which takes the current
batch
and the batch_idx
as arguments. Optionally, it can take dataloader_idx
if you configure multiple dataloaders.
class LitModel(pl.LightningModule):
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
test_loss = F.cross_entropy(y_hat, y)
self.log("test_loss", test_loss)
The test loop isn’t used within fit()
, therefore, you would need to explicitly call test()
.
model = LitModel()
trainer.test(model)
Note
model.eval()
and torch.no_grad()
are called automatically for testing.
Tip
trainer.test()
loads the best checkpoint automatically by default if checkpointing is enabled.
6. Configure Prediction Logic¶
Lightning automates the prediction loop for you and manages all of the associated components such as epoch and batch tracking. As a user,
you just need to define how your model behaves with a batch of data within the predict_step()
method. When using Lightning, simply override the predict_step()
method which takes the current
batch
and the batch_idx
as arguments. Optionally, it can take dataloader_idx
if you configure multiple dataloaders.
If you don’t override predict_step
hook, it by default calls forward()
method on the batch.
class LitModel(LightningModule):
def predict_step(self, batch, batch_idx):
x, y = batch
pred = self(x)
return pred
The predict loop will not be used until you call predict()
.
model = LitModel()
trainer.predict(model)
Note
model.eval()
and torch.no_grad()
are called automatically for testing.
Tip
trainer.predict()
loads the best checkpoint automatically by default if checkpointing is enabled.
7. Remove any .cuda() or .to(device) Calls¶
Your LightningModule can automatically run on any hardware!
If you have any explicit calls to .cuda()
or .to(device)
, you can remove them since Lightning makes sure that the data coming from DataLoader
and all the Module
instances initialized inside LightningModule.__init__
are moved to the respective devices automatically.
If you still need to access the current device, you can use self.device
anywhere in your LightningModule
except in the __init__
and setup
methods.
class LitModel(LightningModule):
def training_step(self, batch, batch_idx):
z = torch.randn(4, 5, device=self.device)
...
Hint: If you are initializing a Tensor
within the LightningModule.__init__
method and want it to be moved to the device automatically you should call
register_buffer()
to register it as a parameter.
class LitModel(LightningModule):
def __init__(self):
super().__init__()
self.register_buffer("running_mean", torch.zeros(num_features))
8. Use your own data¶
To use your DataLoaders, you can override the respective dataloader hooks in the LightningModule
:
class LitModel(LightningModule):
def train_dataloader(self):
return DataLoader(...)
def val_dataloader(self):
return DataLoader(...)
def test_dataloader(self):
return DataLoader(...)
def predict_dataloader(self):
return DataLoader(...)
Alternatively, you can pass your dataloaders in one of the following ways:
Pass in the dataloaders explictly inside
trainer.fit/.validate/.test/.predict
calls.Use a LightningDataModule.
Checkout Managing Data doc to understand data management within Lightning.