Validate and test a model (basic)

Audience: Users who want to add a validation loop to avoid overfitting


Add a test loop

To make sure a model can generalize to an unseen dataset (ie: to publish a paper or in a production environment) a dataset is normally split into two parts, the train split and the test split.

The test set is NOT used during training, it is ONLY used once the model has been trained to see how the model will do in the real-world.


Find the train and test splits

Datasets come with two splits. Refer to the dataset documentation to find the train and test splits.

import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms

# Load data sets
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)

Define the test loop

To add a test loop, implement the test_step method of the LightningModule

class LitAutoEncoder(L.LightningModule):
    def training_step(self, batch, batch_idx):
        ...

    def test_step(self, batch, batch_idx):
        # this is the test loop
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        test_loss = F.mse_loss(x_hat, x)
        self.log("test_loss", test_loss)

Train with the test loop

Once the model has finished training, call .test

from torch.utils.data import DataLoader

# initialize the Trainer
trainer = Trainer()

# test the model
trainer.test(model, dataloaders=DataLoader(test_set))

Add a validation loop

During training, it’s common practice to use a small portion of the train split to determine when the model has finished training.


Split the training data

As a rule of thumb, we use 20% of the training set as the validation set. This number varies from dataset to dataset.

# use 20% of training data for validation
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# split the train set into two
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)

Define the validation loop

To add a validation loop, implement the validation_step method of the LightningModule

class LitAutoEncoder(L.LightningModule):
    def training_step(self, batch, batch_idx):
        ...

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        val_loss = F.mse_loss(x_hat, x)
        self.log("val_loss", val_loss)

Train with the validation loop

To run the validation loop, pass in the validation set to .fit

from torch.utils.data import DataLoader

train_loader = DataLoader(train_set)
valid_loader = DataLoader(valid_set)
model = LitAutoEncoder(...)

# train with both splits
trainer = L.Trainer()
trainer.fit(model, train_loader, valid_loader)