Running multiple validation steps after each training epoch

Hi, I’m currently working on training a network with pl where I need to run validation multiple times after each training epoch. My goal is to log the mean and standard deviation of the validation accuracy across these runs (due to the random nature of the data loader, each run produces different results).

To achieve this, I have implemented a custom callback using PyTorch Lightning’s pl.Callback as follows:

class MultiValCallback(pl.Callback):
    def __init__(self,datamodule,multi_val_test_num):
        super().__init__()
        self.dm = datamodule
        self.multi_val_test_num = multi_val_test_num
    def on_validation_epoch_end(self, trainer, pl_module):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = copy.deepcopy(pl_module.model)
        model = model.to(device)
        model.eval()
        multi_val_acc = []
        for _ in range(self.multi_val_test_num):         
            correct = 0
            total = 0
            with torch.no_grad():
                for data in self.dm.val_dataloader():
                    data = data.to(device)
                    out = model(data)
                    label = torch.argmax(out, dim=1) 
                    correct += torch.sum(label == data.y)
                    total += data.y.shape[0]
            multi_val_acc.append((correct/total).clone().detach().cpu().numpy())
        pl_module.log('val/multi_acc/mean', np.mean(multi_val_acc), sync_dist=True)
        pl_module.log('val/multi_acc/std', np.std(multi_val_acc), sync_dist=True)

However, I’ve noticed that this approach is significantly slower than PyTorch Lightning’s native validation. Also, the code seems to not align with the intended workflow of PyTorch Lightning.

I’m seeking advice on the best practices to implement multiple validation steps. Are there more efficient ways to do it?

Any suggestions or alternative approaches would be greatly appreciated.

@hesamaraghi You can directly return a list of validation dataloaders [1] from the val_dataloader() hook and the validation loop will just iterate over them sequentially. So all you have to do is implement validation_step and validation_epoch_end to compute and log your metrics.

[1] Arbitrary iterable support — PyTorch Lightning 2.1.2 documentation

1 Like