Source code for pytorch_lightning.utilities.distributed
## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Utilities that can be used with distributed training."""importloggingimportosfromtypingimportAny,Callable,Dict,List,Optional,Tuple,Unionimporttorchimporttorch.nn.functionalasFfromtorchimportTensorfromtorch.nn.parallel.distributedimportDistributedDataParallelimportpytorch_lightningasplfrompytorch_lightning.utilities.importsimport_HPU_AVAILABLE,_TPU_AVAILABLEfrompytorch_lightning.utilities.rank_zeroimportrank_zero_debugasnew_rank_zero_debugfrompytorch_lightning.utilities.rank_zeroimportrank_zero_only# noqa: F401frompytorch_lightning.utilities.rank_zeroimportrank_zero_deprecationfrompytorch_lightning.utilities.rank_zeroimportrank_zero_infoasnew_rank_zero_infoif_TPU_AVAILABLE:importtorch_xla.core.xla_modelasxmiftorch.distributed.is_available():fromtorch.distributedimportgroup,ReduceOpelse:classReduceOp:# type: ignore # (see https://github.com/python/mypy/issues/1153)SUM=Noneclassgroup:# type: ignoreWORLD=Nonelog=logging.getLogger(__name__)
[docs]defgather_all_tensors(result:Tensor,group:Optional[Any]=None)->List[Tensor]:"""Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes. Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case tensors are padded, gathered and then trimmed to secure equal workload for all processes. Args: result: the value to sync group: the process group to gather results from. Defaults to all processes (world) Return: gathered_result: list with size equal to the process group where gathered_result[i] corresponds to result tensor from process i """ifgroupisNone:group=torch.distributed.group.WORLD# convert tensors to contiguous formatresult=result.contiguous()world_size=torch.distributed.get_world_size(group)torch.distributed.barrier(group=group)# if the tensor is scalar, things are easyifresult.ndim==0:return_simple_gather_all_tensors(result,group,world_size)# 1. Gather sizes of all tensorslocal_size=torch.tensor(result.shape,device=result.device)local_sizes=[torch.zeros_like(local_size)for_inrange(world_size)]torch.distributed.all_gather(local_sizes,local_size,group=group)max_size=torch.stack(local_sizes).max(dim=0).valuesall_sizes_equal=all(all(ls==max_size)forlsinlocal_sizes)# 2. If shapes are all the same, then do a simple gather:ifall_sizes_equal:return_simple_gather_all_tensors(result,group,world_size)# 3. If not, we need to pad each local tensor to maximum size, gather and then truncatepad_dims=[]pad_by=(max_size-local_size).detach().cpu()forvalinreversed(pad_by):pad_dims.append(0)pad_dims.append(val.item())result_padded=F.pad(result,pad_dims)gathered_result=[torch.zeros_like(result_padded)for_inrange(world_size)]torch.distributed.all_gather(gathered_result,result_padded,group)foridx,item_sizeinenumerate(local_sizes):slice_param=[slice(dim_size)fordim_sizeinitem_size]gathered_result[idx]=gathered_result[idx][slice_param]returngathered_result
[docs]defsync_ddp_if_available(result:Tensor,group:Optional[Any]=None,reduce_op:Optional[Union[ReduceOp,str]]=None)->Tensor:"""Function to reduce a tensor across worker processes during distributed training. Args: result: the value to sync and reduce (typically tensor or number) group: the process group to gather results from. Defaults to all processes (world) reduce_op: the reduction operation. Defaults to sum. Can also be a string of 'avg', 'mean' to calculate the mean during reduction. Return: reduced value """ifdistributed_available():returnsync_ddp(result,group=group,reduce_op=reduce_op)returnresult
[docs]defsync_ddp(result:Tensor,group:Optional[Any]=None,reduce_op:Optional[Union[ReduceOp,str]]=None)->Tensor:"""Function to reduce the tensors from several ddp processes to one main process. Args: result: the value to sync and reduce (typically tensor or number) group: the process group to gather results from. Defaults to all processes (world) reduce_op: the reduction operation. Defaults to sum. Can also be a string of 'avg', 'mean' to calculate the mean during reduction. Return: reduced value """divide_by_world_size=FalseifgroupisNone:group=torch.distributed.group.WORLDop:Optional[ReduceOp]ifisinstance(reduce_op,str):ifreduce_op.lower()in("avg","mean"):op=ReduceOp.SUMdivide_by_world_size=Trueelse:op=getattr(ReduceOp,reduce_op.upper())else:op=reduce_op# WA for HPU. HPU doesn't support Long types, forcefully set it to floatif_HPU_AVAILABLE:is_hpu_backend=os.environ.get("HCCL_DISTRIBUTED_BACKEND")=="1"ifis_hpu_backend:if(result.type()=="torch.LongTensor")or(result.type()=="torch.hpu.LongTensor"):new_rank_zero_info("Long tensor unsupported on HPU, casting to float")result=result.float()# sync all processes before reductiontorch.distributed.barrier(group=group)torch.distributed.all_reduce(result,op=op,group=group,async_op=False)ifdivide_by_world_size:result=result/torch.distributed.get_world_size(group)returnresult
[docs]defall_gather_ddp_if_available(tensor:Tensor,group:Optional["torch.distributed.ProcessGroup"]=None,sync_grads:bool=False)->Tensor:"""Function to gather a tensor from several distributed processes. Args: tensor: tensor of shape (batch, ...) group: the process group to gather results from. Defaults to all processes (world) sync_grads: flag that allows users to synchronize gradients for all_gather op Return: A tensor of shape (world_size, batch, ...) """group=groupifgroupisnotNoneelsetorch.distributed.group.WORLDifdistributed_available():ifsync_grads:returnAllGatherGrad.apply(tensor,group)withtorch.no_grad():returnAllGatherGrad.apply(tensor,group)returntensor
[docs]defregister_ddp_comm_hook(model:DistributedDataParallel,ddp_comm_state:Optional[object]=None,ddp_comm_hook:Optional[Callable]=None,ddp_comm_wrapper:Optional[Callable]=None,)->None:"""Function to register communication hook for DDP model https://pytorch.org/docs/master/ddp_comm_hooks.html. Args: model: DDP model ddp_comm_state: state is passed to the hook and can be used to maintain and update any state information that users would like to maintain as part of the training process. Examples: error feedback in gradient compression, peers to communicate with next in GossipGrad etc. ddp_comm_hook: hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future This callable function is called once the bucket is ready. The hook can perform whatever processing is needed and return a Future indicating completion of any async work (ex: allreduce). If the hook doesn't perform any communication, it can also just return a completed Future. The Future should hold the new value of grad bucket's tensors. Once a bucket is ready, c10d reducer would call this hook and use the tensors returned by the Future and copy grads to individual parameters. ddp_comm_wrapper: communication hook wrapper to support a communication hook such as FP16 compression as wrapper, which could be combined with ddp_comm_hook Examples: >>> from torch.distributed.algorithms.ddp_comm_hooks import ( # doctest: +SKIP ... default_hooks as default, ... powerSGD_hook as powerSGD, ... post_localSGD_hook as post_localSGD, ... ) >>> >>> # fp16_compress_hook for compress gradients >>> ddp_model = ... >>> register_ddp_comm_hook( # doctest: +SKIP ... model=ddp_model, ... ddp_comm_hook=default.fp16_compress_hook, ... ) >>> >>> # powerSGD_hook >>> ddp_model = ... >>> register_ddp_comm_hook( # doctest: +SKIP ... model=ddp_model, ... ddp_comm_state=powerSGD.PowerSGDState( ... process_group=None, ... matrix_approximation_rank=1, ... start_powerSGD_iter=5000, ... ), ... ddp_comm_hook=powerSGD.powerSGD_hook, ... ) >>> >>> # post_localSGD_hook >>> subgroup, _ = torch.distributed.new_subgroups() # doctest: +SKIP >>> ddp_model = ... >>> register_ddp_comm_hook( # doctest: +SKIP ... model=ddp_model, ... state=post_localSGD.PostLocalSGDState( ... process_group=None, ... subgroup=subgroup, ... start_localSGD_iter=1_000, ... ), ... ddp_comm_hook=post_localSGD.post_localSGD_hook, ... ) >>> >>> # fp16_compress_wrapper combined with other communication hook >>> ddp_model = ... >>> register_ddp_comm_hook( # doctest: +SKIP ... model=ddp_model, ... ddp_comm_state=powerSGD.PowerSGDState( ... process_group=None, ... matrix_approximation_rank=1, ... start_powerSGD_iter=5000, ... ), ... ddp_comm_hook=powerSGD.powerSGD_hook, ... ddp_comm_wrapper=default.fp16_compress_wrapper, ... ) """ifddp_comm_hookisNone:return# inform mypy that ddp_comm_hook is callableddp_comm_hook:Callable=ddp_comm_hookifddp_comm_wrapperisnotNone:new_rank_zero_info(f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__}).")ddp_comm_hook=ddp_comm_wrapper(ddp_comm_hook)new_rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.")model.register_comm_hook(state=ddp_comm_state,hook=ddp_comm_hook)# type: ignore[operator]
deftpu_distributed()->bool:return_TPU_AVAILABLEandxm.xrt_world_size()>1defget_default_process_group_backend_for_device(device:torch.device)->str:return"nccl"ifdevice.type=="cuda"else"gloo"def_get_process_group_backend_from_env()->Optional[str]:torch_backend=os.getenv("PL_TORCH_DISTRIBUTED_BACKEND")iftorch_backendisnotNone:rank_zero_deprecation("Environment variable `PL_TORCH_DISTRIBUTED_BACKEND`"" was deprecated in v1.6 and will be removed in v1.8."" Specify `process_group_backend` directly on the strategy constructor.")returntorch_backend
[docs]definit_dist_connection(cluster_environment:"pl.plugins.environments.ClusterEnvironment",torch_distributed_backend:str,global_rank:Optional[int]=None,world_size:Optional[int]=None,**kwargs:Any,)->None:"""Utility function to initialize distributed connection by setting env variables and initializing the distributed process group. Args: cluster_environment: ``ClusterEnvironment`` instance torch_distributed_backend: backend to use (includes `nccl` and `gloo`) global_rank: rank of the current process world_size: number of processes in the group kwargs: kwargs for ``init_process_group`` Raises: RuntimeError: If ``torch.distributed`` is not available """ifnottorch.distributed.is_available():raiseRuntimeError("torch.distributed is not available. Cannot initialize distributed process group")iftorch.distributed.is_initialized():log.debug("torch.distributed is already initialized. Exiting early")returnglobal_rank=global_rankifglobal_rankisnotNoneelsecluster_environment.global_rank()world_size=world_sizeifworld_sizeisnotNoneelsecluster_environment.world_size()os.environ["MASTER_ADDR"]=cluster_environment.main_addressos.environ["MASTER_PORT"]=str(cluster_environment.main_port)log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}")torch.distributed.init_process_group(torch_distributed_backend,rank=global_rank,world_size=world_size,**kwargs)# on rank=0 let everyone know training is startingnew_rank_zero_info(f"{'-'*100}\n"f"distributed_backend={torch_distributed_backend}\n"f"All distributed processes registered. Starting with {world_size} processes\n"f"{'-'*100}\n")
def_broadcast_object_list(obj:Any,rank:int)->Any:objects=[objiftorch.distributed.get_rank()==rankelseNone]torch.distributed.broadcast_object_list(objects,src=rank)returnobjects[0]# TODO: Refactor with the Strategy Collectives once finalized.def_collect_states_on_rank_zero(state:Dict[str,Any])->Dict[int,Any]:"""This distributed utility collects dictionary state across all processes. Args: state: Dictionary containing the state of the current process Returns: states: On global rank 0, a dictionary where the primary keys are the process rank and the values their associated states. Otherwise, returns None. """ifnotdistributed_available():return{0:state}return{rank:_broadcast_object_list(state,rank)forrankinrange(torch.distributed.get_world_size())}defrank_zero_info(*args:Any,**kwargs:Any)->Any:rank_zero_deprecation("pytorch_lightning.utilities.distributed.rank_zero_info has been deprecated in v1.6"" and will be removed in v1.8."" Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead.")returnnew_rank_zero_info(*args,**kwargs)defrank_zero_debug(*args:Any,**kwargs:Any)->Any:rank_zero_deprecation("pytorch_lightning.utilities.distributed.rank_zero_debug has been deprecated in v1.6"" and will be removed in v1.8."" Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead.")returnnew_rank_zero_debug(*args,**kwargs)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.