Source code for lightning.pytorch.profilers.pytorch
# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Profiler to check if there are any bottlenecks in your code."""importinspectimportloggingimportosfromcontextlibimportAbstractContextManagerfromfunctoolsimportlru_cache,partialfrompathlibimportPathfromtypingimportTYPE_CHECKING,Any,Callable,Optional,UnionimporttorchfromtorchimportTensor,nnfromtorch.autograd.profilerimportEventList,record_functionfromtorch.profilerimportProfilerAction,ProfilerActivity,tensorboard_trace_handlerfromtorch.utils.hooksimportRemovableHandlefromtyping_extensionsimportoverridefromlightning.fabric.accelerators.cudaimportis_cuda_availablefromlightning.fabric.utilities.importsimport_TORCH_GREATER_EQUAL_2_4fromlightning.pytorch.profilers.profilerimportProfilerfromlightning.pytorch.utilities.exceptionsimportMisconfigurationExceptionfromlightning.pytorch.utilities.rank_zeroimportWarningCache,rank_zero_warnifTYPE_CHECKING:fromlightning.pytorch.core.moduleimportLightningModulelog=logging.getLogger(__name__)warning_cache=WarningCache()_PROFILER=Union[torch.profiler.profile,torch.autograd.profiler.profile,torch.autograd.profiler.emit_nvtx]_KINETO_AVAILABLE=torch.profiler.kineto_available()classRegisterRecordFunction:"""While profiling autograd operations, this class will add labels for module names around the forward function. The Lightning PyTorch Profiler will activate this feature automatically. It can be deactivated as follows: Example:: from lightning.pytorch.profilers import PyTorchProfiler profiler = PyTorchProfiler(record_module_names=False) Trainer(profiler=profiler) It can be used outside of Lightning as follows: Example:: from lightning.pytorch import Trainer, seed_everything with RegisterRecordFunction(model): out = model(batch) """def__init__(self,model:nn.Module)->None:self._model=modelself._records:dict[str,record_function]={}self._handles:dict[str,list[RemovableHandle]]={}def_start_recording_forward(self,_:nn.Module,input:Tensor,record_name:str)->Tensor:# Add [pl][module] in name for pytorch profiler to recognizerecord=record_function("[pl][module]"+record_name)record.__enter__()self._records[record_name]=recordreturninputdef_stop_recording_forward(self,_:nn.Module,__:Tensor,output:Tensor,record_name:str)->Tensor:self._records[record_name].__exit__(None,None,None)returnoutputdef__enter__(self)->None:formodule_name,moduleinself._model.named_modules():ifmodule_name:full_name=f"{type(module).__module__}.{type(module).__name__}"record_name=f"{full_name}: {module_name}"pre_forward_handle=module.register_forward_pre_hook(partial(self._start_recording_forward,record_name=record_name))post_forward_handle=module.register_forward_hook(partial(self._stop_recording_forward,record_name=record_name))self._handles[module_name]=[pre_forward_handle,post_forward_handle]def__exit__(self,type:Any,value:Any,traceback:Any)->None:forhandlesinself._handles.values():forhinhandles:h.remove()self._handles={}classScheduleWrapper:"""This class is used to override the schedule logic from the profiler and perform recording for both `training_step`, `validation_step`."""def__init__(self,schedule:Callable)->None:ifnot_KINETO_AVAILABLE:raiseModuleNotFoundError("You are trying to use `ScheduleWrapper` which require kineto install.")self._schedule=scheduleself.reset()defreset(self)->None:# handle properly `fast_dev_run`. PyTorch Profiler will fail otherwise.self._num_training_step=0self._num_validation_step=0self._num_test_step=0self._num_predict_step=0self._training_step_reached_end=Falseself._validation_step_reached_end=Falseself._test_step_reached_end=Falseself._predict_step_reached_end=False# used to stop profiler when `ProfilerAction.RECORD_AND_SAVE` is reached.self._current_action:Optional[str]=Noneself._prev_schedule_action:Optional[ProfilerAction]=Noneself._start_action_name:Optional[str]=Nonedefsetup(self,start_action_name:str)->None:self._start_action_name=start_action_namedefpre_step(self,current_action:str)->None:self._current_action=current_action@propertydefis_training(self)->bool:assertself._current_actionisnotNonereturnself._current_action.endswith("training_step")@propertydefis_validating(self)->bool:assertself._current_actionisnotNonereturnself._current_action.endswith("validation_step")@propertydefis_testing(self)->bool:assertself._current_actionisnotNonereturnself._current_action.endswith("test_step")@propertydefis_predicting(self)->bool:assertself._current_actionisnotNonereturnself._current_action.endswith("predict_step")@propertydefnum_step(self)->int:ifself.is_training:returnself._num_training_stepifself.is_validating:returnself._num_validation_stepifself.is_testing:returnself._num_test_stepifself.is_predicting:returnself._num_predict_stepreturn0def_step(self)->None:ifself.is_training:self._num_training_step+=1elifself.is_validating:assertself._start_action_nameisnotNoneifself._start_action_name.endswith("on_fit_start"):ifself._num_training_step>0:self._num_validation_step+=1else:self._num_validation_step+=1elifself.is_testing:self._num_test_step+=1elifself.is_predicting:self._num_predict_step+=1@propertydefhas_finished(self)->bool:ifself.is_training:returnself._training_step_reached_endifself.is_validating:returnself._validation_step_reached_endifself.is_testing:returnself._test_step_reached_endifself.is_predicting:returnself._predict_step_reached_endreturnFalsedef__call__(self,num_step:int)->"ProfilerAction":# ignore the provided input. Keep internal state instead.ifself._current_actionisNoneorself.has_finished:returnProfilerAction.NONEself._step()action=self._schedule(max(self.num_step,0))ifself._prev_schedule_action==ProfilerAction.RECORDandaction==ProfilerAction.WARMUP:# Work around the corner case when validation starts before train.# In this case, the action is RECORD in validation loop, and then call into the train# and the action is still WARMUP in train and pytorch will recognize this as error.action=ProfilerAction.RECORDifaction==ProfilerAction.RECORD_AND_SAVE:ifself.is_training:self._training_step_reached_end=Trueelifself.is_validating:self._validation_step_reached_end=Trueelifself.is_testing:self._test_step_reached_end=Trueelifself.is_predicting:self._predict_step_reached_end=Trueself._prev_schedule_action=actionreturnaction
[docs]classPyTorchProfiler(Profiler):STEP_FUNCTIONS={"training_step","validation_step","test_step","predict_step"}AVAILABLE_SORT_KEYS={"cpu_time","cuda_time","cpu_time_total","cuda_time_total","cpu_memory_usage","cuda_memory_usage","self_cpu_memory_usage","self_cuda_memory_usage","count",}def__init__(self,dirpath:Optional[Union[str,Path]]=None,filename:Optional[str]=None,group_by_input_shapes:bool=False,emit_nvtx:bool=False,export_to_chrome:bool=True,row_limit:int=20,sort_by_key:Optional[str]=None,record_module_names:bool=True,table_kwargs:Optional[dict[str,Any]]=None,**profiler_kwargs:Any,)->None:r"""This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of different operators inside your model - both on the CPU and GPU. Args: dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the ``trainer.log_dir`` (from :class:`~lightning.pytorch.loggers.tensorboard.TensorBoardLogger`) will be used. filename: If present, filename where the profiler results will be saved instead of printing to stdout. The ``.txt`` extension will be used automatically. group_by_input_shapes: Include operator input shapes and group calls by shape. emit_nvtx: Context manager that makes every autograd operation emit an NVTX range Run:: nvprof --profile-from-start off -o trace_name.prof -- <regular command here> To visualize, you can either use:: nvvp trace_name.prof torch.autograd.profiler.load_nvprof(path) export_to_chrome: Whether to export the sequence of profiled operators for Chrome. It will generate a ``.json`` file which can be read by Chrome. row_limit: Limit the number of rows in a table, ``-1`` is a special value that removes the limit completely. sort_by_key: Attribute used to sort entries. By default they are printed in the same order as they were registered. Valid keys include: ``cpu_time``, ``cuda_time``, ``cpu_time_total``, ``cuda_time_total``, ``cpu_memory_usage``, ``cuda_memory_usage``, ``self_cpu_memory_usage``, ``self_cuda_memory_usage``, ``count``. record_module_names: Whether to add module names while recording autograd operation. table_kwargs: Dictionary with keyword arguments for the summary table. \**profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version Raises: MisconfigurationException: If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``. If arg ``schedule`` is not a ``Callable``. If arg ``schedule`` does not return a ``torch.profiler.ProfilerAction``. """super().__init__(dirpath=dirpath,filename=filename)self._group_by_input_shapes=group_by_input_shapesandprofiler_kwargs.get("record_shapes",False)self._emit_nvtx=emit_nvtxself._export_to_chrome=export_to_chromeself._row_limit=row_limitself._sort_by_key=sort_by_keyor_default_sort_by_key(profiler_kwargs)self._record_module_names=record_module_namesself._profiler_kwargs=profiler_kwargsself._table_kwargs=table_kwargsiftable_kwargsisnotNoneelse{}self.profiler:Optional[_PROFILER]=Noneself.function_events:Optional[EventList]=Noneself._lightning_module:Optional[LightningModule]=None# set by ProfilerConnectorself._register:Optional[RegisterRecordFunction]=Noneself._parent_profiler:Optional[AbstractContextManager]=Noneself._recording_map:dict[str,record_function]={}self._start_action_name:Optional[str]=Noneself._schedule:Optional[ScheduleWrapper]=Noneif_KINETO_AVAILABLE:self._init_kineto(profiler_kwargs)ifself._sort_by_keynotinself.AVAILABLE_SORT_KEYS:raiseMisconfigurationException(f"Found sort_by_key: {self._sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. ")forkeyinself._table_kwargs:ifkeyin{"sort_by","row_limit"}:raiseKeyError(f"Found invalid table_kwargs key: {key}. This is already a positional argument of the Profiler.")valid_table_keys=set(inspect.signature(EventList.table).parameters.keys())-{"self","sort_by","row_limit",}ifkeynotinvalid_table_keys:raiseKeyError(f"Found invalid table_kwargs key: {key}. Should be within {valid_table_keys}.")def_init_kineto(self,profiler_kwargs:Any)->None:has_schedule="schedule"inprofiler_kwargsself._has_on_trace_ready="on_trace_ready"inprofiler_kwargsschedule=profiler_kwargs.get("schedule",None)ifscheduleisnotNone:ifnotcallable(schedule):raiseMisconfigurationException(f"Schedule should be a callable. Found: {schedule}")action=schedule(0)ifnotisinstance(action,ProfilerAction):raiseMisconfigurationException(f"Schedule should return a `torch.profiler.ProfilerAction`. Found: {action}")self._default_schedule()schedule=scheduleifhas_scheduleelseself._default_schedule()self._schedule=ScheduleWrapper(schedule)ifscheduleisnotNoneelsescheduleself._profiler_kwargs["schedule"]=self._scheduleactivities=profiler_kwargs.get("activities",None)self._profiler_kwargs["activities"]=activitiesorself._default_activities()self._export_to_flame_graph=profiler_kwargs.get("export_to_flame_graph",False)self._metric=profiler_kwargs.get("metric","self_cpu_time_total")with_stack=profiler_kwargs.get("with_stack",False)orself._export_to_flame_graphself._profiler_kwargs["with_stack"]=with_stack@propertydef_total_steps(self)->Union[int,float]:assertself._scheduleisnotNoneassertself._lightning_moduleisnotNonetrainer=self._lightning_module.trainerifself._schedule.is_training:returntrainer.num_training_batchesifself._schedule.is_validating:num_val_batches=(sum(trainer.num_val_batches)ifisinstance(trainer.num_val_batches,list)elsetrainer.num_val_batches)num_sanity_val_batches=(sum(trainer.num_sanity_val_batches)ifisinstance(trainer.num_sanity_val_batches,list)elsetrainer.num_sanity_val_batches)returnnum_val_batches+num_sanity_val_batchesifself._schedule.is_testing:num_test_batches=(sum(trainer.num_test_batches)ifisinstance(trainer.num_test_batches,list)elsetrainer.num_test_batches)returnnum_test_batchesifself._schedule.is_predicting:returnsum(trainer.num_predict_batches)raiseNotImplementedError("Unsupported schedule")def_should_override_schedule(self)->bool:return(self._lightning_moduleisnotNoneandself._scheduleisnotNoneandself._total_steps<5andself._schedule._schedule==self._default_schedule())@staticmethod@lru_cache(1)def_default_schedule()->Optional[Callable]:if_KINETO_AVAILABLE:# Those schedule defaults allow the profiling overhead to be negligible over training time.returntorch.profiler.schedule(wait=1,warmup=1,active=3)returnNonedef_default_activities(self)->list["ProfilerActivity"]:activities:list[ProfilerActivity]=[]ifnot_KINETO_AVAILABLE:returnactivitiesif_TORCH_GREATER_EQUAL_2_4:activities.append(ProfilerActivity.CPU)ifis_cuda_available():activities.append(ProfilerActivity.CUDA)else:# `use_cpu` and `use_cuda` are deprecated in PyTorch >= 2.4ifself._profiler_kwargs.get("use_cpu",True):activities.append(ProfilerActivity.CPU)ifself._profiler_kwargs.get("use_cuda",is_cuda_available()):activities.append(ProfilerActivity.CUDA)returnactivities
[docs]@overridedefstart(self,action_name:str)->None:ifself.profilerisNone:# close profiler if it is already opened. might happen if 2 profilers# are created and the first one did not call `describe`iftorch.autograd._profiler_enabled():torch.autograd._disable_profiler()ifself._scheduleisnotNone:self._schedule.setup(action_name)self._create_profilers()profiler=self.profiler.__enter__()ifprofilerisnotNone:self.profiler=profilerifself._parent_profilerisnotNone:self._parent_profiler.__enter__()ifself._lightning_moduleisnotNoneandself._registerisNoneandself._record_module_names:self._register=RegisterRecordFunction(self._lightning_module)self._register.__enter__()ifself.profilerisnotNoneandaction_namenotinself._recording_map:# Add [pl][profile] in name for pytorch profiler to recognizerecording=record_function("[pl][profile]"+action_name)recording.__enter__()self._recording_map[action_name]=recording
[docs]@overridedefstop(self,action_name:str)->None:ifaction_nameinself._recording_map:self._recording_map[action_name].__exit__(None,None,None)delself._recording_map[action_name]ifnot_KINETO_AVAILABLEorself._emit_nvtx:returnifself.profilerisnotNoneandany(action_name.endswith(func)forfuncinself.STEP_FUNCTIONS):assertisinstance(self.profiler,torch.profiler.profile)ifself._scheduleisnotNone:self._schedule.pre_step(action_name)# the default schedule requires a minimum of 5 steps to properly work: `wait=1, warmup=1, active=3`.# otherwise, this will raise a `segmentation fault`.ifself._should_override_schedule():warning_cache.warn("The PyTorch Profiler default schedule will be overridden as there is not enough ""steps to properly record traces.")self._schedule=Noneself.profiler.schedule=torch.profiler.profiler._default_schedule_fndefon_trace_ready(profiler:_PROFILER)->None:ifself.dirpathisnotNone:ifself._export_to_chrome:handler=tensorboard_trace_handler(str(self.dirpath),self._prepare_filename(action_name=action_name,extension=""))handler(profiler)ifself._export_to_flame_graph:path=os.path.join(self.dirpath,self._prepare_filename(action_name=action_name,extension=".stack"))assertisinstance(profiler,torch.autograd.profiler.profile)profiler.export_stacks(path,metric=self._metric)else:rank_zero_warn("The PyTorchProfiler failed to export trace as `dirpath` is None")ifnotself._has_on_trace_ready:self.profiler.on_trace_ready=on_trace_readyifself._scheduleisnotNone:self.profiler.step_num=self._schedule.num_stepself.profiler.step()self.profiler.add_metadata("Framework","pytorch-lightning")
def_default_sort_by_key(profiler_kwargs:dict)->str:activities=profiler_kwargs.get("activities",[])is_cuda=(profiler_kwargs.get("use_cuda",False)# `use_cuda` is deprecated in PyTorch >= 2.4or(activitiesandProfilerActivity.CUDAinactivities)or(notactivitiesandis_cuda_available()))returnf"{'cuda'ifis_cudaelse'cpu'}_time_total"
To analyze traffic and optimize your experience, we serve cookies on this
site. By clicking or navigating, you agree to allow our usage of cookies.
Read PyTorch Lightning's
Privacy Policy.
You are viewing an outdated version of PyTorch Lightning Docs