Source code for pytorch_lightning.strategies.bagua
importloggingimportosfromtypingimportAny,Dict,List,Optional,Unionimporttorchfromtorch.nnimportModuleimportpytorch_lightningasplfrompytorch_lightning.overrides.baseimport(_LightningModuleWrapperBase,_LightningPrecisionModuleWrapperBase,unwrap_lightning_module,)frompytorch_lightning.plugins.environments.cluster_environmentimportClusterEnvironmentfrompytorch_lightning.plugins.io.checkpoint_pluginimportCheckpointIOfrompytorch_lightning.plugins.precisionimportPrecisionPluginfrompytorch_lightning.strategies.ddpimportDDPStrategyfrompytorch_lightning.strategies.strategyimportTBroadcastfrompytorch_lightning.utilities.distributedimportReduceOpfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.importsimport_BAGUA_AVAILABLEfrompytorch_lightning.utilities.seedimportreset_seedif_BAGUA_AVAILABLE:importbagua.torch_apiasbaguafrombagua.torch_api.algorithmsimportAlgorithmfrombagua.torch_api.algorithms.q_adamimportQAdamOptimizerfrombagua.torch_api.communicationimportallreduce_inplace,barrier,broadcast_object,is_initializedfrombagua.torch_api.communicationimportReduceOpasBaguaReduceOpfrombagua.torch_api.data_parallel.distributedimportDistributedDataParallel_V1_9_0asBaguaDistributedDataParallelelse:BaguaReduceOp=NoneBaguaDistributedDataParallel=Nonelog=logging.getLogger(__name__)classLightningBaguaModule(_LightningModuleWrapperBase):def__init__(self,pl_module:Union["pl.LightningModule",_LightningPrecisionModuleWrapperBase])->None:super().__init__(pl_module)# Bagua use `bagua_module_name` to distinguish different modulesself._bagua_module_name=f"{pl_module.__class__.__name__}{id(pl_module)}"if_BAGUA_AVAILABLE:# Convert a reduce op to its equivalent `bagua.torch_api.ReduceOp`_bagua_reduce_ops={ReduceOp.SUM:BaguaReduceOp.SUM,ReduceOp.PRODUCT:BaguaReduceOp.PRODUCT,ReduceOp.MIN:BaguaReduceOp.MIN,ReduceOp.MAX:BaguaReduceOp.MAX,ReduceOp.BAND:BaguaReduceOp.BAND,ReduceOp.BOR:BaguaReduceOp.BOR,ReduceOp.BXOR:BaguaReduceOp.BXOR,"avg":BaguaReduceOp.AVG,"mean":BaguaReduceOp.AVG,"sum":BaguaReduceOp.SUM,}else:_bagua_reduce_ops={}
[docs]classBaguaStrategy(DDPStrategy):strategy_name="bagua"def__init__(self,algorithm:str="gradient_allreduce",flatten:bool=True,accelerator:Optional["pl.accelerators.accelerator.Accelerator"]=None,parallel_devices:Optional[List[torch.device]]=None,cluster_environment:Optional[ClusterEnvironment]=None,checkpoint_io:Optional[CheckpointIO]=None,precision_plugin:Optional[PrecisionPlugin]=None,**bagua_kwargs:Union[Any,Dict[str,Any]],):"""Strategy for training using the `Bagua <https://github.com/BaguaSys/bagua>`_ library, with advanced distributed training algorithms and system optimizations. This strategy requires the `bagua` package to be installed. See `installation guide <https://tutorials.baguasys.com/installation>`_ for more information. The :class:`BaguaStrategy` is only supported on GPU and on Linux systems. Arguments: algorithm: Distributed algorithm used to do the actual communication and update. Built-in algorithms include "gradient_allreduce", "bytegrad", "decentralized", "low_precision_decentralized", "qadam" and "async". flatten: Whether to flatten the Bagua communication buckets. The flatten operation will reset data pointer of bucket tensors so that they can use faster code paths. bagua_kwargs: Additional keyword arguments that will be passed to initialize the Bagua algorithm. More details on keyword arguments accepted for each algorithm can be found in the `documentation <https://bagua.readthedocs.io/en/latest/autoapi/bagua/torch_api/algorithms/index.html>`_. """ifnot_BAGUA_AVAILABLE:raiseMisconfigurationException("To use the `BaguaStrategy`, you must have `Bagua` installed. Use `pip install bagua` to install it.")super().__init__(accelerator=accelerator,parallel_devices=parallel_devices,cluster_environment=cluster_environment,checkpoint_io=checkpoint_io,precision_plugin=precision_plugin,)self._bagua_algorithm=algorithmself._bagua_flatten=flattenself._bagua_kwargs=bagua_kwargs@propertydeflightning_module(self)->"pl.LightningModule":model=self._modelifisinstance(model,BaguaDistributedDataParallel):model=model.modulereturnunwrap_lightning_module(model)# type: ignore[arg-type]defsetup_distributed(self)->None:reset_seed()# determine which process we are and world sizeself.set_world_ranks()self._init_bagua_distributed()def_init_bagua_distributed(self)->None:self._set_node_environment_variables()log.info("Initializing Bagua Distributed: "f"GLOBAL_RANK: {self.global_rank}, "f"MEMBER: {self.global_rank+1}/{self.world_size}")# need to set device first before initialize Bagua distributed environment# Note: setup_environment calls super().setup_distributed after calling init_distributed()torch.cuda.set_device(self.local_rank)ifnotis_initialized():bagua.init_process_group()def_set_node_environment_variables(self)->None:"""Set the environment variables as required by the :func:`bagua.init_process_group` call. This enables the use of other cluster environments which don't set these exact variables, e.g., Bagua can be launched with ``torch.distributed.run``. """os.environ["MASTER_ADDR"]=self.cluster_environment.main_address# type: ignore[union-attr]os.environ["MASTER_PORT"]=str(self.cluster_environment.main_port)# type: ignore[union-attr]os.environ["RANK"]=str(self.global_rank)os.environ["NODE_RANK"]=str(self.node_rank)os.environ["WORLD_SIZE"]=str(self.world_size)os.environ["LOCAL_RANK"]=str(self.local_rank)def_check_qadam_optimizer(self)->None:has_qadam_optimizer=any([isinstance(opt,QAdamOptimizer)foroptinself.optimizers])ifnothas_qadam_optimizerorlen(self.optimizers)>1orlen(self.lr_scheduler_configs)>1:raiseMisconfigurationException("Bagua QAdam can only accept one QAdamOptimizer and one LR Scheduler.")self._bagua_kwargs["q_adam_optimizer"]=self.optimizers[0]defconfigure_ddp(self)->None:model=LightningBaguaModule(self.model)# type: ignore[arg-type]self._model=self._setup_model(model)# start the background communication for async algorithmassertself.lightning_module.trainerisnotNoneifself.lightning_module.trainer.trainingandself._bagua_algorithm=="async":self.model.bagua_algorithm.resume(self.model)# type: ignoredef_setup_model(self,model:Module)->BaguaDistributedDataParallel:"""Wraps the model into a Bagua distributed module."""ifself._bagua_algorithm=="qadam":self._check_qadam_optimizer()algorithm=Algorithm.init(self._bagua_algorithm,**self._bagua_kwargs)returnBaguaDistributedDataParallel(module=model,optimizers=self.optimizers,algorithm=algorithm,gradient_as_bucket_view=self._bagua_flatten,)@classmethoddefregister_strategies(cls,strategy_registry:Dict)->None:strategy_registry.register(cls.strategy_name,cls,description=f"{cls.__class__.__name__}",)
[docs]defteardown(self)->None:# abort the background communication for async algorithmassertself.lightning_module.trainerisnotNoneifself.lightning_module.trainer.trainingandself._bagua_algorithm=="async":self.model.bagua_algorithm.abort(self.model)# type: ignoreifisinstance(self.model,BaguaDistributedDataParallel):self.model=self.lightning_moduleifself.root_device.type=="cuda":# GPU teardownlog.detail(f"{self.__class__.__name__}: moving model to CPU")self.lightning_module.cpu()# clean up memorytorch.cuda.empty_cache()
[docs]defreduce(self,tensor:torch.Tensor,group:Optional[Any]=None,reduce_op:Optional[Union[ReduceOp,str]]="mean")->torch.Tensor:"""Reduces a tensor from several distributed processes to one aggregated tensor. Args: tensor: The tensor to sync and reduce. group: The process group to gather results from. Defaults to all processes (world). reduce_op: The reduction operation. Can also be a string 'sum' or ReduceOp. Return: The reduced value, except when the input was not a tensor the output remains is unchanged. """ifnotisinstance(tensor,torch.Tensor):returntensorifgroupisnotNone:raiseValueError("`Bagua` does not support allreduce using a subcommunicator at this time. Unset `group`.")ifreduce_opisNone:op=BaguaReduceOp.AVGelse:op=_bagua_reduce_ops.get(reduce_op,None)ifopisNone:raiseValueError(f"Unrecognized `reduce_op` for `BaguaStrategy`: {reduce_op}")allreduce_inplace(tensor,op=op)returntensor
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.