Hi, I’m currently working on training a network with pl where I need to run validation multiple times after each training epoch. My goal is to log the mean and standard deviation of the validation accuracy across these runs (due to the random nature of the data loader, each run produces different results).
To achieve this, I have implemented a custom callback using PyTorch Lightning’s pl.Callback
as follows:
class MultiValCallback(pl.Callback):
def __init__(self,datamodule,multi_val_test_num):
super().__init__()
self.dm = datamodule
self.multi_val_test_num = multi_val_test_num
def on_validation_epoch_end(self, trainer, pl_module):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = copy.deepcopy(pl_module.model)
model = model.to(device)
model.eval()
multi_val_acc = []
for _ in range(self.multi_val_test_num):
correct = 0
total = 0
with torch.no_grad():
for data in self.dm.val_dataloader():
data = data.to(device)
out = model(data)
label = torch.argmax(out, dim=1)
correct += torch.sum(label == data.y)
total += data.y.shape[0]
multi_val_acc.append((correct/total).clone().detach().cpu().numpy())
pl_module.log('val/multi_acc/mean', np.mean(multi_val_acc), sync_dist=True)
pl_module.log('val/multi_acc/std', np.std(multi_val_acc), sync_dist=True)
However, I’ve noticed that this approach is significantly slower than PyTorch Lightning’s native validation. Also, the code seems to not align with the intended workflow of PyTorch Lightning.
I’m seeking advice on the best practices to implement multiple validation steps. Are there more efficient ways to do it?
Any suggestions or alternative approaches would be greatly appreciated.