PyTorch Lightning ".validate()" returns empty list - [SOLVED]

I am new to lightning and was playing with a ResNet-18 CNN trained on a CIFAR-10 to get a working basic template:

batch_size = 512


# Define CIFAR-10 transformations for training and test sets-
transform_train = transforms.Compose(
    [
      transforms.RandomCrop(32, padding = 4),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize(
        mean = (0.4914, 0.4822, 0.4465),
        std = (0.0305, 0.0296, 0.0342)),
     ]
)

transform_test = transforms.Compose(
    [
      transforms.ToTensor(),
      transforms.Normalize(
        mean = (0.4942, 0.4846, 0.4498),
        std = (0.0304, 0.0295, 0.0342)),
     ]
)

train_dataset = torchvision.datasets.CIFAR10(
    root = '/home/majumdar/Downloads/.data', train = True,
    download = True, transform = transform_train
)

test_dataset = torchvision.datasets.CIFAR10(
    root = '/home/majumdar/Downloads/.data', train = False,
    download = True, transform = transform_test
)

train_loader = torch.utils.data.DataLoader(
    dataset = train_dataset, batch_size = batch_size,
    shuffle = True, num_workers = 4,
    pin_memory = True
    )

test_loader = torch.utils.data.DataLoader(
    dataset = test_dataset, batch_size = batch_size,
    shuffle = False, num_workers = 4,
    pin_memory = True
    )




# Define LightningModule-
class ResNet18_CIFAR10(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.model = ResNet18()
        self.train_hist = dict()


    def validation_step(self, batch, batch_idx):
        # Validation loop.
        x_t, y_t = batch
        pred_t = self.model(x_t)
        loss_t = F.cross_entropy(pred_t, y_t)

        self.train_hist['test_loss'] = loss_t.item()
        self.log('test_loss', loss_t, sync_dist = True)
        return loss_t

    
    def training_step(self, batch, batch_idx):
        # training_step() defines the training loop.
        # It's independent of forward().
        x, y = batch
        pred = self.model(x)
        loss = F.cross_entropy(pred, y)
        
        self.train_hist['train_loss'] = loss.item()
        # log to Tensorboard (if  installed) by default-
        self.log('train_loss', loss, sync_dist = True)
        return loss


    def configure_optimizers(self):
        optimizer = optim.Adam(params = self.parameters(), lr = 1e-3)
        return optimizer


'''
If you have tensorboard installed, you can use it for visualizing experiments.
Run this on your commandline and open your browser to http://localhost:6006/
tensorboard --logdir .
'''

# Instantiate Autoencoder-
model = ResNet18_CIFAR10()


# Train the model
# The Lightning Trainer “mixes” any LightningModule with any dataset and
# abstracts away all the engineering complexity needed for scale.
trainer = pl.Trainer(
    accelerator = 'gpu', devices = [0, 1, 2], strategy = 'ddp_notebook',
    limit_train_batches = 1.0, limit_val_batches = 1.0,
    max_epochs = 20
)

trainer.fit(
    model = model, train_dataloaders = train_loader,
    val_dataloaders = test_loader
)

After the training is done, I log out of the compute platform and log back in, therefore, the previous “trainer” and “model” are gone, except the saved checkpoint.

To test the trained model, I do:

trainer = pl.Trainer(
    accelerator = 'gpu', devices = 1,
    limit_train_batches = 0, limit_val_batches = 0,
)

model_trained = ResNet18_CIFAR10()

saved_checkpt = 'lightning_logs/version_0/checkpoints/epoch=19-step=660.ckpt'

trainer.validate(
    model = model_trained, dataloaders = test_loader,
    ckpt_path = saved_checkpt
    )

But this returns an empty list?

Solution:

# Perform validation using trained checkpoints-
model_trained = LeNet5_MNIST()

trainer = pl.Trainer(
    accelerator = 'cpu',
    limit_train_batches = 1.0, limit_val_batches = 1.0
)

path_to_ckpt = "lightning_logs/version_0/checkpoints/epoch=19-step=2360.ckpt"

validate = trainer.validate(model = model_trained,
             ckpt_path = path_to_ckpt,
             dataloaders = test_loader

)

'''
type(validate), len(validate)
# (list, 1)

validate[0]
# {'test_loss': 0.03873412311077118}

validate[0].keys()
# dict_keys(['test_loss'])
'''

print(f"validation loss = {validate[0]['test_loss']:.3f}")
# validation loss = 0.039