[docs]classTorchCollective(Collective):manages_default_group=Falsedef__init__(self)->None:ifnotdist.is_available():raiseRuntimeError("Torch distributed is not available.")super().__init__()@propertydefgroup(self)->CollectibleGroup:ifself._groupisNone:self._group=dist.GroupMember.WORLDreturnsuper().group@propertydefrank(self)->int:# local rankreturndist.get_rank(self.group)# type: ignore[arg-type]@propertydefworld_size(self)->int:returndist.get_world_size(self.group)# type: ignore[arg-type]defbroadcast(self,tensor:Tensor,src:int)->Tensor:dist.broadcast(tensor,src,group=self.group)returntensordefall_reduce(self,tensor:Tensor,op:Union[str,ReduceOp,RedOpType]="sum")->Tensor:op=self._convert_to_native_op(op)dist.all_reduce(tensor,op=op,group=self.group)returntensordefreduce(self,tensor:Tensor,dst:int,op:Union[str,ReduceOp,RedOpType]="sum")->Tensor:op=self._convert_to_native_op(op)dist.reduce(tensor,dst,op=op,group=self.group)returntensordefall_gather(self,tensor_list:List[Tensor],tensor:Tensor)->List[Tensor]:dist.all_gather(tensor_list,tensor,group=self.group)returntensor_listdefgather(self,tensor:Tensor,gather_list:List[Tensor],dst:int=0)->List[Tensor]:dist.gather(tensor,gather_list,dst,group=self.group)returngather_listdefscatter(self,tensor:Tensor,scatter_list:List[Tensor],src:int=0)->Tensor:dist.scatter(tensor,scatter_list,src,group=self.group)returntensordefreduce_scatter(self,output:Tensor,input_list:List[Tensor],op:Union[str,ReduceOp,RedOpType]="sum")->Tensor:op=self._convert_to_native_op(op)dist.reduce_scatter(output,input_list,op=op,group=self.group)returnoutputdefall_to_all(self,output_tensor_list:List[Tensor],input_tensor_list:List[Tensor])->List[Tensor]:dist.all_to_all(output_tensor_list,input_tensor_list,group=self.group)returnoutput_tensor_listdefsend(self,tensor:Tensor,dst:int,tag:int=0)->None:dist.send(tensor,dst,tag=tag,group=self.group)# type: ignore[arg-type]defrecv(self,tensor:Tensor,src:Optional[int]=None,tag:int=0)->Tensor:dist.recv(tensor,src,tag=tag,group=self.group)# type: ignore[arg-type]returntensordefall_gather_object(self,object_list:List[Any],obj:Any)->List[Any]:dist.all_gather_object(object_list,obj,group=self.group)returnobject_listdefbroadcast_object_list(self,object_list:List[Any],src:int,device:Optional[torch.device]=None)->List[Any]:dist.broadcast_object_list(object_list,src,group=self.group,device=device)returnobject_listdefgather_object(self,obj:Any,object_gather_list:List[Any],dst:int=0)->List[Any]:dist.gather_object(obj,object_gather_list,dst,group=self.group)returnobject_gather_listdefscatter_object_list(self,scatter_object_output_list:List[Any],scatter_object_input_list:List[Any],src:int=0)->List[Any]:dist.scatter_object_list(scatter_object_output_list,scatter_object_input_list,src,group=self.group)returnscatter_object_output_listdefbarrier(self,device_ids:Optional[List[int]]=None)->None:ifself.group==dist.GroupMember.NON_GROUP_MEMBER:returndist.barrier(group=self.group,device_ids=device_ids)defmonitored_barrier(self,timeout:Optional[datetime.timedelta]=None,wait_all_ranks:bool=False)->None:dist.monitored_barrier(group=self.group,timeout=timeout,wait_all_ranks=wait_all_ranks)defsetup(self,main_address:Optional[str]=None,main_port:Optional[str]=None,**kwargs:Any)->Self:ifself.is_initialized():returnself# maybe set addrset_addr=Falseaddr_key="MASTER_ADDR"ifmain_addressisnotNoneandaddr_keynotinos.environ:os.environ[addr_key]=main_addressset_addr=True# maybe set portset_port=Falseport_key="MASTER_PORT"ifmain_portisnotNoneandport_keynotinos.environ:os.environ[port_key]=str(main_port)set_port=True# this will `init_group`super().setup(**kwargs)# set as a class attribute so any instance can know whether we initialized the default process groupTorchCollective.manages_default_group=True# cleanupifset_addr:os.environ.pop("MASTER_ADDR",None)ifset_port:os.environ.pop("MASTER_PORT",None)returnselfdefteardown(self)->Self:group_member=self.group!=dist.GroupMember.NON_GROUP_MEMBERsuper().teardown()# will destroy its own group# try to destroy the default group. this should only be done by a group member to avoid race conditions,# and only if the class is managing itif(group_memberandTorchCollective.manages_default_groupanddist.GroupMember.WORLDisnotNone# not destroyed alreadyandlen(dist.distributed_c10d._pg_map)==1# only the default group is left):self.destroy_group(dist.GroupMember.WORLD)TorchCollective.manages_default_group=FalseelifTorchCollective.manages_default_groupanddist.GroupMember.WORLDisNone:TorchCollective.manages_default_group=Falsereturnself@classmethoddefis_available(cls)->bool:returndist.is_available()@classmethoddefis_initialized(cls)->bool:returndist.is_initialized()@classmethoddefinit_group(cls,**kwargs:Any)->None:dist.init_process_group(**kwargs)@classmethoddefnew_group(cls,**kwargs:Any)->CollectibleGroup:returndist.new_group(**kwargs)@classmethoddefdestroy_group(cls,group:CollectibleGroup)->None:# can be called by all processes in the default group, group will be `object()` if they are not part of the# current groupifgroupindist.distributed_c10d._pg_map:dist.destroy_process_group(group)# type: ignore[arg-type]@classmethoddef_convert_to_native_op(cls,op:Union[str,ReduceOp,RedOpType])->Union[ReduceOp,RedOpType]:# in 1.13, `ReduceOp` has become an empty shell for `RedOpType`, the latter being the actually returned class.# for example, `ReduceOp.SUM` returns a `RedOpType.SUM`. the only exception is `RedOpType.PREMUL_SUM` where# `ReduceOp` is still the desired class, but it's created via a special `_make_nccl_premul_sum` functionifisinstance(op,ReduceOp)or_TORCH_GREATER_EQUAL_1_13andisinstance(op,RedOpType):returnopifnotisinstance(op,str):raiseValueError(f"Unsupported op {op!r} of type {type(op).__name__}")op=op.upper()# `ReduceOp` should contain `RedOpType`'s membersvalue=getattr(ReduceOp,op,None)ifvalueisNone:raiseValueError(f"op {op!r} is not a member of `ReduceOp`")returnvalue
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.