Stopping an Epoch Early¶
You can stop and skip the rest of the current epoch early by overriding
on_train_batch_start() to return
-1 when some condition is met.
If you do this repeatedly, for every epoch you had originally requested, then this will stop your entire training.
EarlyStopping callback can be used to monitor a metric and stop the training when no improvement is observed.
To enable it:
Log the metric you want to monitor using
Init the callback, and set
monitorto the logged metric of your choice.
modebased on the metric needs to be monitored.
EarlyStoppingcallback to the
from lightning.pytorch.callbacks.early_stopping import EarlyStopping class LitModel(LightningModule): def validation_step(self, batch, batch_idx): loss = ... self.log("val_loss", loss) model = LitModel() trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min")]) trainer.fit(model)
You can customize the callbacks behaviour by changing its parameters.
early_stop_callback = EarlyStopping(monitor="val_accuracy", min_delta=0.00, patience=3, verbose=False, mode="max") trainer = Trainer(callbacks=[early_stop_callback])
Additional parameters that stop training at extreme points:
stopping_threshold: Stops training immediately once the monitored quantity reaches this threshold. It is useful when we know that going beyond a certain optimal value does not further benefit us.
divergence_threshold: Stops training as soon as the monitored quantity becomes worse than this threshold. When reaching a value this bad, we believes the model cannot recover anymore and it is better to stop early and run with different initial conditions.
check_finite: When turned on, it stops training if the monitored metric becomes NaN or infinite.
check_on_train_epoch_end: When turned on, it checks the metric at the end of a training epoch. Use this only when you are monitoring any metric logged within training-specific hooks on epoch-level.
In case you need early stopping in a different part of training, subclass
and change where it is called:
class MyEarlyStopping(EarlyStopping): def on_validation_end(self, trainer, pl_module): # override this to disable early stopping at the end of val loop pass def on_train_end(self, trainer, pl_module): # instead, do it at the end of training loop self._run_early_stopping_check(trainer)
EarlyStopping callback runs
at the end of every validation epoch by default. However, the frequency of validation
can be modified by setting various parameters in the
It must be noted that the
patience parameter counts the number of
validation checks with no improvement, and not the number of training epochs.
Therefore, with parameters
patience=3, the trainer
will perform at least 40 training epochs before being stopped.