{ "cells": [ { "cell_type": "markdown", "id": "b6f1f42d-f146-4c9c-8ed8-74f2bcf153f0", "metadata": {}, "source": [ "# Defining custom forward and backward for existing operators\n", "\n", "We are going to add custom executor for forward and backward of `torch.nn.functional.cross_entropy` operator." ] }, { "cell_type": "markdown", "id": "d57fee1e", "metadata": {}, "source": [ "Here's `SoftmaxCrossEntropyLoss` definition from https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py:\n", "\n", "```py\n", "import torch\n", "\n", "import xentropy_cuda\n", "\n", "\n", "class SoftmaxCrossEntropyLoss(torch.autograd.Function):\n", " @staticmethod\n", " def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, half_to_float=False):\n", " losses, max_log_sum_exp = xentropy_cuda.forward(\n", " logits, labels, smoothing, half_to_float)\n", " losses.masked_fill_(labels==padding_idx, 0)\n", "\n", " ctx.save_for_backward(logits, max_log_sum_exp, labels,\n", " torch.FloatTensor([smoothing]),\n", " torch.LongTensor([padding_idx]))\n", "\n", " return losses\n", "\n", " @staticmethod\n", " def backward(ctx, grad_loss):\n", " logits, max_log_sum_exp, labels, smoothing, padding_idx = ctx.saved_tensors\n", "\n", " if not grad_loss.is_contiguous():\n", " grad_loss = grad_loss.contiguous()\n", " grad_loss.masked_fill_(labels==padding_idx.item(), 0)\n", " grad_logits = xentropy_cuda.backward(\n", " grad_loss.contiguous(), logits, max_log_sum_exp,\n", " labels, smoothing.item())\n", "\n", " return grad_logits, None, None, None, None\n", "```" ] }, { "cell_type": "code", "execution_count": 1, "id": "b85398d8", "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.insert(0, '..')\n", "import thunder\n", "import torch\n", "torch.manual_seed(42)\n", "\n", "from thunder.core.proxies import TensorProxy\n" ] }, { "cell_type": "markdown", "id": "981ab590", "metadata": {}, "source": [ "In Thunder, we define _Executors_ to run given ops. Our executor will handle specific ops (rather than fusion regions),\n", "so our first thing is to create our own `OperatorExecutor`and register it with Thunder" ] }, { "cell_type": "code", "execution_count": 2, "id": "576d267d-9cef-4414-a722-b2cef0665cce", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "apex_xentropy_ex" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from thunder.extend import OperatorExecutor, register_executor\n", "apex_xentropy_ex = OperatorExecutor(\"apex_xentropy_ex\", version=\"0.1\")\n", "register_executor(apex_xentropy_ex)" ] }, { "cell_type": "markdown", "id": "ffbbf3f5", "metadata": {}, "source": [ "To get a feel of what's going on, let's have a wrapper that prints function calls and their arguments." ] }, { "cell_type": "code", "execution_count": 3, "id": "02e16bf5", "metadata": {}, "outputs": [], "source": [ "import functools\n", "\n", "_indentation = 0\n", "def _log(msg=None):\n", " \"\"\"Print a message at current indentation.\"\"\"\n", " if msg is not None:\n", " print(\" \" * _indentation + msg)\n", "\n", "def _log_indent(msg=None):\n", " \"\"\"Print a message and then indent the rest.\"\"\"\n", " global _indentation\n", " _log(msg)\n", " _indentation = 2 + _indentation\n", "\n", "def _log_unindent(msg=None):\n", " \"\"\"Unindent then print a message.\"\"\"\n", " global _indentation\n", " _indentation = _indentation - 2\n", " _log(msg)\n", "\n", "def log(func):\n", " \"\"\"A decorator for functions to log arguments and results.\"\"\"\n", " name = func.__name__\n", " def pp(v):\n", " \"\"\"Print certain values more succinctly\"\"\"\n", " vtype = str(type(v))\n", " if isinstance(v, tuple):\n", " return \"({})\".format(pp_values(v))\n", " elif isinstance(v, thunder.core.proxies.TensorProxy):\n", " return f\"TensorProxy(name={v.name}, shape={v.shape}, dtype={v.dtype}, device={v.device})\"\n", " elif isinstance(v, torch.Tensor):\n", " return f\"Tensor(shape={v.shape}, stride={v.stride()}, dtype={v.dtype}, device={v.device}) with values {v}\"\n", " else:\n", " return str(v)\n", " def pp_values(args):\n", " return \", \".join([pp(arg) for arg in args])\n", "\n", " @functools.wraps(func)\n", " def func_wrapper(*args):\n", " _log_indent(\"call {}({})\".format(name, pp_values(args)))\n", " res = func(*args)\n", " _log_unindent(\"|<- {} = {}\\n\".format(name, pp(res)))\n", " return res\n", "\n", " return func_wrapper" ] }, { "cell_type": "markdown", "id": "a06c6260", "metadata": {}, "source": [ "We want to define operators `apex_xentropy_forward` and `apex_xentropy_backward`.\n", "In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor.\n", "So we do this for the forward..." ] }, { "cell_type": "code", "execution_count": 4, "id": "ba10b306", "metadata": {}, "outputs": [], "source": [ "@log\n", "def apex_xentropy_forward_meta(\n", " a,\n", " target,\n", " weight=None,\n", " size_average=None,\n", " ignore_index=-100,\n", " reduce=None,\n", " reduction=\"mean\",\n", " label_smoothing=0.0,\n", "):\n", " max_log_sum_exp = TensorProxy(like=target)\n", " if reduction == \"none\":\n", " return TensorProxy(shape=(a.shape[0],), dtype=a.dtype, device=a.device,\n", " requires_grad=a.requires_grad), max_log_sum_exp\n", " else:\n", " raise ValueError(f\"Invalid reduction: {reduction}\")\n", "\n", "import xentropy_cuda\n", "\n", "@log\n", "def apex_xentropy_forward_impl(\n", " a,\n", " target,\n", " weight=None,\n", " size_average=None,\n", " ignore_index=-100,\n", " reduce=None,\n", " reduction=\"mean\",\n", " label_smoothing=0.0,\n", "):\n", " losses, max_log_sum_exp = xentropy_cuda.forward(a, target, label_smoothing, False)\n", "\n", " if reduction == \"none\":\n", " losses = losses.to(a.dtype)\n", " else:\n", " raise ValueError(f\"Invalid reduction: {reduction}\")\n", "\n", " return losses, max_log_sum_exp\n", "\n", "\n", "apex_xentropy_forward = apex_xentropy_ex.register_operator(\n", " \"apex_xentropy_forward\", meta=apex_xentropy_forward_meta, fn=apex_xentropy_forward_impl\n", ")\n", "\n" ] }, { "cell_type": "markdown", "id": "c4bd3c85", "metadata": {}, "source": [ "...and the backward..." ] }, { "cell_type": "code", "execution_count": 5, "id": "8e1fc927", "metadata": {}, "outputs": [], "source": [ "@log\n", "def apex_xentropy_backward_meta(\n", " grad,\n", " logits,\n", " labels,\n", " max_log_sum_exp,\n", " smoothing,\n", "):\n", " return TensorProxy(like=logits)\n", "\n", "\n", "@log\n", "def apex_xentropy_backward_impl(\n", " grad,\n", " logits,\n", " labels,\n", " max_log_sum_exp,\n", " smoothing,\n", "):\n", " return xentropy_cuda.backward(grad.contiguous(), logits, max_log_sum_exp, labels, smoothing)\n", " \n", "apex_xentropy_backward = apex_xentropy_ex.register_operator(\n", " \"apex_xentropy_backward\", meta=apex_xentropy_backward_meta, fn=apex_xentropy_backward_impl\n", ")" ] }, { "cell_type": "markdown", "id": "afe952ba", "metadata": {}, "source": [ "Because Thunder currently does not allow keyword arguments passed to the operators, we define a convenience wrapper:" ] }, { "cell_type": "code", "execution_count": 6, "id": "53e85b49", "metadata": {}, "outputs": [], "source": [ "def apex_xentropy(\n", " a,\n", " target,\n", " weight=None,\n", " size_average=None,\n", " ignore_index=-100,\n", " reduce=None,\n", " reduction=\"mean\",\n", " label_smoothing=0.0,\n", "):\n", " res, _ = apex_xentropy_forward(a, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)\n", " return res\n" ] }, { "cell_type": "markdown", "id": "1e5255f7", "metadata": {}, "source": [ "We can now `thunder.jit` functions using our operator:" ] }, { "cell_type": "code", "execution_count": 7, "id": "d8b4b898", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "call apex_xentropy_forward_meta(TensorProxy(name=t_0, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=t_1, shape=(2048,), dtype=int64, device=cuda:0), None, None, -100, None, none, 0.0)\n", "|<- apex_xentropy_forward_meta = (TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0))\n", "\n", "call apex_xentropy_forward_impl(Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 0.1940, 2.1614, -0.1721, ..., -0.4797, 1.4608, -0.5221],\n", " [ 1.8288, 0.2116, 0.1760, ..., -0.1599, 0.1195, 0.0073],\n", " [-2.1704, 1.0396, 2.2924, ..., 0.6021, 0.6498, -0.6316],\n", " ...,\n", " [ 0.4908, -0.3445, 2.6618, ..., -2.0946, -0.2890, 0.1500],\n", " [-1.0561, -1.3547, -1.0354, ..., 0.4304, -0.7882, -0.5496],\n", " [-0.6883, -1.3283, 0.3513, ..., -0.6951, 0.2013, -1.0238]],\n", " device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([ 9132, 12067, 5347, ..., 9268, 12534, 33582], device='cuda:0'), None, None, -100, None, none, 0.0)\n", "|<- apex_xentropy_forward_impl = (Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([10.7236, 11.9374, 11.0063, ..., 11.7434, 9.5018, 10.8008],\n", " device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3132, 11.3291, 11.3287, ..., 11.3279, 11.3251, 11.3301],\n", " device='cuda:0'))\n", "\n", "deviation from pytorch implementation: 9.5367431640625e-07\n" ] } ], "source": [ "def loss_fn(logits, labels):\n", " return apex_xentropy(logits, labels, reduction=\"none\")\n", "\n", "jfn = thunder.jit(loss_fn)\n", "\n", "logits = torch.randn([2048, 50257], device=\"cuda\")\n", "labels = torch.randint(0, 50257, [2048], device=\"cuda\")\n", "\n", "actual_result = jfn(logits, labels)\n", "expected_result = torch.nn.functional.cross_entropy(logits, labels, reduction=\"none\")\n", "\n", "print(\"deviation from pytorch implementation:\", (actual_result - expected_result).abs().max().item())" ] }, { "cell_type": "markdown", "id": "2e374912", "metadata": {}, "source": [ "We can also inspect what program thunder recorded to admire the beauty of our operator being called:" ] }, { "cell_type": "code", "execution_count": 8, "id": "40e50b13", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "# Constructed by Delete Last Used (took 0 milliseconds)\n", "import torch\n", "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", "@no_autocast\n", "def computation(logits, labels):\n", " # logits: \"cuda:0 f32[2048, 50257]\" \n", " # labels: \"cuda:0 i64[2048]\" \n", " (res, _) = apex_xentropy_forward(logits, labels, None, None, -100, None, 'none', 0.0)\n", " del logits, labels\n", " return res" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "thunder.last_traces(jfn)[-1]" ] }, { "cell_type": "markdown", "id": "2ec1a9e2", "metadata": {}, "source": [ "But it might be more awesome to have Thunder automatically use our new operators if applicable.\n", "We can define a transformation to do this for us. This consists of two parts:\n", "- a `checker`function that takes the arguments of the function we want to replace (but with `Tensor` arguments replaced by `TensorProxy` ones) and outputs `True` if we handle this case and `False` if not.\n", "- an `execution_transform` that is just a function with the same parameters and same return value as the function we want to replace and does the compute (as you would expect by calling our operator).\n", "\n", "Note that we attach this implementation to the `thunder.torch.cross_entropy` *Symbol* (an operator as appearing in Thunder traces, just like our `apex_xentropy_forward` is a Symbol).\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "7b6144c8", "metadata": {}, "outputs": [], "source": [ "def apex_xentropy_checker(\n", " a: TensorProxy,\n", " /,\n", " target: TensorProxy,\n", " weight: None | TensorProxy = None,\n", " size_average = None,\n", " ignore_index: int = -100,\n", " reduce = None,\n", " reduction: str = \"mean\",\n", " label_smoothing: float = 0.0,\n", ") -> bool:\n", " DeviceType = thunder.devices.DeviceType\n", " if a.device.devicetype != DeviceType.CUDA or target.device.devicetype != DeviceType.CUDA:\n", " return False\n", "\n", " probability_target: bool = thunder.core.utils.same_shape(a.shape, target.shape)\n", " if probability_target or label_smoothing > 0.0:\n", " return False\n", "\n", " torch_dtype: torch.dtype = thunder.torch.to_torch_dtype(a.dtype)\n", " if torch_dtype not in (torch.float16, torch.bfloat16, torch.float32):\n", " return False\n", "\n", " if ignore_index >= 0:\n", " return False\n", "\n", " if weight is not None:\n", " return False\n", "\n", " # NOTE These parameters are deprecated and not supported\n", " if size_average is not None or reduce is not None:\n", " return False\n", "\n", " if reduction not in [\"sum\", \"mean\", \"none\"]:\n", " return False\n", "\n", " # Checks from\n", " # https://github.com/NVIDIA/apex/blob/7b2e71b0d4013f8e2f9f1c8dd21980ff1d76f1b6/apex/contrib/csrc/xentropy/xentropy_kernel.cu#L587-L590\n", " if a.ndim != 2:\n", " return False\n", "\n", " if target.ndim != 1:\n", " return False\n", "\n", " if a.shape[0] != target.shape[0]:\n", " return False\n", "\n", " if a.numel == 0:\n", " return False\n", "\n", " # Xentropy kernel produces incorrect results if a.shape[1] is less\n", " # than 30 and not a multiple of 4\n", " if a.shape[1] < 30 and a.shape[1] % 4 != 0:\n", " return False\n", "\n", " return True\n", "\n", "from thunder.core.transforms import get_grad, put_grads\n", "\n", "\n", "def cross_entropy_to_apex(\n", " a,\n", " target,\n", " weight=None,\n", " size_average=None,\n", " ignore_index=-100,\n", " reduce=None,\n", " reduction=\"mean\",\n", " label_smoothing=0.0,\n", "):\n", " loss, max_log_sum_exp = apex_xentropy_forward(\n", " a,\n", " target,\n", " weight,\n", " size_average,\n", " ignore_index,\n", " reduce,\n", " reduction,\n", " label_smoothing,\n", " )\n", " return loss\n", "\n", "apex_xentropy_ex.register_implementation(thunder.torch.cross_entropy, checker=apex_xentropy_checker, \n", " execution_transform=cross_entropy_to_apex)" ] }, { "cell_type": "markdown", "id": "e2ad6d09", "metadata": {}, "source": [ "We now can run the \"unmodified\" PyTorch function with `F.cross_entroy` and still get our implementation (but don't forget the executor in the call to the jit!):" ] }, { "cell_type": "code", "execution_count": 10, "id": "00c6f1ab", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "call apex_xentropy_forward_meta(TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), None, None, -100, None, none, 0.0)\n", "|<- apex_xentropy_forward_meta = (TensorProxy(name=t19, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t18, shape=(2048,), dtype=int64, device=cuda:0))\n", "\n", "call apex_xentropy_forward_impl(Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 1.2891, -0.2912, 0.6866, ..., -1.5067, 1.3132, -0.7352],\n", " [-1.9077, -0.8366, -0.0747, ..., 1.6109, -0.7460, 0.7346],\n", " [-1.0830, -0.2586, 0.0402, ..., -0.2030, -1.0907, -1.7308],\n", " ...,\n", " [ 0.5805, -0.0830, -0.4658, ..., -0.1023, -1.3720, 0.1850],\n", " [-0.8181, 1.3273, 0.8034, ..., 1.2658, -1.4824, 0.0482],\n", " [ 0.9964, -1.8733, 0.3547, ..., 0.0190, -0.3228, 0.4827]],\n", " device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([ 8137, 23633, 42622, ..., 39128, 39817, 18664], device='cuda:0'), None, None, -100, None, none, 0.0)\n", "|<- apex_xentropy_forward_impl = (Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([12.9479, 11.7810, 9.1981, ..., 10.1080, 10.4095, 10.5884],\n", " device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3268, 11.3203, 11.3294, ..., 11.3251, 11.3333, 11.3225],\n", " device='cuda:0'))\n", "\n", "deviation from pytorch implementation: 9.5367431640625e-07\n", "# Constructed by Delete Last Used (took 0 milliseconds)\n", "import torch\n", "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", "@no_autocast\n", "def computation(logits, labels):\n", " # logits: \"cuda:0 f32[2048, 50257]\" \n", " # labels: \"cuda:0 i64[2048]\" \n", " (t17, _) = apex_xentropy_forward(logits, labels, None, None, -100, None, 'none', 0.0)\n", " del logits, labels\n", " return t17\n" ] } ], "source": [ "def loss_fn(logits, labels):\n", " return torch.nn.functional.cross_entropy(logits, labels, reduction=\"none\")\n", "\n", "jfn = thunder.jit(loss_fn, executors=[apex_xentropy_ex])\n", "\n", "logits = torch.randn([2048, 50257], device=\"cuda\")\n", "labels = torch.randint(0, 50257, [2048], device=\"cuda\")\n", "\n", "actual_result = jfn(logits, labels)\n", "expected_result = torch.nn.functional.cross_entropy(logits, labels, reduction=\"none\")\n", "\n", "print(\"deviation from pytorch implementation:\", (actual_result - expected_result).abs().max().item())\n", "\n", "print(thunder.last_traces(jfn)[-1])" ] }, { "cell_type": "markdown", "id": "bbfe283e", "metadata": {}, "source": [ "## So what is with the backward?\n", "\n", "Well, we can define a gradient function and register it along with our implementation.\n", "\n", "We thought a lot about how our extension point for gradients looked like - PyTorch's `autograd.Functions` is probably the most well-known way - and we felt that it would be nice to make the connection between tensors in the computation and their gradients explicit.\n", "\n", "So the grad transform we implement below is a function that does the following:\n", "\n", "- it takes the same arguments as the forward,\n", "- it computes the forward from its arguments,\n", "- it then uses `get_grad` to obtain the required gradients for the forward outputs,\n", "- computes the gradients for the inputs (this is the backward),\n", "- finally attaches the computed gradients to the respective tensors with `put_grad`\n", "\n", "We supply the grad function as an additional argument of `register_implementation`." ] }, { "cell_type": "code", "execution_count": 11, "id": "755f53d2", "metadata": {}, "outputs": [], "source": [ "@log\n", "def apex_cross_entropy_grad(\n", " a,\n", " target,\n", " weight=None,\n", " size_average=None,\n", " ignore_index=-100,\n", " reduce=None,\n", " reduction=\"mean\",\n", " label_smoothing=0.0,\n", "):\n", " loss, max_log_sum_exp = apex_xentropy_forward(\n", " a,\n", " target,\n", " weight,\n", " size_average,\n", " ignore_index,\n", " reduce,\n", " reduction,\n", " label_smoothing,\n", " )\n", " grad = get_grad(loss)\n", " grad_logits = apex_xentropy_backward(\n", " grad,\n", " a,\n", " target,\n", " max_log_sum_exp,\n", " label_smoothing,\n", " )\n", " put_grads((a,), (grad_logits,))\n", " return loss\n", "\n", "apex_xentropy_ex.register_implementation(thunder.torch.cross_entropy, checker=apex_xentropy_checker, \n", " execution_transform=cross_entropy_to_apex, grad_transform=apex_cross_entropy_grad)" ] }, { "cell_type": "markdown", "id": "b4ec7c57", "metadata": {}, "source": [ "With these registrations, we can compile a function and it will be automatically transformed into forward and backward and wrapped in a PyTorch autograd.Function calling the backward trace computed by Thunder.\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "8c5da6f2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "call apex_cross_entropy_grad(TensorProxy(name=a, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=target, shape=(2048,), dtype=int64, device=cuda:0), None, None, [IntegerProxy name=ignore_index, value=-1], None, none, [FloatProxy name=label_smoothing, value=0.0])\n", " call apex_xentropy_forward_meta(TensorProxy(name=a, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=target, shape=(2048,), dtype=int64, device=cuda:0), None, None, [IntegerProxy name=ignore_index, value=-1], None, none, [FloatProxy name=label_smoothing, value=0.0])\n", " |<- apex_xentropy_forward_meta = (TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0))\n", "\n", " call apex_xentropy_backward_meta(TensorProxy(name=t2, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=a, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=target, shape=(2048,), dtype=int64, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0), [FloatProxy name=label_smoothing, value=0.0])\n", " |<- apex_xentropy_backward_meta = TensorProxy(name=t3, shape=(2048, 50257), dtype=float32, device=cuda:0)\n", "\n", "|<- apex_cross_entropy_grad = TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0)\n", "\n", "call apex_xentropy_forward_meta(TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), None, None, -1, None, none, 0.0)\n", "|<- apex_xentropy_forward_meta = (TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0))\n", "\n", "call apex_xentropy_backward_meta(TensorProxy(name=t2, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0), [FloatProxy name=f0, value=0.0])\n", "|<- apex_xentropy_backward_meta = TensorProxy(name=t3, shape=(2048, 50257), dtype=float32, device=cuda:0)\n", "\n", "call apex_xentropy_forward_impl(Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-9.2466e-01, -4.2534e-01, -2.6438e+00, ..., 4.5115e-01,\n", " 2.4087e-01, 1.9543e+00],\n", " [ 7.5610e-03, -4.9079e-01, 3.6572e-01, ..., 2.5072e+00,\n", " 9.0470e-01, -1.4305e+00],\n", " [-4.4104e-01, -7.6137e-01, -1.1172e+00, ..., 5.9006e-02,\n", " -1.0212e+00, 3.0210e-02],\n", " ...,\n", " [-4.2869e+00, 1.4900e+00, -9.1910e-01, ..., 3.6535e-03,\n", " -6.8372e-01, 7.1824e-01],\n", " [-4.2704e-02, 1.3505e+00, 2.1361e+00, ..., -1.1139e+00,\n", " 6.1626e-01, 4.8158e-01],\n", " [-7.3334e-01, 2.0820e+00, 3.7722e-02, ..., -7.2141e-01,\n", " 4.6871e-01, 7.0758e-01]], device='cuda:0', requires_grad=True), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([ 3957, 45831, 13902, ..., 45225, 32145, 12167], device='cuda:0'), None, None, -1, None, none, 0.0)\n", "|<- apex_xentropy_forward_impl = (Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([12.4000, 10.9672, 12.6648, ..., 11.7144, 11.8293, 10.9396],\n", " device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3186, 11.3176, 11.3300, ..., 11.3257, 11.3189, 11.3202],\n", " device='cuda:0'))\n", "\n", "call apex_xentropy_backward_impl(Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([ 0.8882, -0.0650, -1.2035, ..., -0.4344, -0.0588, -2.5740],\n", " device='cuda:0'), Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-9.2466e-01, -4.2534e-01, -2.6438e+00, ..., 4.5115e-01,\n", " 2.4087e-01, 1.9543e+00],\n", " [ 7.5610e-03, -4.9079e-01, 3.6572e-01, ..., 2.5072e+00,\n", " 9.0470e-01, -1.4305e+00],\n", " [-4.4104e-01, -7.6137e-01, -1.1172e+00, ..., 5.9006e-02,\n", " -1.0212e+00, 3.0210e-02],\n", " ...,\n", " [-4.2869e+00, 1.4900e+00, -9.1910e-01, ..., 3.6535e-03,\n", " -6.8372e-01, 7.1824e-01],\n", " [-4.2704e-02, 1.3505e+00, 2.1361e+00, ..., -1.1139e+00,\n", " 6.1626e-01, 4.8158e-01],\n", " [-7.3334e-01, 2.0820e+00, 3.7722e-02, ..., -7.2141e-01,\n", " 4.6871e-01, 7.0758e-01]], device='cuda:0', requires_grad=True), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([ 3957, 45831, 13902, ..., 45225, 32145, 12167], device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3186, 11.3176, 11.3300, ..., 11.3257, 11.3189, 11.3202],\n", " device='cuda:0'), 0.0)\n", "|<- apex_xentropy_backward_impl = Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 4.2787e-06, 7.0495e-06, 7.6679e-07, ..., 1.6936e-05,\n", " 1.3724e-05, 7.6143e-05],\n", " [-7.9652e-07, -4.8391e-07, -1.1396e-06, ..., -9.7005e-06,\n", " -1.9535e-06, -1.8908e-07],\n", " [-9.2972e-06, -6.7489e-06, -4.7280e-06, ..., -1.5329e-05,\n", " -5.2049e-06, -1.4894e-05],\n", " ...,\n", " [-7.2011e-08, -2.3243e-05, -2.0894e-06, ..., -5.2573e-06,\n", " -2.6439e-06, -1.0743e-05],\n", " [-6.8437e-07, -2.7565e-06, -6.0470e-06, ..., -2.3447e-07,\n", " -1.3227e-06, -1.1561e-06],\n", " [-1.4990e-05, -2.5033e-04, -3.2410e-05, ..., -1.5170e-05,\n", " -4.9872e-05, -6.3328e-05]], device='cuda:0')\n", "\n", "Max error in loss: 9.5367431640625e-07\n", "Max error in logits grad: 2.384185791015625e-07\n" ] }, { "data": { "text/plain": [ "[# Constructed by Backward pass\n", " import torch\n", " from thunder.executors.torchex import no_autocast\n", " \n", " @torch.no_grad()\n", " @no_autocast\n", " def backward_fn(saved_for_backward, cotangents):\n", " # saved_for_backward: \"Collection\" \n", " # cotangents: \"Collection\" \n", " C0, \\\n", " C1, \\\n", " = saved_for_backward\n", " t2, \\\n", " = cotangents\n", " logits, \\\n", " labels, \\\n", " t0, \\\n", " = C0\n", " f0, \\\n", " = C1\n", " t3 = apex_xentropy_backward(t2, logits, labels, t0, f0) # t3: \"cuda:0 f32[2048, 50257]\"\n", " return (t3, None),\n", " # Constructed by Transform for execution (took 0 milliseconds)\n", " import torch\n", " from thunder.executors.torchex import no_autocast\n", " \n", " @torch.no_grad()\n", " @no_autocast\n", " def backward_fn(saved_for_backward, cotangents):\n", " # saved_for_backward: \"Collection\" \n", " # cotangents: \"Collection\" \n", " C0, \\\n", " C1, \\\n", " = saved_for_backward\n", " t2, \\\n", " = cotangents\n", " logits, \\\n", " labels, \\\n", " t0, \\\n", " = C0\n", " f0, \\\n", " = C1\n", " t3 = apex_xentropy_backward(t2, logits, labels, t0, f0) # t3: \"cuda:0 f32[2048, 50257]\"\n", " return (t3, None),\n", " # Constructed by Update Call Context (took 0 milliseconds)\n", " import torch\n", " from thunder.executors.torchex import no_autocast\n", " \n", " @torch.no_grad()\n", " @no_autocast\n", " def backward_fn(saved_for_backward, cotangents):\n", " # saved_for_backward: \"Collection\" \n", " # cotangents: \"Collection\" \n", " C0, \\\n", " C1, \\\n", " = saved_for_backward\n", " t2, \\\n", " = cotangents\n", " labels, \\\n", " logits, \\\n", " t0, \\\n", " = C0\n", " f0, \\\n", " = C1\n", " t3 = apex_xentropy_backward(t2, logits, labels, t0, f0) # t3: \"cuda:0 f32[2048, 50257]\"\n", " return (t3, None),\n", " # Constructed by Delete Last Used (took 0 milliseconds)\n", " import torch\n", " from thunder.executors.torchex import no_autocast\n", " \n", " @torch.no_grad()\n", " @no_autocast\n", " def backward_fn(saved_for_backward, cotangents):\n", " # saved_for_backward: \"Collection\" \n", " # cotangents: \"Collection\" \n", " C0, \\\n", " C1, \\\n", " = saved_for_backward\n", " clear_collection(saved_for_backward)\n", " del saved_for_backward\n", " t2, \\\n", " = cotangents\n", " clear_collection(cotangents)\n", " del cotangents\n", " labels, \\\n", " logits, \\\n", " t0, \\\n", " = C0\n", " clear_collection(C0)\n", " del C0\n", " f0, \\\n", " = C1\n", " clear_collection(C1)\n", " del C1\n", " t3 = apex_xentropy_backward(t2, logits, labels, t0, f0) # t3: \"cuda:0 f32[2048, 50257]\"\n", " del t2, logits, labels, t0, f0\n", " return (t3, None)]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from thunder import torch as ltorch\n", "\n", "torch.manual_seed(0)\n", "\n", "logits = torch.randn([2048, 50257], device=\"cuda\", requires_grad=True)\n", "labels = torch.randint(0, 50257, [2048], device=\"cuda\")\n", "\n", "def loss_fn(logits, labels):\n", " return torch.nn.functional.cross_entropy(logits, labels, reduction=\"none\", ignore_index=-1)\n", "\n", "cfn = thunder.jit(loss_fn, executors=[apex_xentropy_ex])\n", "\n", "actual_loss = cfn(logits, labels)\n", "go = torch.randn_like(actual_loss)\n", "\n", "actual_grad, = torch.autograd.grad(actual_loss, logits, go)\n", "\n", "expected_loss = loss_fn(logits, labels)\n", "expected_grad, = torch.autograd.grad(expected_loss, logits, go)\n", "\n", "print(\"Max error in loss:\", (actual_loss - expected_loss).abs().max().item())\n", "print(\"Max error in logits grad:\", (actual_grad - expected_grad).abs().max().item())\n", "\n", "thunder.last_traces(cfn)[-1]" ] }, { "cell_type": "markdown", "id": "54d6a5ea", "metadata": {}, "source": [ "Alternatively, we can also use the `grad` transform to get the gradient:" ] }, { "cell_type": "code", "execution_count": 13, "id": "c88118eb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "call apex_cross_entropy_grad(TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), None, None, -100, None, none, 0.0)\n", " call apex_xentropy_forward_meta(TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), None, None, -100, None, none, 0.0)\n", " |<- apex_xentropy_forward_meta = (TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0))\n", "\n", " call apex_xentropy_backward_meta(TensorProxy(name=t2, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0), 0.0)\n", " |<- apex_xentropy_backward_meta = TensorProxy(name=t3, shape=(2048, 50257), dtype=float32, device=cuda:0)\n", "\n", "|<- apex_cross_entropy_grad = TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0)\n", "\n", "call apex_xentropy_forward_impl(Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 0.5390, 0.1760, -1.0790, ..., 0.1695, -0.8082, -0.6984],\n", " [ 2.1555, 1.3938, 0.3928, ..., 0.8937, -0.4949, 1.1610],\n", " [ 0.6784, 1.1188, 0.7508, ..., -0.0941, 0.8380, 0.1878],\n", " ...,\n", " [-1.5834, -0.1573, -1.3511, ..., 0.6167, -0.1083, 0.4116],\n", " [-0.5476, 0.5831, 0.0791, ..., -0.4986, -0.5270, 0.0954],\n", " [ 0.2825, -1.0378, -0.5506, ..., 0.0149, 1.3521, -1.0823]],\n", " device='cuda:0', requires_grad=True), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([44917, 35770, 41569, ..., 9798, 33992, 36123], device='cuda:0'), None, None, -100, None, none, 0.0)\n", "|<- apex_xentropy_forward_impl = (Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([10.0233, 11.9095, 11.2898, ..., 10.9289, 10.7487, 10.7455],\n", " device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3241, 11.3207, 11.3283, ..., 11.3224, 11.3186, 11.3205],\n", " device='cuda:0'))\n", "\n", "call apex_xentropy_backward_impl(Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([1., 1., 1., ..., 1., 1., 1.], device='cuda:0'), Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 0.5390, 0.1760, -1.0790, ..., 0.1695, -0.8082, -0.6984],\n", " [ 2.1555, 1.3938, 0.3928, ..., 0.8937, -0.4949, 1.1610],\n", " [ 0.6784, 1.1188, 0.7508, ..., -0.0941, 0.8380, 0.1878],\n", " ...,\n", " [-1.5834, -0.1573, -1.3511, ..., 0.6167, -0.1083, 0.4116],\n", " [-0.5476, 0.5831, 0.0791, ..., -0.4986, -0.5270, 0.0954],\n", " [ 0.2825, -1.0378, -0.5506, ..., 0.0149, 1.3521, -1.0823]],\n", " device='cuda:0', requires_grad=True), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([44917, 35770, 41569, ..., 9798, 33992, 36123], device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3241, 11.3207, 11.3283, ..., 11.3224, 11.3186, 11.3205],\n", " device='cuda:0'), 0.0)\n", "|<- apex_xentropy_backward_impl = Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[2.0706e-05, 1.4403e-05, 4.1058e-06, ..., 1.4309e-05, 5.3827e-06,\n", " 6.0079e-06],\n", " [1.0461e-04, 4.8840e-05, 1.7949e-05, ..., 2.9621e-05, 7.3879e-06,\n", " 3.8697e-05],\n", " [2.3705e-05, 3.6822e-05, 2.5485e-05, ..., 1.0948e-05, 2.7806e-05,\n", " 1.4513e-05],\n", " ...,\n", " [2.4836e-06, 1.0338e-05, 3.1331e-06, ..., 2.2417e-05, 1.0857e-05,\n", " 1.8259e-05],\n", " [7.0235e-06, 2.1758e-05, 1.3145e-05, ..., 7.3762e-06, 7.1699e-06,\n", " 1.3360e-05],\n", " [1.6078e-05, 4.2941e-06, 6.9897e-06, ..., 1.2304e-05, 4.6857e-05,\n", " 4.1070e-06]], device='cuda:0')\n", "\n", "Difference: 1.3969838619232178e-09\n", "# Constructed by Delete Last Used (took 0 milliseconds)\n", "import torch\n", "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", "@no_autocast\n", "def computation(logits, labels):\n", " # logits: \"cuda:0 f32[2048, 50257]\" \n", " # labels: \"cuda:0 i64[2048]\" \n", " (_, t0) = apex_xentropy_forward(logits, labels, None, None, -100, None, 'none', 0.0)\n", " t4 = torch.full((2048,), 1.0, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t4: \"cuda:0 f32[2048]\"\n", " # t4 = ltorch.full((2048,), 1.0, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t4: \"cuda:0 f32[2048]\"\n", " # t4 = prims.full((2048,), 1.0, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t4: \"cuda:0 f32[2048]\"\n", " t3 = apex_xentropy_backward(t4, logits, labels, t0, 0.0) # t3: \"cuda:0 f32[2048, 50257]\"\n", " del t4, logits, labels, t0\n", " return [t3]\n" ] } ], "source": [ "logits = torch.randn([2048, 50257], device=\"cuda\", requires_grad=True)\n", "labels = torch.randint(0, 50257, [2048], device=\"cuda\")\n", "\n", "grad_jfn = thunder.core.transforms.grad(jfn)\n", "actual_grad, = grad_jfn(logits, labels)\n", "\n", "expected_grad, = torch.autograd.grad(loss_fn(logits, labels).sum(), logits)\n", "\n", "\n", "print(\"Difference:\", (actual_grad - expected_grad).abs().max().item())\n", "print(thunder.last_traces(grad_jfn)[-1])\n" ] }, { "cell_type": "markdown", "id": "e234a47b", "metadata": {}, "source": [ "So let's wrap up what we did here:\n", "\n", "- We defined a custom executor with custom operations (Symbols in Thunder language), each with a *Meta-* (data propagation) *function* and an implementation.\n", "- We defined and registered rules to map existing operations to our new operations. This allows us to use optimizations on our model without changing the model's code! \n", "- We defined a gradient rule and saw how our automatic PyTorch Autograd integration or the explicit `grad` transform uses it.\n", "\n", "Now go and implement your favourite optimized operators. We would love to hear about your use-cases!\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4e90796d", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 5 }