Validate and test a model (basic)¶
Audience: Users who want to add a validation loop to avoid overfitting
Add a test loop¶
To make sure a model can generalize to an unseen dataset (ie: to publish a paper or in a production environment) a dataset is normally split into two parts, the train split and the test split.
The test set is NOT used during training, it is ONLY used once the model has been trained to see how the model will do in the real-world.
Find the train and test splits¶
Datasets come with two splits. Refer to the dataset documentation to find the train and test splits.
import torch.utils.data as data
from torchvision import datasets
# Load data sets
train_set = datasets.MNIST(root="MNIST", download=True, train=True)
test_set = datasets.MNIST(root="MNIST", download=True, train=False)
Define the test loop¶
To add a test loop, implement the test_step method of the LightningModule
class LitAutoEncoder(pl.LightningModule):
def training_step(self, batch, batch_idx):
...
def test_step(self, batch, batch_idx):
# this is the test loop
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
test_loss = F.mse_loss(x_hat, x)
self.log("test_loss", test_loss)
Train with the test loop¶
Once the model has finished training, call .test
from torch.utils.data import DataLoader
# initialize the Trainer
trainer = Trainer()
# test the model
trainer.test(model, dataloaders=DataLoader(test_set))
Add a validation loop¶
During training, it’s common practice to use a small portion of the train split to determine when the model has finished training.
Split the training data¶
As a rule of thumb, we use 20% of the training set as the validation set. This number varies from dataset to dataset.
# use 20% of training data for validation
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size
# split the train set into two
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)
Define the validation loop¶
To add a validation loop, implement the validation_step method of the LightningModule
class LitAutoEncoder(pl.LightningModule):
def training_step(self, batch, batch_idx):
...
def validation_step(self, batch, batch_idx):
# this is the validation loop
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
test_loss = F.mse_loss(x_hat, x)
self.log("val_loss", test_loss)
Train with the validation loop¶
To run the validation loop, pass in the validation set to .fit
from torch.utils.data import DataLoader
train_set = DataLoader(train_set)
val_set = DataLoader(val_set)
# train with both splits
trainer = Trainer()
trainer.fit(model, train_set, val_set)