Source code for pytorch_lightning.strategies.horovod
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromcontextlibimportExitStackfromtypingimportAny,Dict,List,Optional,Tuple,Unionimporttorchimporttorch.nnasnnfromtorch.optimimportOptimizerimportpytorch_lightningasplfrompytorch_lightning.core.optimizerimportLightningOptimizerfrompytorch_lightning.plugins.io.checkpoint_pluginimportCheckpointIOfrompytorch_lightning.plugins.precisionimportPrecisionPluginfrompytorch_lightning.strategies.parallelimportParallelStrategyfrompytorch_lightning.utilities.distributedimportdistributed_availablefrompytorch_lightning.utilities.distributedimportgroupasdist_groupfrompytorch_lightning.utilities.distributedimportReduceOpfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.importsimport_HOROVOD_AVAILABLEfrompytorch_lightning.utilities.rank_zeroimportrank_zero_onlyif_HOROVOD_AVAILABLE:importhorovod.torchashvd
[docs]classHorovodStrategy(ParallelStrategy):"""Plugin for Horovod distributed training integration."""strategy_name="horovod"def__init__(self,accelerator:Optional["pl.accelerators.accelerator.Accelerator"]=None,parallel_devices:Optional[List[torch.device]]=None,checkpoint_io:Optional[CheckpointIO]=None,precision_plugin:Optional[PrecisionPlugin]=None,):super().__init__(accelerator=accelerator,parallel_devices=parallel_devices,cluster_environment=None,checkpoint_io=checkpoint_io,precision_plugin=precision_plugin,)rank_zero_only.rank=self.global_rankself._exit_stack:Optional[ExitStack]=None@propertydefglobal_rank(self)->int:returnhvd.rank()@propertydeflocal_rank(self)->int:returnhvd.local_rank()@propertydefworld_size(self)->int:returnhvd.size()@propertydefroot_device(self):returnself.parallel_devices[self.local_rank]@propertydefdistributed_sampler_kwargs(self):distributed_sampler_kwargs=dict(num_replicas=self.world_size,rank=self.global_rank)returndistributed_sampler_kwargs@propertydefhandles_gradient_accumulation(self)->bool:"""Whether the plugin handles gradient accumulation internally."""returnTrue
[docs]defsetup(self,trainer:"pl.Trainer")->None:self.model_to_device()super().setup(trainer)self._exit_stack=ExitStack()self._exit_stack.__enter__()ifnottrainer.training:# no need to setup optimizersreturndef_unpack_lightning_optimizer(opt):returnopt._optimizerifisinstance(opt,LightningOptimizer)elseoptoptimizers=self.optimizersoptimizers=[_unpack_lightning_optimizer(opt)foroptinoptimizers]# Horovod: scale the learning rate by the number of workers to account for# increased total batch sizeforoptimizerinoptimizers:forparam_groupinoptimizer.param_groups:param_group["lr"]*=self.world_size# Horovod: adjust base LR used by schedulers to match scaled optimizer initial LRlr_scheduler_configs=self.lr_scheduler_configsforconfiginlr_scheduler_configs:scheduler=config.schedulerscheduler.base_lrs=[lr*self.world_sizeforlrinscheduler.base_lrs]# Horovod: broadcast parameters & optimizer state to ensure consistent initializationhvd.broadcast_parameters(self.lightning_module.state_dict(),root_rank=0)foroptimizerinoptimizers:hvd.broadcast_optimizer_state(optimizer,root_rank=0)accumulation_scheduler=trainer.accumulation_schedulerifaccumulation_scheduler.epochs!=[0]:raiseMisconfigurationException("Horovod currently does not support different `accumulate_grad_batches` at different epochs.")self.optimizers=self._wrap_optimizers(optimizers,trainer.accumulate_grad_batches)foroptimizerinself.optimizers:# Synchronization will be performed explicitly following backward()self._exit_stack.enter_context(optimizer.skip_synchronize())
[docs]defmodel_to_device(self):ifself.root_device.type=="cuda":# this can potentially be removed after #8312. Not done due to lack of horovod testingtorch.cuda.set_device(self.root_device)self.model.to(self.root_device)
[docs]defreduce(self,tensor,group:Optional[Any]=None,reduce_op:Optional[Union[ReduceOp,str]]="mean"):"""Reduces a tensor from several distributed processes to one aggregated tensor. Args: tensor: the tensor to sync and reduce group: the process group to gather results from. Defaults to all processes (world) reduce_op: the reduction operation. Defaults to 'mean'/'avg'. Can also be a string 'sum' to calculate the sum during reduction. Return: reduced value, except when the input was not a tensor the output remains is unchanged """ifgroupisnotNone:raiseValueError("Horovod does not support allreduce using a subcommunicator at this time. Unset `group`.")ifreduce_opin(None,"avg","mean"):reduce_op=hvd.Averageelifreduce_opin("sum",ReduceOp.SUM):reduce_op=hvd.Sumelse:raiseValueError(f"unrecognized `reduce_op`: {reduce_op}")# sync all processes before reductionself.join()returnhvd.allreduce(tensor,op=reduce_op)
[docs]defall_gather(self,result:torch.Tensor,group:Optional[Any]=dist_group.WORLD,sync_grads:bool=False)->torch.Tensor:ifgroupisnotNoneandgroup!=dist_group.WORLD:raiseValueError("Horovod does not support allgather using a subcommunicator at this time. Unset `group`.")iflen(result.shape)==0:# Convert scalars to single dimension tensorsresult=result.reshape(1)# sync and gather allself.join()returnhvd.allgather(result)
[docs]defpost_backward(self,closure_loss:torch.Tensor)->None:# synchronize all horovod optimizers.foroptimizerinself.optimizers:optimizer.synchronize()
def_wrap_optimizers(self,optimizers:List[Optimizer],accumulate_grad_batches:int)->List["hvd.DistributedOptimizer"]:"""Wraps optimizers to perform gradient aggregation via allreduce."""return[hvd.DistributedOptimizer(opt,backward_passes_per_step=accumulate_grad_batches,named_parameters=self._filter_named_parameters(self.lightning_module,opt),)if"horovod"notinstr(opt.__class__)elseoptforoptinoptimizers]@staticmethoddef_filter_named_parameters(model:nn.Module,optimizer:Optimizer)->List[Tuple[str,nn.Parameter]]:opt_params={pforgroupinoptimizer.param_groupsforpingroup.get("params",[])}return[(name,p)forname,pinmodel.named_parameters()ifpinopt_params]@classmethoddefregister_strategies(cls,strategy_registry:Dict)->None:strategy_registry.register(cls.strategy_name,cls,description=f"{cls.__class__.__name__}",)
[docs]defteardown(self)->None:super().teardown()# teardown may be called before `_exit_stack` is setifself._exit_stack:self._exit_stack.__exit__(None,None,None)self._exit_stack=None# Make sure all workers have finished training before returning to the userself.join()ifself.root_device.type=="cuda":# GPU teardownself.lightning_module.cpu()# clean up memorytorch.cuda.empty_cache()
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.