Source code for pytorch_lightning.strategies.bagua
# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importloggingimportosfromtypingimportAny,Dict,List,Optional,Unionimporttorchfromlightning_utilities.core.importsimportmodule_availablefromtorchimportTensorfromtorch.nnimportModuleimportpytorch_lightningasplfromlightning_fabric.pluginsimportCheckpointIO,ClusterEnvironmentfromlightning_fabric.utilities.optimizerimport_optimizers_to_devicefromlightning_fabric.utilities.seedimportreset_seedfromlightning_fabric.utilities.typesimportReduceOpfrompytorch_lightning.overrides.baseimport_LightningModuleWrapperBase,_LightningPrecisionModuleWrapperBasefrompytorch_lightning.plugins.precisionimportPrecisionPluginfrompytorch_lightning.strategies.ddpimportDDPStrategyfrompytorch_lightning.strategies.strategyimportTBroadcastfrompytorch_lightning.trainer.statesimportTrainerFnfrompytorch_lightning.utilities.exceptionsimportMisconfigurationException_BAGUA_AVAILABLE=module_available("bagua.torch_api")if_BAGUA_AVAILABLE:importbagua.torch_apiasbaguafrombagua.torch_api.algorithmsimportAlgorithmfrombagua.torch_api.algorithms.q_adamimportQAdamOptimizerfrombagua.torch_api.communicationimportallreduce_inplace,barrier,broadcast_object,is_initializedfrombagua.torch_api.communicationimportReduceOpasBaguaReduceOpfrombagua.torch_api.data_parallel.distributedimportDistributedDataParallel_V1_9_0asBaguaDistributedDataParallel# Convert a reduce op to its equivalent `bagua.torch_api.ReduceOp`_bagua_reduce_ops={ReduceOp.SUM:BaguaReduceOp.SUM,ReduceOp.PRODUCT:BaguaReduceOp.PRODUCT,ReduceOp.MIN:BaguaReduceOp.MIN,ReduceOp.MAX:BaguaReduceOp.MAX,ReduceOp.BAND:BaguaReduceOp.BAND,ReduceOp.BOR:BaguaReduceOp.BOR,ReduceOp.BXOR:BaguaReduceOp.BXOR,"avg":BaguaReduceOp.AVG,"mean":BaguaReduceOp.AVG,"sum":BaguaReduceOp.SUM,}else:_bagua_reduce_ops={}log=logging.getLogger(__name__)classLightningBaguaModule(_LightningModuleWrapperBase):def__init__(self,forward_module:Optional[Union["pl.LightningModule",_LightningPrecisionModuleWrapperBase]]=None,pl_module:Optional[Union["pl.LightningModule",_LightningPrecisionModuleWrapperBase]]=None,)->None:self._validate_init_arguments(pl_module,forward_module)forward_module=pl_moduleorforward_modulesuper().__init__(forward_module=forward_module)# Bagua use `bagua_module_name` to distinguish different modulesself._bagua_module_name=f"{forward_module.__class__.__name__}{id(forward_module)}"defforward(self,*inputs:Any,**kwargs:Any)->Any:pl_module=self.lightning_moduletrainer=pl_module._traineriftrainerisnotNone:iftrainer.training:output=self._forward_module.training_step(*inputs,**kwargs)# In manual_optimization, we need to prevent DDP reducer as# it is done manually in `LightningModule.manual_backward`# `require_backward_grad_sync` will be reset in the# ddp_strategy `post_training_step` hookifnotpl_module.automatic_optimization:# Using bagua strategy, the model is redefined in model.inner# and cannot be accessed directly. We need this to make manual# backward work.trainer.model.inner.require_backward_grad_sync=False# type: ignore[union-attr]returnoutputelse:returnsuper().forward(*inputs,**kwargs)returnself._forward_module(*inputs,**kwargs)
[docs]classBaguaStrategy(DDPStrategy):strategy_name="bagua"def__init__(self,algorithm:str="gradient_allreduce",flatten:bool=True,accelerator:Optional["pl.accelerators.Accelerator"]=None,parallel_devices:Optional[List[torch.device]]=None,cluster_environment:Optional[ClusterEnvironment]=None,checkpoint_io:Optional[CheckpointIO]=None,precision_plugin:Optional[PrecisionPlugin]=None,**bagua_kwargs:Union[Any,Dict[str,Any]],):"""Strategy for training using the `Bagua <https://github.com/BaguaSys/bagua>`_ library, with advanced distributed training algorithms and system optimizations. This strategy requires the `bagua` package to be installed. See `installation guide <https://tutorials.baguasys.com/installation>`_ for more information. The :class:`BaguaStrategy` is only supported on GPU and on Linux systems. Arguments: algorithm: Distributed algorithm used to do the actual communication and update. Built-in algorithms include "gradient_allreduce", "bytegrad", "decentralized", "low_precision_decentralized", "qadam" and "async". flatten: Whether to flatten the Bagua communication buckets. The flatten operation will reset data pointer of bucket tensors so that they can use faster code paths. bagua_kwargs: Additional keyword arguments that will be passed to initialize the Bagua algorithm. More details on keyword arguments accepted for each algorithm can be found in the `documentation <https://bagua.readthedocs.io/en/latest/autoapi/bagua/torch_api/algorithms/index.html>`_. """ifnot_BAGUA_AVAILABLE:raiseMisconfigurationException("To use the `BaguaStrategy`, you must have `Bagua` installed. Use `pip install bagua` to install it.")super().__init__(accelerator=accelerator,parallel_devices=parallel_devices,cluster_environment=cluster_environment,checkpoint_io=checkpoint_io,precision_plugin=precision_plugin,)self._bagua_algorithm=algorithmself._bagua_flatten=flattenself._bagua_kwargs=bagua_kwargsdefsetup_distributed(self)->None:reset_seed()# determine which process we are and world sizeself.set_world_ranks()self._init_bagua_distributed()def_init_bagua_distributed(self)->None:self._set_node_environment_variables()log.info("Initializing Bagua Distributed: "f"GLOBAL_RANK: {self.global_rank}, "f"MEMBER: {self.global_rank+1}/{self.world_size}")# need to set device first before initialize Bagua distributed environment# Note: setup_environment calls super().setup_distributed after calling init_distributed()torch.cuda.set_device(self.local_rank)ifnotis_initialized():bagua.init_process_group()def_set_node_environment_variables(self)->None:"""Set the environment variables as required by the :func:`bagua.init_process_group` call. This enables the use of other cluster environments which don't set these exact variables, e.g., Bagua can be launched with ``torch.distributed.run``. """os.environ["MASTER_ADDR"]=self.cluster_environment.main_address# type: ignore[union-attr]os.environ["MASTER_PORT"]=str(self.cluster_environment.main_port)# type: ignore[union-attr]os.environ["RANK"]=str(self.global_rank)os.environ["NODE_RANK"]=str(self.node_rank)os.environ["WORLD_SIZE"]=str(self.world_size)os.environ["LOCAL_RANK"]=str(self.local_rank)
[docs]defsetup(self,trainer:"pl.Trainer")->None:self._rank_0_will_call_children_scripts=self.broadcast(self._rank_0_will_call_children_scripts)ifself._should_run_deadlock_detection():self._share_information_to_prevent_deadlock()assertself.acceleratorisnotNoneself.accelerator.setup(trainer)# move the model to the correct deviceself.model_to_device()trainer_fn=trainer.state.fniftrainer_fn==TrainerFn.FITTING:ifself._layer_syncandself.model:self.model=self._layer_sync.apply(self.model)self.setup_precision_plugin()iftrainer_fn==TrainerFn.FITTING:# set up optimizers after the module has been moved to the device# but before the module has been wrappedself.setup_optimizers(trainer)_optimizers_to_device(self.optimizers,self.root_device)# skip wrapping the model if we are not fitting as no gradients need to be exchangedself._configure_bagua_model(trainer)
def_check_qadam_optimizer(self)->None:has_qadam_optimizer=any([isinstance(opt,QAdamOptimizer)foroptinself.optimizers])ifnothas_qadam_optimizerorlen(self.optimizers)>1orlen(self.lr_scheduler_configs)>1:raiseMisconfigurationException("Bagua QAdam can only accept one QAdamOptimizer and one LR Scheduler.")self._bagua_kwargs["q_adam_optimizer"]=self.optimizers[0]def_configure_bagua_model(self,trainer:"pl.Trainer")->None:model=LightningBaguaModule(self.model)# type: ignore[arg-type]self.model=self._setup_model(model)# start the background communication for async algorithmiftrainer.trainingandself._bagua_algorithm=="async":self.model.bagua_algorithm.resume(self.model)# type: ignoredef_setup_model(self,model:Module)->"BaguaDistributedDataParallel":"""Wraps the model into a Bagua distributed module."""ifself._bagua_algorithm=="qadam":self._check_qadam_optimizer()algorithm=Algorithm.init(self._bagua_algorithm,**self._bagua_kwargs)returnBaguaDistributedDataParallel(module=model,optimizers=self.optimizers,algorithm=algorithm,gradient_as_bucket_view=self._bagua_flatten,)@classmethoddefregister_strategies(cls,strategy_registry:Dict)->None:strategy_registry.register(cls.strategy_name,cls,description=f"{cls.__class__.__name__}",)
[docs]defteardown(self)->None:# abort the background communication for async algorithmassertself.lightning_moduleisnotNoneifself.lightning_module.trainer.trainingandself._bagua_algorithm=="async":self.model.bagua_algorithm.abort(self.model)# type: ignoreifisinstance(self.model,BaguaDistributedDataParallel):self.model=self.lightning_modulesuper().teardown()
defpost_training_step(self)->None:assertself.lightning_moduleisnotNone# Using bagua strategy, the model is redefined in model.inner# and cannot be accessed directly. We need to redefine the# post_training_step function to make manual backward work.ifnotself.lightning_module.automatic_optimization:self.model.inner.require_backward_grad_sync=True# type: ignore[union-attr]
[docs]defreduce(self,tensor:Tensor,group:Optional[Any]=None,reduce_op:Optional[Union[ReduceOp,str]]="mean")->Tensor:"""Reduces a tensor from several distributed processes to one aggregated tensor. Args: tensor: The tensor to sync and reduce. group: The process group to gather results from. Defaults to all processes (world). reduce_op: The reduction operation. Can also be a string 'sum' or ReduceOp. Return: The reduced value, except when the input was not a tensor the output remains is unchanged. """ifnotisinstance(tensor,Tensor):returntensorifgroupisnotNone:raiseValueError("`Bagua` does not support allreduce using a subcommunicator at this time. Unset `group`.")ifreduce_opisNone:op=BaguaReduceOp.AVGelse:op=_bagua_reduce_ops.get(reduce_op,None)ifopisNone:raiseValueError(f"Unrecognized `reduce_op` for `BaguaStrategy`: {reduce_op}")allreduce_inplace(tensor,op=op)returntensor
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.