# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importinspectfromabcimportABC,abstractmethodfromtypingimportAny,Dict,Generic,Optional,Type,TypeVar,UnionfromdeprecateimportvoidfromtorchmetricsimportMetricimportpytorch_lightningasplfrompytorch_lightning.trainer.connectors.logger_connector.resultimport_ResultCollectionfrompytorch_lightning.trainer.progressimportBaseProgressfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.importsimport_fault_tolerant_trainingT=TypeVar("T")# the output type of `run`classLoop(ABC,Generic[T]):"""Basic Loops interface. All classes derived from this must implement the following properties and methods: * :attr:`done` (property): Condition to break the loop * :attr:`reset` (method): Resets the internal state between multiple calls of :attr:`run` * :attr:`advance` (method): Implements one step of the loop This class implements the following loop structure: .. code-block:: python on_run_start() while not done: on_advance_start() advance() on_advance_end() on_run_end() """def__init__(self)->None:self._restarting=Falseself._trainer:Optional["pl.Trainer"]=None@propertydeftrainer(self)->"pl.Trainer":ifself._trainerisNone:raiseRuntimeError("The loop is not attached to a Trainer.")returnself._trainer@trainer.setterdeftrainer(self,trainer:"pl.Trainer")->None:"""Connects this loop's trainer and its children."""self._trainer=trainerforvinself.__dict__.values():ifisinstance(v,Loop):v.trainer=trainer@propertydefrestarting(self)->bool:"""Whether the state of this loop was reloaded and it needs to restart."""returnself._restarting@restarting.setterdefrestarting(self,restarting:bool)->None:"""Connects this loop's restarting value and its children."""self._restarting=restartingforloopinvars(self).values():ifisinstance(loop,Loop):loop.restarting=restarting@property@abstractmethoddefdone(self)->bool:"""Property indicating when the loop is finished. Example:: @property def done(self): return self.trainer.global_step >= self.trainer.max_steps """@propertydefskip(self)->bool:"""Determine whether to return immediately from the call to :meth:`run`. Example:: @property def skip(self): return len(self.trainer.train_dataloader) == 0 """returnFalsedefconnect(self,**kwargs:"Loop")->None:"""Optionally connect one or multiple loops to this one. Linked loops should form a tree. """defreplace(self,**loops:Union["Loop",Type["Loop"]])->None:"""Optionally replace one or multiple of this loop's sub-loops. This method takes care of instantiating the class (if necessary) with all existing arguments, connecting all sub-loops of the old loop to the new instance, setting the ``Trainer`` reference, and connecting the new loop to the parent. Args: **loops: ``Loop`` subclasses or instances. The name used should match the loop attribute name you want to replace. Raises: MisconfigurationException: When passing a ``Loop`` class, if the ``__init__`` arguments do not match those of the Loop class it replaces. """new_loops={}forname,type_or_objectinloops.items():old_loop=getattr(self,name)ifisinstance(type_or_object,type):# compare the signaturesold_parameters=inspect.signature(old_loop.__class__.__init__).parameterscurrent_parameters=inspect.signature(type_or_object.__init__).parametersifold_parameters!=current_parameters:raiseMisconfigurationException(f"`{self.__class__.__name__}.replace({type_or_object.__name__})` can only be used if the"f" `__init__` signatures match but `{old_loop.__class__.__name__}` does not.")# instantiate the loopkwargs={p:getattr(old_loop,p)forpinold_parametersifp!="self"}loop=type_or_object(**kwargs)else:loop=type_or_object# connect sub-loopskwargs={n:lforn,linold_loop.__dict__.items()ifisinstance(l,Loop)}loop.connect(**kwargs)# set the trainer referenceloop.trainer=self.trainernew_loops[name]=loop# connect to selfself.connect(**new_loops)defon_skip(self)->T:"""The function to run when :meth:`run` should be skipped, determined by the condition in :attr:`skip`. Returns: the default output value of :meth:`on_run_end` """
[docs]defrun(self,*args:Any,**kwargs:Any)->T:"""The main entry point to the loop. Will frequently check the :attr:`done` condition and calls :attr:`advance` until :attr:`done` evaluates to ``True``. Override this if you wish to change the default behavior. The default implementation is: Example:: def run(self, *args, **kwargs): if self.skip: return self.on_skip() self.reset() self.on_run_start(*args, **kwargs) while not self.done: self.advance(*args, **kwargs) output = self.on_run_end() return output Returns: The output of :attr:`on_run_end` (often outputs collected from each step of the loop) """ifself.skip:returnself.on_skip()self.reset()self.on_run_start(*args,**kwargs)whilenotself.done:try:self.on_advance_start(*args,**kwargs)self.advance(*args,**kwargs)self.on_advance_end()self._restarting=FalseexceptStopIteration:breakself._restarting=Falseoutput=self.on_run_end()returnoutput
[docs]@abstractmethoddefreset(self)->None:"""Resets the internal state of the loop at the beginning of each call to :attr:`run`. Example:: def reset(self): # reset your internal state or add custom logic # if you expect run() to be called multiple times self.current_iteration = 0 self.outputs = [] """
defon_run_start(self,*args:Any,**kwargs:Any)->None:"""Hook to be called as the first thing after entering :attr:`run` (except the state reset). Accepts all arguments passed to :attr:`run`. """void(*args,**kwargs)defon_advance_start(self,*args:Any,**kwargs:Any)->None:"""Hook to be called each time before :attr:`advance` is called. Accepts all arguments passed to :attr`run`. """void(*args,**kwargs)
[docs]@abstractmethoddefadvance(self,*args:Any,**kwargs:Any)->None:"""Performs a single step. Accepts all arguments passed to :attr:`run`. Example:: def advance(self, iterator): batch = next(iterator) loss = self.trainer.lightning_module.training_step(batch, batch_idx) ... """
defon_advance_end(self)->None:"""Hook to be called each time after :attr:`advance` is called."""defon_run_end(self)->T:"""Hook to be called at the end of the run. Its return argument is returned from :attr:`run`. """defteardown(self)->None:"""Use to release memory etc."""defon_save_checkpoint(self)->Dict:"""Called when saving a model checkpoint, use to persist loop state. Returns: The current loop state. """return{}defon_load_checkpoint(self,state_dict:Dict)->None:"""Called when loading a model checkpoint, use to reload loop state."""defstate_dict(self,destination:Optional[Dict]=None,prefix:str="")->Dict:"""The state dict is determined by the state and progress of this loop and all its children. Args: destination: An existing dictionary to update with this loop's state. By default a new dictionary is returned. prefix: A prefix for each key in the state dictionary """ifdestinationisNone:destination={}destination[prefix+"state_dict"]=self.on_save_checkpoint()# do not get the mode from `self.trainer` because it might not have been attached yetft_enabled=_fault_tolerant_training()fork,vinself.__dict__.items():key=prefix+kifisinstance(v,BaseProgress):destination[key]=v.state_dict()elifisinstance(v,Loop):v.state_dict(destination,key+".")elifft_enabledandisinstance(v,_ResultCollection):# sync / unsync metricsv.sync()destination[key]=v.state_dict()v.unsync()returndestinationdefload_state_dict(self,state_dict:Dict,prefix:str="",metrics:Optional[Dict[str,Metric]]=None,)->None:"""Loads the state of this loop and all its children."""self._load_from_state_dict(state_dict.copy(),prefix,metrics)fork,vinself.__dict__.items():ifisinstance(v,Loop):v.load_state_dict(state_dict.copy(),prefix+k+".")self.restarting=Truedef_load_from_state_dict(self,state_dict:Dict,prefix:str,metrics:Optional[Dict[str,Metric]]=None)->None:trainer=self._trainerfork,vinself.__dict__.items():key=prefix+kifkeynotinstate_dict:# compatibility with old checkpointscontinueifisinstance(v,BaseProgress):v.load_state_dict(state_dict[key])elifisinstance(v,_ResultCollection)andtrainerisnotNoneandtrainer.lightning_moduleisnotNone:metric_attributes={name:moduleforname,moduleinself.trainer.lightning_module.named_modules()ifisinstance(module,Metric)}ifmetrics:metric_attributes.update(metrics)# The `_ResultCollection` objects have 2 types of metrics: `Tensor` and `torchmetrics.Metric`.# When creating a checkpoint, the `Metric`s are dropped from the loop `state_dict` to serialize only# Python primitives. However, their states are saved with the model's `state_dict`.# On reload, we need to re-attach the `Metric`s back to the `_ResultCollection`.# The references are provided through the `metric_attributes` dictionary.v.load_state_dict(state_dict[key],metrics=metric_attributes,sync_fn=self.trainer.strategy.reduce)ifnotself.trainer.is_global_zero:v.reset(metrics=False)ifprefix+"state_dict"instate_dict:# compatibility with old checkpointsself.on_load_checkpoint(state_dict[prefix+"state_dict"])
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.