Source code for lightning.pytorch.callbacks.gradient_accumulation_scheduler
# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.r"""Gradient Accumulator====================Change gradient accumulation factor according to scheduling.Trainer also calls ``optimizer.step()`` for the last indivisible step number."""fromtypingimportAny,Dictimportlightning.pytorchasplfromlightning.pytorch.callbacks.callbackimportCallbackfromlightning.pytorch.utilities.exceptionsimportMisconfigurationExceptionfromlightning.pytorch.utilities.importsimport_LIGHTNING_COLOSSALAI_AVAILABLEfromlightning.pytorch.utilities.model_helpersimportis_overriddenfromlightning.pytorch.utilities.rank_zeroimportrank_zero_warn
[docs]classGradientAccumulationScheduler(Callback):r"""Change gradient accumulation factor according to scheduling. Args: scheduling: scheduling in format {epoch: accumulation_factor} Note: The argument scheduling is a dictionary. Each key represent an epoch and its associated accumulation factor value. Warning: Epoch are zero-indexed c.f it means if you want to change the accumulation factor after 4 epochs, set ``Trainer(accumulate_grad_batches={4: factor})`` or ``GradientAccumulationScheduler(scheduling={4: factor})``. For more info check the example below. Raises: TypeError: If ``scheduling`` is an empty ``dict``, or not all keys and values of ``scheduling`` are integers. IndexError: If ``minimal_epoch`` is less than 0. Example:: >>> from lightning.pytorch import Trainer >>> from lightning.pytorch.callbacks import GradientAccumulationScheduler # from epoch 5, it starts accumulating every 2 batches. Here we have 4 instead of 5 # because epoch (key) should be zero-indexed. >>> accumulator = GradientAccumulationScheduler(scheduling={4: 2}) >>> trainer = Trainer(callbacks=[accumulator]) """def__init__(self,scheduling:Dict[int,int]):super().__init__()ifnotscheduling:# empty dict errorraiseTypeError("Empty dict cannot be interpreted correct")ifany(notisinstance(key,int)orkey<0forkeyinscheduling):raiseMisconfigurationException(f"Epoch should be an int greater than or equal to 0. Got {list(scheduling.keys())}.")ifany(notisinstance(value,int)orvalue<1forvalueinscheduling.values()):raiseMisconfigurationException(f"Accumulation factor should be an int greater than 0. Got {list(scheduling.values())}.")minimal_epoch=min(scheduling.keys())ifminimal_epoch<0:raiseIndexError(f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct")ifminimal_epoch!=0:# if user didn't define first epoch accumulation factorscheduling.update({0:1})self.scheduling=schedulingself.epochs=sorted(scheduling.keys())defgoing_to_accumulate_grad_batches(self)->bool:returnany(v>1forvinself.scheduling.values())defget_accumulate_grad_batches(self,epoch:int)->int:accumulate_grad_batches=1foriter_epochinreversed(self.epochs):ifepoch>=iter_epoch:accumulate_grad_batches=self.scheduling[iter_epoch]breakreturnaccumulate_grad_batches
[docs]defon_train_start(self,trainer:"pl.Trainer",pl_module:"pl.LightningModule")->None:"""Performns a configuration validation before training starts and raises errors for incompatible settings."""ifnotpl_module.automatic_optimization:raiseRuntimeError("""Automatic gradient accumulation and the `GradientAccumulationScheduler` is not supported for manual optimization. Please remove the callback or switch to automatic optimization.""")overridden_optimizer_step=is_overridden("optimizer_step",pl_module)overridden_optimizer_zero_grad=is_overridden("optimizer_zero_grad",pl_module)going_to_accumulate_grad_batches=self.going_to_accumulate_grad_batches()has_overridden_optimization_functions=overridden_optimizer_steporoverridden_optimizer_zero_gradifhas_overridden_optimization_functionsandgoing_to_accumulate_grad_batches:rank_zero_warn("When using `Trainer(accumulate_grad_batches != 1)` and overriding"" `LightningModule.optimizer_{step,zero_grad}`, the hooks will not be called on every batch"" (rather, they are called on every optimization step).")# local import to avoid circular importfromlightning.pytorch.acceleratorsimportIPUAcceleratorfromlightning.pytorch.strategiesimportDeepSpeedStrategyunsupported_accelerators=(IPUAccelerator,)unsupported_strategies=[DeepSpeedStrategy]if_LIGHTNING_COLOSSALAI_AVAILABLE:fromlightning_colossalaiimportColossalAIStrategyunsupported_strategies.append(ColossalAIStrategy)ifisinstance(trainer.accelerator,unsupported_accelerators):raiseRuntimeError(f"The `{type(trainer.accelerator).__name__}` does not support `accumulate_grad_batches` changing"" between epochs.")ifisinstance(trainer.strategy,tuple(unsupported_strategies)):raiseRuntimeError(f"The `{type(trainer.strategy).__name__}` does not support `accumulate_grad_batches` changing"" between epochs.")iftrainer.accumulate_grad_batches!=1:raiseValueError("You have set `accumulate_grad_batches` and are using the `GradientAccumulationScheduler`"" callback. Either remove `accumulate_grad_batches` from the Trainer or remove the callback.")
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.