# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importloggingimportosimportshutilimportsignalimporttempfileimporttimefrompathlibimportPathfromtypingimportAny,Callable,Dict,List,Optional,Unionimporttorchimporttorch.distributedfromtorch.nnimportModulefromtorch.nn.parallel.distributedimportDistributedDataParallelfromtorch.optim.optimizerimportOptimizerimportpytorch_lightningasplfrompytorch_lightning.core.optimizerimportLightningOptimizerfrompytorch_lightning.overridesimportLightningDistributedModulefrompytorch_lightning.overrides.distributedimportprepare_for_backwardfrompytorch_lightning.plugins.environments.cluster_environmentimportClusterEnvironmentfrompytorch_lightning.plugins.io.checkpoint_pluginimportCheckpointIOfrompytorch_lightning.plugins.precisionimportPrecisionPluginfrompytorch_lightning.strategies.launchers.subprocess_scriptimport_SubprocessScriptLauncherfrompytorch_lightning.strategies.parallelimportParallelStrategyfrompytorch_lightning.trainer.statesimportTrainerFnfrompytorch_lightning.utilities.distributedimport(_get_process_group_backend_from_env,distributed_available,get_default_process_group_backend_for_device,)frompytorch_lightning.utilities.distributedimportgroupas_groupfrompytorch_lightning.utilities.distributedimport(init_dist_connection,ReduceOp,register_ddp_comm_hook,sync_ddp_if_available,)frompytorch_lightning.utilities.exceptionsimportDeadlockDetectedExceptionfrompytorch_lightning.utilities.importsimport(_FAIRSCALE_AVAILABLE,_IS_WINDOWS,_TORCH_GREATER_EQUAL_1_9,_TORCH_GREATER_EQUAL_1_10,_TORCH_GREATER_EQUAL_1_11,)frompytorch_lightning.utilities.rank_zeroimportrank_zero_info,rank_zero_only,rank_zero_warnfrompytorch_lightning.utilities.seedimportreset_seedfrompytorch_lightning.utilities.typesimportSTEP_OUTPUTif_FAIRSCALE_AVAILABLE:fromfairscale.optimimportOSSif_TORCH_GREATER_EQUAL_1_10andtorch.distributed.is_available():fromtorch.distributed.algorithms.model_averaging.averagersimportModelAveragerlog=logging.getLogger(__name__)
[docs]classDDPStrategy(ParallelStrategy):"""Strategy for multi-process single-device training on one or multiple nodes."""strategy_name="ddp"def__init__(self,accelerator:Optional["pl.accelerators.accelerator.Accelerator"]=None,parallel_devices:Optional[List[torch.device]]=None,cluster_environment:Optional[ClusterEnvironment]=None,checkpoint_io:Optional[CheckpointIO]=None,precision_plugin:Optional[PrecisionPlugin]=None,ddp_comm_state:Optional[object]=None,ddp_comm_hook:Optional[callable]=None,ddp_comm_wrapper:Optional[callable]=None,model_averaging_period:Optional[int]=None,process_group_backend:Optional[str]=None,**kwargs:Union[Any,Dict[str,Any]],)->None:super().__init__(accelerator=accelerator,parallel_devices=parallel_devices,cluster_environment=cluster_environment,checkpoint_io=checkpoint_io,precision_plugin=precision_plugin,)log.detail(f"{self.__class__.__name__}: initializing DDP plugin")self._num_nodes=1self._ddp_kwargs=kwargsself._ddp_comm_state=ddp_comm_stateself._ddp_comm_hook=ddp_comm_hookself._ddp_comm_wrapper=ddp_comm_wrapperself._model_averaging_period=model_averaging_periodself._model_averager:Optional[ModelAverager]=Noneself._pids:Optional[List[int]]=Noneself._sync_dir:Optional[str]=Noneself._rank_0_will_call_children_scripts:bool=Falseself._process_group_backend:Optional[str]=process_group_backend@propertydefis_distributed(self)->bool:returnTrue@propertydefroot_device(self)->torch.device:returnself.parallel_devices[self.local_rank]@propertydefnum_nodes(self)->int:returnself._num_nodes@num_nodes.setterdefnum_nodes(self,num_nodes:int)->None:# note that world ranks is related to num_nodes, when resetting it, need to reset world ranksself._num_nodes=num_nodes@propertydefnum_processes(self):returnlen(self.parallel_devices)ifself.parallel_devicesisnotNoneelse0@propertydefdistributed_sampler_kwargs(self):distributed_sampler_kwargs=dict(num_replicas=(self.num_nodes*self.num_processes),rank=self.global_rank)returndistributed_sampler_kwargs@propertydef_is_single_process_single_device(self)->bool:returnTrue@propertydefprocess_group_backend(self)->Optional[str]:returnself._process_group_backenddef_configure_launcher(self)->None:ifnotself.cluster_environment.creates_processes_externally:self._launcher=_SubprocessScriptLauncher(self.cluster_environment,self.num_processes,self.num_nodes)self._rank_0_will_call_children_scripts=True
[docs]defsetup(self,trainer:"pl.Trainer")->None:super().setup(trainer)# share ddp pids to all processesself._rank_0_will_call_children_scripts=self.broadcast(self._rank_0_will_call_children_scripts)ifself._should_run_deadlock_detection():self._share_information_to_prevent_deadlock()# move the model to the correct deviceself.model_to_device()# skip wrapping the model if we are not fitting as no gradients need to be exchangedtrainer_fn=trainer.state.fniftrainer_fn!=TrainerFn.FITTING:returnifself._layer_sync:self.model=self._layer_sync.apply(self.model)self.configure_ddp()
def_setup_model(self,model:Module)->DistributedDataParallel:"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""device_ids=self.determine_ddp_device_ids()log.detail(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")returnDistributedDataParallel(module=model,device_ids=device_ids,**self._ddp_kwargs)defsetup_distributed(self):log.detail(f"{self.__class__.__name__}: setting up distributed...")reset_seed()# determine which process we are and world sizeself.set_world_ranks()# set warning rankrank_zero_only.rank=self.global_rankself._process_group_backend=self._get_process_group_backend()init_dist_connection(self.cluster_environment,self._process_group_backend)def_get_process_group_backend(self)->str:return(self._process_group_backendor_get_process_group_backend_from_env()orget_default_process_group_backend_for_device(self.root_device))defset_world_ranks(self)->None:ifself.cluster_environmentisNone:returnself.cluster_environment.set_global_rank(self.node_rank*self.num_processes+self.local_rank)self.cluster_environment.set_world_size(self.num_nodes*self.num_processes)rank_zero_only.rank=self.cluster_environment.global_rank()defpre_configure_ddp(self):# if unset, default `find_unused_parameters` `True`# Many models require setting this parameter to True, as there are corner cases# when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True.# This flag does come with a performance hit, so it is suggested to disable in cases where it is possible.self._ddp_kwargs["find_unused_parameters"]=self._ddp_kwargs.get("find_unused_parameters",True)def_register_ddp_hooks(self)->None:log.detail(f"{self.__class__.__name__}: registering ddp hooks")# In 1.8, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode# Since 1.9, DDP communication hooks can work on all backends.if_TORCH_GREATER_EQUAL_1_9or(self.root_device.type=="cuda"andself._is_single_process_single_device):register_ddp_comm_hook(model=self.model,ddp_comm_state=self._ddp_comm_state,ddp_comm_hook=self._ddp_comm_hook,ddp_comm_wrapper=self._ddp_comm_wrapper,)if_TORCH_GREATER_EQUAL_1_10andself.lightning_module.trainer.state.fn==TrainerFn.FITTING:importtorch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hookaspost_localSGDifisinstance(self._ddp_comm_state,post_localSGD.PostLocalSGDState):self._enable_model_averaging()def_enable_model_averaging(self)->None:# Only called when PyTorch version >= 1.10log.detail(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD")ifself._model_averaging_periodisNone:raiseValueError("Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy.")fromtorch.distributed.optimimportDistributedOptimizer,PostLocalSGDOptimizer,ZeroRedundancyOptimizerforoptimizerinself.optimizers:ifisinstance(optimizer,LightningOptimizer):optimizer=optimizer._optimizeris_distributed_optimizer=isinstance(optimizer,DistributedOptimizer)ifnot_IS_WINDOWSelseFalseif(is_distributed_optimizerorisinstance(optimizer,ZeroRedundancyOptimizer)or(_FAIRSCALE_AVAILABLEandisinstance(optimizer,OSS))orisinstance(optimizer,PostLocalSGDOptimizer)):raiseValueError(f"Currently model averaging cannot work with a distributed optimizer of type "f"{optimizer.__class__.__name__}.")self._model_averager=torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager(period=self._model_averaging_period,warmup_steps=self._ddp_comm_state.start_localSGD_iter)
[docs]defoptimizer_step(self,optimizer:Optimizer,opt_idx:int,closure:Callable[[],Any],model:Optional[Union["pl.LightningModule",Module]]=None,**kwargs:Any,)->Any:"""Performs the actual optimizer step. Args: optimizer: the optimizer performing the step opt_idx: index of the current optimizer closure: closure calculating the loss value model: reference to the model, optionally defining optimizer step related hooks **kwargs: Any extra arguments to ``optimizer.step`` """optimizer_output=super().optimizer_step(optimizer,opt_idx,closure,model,**kwargs)ifnot_TORCH_GREATER_EQUAL_1_10orself._model_averagerisNone:returnoptimizer_outputparams=[paramforgroupinoptimizer.param_groupsforparamingroup["params"]ifparam.gradisnotNone]self._model_averager.average_parameters(iter(params))returnoptimizer_output
[docs]defpre_backward(self,closure_loss:torch.Tensor)->None:"""Run before precision plugin executes backward."""ifnotself.lightning_module.automatic_optimization:prepare_for_backward(self.model,closure_loss)
[docs]defmodel_to_device(self):log.detail(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")self.model.to(self.root_device)
[docs]defreduce(self,tensor,group:Optional[Any]=None,reduce_op:Union[ReduceOp,str]="mean")->torch.Tensor:"""Reduces a tensor from several distributed processes to one aggregated tensor. Args: tensor: the tensor to sync and reduce group: the process group to gather results from. Defaults to all processes (world) reduce_op: the reduction operation. Defaults to 'mean'/'avg'. Can also be a string 'sum' to calculate the sum during reduction. Return: reduced value, except when the input was not a tensor the output remains is unchanged """ifisinstance(tensor,torch.Tensor):tensor=sync_ddp_if_available(tensor,group,reduce_op=reduce_op)returntensor
[docs]defvalidation_step(self,*args,**kwargs)->Optional[STEP_OUTPUT]:withself.precision_plugin.val_step_context():ifisinstance(self.model,DistributedDataParallel):# used when calling `trainer.fit`returnself.model(*args,**kwargs)else:# used when calling `trainer.validate`returnself.lightning_module.validation_step(*args,**kwargs)
defpost_training_step(self):ifnotself.lightning_module.automatic_optimization:self.model.require_backward_grad_sync=True@classmethoddefregister_strategies(cls,strategy_registry:Dict)->None:strategy_registry.register("ddp_find_unused_parameters_false",cls,description="DDP Strategy with `find_unused_parameters` as False",find_unused_parameters=False,)strategy_registry.register(cls.strategy_name,cls,description=f"{cls.__class__.__name__}",)def_should_run_deadlock_detection(self)->bool:"""Determines whether the plugin will perform process reconciliation in case of errors. If the environment variable `PL_RECONCILE_PROCESS` is set, run detection regardless of the cluster environment. By default this is disabled. Otherwise, if the cluster environment creates the processes, allow the scheduler / parent process to perform the process termination, external to Lightning. """returnos.getenv("PL_RECONCILE_PROCESS","0")=="1"orself._rank_0_will_call_children_scriptsdef_share_information_to_prevent_deadlock(self)->None:self._share_pids()# there should be a unique sync_dir per nodes.ifself.local_rank==0:# create a temporary directory used to synchronize processes on deadlock.self._sync_dir=tempfile.mkdtemp()sync_dirs=[]global_node_rank_zero=0for_inrange(self.num_nodes):sync_dirs.append(self.broadcast(self._sync_dir,global_node_rank_zero))global_node_rank_zero+=self.world_size//self.num_nodesself._sync_dir=sync_dirs[self.node_rank]def_share_pids(self)->None:"""Make all DDP processes aware of all processes pids."""self.barrier()pids=self.all_gather(torch.tensor(os.getpid(),device=self.root_device))pids=pids.cpu().numpy().tolist()self._pids=pidsifisinstance(pids,list)else[pids]
[docs]defreconciliate_processes(self,trace:str)->None:ifself.world_size<2:returnifnotself._should_run_deadlock_detection():returnsync_dir=self._sync_dirifnotsync_dir:rank_zero_warn("Error handling mechanism for deadlock detection is uninitialized. Skipping check.")return# The cluster may be configured to periodically purge the `/tmp`# directory, in which case `sync_dir` may not exist anymore at this# point. Idempotently create it to ensure its existence.Path(sync_dir).mkdir(parents=True,exist_ok=True)# save a file locally.torch.save(True,os.path.join(sync_dir,f"{self.global_rank}.pl"))# sleep for a short timetime.sleep(3)# return if all processes wrote a file in the `sync_dir`.# todo (tchaton) Add support for non-shared file-system which will fail.iflen(os.listdir(sync_dir))==(self.world_size//self.num_nodes):returnforpidinself._pids:ifpid!=os.getpid():os.kill(pid,signal.SIGKILL)shutil.rmtree(sync_dir)raiseDeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank}\n{trace}")
[docs]defteardown(self)->None:log.detail(f"{self.__class__.__name__}: tearing down strategy")super().teardown()ifisinstance(self.model,DistributedDataParallel):if(_TORCH_GREATER_EQUAL_1_11andnotself.model.static_graphandself.model._get_ddp_logging_data().get("can_set_static_graph")):rank_zero_info("Your model can run with static graph optimizations. For future training runs, we suggest you"f" pass `Trainer(..., strategy={self.__class__.__name__}(static_graph=True))` to enable them.")# unwrap modelself.model=self.lightning_moduleif(self.lightning_module.trainerisnotNoneandself.lightning_module.trainer.state.fn==TrainerFn.FITTINGandself._layer_sync):# `self.lightning_module.trainer` can be None if teardown gets called on an exception before# the trainer gets set on the LightningModuleself.model=self._layer_sync.revert(self.model)ifself.root_device.type=="cuda":# GPU teardownlog.detail(f"{self.__class__.__name__}: moving model to CPU")self.lightning_module.cpu()# clean up memorytorch.cuda.empty_cache()
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.