{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Extending Thunder\n", "\n", "This notebook shows how to use thunder's extend submodule to add new operations and custom grad and execution transforms." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.insert(0, '..')\n", "from numbers import Number\n", "\n", "import thunder\n", "import thunder.torch as ltorch\n", "from thunder.core.devices import DeviceType\n", "from thunder.core.proxies import TensorProxy\n", "from thunder.core.transforms import grad, put_grads, get_grad\n", "\n", "import torch\n", "import numpy as np\n", "\n", "torch.manual_seed(42);" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from thunder.extend import OperatorExecutor, register_executor" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "myex" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Registers a new operator executor\n", "myex = OperatorExecutor(\"myex\", version=\"0.1\")\n", "register_executor(myex)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Our operator executor will use the \"multimul\" function as a new example operator.\n", "# This function uses NumPy to perform two multiplications of four inputs.\n", "# This function's contrived, but will be useful to illustrate the extend submodule's capabilities.\n", "def multimul_impl(\n", " a: Number | torch.Tensor, \n", " b: Number | torch.Tensor,\n", " c: Number | torch.Tensor,\n", " d: Number | torch.Tensor,) -> tuple[torch.Tensor, torch.Tensor]:\n", " return np.multiply(a, b), np.multiply(c, d)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[-0.3781, -0.0240],\n", " [ 0.5177, -0.1470]]),\n", " tensor([[-0.3781, -0.0240],\n", " [ 0.5177, -0.1470]]))" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# We can verify that multimul is a valid Python function that operates on PyTorch tensors -- at least PyTorch tensors on the CPU.\n", "a = torch.randn((2, 2))\n", "b = torch.randn((2, 2))\n", "multimul_impl(a, b, a, b)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# To let thunder use multimul we need to define how it propagates metadata. This can be done by directly defining a \"meta function\", \n", "# of by defining a traceable \"like\" function that describes what multimul does in terms of existing thunder operations. \n", "# The \"like\" function can be used for metadata propagation AND transforming the new operator, as we'll see below.\n", "# In this case, the \"like\" function just describes the two multiplications that multimul performs.\n", "def multimul_like(\n", " a: Number | TensorProxy, \n", " b: Number | TensorProxy,\n", " c: Number | TensorProxy,\n", " d: Number | TensorProxy,\n", "):\n", " return a * b, c * d" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# The \"register_operator\" method of operator executor's returns a \"Symbol\" object for multimul that can be called directly\n", "# from compiled thunder code.\n", "multimul = myex.register_operator('multimul', like=multimul_like, fn=multimul_impl)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[-0.3781, -0.0240],\n", " [ 0.5177, -0.1470]]),\n", " tensor([[-0.3781, -0.0240],\n", " [ 0.5177, -0.1470]]))" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Example of calling the new multimul symbol\n", "def foo(a, b, c, d):\n", " return multimul(a, b, c, d)\n", "\n", "cfoo = thunder.jit(foo, executors=[myex])\n", "cfoo(a, b, a, b)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "# Constructed by Delete Last Used (took 0 milliseconds)\n", "import torch\n", "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", "@no_autocast\n", "def computation(t_0, t_1, t_2, t_3):\n", " # t_0: \"cpu f32[2, 2]\" \n", " # t_1: \"cpu f32[2, 2]\" \n", " # t_2: \"cpu f32[2, 2]\" \n", " # t_3: \"cpu f32[2, 2]\" \n", " (t0, t1) = multimul(t_0, t_1, t_2, t_3)\n", " # t0 = ltorch.mul(t_0, t_1) # t0: \"cpu f32[2, 2]\"\n", " # t0 = prims.mul(t_0, t_1) # t0: \"cpu f32[2, 2]\"\n", " # t1 = ltorch.mul(t_2, t_3) # t1: \"cpu f32[2, 2]\"\n", " # t1 = prims.mul(t_2, t_3) # t1: \"cpu f32[2, 2]\"\n", " del t_0, t_1, t_2, t_3\n", " return (t0, t1)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The symbol is recorded, like other operations, into thunder's trace\n", "thunder.last_traces(cfoo)[-1]" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "# Constructed by Delete Last Used (took 0 milliseconds)\n", "import torch\n", "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", "@no_autocast\n", "def computation(t_0, t_1, t_2, t_3):\n", " # t_1: \"cpu f32[2, 2]\" \n", " t8 = torch.full((2, 2), 1.0, device=torch.device(\"cpu\"), dtype=torch.float32) # t8: \"cpu f32[2, 2]\"\n", " # t8 = ltorch.full((2, 2), 1.0, device=torch.device(\"cpu\"), dtype=torch.float32) # t8: \"cpu f32[2, 2]\"\n", " # t8 = prims.full((2, 2), 1.0, device=devices.Device(\"cpu\"), dtype=dtypes.float32) # t8: \"cpu f32[2, 2]\"\n", " t2 = torch.mul(t_1, t8) # t2: \"cpu f32[2, 2]\"\n", " # t2 = ltorch.mul(t_1, t8) # t2: \"cpu f32[2, 2]\"\n", " # t2 = prims.mul(t_1, t8) # t2: \"cpu f32[2, 2]\"\n", " del t_1\n", " # t_0: \"cpu f32[2, 2]\" \n", " t3 = torch.mul(t_0, t8) # t3: \"cpu f32[2, 2]\"\n", " # t3 = ltorch.mul(t_0, t8) # t3: \"cpu f32[2, 2]\"\n", " # t3 = prims.mul(t_0, t8) # t3: \"cpu f32[2, 2]\"\n", " del t_0, t8\n", " # t_3: \"cpu f32[2, 2]\" \n", " t9 = torch.full((2, 2), 1.0, device=torch.device(\"cpu\"), dtype=torch.float32) # t9: \"cpu f32[2, 2]\"\n", " # t9 = ltorch.full((2, 2), 1.0, device=torch.device(\"cpu\"), dtype=torch.float32) # t9: \"cpu f32[2, 2]\"\n", " # t9 = prims.full((2, 2), 1.0, device=devices.Device(\"cpu\"), dtype=dtypes.float32) # t9: \"cpu f32[2, 2]\"\n", " t6 = torch.mul(t_3, t9) # t6: \"cpu f32[2, 2]\"\n", " # t6 = ltorch.mul(t_3, t9) # t6: \"cpu f32[2, 2]\"\n", " # t6 = prims.mul(t_3, t9) # t6: \"cpu f32[2, 2]\"\n", " del t_3\n", " # t_2: \"cpu f32[2, 2]\" \n", " t7 = torch.mul(t_2, t9) # t7: \"cpu f32[2, 2]\"\n", " # t7 = ltorch.mul(t_2, t9) # t7: \"cpu f32[2, 2]\"\n", " # t7 = prims.mul(t_2, t9) # t7: \"cpu f32[2, 2]\"\n", " del t_2, t9\n", " return [t2, t3, t6, t7]\n" ] }, { "data": { "text/plain": [ "tensor([[-1.1229, -0.1863],\n", " [ 2.2082, -0.6380]])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# multimul is even differentiable because its \"like\" function is differentiable\n", "a.requires_grad_(True)\n", "b.requires_grad_(True)\n", "\n", "cfoo_grad = grad(cfoo)\n", "cfoo_grad(a, b, a, b)\n", "print(thunder.last_traces(cfoo_grad)[-1])\n", "\n", "a.requires_grad_(False)\n", "b.requires_grad_(False)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# We can tell thunder to execute existing operations using multimul by defining a transform\n", "# from them to multimul, and a \"checker\" function that returns True when the \n", "# transform is valid and False otherwise.\n", "\n", "# We can translate mul to multimul by ignoring the second multiplication\n", "def mul_to_multimul(a: Number | TensorProxy, b: Number | TensorProxy) -> TensorProxy:\n", " result, _ = multimul(a, b, 0, 0)\n", " return result\n", "\n", "# The \"checker\" function verifies that all inputs are CPU tensors or numbers, because NumPy\n", "# can't handle other inputs\n", "def mul_to_multimul_checker(a: Number | TensorProxy, b: Number | TensorProxy) -> bool:\n", " def is_cpu(x: Number | TensorProxy) -> bool:\n", " if isinstance(a, TensorProxy):\n", " return a.device.devicetype == DeviceType.CPU\n", " return True\n", "\n", " return all(is_cpu(x) for x in (a, b))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# The \"register_implementation\" method describes how to translate mul to multimul\n", "myex.register_implementation(ltorch.mul, checker=mul_to_multimul_checker, execution_transform=mul_to_multimul)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "# Constructed by Delete Last Used (took 0 milliseconds)\n", "import torch\n", "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", "@no_autocast\n", "def computation(t_0, t_1):\n", " # t_0: \"cpu f32[2, 2]\" \n", " # t_1: \"cpu f32[2, 2]\" \n", " (t0, _) = multimul(t_0, t_1, 0, 0)\n", " # t0 = ltorch.mul(t_0, t_1) # t0: \"cpu f32[2, 2]\"\n", " # t0 = prims.mul(t_0, t_1) # t0: \"cpu f32[2, 2]\"\n", " del t_0, t_1\n", " return t0" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Verifies the implementation of mul using multimul, and shows the execution transform\n", "def bar(a, b):\n", " return a * b\n", "cbar = thunder.jit(bar, executors=[myex])\n", "cbar(a, b)\n", "thunder.last_traces(cbar)[-1]" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "# Constructed by Delete Last Used (took 0 milliseconds)\n", "import torch\n", "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", "@no_autocast\n", "def computation(t_0, t_1):\n", " # t_1: \"cpu f32[2, 2]\" \n", " t4 = torch.full((2, 2), 1.0, device=torch.device(\"cpu\"), dtype=torch.float32) # t4: \"cpu f32[2, 2]\"\n", " # t4 = ltorch.full((2, 2), 1.0, device=torch.device(\"cpu\"), dtype=torch.float32) # t4: \"cpu f32[2, 2]\"\n", " # t4 = prims.full((2, 2), 1.0, device=devices.Device(\"cpu\"), dtype=dtypes.float32) # t4: \"cpu f32[2, 2]\"\n", " (t2, _) = multimul(t_1, t4, 0, 0)\n", " # t2 = ltorch.mul(t_1, t4) # t2: \"cpu f32[2, 2]\"\n", " # t2 = prims.mul(t_1, t4) # t2: \"cpu f32[2, 2]\"\n", " del t_1\n", " # t_0: \"cpu f32[2, 2]\" \n", " (t3, _) = multimul(t_0, t4, 0, 0)\n", " # t3 = ltorch.mul(t_0, t4) # t3: \"cpu f32[2, 2]\"\n", " # t3 = prims.mul(t_0, t4) # t3: \"cpu f32[2, 2]\"\n", " del t_0, t4\n", " return [t2, t3]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Execution transforms happen AFTER semantic transforms like grad, so even when computing the grad\n", "# of mul (which involves two multiplications to compute the grad) we still see multimul in the\n", "# execution trace\n", "a.requires_grad_(True)\n", "b.requires_grad_(True)\n", "\n", "cbar_grad = grad(cbar)\n", "cbar_grad(a, b)\n", "thunder.last_traces(cbar_grad)[-1]" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# In the above grad trace there are two multimuls, and both ignore one of their multiplications.\n", "# It would be more efficient to perform just one multimul, and we can make this happen\n", "# by defining a new grad transform for mul that calls multimul once.\n", "# thunder's grad transforms are defined in a novel way that's not the focus of this notebook,\n", "# but below we define the grad transform to use multimul.\n", "def mymul_grad(a: TensorProxy, b: TensorProxy) -> TensorProxy:\n", " fwd = a * b\n", "\n", " g = get_grad(fwd)\n", " a_grad, b_grad = multimul(b, g, a, g)\n", " put_grads((a, b), (a_grad, b_grad))\n", "\n", " return fwd\n", "\n", "# Re-registers the implementation, including the execution transform and now a grad transform\n", "myex.register_implementation(ltorch.mul, checker=mul_to_multimul_checker, execution_transform=mul_to_multimul, grad_transform=mymul_grad)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "# Constructed by Delete Last Used (took 0 milliseconds)\n", "import torch\n", "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", "@no_autocast\n", "def computation(t_0, t_1):\n", " # t_0: \"cpu f32[2, 2]\" \n", " t4 = torch.full((2, 2), 1.0, device=torch.device(\"cpu\"), dtype=torch.float32) # t4: \"cpu f32[2, 2]\"\n", " # t4 = ltorch.full((2, 2), 1.0, device=torch.device(\"cpu\"), dtype=torch.float32) # t4: \"cpu f32[2, 2]\"\n", " # t4 = prims.full((2, 2), 1.0, device=devices.Device(\"cpu\"), dtype=dtypes.float32) # t4: \"cpu f32[2, 2]\"\n", " # t_1: \"cpu f32[2, 2]\" \n", " (t2, t3) = multimul(t_1, t4, t_0, t4)\n", " # t2 = ltorch.mul(t_1, t4) # t2: \"cpu f32[2, 2]\"\n", " # t2 = prims.mul(t_1, t4) # t2: \"cpu f32[2, 2]\"\n", " # t3 = ltorch.mul(t_0, t4) # t3: \"cpu f32[2, 2]\"\n", " # t3 = prims.mul(t_0, t4) # t3: \"cpu f32[2, 2]\"\n", " del t_1, t4, t_0\n", " return [t2, t3]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Verifies our new grad transform is used and that a single multimul call is made\n", "cbar_grad = grad(cbar)\n", "cbar_grad(a, b)\n", "thunder.last_traces(cbar_grad)[-1]" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# Some operations may require inputs have particular properties (like be contiguous), or a transform may wish\n", "# to interleave torch operations with new operations. The transform function supports this. Here\n", "# we can see an example where the inputs to multimul are made contiguous before it's called\n", "def mul_to_contiguous_multimul(a: Number | TensorProxy, b: Number | TensorProxy) -> TensorProxy:\n", " a = a.contiguous()\n", " b = b.contiguous()\n", " result, _ = multimul(a, b, 0, 0)\n", " return result\n", "\n", "myex.register_implementation(ltorch.mul, checker=mul_to_multimul_checker, execution_transform=mul_to_contiguous_multimul)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "# Constructed by Delete Last Used (took 0 milliseconds)\n", "from torch import Tensor\n", "import torch\n", "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", "@no_autocast\n", "def computation(t_0, t_1):\n", " # t_0: \"cpu f32[2, 2]\" \n", " # t_1: \"cpu f32[2, 2]\" \n", " t1 = Tensor.contiguous(t_0, memory_format=_torch_memory_format_0) # t1: \"cpu f32[2, 2]\"\n", " # t1 = ltorch.contiguous(t_0, memory_format=_torch_memory_format_0) # t1: \"cpu f32[2, 2]\"\n", " # t1 = prims.stride_order(t_0, (1, 0)) # t1: \"cpu f32[2, 2]\"\n", " del t_0\n", " t2 = Tensor.contiguous(t_1, memory_format=_torch_memory_format_0) # t2: \"cpu f32[2, 2]\"\n", " # t2 = ltorch.contiguous(t_1, memory_format=_torch_memory_format_0) # t2: \"cpu f32[2, 2]\"\n", " # t2 = prims.stride_order(t_1, (1, 0)) # t2: \"cpu f32[2, 2]\"\n", " del t_1\n", " (t0, _) = multimul(t1, t2, 0, 0)\n", " # t0 = ltorch.mul(t1, t2) # t0: \"cpu f32[2, 2]\"\n", " # t0 = prims.mul(t1, t2) # t0: \"cpu f32[2, 2]\"\n", " del t1, t2\n", " return t0" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Verifies the new \"prologue\" for multimul works as expected. Note that the contiguous operations are \n", "# executed by PyTorch, and don't have to be executed by your executor\n", "a.requires_grad_(False)\n", "b.requires_grad_(False)\n", "\n", "def caz(a, b):\n", " return a * b\n", "ccaz = thunder.jit(caz, executors=[myex])\n", "ccaz(a, b)\n", "thunder.last_traces(ccaz)[-1]" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# NVIDIA's APEX cross-entropy executor is a good example of a real-world operator executor. It defines\n", "# fast forward and backward functions for torch.nn.functional.cross_entropy. We can see its custom\n", "# fwd and bwd operations below\n", "# NOTE This cell and the following cells require the apex executor be installed to run properly\n", "dtype = torch.float32\n", "device = 'cuda'\n", "logits = torch.randn([2048, 50257], device=device, dtype=ltorch.to_torch_dtype(dtype), requires_grad=False)\n", "labels = torch.randint(0, 50257, [2048], device=device)\n", "\n", "from thunder.executors.apexex import apex_ex\n", "\n", "def foo(logits, labels):\n", " return torch.nn.functional.cross_entropy(logits, labels, reduction=\"mean\", ignore_index=-1)\n", "cfoo = thunder.jit(foo, executors=[apex_ex])" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "# Constructed by Delete Last Used (took 0 milliseconds)\n", "import torch\n", "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", "@no_autocast\n", "def computation(t_0, t_1):\n", " # t_0: \"cuda:0 f32[2048, 50257]\" \n", " # t_1: \"cuda:0 i64[2048]\" \n", " (t18, _) = apex_cross_entropy(t_0, t_1, 'mean', 0.0)\n", " del t_0, t_1\n", " return t18" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Shows the forward operation\n", "cfoo(logits, labels)\n", "thunder.last_traces(cfoo)[-1]" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "# Constructed by Delete Last Used (took 0 milliseconds)\n", "from torch import Tensor\n", "import torch\n", "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", "@no_autocast\n", "def computation(t_0, t_1):\n", " # t_0: \"cuda:0 f32[2048, 50257]\" \n", " # t_1: \"cuda:0 i64[2048]\" \n", " (_, t1) = apex_cross_entropy(t_0, t_1, 'mean', 0.0)\n", " t6 = Tensor.contiguous(t_0, memory_format=_torch_memory_format_0) # t6: \"cuda:0 f32[2048, 50257]\"\n", " # t6 = ltorch.contiguous(t_0, memory_format=_torch_memory_format_0) # t6: \"cuda:0 f32[2048, 50257]\"\n", " # t6 = prims.stride_order(t_0, (1, 0)) # t6: \"cuda:0 f32[2048, 50257]\"\n", " del t_0\n", " t8 = torch.full((), 1.0, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t8: \"cuda:0 f32[]\"\n", " # t8 = ltorch.full((), 1.0, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t8: \"cuda:0 f32[]\"\n", " # t8 = prims.full((), 1.0, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t8: \"cuda:0 f32[]\"\n", " t12 = torch.unsqueeze(t8, 0) # t12: \"cuda:0 f32[1]\"\n", " # t12 = ltorch.unsqueeze(t8, 0) # t12: \"cuda:0 f32[1]\"\n", " # t12 = prims.broadcast_in_dim(t8, [1], []) # t12: \"cuda:0 f32[1]\"\n", " del t8\n", " t3 = Tensor.expand(t12, [1]) # t3: \"cuda:0 f32[1]\"\n", " # t3 = ltorch.expand(t12, [1]) # t3: \"cuda:0 f32[1]\"\n", " # t3 = prims.broadcast_in_dim(t12, (1,), (0,)) # t3: \"cuda:0 f32[1]\"\n", " del t12\n", " t4 = Tensor.expand(t3, (2048,)) # t4: \"cuda:0 f32[2048]\"\n", " # t4 = ltorch.expand(t3, (2048,)) # t4: \"cuda:0 f32[2048]\"\n", " # t4 = prims.broadcast_in_dim(t3, (2048,), (0,)) # t4: \"cuda:0 f32[2048]\"\n", " del t3\n", " t5 = torch.mul(t4, 0.00048828125) # t5: \"cuda:0 f32[2048]\"\n", " # t5 = ltorch.mul(t4, 0.00048828125) # t5: \"cuda:0 f32[2048]\"\n", " # t5 = prims.mul(t4, 0.00048828125) # t5: \"cuda:0 f32[2048]\"\n", " del t4\n", " t7 = apex_cross_entropy_backward(t5, t6, target=t_1, max_log_sum_exp=t1, label_smoothing=0.0) # t7: \"cuda:0 f32[2048, 50257]\"\n", " del t5, t6, t1, t_1\n", " return [t7]" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Shows APEX's custom forward and backward operations, plus additional PyTorch operations between the two\n", "logits.requires_grad_(True)\n", "\n", "cfoo_grad = grad(cfoo)\n", "\n", "cfoo_grad(logits, labels)\n", "thunder.last_traces(cfoo_grad)[-1]" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 2 }