Shortcuts

LearningRateFinder

class pytorch_lightning.callbacks.LearningRateFinder(min_lr=1e-08, max_lr=1, num_training_steps=100, mode='exponential', early_stop_threshold=4.0, update_attr=False)[source]

Bases: pytorch_lightning.callbacks.callback.Callback

The LearningRateFinder callback enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate.

Parameters:
  • min_lr (float) – Minimum learning rate to investigate

  • max_lr (float) – Maximum learning rate to investigate

  • num_training_steps (int) – Number of learning rates to test

  • mode (str) –

    Search strategy to update learning rate after each batch:

    • 'exponential' (default): Increases the learning rate exponentially.

    • 'linear': Increases the learning rate linearly.

  • early_stop_threshold (Optional[float]) – Threshold for stopping the search. If the loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None.

  • update_attr (bool) – Whether to update the learning rate attribute or not.

Example:

# Customize LearningRateFinder callback to run at different epochs.
# This feature is useful while fine-tuning models.
from pytorch_lightning.callbacks import LearningRateFinder


class FineTuneLearningRateFinder(LearningRateFinder):
    def __init__(self, milestones, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.milestones = milestones

    def on_fit_start(self, *args, **kwargs):
        return

    def on_train_epoch_start(self, trainer, pl_module):
        if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
            self.lr_find(trainer, pl_module)


trainer = Trainer(callbacks=[FineTuneLearningRateFinder(milestones=(5, 10))])
trainer.fit(...)
Raises:

MisconfigurationException – If learning rate/lr in model or model.hparams isn’t overridden when auto_lr_find=True, or if you are using more than one optimizer.

on_fit_start(trainer, pl_module)[source]

Called when fit begins.

Return type:

None