Introduction
In this tutorial, we will write a Trace transformation to perform CPU Offloading of intermediate tensors.
CPU Offloading is a technique to decrease the peak memory usage during training. This can allow us to train a larger model which would otherwise won’t be possible. However, we have to trade of some performance (increased memory transfers) to achieve the same.
[1]:
import gc
from typing import Callable, Mapping
import torch
import torch.utils.benchmark
import thunder
from thunder.core.trace import TraceCtx
from thunder.core.transform_common import Transform
from thunder.core.proxies import TensorProxy, variableify, Variable
from thunder.core.pytree import tree_map
from thunder.core.trace import tracectx, from_trace, TraceTag
from thunder.extend import OperatorExecutor
from thunder.core.symbol import BoundSymbol
from thunder.core import prims
from thunder.core.transforms import bsym_list_to_dag, Node, toposort_bsym_dag, TOPOSORT_ORDER
from thunder.core.vjp_utils import get_saved_for_backward_tensors
from thunder.core.module import ThunderModule
Transforms
To understand transforms, we need to know what a Trace
is. In thunder
, Trace
is the representation of the jitted program in terms of thunder operations/symbol. Each operation in Trace
is a collection of BoundSymbol
i.e. a Symbol
with it’s input and output. We can also print the trace as a Python program for easier inspection. We will do this later in the notebook. To understand these concepts, you can read the helpful zero_to_thunder.ipynb
.
thunder
allows us to write our custom transforms to transform trace/s. These transforms can be used for replace pointwise operations with fused implementation, compute gradient of the given computation, etc. Besides this, thunder
enables us to apply these transforms at different stages during compilation. In this tutorial, we will use the post optimization stage, the point at which we already have the forward and the backward execution trace ready. To write our transform, we have to
inherit from Transform
class. This class implements the interface that each transform should have. By default, it provides no-op transformations. To use our transform, we provide an instance of our transform object to the thunder.jit
via transforms
argument.
However, before writing our transform, we will make an OperatorExecutor
with which we will create 2 operators/symbol - 1. to offload tensors to CPU 2. Load the offloaded tensors back to CUDA device. You read more about adding custom operators in adding_custom_operator.ipynb
and also zero_to_thunder.ipynb
.
[8]:
# Create a new executor.
offload_ex = OperatorExecutor("offload_ex")
# NOTE: We create the offloaded CPU tensor in pinned memory and load the tensor back onto GPU with `to(non_blocking=True)`.
# These allow for better memory transfer speeds.
# Read the following tutorial for detailed explanation - https://pytorch.org/tutorials/intermediate/pinmem_nonblock.html
# Offload the GPU tensor to a pinned CPU tensor.
def offload_to_cpu_impl(t):
# Due to https://github.com/Lightning-AI/lightning-thunder/issues/950
# it may receive tensor on CPU.
if t.device == torch.device("cpu"):
return t
packed = torch.empty(
t.size(),
dtype=t.dtype,
layout=t.layout,
pin_memory=True,
)
packed.copy_(t)
return packed
offload_to_cpu = offload_ex.register_operator(
"offload_to_cpu",
meta=lambda t: TensorProxy("offloaded_" + t.name, like=t, device=thunder.core.devices.Device("cpu")),
fn=offload_to_cpu_impl,
)
# Load the tensor to given GPU
def load_to_gpu_impl(t, device):
return t.to(device, non_blocking=True)
load_to_gpu = offload_ex.register_operator(
"load_to_gpu",
meta=lambda t, device: TensorProxy(like=t, device=thunder.core.devices.Device(device)),
fn=load_to_gpu_impl,
)
First we will have some helper functions to implement our transformation
[9]:
def get_symbols_to_first_or_last_used_variables(symbols, first_used=False):
"""
This function processes a sequence of symbols and determines which variables
are first/last used by each symbol determined based on argument `first_used`.
It returns a mapping from variables to the symbols where they were first/last used.
Args:
symbols (iterable): An iterable of symbols
first_used (bool): Whether to return the map of first used variable to symbol mapping if True otherwise return the map for last used.
Defaults to False.
Returns:
variable_to_symbol (dict): A dictionary mapping each variable to the symbol where it is first/last used based on `first_used` argument.
"""
variable_to_symbol = {}
def _mark_first_or_last_use(symbol, variable):
if not variable in variable_to_symbol:
variable_to_symbol[variable] = symbol
iter_symbols = symbols if first_used else reversed(symbols)
for symbol in iter_symbols:
# If this function is used in the combined nvfuser+torch executor, there are no symbols but regions.
# Regions do not have args, kwargs
if hasattr(symbol, "inputs"):
variables = tuple(symbol.inputs) + tuple(symbol.outputs)
else:
variables = (symbol.flat_variableified_proxy_args) + tuple(symbol.flat_variableified_proxy_outs)
tree_map(lambda x: _mark_first_or_last_use(symbol, x), variables)
return variable_to_symbol
def get_symbol_to_idx(symbols):
'''
This function returns a map from symbol to it's position in the sequence.
'''
return {sym: idx for idx, sym in enumerate(symbols)}
def move_closer_to_consumer(execution_trace: TraceCtx) -> TraceCtx:
'''
This function takes the trace and reorders the operation such that operations producing input for the next operation
are closer together.
This is required as in the backward trace, the first consumer of saved_for_backward tensor maybe
a reshape or permute op and the actual computation occurs 50-100 (or more) lines later.
Because of this we load more tensors than required eagerly (thus decreasing the memory gains from CPU Offloading).
Args:
execution_trace (TraceCtx): Trace to be re-ordered.
'''
order_in_trace = {bsym: i for i, bsym in enumerate(execution_trace.bound_symbols)}
def prefer_ops_closer_to_consumer(eligible_nodes: list[Node]) -> int:
def key(node: Node) -> int:
return order_in_trace[node.bsym]
return min(range(len(eligible_nodes)), key=lambda i: key(eligible_nodes[i]))
# This moves all del or clear collection at the bottom (as they don't return anything)
bound_symbols = toposort_bsym_dag(
bsym_list_to_dag(execution_trace.bound_symbols)[1],
TOPOSORT_ORDER.BOTTOM_UP,
selector=prefer_ops_closer_to_consumer,
)
for idx, bsym in enumerate(bound_symbols):
if bsym.sym.id == prims.PrimIDs.DEL:
break
new_execution_trace = from_trace(execution_trace)
new_execution_trace.bound_symbols = bound_symbols[:idx]
new_execution_trace = thunder.executors.passes.del_last_used(new_execution_trace, clear_mutable_collections=True)
return new_execution_trace
Now to the main topic, of writing the transform for CPUOffloading.
The rough implementation of transform will look like this
From the forward computation trace, determine which tensors we want to offload to CPU. The
return
symbol of the forward trace has a sequence of tensors which are saved for the backward trace. We go through this list of tensor and find all the intermediate tensors (i.e. which are not an input to the trace). Here, we will also call a user provided callback which can further filter this list of tensors to offload.In the forward trace, we then find the last of use of the tensors to offload from above step and insert a call to
offload_to_cpu
symbol that we created above. Note that we will also save a map of which tensors we offloaded. We also note the original device where the tensor lived so that we can load it back to correct device.In the forward trace, we then update the
return
symbol to return the offloaded tensors (which are saved for the backward pass).In the backward trace, we read from the map of the tensors which were offloaded and update the
unpack
symbol of saved tensors to replace the original tensors with our offloaded tensors.In the backward trace, we then find the first use of the offloaded tensor in a computation and insert a
load_to_gpu
call before it. Note that here, we will use the previously stored map of tensor to original device so that we load it onto the correct device.
To see this steps, in our implementation -
See method
transform_trace_post_optimization
, which is invoked bythunder
with first the forward trace and then separately with the backward trace.See method
_offload_tensors_from_forward
, which implements Step 1, 2 and 3 from above.See method
_load_tensors_for_backward
, which implements Step 4, 5 from above.
Note that each of the above method has more details regarding the implementation.
[10]:
class CPUOffloading(Transform):
'''
Transform to implement CPU Offloading.
Args:
save_tensor_policy: Users can pass a callback with signature fn(offloaded_tensors, forward_trace) to filter
the offloaded_tensors based on their preference eg. biggest 20% intermediate tensors or
intermediates of certain operations
'''
def __init__(self, save_tensor_policy: Callable[[tuple[TensorProxy, ...], TraceCtx], tuple[TensorProxy, ...]] | None = None):
self.forward_pass = None
self.backward_pass = None
self._offloaded_tensors: Mapping[Variable, TensorProxy] = {}
self._offloaded_tensors_dev: Mapping[Variable, str] = {}
self.save_tensor_policy = None
if save_tensor_policy is not None:
assert callable(save_tensor_policy)
self.save_tensor_policy = save_tensor_policy
def _get_tensors_to_offload(self, forward_trace):
'''
Based on the `forward_trace`, we find the symbols that we want to offload to CPU.
This function finds the intermediate tensors that are saved for backward i.e. ones that are not input or output of this trace.
'''
return_bsym = forward_trace.bound_symbols[-1]
trace_args = return_bsym.args[0]["flat_args"]
saved_tensors = get_saved_for_backward_tensors(forward_trace)
tensor_args_name = tuple(arg.name for arg in trace_args if isinstance(arg, TensorProxy))
def is_in_tensor_args(t):
return t.name in tensor_args_name
def is_cuda_tensor(t):
return t.device.type == "cuda"
# Tensors which are intermediate and not argument to the computation trace are
# the ones we are interested in offloading.
tensors_to_offload = tuple(t for t in saved_tensors if ((not is_in_tensor_args(t)) and is_cuda_tensor(t)))
# Apply users policy if present.
if self.save_tensor_policy is not None:
tensors_to_offload = self.save_tensor_policy(tensors_to_offload, forward_trace)
self.tensors_to_offload = tensors_to_offload
return self.tensors_to_offload
def _replace_saved_tensors(self, forward_trace, new_output_map):
return_bsym = forward_trace.bound_symbols[-1]
new_return_bsym = return_bsym.from_bsym_swap_proxies(new_output_map)
# Replace the old return with our new return.
forward_trace.bound_symbols.pop(-1)
forward_trace.bound_symbols.append(new_return_bsym)
def _offload_tensors_from_forward(self, computation_trace):
'''
This function takes the forward computation trace and performs following step
1. Find the tensors to be offloaded using `_get_tensors_to_offload` (this also calls users `save_tensor_policy` if present).
2. Insert calls to the `offload_to_cpu` symbol with the tensor to offload. These calls are placed after the last computational
use of the tensors to be offloaded so that we free the memory as soon as possible.
3. Finally, we update the last symbol i.e. `return` symbol to return the offloaded tensors instead of the original tensors.
'''
# Step 1
# Find the tensors to offload.
# We offload saved tensors which are not arguments to the computation trace and are saved for backwards.
tensors_to_offload = self._get_tensors_to_offload(computation_trace)
# Step 2
# Insert the offloading calls after the last use of the saved tensor (which we want to offload).
# NOTE - We pass `computation_trace.bound_symbols[:-1]` as we don't want to pass the `return` symbol (which will otherwise be the last use of the saved tensors).
variable_to_last_symbol = get_symbols_to_first_or_last_used_variables(
computation_trace.bound_symbols[:-1], first_used=False
)
symbol_to_idx = get_symbol_to_idx(computation_trace.bound_symbols)
# Book keeping for backward pass update.
new_output_map: Mapping[Variable, TensorProxy] = {}
new_output_dev_map: Mapping[Variable, str] = {}
# Since we are inserting in the list (we need to obey increasing order) - else the insertions will be incorrect.
sorted_tensors_to_offload = sorted(
tensors_to_offload, key=lambda t: symbol_to_idx[variable_to_last_symbol[variableify(t)]]
)
for idx, t in enumerate(sorted_tensors_to_offload):
last_used_symbol = variable_to_last_symbol[variableify(t)]
last_used_symbol_idx = symbol_to_idx[last_used_symbol]
computation_trace.push_scope([])
with tracectx(computation_trace):
o = offload_to_cpu(t)
prims.python_del(t)
scoped_comp = computation_trace.pop_scope()
scoped_comp[0].header = "Created by CPU Offloading Transform"
offload_to_cpu_symbol = scoped_comp[0]
del_symbol = scoped_comp[1]
# This will insert `del` first and then push it down when we insert `offload_to_cpu`.
computation_trace.bound_symbols.insert(last_used_symbol_idx + 1 + (idx * 2), del_symbol)
computation_trace.bound_symbols.insert(last_used_symbol_idx + 1 + (idx * 2), offload_to_cpu_symbol)
# Update book keeping.
new_output_map[variableify(t)] = o
new_output_dev_map[variableify(t)] = t.device.device_str()
# Step 3
# Update the return symbol to return our offloaded tensors in saved for backward.
self._replace_saved_tensors(computation_trace, new_output_map)
# Book keeping for backward pass update.
self._offloaded_tensors = new_output_map
self._offloaded_tensors_dev = new_output_dev_map
return computation_trace
def _load_tensors_for_backward(self, computation_trace):
'''
This function takes the backward computation trace and performs following step
1. Finds the unpack collection symbol which unpacks the saved tensors passed to the backward trace.
2. Updates the unpack collection to unpack the offloaded tensors instead of the original ones.
3. Before the first use of the offloaded tensor in computation, we insert the `load_to_gpu` to load the tensor back on GPU.
'''
self.backward_pass = computation_trace
offloaded_tensors = self._offloaded_tensors
offloaded_tensors_dev_map = self._offloaded_tensors_dev
compute_producers, compute_consumers = thunder.core.utils.producers_and_consumers(computation_trace)
# We want to insert `loads` before the first use of offloaded_tensors.
variable_to_first_symbol = get_symbols_to_first_or_last_used_variables(computation_trace.bound_symbols, first_used=True)
symbol_to_idx = get_symbol_to_idx(computation_trace.bound_symbols)
# Step 1 and 2
# Update unpack collection so that it
# outputs the offloaded tensor proxies (not the original ones).
unpack_sym = compute_producers[list(offloaded_tensors.keys())[0].proxy]
unpack_idx = symbol_to_idx[unpack_sym]
unpack_sym_out = unpack_sym.output
new_out = []
for out in unpack_sym_out:
if (vout := variableify(out)) in offloaded_tensors:
new_out.append(offloaded_tensors[vout])
else:
new_out.append(out)
new_unpack_bsym = BoundSymbol.from_bsym(unpack_sym, output=tuple(new_out))
computation_trace.bound_symbols[unpack_idx] = new_unpack_bsym
# Now we again find the first usages of offloaded tensor
# This will actually point us to the first consumer of the offloaded tensor.
offset = unpack_idx + 1
variable_to_first_symbol = get_symbols_to_first_or_last_used_variables(computation_trace.bound_symbols[offset:], first_used=True)
# Step 3
# Load the offloaded tensors to GPU before usage.
# Should iterate in correct order (else insertion positions will be incorrect).
for idx, (vt, offloaded_t) in enumerate(
sorted(offloaded_tensors.items(), key=lambda kv: symbol_to_idx[variable_to_first_symbol[kv[0]]])
):
first_used_symbol = variable_to_first_symbol[vt]
first_used_symbol_idx = symbol_to_idx[first_used_symbol]
t = vt.proxy
device = offloaded_tensors_dev_map[vt]
with tracectx(computation_trace):
new_sym = load_to_gpu.bind(offloaded_t, device, output=t)
new_sym.header = "Created by CPU Offloading Transform"
computation_trace.bound_symbols.insert(first_used_symbol_idx + idx, new_sym)
return computation_trace
def transform_trace_post_optimization(self, computation_trace: thunder.TraceCtx, **kwargs):
if self.forward_pass is None:
self.forward_pass = computation_trace
# Processing for the forward pass (only if we are going to compute backward).
if TraceTag.AUGMENTED_FORWARD in computation_trace.tags:
# Create a new copy of computation trace using `from_trace`.
new_computation_trace = from_trace(computation_trace)
# `from_trace` creates a shallow copy where `bound_symbols` and `provenance` are not copied.
new_computation_trace.bound_symbols = computation_trace.bound_symbols
new_computation_trace = self._offload_tensors_from_forward(new_computation_trace)
else:
# Skip if no tensor was offloaded.
if len(self._offloaded_tensors) == 0:
return computation_trace
# Create a new copy of computation trace using `from_trace`.
new_computation_trace = from_trace(computation_trace)
# `from_trace` creates a shallow copy where `bound_symbols` and `provenance` are not copied.
new_computation_trace.bound_symbols = computation_trace.bound_symbols
# We need this because in unmodified backward trace, the first consumer of saved_for_backward maybe
# a reshape or permute op and the actual computation occurs 50-100 (or more) lines later.
# Because of this we load more tensors than required eagerly (thus decreasing the memory gains from CPU Offloading).
# Eg. on line 92
# # Created by CPU Offloading Transform
# t1319 = load_to_gpu(offloaded_t1319, 'cuda:0') # t1319: "cuda:0 f32[8, 1024, 11008]"
# t4021 = torch.reshape(t1319, (-1, 11008)) # t4021: "cuda:0 f32[8192, 11008]"
# # t4021 = ltorch.reshape(t1319, (-1, 11008)) # t4021: "cuda:0 f32[8192, 11008]"
# # t4021 = prims.reshape(t1319, (8192, 11008)) # t4021: "cuda:0 f32[8192, 11008]"
# del t1319
# And it's usage in computation is at 612
# t4022 = torch.matmul(t4020, t4021) # t4022: "cuda:0 f32[4096, 11008]"
# t4022 = ltorch.matmul(t4020, t4021) # t4022: "cuda:0 f32[4096, 11008]"
# t4022 = prims.matmul(t4020, t4021) # t4022: "cuda:0 f32[4096, 11008]"
new_computation_trace = move_closer_to_consumer(new_computation_trace)
# Transform the backward trace to load offloaded tensors back to the device.
new_computation_trace = self._load_tensors_for_backward(new_computation_trace)
return new_computation_trace
[11]:
def clear_memory():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_accumulated_memory_stats()
torch.cuda.reset_peak_memory_stats()
print(f"Allocated Memory after cleaning {torch.cuda.memory_allocated() / 1e9} GB")
[12]:
def benchmark(jmodel: ThunderModule, model: torch.nn.Module, args, kwargs):
# NOTE - This function takes care of warm-up
stmt = """
# Use the optimized model for prediction and backward
o = jmodel(*args, **kwargs)
o.sum().backward()
for param in model.parameters(): # use original model for clear grads
param.grad = None
"""
timer = torch.utils.benchmark.Timer(
stmt=stmt, globals={"jmodel": jmodel, "model": model, "args": args, "kwargs": kwargs}
).timeit(number=10)
return timer
Testing our Transform on a Simple Model
[13]:
class MySimpleModel(torch.nn.Module):
def __init__(self, n_layers=10):
super().__init__()
self.fcs = torch.nn.ModuleList([torch.nn.Linear(16, 16) for _ in range(n_layers)])
def forward(self, x):
for fc in self.fcs:
x = torch.nn.functional.relu(fc(x))
return x
def get_model_and_args():
device = 'cuda'
model = MySimpleModel(n_layers=100).to(device)
args = (torch.randn(128, 16, device=device),)
kwargs = {}
return model, args, kwargs
model, args, kwargs = get_model_and_args()
# Check against the vanilla `thunder.jit` model
expected = thunder.jit(model)(*args, **kwargs)
grad_output = torch.randn_like(expected)
expected_grads = torch.autograd.grad(expected, model.parameters(), grad_output)
print(f"Peak Memory with thunder : {torch.cuda.max_memory_allocated()} bytes")
torch.cuda.reset_peak_memory_stats()
with torch.no_grad():
expected_cpu = expected.to("cpu")
expected_grads_cpu = tree_map(lambda t: t.to("cpu"), expected_grads)
jmodel = thunder.jit(model, transforms=[CPUOffloading()])
actual = jmodel(*args, **kwargs)
# Verify that saved tensors are on CPU.
saved_tensor_devices = set()
for t in actual.grad_fn.saved_tensors:
saved_tensor_devices.add(str(t.device))
assert "cpu" in saved_tensor_devices # Verify that we actually have saved tensors on CPU
actual_grads = torch.autograd.grad(actual, jmodel.parameters(), grad_output)
print(f"Peak Memory with CPU Offloading : {torch.cuda.max_memory_allocated()} bytes")
with torch.no_grad():
actual_cpu = actual.to("cpu")
actual_grads_cpu = tree_map(lambda t: t.to("cpu"), actual_grads)
# Sanity Check that values match
torch.testing.assert_close(actual_cpu, expected_cpu)
torch.testing.assert_close(actual_grads_cpu, expected_grads_cpu)
# Fetch the forward and backward traces for inspection
fw_traces = thunder.last_traces(jmodel)
bw_traces = thunder.last_backward_traces(jmodel)
del jmodel, model, args, kwargs, actual, actual_grads, expected, expected_grads, grad_output # Free memory.
clear_memory()
Peak Memory with thunder : 19279872 bytes
Peak Memory with CPU Offloading : 18444288 bytes
Allocated Memory after cleaning 0.017047552 GB
Inspecting the forward and the backward traces.
[ ]:
fw_traces[-1] # Note the calls to `offload_to_cpu` and verify that they are after the last usage of the tensor.
Snippet from the forward trace
t485 = torch.nn.functional.linear(t484, t_fcs_97_weight, t_fcs_97_bias) # t485: "cuda:0 f32[128, 16]"
# t485 = ltorch.linear(t484, t_fcs_97_weight, t_fcs_97_bias) # t485: "cuda:0 f32[128, 16]"
# t485 = prims.linear(t484, t_fcs_97_weight, t_fcs_97_bias) # t485: "cuda:0 f32[128, 16]"
# Created by CPU Offloading Transform
offloaded_t484 = offload_to_cpu(t484) # offloaded_t484: "cpu f32[128, 16]"
del t484
[t487, t489] = nvFusion97(t485)
# t487 = prims.gt(t485, 0.0) # t487: "cuda:0 b8[128, 16]"
# t489 = prims.where(t487, t485, 0.0) # t489: "cuda:0 f32[128, 16]"
# Created by CPU Offloading Transform
offloaded_t487 = offload_to_cpu(t487) # offloaded_t487: "cpu b8[128, 16]"
del t487
del t485
t490 = torch.nn.functional.linear(t489, t_fcs_98_weight, t_fcs_98_bias) # t490: "cuda:0 f32[128, 16]"
# t490 = ltorch.linear(t489, t_fcs_98_weight, t_fcs_98_bias) # t490: "cuda:0 f32[128, 16]"
# t490 = prims.linear(t489, t_fcs_98_weight, t_fcs_98_bias) # t490: "cuda:0 f32[128, 16]"
# Created by CPU Offloading Transform
offloaded_t489 = offload_to_cpu(t489) # offloaded_t489: "cpu f32[128, 16]"
del t489
[ ]:
bw_traces[-1] # Note the calls to `load_to_gpu` and verify that they are before the first usage of the tensor.
Snippet from the backward trace
# Created by CPU Offloading Transform
t489 = load_to_gpu(offloaded_t489, 'cuda:0') # t489: "cuda:0 f32[128, 16]"
t2015 = torch.reshape(t489, (-1, 16)) # t2015: "cuda:0 f32[128, 16]"
# t2015 = ltorch.reshape(t489, (-1, 16)) # t2015: "cuda:0 f32[128, 16]"
# t2015 = prims.reshape(t489, (128, 16)) # t2015: "cuda:0 f32[128, 16]"
del t489
t2016 = torch.matmul(t2014, t2015) # t2016: "cuda:0 f32[16, 16]"
# t2016 = ltorch.matmul(t2014, t2015) # t2016: "cuda:0 f32[16, 16]"
# t2016 = prims.matmul(t2014, t2015) # t2016: "cuda:0 f32[16, 16]"
del t2014, t2015
t2005 = torch.permute(t2002, (1, 0)) # t2005: "cuda:0 f32[16, 128]"
# t2005 = ltorch.permute(t2002, (1, 0)) # t2005: "cuda:0 f32[16, 128]"
# t2005 = prims.transpose(t2002, (1, 0)) # t2005: "cuda:0 f32[16, 128]"
del t2002
# Created by CPU Offloading Transform
t494 = load_to_gpu(offloaded_t494, 'cuda:0') # t494: "cuda:0 f32[128, 16]"
t2006 = torch.reshape(t494, (-1, 16)) # t2006: "cuda:0 f32[128, 16]"
# t2006 = ltorch.reshape(t494, (-1, 16)) # t2006: "cuda:0 f32[128, 16]"
# t2006 = prims.reshape(t494, (128, 16)) # t2006: "cuda:0 f32[128, 16]"
del t494
Benchmark thunder
vs thunder + CPU Offloading
on Simple Model
[14]:
model, args, kwargs = get_model_and_args()
measurement_thunder = benchmark(thunder.jit(model), model, args, kwargs)
measurement_thunder_offload = benchmark(thunder.jit(model, transforms=[CPUOffloading()]), model, args, kwargs)
del model, args, kwargs
clear_memory()
Allocated Memory after cleaning 0.017047552 GB
[15]:
measurement_thunder
[15]:
<torch.utils.benchmark.utils.common.Measurement object at 0x74ff48fbb6a0>
stmt:
# Use the optimized model for prediction and backward
o = jmodel(*args, **kwargs)
o.sum().backward()
for param in model.parameters(): # use original model for clear grads
param.grad = None
8.50 ms
1 measurement, 10 runs , 1 thread
[16]:
measurement_thunder_offload
[16]:
<torch.utils.benchmark.utils.common.Measurement object at 0x74ff439d2dd0>
stmt:
# Use the optimized model for prediction and backward
o = jmodel(*args, **kwargs)
o.sum().backward()
for param in model.parameters(): # use original model for clear grads
param.grad = None
12.62 ms
1 measurement, 10 runs , 1 thread
Let us try it on a real-life model Llama-3. We will run it on a smaller Llama-3. Feel free to update N_LAYER
and BLOCK_SIZE
based on the available device memory.
NOTE: Running the cell below requires litgpt
installed. Use pip install litgpt
if it is not available.
[22]:
from litgpt import Config, GPT
from functools import partial
from torch.testing import make_tensor
N_LAYER = 9
BLOCK_SIZE = 1024
def get_model_and_args(batchdims=8):
with torch.device("cuda"):
cfg: Config = Config.from_name("Llama-3-8B")
# Smaller configuration
cfg.n_layer = N_LAYER
cfg.block_size = BLOCK_SIZE
model = GPT(cfg)
make = partial(make_tensor, low=0, high=255, device='cuda', dtype=torch.int64, requires_grad=False)
shape = (batchdims,) + (cfg.block_size,)
x = make(shape)
args, kwargs = (x,), {}
return model, args, kwargs, cfg
def print_memory_usage_and_benchmark(name):
print(f"{name} took -")
model, args, kwargs, cfg = get_model_and_args()
if name == 'thunder':
jmodel = thunder.jit(model)
elif name == 'thunder_offload':
jmodel = thunder.jit(model, transforms=[CPUOffloading()])
else:
raise RuntimeError("Received invalid value for `name` - try `thunder` or `thunder_offload`.")
memory_after_model_load = torch.cuda.max_memory_allocated() / 1e9
print(f"Peak memory after loading the model : {memory_after_model_load} GB")
a = jmodel(*args, **kwargs)
memory_after_forward = torch.cuda.max_memory_allocated() / 1e9
print(f"Peak memory after forward the model : {memory_after_forward} GB")
g = torch.rand_like(a)
actual_grads = torch.autograd.grad(a, model.parameters(), g)
memory_after_backward = torch.cuda.max_memory_allocated() / 1e9
print(f"Peak memory after backward the model : {memory_after_backward} GB")
del a, g, actual_grads # Clear data which is not required for benchmark to free some memory.
gc.collect()
torch.cuda.empty_cache()
measurement = benchmark(jmodel, model, args, kwargs)
print(f"Benchmark Timings - mean : {measurement.mean} - median {measurement.median}")
del jmodel, model, cfg, args, kwargs
clear_memory()
[23]:
# Uncomment this to run the benchmarks.
# print_memory_usage_and_benchmark("thunder")
thunder took -
Peak memory after loading the model : 12.073366016 GB
Peak memory after forward the model : 38.901342208 GB
Peak memory after backward the model : 46.245552128 GB
Benchmark Timings - mean : 5.008525840996299 - median 5.008525840996299
Allocated Memory after cleaning 0.017047552 GB
[24]:
# Uncomment this to run the benchmarks.
# print_memory_usage_and_benchmark("thunder_offload")
thunder_offload took -
Peak memory after loading the model : 12.073366016 GB
Peak memory after forward the model : 16.409812992 GB
Peak memory after backward the model : 35.545775616 GB
Benchmark Timings - mean : 5.91704241540283 - median 5.91704241540283
Allocated Memory after cleaning 0.017047552 GB
Conclusion
In this notebook, we have understood how to write our own Transform
in thunder
. As an example, we wrote an CPUOffloading
transform to implement CPU offloading technique to decrease peak memory usage during training.