Test set

Lightning forces the user to run the test set separately to make sure it isn’t evaluated by mistake. Testing is performed using the trainer object’s .test() method.

Trainer.test(model=None, dataloaders=None, ckpt_path=None, verbose=True, datamodule=None, test_dataloaders=None)[source]

Perform one evaluation epoch over the test set. It’s separated from fit to make sure you never run on your test set until you want to.

Return type

List[Dict[str, float]]


List of dictionaries with metrics logged during the test phase, e.g., in model- or callback hooks like test_step(), test_epoch_end(), etc. The length of the list corresponds to the number of test dataloaders used.

Test after fit

To run the test set after training completes, use this method.

# run full training

# (1) load the best checkpoint automatically (lightning tracks this for you)

# (2) test using a specific checkpoint

# (3) test with an explicit model (will use this model and not load a checkpoint)

Test multiple models

You can run the test set on multiple models using the same trainer instance.

model1 = LitModel()
model2 = GANModel()

trainer = Trainer()

Test pre-trained model

To run the test set on a pre-trained model, use this method.

model = MyLightningModule.load_from_checkpoint(

# init trainer with whatever options
trainer = Trainer(...)

# test (pass in the model)

In this case, the options you pass to trainer will be used when running the test set (ie: 16-bit, dp, ddp, etc…)

Test with additional data loaders

You can still run inference on a test set even if the test_dataloader method hasn’t been defined within your lightning module instance. This would be the case when your test data is not available at the time your model was declared.

# setup your data loader
test_dataloader = DataLoader(...)

# test (pass in the loader)

You can either pass in a single dataloader or a list of them. This optional named parameter can be used in conjunction with any of the above use cases. Additionally, you can also pass in an datamodules that have overridden the test_dataloader method.

class MyDataModule(pl.LightningDataModule):

    def test_dataloader(self):
        return DataLoader(...)

# setup your datamodule
dm = MyDataModule(...)

# test (pass in datamodule)