Source code for lightning.pytorch.callbacks.lr_finder
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
LearningRateFinder
==================
Finds optimal learning rate
"""
from typing import Optional
import lightning.pytorch as pl
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.tuner.lr_finder import _lr_find, _LRFinder
from lightning.pytorch.utilities.exceptions import _TunerExitException
from lightning.pytorch.utilities.seed import isolate_rng
[docs]class LearningRateFinder(Callback):
    """The ``LearningRateFinder`` callback enables the user to do a range test of good initial learning rates, to
    reduce the amount of guesswork in picking a good starting learning rate.
    .. warning::  This is an :ref:`experimental <versioning:Experimental API>` feature.
    Args:
        min_lr: Minimum learning rate to investigate
        max_lr: Maximum learning rate to investigate
        num_training_steps: Number of learning rates to test
        mode: Search strategy to update learning rate after each batch:
            - ``'exponential'`` (default): Increases the learning rate exponentially.
            - ``'linear'``: Increases the learning rate linearly.
        early_stop_threshold: Threshold for stopping the search. If the
            loss at any point is larger than early_stop_threshold*best_loss
            then the search is stopped. To disable, set to None.
        update_attr: Whether to update the learning rate attribute or not.
        attr_name: Name of the attribute which stores the learning rate. The names 'learning_rate' or 'lr' get
            automatically detected. Otherwise, set the name here.
    Example::
        # Customize LearningRateFinder callback to run at different epochs.
        # This feature is useful while fine-tuning models.
        from lightning.pytorch.callbacks import LearningRateFinder
        class FineTuneLearningRateFinder(LearningRateFinder):
            def __init__(self, milestones, *args, **kwargs):
                super().__init__(*args, **kwargs)
                self.milestones = milestones
            def on_fit_start(self, *args, **kwargs):
                return
            def on_train_epoch_start(self, trainer, pl_module):
                if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
                    self.lr_find(trainer, pl_module)
        trainer = Trainer(callbacks=[FineTuneLearningRateFinder(milestones=(5, 10))])
        trainer.fit(...)
    Raises:
        MisconfigurationException:
            If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden, or if you are using more than
            one optimizer.
    """
    SUPPORTED_MODES = ("linear", "exponential")
    def __init__(
        self,
        min_lr: float = 1e-8,
        max_lr: float = 1,
        num_training_steps: int = 100,
        mode: str = "exponential",
        early_stop_threshold: Optional[float] = 4.0,
        update_attr: bool = True,
        attr_name: str = "",
    ) -> None:
        mode = mode.lower()
        if mode not in self.SUPPORTED_MODES:
            raise ValueError(f"`mode` should be either of {self.SUPPORTED_MODES}")
        self._min_lr = min_lr
        self._max_lr = max_lr
        self._num_training_steps = num_training_steps
        self._mode = mode
        self._early_stop_threshold = early_stop_threshold
        self._update_attr = update_attr
        self._attr_name = attr_name
        self._early_exit = False
        self.lr_finder: Optional[_LRFinder] = None
    def lr_find(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        with isolate_rng():
            self.optimal_lr = _lr_find(
                trainer,
                pl_module,
                min_lr=self._min_lr,
                max_lr=self._max_lr,
                num_training=self._num_training_steps,
                mode=self._mode,
                early_stop_threshold=self._early_stop_threshold,
                update_attr=self._update_attr,
                attr_name=self._attr_name,
            )
        if self._early_exit:
            raise _TunerExitException()
[docs]    def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        self.lr_find(trainer, pl_module)