Source code for pytorch_lightning.strategies.ddp_spawn
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importloggingimportosfromtypingimportAny,Dict,List,Optional,Unionimporttorchimporttorch.distributedfromtorch.nnimportModulefromtorch.nn.parallel.distributedimportDistributedDataParallelimportpytorch_lightningasplfrompytorch_lightning.overridesimportLightningDistributedModulefrompytorch_lightning.overrides.distributedimportprepare_for_backwardfrompytorch_lightning.plugins.environments.cluster_environmentimportClusterEnvironmentfrompytorch_lightning.plugins.io.checkpoint_pluginimportCheckpointIOfrompytorch_lightning.plugins.precisionimportPrecisionPluginfrompytorch_lightning.strategies.launchers.spawnimport_SpawnLauncherfrompytorch_lightning.strategies.parallelimportParallelStrategyfrompytorch_lightning.trainer.statesimportTrainerFnfrompytorch_lightning.utilities.distributedimport(_get_process_group_backend_from_env,distributed_available,get_default_process_group_backend_for_device,)frompytorch_lightning.utilities.distributedimportgroupas_groupfrompytorch_lightning.utilities.distributedimport(init_dist_connection,ReduceOp,register_ddp_comm_hook,sync_ddp_if_available,)frompytorch_lightning.utilities.importsimport_TORCH_GREATER_EQUAL_1_11frompytorch_lightning.utilities.rank_zeroimportrank_zero_info,rank_zero_onlyfrompytorch_lightning.utilities.seedimportreset_seedfrompytorch_lightning.utilities.typesimportSTEP_OUTPUTlog=logging.getLogger(__name__)
[docs]classDDPSpawnStrategy(ParallelStrategy):"""Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training finishes."""strategy_name="ddp_spawn"def__init__(self,accelerator:Optional["pl.accelerators.accelerator.Accelerator"]=None,parallel_devices:Optional[List[torch.device]]=None,cluster_environment:Optional[ClusterEnvironment]=None,checkpoint_io:Optional[CheckpointIO]=None,precision_plugin:Optional[PrecisionPlugin]=None,ddp_comm_state:Optional[object]=None,ddp_comm_hook:Optional[callable]=None,ddp_comm_wrapper:Optional[callable]=None,process_group_backend:Optional[str]=None,**kwargs:Any,):super().__init__(accelerator=accelerator,parallel_devices=parallel_devices,cluster_environment=cluster_environment,checkpoint_io=checkpoint_io,precision_plugin=precision_plugin,)self._num_nodes=1self._ddp_kwargs=kwargsself._ddp_comm_state=ddp_comm_stateself._ddp_comm_hook=ddp_comm_hookself._ddp_comm_wrapper=ddp_comm_wrapperself._local_rank=0self._process_group_backend:Optional[str]=process_group_backend@propertydefnum_nodes(self)->int:returnself._num_nodes@num_nodes.setterdefnum_nodes(self,num_nodes:int)->None:# note that world ranks is related to num_nodes, when resetting it, need to reset world ranksself._num_nodes=num_nodes@propertydeflocal_rank(self)->int:returnself._local_rank@propertydefroot_device(self):returnself.parallel_devices[self.local_rank]@propertydefnum_processes(self):returnlen(self.parallel_devices)ifself.parallel_devicesisnotNoneelse0@propertydefdistributed_sampler_kwargs(self):distributed_sampler_kwargs=dict(num_replicas=(self.num_nodes*self.num_processes),rank=self.global_rank)returndistributed_sampler_kwargs@propertydef_is_single_process_single_device(self):returnTrue@propertydefprocess_group_backend(self)->Optional[str]:returnself._process_group_backenddef_configure_launcher(self):self._launcher=_SpawnLauncher(self)
[docs]defsetup(self,trainer:"pl.Trainer")->None:os.environ["MASTER_PORT"]=str(self.cluster_environment.main_port)super().setup(trainer)# move the model to the correct deviceself.model_to_device()trainer_fn=self.lightning_module.trainer.state.fniftrainer_fn!=TrainerFn.FITTING:returnifself._layer_sync:self.model=self._layer_sync.apply(self.model)# skip wrapping the model if we are not fitting as no gradients need to be exchangedself.configure_ddp()
def_setup_model(self,model:Module)->DistributedDataParallel:"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""returnDistributedDataParallel(module=model,device_ids=self.determine_ddp_device_ids(),**self._ddp_kwargs)defset_world_ranks(self,process_idx:int=0)->None:self._local_rank=process_idxifself.cluster_environmentisNone:returnself.cluster_environment.set_global_rank(self.node_rank*self.num_processes+self.local_rank)self.cluster_environment.set_world_size(self.num_nodes*self.num_processes)rank_zero_only.rank=self.cluster_environment.global_rank()def_worker_setup(self,process_idx:int):reset_seed()self.set_world_ranks(process_idx)rank_zero_only.rank=self.global_rankself._process_group_backend=self._get_process_group_backend()init_dist_connection(self.cluster_environment,self._process_group_backend,self.global_rank,self.world_size)def_get_process_group_backend(self)->str:return(self._process_group_backendor_get_process_group_backend_from_env()orget_default_process_group_backend_for_device(self.root_device))defpre_configure_ddp(self):# if unset, default `find_unused_parameters` `True`# Many models require setting this parameter to True, as there are corner cases# when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True.# This flag does come with a performance hit, so it is suggested to disable in cases where it is possible.self._ddp_kwargs["find_unused_parameters"]=self._ddp_kwargs.get("find_unused_parameters",True)def_register_ddp_hooks(self)->None:# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084ifself.root_device.type=="cuda"andself._is_single_process_single_device:register_ddp_comm_hook(model=self.model,ddp_comm_state=self._ddp_comm_state,ddp_comm_hook=self._ddp_comm_hook,ddp_comm_wrapper=self._ddp_comm_wrapper,)defconfigure_ddp(self)->None:self.pre_configure_ddp()self.model=self._setup_model(LightningDistributedModule(self.model))self._register_ddp_hooks()defdetermine_ddp_device_ids(self):ifself.root_device.type=="cpu":returnNonereturn[self.root_device.index]
[docs]defmodel_to_device(self):ifself.root_device.type=="cuda":# set the device on the spawned subprocessestorch.cuda.set_device(self.root_device)self.model.to(self.root_device)
[docs]defpre_backward(self,closure_loss:torch.Tensor)->None:"""Run before precision plugin executes backward."""ifnotself.lightning_module.automatic_optimization:prepare_for_backward(self.model,closure_loss)
[docs]defreduce(self,tensor,group:Optional[Any]=None,reduce_op:Union[ReduceOp,str]="mean")->torch.Tensor:"""Reduces a tensor from several distributed processes to one aggregated tensor. Args: tensor: the tensor to sync and reduce group: the process group to gather results from. Defaults to all processes (world) reduce_op: the reduction operation. Defaults to 'mean'/'avg'. Can also be a string 'sum' to calculate the sum during reduction. Return: reduced value, except when the input was not a tensor the output remains is unchanged """ifisinstance(tensor,torch.Tensor):tensor=sync_ddp_if_available(tensor,group,reduce_op=reduce_op)returntensor
[docs]defvalidation_step(self,*args,**kwargs)->Optional[STEP_OUTPUT]:withself.precision_plugin.val_step_context():ifisinstance(self.model,DistributedDataParallel):# used when calling `trainer.fit`returnself.model(*args,**kwargs)else:# used when calling `trainer.validate`returnself.lightning_module.validation_step(*args,**kwargs)
defpost_training_step(self):ifnotself.lightning_module.automatic_optimization:self.model.require_backward_grad_sync=True@classmethoddefregister_strategies(cls,strategy_registry:Dict)->None:strategy_registry.register("ddp_spawn_find_unused_parameters_false",cls,description="DDPSpawn Strategy with `find_unused_parameters` as False",find_unused_parameters=False,)strategy_registry.register(cls.strategy_name,cls,description=f"{cls.__class__.__name__}",)
[docs]defteardown(self)->None:log.detail(f"{self.__class__.__name__}: tearing down strategy")super().teardown()ifisinstance(self.model,DistributedDataParallel):if(_TORCH_GREATER_EQUAL_1_11andnotself.model.static_graphandself.model._get_ddp_logging_data().get("can_set_static_graph")):rank_zero_info("Your model can run with static graph optimizations. For future training runs, we suggest you"f" pass `Trainer(..., strategy={self.__class__.__name__}(static_graph=True))` to enable them.")# unwrap modelself.model=self.lightning_moduleif(self.lightning_module.trainerisnotNoneandself.lightning_module.trainer.state.fn==TrainerFn.FITTINGandself._layer_sync):# `self.lightning_module.trainer` can be None if teardown gets called on an exception before# the trainer gets set on the LightningModuleself.model=self._layer_sync.revert(self.model)ifself.root_device.type=="cuda":# GPU teardownlog.detail(f"{self.__class__.__name__}: moving model to CPU")self.lightning_module.cpu()# clean up memorytorch.cuda.empty_cache()
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.