# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importfunctoolsimportqueueasqimporttracebackfrommultiprocessingimportProcess,QueuefromtypingimportAny,Callable,Dict,List,Optional,Unionimporttorchfromlightning_utilities.core.importsimportModuleAvailableCachefromlightning_fabric.accelerators.acceleratorimportAcceleratorfromlightning_fabric.utilities.device_parserimport_check_data_type
[docs]classTPUAccelerator(Accelerator):"""Accelerator for TPU devices."""def__init__(self,*args:Any,**kwargs:Any)->None:ifnot_XLA_AVAILABLE:raiseModuleNotFoundError(str(_XLA_AVAILABLE))super().__init__(*args,**kwargs)
[docs]@staticmethoddefget_parallel_devices(devices:Union[int,List[int]])->List[int]:"""Gets parallel devices for the Accelerator."""ifisinstance(devices,int):returnlist(range(devices))returndevices
[docs]@staticmethoddefauto_device_count()->int:"""Get the devices when set to auto."""return8
[docs]@staticmethod@functools.lru_cache(maxsize=1)defis_available()->bool:# check `_XLA_AVAILABLE` again to avoid launching processesreturnbool(_XLA_AVAILABLE)and_is_device_tpu()
# define TPU availability timeout in secondsTPU_CHECK_TIMEOUT=60def_inner_f(queue:Queue,func:Callable,*args:Any,**kwargs:Any)->None:# pragma: no covertry:queue.put(func(*args,**kwargs))exceptException:traceback.print_exc()queue.put(None)def_multi_process(func:Callable)->Callable:@functools.wraps(func)defwrapper(*args:Any,**kwargs:Any)->Union[bool,Any]:queue:Queue=Queue()proc=Process(target=_inner_f,args=(queue,func,*args),kwargs=kwargs)proc.start()proc.join(TPU_CHECK_TIMEOUT)try:returnqueue.get_nowait()exceptq.Empty:traceback.print_exc()returnFalsereturnwrapper@_multi_processdef_is_device_tpu()->bool:"""Check if TPU devices are available. Runs XLA device check within a separate process. Return: A boolean value indicating if TPU devices are available """ifnot_XLA_AVAILABLE:returnFalseimporttorch_xla.core.xla_modelasxm# For the TPU Pod training process, for example, if we have# TPU v3-32 with 4 VMs, the world size would be 4 and as# we would have to use `torch_xla.distributed.xla_dist` for# multiple VMs and TPU_CONFIG won't be available, running# `xm.get_xla_supported_devices("TPU")` won't be possible.return(xm.xrt_world_size()>1)orbool(xm.get_xla_supported_devices("TPU"))_XLA_AVAILABLE=ModuleAvailableCache("torch_xla")def_tpu_distributed()->bool:ifnotTPUAccelerator.is_available():returnFalseimporttorch_xla.core.xla_modelasxmreturnxm.xrt_world_size()>1def_parse_tpu_devices(devices:Optional[Union[int,str,List[int]]])->Optional[Union[int,List[int]]]:""" Parses the TPU devices given in the format as accepted by the :class:`~pytorch_lightning.trainer.Trainer` and :class:`~lightning_fabric.Fabric`. Args: devices: An int of 1 or string '1' indicates that 1 core with multi-processing should be used An int 8 or string '8' indicates that all 8 cores with multi-processing should be used A list of ints or a strings containing a list of comma separated integers indicates the specific TPU core to use. Returns: A list of tpu_cores to be used or ``None`` if no TPU cores were requested Raises: TypeError: If TPU devices aren't 1, 8 or [<1-8>] """_check_data_type(devices)ifisinstance(devices,str):devices=_parse_tpu_devices_str(devices.strip())ifnot_tpu_devices_valid(devices):raiseTypeError("`devices` can only be 1, 8 or [<1-8>] for TPUs.")returndevicesdef_tpu_devices_valid(devices:Any)->bool:# allow 1 or 8 coresifdevicesin(1,8,None):returnTrue# allow picking 1 of 8 indexesifisinstance(devices,(list,tuple,set)):has_1_tpu_idx=len(devices)==1is_valid_tpu_idx=1<=list(devices)[0]<=8is_valid_tpu_core_choice=has_1_tpu_idxandis_valid_tpu_idxreturnis_valid_tpu_core_choicereturnFalsedef_parse_tpu_devices_str(devices:str)->Union[int,List[int]]:ifdevicesin("1","8"):returnint(devices)return[int(x.strip())forxindevices.split(",")iflen(x)>0]
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.