Source code for pytorch_lightning.strategies.hpu_parallel
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importloggingimportosfromtypingimportDict,List,Optionalimporttorchimporttorch.distributedimportpytorch_lightningasplfrompytorch_lightning.overridesimportLightningDistributedModulefrompytorch_lightning.overrides.torch_distributedimportbroadcast_object_listfrompytorch_lightning.plugins.io.checkpoint_pluginimportCheckpointIOfrompytorch_lightning.plugins.io.hpu_pluginimportHPUCheckpointIOfrompytorch_lightning.plugins.precisionimportPrecisionPluginfrompytorch_lightning.strategies.ddpimportDDPStrategyfrompytorch_lightning.utilities.distributedimportgroupas_groupfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.importsimport_HPU_AVAILABLE,_TORCH_LESSER_EQUAL_1_10_2if_HPU_AVAILABLE:importhabana_frameworks.torch.core.hccl# noqa: F401fromhabana_frameworks.torch.utils.library_loaderimportload_habana_modulelog=logging.getLogger(__name__)
[docs]classHPUParallelStrategy(DDPStrategy):"""Strategy for distributed training on multiple HPU devices."""strategy_name="hpu_parallel"def__init__(self,accelerator:Optional["pl.accelerators.accelerator.Accelerator"]=None,parallel_devices:Optional[List[torch.device]]=None,checkpoint_io:Optional[CheckpointIO]=None,precision_plugin:Optional[PrecisionPlugin]=None,process_group_backend:Optional[str]="hccl",)->None:ifnot_HPU_AVAILABLE:raiseMisconfigurationException("`HPUParallelStrategy` requires HPU devices to run")super().__init__(accelerator=accelerator,parallel_devices=parallel_devices,checkpoint_io=checkpoint_ioorHPUCheckpointIO(),precision_plugin=precision_plugin,process_group_backend=process_group_backend,)
[docs]defsetup_environment(self)->None:# This function is used to load Habana libraries required for PyTorch# to register HPU as one of the available devices.load_habana_module()os.environ["ID"]=str(self.local_rank)ifself._process_group_backend=="hccl":# this env is used in overrides to check the backend initiatedos.environ["HCCL_DISTRIBUTED_BACKEND"]=str(1)super().setup_environment()
defdetermine_ddp_device_ids(self)->None:returnNonedefpre_configure_ddp(self):# type: ignore# if unset, default `find_unused_parameters` `True`# Many models require setting this parameter to True, as there are corner cases# when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True.# This flag does come with a performance hit, so it is suggested to disable in cases where it is possible.self._ddp_kwargs["find_unused_parameters"]=self._ddp_kwargs.get("find_unused_parameters",True)self._static_graph=Falsestatic_graph=self._ddp_kwargs.get("static_graph")ifstatic_graph:# when _set_static_graph() is called find_unused_parameters does not have any significance.# Resetting the value of find_unused_parameters to False which is the default value to DDPself._ddp_kwargs["find_unused_parameters"]=Falseself._static_graph=Trueifstatic_graphisnotNone:# DDP does not accept static_graph as a parameter, hence removing it from the listdelself._ddp_kwargs["static_graph"]defconfigure_ddp(self)->None:# DDP does not accept static graph as param with torch < 1.11if_TORCH_LESSER_EQUAL_1_10_2:log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel")self.pre_configure_ddp()self.model=self._setup_model(LightningDistributedModule(self.model))# type: ignoreifself.root_device.type=="hpu"andself._static_graph:self._model._set_static_graph()# type: ignoreself._register_ddp_hooks()else:self.configure_ddp()
[docs]defteardown(self)->None:log.detail(f"{self.__class__.__name__}: tearing down strategy.")super().teardown()log.detail(f"{self.__class__.__name__}: moving model to CPU")self.lightning_module.cpu()# type: ignore# Was set to local rankos.environ.pop("ID",None)os.environ.pop("HCCL_DISTRIBUTED_BACKEND",None)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.