Source code for pytorch_lightning.strategies.strategy
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importcontextlibimportloggingfromabcimportABC,abstractmethodfromtypingimportAny,Callable,Dict,Generator,List,Mapping,Optional,Tuple,TypeVar,UnionimporttorchfromtorchimportTensorfromtorch.nnimportModulefromtorch.optimimportOptimizerfromtorch.utils.dataimportDataLoaderimportpytorch_lightningasplfrompytorch_lightning.core.optimizerimport_init_optimizers_and_lr_schedulers,LightningOptimizerfrompytorch_lightning.overrides.baseimportunwrap_lightning_modulefrompytorch_lightning.pluginsimportTorchCheckpointIOfrompytorch_lightning.plugins.io.checkpoint_pluginimportCheckpointIOfrompytorch_lightning.plugins.io.wrapperimport_WrappingCheckpointIOfrompytorch_lightning.plugins.precisionimportPrecisionPluginfrompytorch_lightning.strategies.launchers.baseimport_Launcherfrompytorch_lightning.trainer.statesimportTrainerFnfrompytorch_lightning.utilities.apply_funcimportmove_data_to_devicefrompytorch_lightning.utilities.distributedimportReduceOpfrompytorch_lightning.utilities.optimizerimportoptimizer_to_device,optimizers_to_devicefrompytorch_lightning.utilities.typesimport(_PATH,LRSchedulerConfig,PredictStep,STEP_OUTPUT,TestStep,TrainingStep,ValidationStep,)TBroadcast=TypeVar("TBroadcast")TReduce=TypeVar("TReduce")log=logging.getLogger(__name__)
[docs]classStrategy(ABC):"""Base class for all strategies that change the behaviour of the training, validation and test- loop."""def__init__(self,accelerator:Optional["pl.accelerators.accelerator.Accelerator"]=None,checkpoint_io:Optional[CheckpointIO]=None,precision_plugin:Optional[PrecisionPlugin]=None,)->None:self._accelerator:Optional["pl.accelerators.accelerator.Accelerator"]=acceleratorself._checkpoint_io:Optional[CheckpointIO]=checkpoint_ioself._precision_plugin:Optional[PrecisionPlugin]=precision_pluginself._launcher:Optional[_Launcher]=Noneself._model:Optional[Module]=Noneself._optimizers:List[Optimizer]=[]self._lightning_optimizers:Dict[int,LightningOptimizer]={}self.lr_scheduler_configs:List[LRSchedulerConfig]=[]self.optimizer_frequencies:List[int]=[]@propertydeflauncher(self)->Optional[_Launcher]:returnself._launcher@propertydefaccelerator(self)->Optional["pl.accelerators.accelerator.Accelerator"]:returnself._accelerator@accelerator.setterdefaccelerator(self,accelerator:"pl.accelerators.accelerator.Accelerator")->None:self._accelerator=accelerator@propertydefcheckpoint_io(self)->CheckpointIO:ifself._checkpoint_ioisNone:self._checkpoint_io=TorchCheckpointIO()elifisinstance(self._checkpoint_io,_WrappingCheckpointIO):self._checkpoint_io.checkpoint_io=TorchCheckpointIO()returnself._checkpoint_io@checkpoint_io.setterdefcheckpoint_io(self,io:Optional[CheckpointIO])->None:self._checkpoint_io=io@propertydefprecision_plugin(self)->PrecisionPlugin:returnself._precision_pluginifself._precision_pluginisnotNoneelsePrecisionPlugin()@precision_plugin.setterdefprecision_plugin(self,precision_plugin:Optional[PrecisionPlugin])->None:self._precision_plugin=precision_plugin@propertydefoptimizers(self)->List[Optimizer]:returnself._optimizers@optimizers.setterdefoptimizers(self,optimizers:List[Optimizer])->None:self._optimizers=optimizersself._lightning_optimizers={idx:LightningOptimizer._to_lightning_optimizer(opt,self,idx)foridx,optinenumerate(self.optimizers)}
[docs]defconnect(self,model:Module)->None:"""Called by the accelerator to connect the accelerator and the model with this plugin."""self.model=model
def_configure_launcher(self)->None:"""Attach the launcher based on Strategy."""
[docs]defsetup_environment(self)->None:"""Setup any processes or distributed connections. This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator environment before setup is complete. """assertself.acceleratorisnotNoneself.accelerator.setup_environment(self.root_device)
[docs]defsetup_optimizers(self,trainer:"pl.Trainer")->None:"""Creates optimizers and schedulers. Args: trainer: the Trainer, these optimizers should be connected to """iftrainer.state.fnnotin(TrainerFn.FITTING,TrainerFn.TUNING):returnassertself.lightning_moduleisnotNoneself.optimizers,self.lr_scheduler_configs,self.optimizer_frequencies=_init_optimizers_and_lr_schedulers(self.lightning_module)
[docs]defsetup(self,trainer:"pl.Trainer")->None:"""Setup plugins for the trainer fit and creates optimizers. Args: trainer: the trainer instance """assertself.acceleratorisnotNoneself.accelerator.setup(trainer)self.setup_optimizers(trainer)self.setup_precision_plugin()optimizers_to_device(self.optimizers,self.root_device)
[docs]defsetup_precision_plugin(self)->None:"""Attaches the precision plugin to the accelerator."""assertself.modelisnotNonemodel,optimizers,lr_scheduler_configs=self.precision_plugin.connect(self.model,self.optimizers,self.lr_scheduler_configs)self.model=modelself.optimizers=optimizersself.lr_scheduler_configs=lr_scheduler_configs
[docs]defoptimizer_state(self,optimizer:Optimizer)->Dict[str,Tensor]:"""Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom plugins. """returnoptimizer.state_dict()
[docs]defbackward(self,closure_loss:Tensor,optimizer:Optional[Optimizer],optimizer_idx:Optional[int],*args:Any,**kwargs:Any,)->Tensor:"""Forwards backward-calls to the precision plugin. Args: closure_loss: a tensor holding the loss value to backpropagate """self.pre_backward(closure_loss)assertself.lightning_moduleisnotNoneclosure_loss=self.precision_plugin.pre_backward(self.lightning_module,closure_loss)self.precision_plugin.backward(self.lightning_module,closure_loss,optimizer,optimizer_idx,*args,**kwargs)closure_loss=self.precision_plugin.post_backward(self.lightning_module,closure_loss)self.post_backward(closure_loss)returnclosure_loss
[docs]defoptimizer_step(self,optimizer:Optimizer,opt_idx:int,closure:Callable[[],Any],model:Optional[Union["pl.LightningModule",Module]]=None,**kwargs:Any,)->Any:"""Performs the actual optimizer step. Args: optimizer: the optimizer performing the step opt_idx: index of the current optimizer closure: closure calculating the loss value model: reference to the model, optionally defining optimizer step related hooks **kwargs: Any extra arguments to ``optimizer.step`` """model=modelorself.lightning_modulereturnself.precision_plugin.optimizer_step(model,optimizer,opt_idx,closure,**kwargs)
def_setup_model_and_optimizers(self,model:Module,optimizers:List[Optimizer])->Tuple[Module,List[Optimizer]]:"""Setup a model and multiple optimizers together. The returned objects are expected to be in the same order they were passed in. The default implementation will call :meth:`_setup_model` and :meth:`_setup_optimizer` on the inputs. """# TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324model=self._setup_model(model)optimizers=[self._setup_optimizer(optimizer)foroptimizerinoptimizers]returnmodel,optimizersdef_setup_model(self,model:Module)->Module:"""Performs setup for the model, e.g., by wrapping it by another class."""# TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324returnmodeldef_setup_optimizer(self,optimizer:Optimizer)->Optimizer:"""Performs setup for the optimizer, e.g., by wrapping it by another class."""# TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324returnoptimizer
[docs]defbatch_to_device(self,batch:Any,device:Optional[torch.device]=None,dataloader_idx:int=0)->Any:"""Moves the batch to the correct device. The returned batch is of the same type as the input batch, just having all tensors on the correct device. Args: batch: The batch of samples to move to the correct device device: The target device dataloader_idx: The index of the dataloader to which the batch belongs. """model=self.lightning_moduledevice=deviceorself.root_deviceifmodelisnotNone:returnmodel._apply_batch_transfer_handler(batch,device=device,dataloader_idx=dataloader_idx)returnmove_data_to_device(batch,device)
@property@abstractmethoddefroot_device(self)->torch.device:"""Returns the root device."""
[docs]@abstractmethoddefmodel_to_device(self)->None:"""Moves the model to the correct device."""
@property@abstractmethoddefis_global_zero(self)->bool:"""Whether the current process is the rank zero process not only on the local node, but for all nodes."""
[docs]@abstractmethoddefreduce(self,tensor:Union[Tensor,Any],group:Optional[Any]=None,reduce_op:Optional[Union[ReduceOp,str]]="mean",)->Union[Tensor,Any]:"""Reduces the given tensor (e.g. across GPUs/processes). Args: tensor: the tensor to sync and reduce group: the process group to reduce reduce_op: the reduction operation. Defaults to 'mean'. Can also be a string 'sum' or ReduceOp. """
[docs]@abstractmethoddefbarrier(self,name:Optional[str]=None)->None:"""Synchronizes all processes which blocks processes until the whole group enters this function. Args: name: an optional name to pass into barrier. """
[docs]@abstractmethoddefbroadcast(self,obj:TBroadcast,src:int=0)->TBroadcast:"""Broadcasts an object to all processes. Args: obj: the object to broadcast src: source rank """
[docs]@abstractmethoddefall_gather(self,tensor:Tensor,group:Optional[Any]=None,sync_grads:bool=False)->Tensor:"""Perform an all_gather on all processes. Args: tensor: the tensor to all_gather group: the process group to gather results from sync_grads: flag that allows users to synchronize gradients for all_gather op """
[docs]defreduce_boolean_decision(self,decision:bool)->bool:"""Reduce a boolean decision across all processes."""returndecision
[docs]defpre_backward(self,closure_loss:Tensor)->None:"""Run before precision plugin executes backward."""
[docs]defpost_backward(self,closure_loss:Tensor)->None:"""Run after precision plugin executes backward."""
@propertydefmodel(self)->Optional[Module]:"""Returns the potentially wrapped LightningModule."""returnself._model@model.setterdefmodel(self,new_model:Optional[Module])->None:self._model=new_model@propertydeflightning_module(self)->Optional["pl.LightningModule"]:"""Returns the pure LightningModule without potential wrappers."""returnunwrap_lightning_module(self.model)ifself.modelisnotNoneelseNonedefload_checkpoint(self,checkpoint_path:_PATH)->Dict[str,Any]:torch.cuda.empty_cache()returnself.checkpoint_io.load_checkpoint(checkpoint_path)defload_model_state_dict(self,checkpoint:Mapping[str,Any])->None:assertself.lightning_moduleisnotNoneself.lightning_module.load_state_dict(checkpoint["state_dict"])defload_optimizer_state_dict(self,checkpoint:Mapping[str,Any])->None:optimizer_states=checkpoint["optimizer_states"]foroptimizer,opt_stateinzip(self.optimizers,optimizer_states):optimizer.load_state_dict(opt_state)optimizer_to_device(optimizer,self.root_device)
[docs]deftraining_step(self,*args:Any,**kwargs:Any)->STEP_OUTPUT:"""The actual training step. See :meth:`~pytorch_lightning.core.module.LightningModule.training_step` for more details """withself.precision_plugin.train_step_context():assertisinstance(self.model,TrainingStep)returnself.model.training_step(*args,**kwargs)
defpost_training_step(self)->None:pass
[docs]defvalidation_step(self,*args:Any,**kwargs:Any)->Optional[STEP_OUTPUT]:"""The actual validation step. See :meth:`~pytorch_lightning.core.module.LightningModule.validation_step` for more details """withself.precision_plugin.val_step_context():assertisinstance(self.model,ValidationStep)returnself.model.validation_step(*args,**kwargs)
[docs]deftest_step(self,*args:Any,**kwargs:Any)->Optional[STEP_OUTPUT]:"""The actual test step. See :meth:`~pytorch_lightning.core.module.LightningModule.test_step` for more details """withself.precision_plugin.test_step_context():assertisinstance(self.model,TestStep)returnself.model.test_step(*args,**kwargs)
[docs]defpredict_step(self,*args:Any,**kwargs:Any)->STEP_OUTPUT:"""The actual predict step. See :meth:`~pytorch_lightning.core.module.LightningModule.predict_step` for more details """withself.precision_plugin.predict_step_context():assertisinstance(self.model,PredictStep)returnself.model.predict_step(*args,**kwargs)
[docs]defprocess_dataloader(self,dataloader:DataLoader)->DataLoader:"""Wraps the dataloader if necessary. Args: dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader` """returndataloader
@propertydefrestore_checkpoint_after_setup(self)->bool:"""Override to delay restoring from checkpoint till after pre-dispatch. This is useful when the plugin requires all the setup hooks to run before loading checkpoint. Returns: If true, restore checkpoint after pre_dispatch. """returnFalse@propertydeflightning_restore_optimizer(self)->bool:"""Override to disable Lightning restoring optimizers/schedulers. This is useful for plugins which manage restoring optimizers/schedulers. """returnTrue@propertydefhandles_gradient_accumulation(self)->bool:"""Whether the plugin handles gradient accumulation internally."""returnFalse
[docs]deflightning_module_state_dict(self)->Dict[str,Union[Any,Tensor]]:"""Returns model state."""assertself.lightning_moduleisnotNonereturnself.lightning_module.state_dict()
[docs]defsave_checkpoint(self,checkpoint:Dict[str,Any],filepath:_PATH,storage_options:Optional[Any]=None)->None:"""Save model/training states as a checkpoint file through state-dump and file-write. Args: checkpoint: dict containing model and trainer state filepath: write-target file's path storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin """ifself.is_global_zero:self.checkpoint_io.save_checkpoint(checkpoint,filepath,storage_options=storage_options)
[docs]defremove_checkpoint(self,filepath:_PATH)->None:"""Remove checkpoint filepath from the filesystem. Args: filepath: Path to checkpoint """ifself.is_global_zero:self.checkpoint_io.remove_checkpoint(filepath)
[docs]@contextlib.contextmanagerdefmodel_sharded_context(self)->Generator:"""Provide hook to create modules in a distributed aware context. This is useful for when we'd like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time. Returns: Model parallel context. """yield
[docs]defteardown(self)->None:"""This method is called to teardown the training process. It is the right place to release memory and free other resources. """optimizers_to_device(self.optimizers,torch.device("cpu"))ifself.lightning_moduleisnotNone:log.detail(f"{self.__class__.__name__}: moving model to CPU")self.lightning_module.cpu()self.precision_plugin.teardown()assertself.acceleratorisnotNoneself.accelerator.teardown()self.checkpoint_io.teardown()
[docs]defon_train_start(self)->None:"""Called when train begins."""pass
[docs]defon_validation_start(self)->None:"""Called when validation begins."""pass
[docs]defon_test_start(self)->None:"""Called when test begins."""pass
[docs]defon_predict_start(self)->None:"""Called when predict begins."""pass
[docs]defon_train_end(self)->None:"""Called when train ends."""pass
[docs]defon_validation_end(self)->None:"""Called when validation ends."""pass
[docs]defon_test_end(self)->None:"""Called when test end."""pass
[docs]defon_predict_end(self)->None:"""Called when predict ends."""pass
[docs]defon_train_batch_start(self,batch:Any,batch_idx:int)->None:"""Called in the training loop before anything happens for that batch."""pass
[docs]defdispatch(self,trainer:"pl.Trainer")->None:"""Hook to do something before the training/evaluation/prediction starts."""self.precision_plugin.dispatch(trainer)
def__getstate__(self)->Dict:# `LightningOptimizer` overrides `self.__class__` so they cannot be pickledstate=dict(vars(self))# copystate["_lightning_optimizers"]={}returnstatedef__setstate__(self,state:Dict)->None:self.__dict__=stateself.optimizers=self.optimizers# re-create the `_lightning_optimizers`
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.