Source code for lightning_fabric.accelerators.cuda
# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importosimportwarningsfromcontextlibimportcontextmanagerfromfunctoolsimportlru_cachefromtypingimportcast,Dict,Generator,List,Optional,Unionimporttorchfromlightning_utilities.core.rank_zeroimportrank_zero_infofromlightning_fabric.accelerators.acceleratorimportAcceleratorfromlightning_fabric.utilities.importsimport_TORCH_GREATER_EQUAL_1_12,_TORCH_GREATER_EQUAL_2_0
[docs]classCUDAAccelerator(Accelerator):"""Accelerator for NVIDIA CUDA devices."""
[docs]defsetup_device(self,device:torch.device)->None:""" Raises: ValueError: If the selected device is not of type CUDA. """ifdevice.type!="cuda":raiseValueError(f"Device should be CUDA, got {device} instead.")_check_cuda_matmul_precision(device)torch.cuda.set_device(device)
[docs]@staticmethoddefget_parallel_devices(devices:List[int])->List[torch.device]:"""Gets parallel devices for the Accelerator."""return[torch.device("cuda",i)foriindevices]
[docs]@staticmethoddefauto_device_count()->int:"""Get the devices when set to auto."""returnnum_cuda_devices()
deffind_usable_cuda_devices(num_devices:int=-1)->List[int]:"""Returns a list of all available and usable CUDA GPU devices. A GPU is considered usable if we can successfully move a tensor to the device, and this is what this function tests for each GPU on the system until the target number of usable devices is found. A subset of GPUs on the system might be used by other processes, and if the GPU is configured to operate in 'exclusive' mode (configurable by the admin), then only one process is allowed to occupy it. Args: num_devices: The number of devices you want to request. By default, this function will return as many as there are usable CUDA GPU devices available. Warning: If multiple processes call this function at the same time, there can be race conditions in the case where both processes determine that the device is unoccupied, leading into one of them crashing later on. """visible_devices=_get_all_visible_cuda_devices()ifnotvisible_devices:raiseValueError(f"You requested to find {num_devices} devices but there are no visible CUDA devices on this machine.")ifnum_devices>len(visible_devices):raiseValueError(f"You requested to find {num_devices} devices but this machine only has {len(visible_devices)} GPUs.")available_devices=[]unavailable_devices=[]forgpu_idxinvisible_devices:try:torch.tensor(0,device=torch.device("cuda",gpu_idx))exceptRuntimeError:unavailable_devices.append(gpu_idx)continueavailable_devices.append(gpu_idx)iflen(available_devices)==num_devices:# exit early if we found the right number of GPUsbreakifnum_devices!=-1andlen(available_devices)!=num_devices:raiseRuntimeError(f"You requested to find {num_devices} devices but only {len(available_devices)} are currently available."f" The devices {unavailable_devices} are occupied by other processes and can't be used at the moment.")returnavailable_devicesdef_get_all_visible_cuda_devices()->List[int]:"""Returns a list of all visible CUDA GPU devices. Devices masked by the environment variabale ``CUDA_VISIBLE_DEVICES`` won't be returned here. For example, assume you have 8 physical GPUs. If ``CUDA_VISIBLE_DEVICES="1,3,6"``, then this function will return the list ``[0, 1, 2]`` because these are the three visible GPUs after applying the mask ``CUDA_VISIBLE_DEVICES``. """returnlist(range(num_cuda_devices()))# TODO: Remove once minimum supported PyTorch version is 2.0@contextmanagerdef_patch_cuda_is_available()->Generator:"""Context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if possible."""ifhasattr(torch._C,"_cuda_getDeviceCount")and_device_count_nvml()>=0andnot_TORCH_GREATER_EQUAL_2_0:# we can safely patch is_available if both torch has CUDA compiled and the NVML count is succeeding# otherwise, patching is_available could lead to attribute errors or infinite recursionorig_check=torch.cuda.is_availabletorch.cuda.is_available=is_cuda_availabletry:yieldfinally:torch.cuda.is_available=orig_checkelse:yield@lru_cache(1)defnum_cuda_devices()->int:"""Returns the number of available CUDA devices. Unlike :func:`torch.cuda.device_count`, this function does its best not to create a CUDA context for fork support, if the platform allows it. """if_TORCH_GREATER_EQUAL_2_0:returntorch.cuda.device_count()# Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879# TODO: Remove once minimum supported PyTorch version is 2.0nvml_count=_device_count_nvml()returntorch.cuda.device_count()ifnvml_count<0elsenvml_countdefis_cuda_available()->bool:"""Returns a bool indicating if CUDA is currently available. Unlike :func:`torch.cuda.is_available`, this function does its best not to create a CUDA context for fork support, if the platform allows it. """# We set `PYTORCH_NVML_BASED_CUDA_CHECK=1` in lightning_fabric.__init__.pyreturntorch.cuda.is_available()if_TORCH_GREATER_EQUAL_2_0elsenum_cuda_devices()>0# TODO: Remove once minimum supported PyTorch version is 2.0def_parse_visible_devices()->Union[List[int],List[str]]:"""Parse CUDA_VISIBLE_DEVICES environment variable."""var=os.getenv("CUDA_VISIBLE_DEVICES")ifvarisNone:returnlist(range(64))def_strtoul(s:str)->int:"""Return -1 or positive integer sequence string starts with,"""ifnots:return-1foridx,cinenumerate(s):ifnot(c.isdigit()or(idx==0andcin"+-")):breakifidx+1==len(s):idx+=1returnint(s[:idx])ifidx>0else-1defparse_list_with_prefix(lst:str,prefix:str)->List[str]:rcs:List[str]=[]foreleminlst.split(","):# Repeated id results in empty setifeleminrcs:returncast(List[str],[])# Anything other but prefix is ignoredifnotelem.startswith(prefix):breakrcs.append(elem)returnrcsifvar.startswith("GPU-"):returnparse_list_with_prefix(var,"GPU-")ifvar.startswith("MIG-"):returnparse_list_with_prefix(var,"MIG-")# CUDA_VISIBLE_DEVICES uses something like strtoul# which makes `1gpu2,2ampere` is equivalent to `1,2`rc:List[int]=[]foreleminvar.split(","):x=_strtoul(elem.strip())# Repeated ordinal results in empty setifxinrc:returncast(List[int],[])# Negative value aborts the sequenceifx<0:breakrc.append(x)returnrc# TODO: Remove once minimum supported PyTorch version is 2.0def_raw_device_count_nvml()->int:"""Return number of devices as reported by NVML or negative value if NVML discovery/initialization failed."""fromctypesimportbyref,c_int,CDLLnvml_h=CDLL("libnvidia-ml.so.1")rc=nvml_h.nvmlInit()ifrc!=0:warnings.warn("Can't initialize NVML")return-1dev_count=c_int(-1)rc=nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))ifrc!=0:warnings.warn("Can't get nvml device count")return-1delnvml_hreturndev_count.value# TODO: Remove once minimum supported PyTorch version is 2.0def_raw_device_uuid_nvml()->Optional[List[str]]:"""Return list of device UUID as reported by NVML or None if NVM discovery/initialization failed."""fromctypesimportbyref,c_int,c_void_p,CDLL,create_string_buffernvml_h=CDLL("libnvidia-ml.so.1")rc=nvml_h.nvmlInit()ifrc!=0:warnings.warn("Can't initialize NVML")returnNonedev_count=c_int(-1)rc=nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))ifrc!=0:warnings.warn("Can't get nvml device count")returnNoneuuids:List[str]=[]foridxinrange(dev_count.value):dev_id=c_void_p()rc=nvml_h.nvmlDeviceGetHandleByIndex_v2(idx,byref(dev_id))ifrc!=0:warnings.warn("Can't get device handle")returnNonebuf_len=96buf=create_string_buffer(buf_len)rc=nvml_h.nvmlDeviceGetUUID(dev_id,buf,buf_len)ifrc!=0:warnings.warn("Can't get device UUID")returnNoneuuids.append(buf.raw.decode("ascii").strip("\0"))delnvml_hreturnuuids# TODO: Remove once minimum supported PyTorch version is 2.0def_transform_uuid_to_ordinals(candidates:List[str],uuids:List[str])->List[int]:"""Given the set of partial uuids and list of known uuids builds a set of ordinals excluding ambiguous partials IDs."""defuuid_to_orinal(candidate:str,uuids:List[str])->int:best_match=-1foridx,uuidinenumerate(uuids):ifnotuuid.startswith(candidate):continue# Ambigous candidateifbest_match!=-1:return-1best_match=idxreturnbest_matchrc:List[int]=[]forcandidateincandidates:idx=uuid_to_orinal(candidate,uuids)# First invalid ordinal stops parsingifidx<0:break# Duplicates result in empty setifidxinrc:returncast(List[int],[])rc.append(idx)returnrc# TODO: Remove once minimum supported PyTorch version is 2.0def_device_count_nvml()->int:"""Return number of devices as reported by NVML taking CUDA_VISIBLE_DEVICES into account. Negative value is returned if NVML discovery or initialization has failed. """visible_devices=_parse_visible_devices()ifnotvisible_devices:return0try:iftype(visible_devices[0])isstr:# Skip MIG parsingifvisible_devices[0].startswith("MIG-"):return-1uuids=_raw_device_uuid_nvml()ifuuidsisNone:return-1visible_devices=_transform_uuid_to_ordinals(cast(List[str],visible_devices),uuids)else:raw_cnt=_raw_device_count_nvml()ifraw_cnt<=0:returnraw_cnt# Trim the list up to a maximum available deviceforidx,valinenumerate(visible_devices):ifcast(int,val)>=raw_cnt:returnidxexceptOSError:return-1exceptAttributeError:return-1returnlen(visible_devices)def_check_cuda_matmul_precision(device:torch.device)->None:ifnot_TORCH_GREATER_EQUAL_1_12:# before 1.12, tf32 was used by defaultreturnmajor,_=torch.cuda.get_device_capability(device)ampere_or_later=major>=8# Ampere and later leverage tensor cores, where this setting becomes usefulifnotampere_or_later:return# check that the user hasn't changed the precision already, this works for both `allow_tf32 = True` and# `set_float32_matmul_precision`iftorch.get_float32_matmul_precision()=="highest":# defaultrank_zero_info(f"You are using a CUDA device ({torch.cuda.get_device_name(device)!r}) that has Tensor Cores. To properly"" utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off"" precision for performance. For more details, read https://pytorch.org/docs/stable/generated/""torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision")# note: no need change `torch.backends.cudnn.allow_tf32` as it's enabled by default:# https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devicesdef_clear_cuda_memory()->None:if_TORCH_GREATER_EQUAL_2_0:# https://github.com/pytorch/pytorch/issues/95668torch._C._cuda_clearCublasWorkspaces()torch.cuda.empty_cache()
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.