Source code for lightning_fabric.strategies.strategy
# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importloggingfromabcimportABC,abstractmethodfromcontextlibimportcontextmanagerfromtypingimportAny,Dict,Generator,Iterable,List,Mapping,Optional,Tuple,TypeVar,UnionimporttorchfromtorchimportTensorfromtorch.nnimportModulefromtorch.optimimportOptimizerfromtorch.utils.dataimportDataLoaderfromlightning_fabric.acceleratorsimportAcceleratorfromlightning_fabric.plugins.io.checkpoint_ioimportCheckpointIOfromlightning_fabric.plugins.io.torch_ioimportTorchCheckpointIOfromlightning_fabric.plugins.precisionimportPrecisionfromlightning_fabric.strategies.launchers.baseimport_Launcherfromlightning_fabric.utilities.apply_funcimportmove_data_to_devicefromlightning_fabric.utilities.optimizerimport_optimizer_to_devicefromlightning_fabric.utilities.typesimport_PATH,Optimizable,ReduceOpTBroadcast=TypeVar("TBroadcast")TReduce=TypeVar("TReduce")log=logging.getLogger(__name__)
[docs]classStrategy(ABC):"""Base class for all strategies that change the behaviour of the training, validation and test- loop."""def__init__(self,accelerator:Optional[Accelerator]=None,checkpoint_io:Optional[CheckpointIO]=None,precision:Optional[Precision]=None,)->None:self._accelerator:Optional[Accelerator]=acceleratorself._checkpoint_io:Optional[CheckpointIO]=checkpoint_ioself._precision:Optional[Precision]=precisionself._launcher:Optional[_Launcher]=Noneself._backward_sync_control:Optional[_BackwardSyncControl]=None@property@abstractmethoddefroot_device(self)->torch.device:"""Returns the root device."""@property@abstractmethoddefis_global_zero(self)->bool:"""Whether the current process is the rank zero process not only on the local node, but for all nodes."""@propertydeflauncher(self)->Optional[_Launcher]:returnself._launcher@propertydefaccelerator(self)->Optional[Accelerator]:returnself._accelerator@accelerator.setterdefaccelerator(self,accelerator:Accelerator)->None:self._accelerator=accelerator@propertydefcheckpoint_io(self)->CheckpointIO:ifself._checkpoint_ioisNone:self._checkpoint_io=TorchCheckpointIO()returnself._checkpoint_io@checkpoint_io.setterdefcheckpoint_io(self,io:Optional[CheckpointIO])->None:self._checkpoint_io=io@propertydefprecision(self)->Precision:returnself._precisionifself._precisionisnotNoneelsePrecision()@precision.setterdefprecision(self,precision:Optional[Precision])->None:self._precision=precisiondef_configure_launcher(self)->None:"""Attach the launcher based on Strategy."""
[docs]defsetup_environment(self)->None:"""Setup any processes or distributed connections. This must be called by the framework at the beginning of every process, before any distributed communication takes place. """assertself.acceleratorisnotNoneself.accelerator.setup_device(self.root_device)
[docs]defprocess_dataloader(self,dataloader:DataLoader)->DataLoader:"""Wraps the dataloader if necessary. Args: dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader` """returndataloader
[docs]defsetup_module_and_optimizers(self,module:Module,optimizers:List[Optimizer])->Tuple[Module,List[Optimizer]]:"""Set up a model and multiple optimizers together. The returned objects are expected to be in the same order they were passed in. The default implementation will call :meth:`setup_module` and :meth:`setup_optimizer` on the inputs. """module=self.setup_module(module)optimizers=[self.setup_optimizer(optimizer)foroptimizerinoptimizers]returnmodule,optimizers
[docs]defsetup_module(self,module:Module)->Module:"""Performs setup for the model, e.g., by wrapping it by another class."""returnmodule
[docs]defsetup_optimizer(self,optimizer:Optimizer)->Optimizer:"""Performs setup for the optimizer, e.g., by wrapping it by another class."""returnoptimizer
[docs]@abstractmethoddefmodule_to_device(self,module:Module)->None:"""Moves the model to the correct device."""
[docs]defbatch_to_device(self,batch:Any,device:Optional[torch.device]=None)->Any:"""Moves the batch to the correct device. The returned batch is of the same type as the input batch, just having all tensors on the correct device. Args: batch: The batch of samples to move to the correct device device: The target device """device=deviceorself.root_devicereturnmove_data_to_device(batch,device)
[docs]defbackward(self,tensor:Tensor,module:Optional[Module],*args:Any,**kwargs:Any)->None:r"""Forwards backward-calls to the precision plugin."""self.precision.pre_backward(tensor,module)self.precision.backward(tensor,module,*args,**kwargs)self.precision.post_backward(tensor,module)
[docs]defoptimizer_step(self,optimizer:Optimizable,**kwargs:Any,)->Any:"""Performs the actual optimizer step. Args: optimizer: the optimizer performing the step **kwargs: Any extra arguments to ``optimizer.step`` """returnself.precision.optimizer_step(optimizer,**kwargs)
[docs]@abstractmethoddefall_gather(self,tensor:Tensor,group:Optional[Any]=None,sync_grads:bool=False)->Tensor:"""Perform an all_gather on all processes. Args: tensor: the tensor to all_gather group: the process group to gather results from sync_grads: flag that allows users to synchronize gradients for all_gather op """
[docs]@abstractmethoddefall_reduce(self,tensor:Union[Tensor,Any],group:Optional[Any]=None,reduce_op:Optional[Union[ReduceOp,str]]="mean",)->Union[Tensor,Any]:"""Reduces the given tensor (e.g. across GPUs/processes). Args: tensor: the tensor to sync and reduce group: the process group to reduce reduce_op: the reduction operation. Defaults to 'mean'. Can also be a string 'sum' or ReduceOp. """
[docs]@abstractmethoddefbarrier(self,name:Optional[str]=None)->None:"""Synchronizes all processes which blocks processes until the whole group enters this function. Args: name: an optional name to pass into barrier. """
[docs]@abstractmethoddefbroadcast(self,obj:TBroadcast,src:int=0)->TBroadcast:"""Broadcasts an object to all processes. Args: obj: the object to broadcast src: source rank """
[docs]defreduce_boolean_decision(self,decision:bool,all:bool=True)->bool:"""Reduce a boolean decision across all processes."""returndecision
[docs]defsave_checkpoint(self,checkpoint:Dict[str,Any],filepath:_PATH,storage_options:Optional[Any]=None)->None:"""Save model/training states as a checkpoint file through state-dump and file-write. Args: checkpoint: dict containing model and trainer state filepath: write-target file's path storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin """ifself.is_global_zero:self.checkpoint_io.save_checkpoint(checkpoint,filepath,storage_options=storage_options)
[docs]defget_module_state_dict(self,module:Module)->Dict[str,Union[Any,Tensor]]:"""Returns model state."""# TODO(fabric): Integrate this into Lightning Fabricreturnmodule.state_dict()
[docs]defget_optimizer_state(self,optimizer:Optimizer)->Dict[str,Tensor]:"""Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom plugins. """ifhasattr(optimizer,"consolidate_state_dict"):# there are optimizers like PyTorch's ZeroRedundancyOptimizer that shard their# states, and to avoid OOM we consolidate the full state on rank 0 onlyoptimizer.consolidate_state_dict()returnoptimizer.state_dict()ifself.is_global_zeroelse{}# for optimizers that are not sharded, we return the state dict on all ranksreturnoptimizer.state_dict()
defload_checkpoint(self,checkpoint_path:_PATH)->Dict[str,Any]:torch.cuda.empty_cache()returnself.checkpoint_io.load_checkpoint(checkpoint_path)defload_module_state_dict(self,module:Module,checkpoint:Mapping[str,Any])->None:# TODO(fabric): Integrate this into Lightning Fabricmodule.load_state_dict(checkpoint["state_dict"])defload_optimizer_state_dict(self,optimizers:Union[Optimizer,Iterable[Optimizer]],checkpoint:Mapping[str,Any])->None:ifnotisinstance(optimizers,Iterable):optimizers=[optimizers]optimizer_states=checkpoint["optimizer_states"]foroptimizer,opt_stateinzip(optimizers,optimizer_states):optimizer.load_state_dict(opt_state)_optimizer_to_device(optimizer,self.root_device)
[docs]defremove_checkpoint(self,filepath:_PATH)->None:"""Remove checkpoint filepath from the filesystem. Args: filepath: Path to checkpoint """ifself.is_global_zero:self.checkpoint_io.remove_checkpoint(filepath)
[docs]defteardown(self)->None:"""This method is called to teardown the training process. It is the right place to release memory and free other resources. """self.precision.teardown()assertself.acceleratorisnotNoneself.accelerator.teardown()self.checkpoint_io.teardown()
@classmethoddefregister_strategies(cls,strategy_registry:Dict[str,Any])->None:passdef_err_msg_joint_setup_required(self)->str:return(f"The `{type(self).__name__}` does not support setting up the module and optimizer(s) independently."" Please call `setup_module_and_optimizers(model, [optimizer, ...])` to jointly set them up.")
class_BackwardSyncControl(ABC):"""Interface for any :class:`Strategy` that wants to offer a functionality to enable or disable gradient synchronization during/after back-propagation. The most common use-case is gradient accumulation. If a :class:`Strategy` implements this interface, the user can implement their gradient accumulation loop very efficiently by disabling redundant gradient synchronization. """@contextmanager@abstractmethoddefno_backward_sync(self,module:Module)->Generator:"""Blocks the synchronization of gradients during the backward pass. This is a context manager. It is only effective if it wraps a call to `.backward()`. """class_Sharded(ABC):"""Mixin-interface for any :class:`Strategy` that wants to expose functionality for sharding model parameters."""@abstractmethod@contextmanagerdefmodule_sharded_context(self)->Generator:"""A context manager that goes over the instantiation of an :class:`torch.nn.Module` and handles sharding of parameters on creation. By sharding layers directly on instantiation, one can reduce peak memory usage and initialization time. """yield
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.