Source code for pytorch_lightning.strategies.ddp_spawn
# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importloggingimportosfromdatetimeimporttimedeltafromtypingimportAny,Callable,Dict,List,Optional,Unionimporttorchimporttorch.distributedfromtorchimportTensorfromtorch.nnimportModulefromtorch.nn.parallel.distributedimportDistributedDataParallelfromtyping_extensionsimportLiteralimportpytorch_lightningasplfromlightning_fabric.pluginsimportCheckpointIO,ClusterEnvironmentfromlightning_fabric.plugins.collectives.torch_collectiveimportdefault_pg_timeoutfromlightning_fabric.utilities.distributedimport(_distributed_available,_get_default_process_group_backend_for_device,_init_dist_connection,_sync_ddp_if_available,)fromlightning_fabric.utilities.distributedimportgroupas_groupfromlightning_fabric.utilities.importsimport_TORCH_GREATER_EQUAL_1_11fromlightning_fabric.utilities.optimizerimport_optimizers_to_devicefromlightning_fabric.utilities.typesimportReduceOpfrompytorch_lightning.overridesimportLightningDistributedModulefrompytorch_lightning.overrides.baseimport_LightningPrecisionModuleWrapperBasefrompytorch_lightning.overrides.distributedimportprepare_for_backwardfrompytorch_lightning.plugins.precisionimportPrecisionPluginfrompytorch_lightning.strategies.launchers.multiprocessingimport_MultiProcessingLauncherfrompytorch_lightning.strategies.parallelimportParallelStrategyfrompytorch_lightning.strategies.strategyimportTBroadcastfrompytorch_lightning.trainer.statesimportTrainerFnfrompytorch_lightning.utilities.distributedimportregister_ddp_comm_hookfrompytorch_lightning.utilities.rank_zeroimportrank_zero_info,rank_zero_onlyfrompytorch_lightning.utilities.typesimportPredictStep,STEP_OUTPUT,TestStep,ValidationSteplog=logging.getLogger(__name__)_DDP_FORK_ALIASES=("ddp_fork","ddp_fork_find_unused_parameters_false","ddp_notebook","ddp_notebook_find_unused_parameters_false",)
[docs]classDDPSpawnStrategy(ParallelStrategy):"""Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training finishes."""strategy_name="ddp_spawn"def__init__(self,accelerator:Optional["pl.accelerators.Accelerator"]=None,parallel_devices:Optional[List[torch.device]]=None,cluster_environment:Optional[ClusterEnvironment]=None,checkpoint_io:Optional[CheckpointIO]=None,precision_plugin:Optional[PrecisionPlugin]=None,ddp_comm_state:Optional[object]=None,ddp_comm_hook:Optional[Callable]=None,ddp_comm_wrapper:Optional[Callable]=None,process_group_backend:Optional[str]=None,timeout:Optional[timedelta]=default_pg_timeout,start_method:Literal["spawn","fork","forkserver"]="spawn",**kwargs:Any,):super().__init__(accelerator=accelerator,parallel_devices=parallel_devices,cluster_environment=cluster_environment,checkpoint_io=checkpoint_io,precision_plugin=precision_plugin,)self._num_nodes=1self._ddp_kwargs=kwargsself._ddp_comm_state=ddp_comm_stateself._ddp_comm_hook=ddp_comm_hookself._ddp_comm_wrapper=ddp_comm_wrapperself._local_rank=0self._process_group_backend:Optional[str]=process_group_backendself._timeout:Optional[timedelta]=timeoutself._start_method=start_method@propertydefnum_nodes(self)->int:returnself._num_nodes@num_nodes.setterdefnum_nodes(self,num_nodes:int)->None:# note that world ranks is related to num_nodes, when resetting it, need to reset world ranksself._num_nodes=num_nodes@propertydeflocal_rank(self)->int:returnself._local_rank@propertydefroot_device(self)->torch.device:assertself.parallel_devicesisnotNonereturnself.parallel_devices[self.local_rank]@propertydefnum_processes(self)->int:returnlen(self.parallel_devices)ifself.parallel_devicesisnotNoneelse0@propertydefdistributed_sampler_kwargs(self)->Dict[str,int]:distributed_sampler_kwargs=dict(num_replicas=(self.num_nodes*self.num_processes),rank=self.global_rank)returndistributed_sampler_kwargs@propertydef_is_single_process_single_device(self)->bool:returnTrue@propertydefprocess_group_backend(self)->Optional[str]:returnself._process_group_backenddef_configure_launcher(self)->None:self._launcher=_MultiProcessingLauncher(self,start_method=self._start_method)
[docs]defsetup(self,trainer:"pl.Trainer")->None:assertself.cluster_environmentisnotNoneos.environ["MASTER_PORT"]=str(self.cluster_environment.main_port)assertself.acceleratorisnotNoneself.accelerator.setup(trainer)# move the model to the correct deviceself.model_to_device()# skip wrapping the model if we are not fitting as no gradients need to be exchangedtrainer_fn=trainer.state.fniftrainer_fn==TrainerFn.FITTING:ifself._layer_sync:assertself.modelisnotNoneself.model=self._layer_sync.apply(self.model)self.setup_precision_plugin()iftrainer_fn==TrainerFn.FITTING:self.configure_ddp()
def_setup_model(self,model:Module)->DistributedDataParallel:"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""returnDistributedDataParallel(module=model,device_ids=self.determine_ddp_device_ids(),**self._ddp_kwargs)defsetup_distributed(self)->None:log.detail(f"{self.__class__.__name__}: setting up distributed...")self.set_world_ranks()rank_zero_only.rank=self.global_rankself._process_group_backend=self._get_process_group_backend()assertself.cluster_environmentisnotNone_init_dist_connection(self.cluster_environment,self._process_group_backend,self.global_rank,self.world_size,timeout=self._timeout,)defset_world_ranks(self)->None:ifself.cluster_environmentisNone:returnself.cluster_environment.set_global_rank(self.node_rank*self.num_processes+self.local_rank)self.cluster_environment.set_world_size(self.num_nodes*self.num_processes)rank_zero_only.rank=self.cluster_environment.global_rank()def_get_process_group_backend(self)->str:returnself._process_group_backendor_get_default_process_group_backend_for_device(self.root_device)defpre_configure_ddp(self)->None:# if unset, default `find_unused_parameters` `True`# Many models require setting this parameter to True, as there are corner cases# when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True.# This flag does come with a performance hit, so it is suggested to disable in cases where it is possible.self._ddp_kwargs["find_unused_parameters"]=self._ddp_kwargs.get("find_unused_parameters",True)def_register_ddp_hooks(self)->None:# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084ifself.root_device.type=="cuda"andself._is_single_process_single_device:assertisinstance(self.model,DistributedDataParallel)register_ddp_comm_hook(model=self.model,ddp_comm_state=self._ddp_comm_state,ddp_comm_hook=self._ddp_comm_hook,ddp_comm_wrapper=self._ddp_comm_wrapper,)defconfigure_ddp(self)->None:self.pre_configure_ddp()assertisinstance(self.model,(pl.LightningModule,_LightningPrecisionModuleWrapperBase))self.model=self._setup_model(LightningDistributedModule(self.model))self._register_ddp_hooks()# set up optimizers after the wrapped module has been moved to the deviceassertself.lightning_moduleisnotNoneself.setup_optimizers(self.lightning_module.trainer)_optimizers_to_device(self.optimizers,self.root_device)defdetermine_ddp_device_ids(self)->Optional[List[int]]:ifself.root_device.type=="cpu":returnNonereturn[self.root_device.index]
[docs]defmodel_to_device(self)->None:ifself.root_device.type=="cuda":# set the device on the spawned subprocessestorch.cuda.set_device(self.root_device)assertself.modelisnotNoneself.model.to(self.root_device)
[docs]defpre_backward(self,closure_loss:Tensor)->None:"""Run before precision plugin executes backward."""ifnotisinstance(self.model,DistributedDataParallel):returnassertself.lightning_moduleisnotNoneifnotself.lightning_module.automatic_optimization:prepare_for_backward(self.model,closure_loss)
[docs]defreduce(self,tensor:Tensor,group:Optional[Any]=None,reduce_op:Optional[Union[ReduceOp,str]]="mean")->Tensor:"""Reduces a tensor from several distributed processes to one aggregated tensor. Args: tensor: the tensor to sync and reduce group: the process group to gather results from. Defaults to all processes (world) reduce_op: the reduction operation. Defaults to 'mean'/'avg'. Can also be a string 'sum' to calculate the sum during reduction. Return: reduced value, except when the input was not a tensor the output remains is unchanged """ifisinstance(tensor,Tensor):tensor=_sync_ddp_if_available(tensor,group,reduce_op=reduce_op)returntensor
[docs]defvalidation_step(self,*args:Any,**kwargs:Any)->Optional[STEP_OUTPUT]:withself.precision_plugin.val_step_context():assertself.lightning_moduleisnotNoneassertself.modelisnotNoneifself.lightning_module.trainer.state.fn==TrainerFn.FITTING:# used when calling `trainer.fit`returnself.model(*args,**kwargs)else:# used when calling `trainer.validate`assertisinstance(self.model,ValidationStep)returnself.model.validation_step(*args,**kwargs)
defpost_training_step(self)->None:assertself.lightning_moduleisnotNoneifnotself.lightning_module.automatic_optimization:assertself.modelisnotNoneself.model.require_backward_grad_sync=True# type: ignore[assignment]@classmethoddefregister_strategies(cls,strategy_registry:Dict)->None:entries=(("ddp_spawn","spawn"),("ddp_fork","fork"),("ddp_notebook","fork"),)forname,start_methodinentries:strategy_registry.register(name,cls,description=f"DDP strategy with `start_method` '{start_method}'",start_method=start_method,)entries=(("ddp_spawn_find_unused_parameters_false","spawn"),("ddp_fork_find_unused_parameters_false","fork"),("ddp_notebook_find_unused_parameters_false","fork"),)forname,start_methodinentries:strategy_registry.register(name,cls,description=f"DDP strategy with `find_unused_parameters` as False and `start_method` '{start_method}'",find_unused_parameters=False,start_method=start_method,)
[docs]defteardown(self)->None:log.detail(f"{self.__class__.__name__}: tearing down strategy")pl_module=self.lightning_moduleifisinstance(self.model,DistributedDataParallel):if(_TORCH_GREATER_EQUAL_1_11andnotself.model.static_graphandself.model._get_ddp_logging_data().get("can_set_static_graph")):rank_zero_info("Your model can run with static graph optimizations. For future training runs, we suggest you"f" pass `Trainer(..., strategy={self.__class__.__name__}(static_graph=True))` to enable them.")# unwrap modelself.model=pl_moduleif(pl_moduleisnotNone# `self.lightning_module._trainer` can be None if teardown gets called on an exception before# the trainer gets set on the LightningModuleandpl_module._trainerisnotNoneandpl_module._trainer.state.fn==TrainerFn.FITTINGandself._layer_sync):assertself.modelisnotNoneself.model=self._layer_sync.revert(self.model)super().teardown()
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.