Source code for pytorch_lightning.loops.batch.training_batch_loop
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromtypingimportAny,List,Optional,Tuple,UnionfromdeprecateimportvoidfromtorchimportTensorfrompytorch_lightning.loops.baseimportLoopfrompytorch_lightning.loops.optimization.manual_loopimport_OUTPUTS_TYPEas_MANUAL_LOOP_OUTPUTS_TYPEfrompytorch_lightning.loops.optimization.manual_loopimportManualOptimizationfrompytorch_lightning.loops.optimization.optimizer_loopimport_OUTPUTS_TYPEas_OPTIMIZER_LOOP_OUTPUTS_TYPEfrompytorch_lightning.loops.optimization.optimizer_loopimportOptimizerLoopfrompytorch_lightning.loops.utilitiesimport_get_active_optimizersfrompytorch_lightning.trainer.supportersimportTensorRunningAccum_OUTPUTS_TYPE=List[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE,_MANUAL_LOOP_OUTPUTS_TYPE]]
[docs]classTrainingBatchLoop(Loop[_OUTPUTS_TYPE]):"""Runs over a single batch of data."""def__init__(self)->None:super().__init__()self.accumulated_loss=TensorRunningAccum(window_length=20)self.running_loss=TensorRunningAccum(window_length=20)# the current split index when the batch gets split into chunks in truncated backprop through timeself.split_idx:int=0self.optimizer_loop=OptimizerLoop()self.manual_loop=ManualOptimization()self._outputs:_OUTPUTS_TYPE=[]self._remaining_splits:List[Tuple[int,Any]]=[]@propertydefdone(self)->bool:"""Returns if all batch splits have been processed already."""returnlen(self._remaining_splits)==0
[docs]defreset(self)->None:"""Resets the loop state."""self._outputs=[]
[docs]defon_run_start(self,batch:Any,batch_idx:int)->None:# type: ignore[override]"""Splits the data into tbptt splits. Args: batch: the current batch to run the trainstep on batch_idx: the index of the current batch """void(batch_idx)self._remaining_splits=list(enumerate(self._tbptt_split_batch(batch)))
[docs]defadvance(self,batch:Any,batch_idx:int)->None:# type: ignore[override]"""Runs the train step together with optimization (if necessary) on the current batch split. Args: batch: the current batch to run the training on (this is not the split!) batch_idx: the index of the current batch """void(batch)self.split_idx,split_batch=self._remaining_splits.pop(0)self.trainer._logger_connector.on_train_split_start(self.split_idx)outputs:Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE,_MANUAL_LOOP_OUTPUTS_TYPE]]=None# for mypy# choose which loop will run the optimizationifself.trainer.lightning_module.automatic_optimization:optimizers=_get_active_optimizers(self.trainer.optimizers,self.trainer.optimizer_frequencies,batch_idx)outputs=self.optimizer_loop.run(split_batch,optimizers,batch_idx)else:outputs=self.manual_loop.run(split_batch,batch_idx)ifoutputs:# automatic: can be empty if all optimizers skip their batches# manual: #9052 added support for raising `StopIteration` in the `training_step`. If that happens,# then `advance` doesn't finish and an empty dict is returnedself._outputs.append(outputs)
[docs]defon_run_end(self)->_OUTPUTS_TYPE:self.optimizer_loop._hiddens=None# this is not necessary as the manual loop runs for only 1 iteration, but just in caseself.manual_loop._hiddens=Noneoutput,self._outputs=self._outputs,[]# free memoryself._remaining_splits=[]returnoutput
def_tbptt_split_batch(self,batch:Any)->List[Any]:"""Splits a single batch into a list of sequence steps for tbptt. Args: batch: the current batch to split """tbptt_steps=self.trainer.lightning_module.truncated_bptt_stepsiftbptt_steps==0:return[batch]splits=self.trainer._call_lightning_module_hook("tbptt_split_batch",batch,tbptt_steps)returnsplitsdef_update_running_loss(self,current_loss:Tensor)->None:"""Updates the running loss value with the current value."""ifself.trainer.lightning_module.automatic_optimization:# track total loss for logging (avoid mem leaks)self.accumulated_loss.append(current_loss)accumulated_loss=self.accumulated_loss.mean()ifaccumulated_lossisnotNone:# calculate running loss for displayself.running_loss.append(self.accumulated_loss.mean()*self.trainer.accumulate_grad_batches)# reset for next set of accumulated gradsself.accumulated_loss.reset()
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.