Auto-saving model weights

the way to go would be not to change the monitor argument in your callback, but as @ydcjeff suggested to use checkpoint_on in your validation_step/validation_epoch_end. So your trainer config would look like this:

from pytorch_lightning.callbacks import ModelCheckpoint

# DEFAULTS used by the Trainer
checkpoint_callback = ModelCheckpoint(
    save_top_k=1,
    verbose=True,
    mode='max',
)

trainer = Trainer(checkpoint_callback=checkpoint_callback)

and your validation phase either like this:

def validation_step(self, batch, batch_idx):
    acc = self.calculate_acc(batch)
    result= pl.EvalResult(checkpoint_on=val_acc) # for early stopping you could also use early_stop_on here
    result.log('val_acc', acc)
    return result

without any validation_epoch_end (per default the result will average your values for checkpointing now, see here for details) or you could also do it like this when you really want to sum it:

def validation_step(self, batch, batch_idx):
    acc = self.calculate_acc(batch)
    result= pl.EvalResult()
    result.log('val_acc', acc)
    return result

def validation_epoch_end(self, val_step_output):
    aggregated_val_acc = torch.sum(val_step_output.val_acc)
    end_result = pl.EvalResult(checkpoint_on=aggregated_val_acc)
    return end_result

Hope that helps :slight_smile:

1 Like