{ "cells": [ { "cell_type": "markdown", "id": "b6f1f42d-f146-4c9c-8ed8-74f2bcf153f0", "metadata": {}, "source": [ "# Defining new Thunder operators\n", "\n", "We are going to add a new operator to Thunder with the corresponding executor. The operator will be called `sincos`` and will compute the sine and cosine of a given input.\n", "\n", "Thunder has three sets of core operators: `thunder.torch`, `thunder.clang`, and `thunder.prims`. `thunder.prims` is a set of operators that are implemented in Python and are used to build the other two sets of operators. A primitive is an operator that is not implemented in terms of other operators." ] }, { "cell_type": "code", "execution_count": 1, "id": "576d267d-9cef-4414-a722-b2cef0665cce", "metadata": {}, "outputs": [], "source": [ "import thunder\n", "import torch\n", "\n", "from thunder.core.proxies import TensorProxy\n", "from enum import Enum" ] }, { "cell_type": "markdown", "id": "a1b6863a", "metadata": {}, "source": [ "Let us define some helper functions (execute the cell below) for printing what's going on." ] }, { "cell_type": "code", "execution_count": 2, "id": "02e16bf5", "metadata": {}, "outputs": [], "source": [ "import functools\n", "\n", "_indentation = 0\n", "def _log(msg=None):\n", " \"\"\"Print a message at current indentation.\"\"\"\n", " if msg is not None:\n", " print(\" \" * _indentation + msg)\n", "\n", "def _log_indent(msg=None):\n", " \"\"\"Print a message and then indent the rest.\"\"\"\n", " global _indentation\n", " _log(msg)\n", " _indentation = 2 + _indentation\n", "\n", "def _log_unindent(msg=None):\n", " \"\"\"Unindent then print a message.\"\"\"\n", " global _indentation\n", " _indentation = _indentation - 2\n", " _log(msg)\n", " \n", "def log(func):\n", " \"\"\"A decorator for functions to log arguments and results.\"\"\"\n", " name = func.__name__\n", " def pp(v):\n", " \"\"\"Print certain values more succinctly\"\"\"\n", " vtype = str(type(v))\n", " if isinstance(v, tuple):\n", " return \"({})\".format(pp_values(v))\n", " elif isinstance(v, thunder.core.proxies.TensorProxy):\n", " return f\"TensorProxy(name={v.name}, shape={v.shape}, dtype={v.dtype}, device={v.device})\"\n", " elif isinstance(v, torch.Tensor):\n", " return f\"Tensor(shape={v.shape}, stride={v.stride()}, dtype={v.dtype}, device={v.device}) with values {v}\"\n", " else:\n", " return str(v)\n", " def pp_values(args):\n", " return \", \".join([pp(arg) for arg in args])\n", "\n", " @functools.wraps(func)\n", " def func_wrapper(*args):\n", " _log_indent(\"call {}({})\".format(name, pp_values(args)))\n", " res = func(*args)\n", " _log_unindent(\"|<- {} = {}\\n\".format(name, pp(res)))\n", " return res\n", "\n", " return func_wrapper" ] }, { "cell_type": "markdown", "id": "c8e1626f", "metadata": {}, "source": [ "Our new operator has the following signature `sincos(x: Tensor) -> Tuple[Tensor, Tensor]`. It takes a tensor as input and returns a tuple of two tensors. The first tensor is the sine of the input and the second tensor is the cosine of the input.\n", "\n", "We call all callables that should be recorded in the trace *Symbols*. Symbols are the building blocks of the trace. Symbols are either primitives or composite operators. Composite operators are implemented in terms of other operators and primitives. Primitives are operators that are not implemented in terms of other operators or primitives.\n", "\n", "The easiest way to register a new operator is through defining a meta - defining how the metadata of the output looks like given the metadata of the inputs and an implementation (dealing with concrete objects like Python `Number`s and PyTorch `Tensor`s) and register both of them through an executor. This will automatically create a symbol for us.\n", "\n", "So we create an executor:" ] }, { "cell_type": "code", "execution_count": 3, "id": "f680ae37", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[sincos_executor, sdpa]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sincos_executor = thunder.extend.OperatorExecutor(\"sincos_executor\", version='0.1')\n", "thunder.add_default_executor(sincos_executor)" ] }, { "cell_type": "markdown", "id": "4f147274", "metadata": {}, "source": [ "We define meta and implementation: " ] }, { "cell_type": "code", "execution_count": 4, "id": "d5a72aff", "metadata": {}, "outputs": [], "source": [ "@log\n", "def sincos_meta(inp):\n", " return (TensorProxy(like=inp), TensorProxy(like=inp))\n", "\n", "@log\n", "def sincos_impl(inp):\n", " return torch.sin(inp), torch.cos(inp)" ] }, { "cell_type": "markdown", "id": "a06c6260", "metadata": {}, "source": [ "And register it as `sincos`:\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "03516b03", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[Symbol name=sincos]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sincos = sincos_executor.register_operator('sincos', meta=sincos_meta, fn=sincos_impl)\n", "sincos" ] }, { "cell_type": "markdown", "id": "e234a47b", "metadata": {}, "source": [ "That's it! We have implemented our new primitive. Let's test it." ] }, { "cell_type": "code", "execution_count": 6, "id": "8c5da6f2", "metadata": {}, "outputs": [], "source": [ "def fun(a, b):\n", " sin, cos = sincos(a)\n", " return sin + cos + b" ] }, { "cell_type": "code", "execution_count": 7, "id": "aef98360", "metadata": {}, "outputs": [], "source": [ "a = torch.randn(1)\n", "b = torch.randn(1)" ] }, { "cell_type": "markdown", "id": "6266dea4", "metadata": {}, "source": [ "`fun` is now a Thunder function, meaning it can only accept Thunder's TensorProxy as inputs. Let's test it." ] }, { "cell_type": "code", "execution_count": 8, "id": "87f9f6e7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Attempting to execute outside of a tracing context, which is not supported\n" ] } ], "source": [ "try:\n", " fun(a, b)\n", "except Exception as e:\n", " print(e)" ] }, { "cell_type": "markdown", "id": "8dde0c4c", "metadata": {}, "source": [ "In the future we will add support for `torch.Tensor` and `numpy.ndarray` inputs for eager mode of Thunder functions. But for now this function is working only in the tracing mode." ] }, { "cell_type": "code", "execution_count": 9, "id": "f938dff7-bac6-4807-b79d-a16cb5c6d90c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "call sincos_meta(TensorProxy(name=a, shape=(1,), dtype=float32, device=cpu))\n", "|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cpu), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cpu))\n", "\n", "# Constructed by Dead Code Elimination (took 0 milliseconds)\n", "import thunder\n", "import thunder.torch as ltorch\n", "import torch\n", "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", "@no_autocast\n", "def fun(a, b):\n", " # a: \"cpu f32[1]\" \n", " # b: \"cpu f32[1]\" \n", " (t0, t1) = sincos(a)\n", " t2 = ltorch.add(t0, t1, alpha=None) # t2: \"cpu f32[1]\"\n", " # t2 = prims.add(t0, t1) # t2: \"cpu f32[1]\"\n", " t3 = ltorch.add(t2, b, alpha=None) # t3: \"cpu f32[1]\"\n", " # t3 = prims.add(t2, b) # t3: \"cpu f32[1]\"\n", " return t3\n" ] } ], "source": [ "# Let's see first how this function is represented as a trace\n", "trace = thunder.trace()(fun, a, b)\n", "print(trace)" ] }, { "cell_type": "code", "execution_count": 10, "id": "6eb4818b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Bound symbol with id=PrimIDs.UNPACK_TRIVIAL is represented in the trace as |# a: \"cpu f32[1]\" |\n", "Bound symbol with id=PrimIDs.UNPACK_TRIVIAL is represented in the trace as |# b: \"cpu f32[1]\" |\n", "Bound symbol with id=sincos is represented in the trace as |(t0, t1) = sincos(a)|\n", "Bound symbol with id=torch.add is represented in the trace as |t2 = ltorch.add(t0, t1, alpha=None) # t2: \"cpu f32[1]\"\n", " # t2 = prims.add(t0, t1) # t2: \"cpu f32[1]\"|\n", " It has the following subsymbols:\n", " id=PrimIDs.ADD |t2 = prims.add(t0, t1) # t2: \"cpu f32[1]\"|\n", "Bound symbol with id=torch.add is represented in the trace as |t3 = ltorch.add(t2, b, alpha=None) # t3: \"cpu f32[1]\"\n", " # t3 = prims.add(t2, b) # t3: \"cpu f32[1]\"|\n", " It has the following subsymbols:\n", " id=PrimIDs.ADD |t3 = prims.add(t2, b) # t3: \"cpu f32[1]\"|\n", "Bound symbol with id=PrimIDs.RETURN is represented in the trace as |return t3|\n" ] } ], "source": [ "# We can loop over the recorded operations that we call BoundSymbols\n", "for bound_symbol in trace.bound_symbols:\n", " print(f\"Bound symbol with id={bound_symbol.sym.id} is represented in the trace as |{bound_symbol}|\")\n", " if bound_symbol.subsymbols:\n", " print(\" It has the following subsymbols:\")\n", " for subsymbol in bound_symbol.subsymbols:\n", " print(f\" id={subsymbol.sym.id} |{subsymbol}|\")" ] }, { "cell_type": "markdown", "id": "2b2507d2", "metadata": {}, "source": [ "Let's see what happens if we try to compile a function that uses our new primitive and run it." ] }, { "cell_type": "code", "execution_count": 11, "id": "41566de2-a60f-4c87-a3d6-58e6a89dc38b", "metadata": {}, "outputs": [], "source": [ "cfun = thunder.jit(fun)" ] }, { "cell_type": "code", "execution_count": 12, "id": "24af4b99", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "call sincos_meta(TensorProxy(name=t_0, shape=(1,), dtype=float32, device=cpu))\n", "|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cpu), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cpu))\n", "\n", "call sincos_impl(Tensor(shape=torch.Size([1]), stride=(1,), dtype=torch.float32, device=cpu) with values tensor([0.1413]))\n", "|<- sincos_impl = (Tensor(shape=torch.Size([1]), stride=(1,), dtype=torch.float32, device=cpu) with values tensor([0.1408]), Tensor(shape=torch.Size([1]), stride=(1,), dtype=torch.float32, device=cpu) with values tensor([0.9900]))\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of unknown type NoneType, which may or may not be safe. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!\n", " warnings.warn(s)\n", "/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of type bool, which is not identified as an input. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!\n", " warnings.warn(s)\n", "/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of unknown type SequenceIter, which may or may not be safe. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!\n", " warnings.warn(s)\n", "/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of type int, which is not identified as an input. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!\n", " warnings.warn(s)\n", "/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of unknown type NotImplementedType, which may or may not be safe. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!\n", " warnings.warn(s)\n", "/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of unknown type StopIteration, which may or may not be safe. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!\n", " warnings.warn(s)\n" ] }, { "data": { "text/plain": [ "tensor([0.7666])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cfun(a, b)" ] }, { "cell_type": "markdown", "id": "d7cec09d", "metadata": {}, "source": [ "Let's check how our function is represented in the execution trace now (change to `thunder.last_traces(cfun)[0]` to see the trace before transformations)" ] }, { "cell_type": "code", "execution_count": 13, "id": "a7ff30ef", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "# Constructed by Delete Last Used (took 0 milliseconds)\n", "import torch\n", "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", "@no_autocast\n", "def computation(a, b):\n", " # a: \"cpu f32[1]\" \n", " # b: \"cpu f32[1]\" \n", " (res, cos) = sincos(a)\n", " del a\n", " result = torch.add(res, cos) # result: \"cpu f32[1]\"\n", " # result = ltorch.add(res, cos, alpha=None) # result: \"cpu f32[1]\"\n", " # result = prims.add(res, cos) # result: \"cpu f32[1]\"\n", " del res, cos\n", " t3 = torch.add(result, b) # t3: \"cpu f32[1]\"\n", " # t3 = ltorch.add(result, b, alpha=None) # t3: \"cpu f32[1]\"\n", " # t3 = prims.add(result, b) # t3: \"cpu f32[1]\"\n", " del result, b\n", " return t3" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "thunder.last_traces(cfun)[-1]" ] }, { "cell_type": "markdown", "id": "35b71375", "metadata": {}, "source": [ "For a peek under the hood, we can also first create a new symbol (without reference to an executor) and then register an executor for that.\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "f28094bb", "metadata": {}, "outputs": [], "source": [ "from thunder.core.symbol import Symbol\n", "@log\n", "def sincos_meta(input):\n", " return (TensorProxy(like=input), TensorProxy(like=input))\n", "\n", "# this gives a nice, unique, printable id\n", "class CustomOps(Enum):\n", " sincos2 = 0\n", "\n", "sincos2 = Symbol(\n", " id=CustomOps.sincos2,\n", " name=\"sincos2\",\n", " meta=sincos_meta,\n", " is_prim=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 15, "id": "7fbab758", "metadata": {}, "outputs": [], "source": [ "def fun2(a, b):\n", " sin, cos = sincos2(a)\n", " return sin + cos + b\n", "\n", "cfun2 = thunder.jit(fun2)" ] }, { "cell_type": "code", "execution_count": 16, "id": "950d74ad", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "call sincos_meta(TensorProxy(name=t_0, shape=(1,), dtype=float32, device=cpu))\n", "|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cpu), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cpu))\n", "\n", "Failed to find an executor for bound symbol bsym=(res, cos) = __main__.sincos2(a)\n" ] } ], "source": [ "try:\n", " cfun2(a, b)\n", "except RuntimeError as e:\n", " print(e)" ] }, { "cell_type": "markdown", "id": "aadcf2a9", "metadata": {}, "source": [ "There's no registered executor for `sincos` so we need to register an executor for our new primitive. Let's do that." ] }, { "cell_type": "markdown", "id": "995febba", "metadata": {}, "source": [ "Check out the \"adding-operator-executor.ipynb\" notebook to see how to implement an executor for a Symbol." ] }, { "cell_type": "code", "execution_count": 17, "id": "956a4a6b", "metadata": {}, "outputs": [], "source": [ "@log\n", "def checker_sincos2(a):\n", " # We allow the sincos function to be called with any tensor\n", " return True\n", "\n", "@log\n", "def executor_sincos2(a):\n", " # we need to have something here works with TensorProxies during the transformations,\n", " # so we need to functions from thunder.torch or thunder.clang or other Symbols \n", " return thunder.torch.sin(a), thunder.torch.cos(a)\n", "\n", "sincos_executor.register_implementation(sincos2, checker=checker_sincos2, execution_transform=executor_sincos2)\n" ] }, { "cell_type": "code", "execution_count": 18, "id": "1c77c508", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "call sincos_meta(TensorProxy(name=t_0, shape=(1,), dtype=float32, device=cpu))\n", "|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cpu), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cpu))\n", "\n", "call checker_sincos2(TensorProxy(name=a, shape=(1,), dtype=float32, device=cpu))\n", "|<- checker_sincos2 = True\n", "\n", "call executor_sincos2(TensorProxy(name=a, shape=(1,), dtype=float32, device=cpu))\n", "|<- executor_sincos2 = (TensorProxy(name=t4, shape=(1,), dtype=float32, device=cpu), TensorProxy(name=t5, shape=(1,), dtype=float32, device=cpu))\n", "\n" ] }, { "data": { "text/plain": [ "tensor([0.7666])" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Let's try again\n", "cfun2 = thunder.jit(fun2)\n", "cfun2(a, b)" ] }, { "cell_type": "code", "execution_count": 19, "id": "f9797cf2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "# Constructed by Delete Last Used (took 0 milliseconds)\n", "import torch\n", "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", "@no_autocast\n", "def computation(a, b):\n", " # a: \"cpu f32[1]\" \n", " # b: \"cpu f32[1]\" \n", " res = torch.sin(a) # res: \"cpu f32[1]\"\n", " # res = ltorch.sin(a) # res: \"cpu f32[1]\"\n", " # res = prims.sin(a) # res: \"cpu f32[1]\"\n", " cos = torch.cos(a) # cos: \"cpu f32[1]\"\n", " # cos = ltorch.cos(a) # cos: \"cpu f32[1]\"\n", " # cos = prims.cos(a) # cos: \"cpu f32[1]\"\n", " del a\n", " result = torch.add(res, cos) # result: \"cpu f32[1]\"\n", " # result = ltorch.add(res, cos, alpha=None) # result: \"cpu f32[1]\"\n", " # result = prims.add(res, cos) # result: \"cpu f32[1]\"\n", " del res, cos\n", " t3 = torch.add(result, b) # t3: \"cpu f32[1]\"\n", " # t3 = ltorch.add(result, b, alpha=None) # t3: \"cpu f32[1]\"\n", " # t3 = prims.add(result, b) # t3: \"cpu f32[1]\"\n", " del result, b\n", " return t3" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Let's check how our function is represented in the execution trace now\n", "thunder.last_traces(cfun2)[-1]" ] }, { "cell_type": "markdown", "id": "122ead11", "metadata": {}, "source": [ "That's it! We've created our custom operator and registered an executor for it. To recap, we've done the following:\n", "* Created a new Symbol called `sincos` that represents the sine and cosine\n", " computation (but not the actual computation itself). All we know about it is\n", " that it takes a tensor as input and returns a tuple of two tensors. We gave this Symbol a name and id attributes to identify it in the trace and when processing the trace.\n", "* Implemented the actual computation by calling PyTorch's `sin` and `cos` functions." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 5 }