Source code for pytorch_lightning.accelerators.gpu
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importloggingimportosimportshutilimportsubprocessfromtypingimportAny,Dict,List,Optional,Unionimporttorchimportpytorch_lightningasplfrompytorch_lightning.accelerators.acceleratorimportAcceleratorfrompytorch_lightning.utilitiesimportdevice_parserfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.typesimport_DEVICE_log=logging.getLogger(__name__)
[docs]classGPUAccelerator(Accelerator):"""Accelerator for GPU devices."""
[docs]defsetup_environment(self,root_device:torch.device)->None:""" Raises: MisconfigurationException: If the selected device is not GPU. """super().setup_environment(root_device)ifroot_device.type!="cuda":raiseMisconfigurationException(f"Device should be GPU, got {root_device} instead")torch.cuda.set_device(root_device)
[docs]defsetup(self,trainer:"pl.Trainer")->None:# TODO refactor input from trainer to local_rank @four4fishself.set_nvidia_flags(trainer.local_rank)# clear cache before trainingtorch.cuda.empty_cache()
@staticmethoddefset_nvidia_flags(local_rank:int)->None:# set the correct cuda visible devices (using pci order)os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"all_gpu_ids=",".join(str(x)forxinrange(torch.cuda.device_count()))devices=os.getenv("CUDA_VISIBLE_DEVICES",all_gpu_ids)_log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]")
[docs]defget_device_stats(self,device:_DEVICE)->Dict[str,Any]:"""Gets stats for the given GPU device. Args: device: GPU device for which to get stats Returns: A dictionary mapping the metrics to their values. Raises: FileNotFoundError: If nvidia-smi installation not found """returntorch.cuda.memory_stats(device)
[docs]@staticmethoddefget_parallel_devices(devices:List[int])->List[torch.device]:"""Gets parallel devices for the Accelerator."""return[torch.device("cuda",i)foriindevices]
[docs]@staticmethoddefauto_device_count()->int:"""Get the devices when set to auto."""returntorch.cuda.device_count()
defget_nvidia_gpu_stats(device:_DEVICE)->Dict[str,float]:# pragma: no-cover"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi. Args: device: GPU device for which to get stats Returns: A dictionary mapping the metrics to their values. Raises: FileNotFoundError: If nvidia-smi installation not found """nvidia_smi_path=shutil.which("nvidia-smi")ifnvidia_smi_pathisNone:raiseFileNotFoundError("nvidia-smi: command not found")gpu_stat_metrics=[("utilization.gpu","%"),("memory.used","MB"),("memory.free","MB"),("utilization.memory","%"),("fan.speed","%"),("temperature.gpu","°C"),("temperature.memory","°C"),]gpu_stat_keys=[kfork,_ingpu_stat_metrics]gpu_query=",".join(gpu_stat_keys)index=torch._utils._get_device_index(device)gpu_id=_get_gpu_id(index)result=subprocess.run([nvidia_smi_path,f"--query-gpu={gpu_query}","--format=csv,nounits,noheader",f"--id={gpu_id}"],encoding="utf-8",capture_output=True,check=True,)def_to_float(x:str)->float:try:returnfloat(x)exceptValueError:return0.0s=result.stdout.strip()stats=[_to_float(x)forxins.split(", ")]gpu_stats={f"{x} ({unit})":statfor(x,unit),statinzip(gpu_stat_metrics,stats)}returngpu_statsdef_get_gpu_id(device_id:int)->str:"""Get the unmasked real GPU IDs."""# All devices if `CUDA_VISIBLE_DEVICES` unsetdefault=",".join(str(i)foriinrange(torch.cuda.device_count()))cuda_visible_devices=os.getenv("CUDA_VISIBLE_DEVICES",default=default).split(",")returncuda_visible_devices[device_id].strip()
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.