Source code for lightning.pytorch.callbacks.lr_finder

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
LearningRateFinder
==================

Finds optimal learning rate
"""

from typing import Optional

from typing_extensions import override

import lightning.pytorch as pl
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.tuner.lr_finder import _lr_find, _LRFinder
from lightning.pytorch.utilities.exceptions import _TunerExitException
from lightning.pytorch.utilities.seed import isolate_rng


[docs]class LearningRateFinder(Callback): """The ``LearningRateFinder`` callback enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate. .. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature. Args: min_lr: Minimum learning rate to investigate max_lr: Maximum learning rate to investigate num_training_steps: Number of learning rates to test mode: Search strategy to update learning rate after each batch: - ``'exponential'`` (default): Increases the learning rate exponentially. - ``'linear'``: Increases the learning rate linearly. early_stop_threshold: Threshold for stopping the search. If the loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None. update_attr: Whether to update the learning rate attribute or not. attr_name: Name of the attribute which stores the learning rate. The names 'learning_rate' or 'lr' get automatically detected. Otherwise, set the name here. Example:: # Customize LearningRateFinder callback to run at different epochs. # This feature is useful while fine-tuning models. from lightning.pytorch.callbacks import LearningRateFinder class FineTuneLearningRateFinder(LearningRateFinder): def __init__(self, milestones, *args, **kwargs): super().__init__(*args, **kwargs) self.milestones = milestones def on_fit_start(self, *args, **kwargs): return def on_train_epoch_start(self, trainer, pl_module): if trainer.current_epoch in self.milestones or trainer.current_epoch == 0: self.lr_find(trainer, pl_module) trainer = Trainer(callbacks=[FineTuneLearningRateFinder(milestones=(5, 10))]) trainer.fit(...) Raises: MisconfigurationException: If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden, or if you are using more than one optimizer. """ SUPPORTED_MODES = ("linear", "exponential") def __init__( self, min_lr: float = 1e-8, max_lr: float = 1, num_training_steps: int = 100, mode: str = "exponential", early_stop_threshold: Optional[float] = 4.0, update_attr: bool = True, attr_name: str = "", ) -> None: mode = mode.lower() if mode not in self.SUPPORTED_MODES: raise ValueError(f"`mode` should be either of {self.SUPPORTED_MODES}") self._min_lr = min_lr self._max_lr = max_lr self._num_training_steps = num_training_steps self._mode = mode self._early_stop_threshold = early_stop_threshold self._update_attr = update_attr self._attr_name = attr_name self._early_exit = False self.lr_finder: Optional[_LRFinder] = None def lr_find(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: with isolate_rng(): self.optimal_lr = _lr_find( trainer, pl_module, min_lr=self._min_lr, max_lr=self._max_lr, num_training=self._num_training_steps, mode=self._mode, early_stop_threshold=self._early_stop_threshold, update_attr=self._update_attr, attr_name=self._attr_name, ) if self._early_exit: raise _TunerExitException()
[docs] @override def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.lr_find(trainer, pl_module)