Introduction

In this tutorial, we will write a Trace transformation to perform CPU Offloading of intermediate tensors.

CPU Offloading is a technique to decrease the peak memory usage during training. This can allow us to train a larger model which would otherwise won’t be possible. However, we have to trade of some performance (increased memory transfers) to achieve the same.

[1]:
import gc
from typing import Callable, Mapping

import torch
import torch.utils.benchmark

import thunder
from thunder.core.trace import TraceCtx
from thunder.core.transform_common import Transform
from thunder.core.proxies import TensorProxy, variableify, Variable
from thunder.core.pytree import tree_map
from thunder.core.trace import tracectx, from_trace, TraceTag
from thunder.extend import OperatorExecutor
from thunder.core.symbol import BoundSymbol
from thunder.core import prims
from thunder.core.transforms import bsym_list_to_dag, Node, toposort_bsym_dag, TOPOSORT_ORDER
from thunder.core.vjp_utils import get_saved_for_backward_tensors
from thunder.core.module import ThunderModule

Transforms

To understand transforms, we need to know what a Trace is. In thunder, Trace is the representation of the jitted program in terms of thunder operations/symbol. Each operation in Trace is a collection of BoundSymbol i.e. a Symbol with it’s input and output. We can also print the trace as a Python program for easier inspection. We will do this later in the notebook. To understand these concepts, you can read the helpful zero_to_thunder.ipynb.

thunder allows us to write our custom transforms to transform trace/s. These transforms can be used for replace pointwise operations with fused implementation, compute gradient of the given computation, etc. Besides this, thunder enables us to apply these transforms at different stages during compilation. In this tutorial, we will use the post optimization stage, the point at which we already have the forward and the backward execution trace ready. To write our transform, we have to inherit from Transform class. This class implements the interface that each transform should have. By default, it provides no-op transformations. To use our transform, we provide an instance of our transform object to the thunder.jit via transforms argument.

However, before writing our transform, we will make an OperatorExecutor with which we will create 2 operators/symbol - 1. to offload tensors to CPU 2. Load the offloaded tensors back to CUDA device. You read more about adding custom operators in adding_custom_operator.ipynb and also zero_to_thunder.ipynb.

[8]:
# Create a new executor.
offload_ex = OperatorExecutor("offload_ex")

# NOTE: We create the offloaded CPU tensor in pinned memory and load the tensor back onto GPU with `to(non_blocking=True)`.
#       These allow for better memory transfer speeds.
#       Read the following tutorial for detailed explanation - https://pytorch.org/tutorials/intermediate/pinmem_nonblock.html

# Offload the GPU tensor to a pinned CPU tensor.
def offload_to_cpu_impl(t):
    # Due to https://github.com/Lightning-AI/lightning-thunder/issues/950
    # it may receive tensor on CPU.
    if t.device == torch.device("cpu"):
        return t

    packed = torch.empty(
        t.size(),
        dtype=t.dtype,
        layout=t.layout,
        pin_memory=True,
    )
    packed.copy_(t)
    return packed

offload_to_cpu = offload_ex.register_operator(
    "offload_to_cpu",
    meta=lambda t: TensorProxy("offloaded_" + t.name, like=t, device=thunder.core.devices.Device("cpu")),
    fn=offload_to_cpu_impl,
)


# Load the tensor to given GPU
def load_to_gpu_impl(t, device):
    return t.to(device, non_blocking=True)


load_to_gpu = offload_ex.register_operator(
    "load_to_gpu",
    meta=lambda t, device: TensorProxy(like=t, device=thunder.core.devices.Device(device)),
    fn=load_to_gpu_impl,
)

First we will have some helper functions to implement our transformation

[9]:
def get_symbols_to_first_or_last_used_variables(symbols, first_used=False):
    """
    This function processes a sequence of symbols and determines which variables
    are first/last used by each symbol determined based on argument `first_used`.
    It returns a mapping from variables to the symbols where they were first/last used.

    Args:
        symbols (iterable): An iterable of symbols
        first_used (bool): Whether to return the map of first used variable to symbol mapping if True otherwise return the map for last used.
                           Defaults to False.

    Returns:
        variable_to_symbol (dict): A dictionary mapping each variable to the  symbol where it is first/last used based on `first_used` argument.
    """
    variable_to_symbol = {}

    def _mark_first_or_last_use(symbol, variable):
        if not variable in variable_to_symbol:
            variable_to_symbol[variable] = symbol

    iter_symbols = symbols if first_used else reversed(symbols)
    for symbol in iter_symbols:
        # If this function is used in the combined nvfuser+torch executor, there are no symbols but regions.
        # Regions do not have args, kwargs
        if hasattr(symbol, "inputs"):
            variables = tuple(symbol.inputs) + tuple(symbol.outputs)
        else:
            variables = (symbol.flat_variableified_proxy_args) + tuple(symbol.flat_variableified_proxy_outs)
        tree_map(lambda x: _mark_first_or_last_use(symbol, x), variables)

    return variable_to_symbol


def get_symbol_to_idx(symbols):
    '''
    This function returns a map from symbol to it's position in the sequence.
    '''
    return {sym: idx for idx, sym in enumerate(symbols)}


def move_closer_to_consumer(execution_trace: TraceCtx) -> TraceCtx:
    '''
    This function takes the trace and reorders the operation such that operations producing input for the next operation
    are closer together.

    This is required as in the backward trace, the first consumer of saved_for_backward tensor maybe
    a reshape or permute op and the actual computation occurs 50-100 (or more) lines later.
    Because of this we load more tensors than required eagerly (thus decreasing the memory gains from CPU Offloading).

    Args:
        execution_trace (TraceCtx): Trace to be re-ordered.
    '''
    order_in_trace = {bsym: i for i, bsym in enumerate(execution_trace.bound_symbols)}

    def prefer_ops_closer_to_consumer(eligible_nodes: list[Node]) -> int:
        def key(node: Node) -> int:
            return order_in_trace[node.bsym]

        return min(range(len(eligible_nodes)), key=lambda i: key(eligible_nodes[i]))

    # This moves all del or clear collection at the bottom (as they don't return anything)
    bound_symbols = toposort_bsym_dag(
        bsym_list_to_dag(execution_trace.bound_symbols)[1],
        TOPOSORT_ORDER.BOTTOM_UP,
        selector=prefer_ops_closer_to_consumer,
    )

    for idx, bsym in enumerate(bound_symbols):
        if bsym.sym.id == prims.PrimIDs.DEL:
            break

    new_execution_trace = from_trace(execution_trace)
    new_execution_trace.bound_symbols = bound_symbols[:idx]

    new_execution_trace = thunder.executors.passes.del_last_used(new_execution_trace, clear_mutable_collections=True)
    return new_execution_trace

Now to the main topic, of writing the transform for CPUOffloading.

The rough implementation of transform will look like this

  1. From the forward computation trace, determine which tensors we want to offload to CPU. The return symbol of the forward trace has a sequence of tensors which are saved for the backward trace. We go through this list of tensor and find all the intermediate tensors (i.e. which are not an input to the trace). Here, we will also call a user provided callback which can further filter this list of tensors to offload.

  2. In the forward trace, we then find the last of use of the tensors to offload from above step and insert a call to offload_to_cpu symbol that we created above. Note that we will also save a map of which tensors we offloaded. We also note the original device where the tensor lived so that we can load it back to correct device.

  3. In the forward trace, we then update the return symbol to return the offloaded tensors (which are saved for the backward pass).

  4. In the backward trace, we read from the map of the tensors which were offloaded and update the unpack symbol of saved tensors to replace the original tensors with our offloaded tensors.

  5. In the backward trace, we then find the first use of the offloaded tensor in a computation and insert a load_to_gpu call before it. Note that here, we will use the previously stored map of tensor to original device so that we load it onto the correct device.

To see this steps, in our implementation -

  1. See method transform_trace_post_optimization, which is invoked by thunder with first the forward trace and then separately with the backward trace.

  2. See method _offload_tensors_from_forward, which implements Step 1, 2 and 3 from above.

  3. See method _load_tensors_for_backward, which implements Step 4, 5 from above.

Note that each of the above method has more details regarding the implementation.

[10]:
class CPUOffloading(Transform):
    '''
    Transform to implement CPU Offloading.

    Args:
        save_tensor_policy: Users can pass a callback with signature fn(offloaded_tensors, forward_trace) to filter
                            the offloaded_tensors based on their preference eg. biggest 20% intermediate tensors or
                            intermediates of certain operations
    '''
    def __init__(self, save_tensor_policy: Callable[[tuple[TensorProxy, ...], TraceCtx], tuple[TensorProxy, ...]] | None = None):
        self.forward_pass = None
        self.backward_pass = None
        self._offloaded_tensors: Mapping[Variable, TensorProxy] = {}
        self._offloaded_tensors_dev: Mapping[Variable, str] = {}
        self.save_tensor_policy = None
        if save_tensor_policy is not None:
            assert callable(save_tensor_policy)
            self.save_tensor_policy = save_tensor_policy

    def _get_tensors_to_offload(self, forward_trace):
        '''
        Based on the `forward_trace`, we find the symbols that we want to offload to CPU.
        This function finds the intermediate tensors that are saved for backward i.e. ones that are not input or output of this trace.
        '''
        return_bsym = forward_trace.bound_symbols[-1]
        trace_args = return_bsym.args[0]["flat_args"]
        saved_tensors = get_saved_for_backward_tensors(forward_trace)

        tensor_args_name = tuple(arg.name for arg in trace_args if isinstance(arg, TensorProxy))

        def is_in_tensor_args(t):
            return t.name in tensor_args_name

        def is_cuda_tensor(t):
            return t.device.type == "cuda"

        # Tensors which are intermediate and not argument to the computation trace are
        # the ones we are interested in offloading.
        tensors_to_offload = tuple(t for t in saved_tensors if ((not is_in_tensor_args(t)) and is_cuda_tensor(t)))

        # Apply users policy if present.
        if self.save_tensor_policy is not None:
            tensors_to_offload = self.save_tensor_policy(tensors_to_offload, forward_trace)
        self.tensors_to_offload = tensors_to_offload
        return self.tensors_to_offload

    def _replace_saved_tensors(self, forward_trace, new_output_map):
        return_bsym = forward_trace.bound_symbols[-1]
        new_return_bsym = return_bsym.from_bsym_swap_proxies(new_output_map)

        # Replace the old return with our new return.
        forward_trace.bound_symbols.pop(-1)
        forward_trace.bound_symbols.append(new_return_bsym)

    def _offload_tensors_from_forward(self, computation_trace):
        '''
        This function takes the forward computation trace and performs following step
        1. Find the tensors to be offloaded using `_get_tensors_to_offload` (this also calls users `save_tensor_policy` if present).
        2. Insert calls to the `offload_to_cpu` symbol with the tensor to offload. These calls are placed after the last computational
           use of the tensors to be offloaded so that we free the memory as soon as possible.
        3. Finally, we update the last symbol i.e. `return` symbol to return the offloaded tensors instead of the original tensors.
        '''
        # Step 1
        # Find the tensors to offload.
        # We offload saved tensors which are not arguments to the computation trace and are saved for backwards.
        tensors_to_offload = self._get_tensors_to_offload(computation_trace)

        # Step 2
        # Insert the offloading calls after the last use of the saved tensor (which we want to offload).
        # NOTE - We pass `computation_trace.bound_symbols[:-1]` as we don't want to pass the `return` symbol (which will otherwise be the last use of the saved tensors).
        variable_to_last_symbol = get_symbols_to_first_or_last_used_variables(
            computation_trace.bound_symbols[:-1], first_used=False
        )
        symbol_to_idx = get_symbol_to_idx(computation_trace.bound_symbols)

        # Book keeping for backward pass update.
        new_output_map: Mapping[Variable, TensorProxy] = {}
        new_output_dev_map: Mapping[Variable, str] = {}

        # Since we are inserting in the list (we need to obey increasing order) - else the insertions will be incorrect.
        sorted_tensors_to_offload = sorted(
            tensors_to_offload, key=lambda t: symbol_to_idx[variable_to_last_symbol[variableify(t)]]
        )
        for idx, t in enumerate(sorted_tensors_to_offload):
            last_used_symbol = variable_to_last_symbol[variableify(t)]
            last_used_symbol_idx = symbol_to_idx[last_used_symbol]
            computation_trace.push_scope([])
            with tracectx(computation_trace):
                o = offload_to_cpu(t)
                prims.python_del(t)
            scoped_comp = computation_trace.pop_scope()
            scoped_comp[0].header = "Created by CPU Offloading Transform"
            offload_to_cpu_symbol = scoped_comp[0]
            del_symbol = scoped_comp[1]

            # This will insert `del` first and then push it down when we insert `offload_to_cpu`.
            computation_trace.bound_symbols.insert(last_used_symbol_idx + 1 + (idx * 2), del_symbol)
            computation_trace.bound_symbols.insert(last_used_symbol_idx + 1 + (idx * 2), offload_to_cpu_symbol)

            # Update book keeping.
            new_output_map[variableify(t)] = o
            new_output_dev_map[variableify(t)] = t.device.device_str()

        # Step 3
        # Update the return symbol to return our offloaded tensors in saved for backward.
        self._replace_saved_tensors(computation_trace, new_output_map)

        # Book keeping for backward pass update.
        self._offloaded_tensors = new_output_map
        self._offloaded_tensors_dev = new_output_dev_map
        return computation_trace

    def _load_tensors_for_backward(self, computation_trace):
        '''
        This function takes the backward computation trace and performs following step
        1. Finds the unpack collection symbol which unpacks the saved tensors passed to the backward trace.
        2. Updates the unpack collection to unpack the offloaded tensors instead of the original ones.
        3. Before the first use of the offloaded tensor in computation, we insert the `load_to_gpu` to load the tensor back on GPU.
        '''
        self.backward_pass = computation_trace
        offloaded_tensors = self._offloaded_tensors
        offloaded_tensors_dev_map = self._offloaded_tensors_dev

        compute_producers, compute_consumers = thunder.core.utils.producers_and_consumers(computation_trace)

        # We want to insert `loads` before the first use of offloaded_tensors.
        variable_to_first_symbol = get_symbols_to_first_or_last_used_variables(computation_trace.bound_symbols, first_used=True)

        symbol_to_idx = get_symbol_to_idx(computation_trace.bound_symbols)

        # Step 1 and 2
        # Update unpack collection so that it
        # outputs the offloaded tensor proxies (not the original ones).
        unpack_sym = compute_producers[list(offloaded_tensors.keys())[0].proxy]
        unpack_idx = symbol_to_idx[unpack_sym]
        unpack_sym_out = unpack_sym.output
        new_out = []
        for out in unpack_sym_out:
            if (vout := variableify(out)) in offloaded_tensors:
                new_out.append(offloaded_tensors[vout])
            else:
                new_out.append(out)
        new_unpack_bsym = BoundSymbol.from_bsym(unpack_sym, output=tuple(new_out))
        computation_trace.bound_symbols[unpack_idx] = new_unpack_bsym

        # Now we again find the first usages of offloaded tensor
        # This will actually point us to the first consumer of the offloaded tensor.
        offset = unpack_idx + 1
        variable_to_first_symbol = get_symbols_to_first_or_last_used_variables(computation_trace.bound_symbols[offset:], first_used=True)

        # Step 3
        # Load the offloaded tensors to GPU before usage.
        # Should iterate in correct order (else insertion positions will be incorrect).
        for idx, (vt, offloaded_t) in enumerate(
            sorted(offloaded_tensors.items(), key=lambda kv: symbol_to_idx[variable_to_first_symbol[kv[0]]])
        ):
            first_used_symbol = variable_to_first_symbol[vt]
            first_used_symbol_idx = symbol_to_idx[first_used_symbol]
            t = vt.proxy
            device = offloaded_tensors_dev_map[vt]

            with tracectx(computation_trace):
                new_sym = load_to_gpu.bind(offloaded_t, device, output=t)

            new_sym.header = "Created by CPU Offloading Transform"
            computation_trace.bound_symbols.insert(first_used_symbol_idx + idx, new_sym)

        return computation_trace

    def transform_trace_post_optimization(self, computation_trace: thunder.TraceCtx, **kwargs):
        if self.forward_pass is None:
            self.forward_pass = computation_trace
            # Processing for the forward pass (only if we are going to compute backward).
            if TraceTag.AUGMENTED_FORWARD in computation_trace.tags:
                # Create a new copy of computation trace using `from_trace`.
                new_computation_trace = from_trace(computation_trace)
                # `from_trace` creates a shallow copy where `bound_symbols` and `provenance` are not copied.
                new_computation_trace.bound_symbols = computation_trace.bound_symbols

                new_computation_trace = self._offload_tensors_from_forward(new_computation_trace)
        else:
            # Skip if no tensor was offloaded.
            if len(self._offloaded_tensors) == 0:
                return computation_trace

            # Create a new copy of computation trace using `from_trace`.
            new_computation_trace = from_trace(computation_trace)
            # `from_trace` creates a shallow copy where `bound_symbols` and `provenance` are not copied.
            new_computation_trace.bound_symbols = computation_trace.bound_symbols

            # We need this because in unmodified backward trace, the first consumer of saved_for_backward maybe
            # a reshape or permute op and the actual computation occurs 50-100 (or more) lines later.
            # Because of this we load more tensors than required eagerly (thus decreasing the memory gains from CPU Offloading).
            # Eg. on line 92
            #   # Created by CPU Offloading Transform
            #   t1319 = load_to_gpu(offloaded_t1319, 'cuda:0')  # t1319: "cuda:0 f32[8, 1024, 11008]"
            #   t4021 = torch.reshape(t1319, (-1, 11008))  # t4021: "cuda:0 f32[8192, 11008]"
            #     # t4021 = ltorch.reshape(t1319, (-1, 11008))  # t4021: "cuda:0 f32[8192, 11008]"
            #       # t4021 = prims.reshape(t1319, (8192, 11008))  # t4021: "cuda:0 f32[8192, 11008]"
            #   del t1319
            # And it's usage in computation is at 612
            # t4022 = torch.matmul(t4020, t4021)  # t4022: "cuda:0 f32[4096, 11008]"
            #   t4022 = ltorch.matmul(t4020, t4021)  # t4022: "cuda:0 f32[4096, 11008]"
            #     t4022 = prims.matmul(t4020, t4021)  # t4022: "cuda:0 f32[4096, 11008]"
            new_computation_trace = move_closer_to_consumer(new_computation_trace)

            # Transform the backward trace to load offloaded tensors back to the device.
            new_computation_trace = self._load_tensors_for_backward(new_computation_trace)

        return new_computation_trace
[11]:
def clear_memory():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_accumulated_memory_stats()
    torch.cuda.reset_peak_memory_stats()
    print(f"Allocated Memory after cleaning {torch.cuda.memory_allocated() / 1e9} GB")
[12]:
def benchmark(jmodel: ThunderModule, model: torch.nn.Module, args, kwargs):
    # NOTE - This function takes care of warm-up
    stmt = """
# Use the optimized model for prediction and backward
o = jmodel(*args, **kwargs)
o.sum().backward()
for param in model.parameters():  # use original model for clear grads
    param.grad = None
"""
    timer = torch.utils.benchmark.Timer(
        stmt=stmt, globals={"jmodel": jmodel, "model": model, "args": args, "kwargs": kwargs}
    ).timeit(number=10)
    return timer

Testing our Transform on a Simple Model

[13]:
class MySimpleModel(torch.nn.Module):
    def __init__(self, n_layers=10):
        super().__init__()
        self.fcs = torch.nn.ModuleList([torch.nn.Linear(16, 16) for _ in range(n_layers)])

    def forward(self, x):
        for fc in self.fcs:
            x = torch.nn.functional.relu(fc(x))

        return x

def get_model_and_args():
    device = 'cuda'
    model = MySimpleModel(n_layers=100).to(device)
    args = (torch.randn(128, 16, device=device),)
    kwargs = {}
    return model, args, kwargs

model, args, kwargs = get_model_and_args()

# Check against the vanilla `thunder.jit` model
expected = thunder.jit(model)(*args, **kwargs)

grad_output = torch.randn_like(expected)
expected_grads = torch.autograd.grad(expected, model.parameters(), grad_output)

print(f"Peak Memory with thunder : {torch.cuda.max_memory_allocated()} bytes")
torch.cuda.reset_peak_memory_stats()

with torch.no_grad():
    expected_cpu = expected.to("cpu")
    expected_grads_cpu = tree_map(lambda t: t.to("cpu"), expected_grads)

jmodel = thunder.jit(model, transforms=[CPUOffloading()])

actual = jmodel(*args, **kwargs)

# Verify that saved tensors are on CPU.
saved_tensor_devices = set()
for t in actual.grad_fn.saved_tensors:
    saved_tensor_devices.add(str(t.device))

assert "cpu" in saved_tensor_devices  # Verify that we actually have saved tensors on CPU
actual_grads = torch.autograd.grad(actual, jmodel.parameters(), grad_output)

print(f"Peak Memory with CPU Offloading : {torch.cuda.max_memory_allocated()} bytes")

with torch.no_grad():
    actual_cpu = actual.to("cpu")
    actual_grads_cpu = tree_map(lambda t: t.to("cpu"), actual_grads)

# Sanity Check that values match
torch.testing.assert_close(actual_cpu, expected_cpu)
torch.testing.assert_close(actual_grads_cpu, expected_grads_cpu)

# Fetch the forward and backward traces for inspection
fw_traces = thunder.last_traces(jmodel)
bw_traces = thunder.last_backward_traces(jmodel)

del jmodel, model, args, kwargs, actual, actual_grads, expected, expected_grads, grad_output  # Free memory.

clear_memory()
Peak Memory with thunder : 19279872 bytes
Peak Memory with CPU Offloading : 18444288 bytes
Allocated Memory after cleaning 0.017047552 GB

Inspecting the forward and the backward traces.

[ ]:
fw_traces[-1]  # Note the calls to `offload_to_cpu` and verify that they are after the last usage of the tensor.

Snippet from the forward trace

t485 = torch.nn.functional.linear(t484, t_fcs_97_weight, t_fcs_97_bias)  # t485: "cuda:0 f32[128, 16]"
  # t485 = ltorch.linear(t484, t_fcs_97_weight, t_fcs_97_bias)  # t485: "cuda:0 f32[128, 16]"
    # t485 = prims.linear(t484, t_fcs_97_weight, t_fcs_97_bias)  # t485: "cuda:0 f32[128, 16]"
# Created by CPU Offloading Transform
offloaded_t484 = offload_to_cpu(t484)  # offloaded_t484: "cpu f32[128, 16]"
del t484
[t487, t489] = nvFusion97(t485)
  # t487 = prims.gt(t485, 0.0)  # t487: "cuda:0 b8[128, 16]"
  # t489 = prims.where(t487, t485, 0.0)  # t489: "cuda:0 f32[128, 16]"
# Created by CPU Offloading Transform
offloaded_t487 = offload_to_cpu(t487)  # offloaded_t487: "cpu b8[128, 16]"
del t487
del t485
t490 = torch.nn.functional.linear(t489, t_fcs_98_weight, t_fcs_98_bias)  # t490: "cuda:0 f32[128, 16]"
  # t490 = ltorch.linear(t489, t_fcs_98_weight, t_fcs_98_bias)  # t490: "cuda:0 f32[128, 16]"
    # t490 = prims.linear(t489, t_fcs_98_weight, t_fcs_98_bias)  # t490: "cuda:0 f32[128, 16]"
# Created by CPU Offloading Transform
offloaded_t489 = offload_to_cpu(t489)  # offloaded_t489: "cpu f32[128, 16]"
del t489
[ ]:
bw_traces[-1]  # Note the calls to `load_to_gpu` and verify that they are before the first usage of the tensor.

Snippet from the backward trace

# Created by CPU Offloading Transform
t489 = load_to_gpu(offloaded_t489, 'cuda:0')  # t489: "cuda:0 f32[128, 16]"
t2015 = torch.reshape(t489, (-1, 16))  # t2015: "cuda:0 f32[128, 16]"
  # t2015 = ltorch.reshape(t489, (-1, 16))  # t2015: "cuda:0 f32[128, 16]"
    # t2015 = prims.reshape(t489, (128, 16))  # t2015: "cuda:0 f32[128, 16]"
del t489
t2016 = torch.matmul(t2014, t2015)  # t2016: "cuda:0 f32[16, 16]"
  # t2016 = ltorch.matmul(t2014, t2015)  # t2016: "cuda:0 f32[16, 16]"
    # t2016 = prims.matmul(t2014, t2015)  # t2016: "cuda:0 f32[16, 16]"
del t2014, t2015
t2005 = torch.permute(t2002, (1, 0))  # t2005: "cuda:0 f32[16, 128]"
  # t2005 = ltorch.permute(t2002, (1, 0))  # t2005: "cuda:0 f32[16, 128]"
    # t2005 = prims.transpose(t2002, (1, 0))  # t2005: "cuda:0 f32[16, 128]"
del t2002
# Created by CPU Offloading Transform
t494 = load_to_gpu(offloaded_t494, 'cuda:0')  # t494: "cuda:0 f32[128, 16]"
t2006 = torch.reshape(t494, (-1, 16))  # t2006: "cuda:0 f32[128, 16]"
  # t2006 = ltorch.reshape(t494, (-1, 16))  # t2006: "cuda:0 f32[128, 16]"
    # t2006 = prims.reshape(t494, (128, 16))  # t2006: "cuda:0 f32[128, 16]"
del t494

Benchmark thunder vs thunder + CPU Offloading on Simple Model

[14]:
model, args, kwargs = get_model_and_args()

measurement_thunder = benchmark(thunder.jit(model), model, args, kwargs)
measurement_thunder_offload = benchmark(thunder.jit(model, transforms=[CPUOffloading()]), model, args, kwargs)

del model, args, kwargs

clear_memory()
Allocated Memory after cleaning 0.017047552 GB
[15]:
measurement_thunder
[15]:
<torch.utils.benchmark.utils.common.Measurement object at 0x74ff48fbb6a0>
stmt:
  # Use the optimized model for prediction and backward
  o = jmodel(*args, **kwargs)
  o.sum().backward()
  for param in model.parameters():  # use original model for clear grads
      param.grad = None

  8.50 ms
  1 measurement, 10 runs , 1 thread
[16]:
measurement_thunder_offload
[16]:
<torch.utils.benchmark.utils.common.Measurement object at 0x74ff439d2dd0>
stmt:
  # Use the optimized model for prediction and backward
  o = jmodel(*args, **kwargs)
  o.sum().backward()
  for param in model.parameters():  # use original model for clear grads
      param.grad = None

  12.62 ms
  1 measurement, 10 runs , 1 thread

Let us try it on a real-life model Llama-3. We will run it on a smaller Llama-3. Feel free to update N_LAYER and BLOCK_SIZE based on the available device memory.

NOTE: Running the cell below requires litgpt installed. Use pip install litgpt if it is not available.

[22]:
from litgpt import Config, GPT
from functools import partial
from torch.testing import make_tensor

N_LAYER = 9
BLOCK_SIZE = 1024

def get_model_and_args(batchdims=8):
    with torch.device("cuda"):
        cfg: Config = Config.from_name("Llama-3-8B")
        # Smaller configuration
        cfg.n_layer = N_LAYER
        cfg.block_size = BLOCK_SIZE

        model = GPT(cfg)
        make = partial(make_tensor, low=0, high=255, device='cuda', dtype=torch.int64, requires_grad=False)
        shape = (batchdims,) + (cfg.block_size,)

        x = make(shape)
        args, kwargs = (x,), {}

    return model, args, kwargs, cfg

def print_memory_usage_and_benchmark(name):
    print(f"{name} took -")
    model, args, kwargs, cfg = get_model_and_args()

    if name == 'thunder':
        jmodel = thunder.jit(model)
    elif name == 'thunder_offload':
        jmodel = thunder.jit(model, transforms=[CPUOffloading()])
    else:
        raise RuntimeError("Received invalid value for `name` - try `thunder` or `thunder_offload`.")

    memory_after_model_load = torch.cuda.max_memory_allocated() / 1e9
    print(f"Peak memory after loading the model : {memory_after_model_load} GB")

    a = jmodel(*args, **kwargs)

    memory_after_forward = torch.cuda.max_memory_allocated() / 1e9
    print(f"Peak memory after forward the model : {memory_after_forward} GB")

    g = torch.rand_like(a)
    actual_grads = torch.autograd.grad(a, model.parameters(), g)

    memory_after_backward = torch.cuda.max_memory_allocated() / 1e9
    print(f"Peak memory after backward the model : {memory_after_backward} GB")

    del a, g, actual_grads  # Clear data which is not required for benchmark to free some memory.
    gc.collect()
    torch.cuda.empty_cache()

    measurement = benchmark(jmodel, model, args, kwargs)
    print(f"Benchmark Timings - mean : {measurement.mean} - median {measurement.median}")

    del jmodel, model, cfg, args, kwargs

    clear_memory()
[23]:
# Uncomment this to run the benchmarks.
# print_memory_usage_and_benchmark("thunder")
thunder took -
Peak memory after loading the model : 12.073366016 GB
Peak memory after forward the model : 38.901342208 GB
Peak memory after backward the model : 46.245552128 GB
Benchmark Timings - mean : 5.008525840996299 - median 5.008525840996299
Allocated Memory after cleaning 0.017047552 GB
[24]:
# Uncomment this to run the benchmarks.
# print_memory_usage_and_benchmark("thunder_offload")
thunder_offload took -
Peak memory after loading the model : 12.073366016 GB
Peak memory after forward the model : 16.409812992 GB
Peak memory after backward the model : 35.545775616 GB
Benchmark Timings - mean : 5.91704241540283 - median 5.91704241540283
Allocated Memory after cleaning 0.017047552 GB

Conclusion

In this notebook, we have understood how to write our own Transform in thunder. As an example, we wrote an CPUOffloading transform to implement CPU offloading technique to decrease peak memory usage during training.