Extending Thunder
This notebook shows how to use thunder’s extend submodule to add new operations and custom grad and execution transforms.
[1]:
import sys
sys.path.insert(0, '..')
from numbers import Number
import thunder
import thunder.torch as ltorch
from thunder.core.devices import DeviceType
from thunder.core.proxies import TensorProxy
from thunder.core.transforms import grad, put_grads, get_grad
import torch
import numpy as np
torch.manual_seed(42);
[2]:
from thunder.extend import OperatorExecutor, register_executor
[3]:
# Registers a new operator executor
myex = OperatorExecutor("myex", version="0.1")
register_executor(myex)
[3]:
myex
[4]:
# Our operator executor will use the "multimul" function as a new example operator.
# This function uses NumPy to perform two multiplications of four inputs.
# This function's contrived, but will be useful to illustrate the extend submodule's capabilities.
def multimul_impl(
a: Number | torch.Tensor,
b: Number | torch.Tensor,
c: Number | torch.Tensor,
d: Number | torch.Tensor,) -> tuple[torch.Tensor, torch.Tensor]:
return np.multiply(a, b), np.multiply(c, d)
[5]:
# We can verify that multimul is a valid Python function that operates on PyTorch tensors -- at least PyTorch tensors on the CPU.
a = torch.randn((2, 2))
b = torch.randn((2, 2))
multimul_impl(a, b, a, b)
[5]:
(tensor([[-0.3781, -0.0240],
[ 0.5177, -0.1470]]),
tensor([[-0.3781, -0.0240],
[ 0.5177, -0.1470]]))
[6]:
# To let thunder use multimul we need to define how it propagates metadata. This can be done by directly defining a "meta function",
# of by defining a traceable "like" function that describes what multimul does in terms of existing thunder operations.
# The "like" function can be used for metadata propagation AND transforming the new operator, as we'll see below.
# In this case, the "like" function just describes the two multiplications that multimul performs.
def multimul_like(
a: Number | TensorProxy,
b: Number | TensorProxy,
c: Number | TensorProxy,
d: Number | TensorProxy,
):
return a * b, c * d
[7]:
# The "register_operator" method of operator executor's returns a "Symbol" object for multimul that can be called directly
# from compiled thunder code.
multimul = myex.register_operator('multimul', like=multimul_like, fn=multimul_impl)
[8]:
# Example of calling the new multimul symbol
def foo(a, b, c, d):
return multimul(a, b, c, d)
cfoo = thunder.jit(foo, executors=[myex])
cfoo(a, b, a, b)
[8]:
(tensor([[-0.3781, -0.0240],
[ 0.5177, -0.1470]]),
tensor([[-0.3781, -0.0240],
[ 0.5177, -0.1470]]))
[9]:
# The symbol is recorded, like other operations, into thunder's trace
thunder.last_traces(cfoo)[-1]
[9]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(t_0, t_1, t_2, t_3):
# t_0: "cpu f32[2, 2]"
# t_1: "cpu f32[2, 2]"
# t_2: "cpu f32[2, 2]"
# t_3: "cpu f32[2, 2]"
(t0, t1) = multimul(t_0, t_1, t_2, t_3)
# t0 = ltorch.mul(t_0, t_1) # t0: "cpu f32[2, 2]"
# t0 = prims.mul(t_0, t_1) # t0: "cpu f32[2, 2]"
# t1 = ltorch.mul(t_2, t_3) # t1: "cpu f32[2, 2]"
# t1 = prims.mul(t_2, t_3) # t1: "cpu f32[2, 2]"
del t_0, t_1, t_2, t_3
return (t0, t1)
[10]:
# multimul is even differentiable because its "like" function is differentiable
a.requires_grad_(True)
b.requires_grad_(True)
cfoo_grad = grad(cfoo)
cfoo_grad(a, b, a, b)
print(thunder.last_traces(cfoo_grad)[-1])
a.requires_grad_(False)
b.requires_grad_(False)
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(t_0, t_1, t_2, t_3):
# t_1: "cpu f32[2, 2]"
t8 = torch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32) # t8: "cpu f32[2, 2]"
# t8 = ltorch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32) # t8: "cpu f32[2, 2]"
# t8 = prims.full((2, 2), 1.0, device=devices.Device("cpu"), dtype=dtypes.float32) # t8: "cpu f32[2, 2]"
t2 = torch.mul(t_1, t8) # t2: "cpu f32[2, 2]"
# t2 = ltorch.mul(t_1, t8) # t2: "cpu f32[2, 2]"
# t2 = prims.mul(t_1, t8) # t2: "cpu f32[2, 2]"
del t_1
# t_0: "cpu f32[2, 2]"
t3 = torch.mul(t_0, t8) # t3: "cpu f32[2, 2]"
# t3 = ltorch.mul(t_0, t8) # t3: "cpu f32[2, 2]"
# t3 = prims.mul(t_0, t8) # t3: "cpu f32[2, 2]"
del t_0, t8
# t_3: "cpu f32[2, 2]"
t9 = torch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32) # t9: "cpu f32[2, 2]"
# t9 = ltorch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32) # t9: "cpu f32[2, 2]"
# t9 = prims.full((2, 2), 1.0, device=devices.Device("cpu"), dtype=dtypes.float32) # t9: "cpu f32[2, 2]"
t6 = torch.mul(t_3, t9) # t6: "cpu f32[2, 2]"
# t6 = ltorch.mul(t_3, t9) # t6: "cpu f32[2, 2]"
# t6 = prims.mul(t_3, t9) # t6: "cpu f32[2, 2]"
del t_3
# t_2: "cpu f32[2, 2]"
t7 = torch.mul(t_2, t9) # t7: "cpu f32[2, 2]"
# t7 = ltorch.mul(t_2, t9) # t7: "cpu f32[2, 2]"
# t7 = prims.mul(t_2, t9) # t7: "cpu f32[2, 2]"
del t_2, t9
return [t2, t3, t6, t7]
[10]:
tensor([[-1.1229, -0.1863],
[ 2.2082, -0.6380]])
[11]:
# We can tell thunder to execute existing operations using multimul by defining a transform
# from them to multimul, and a "checker" function that returns True when the
# transform is valid and False otherwise.
# We can translate mul to multimul by ignoring the second multiplication
def mul_to_multimul(a: Number | TensorProxy, b: Number | TensorProxy) -> TensorProxy:
result, _ = multimul(a, b, 0, 0)
return result
# The "checker" function verifies that all inputs are CPU tensors or numbers, because NumPy
# can't handle other inputs
def mul_to_multimul_checker(a: Number | TensorProxy, b: Number | TensorProxy) -> bool:
def is_cpu(x: Number | TensorProxy) -> bool:
if isinstance(a, TensorProxy):
return a.device.devicetype == DeviceType.CPU
return True
return all(is_cpu(x) for x in (a, b))
[12]:
# The "register_implementation" method describes how to translate mul to multimul
myex.register_implementation(ltorch.mul, checker=mul_to_multimul_checker, execution_transform=mul_to_multimul)
[13]:
# Verifies the implementation of mul using multimul, and shows the execution transform
def bar(a, b):
return a * b
cbar = thunder.jit(bar, executors=[myex])
cbar(a, b)
thunder.last_traces(cbar)[-1]
[13]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(t_0, t_1):
# t_0: "cpu f32[2, 2]"
# t_1: "cpu f32[2, 2]"
(t0, _) = multimul(t_0, t_1, 0, 0)
# t0 = ltorch.mul(t_0, t_1) # t0: "cpu f32[2, 2]"
# t0 = prims.mul(t_0, t_1) # t0: "cpu f32[2, 2]"
del t_0, t_1
return t0
[14]:
# Execution transforms happen AFTER semantic transforms like grad, so even when computing the grad
# of mul (which involves two multiplications to compute the grad) we still see multimul in the
# execution trace
a.requires_grad_(True)
b.requires_grad_(True)
cbar_grad = grad(cbar)
cbar_grad(a, b)
thunder.last_traces(cbar_grad)[-1]
[14]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(t_0, t_1):
# t_1: "cpu f32[2, 2]"
t4 = torch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32) # t4: "cpu f32[2, 2]"
# t4 = ltorch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32) # t4: "cpu f32[2, 2]"
# t4 = prims.full((2, 2), 1.0, device=devices.Device("cpu"), dtype=dtypes.float32) # t4: "cpu f32[2, 2]"
(t2, _) = multimul(t_1, t4, 0, 0)
# t2 = ltorch.mul(t_1, t4) # t2: "cpu f32[2, 2]"
# t2 = prims.mul(t_1, t4) # t2: "cpu f32[2, 2]"
del t_1
# t_0: "cpu f32[2, 2]"
(t3, _) = multimul(t_0, t4, 0, 0)
# t3 = ltorch.mul(t_0, t4) # t3: "cpu f32[2, 2]"
# t3 = prims.mul(t_0, t4) # t3: "cpu f32[2, 2]"
del t_0, t4
return [t2, t3]
[15]:
# In the above grad trace there are two multimuls, and both ignore one of their multiplications.
# It would be more efficient to perform just one multimul, and we can make this happen
# by defining a new grad transform for mul that calls multimul once.
# thunder's grad transforms are defined in a novel way that's not the focus of this notebook,
# but below we define the grad transform to use multimul.
def mymul_grad(a: TensorProxy, b: TensorProxy) -> TensorProxy:
fwd = a * b
g = get_grad(fwd)
a_grad, b_grad = multimul(b, g, a, g)
put_grads((a, b), (a_grad, b_grad))
return fwd
# Re-registers the implementation, including the execution transform and now a grad transform
myex.register_implementation(ltorch.mul, checker=mul_to_multimul_checker, execution_transform=mul_to_multimul, grad_transform=mymul_grad)
[16]:
# Verifies our new grad transform is used and that a single multimul call is made
cbar_grad = grad(cbar)
cbar_grad(a, b)
thunder.last_traces(cbar_grad)[-1]
[16]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(t_0, t_1):
# t_0: "cpu f32[2, 2]"
t4 = torch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32) # t4: "cpu f32[2, 2]"
# t4 = ltorch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32) # t4: "cpu f32[2, 2]"
# t4 = prims.full((2, 2), 1.0, device=devices.Device("cpu"), dtype=dtypes.float32) # t4: "cpu f32[2, 2]"
# t_1: "cpu f32[2, 2]"
(t2, t3) = multimul(t_1, t4, t_0, t4)
# t2 = ltorch.mul(t_1, t4) # t2: "cpu f32[2, 2]"
# t2 = prims.mul(t_1, t4) # t2: "cpu f32[2, 2]"
# t3 = ltorch.mul(t_0, t4) # t3: "cpu f32[2, 2]"
# t3 = prims.mul(t_0, t4) # t3: "cpu f32[2, 2]"
del t_1, t4, t_0
return [t2, t3]
[17]:
# Some operations may require inputs have particular properties (like be contiguous), or a transform may wish
# to interleave torch operations with new operations. The transform function supports this. Here
# we can see an example where the inputs to multimul are made contiguous before it's called
def mul_to_contiguous_multimul(a: Number | TensorProxy, b: Number | TensorProxy) -> TensorProxy:
a = a.contiguous()
b = b.contiguous()
result, _ = multimul(a, b, 0, 0)
return result
myex.register_implementation(ltorch.mul, checker=mul_to_multimul_checker, execution_transform=mul_to_contiguous_multimul)
[18]:
# Verifies the new "prologue" for multimul works as expected. Note that the contiguous operations are
# executed by PyTorch, and don't have to be executed by your executor
a.requires_grad_(False)
b.requires_grad_(False)
def caz(a, b):
return a * b
ccaz = thunder.jit(caz, executors=[myex])
ccaz(a, b)
thunder.last_traces(ccaz)[-1]
[18]:
# Constructed by Delete Last Used (took 0 milliseconds)
from torch import Tensor
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(t_0, t_1):
# t_0: "cpu f32[2, 2]"
# t_1: "cpu f32[2, 2]"
t1 = Tensor.contiguous(t_0, memory_format=_torch_memory_format_0) # t1: "cpu f32[2, 2]"
# t1 = ltorch.contiguous(t_0, memory_format=_torch_memory_format_0) # t1: "cpu f32[2, 2]"
# t1 = prims.stride_order(t_0, (1, 0)) # t1: "cpu f32[2, 2]"
del t_0
t2 = Tensor.contiguous(t_1, memory_format=_torch_memory_format_0) # t2: "cpu f32[2, 2]"
# t2 = ltorch.contiguous(t_1, memory_format=_torch_memory_format_0) # t2: "cpu f32[2, 2]"
# t2 = prims.stride_order(t_1, (1, 0)) # t2: "cpu f32[2, 2]"
del t_1
(t0, _) = multimul(t1, t2, 0, 0)
# t0 = ltorch.mul(t1, t2) # t0: "cpu f32[2, 2]"
# t0 = prims.mul(t1, t2) # t0: "cpu f32[2, 2]"
del t1, t2
return t0
[19]:
# NVIDIA's APEX cross-entropy executor is a good example of a real-world operator executor. It defines
# fast forward and backward functions for torch.nn.functional.cross_entropy. We can see its custom
# fwd and bwd operations below
# NOTE This cell and the following cells require the apex executor be installed to run properly
dtype = torch.float32
device = 'cuda'
logits = torch.randn([2048, 50257], device=device, dtype=ltorch.to_torch_dtype(dtype), requires_grad=False)
labels = torch.randint(0, 50257, [2048], device=device)
from thunder.executors.apexex import apex_ex
def foo(logits, labels):
return torch.nn.functional.cross_entropy(logits, labels, reduction="mean", ignore_index=-1)
cfoo = thunder.jit(foo, executors=[apex_ex])
[20]:
# Shows the forward operation
cfoo(logits, labels)
thunder.last_traces(cfoo)[-1]
[20]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(t_0, t_1):
# t_0: "cuda:0 f32[2048, 50257]"
# t_1: "cuda:0 i64[2048]"
(t18, _) = apex_cross_entropy(t_0, t_1, 'mean', 0.0)
del t_0, t_1
return t18
[21]:
# Shows APEX's custom forward and backward operations, plus additional PyTorch operations between the two
logits.requires_grad_(True)
cfoo_grad = grad(cfoo)
cfoo_grad(logits, labels)
thunder.last_traces(cfoo_grad)[-1]
[21]:
# Constructed by Delete Last Used (took 0 milliseconds)
from torch import Tensor
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(t_0, t_1):
# t_0: "cuda:0 f32[2048, 50257]"
# t_1: "cuda:0 i64[2048]"
(_, t1) = apex_cross_entropy(t_0, t_1, 'mean', 0.0)
t6 = Tensor.contiguous(t_0, memory_format=_torch_memory_format_0) # t6: "cuda:0 f32[2048, 50257]"
# t6 = ltorch.contiguous(t_0, memory_format=_torch_memory_format_0) # t6: "cuda:0 f32[2048, 50257]"
# t6 = prims.stride_order(t_0, (1, 0)) # t6: "cuda:0 f32[2048, 50257]"
del t_0
t8 = torch.full((), 1.0, device=torch.device("cuda:0"), dtype=torch.float32) # t8: "cuda:0 f32[]"
# t8 = ltorch.full((), 1.0, device=torch.device("cuda:0"), dtype=torch.float32) # t8: "cuda:0 f32[]"
# t8 = prims.full((), 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t8: "cuda:0 f32[]"
t12 = torch.unsqueeze(t8, 0) # t12: "cuda:0 f32[1]"
# t12 = ltorch.unsqueeze(t8, 0) # t12: "cuda:0 f32[1]"
# t12 = prims.broadcast_in_dim(t8, [1], []) # t12: "cuda:0 f32[1]"
del t8
t3 = Tensor.expand(t12, [1]) # t3: "cuda:0 f32[1]"
# t3 = ltorch.expand(t12, [1]) # t3: "cuda:0 f32[1]"
# t3 = prims.broadcast_in_dim(t12, (1,), (0,)) # t3: "cuda:0 f32[1]"
del t12
t4 = Tensor.expand(t3, (2048,)) # t4: "cuda:0 f32[2048]"
# t4 = ltorch.expand(t3, (2048,)) # t4: "cuda:0 f32[2048]"
# t4 = prims.broadcast_in_dim(t3, (2048,), (0,)) # t4: "cuda:0 f32[2048]"
del t3
t5 = torch.mul(t4, 0.00048828125) # t5: "cuda:0 f32[2048]"
# t5 = ltorch.mul(t4, 0.00048828125) # t5: "cuda:0 f32[2048]"
# t5 = prims.mul(t4, 0.00048828125) # t5: "cuda:0 f32[2048]"
del t4
t7 = apex_cross_entropy_backward(t5, t6, target=t_1, max_log_sum_exp=t1, label_smoothing=0.0) # t7: "cuda:0 f32[2048, 50257]"
del t5, t6, t1, t_1
return [t7]