# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importloggingfromcontextlibimportnullcontextfromdatetimeimporttimedeltafromtypingimportTYPE_CHECKING,Any,Callable,Dict,List,Literal,Optional,Unionimporttorchimporttorch.distributedfromlightning_utilities.core.rank_zeroimportrank_zero_onlyasutils_rank_zero_onlyfromtorchimportTensorfromtorch.nnimportModulefromtorch.nn.parallel.distributedimportDistributedDataParallelfromtorch.optim.optimizerimportOptimizerfromtyping_extensionsimportoverrideimportlightning.pytorchasplfromlightning.fabric.pluginsimportCheckpointIO,ClusterEnvironmentfromlightning.fabric.plugins.collectives.torch_collectiveimportdefault_pg_timeoutfromlightning.fabric.strategiesimport_StrategyRegistryfromlightning.fabric.utilities.distributedimport(_distributed_is_initialized,_get_default_process_group_backend_for_device,_init_dist_connection,_sync_ddp_if_available,)fromlightning.fabric.utilities.distributedimportgroupas_groupfromlightning.fabric.utilities.importsimport_IS_WINDOWSfromlightning.fabric.utilities.optimizerimport_optimizers_to_devicefromlightning.fabric.utilities.seedimportreset_seedfromlightning.fabric.utilities.typesimportReduceOpfromlightning.pytorch.core.optimizerimportLightningOptimizerfromlightning.pytorch.overrides.distributedimport_register_ddp_comm_hook,_sync_module_states,prepare_for_backwardfromlightning.pytorch.plugins.precisionimportPrecisionfromlightning.pytorch.strategies.launchersimport_MultiProcessingLauncher,_SubprocessScriptLauncherfromlightning.pytorch.strategies.parallelimportParallelStrategyfromlightning.pytorch.strategies.strategyimportTBroadcast,_ForwardRedirectionfromlightning.pytorch.trainer.statesimportTrainerFnfromlightning.pytorch.utilities.exceptionsimport_augment_messagefromlightning.pytorch.utilities.rank_zeroimportrank_zero_deprecation,rank_zero_info,rank_zero_onlyifTYPE_CHECKING:fromtorch.distributed.algorithms.model_averaging.averagersimportModelAveragerlog=logging.getLogger(__name__)_DDP_FORK_ALIASES=("ddp_fork","ddp_fork_find_unused_parameters_false","ddp_fork_find_unused_parameters_true","ddp_notebook","ddp_notebook_find_unused_parameters_false","ddp_notebook_find_unused_parameters_true",)
[docs]classDDPStrategy(ParallelStrategy):"""Strategy for multi-process single-device training on one or multiple nodes."""def__init__(self,accelerator:Optional["pl.accelerators.Accelerator"]=None,parallel_devices:Optional[List[torch.device]]=None,cluster_environment:Optional[ClusterEnvironment]=None,checkpoint_io:Optional[CheckpointIO]=None,precision_plugin:Optional[Precision]=None,ddp_comm_state:Optional[object]=None,ddp_comm_hook:Optional[Callable]=None,ddp_comm_wrapper:Optional[Callable]=None,model_averaging_period:Optional[int]=None,process_group_backend:Optional[str]=None,timeout:Optional[timedelta]=default_pg_timeout,start_method:Literal["popen","spawn","fork","forkserver"]="popen",**kwargs:Any,)->None:super().__init__(accelerator=accelerator,parallel_devices=parallel_devices,cluster_environment=cluster_environment,checkpoint_io=checkpoint_io,precision_plugin=precision_plugin,)log.debug(f"{self.__class__.__name__}: initializing DDP strategy")self._forward_redirection=_DDPForwardRedirection()self._num_nodes=1self._ddp_kwargs=kwargsself._ddp_comm_state=ddp_comm_stateself._ddp_comm_hook=ddp_comm_hookself._ddp_comm_wrapper=ddp_comm_wrapperself._model_averaging_period=model_averaging_periodself._model_averager:Optional[ModelAverager]=Noneself._process_group_backend:Optional[str]=process_group_backendself._timeout:Optional[timedelta]=timeoutself._start_method=start_method@propertydefis_distributed(self)->bool:# pragma: no-cover"""Legacy property kept for backwards compatibility."""rank_zero_deprecation(f"`{type(self).__name__}.is_distributed` is deprecated. Use is discouraged.",stacklevel=6)returnTrue@property@overridedefroot_device(self)->torch.device:assertself.parallel_devicesisnotNonereturnself.parallel_devices[self.local_rank]@propertydefnum_nodes(self)->int:returnself._num_nodes@num_nodes.setterdefnum_nodes(self,num_nodes:int)->None:# note that world ranks is related to num_nodes, when resetting it, need to reset world ranksself._num_nodes=num_nodes@propertydefnum_processes(self)->int:returnlen(self.parallel_devices)ifself.parallel_devicesisnotNoneelse0@property@overridedefdistributed_sampler_kwargs(self)->Dict[str,Any]:return{"num_replicas":(self.num_nodes*self.num_processes),"rank":self.global_rank}@propertydefprocess_group_backend(self)->Optional[str]:returnself._process_group_backend@overridedef_configure_launcher(self)->None:assertself.cluster_environmentisnotNoneifself._start_method=="popen":self._launcher=_SubprocessScriptLauncher(self.cluster_environment,self.num_processes,self.num_nodes)else:self._launcher=_MultiProcessingLauncher(self,start_method=self._start_method)
[docs]@overridedefsetup(self,trainer:"pl.Trainer")->None:assertself.acceleratorisnotNoneself.accelerator.setup(trainer)trainer_fn=trainer.state.fnassertself.modelisnotNoneiftrainer_fn==TrainerFn.FITTINGandself._layer_sync:self.model=self._layer_sync.apply(self.model)self.precision_plugin.convert_module(self.model)self.model_to_device()iftrainer_fn==TrainerFn.FITTING:# do not wrap with DDP if not fitting as there's no gradients to reduceself.configure_ddp()# set up optimizers after the wrapped module has been moved to the deviceself.setup_optimizers(trainer)else:# we need to manually synchronize the module's states since we aren't using the DDP wrapper_sync_module_states(self.model)self.setup_precision_plugin()iftrainer_fn==TrainerFn.FITTING:_optimizers_to_device(self.optimizers,self.root_device)importtorch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hookaspost_localSGDifisinstance(self._ddp_comm_state,post_localSGD.PostLocalSGDState):self._enable_model_averaging()
@overridedef_setup_model(self,model:Module)->DistributedDataParallel:"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""device_ids=self.determine_ddp_device_ids()log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")# https://pytorch.org/docs/stable/notes/cuda.html#id5ctx=torch.cuda.stream(torch.cuda.Stream())ifdevice_idsisnotNoneelsenullcontext()withctx:returnDistributedDataParallel(module=model,device_ids=device_ids,**self._ddp_kwargs)defsetup_distributed(self)->None:log.debug(f"{self.__class__.__name__}: setting up distributed...")reset_seed()self.set_world_ranks()self._process_group_backend=self._get_process_group_backend()assertself.cluster_environmentisnotNone_init_dist_connection(self.cluster_environment,self._process_group_backend,timeout=self._timeout)def_get_process_group_backend(self)->str:returnself._process_group_backendor_get_default_process_group_backend_for_device(self.root_device)defset_world_ranks(self)->None:ifself.cluster_environmentisnotNone:self.cluster_environment.set_global_rank(self.node_rank*self.num_processes+self.local_rank)self.cluster_environment.set_world_size(self.num_nodes*self.num_processes)# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail# additionally, for some implementations, the setter is a no-op, so it's safer to access the getterrank_zero_only.rank=utils_rank_zero_only.rank=self.global_rankdef_register_ddp_hooks(self)->None:log.debug(f"{self.__class__.__name__}: registering ddp hooks")# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084ifself.root_device.type=="cuda":assertisinstance(self.model,DistributedDataParallel)_register_ddp_comm_hook(model=self.model,ddp_comm_state=self._ddp_comm_state,ddp_comm_hook=self._ddp_comm_hook,ddp_comm_wrapper=self._ddp_comm_wrapper,)def_enable_model_averaging(self)->None:log.debug(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD")ifself._model_averaging_periodisNone:raiseValueError("Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy.")fromtorch.distributed.optimimportDistributedOptimizer,PostLocalSGDOptimizer,ZeroRedundancyOptimizerforoptimizerinself.optimizers:ifisinstance(optimizer,LightningOptimizer):optimizer=optimizer._optimizeris_distributed_optimizer=isinstance(optimizer,DistributedOptimizer)ifnot_IS_WINDOWSelseFalseifisinstance(optimizer,(ZeroRedundancyOptimizer,PostLocalSGDOptimizer))oris_distributed_optimizer:raiseValueError(f"Currently model averaging cannot work with a distributed optimizer of type "f"{optimizer.__class__.__name__}.")assertself._ddp_comm_stateisnotNoneself._model_averager=torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager(period=self._model_averaging_period,warmup_steps=self._ddp_comm_state.start_localSGD_iter)
[docs]@overridedefoptimizer_step(self,optimizer:Optimizer,closure:Callable[[],Any],model:Optional[Union["pl.LightningModule",Module]]=None,**kwargs:Any,)->Any:"""Performs the actual optimizer step. Args: optimizer: the optimizer performing the step closure: closure calculating the loss value model: reference to the model, optionally defining optimizer step related hooks **kwargs: Any extra arguments to ``optimizer.step`` """optimizer_output=super().optimizer_step(optimizer,closure,model,**kwargs)ifself._model_averagerisNone:returnoptimizer_outputparams=[paramforgroupinoptimizer.param_groupsforparamingroup["params"]ifparam.gradisnotNone]self._model_averager.average_parameters(iter(params))returnoptimizer_output
[docs]@overridedefpre_backward(self,closure_loss:Tensor)->None:"""Run before precision plugin executes backward."""ifnotisinstance(self.model,DistributedDataParallel):returnassertself.lightning_moduleisnotNoneifnotself.lightning_module.automatic_optimization:prepare_for_backward(self.model,closure_loss)
[docs]@overridedefmodel_to_device(self)->None:log.debug(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")assertself.modelisnotNoneself.model.to(self.root_device)
[docs]@overridedefreduce(self,tensor:Tensor,group:Optional[Any]=None,reduce_op:Optional[Union[ReduceOp,str]]="mean")->Tensor:"""Reduces a tensor from several distributed processes to one aggregated tensor. Args: tensor: the tensor to sync and reduce group: the process group to gather results from. Defaults to all processes (world) reduce_op: the reduction operation. Defaults to 'mean'/'avg'. Can also be a string 'sum' to calculate the sum during reduction. Return: reduced value, except when the input was not a tensor the output remains is unchanged """ifisinstance(tensor,Tensor):return_sync_ddp_if_available(tensor,group,reduce_op=reduce_op)returntensor
@classmethod@overridedefregister_strategies(cls,strategy_registry:_StrategyRegistry)->None:entries=(("ddp","popen"),("ddp_spawn","spawn"),("ddp_fork","fork"),("ddp_notebook","fork"),)forname,start_methodinentries:strategy_registry.register(name,cls,description=f"DDP strategy with `start_method` '{start_method}'",start_method=start_method,)entries=(("ddp_find_unused_parameters_false",False,"popen"),("ddp_find_unused_parameters_true",True,"popen"),("ddp_spawn_find_unused_parameters_false",False,"spawn"),("ddp_spawn_find_unused_parameters_true",True,"spawn"),("ddp_fork_find_unused_parameters_false",False,"fork"),("ddp_fork_find_unused_parameters_true",True,"fork"),("ddp_notebook_find_unused_parameters_false",False,"fork"),("ddp_notebook_find_unused_parameters_true",True,"fork"),)forname,fup,start_methodinentries:strategy_registry.register(name,cls,description=f"DDP strategy with `find_unused_parameters` as {fup} and `start_method` '{start_method}'",find_unused_parameters=fup,start_method=start_method,)
[docs]@overridedefon_exception(self,exception:BaseException)->None:_augment_message(exception,pattern=".*Expected to have finished reduction in the prior iteration.*",new_message=("It looks like your LightningModule has parameters that were not used in producing the loss returned"" by training_step. If this is intentional, you must enable the detection of unused parameters in DDP,"" either by setting the string value `strategy='ddp_find_unused_parameters_true'`"" or by setting the flag in the strategy with `strategy=DDPStrategy(find_unused_parameters=True)`."),)
[docs]@overridedefteardown(self)->None:log.debug(f"{self.__class__.__name__}: tearing down strategy")pl_module=self.lightning_moduleifisinstance(self.model,DistributedDataParallel):ifnotself.model.static_graphandself.model._get_ddp_logging_data().get("can_set_static_graph"):rank_zero_info("Your model can run with static graph optimizations. For future training runs, we suggest you"f" pass `Trainer(..., strategy={self.__class__.__name__}(static_graph=True))` to enable them.")# unwrap modelself.model=pl_moduleif(pl_moduleisnotNone# `self.lightning_module._trainer` can be None if teardown gets called on an exception before# the trainer gets set on the LightningModuleandpl_module._trainerisnotNoneandpl_module._trainer.state.fn==TrainerFn.FITTINGandself._layer_sync):assertself.modelisnotNoneself.model=self._layer_sync.revert(self.model)super().teardown()
class_DDPForwardRedirection(_ForwardRedirection):@overridedefon_after_inner_forward(self,wrapper_module:Module,original_module:"pl.LightningModule")->None:# In manual_optimization, we need to prevent DDP reducer as# it is done manually in `LightningModule.manual_backward`ifisinstance(wrapper_module,DistributedDataParallel)andnotoriginal_module.automatic_optimization:wrapper_module.require_backward_grad_sync=False@overridedefon_after_outer_forward(self,wrapper_module:Module,original_module:"pl.LightningModule")->None:ifisinstance(wrapper_module,DistributedDataParallel)andnotoriginal_module.automatic_optimization:wrapper_module.require_backward_grad_sync=True
To analyze traffic and optimize your experience, we serve cookies on this
site. By clicking or navigating, you agree to allow our usage of cookies.
Read PyTorch Lightning's
Privacy Policy.
You are viewing an outdated version of PyTorch Lightning Docs