Source code for pytorch_lightning.utilities.memory

# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities related to memory."""

import gc
from io import BytesIO
from typing import Any

import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch.nn import Module

[docs]def recursive_detach(in_dict: Any, to_cpu: bool = False) -> Any: """Detach all tensors in `in_dict`. May operate recursively if some of the values in `in_dict` are dictionaries which contain instances of `Tensor`. Other types in `in_dict` are not affected by this utility function. Args: in_dict: Dictionary with tensors to detach to_cpu: Whether to move tensor to cpu Return: out_dict: Dictionary with detached tensors """ def detach_and_move(t: Tensor, to_cpu: bool) -> Tensor: t = t.detach() if to_cpu: t = t.cpu() return t return apply_to_collection(in_dict, Tensor, detach_and_move, to_cpu=to_cpu)
def is_oom_error(exception: BaseException) -> bool: return is_cuda_out_of_memory(exception) or is_cudnn_snafu(exception) or is_out_of_cpu_memory(exception) # based on def is_cuda_out_of_memory(exception: BaseException) -> bool: return ( isinstance(exception, RuntimeError) and len(exception.args) == 1 and "CUDA" in exception.args[0] and "out of memory" in exception.args[0] ) # based on def is_cudnn_snafu(exception: BaseException) -> bool: # For/because of return ( isinstance(exception, RuntimeError) and len(exception.args) == 1 and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0] ) # based on def is_out_of_cpu_memory(exception: BaseException) -> bool: return ( isinstance(exception, RuntimeError) and len(exception.args) == 1 and "DefaultCPUAllocator: can't allocate memory" in exception.args[0] ) # based on
[docs]def garbage_collection_cuda() -> None: """Garbage collection Torch (CUDA) memory.""" gc.collect() try: # This is the last thing that should cause an OOM error, but seemingly it can. torch.cuda.empty_cache() except RuntimeError as exception: if not is_oom_error(exception): # Only handle OOM errors raise
[docs]def get_model_size_mb(model: Module) -> float: """Calculates the size of a Module in megabytes. The computation includes everything in the :meth:`~torch.nn.Module.state_dict`, i.e., by default the parameters and buffers. Returns: Number of megabytes in the parameters of the input module. """ model_size = BytesIO(), model_size) size_mb = model_size.getbuffer().nbytes / 1e6 return size_mb

© Copyright Copyright (c) 2018-2023, Lightning AI et al...

Built with Sphinx using a theme provided by Read the Docs.