Source code for pytorch_lightning.strategies.strategy
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importcontextlibfromabcimportABC,abstractmethodfromtypingimportAny,Callable,Dict,Generator,List,Mapping,Optional,Tuple,TypeVar,UnionimporttorchfromtorchimportTensorfromtorch.nnimportModulefromtorch.optimimportOptimizerfromtorch.utils.dataimportDataLoaderimportpytorch_lightningasplfrompytorch_lightning.core.optimizerimport_init_optimizers_and_lr_schedulers,LightningOptimizerfrompytorch_lightning.overrides.baseimportunwrap_lightning_modulefrompytorch_lightning.pluginsimportTorchCheckpointIOfrompytorch_lightning.plugins.io.checkpoint_pluginimportCheckpointIOfrompytorch_lightning.plugins.precisionimportPrecisionPluginfrompytorch_lightning.strategies.launchers.baseimport_Launcherfrompytorch_lightning.trainer.statesimportTrainerFnfrompytorch_lightning.utilitiesimportrank_zero_deprecationfrompytorch_lightning.utilities.apply_funcimportmove_data_to_devicefrompytorch_lightning.utilities.distributedimportReduceOpfrompytorch_lightning.utilities.model_helpersimportis_overriddenfrompytorch_lightning.utilities.optimizerimportoptimizer_to_device,optimizers_to_devicefrompytorch_lightning.utilities.typesimport_PATH,LRSchedulerConfig,STEP_OUTPUTTBroadcast=TypeVar("TBroadcast")
[docs]classStrategy(ABC):"""Base class for all strategies that change the behaviour of the training, validation and test- loop."""def__init__(self,accelerator:Optional["pl.accelerators.accelerator.Accelerator"]=None,checkpoint_io:Optional[CheckpointIO]=None,precision_plugin:Optional[PrecisionPlugin]=None,)->None:self.accelerator=acceleratorself._launcher:Optional[_Launcher]=Noneself._model:Optional[Module]=Noneself.checkpoint_io=checkpoint_ioself.precision_plugin=precision_pluginself._optimizers:List[Optimizer]=[]self._lightning_optimizers:Dict[int,LightningOptimizer]={}self.lr_scheduler_configs:List[LRSchedulerConfig]=[]self.optimizer_frequencies:List[int]=[]ifis_overridden("post_dispatch",self,parent=Strategy):rank_zero_deprecation(f"`{self.__class__.__name__}.post_dispatch()` has been deprecated in v1.6 and will be removed in v1.7."f" Move your implementation to `{self.__class__.__name__}.teardown()` instead.")@propertydeflauncher(self)->Optional[_Launcher]:returnself._launcher@propertydefaccelerator(self)->"pl.accelerators.accelerator.Accelerator":returnself._accelerator@accelerator.setterdefaccelerator(self,accelerator:"pl.accelerators.accelerator.Accelerator")->None:self._accelerator=accelerator@propertydefcheckpoint_io(self)->CheckpointIO:returnself._checkpoint_ioifself._checkpoint_ioisnotNoneelseTorchCheckpointIO()@checkpoint_io.setterdefcheckpoint_io(self,io:Optional[CheckpointIO])->None:self._checkpoint_io=io@propertydefprecision_plugin(self)->PrecisionPlugin:returnself._precision_pluginifself._precision_pluginisnotNoneelsePrecisionPlugin()@precision_plugin.setterdefprecision_plugin(self,precision_plugin:Optional[PrecisionPlugin])->None:self._precision_plugin=precision_plugin@propertydefoptimizers(self)->List[Optimizer]:returnself._optimizers@optimizers.setterdefoptimizers(self,optimizers:List[Optimizer])->None:self._optimizers=optimizersself._lightning_optimizers={idx:LightningOptimizer._to_lightning_optimizer(opt,self,idx)foridx,optinenumerate(self.optimizers)}
[docs]defconnect(self,model:Module)->None:"""Called by the accelerator to connect the accelerator and the model with this plugin."""self.model=model
def_configure_launcher(self):"""Attach the launcher based on Strategy."""
[docs]defsetup_environment(self)->None:"""Setup any processes or distributed connections. This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator environment before setup is complete. """self.accelerator.setup_environment(self.root_device)
[docs]defsetup_optimizers(self,trainer:"pl.Trainer")->None:"""Creates optimizers and schedulers. Args: trainer: the Trainer, these optimizers should be connected to """iftrainer.state.fnnotin(TrainerFn.FITTING,TrainerFn.TUNING):returnself.optimizers,self.lr_scheduler_configs,self.optimizer_frequencies=_init_optimizers_and_lr_schedulers(self.lightning_module)
[docs]defsetup(self,trainer:"pl.Trainer")->None:"""Setup plugins for the trainer fit and creates optimizers. Args: trainer: the trainer instance """self.accelerator.setup(trainer)self.setup_optimizers(trainer)self.setup_precision_plugin()optimizers_to_device(self.optimizers,self.root_device)
[docs]defsetup_precision_plugin(self)->None:"""Attaches the precision plugin to the accelerator."""model,optimizers,lr_scheduler_configs=self.precision_plugin.connect(self.model,self.optimizers,self.lr_scheduler_configs)self.model=modelself.optimizers=optimizersself.lr_scheduler_configs=lr_scheduler_configs
[docs]defoptimizer_state(self,optimizer:Optimizer)->Dict[str,Tensor]:"""Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom plugins. """returnoptimizer.state_dict()
[docs]defbackward(self,closure_loss:Tensor,*args:Any,**kwargs:Any)->Tensor:"""Forwards backward-calls to the precision plugin. Args: closure_loss: a tensor holding the loss value to backpropagate """self.pre_backward(closure_loss)closure_loss=self.precision_plugin.pre_backward(self.lightning_module,closure_loss)self.precision_plugin.backward(self.lightning_module,closure_loss,*args,**kwargs)closure_loss=self.precision_plugin.post_backward(self.lightning_module,closure_loss)self.post_backward(closure_loss)returnclosure_loss
[docs]defoptimizer_step(self,optimizer:Optimizer,opt_idx:int,closure:Callable[[],Any],model:Optional[Union["pl.LightningModule",Module]]=None,**kwargs:Any,)->Any:"""Performs the actual optimizer step. Args: optimizer: the optimizer performing the step opt_idx: index of the current optimizer closure: closure calculating the loss value model: reference to the model, optionally defining optimizer step related hooks **kwargs: Any extra arguments to ``optimizer.step`` """model=modelorself.lightning_modulereturnself.precision_plugin.optimizer_step(model,optimizer,opt_idx,closure,**kwargs)
def_setup_model_and_optimizers(self,model:Module,optimizers:List[Optimizer])->Tuple[Module,List[Optimizer]]:"""Setup a model and multiple optimizers together. The returned objects are expected to be in the same order they were passed in. The default implementation will call :meth:`_setup_model` and :meth:`_setup_optimizer` on the inputs. """# TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324model=self._setup_model(model)optimizers=[self._setup_optimizer(optimizer)foroptimizerinoptimizers]returnmodel,optimizersdef_setup_model(self,model:Module)->Module:"""Performs setup for the model, e.g., by wrapping it by another class."""# TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324returnmodeldef_setup_optimizer(self,optimizer:Optimizer)->Optimizer:"""Performs setup for the optimizer, e.g., by wrapping it by another class."""# TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324returnoptimizer
[docs]defbatch_to_device(self,batch:Any,device:Optional[torch.device]=None,dataloader_idx:int=0)->Any:"""Moves the batch to the correct device. The returned batch is of the same type as the input batch, just having all tensors on the correct device. Args: batch: The batch of samples to move to the correct device device: The target device dataloader_idx: The index of the dataloader to which the batch belongs. """model=self.lightning_moduledevice=deviceorself.root_deviceifmodelisnotNone:returnmodel._apply_batch_transfer_handler(batch,device=device,dataloader_idx=dataloader_idx)returnmove_data_to_device(batch,device)
@property@abstractmethoddefroot_device(self)->torch.device:"""Returns the root device."""
[docs]@abstractmethoddefmodel_to_device(self)->None:"""Moves the model to the correct device."""
@property@abstractmethoddefis_global_zero(self)->bool:"""Whether the current process is the rank zero process not only on the local node, but for all nodes."""
[docs]@abstractmethoddefreduce(self,tensor:Union[torch.Tensor,Any],group:Optional[Any]=None,reduce_op:Optional[Union[ReduceOp,str]]="mean",)->Union[torch.Tensor,Any]:"""Reduces the given tensor (e.g. across GPUs/processes). Args: tensor: the tensor to sync and reduce group: the process group to reduce reduce_op: the reduction operation. Defaults to 'mean'. Can also be a string 'sum' or ReduceOp. """
[docs]@abstractmethoddefbarrier(self,name:Optional[str]=None)->None:"""Synchronizes all processes which blocks processes until the whole group enters this function. Args: name: an optional name to pass into barrier. """
[docs]@abstractmethoddefbroadcast(self,obj:TBroadcast,src:int=0)->TBroadcast:"""Broadcasts an object to all processes. Args: obj: the object to broadcast src: source rank """
[docs]@abstractmethoddefall_gather(self,tensor:torch.Tensor,group:Optional[Any]=None,sync_grads:bool=False)->torch.Tensor:"""Perform an all_gather on all processes. Args: tensor: the tensor to all_gather group: the process group to gather results from sync_grads: flag that allows users to synchronize gradients for all_gather op """
[docs]defreduce_boolean_decision(self,decision:bool)->bool:"""Reduce the early stopping decision across all processes."""returndecision
[docs]defpre_backward(self,closure_loss:torch.Tensor)->None:"""Run before precision plugin executes backward."""
[docs]defpost_backward(self,closure_loss:torch.Tensor)->None:"""Run after precision plugin executes backward."""
@propertydefmodel(self)->Optional[Module]:"""Returns the potentially wrapped LightningModule."""returnself._model@model.setterdefmodel(self,new_model:Optional[Module])->None:self._model=new_model@propertydeflightning_module(self)->Optional["pl.LightningModule"]:"""Returns the pure LightningModule without potential wrappers."""returnunwrap_lightning_module(self.model)ifself.modelisnotNoneelseNonedefload_checkpoint(self,checkpoint_path:_PATH)->Dict[str,Any]:torch.cuda.empty_cache()returnself.checkpoint_io.load_checkpoint(checkpoint_path)defload_model_state_dict(self,checkpoint:Mapping[str,Any])->None:self.lightning_module.load_state_dict(checkpoint["state_dict"])defload_optimizer_state_dict(self,checkpoint:Mapping[str,Any])->None:optimizer_states=checkpoint["optimizer_states"]foroptimizer,opt_stateinzip(self.optimizers,optimizer_states):optimizer.load_state_dict(opt_state)optimizer_to_device(optimizer,self.root_device)
[docs]deftraining_step(self,*args,**kwargs)->STEP_OUTPUT:"""The actual training step. See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details """withself.precision_plugin.train_step_context():returnself.model.training_step(*args,**kwargs)
defpost_training_step(self):pass
[docs]defvalidation_step(self,*args,**kwargs)->Optional[STEP_OUTPUT]:"""The actual validation step. See :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` for more details """withself.precision_plugin.val_step_context():returnself.model.validation_step(*args,**kwargs)
[docs]deftest_step(self,*args,**kwargs)->Optional[STEP_OUTPUT]:"""The actual test step. See :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` for more details """withself.precision_plugin.test_step_context():returnself.model.test_step(*args,**kwargs)
[docs]defpredict_step(self,*args,**kwargs)->STEP_OUTPUT:"""The actual predict step. See :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` for more details """withself.precision_plugin.predict_step_context():returnself.model.predict_step(*args,**kwargs)
[docs]defprocess_dataloader(self,dataloader:DataLoader)->DataLoader:"""Wraps the dataloader if necessary. Args: dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader` """returndataloader
@propertydefrestore_checkpoint_after_setup(self)->bool:"""Override to delay restoring from checkpoint till after pre-dispatch. This is useful when the plugin requires all the setup hooks to run before loading checkpoint. Returns: If true, restore checkpoint after pre_dispatch. """returnFalse@propertydeflightning_restore_optimizer(self)->bool:"""Override to disable Lightning restoring optimizers/schedulers. This is useful for plugins which manage restoring optimizers/schedulers. """returnTrue@propertydefhandles_gradient_accumulation(self)->bool:"""Whether the plugin handles gradient accumulation internally."""returnFalse
[docs]deflightning_module_state_dict(self)->Dict[str,Union[Any,Tensor]]:"""Returns model state."""model=self.lightning_modulereturnmodel.state_dict()
[docs]defsave_checkpoint(self,checkpoint:Dict[str,Any],filepath:_PATH,storage_options:Optional[Any]=None)->None:"""Save model/training states as a checkpoint file through state-dump and file-write. Args: checkpoint: dict containing model and trainer state filepath: write-target file's path storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin """ifself.is_global_zero:self.checkpoint_io.save_checkpoint(checkpoint,filepath,storage_options=storage_options)
[docs]defremove_checkpoint(self,filepath:_PATH)->None:"""Remove checkpoint filepath from the filesystem. Args: filepath: Path to checkpoint """ifself.is_global_zero:self.checkpoint_io.remove_checkpoint(filepath)
[docs]@contextlib.contextmanagerdefmodel_sharded_context(self)->Generator:"""Provide hook to create modules in a distributed aware context. This is useful for when we'd like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time. Returns: Model parallel context. """yield
[docs]defteardown(self)->None:"""This method is called to teardown the training process. It is the right place to release memory and free other resources. """optimizers_to_device(self.optimizers,torch.device("cpu"))self.precision_plugin.teardown()
[docs]defon_train_start(self)->None:"""Called when train begins."""pass
[docs]defon_validation_start(self)->None:"""Called when validation begins."""pass
[docs]defon_test_start(self)->None:"""Called when test begins."""pass
[docs]defon_predict_start(self)->None:"""Called when predict begins."""pass
[docs]defon_train_end(self)->None:"""Called when train ends."""pass
[docs]defon_validation_end(self)->None:"""Called when validation ends."""pass
[docs]defon_test_end(self)->None:"""Called when test end."""pass
[docs]defon_predict_end(self):"""Called when predict ends."""pass
[docs]defon_train_batch_start(self,batch:Any,batch_idx:int,dataloader_idx:int=0)->None:"""Called in the training loop before anything happens for that batch."""pass
[docs]defdispatch(self,trainer:"pl.Trainer")->None:"""Hook to do something before the training/evaluation/prediction starts."""self.precision_plugin.dispatch(trainer)
def__getstate__(self)->Dict:# `LightningOptimizer` overrides `self.__class__` so they cannot be pickledstate=dict(vars(self))# copystate["_lightning_optimizers"]={}returnstatedef__setstate__(self,state:Dict)->None:self.__dict__=stateself.optimizers=self.optimizers# re-create the `_lightning_optimizers`
[docs]defpost_dispatch(self,trainer:"pl.Trainer")->None:r""" .. deprecated:: v1.6 This method has been deprecated in v1.6 and will be removed in v1.7. Use :meth:`teardown` instead. Hook to do something after the training/evaluation/prediction finishes. """
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.