Source code for pytorch_lightning.callbacks.gradient_accumulation_scheduler
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.r"""Gradient Accumulator====================Change gradient accumulation factor according to scheduling.Trainer also calls ``optimizer.step()`` for the last indivisible step number."""fromtypingimportAny,Dictimportpytorch_lightningasplfrompytorch_lightning.callbacks.baseimportCallbackfrompytorch_lightning.utilities.exceptionsimportMisconfigurationException
[docs]classGradientAccumulationScheduler(Callback):r""" Change gradient accumulation factor according to scheduling. Args: scheduling: scheduling in format {epoch: accumulation_factor} Note: The argument scheduling is a dictionary. Each key represent an epoch and its associated accumulation factor value. Warning: Epoch are zero-indexed c.f it means if you want to change the accumulation factor after 4 epochs, set ``Trainer(accumulate_grad_batches={4: factor})`` or ``GradientAccumulationScheduler(scheduling={4: factor})``. For more info check the example below. Raises: TypeError: If ``scheduling`` is an empty ``dict``, or not all keys and values of ``scheduling`` are integers. IndexError: If ``minimal_epoch`` is less than 0. Example:: >>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import GradientAccumulationScheduler # from epoch 5, it starts accumulating every 2 batches. Here we have 4 instead of 5 # because epoch (key) should be zero-indexed. >>> accumulator = GradientAccumulationScheduler(scheduling={4: 2}) >>> trainer = Trainer(callbacks=[accumulator]) # alternatively, pass the scheduling dict directly to the Trainer >>> trainer = Trainer(accumulate_grad_batches={4: 2}) """def__init__(self,scheduling:Dict[int,int]):super().__init__()ifnotscheduling:# empty dict errorraiseTypeError("Empty dict cannot be interpreted correct")ifany(notisinstance(key,int)orkey<0forkeyinscheduling):raiseMisconfigurationException(f"Epoch should be an int greater than or equal to 0. Got {list(scheduling.keys())}.")ifany(notisinstance(value,int)orvalue<1forvalueinscheduling.values()):raiseMisconfigurationException(f"Accumulation factor should be an int greater than 0. Got {list(scheduling.values())}.")minimal_epoch=min(scheduling.keys())ifminimal_epoch<0:raiseIndexError(f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct")ifminimal_epoch!=0:# if user didnt define first epoch accumulation factorscheduling.update({0:1})self.scheduling=schedulingself.epochs=sorted(scheduling.keys())defgoing_to_accumulate_grad_batches(self)->bool:returnany(v>1forvinself.scheduling.values())defget_accumulate_grad_batches(self,epoch:int)->int:accumulate_grad_batches=1foriter_epochinreversed(self.epochs):ifepoch>=iter_epoch:accumulate_grad_batches=self.scheduling[iter_epoch]breakreturnaccumulate_grad_batches
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.