Source code for lightning.fabric.plugins.environments.xla
# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importfunctoolsimportloggingfromtypingimportAnyfromtyping_extensionsimportoverridefromlightning.fabric.accelerators.xlaimport_XLA_AVAILABLE,_XLA_GREATER_EQUAL_2_1,XLAAcceleratorfromlightning.fabric.plugins.environments.cluster_environmentimportClusterEnvironmentlog=logging.getLogger(__name__)
[docs]classXLAEnvironment(ClusterEnvironment):"""Cluster environment for training on a TPU Pod with the `PyTorch/XLA <https://pytorch.org/xla>`_ library. A list of environment variables set by XLA can be found `here <https://github.com/pytorch/xla/blob/master/torch_xla/core/xla_env_vars.py>`_. """def__init__(self,*args:Any,**kwargs:Any)->None:ifnot_XLA_AVAILABLE:raiseModuleNotFoundError(str(_XLA_AVAILABLE))super().__init__(*args,**kwargs)@property@overridedefcreates_processes_externally(self)->bool:returnFalse@property@overridedefmain_address(self)->str:# unused by lightningraiseNotImplementedError@property@overridedefmain_port(self)->int:# unused by lightningraiseNotImplementedError
[docs]@override@functools.lru_cache(maxsize=1)defworld_size(self)->int:"""The number of processes across all devices and hosts. The output is cached for performance. """importtorch_xla.core.xla_modelasxmreturnxm.xrt_world_size()
@overridedefset_world_size(self,size:int)->None:log.debug("XLAEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
[docs]@override@functools.lru_cache(maxsize=1)defglobal_rank(self)->int:"""The rank (index) of the currently running process across all host and devices. The output is cached for performance. """importtorch_xla.core.xla_modelasxmreturnxm.get_ordinal()
@overridedefset_global_rank(self,rank:int)->None:log.debug("XLAEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
[docs]@override@functools.lru_cache(maxsize=1)deflocal_rank(self)->int:"""The rank (index) of the currently running process inside of the current host. The output is cached for performance. """importtorch_xla.core.xla_modelasxmreturnxm.get_local_ordinal()
[docs]@override@functools.lru_cache(maxsize=1)defnode_rank(self)->int:"""The rank (index) of the host on which the current process runs. The output is cached for performance. """if_XLA_GREATER_EQUAL_2_1:fromtorch_xlaimportruntimeasxrreturnxr.host_index()importtorch_xla.core.xla_env_varsasxenvfromtorch_xla.utils.utilsimportgetenv_asreturngetenv_as(xenv.HOST_ORDINAL,int,0)
To analyze traffic and optimize your experience, we serve cookies on this
site. By clicking or navigating, you agree to allow our usage of cookies.
Read PyTorch Lightning's
Privacy Policy.