Source code for pytorch_lightning.utilities.distributed
## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Utilities that can be used with distributed training."""importloggingimportosfromtypingimportAny,Callable,Dict,List,Optional,Tuple,Unionimporttorchfromtorch.nn.parallel.distributedimportDistributedDataParallelimportpytorch_lightningasplfrompytorch_lightning.utilities.importsimport_HPU_AVAILABLE,_TORCH_GREATER_EQUAL_1_9,_TPU_AVAILABLEfrompytorch_lightning.utilities.rank_zeroimportrank_zero_debugasnew_rank_zero_debugfrompytorch_lightning.utilities.rank_zeroimportrank_zero_only# noqa: F401frompytorch_lightning.utilities.rank_zeroimportrank_zero_deprecationfrompytorch_lightning.utilities.rank_zeroimportrank_zero_infoasnew_rank_zero_infofrompytorch_lightning.utilities.rank_zeroimportrank_zero_warnasnew_rank_zero_warnif_TPU_AVAILABLE:importtorch_xla.core.xla_modelasxmiftorch.distributed.is_available():fromtorch.distributedimportgroup,ReduceOpelse:classReduceOp:# type: ignore # (see https://github.com/python/mypy/issues/1153)SUM=Noneclassgroup:# type: ignoreWORLD=Nonelog=logging.getLogger(__name__)
[docs]defgather_all_tensors(result:torch.Tensor,group:Optional[Any]=None)->List[torch.Tensor]:"""Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes. Args: result: the value to sync group: the process group to gather results from. Defaults to all processes (world) Return: gathered_result: list with size equal to the process group where gathered_result[i] corresponds to result tensor from process i """ifgroupisNone:group=torch.distributed.group.WORLD# convert tensors to contiguous formatresult=result.contiguous()world_size=torch.distributed.get_world_size(group)gathered_result=[torch.zeros_like(result)for_inrange(world_size)]# sync and broadcast alltorch.distributed.barrier(group=group)torch.distributed.all_gather(gathered_result,result,group)returngathered_result
[docs]defsync_ddp_if_available(result:torch.Tensor,group:Optional[Any]=None,reduce_op:Optional[Union[ReduceOp,str]]=None)->torch.Tensor:"""Function to reduce a tensor across worker processes during distributed training. Args: result: the value to sync and reduce (typically tensor or number) group: the process group to gather results from. Defaults to all processes (world) reduce_op: the reduction operation. Defaults to sum. Can also be a string of 'avg', 'mean' to calculate the mean during reduction. Return: reduced value """ifdistributed_available():returnsync_ddp(result,group=group,reduce_op=reduce_op)returnresult
[docs]defsync_ddp(result:torch.Tensor,group:Optional[Any]=None,reduce_op:Optional[Union[ReduceOp,str]]=None)->torch.Tensor:"""Function to reduce the tensors from several ddp processes to one main process. Args: result: the value to sync and reduce (typically tensor or number) group: the process group to gather results from. Defaults to all processes (world) reduce_op: the reduction operation. Defaults to sum. Can also be a string of 'avg', 'mean' to calculate the mean during reduction. Return: reduced value """divide_by_world_size=FalseifgroupisNone:group=torch.distributed.group.WORLDifisinstance(reduce_op,str):ifreduce_op.lower()in("avg","mean"):op=ReduceOp.SUMdivide_by_world_size=Trueelse:op=getattr(ReduceOp,reduce_op.upper())else:op=reduce_op# WA for HPU. HPU doesn't support Long types, forcefully set it to floatif_HPU_AVAILABLE:is_hpu_backend=os.environ.get("HCCL_DISTRIBUTED_BACKEND")=="1"ifis_hpu_backend:if(result.type()=="torch.LongTensor")or(result.type()=="torch.hpu.LongTensor"):new_rank_zero_info("Long tensor unsupported on HPU, casting to float")result=result.float()# sync all processes before reductiontorch.distributed.barrier(group=group)torch.distributed.all_reduce(result,op=op,group=group,async_op=False)ifdivide_by_world_size:result=result/torch.distributed.get_world_size(group)returnresult
[docs]defall_gather_ddp_if_available(tensor:torch.Tensor,group:Optional["torch.distributed.ProcessGroup"]=None,sync_grads:bool=False)->torch.Tensor:"""Function to gather a tensor from several distributed processes. Args: tensor: tensor of shape (batch, ...) group: the process group to gather results from. Defaults to all processes (world) sync_grads: flag that allows users to synchronize gradients for all_gather op Return: A tensor of shape (world_size, batch, ...) """group=groupifgroupisnotNoneelsetorch.distributed.group.WORLDifdistributed_available():ifsync_grads:returnAllGatherGrad.apply(tensor,group)withtorch.no_grad():returnAllGatherGrad.apply(tensor,group)returntensor
[docs]defregister_ddp_comm_hook(model:DistributedDataParallel,ddp_comm_state:Optional[object]=None,ddp_comm_hook:Optional[Callable]=None,ddp_comm_wrapper:Optional[Callable]=None,)->None:"""Function to register communication hook for DDP model https://pytorch.org/docs/master/ddp_comm_hooks.html. Args: model: DDP model ddp_comm_state: state is passed to the hook and can be used to maintain and update any state information that users would like to maintain as part of the training process. Examples: error feedback in gradient compression, peers to communicate with next in GossipGrad etc. ddp_comm_hook: hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future This callable function is called once the bucket is ready. The hook can perform whatever processing is needed and return a Future indicating completion of any async work (ex: allreduce). If the hook doesn't perform any communication, it can also just return a completed Future. The Future should hold the new value of grad bucket's tensors. Once a bucket is ready, c10d reducer would call this hook and use the tensors returned by the Future and copy grads to individual parameters. ddp_comm_wrapper: communication hook wrapper to support a communication hook such as FP16 compression as wrapper, which could be combined with ddp_comm_hook .. warning :: DDP communication hook needs pytorch version at least 1.8.0 .. warning :: DDP communication wrapper needs pytorch version at least 1.9.0 Post-localSGD hook needs pytorch version at least 1.9.0 Examples: >>> from torch.distributed.algorithms.ddp_comm_hooks import ( # doctest: +SKIP ... default_hooks as default, ... powerSGD_hook as powerSGD, ... post_localSGD_hook as post_localSGD, ... ) >>> >>> # fp16_compress_hook for compress gradients >>> ddp_model = ... >>> register_ddp_comm_hook( # doctest: +SKIP ... model=ddp_model, ... ddp_comm_hook=default.fp16_compress_hook, ... ) >>> >>> # powerSGD_hook >>> ddp_model = ... >>> register_ddp_comm_hook( # doctest: +SKIP ... model=ddp_model, ... ddp_comm_state=powerSGD.PowerSGDState( ... process_group=None, ... matrix_approximation_rank=1, ... start_powerSGD_iter=5000, ... ), ... ddp_comm_hook=powerSGD.powerSGD_hook, ... ) >>> >>> # post_localSGD_hook >>> subgroup, _ = torch.distributed.new_subgroups() # doctest: +SKIP >>> ddp_model = ... >>> register_ddp_comm_hook( # doctest: +SKIP ... model=ddp_model, ... state=post_localSGD.PostLocalSGDState( ... process_group=None, ... subgroup=subgroup, ... start_localSGD_iter=1_000, ... ), ... ddp_comm_hook=post_localSGD.post_localSGD_hook, ... ) >>> >>> # fp16_compress_wrapper combined with other communication hook >>> ddp_model = ... >>> register_ddp_comm_hook( # doctest: +SKIP ... model=ddp_model, ... ddp_comm_state=powerSGD.PowerSGDState( ... process_group=None, ... matrix_approximation_rank=1, ... start_powerSGD_iter=5000, ... ), ... ddp_comm_hook=powerSGD.powerSGD_hook, ... ddp_comm_wrapper=default.fp16_compress_wrapper, ... ) """ifddp_comm_hookisNone:return# inform mypy that ddp_comm_hook is callableddp_comm_hook:Callable=ddp_comm_hookifddp_comm_wrapperisnotNone:ifnot_TORCH_GREATER_EQUAL_1_9:new_rank_zero_warn("Not applying DDP comm wrapper. To use communication wrapper, please use pytorch>=1.9.0.")else:new_rank_zero_info(f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__}).")ddp_comm_hook=ddp_comm_wrapper(ddp_comm_hook)new_rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.")model.register_comm_hook(state=ddp_comm_state,hook=ddp_comm_hook)
deftpu_distributed()->bool:return_TPU_AVAILABLEandxm.xrt_world_size()>1defget_default_process_group_backend_for_device(device:torch.device)->str:return"nccl"ifdevice.type=="cuda"else"gloo"def_get_process_group_backend_from_env()->Optional[str]:torch_backend=os.getenv("PL_TORCH_DISTRIBUTED_BACKEND")iftorch_backendisnotNone:rank_zero_deprecation("Environment variable `PL_TORCH_DISTRIBUTED_BACKEND`"" was deprecated in v1.6 and will be removed in v1.8."" Specify `process_group_backend` directly on the strategy constructor.")returntorch_backend
[docs]definit_dist_connection(cluster_environment:"pl.plugins.environments.ClusterEnvironment",torch_distributed_backend:str,global_rank:Optional[int]=None,world_size:Optional[int]=None,**kwargs:Any,)->None:"""Utility function to initialize distributed connection by setting env variables and initializing the distributed process group. Args: cluster_environment: ``ClusterEnvironment`` instance torch_distributed_backend: backend to use (includes `nccl` and `gloo`) global_rank: rank of the current process world_size: number of processes in the group kwargs: kwargs for ``init_process_group`` Raises: RuntimeError: If ``torch.distributed`` is not available """ifnottorch.distributed.is_available():raiseRuntimeError("torch.distributed is not available. Cannot initialize distributed process group")iftorch.distributed.is_initialized():log.debug("torch.distributed is already initialized. Exiting early")returnglobal_rank=global_rankifglobal_rankisnotNoneelsecluster_environment.global_rank()world_size=world_sizeifworld_sizeisnotNoneelsecluster_environment.world_size()os.environ["MASTER_ADDR"]=cluster_environment.main_addressos.environ["MASTER_PORT"]=str(cluster_environment.main_port)log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}")torch.distributed.init_process_group(torch_distributed_backend,rank=global_rank,world_size=world_size,**kwargs)# on rank=0 let everyone know training is startingnew_rank_zero_info(f"{'-'*100}\n"f"distributed_backend={torch_distributed_backend}\n"f"All distributed processes registered. Starting with {world_size} processes\n"f"{'-'*100}\n")
def_broadcast_object_list(obj:Any,rank:int)->Any:objects=[objiftorch.distributed.get_rank()==rankelseNone]torch.distributed.broadcast_object_list(objects,src=rank)returnobjects[0]# TODO: Refactor with the Strategy Collectives once finalized.def_collect_states_on_rank_zero(state:Dict[str,Any])->Dict[int,Any]:"""This distributed utility collects dictionary state across all processes. Args: state: Dictionary containing the state of the current process device: Current process device. Returns: states: On global rank 0, a dictionary where the primary keys are the process rank and the values their associated states. Otherwise, returns None. """ifnotdistributed_available():return{0:state}return{rank:_broadcast_object_list(state,rank)forrankinrange(torch.distributed.get_world_size())}defrank_zero_info(*args:Any,**kwargs:Any)->Any:rank_zero_deprecation("pytorch_lightning.utilities.distributed.rank_zero_info has been deprecated in v1.6"" and will be removed in v1.8."" Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead.")returnnew_rank_zero_info(*args,**kwargs)defrank_zero_debug(*args:Any,**kwargs:Any)->Any:rank_zero_deprecation("pytorch_lightning.utilities.distributed.rank_zero_debug has been deprecated in v1.6"" and will be removed in v1.8."" Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead.")returnnew_rank_zero_debug(*args,**kwargs)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.