Hey there,
PL Bolts has this nice callback called SSLOnlineEvaluator to evaluate your model by stacking a (mlp) model on top of the features and training it to assess whether the features are meaningful, as done in SimCLR: 266-270.
My question is if it is possible to train the mlp for more than one epoch?
An unclean solution could be something like:
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
if batch_idx == 0:
train_loader = trainer.train_loader
epochs = 5
for epoch in range(epochs):
for batch in train_loader:
# forward + backward of mlp
But I think that this is not the way the callback is supposed to be used.
Best,
Temi