{ "cells": [ { "attachments": { "image.png": { "image/png": "" } }, "cell_type": "markdown", "id": "e5aab976", "metadata": { "hideCode": false, "hideOutput": false, "hidePrompt": false }, "source": [ "# Extend Thunder with CUDA-Python\n", "\n", "In this demo, we implement a (naive, unoptimized) version of the flash attention 2 algorithm from \n", "[Tri Dao: FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://tridao.me/publications/flash2/flash2.pdf), following this pseudocode which we took from the paper.\n", "\n", "Our implementation won't be quite as fast as flash attention, but we will learn how to use a CUDA kernel from Python/PyTorch/Thunder by extending Thunder with NVIDIA's [CUDA-Python](https://github.com/NVIDIA/cuda-python) low-level bindings, and then you can do it for your own, even more awesome kernels.\n", "\n", "There is not much special about the Thunder part of this, so if you looked at the extending Thunder section in the [Zero to Thunder tutorial](./zero_to_thunder.ipynb) things should look very familiar, but as CUDA-Python is relatively new, we thought it might be neat to have a spelled-out example here.\n", "\n", "![image.png](attachment:image.png)\n", "\n", "OK, Now we know what to do. Let's import some modules and get a few sample inputs." ] }, { "cell_type": "code", "execution_count": 1, "id": "718c1f3e", "metadata": { "hideCode": false, "hideOutput": false, "hidePrompt": false }, "outputs": [], "source": [ "import torch, math, itertools, numpy\n", "\n", "N_inp = 512\n", "N_out = 512\n", "d = 128\n", "\n", "with torch.device(\"cuda\"):\n", " Q = torch.randn(96, N_out, d)\n", " K = torch.randn(96, N_inp, d)\n", " V = torch.randn(96, N_inp, d)" ] }, { "cell_type": "markdown", "id": "86d4e193", "metadata": { "hideCode": false, "hideOutput": false, "hidePrompt": false }, "source": [ "The first thing we do is implement a quite literal translation of the pseudo code into a Python function using tensors for the tiles.\n", "We are not terribly ambitious here and assume that `B_c`and `B_r` divide N_inp and N_out, that N_inp and N_out are actually the same etc.\n", "\n", "You might improve the generality (and we welcome your PR)." ] }, { "cell_type": "code", "execution_count": 2, "id": "158809ac", "metadata": { "hideCode": false, "hideOutput": false, "hidePrompt": false }, "outputs": [], "source": [ "def flash_attention_reference(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, is_causal: bool = False, scale: float | None =None):\n", " # N.B.: This uses the PyTorch SDPA tensor shape of batch, head_no, seq_len, head_dim\n", " \n", " *batch, N_inp, d = K.shape\n", " *_, N_out, _ = Q.shape\n", "\n", " # assert shape compat\n", " \n", " \n", " O = V.new_zeros(*batch, N_out, d)\n", " L = V.new_zeros(*batch, N_out, 1)\n", "\n", " dtype = O.dtype\n", " device = O.device\n", "\n", " neginf = torch.tensor(-math.inf, dtype=Q.dtype, device=Q.device)\n", "\n", " B_c = 16 # this is NOT what the original impl uses\n", " B_r = 16\n", " T_c = (N_inp + B_c - 1) // B_c\n", " T_r = (N_out + B_r - 1) // B_r\n", "\n", " if scale is None:\n", " scale = 1 / math.sqrt(d)\n", "\n", " for block in itertools.product(*(range(s) for s in batch)):\n", " # Q and O L split into T_r; K, V in T_c blocks\n", " for i in range(T_r):\n", " Q_i = Q[block][i * B_r: (i+1) * B_r]\n", " O_i = torch.zeros(B_r, d, device=device, dtype=dtype)\n", " l_i = torch.zeros(B_r, 1, device=device, dtype=dtype)\n", " m_i = torch.full((B_r, 1), -math.inf, device=device, dtype=dtype)\n", " last_m_i = m_i\n", " for j in range(T_c):\n", " if is_causal and j * B_c > (i+1) * B_r - 1:\n", " break\n", " # in Python 3.11+ we could write K[*block, j * B_c: (j + 1) * B_c] instead...\n", " K_j = K[block][j * B_c: (j + 1) * B_c]\n", " V_j = V[block][j * B_c: (j + 1) * B_c]\n", " S_i = scale * (Q_i @ K_j.T)\n", " if is_causal and i * B_r < (j + 1) * B_c - 1:\n", " mask = torch.arange(i*B_r, (i+1)*B_r, device=device, dtype=dtype)[:, None] >= torch.arange(j*B_c, (j+1) * B_c, device=device, dtype=dtype)[None, :]\n", " S_i = torch.where(mask, S_i, neginf)\n", "\n", " m_i = torch.maximum(m_i, S_i.max(dim=-1, keepdim=True).values)\n", " P_i = torch.exp(S_i - m_i)\n", " l_i = torch.exp(last_m_i - m_i) * l_i + P_i.sum(dim=-1, keepdim=True)\n", " O_i = torch.exp(last_m_i - m_i) * O_i + P_i @ V_j\n", " last_m_i = m_i\n", " O_i = (1.0 / l_i) * O_i\n", " L_i = m_i + torch.log(l_i)\n", " O[block][i * B_r: (i + 1) * B_r] = O_i\n", " L[block][i * B_r: (i + 1) * B_r] = L_i\n", " return O, L\n" ] }, { "cell_type": "markdown", "id": "59ae13ec", "metadata": {}, "source": [ "Let's see if our function computes the same thing as the PyTorch `scaled_dot_product_attention`." ] }, { "cell_type": "code", "execution_count": 3, "id": "ab9ee99e", "metadata": { "hideCode": false, "hideOutput": false, "hidePrompt": false }, "outputs": [], "source": [ "actual, _ = flash_attention_reference(Q, K, V)" ] }, { "cell_type": "code", "execution_count": 4, "id": "b4648ba0", "metadata": { "hideCode": false, "hideOutput": false, "hidePrompt": false }, "outputs": [], "source": [ "expected = torch.nn.functional.scaled_dot_product_attention(Q, K, V)" ] }, { "cell_type": "code", "execution_count": 5, "id": "f0f8ea7f", "metadata": { "hideCode": false, "hideOutput": false, "hidePrompt": false }, "outputs": [ { "data": { "text/plain": [ "tensor(1.1921e-06, device='cuda:0')" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(actual - expected).abs().max()" ] }, { "cell_type": "markdown", "id": "94d3b173", "metadata": { "hideCode": false, "hideOutput": false, "hidePrompt": false }, "source": [ "That is neat! But we wanted to write our own CUDA kernel, so let us get out CUDA-Python." ] }, { "cell_type": "markdown", "id": "31cc1ead-9039-42e2-945b-4986161d42fd", "metadata": { "hideCode": false, "hideOutput": false, "hidePrompt": false }, "source": [ "## Using CUDA-Python to compile CUDA kernels for PyTorch\n", "\n", "[CUDA-Python](https://github.com/NVIDIA/cuda-python) provides low-level bindings to the CUDA and NVRTC (NVIDIA Run Time Compiler) API. We install it with `pip install cuda-python`.\n", "\n", "As those functions are very (I mean extremely) low level, we provide here a couple of helper functions for the following:\n", "\n", "- The function `compile_program_and_get_kernel` takes source code and produces a CUDA kernel from it,\n", "- The function `launch_kernel` runs one of our kernels with the specified arguments.\n", "\n", "We first import the `cuda` and `nvrtc` modules." ] }, { "cell_type": "code", "execution_count": null, "id": "471a1133", "metadata": { "hideCode": false, "hideOutput": false, "hidePrompt": false }, "outputs": [], "source": [ "from cuda.bindings import driver, nvrtc" ] }, { "cell_type": "markdown", "id": "36a37fe6", "metadata": { "hideCode": false, "hideOutput": false, "hidePrompt": false }, "source": [ "The function `compile_program_and_get_kernel` compiles the source code through NVRTC obtaining a PTX. This\n", "is then loaded (and compiled to SASS by `cuModuleLoadData`).\n", "Quite likely, one would want to let users access some of the other bits (e.g. compile flags), but we have to stop somewhere...\n" ] }, { "cell_type": "code", "execution_count": null, "id": "017cd520", "metadata": { "hideCode": false, "hideOutput": false, "hidePrompt": false }, "outputs": [], "source": [ "def compile_program_and_get_kernel(cuda_src, function_name):\n", " \"\"\"\n", " Compiles a kernel from the CUDA source code provided in the string `cuda_src` and get the kernel with the name `function_name` (which needs to be\n", " defined as extern \"C\" in the CUDA source code).\n", " \n", " \n", " The kernel can then be launched with `launch_kernel`\n", " \"\"\"\n", "\n", " def check_error(results):\n", " err, *results = results\n", " if isinstance(err, driver.CUresult):\n", " if err != driver.CUresult.CUDA_SUCCESS:\n", " _, err_str = driver.cuGetErrorString(err)\n", " raise RuntimeError(f\"CUDA error: {err_str.decode()}\")\n", " elif isinstance(err, nvrtc.nvrtcResult):\n", " if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:\n", " logSize = check_error(nvrtc.nvrtcGetProgramLogSize(prog))\n", " log = b\" \" * logSize\n", " check_error(nvrtc.nvrtcGetProgramLog(prog, log))\n", " print(log.decode())\n", " _, err_str = nvrtc.nvrtcGetErrorString(err)\n", " raise RuntimeError(f\"NVRTC error: {err_str.decode()}\")\n", " else:\n", " raise TypeError(\"Unknown error type: {err}\")\n", " if len(results) == 0:\n", " return\n", " if len(results) == 1:\n", " return results[0]\n", " return results\n", "\n", " torch.cuda.current_stream() # this initializes the device context for us. we don't need the stream specifically.\n", " \n", " # Create program\n", " prog = check_error(nvrtc.nvrtcCreateProgram(str.encode(cuda_src), (function_name + '.cu').encode(), 0, [], [])) \n", " \n", " # Compile program\n", " min, maj = torch.cuda.get_device_capability()\n", " opts = [f\"--gpu-architecture=compute_{min}{maj}\".encode()] #, b\"--expt-relaxed-constexpr\"]\n", " check_error(nvrtc.nvrtcCompileProgram(prog, len(opts), opts))\n", " \n", " ## Get PTX from compilation\n", " ptxSize = check_error(nvrtc.nvrtcGetPTXSize(prog))\n", " ptx = b\" \" * ptxSize\n", " check_error(nvrtc.nvrtcGetPTX(prog, ptx))\n", "\n", " # Load PTX as module data and retrieve function\n", " module = check_error(driver.cuModuleLoadData(ptx))\n", " kernel = check_error(driver.cuModuleGetFunction(module, function_name.encode()))\n", " return kernel" ] }, { "cell_type": "markdown", "id": "eaa797c6", "metadata": { "hideCode": false, "hideOutput": false, "hidePrompt": false }, "source": [ "Phew. If you're into CUDA programming (I guess you are, but you might check out the [CUDA-MODE](https://github.com/cuda-mode/resource-stream) series if you want to learn more. In fact, this notebooks started as a demo for a [lecture there](https://www.youtube.com/watch?v=zEuwuCTEf_0)), you know that to launch you need to specify block (the number of threads as a 3d \"array\") and grid (the number of blocks, again as a 3d \"array\") layout as well as dynamic shared memory.\n", "\n", "Another important detail is how we need to pass kernel arguments to `cuLaunchKernel`: We need to set up an array of pointers with the pointers pointing to a (CPU) memory address that contains the parameter (which, in the case of tensors, is a pointer itself). To facilitate having type information, we use numpy scalar types (e.g. `numpy.float32(0.5)`) for the arguments. Note that we don't check whether the kernel actually takes the parameters you give it.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4628c9c2", "metadata": { "hideCode": false, "hideOutput": false, "hidePrompt": false }, "outputs": [], "source": [ "def launch_kernel(kernel, grid, block, /, *args, shmem=0):\n", " \"\"\"utility function to launch kernels.\n", " Args can be tensors (corresponding to float* etc kernel params or numpy scalars (which have precision info))\n", " \"\"\"\n", " \n", " # collect values (data_ptr as uint64 array for tensors, the values as an array for values)\n", " addresses = []\n", " wrapped_args = []\n", " for a in args:\n", " if isinstance(a, torch.Tensor):\n", " # for tensor pass in data_ptr\n", " wrapped_args.append(numpy.array(a.data_ptr(), dtype=numpy.uint64))\n", " elif isinstance(a, numpy.number):\n", " wrapped_args.append(numpy.array([a]))\n", " else:\n", " raise TypeError(\"please only pass tensors and numpy numbers to launch_kernel\")\n", " \n", " # assemble an array of pointers to the args\n", " args = numpy.array([a.ctypes.data for a in wrapped_args], dtype=numpy.uint64)\n", "\n", " # set up grid / block layout to be 3d\n", " grid = tuple(grid)\n", " block = tuple(block)\n", " assert 1 <= len(block) <= 3 and 1 <= len(grid) <= 3\n", " grid = grid + (3 - len(grid)) * (1,)\n", " block = block + (3 - len(block)) * (1,)\n", " \n", " # Launch!\n", " err, = driver.cuLaunchKernel(\n", " kernel,\n", " *grid, *block, # xyz each\n", " shmem, # dynamic shared memory\n", " torch.cuda.current_stream().stream_id, # stream\n", " args.ctypes.data, # kernel arguments\n", " 0, # extra (ignore)\n", " )\n", " if err != driver.CUresult.CUDA_SUCCESS:\n", " raise RuntimeError(f\"CUDA error: {err}\")" ] }, { "cell_type": "markdown", "id": "3aace14d", "metadata": { "hideCode": false, "hideOutput": false, "hidePrompt": false }, "source": [ "## A native flash attention kernel\n", "\n", "With these two done, we can implement our flash attention kernel. As with matmul, the tiling is important. But given head size can be large (128 for LLama-2 7B), using tiles that are large also in the other dimensions puts quite a strain on our on-chip memory resources (we should move to 16 bit floats, really - will you send a PR?).\n", "We put tiles for `Q`, `K`, `V` and `S` (a tile of the matrix of the attention weights or intermediate results for it) in shared memory and `l` and `m` (the denominator of the softmax split as factor and maximum of the log for the stabilization) and `O` tiles into the registers. The other bits are more or less a spelled-out version of the Python version, with a bit more nuisance to implement array operations by spreading across threads and/or loops.\n", "\n", "Again, we have many gaps (for shapes that do not divide tile sizes etc. and limitations (no causal yet)), but maybe you find the general idea helpful.\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "73570568", "metadata": { "hideCode": false, "hideOutput": false, "hidePrompt": false }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cuda_d = 128\n", "cuda_B_r = 32\n", "cuda_B_c = 16\n", "\n", "\n", "cuda_src = (\n", "f\"\"\"\n", "constexpr int B_r = {cuda_B_r};\n", "constexpr int B_c = {cuda_B_c};\n", "constexpr int d = {cuda_d};\n", "constexpr int o_per_thread_x = 1;\n", "constexpr int o_per_thread_y = 128/32;\n", "\"\"\"\n", " + r\"\"\"\\\n", "#define NEG_INFINITY __int_as_float(0xff800000)\n", "\n", "extern \"C\" __global__\n", "void silly_attn(float *out, float* out_l, float *K, float *Q, float* V, float scaling, int batch_stride, int T_r, int T_c)\n", "{\n", " int tid_x = threadIdx.x;\n", " int tid_y = threadIdx.y;\n", " int batch_offset = batch_stride * blockIdx.x;\n", "\n", " __shared__ float Q_i[B_r][d];\n", " __shared__ float K_j[B_c][d];\n", " __shared__ float V_j[B_c][d];\n", " \n", " __shared__ float S_i[B_r][B_c];\n", "\n", " float l_i[o_per_thread_x];\n", " float m_i[o_per_thread_x];\n", " float O_i[o_per_thread_x][o_per_thread_y];\n", "\n", " for (int i = 0; i < T_r; i++) {\n", " for (int ii = 0; ii < o_per_thread_x; ii++) {\n", " for (int dd = 0; dd < o_per_thread_y; dd++) {\n", " O_i[ii][dd] = 0;\n", " }\n", " l_i[ii] = 0.f;\n", " m_i[ii] = NEG_INFINITY;\n", " }\n", " for (int ii = tid_y; ii < B_r; ii += blockDim.y) {\n", " for (int dd = tid_x; dd < d; dd += blockDim.x) {\n", " Q_i[ii][dd] = Q[batch_offset + (ii + i * B_r) * d + dd];\n", " }\n", " }\n", " for (int j=0; j < T_c; j++) {\n", " __syncthreads();\n", " for (int jj=tid_y; jj < B_c; jj+= blockDim.y) {\n", " for (int dd=tid_x; dd < d; dd += blockDim.x) {\n", " K_j[jj][dd] = K[batch_offset + (jj + j * B_c) * d + dd];\n", " V_j[jj][dd] = V[batch_offset + (jj + j * B_c) * d + dd];\n", " }\n", " }\n", " __syncthreads();\n", " // S_i = scale * (Q_i @ K_j.T)\n", " for (int ii = tid_x; ii < B_r; ii += blockDim.x) {\n", " for (int jj = tid_y; jj < B_c; jj += blockDim.y) {\n", " float S_ij = 0.f;\n", " for (int dd = 0; dd < d; dd++) {\n", " S_ij += Q_i[ii][dd] * K_j[jj][dd];\n", " }\n", " S_ij = scaling * S_ij;\n", " S_i[ii][jj] = S_ij;\n", " }\n", " }\n", " __syncthreads();\n", " for (int ii = 0; ii < o_per_thread_x; ii++) {\n", " float m = m_i[ii];\n", " float last_m = m;\n", " for (int jj = 0; jj < B_c; jj += 1) {\n", " if (m < S_i[ii * blockDim.x + tid_x][jj]) {\n", " m = S_i[ii * blockDim.x + tid_x][jj];\n", " }\n", " }\n", " m_i[ii] = m;\n", " float l = exp(last_m - m) * l_i[ii];\n", " for (int dd = 0; dd < o_per_thread_y; dd++) {\n", " O_i[ii][dd] *= exp(last_m - m);\n", " }\n", " \n", " for (int jj = 0; jj < B_c; jj ++) {\n", " float S_ij = exp(S_i[ii * blockDim.x + tid_x][jj] - m);\n", " l += S_ij;\n", " for (int dd = 0; dd < o_per_thread_y; dd++) {\n", " O_i[ii][dd] += S_ij * V_j[jj][dd * blockDim.y + tid_y];\n", " }\n", " }\n", " l_i[ii] = l;\n", "\n", " }\n", " }\n", " for (int ii = 0; ii < o_per_thread_x; ii++) {\n", " for (int dd = 0; dd < o_per_thread_y; dd++) {\n", " out[batch_offset + (ii * blockDim.x + tid_x + i * B_r) * d + dd * blockDim.y + tid_y] = O_i[ii][dd] / l_i[ii];\n", " out_l[batch_offset / d + ii * blockDim.x + tid_x + i * B_r] = l_i[ii];\n", " }\n", " }\n", " }\n", "}\n", "\"\"\")\n", "\n", "\n", "cuda_flash_attention_kernel = compile_program_and_get_kernel(cuda_src, \"silly_attn\")\n", "\n", "cuda_flash_attention_kernel" ] }, { "cell_type": "markdown", "id": "baf44f9b", "metadata": { "hideCode": false, "hideOutput": false, "hidePrompt": false }, "source": [ "With this, we need a wrapper that calls our kernel. This basically prepares inputs and allocates outputs, defines the block and grid layout and calls `launch_kernel`." ] }, { "cell_type": "code", "execution_count": 10, "id": "8ef456ab", "metadata": { "hideCode": false, "hideOutput": false, "hidePrompt": false }, "outputs": [ { "data": { "text/plain": [ "tensor(1.4305e-06, device='cuda:0')" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def cuda_python_flash_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, is_causal: bool = False, scale: float | None = None):\n", " assert Q.device.type == 'cuda'\n", " \n", " if is_causal:\n", " raise NotImplementedError(\"cuda_python_flash_attention is_causal=True is not implemented\")\n", " \n", " *batch, N_inp, d = K.shape\n", " *_, N_out, _ = Q.shape\n", "\n", " assert d == cuda_d\n", " \n", " Q_3d = Q.reshape(-1, N_out, d)\n", " K_3d = K.reshape(-1, N_inp, d)\n", " V_3d = V.reshape(-1, N_inp, d)\n", " \n", " blocks = Q_3d.shape[0]\n", "\n", " # assert shape compat\n", " \n", " O = V.new_zeros(*batch, N_out, d)\n", " L = V.new_zeros(*batch, N_out, 1)\n", "\n", " O_3d = O.view(-1, N_out, d)\n", " L_3d = L.view(-1, N_out, 1)\n", " \n", " T_c = (N_inp + cuda_B_c - 1) // cuda_B_c\n", " T_r = (N_out + cuda_B_r - 1) // cuda_B_r\n", "\n", " if scale is None:\n", " scale = 1 / math.sqrt(d)\n", "\n", " assert N_inp % cuda_B_r == 0 # TODO\n", " assert N_out == N_inp # TODO\n", "\n", " GRID = (blocks,)\n", " BLOCK = (32, 32)\n", "\n", " launch_kernel(cuda_flash_attention_kernel, GRID, BLOCK, O_3d, L_3d, K_3d, Q_3d, V_3d, numpy.float32(scale), \n", " numpy.int32(N_inp * d), numpy.int32(T_r), numpy.int32(T_c))\n", " return O, L\n", "\n", "\n", "Qc = Q.cuda()\n", "Kc = K.cuda()\n", "Vc = V.cuda()\n", "\n", "actual, _ = cuda_python_flash_attention(Qc, Kc, Vc)\n", "expected = torch.nn.functional.scaled_dot_product_attention(Qc, Kc, Vc)\n", "\n", "(actual - expected).abs().max()\n" ] }, { "cell_type": "markdown", "id": "8cce98d7", "metadata": {}, "source": [ "So how slow is it? Quite a lot slower. Depending on the input sizes, this can be an order of magnitude.\n", "But hey, it's an optimization opportunity." ] }, { "cell_type": "code", "execution_count": 11, "id": "be789298", "metadata": { "hideCode": false, "hideOutput": false, "hidePrompt": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.1 ms ± 911 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n", "14.9 ms ± 1.63 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "%timeit torch.nn.functional.scaled_dot_product_attention(Qc, Kc, Vc); torch.cuda.synchronize()\n", "%timeit cuda_python_flash_attention(Qc, Kc, Vc); torch.cuda.synchronize()" ] }, { "cell_type": "markdown", "id": "b2995e26-6aea-4c3d-b830-e0ac44ffe224", "metadata": { "hideCode": false, "hidePrompt": false }, "source": [ "# Running our kernel in Thunder\n", "\n", "We want to have our kernel handle calls to `scalar_dot_product_attention` where it applies.\n", "Fortunately, this is much easier than getting the kernel itself.\n", "\n", "We start with having our own executor.\n", "\n", "## Create a thunder executor\n", "\n", "We create an `OperatorExecutor` and register it as a default executor.\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "9d85f00b", "metadata": { "hideCode": false, "hidePrompt": false }, "outputs": [ { "data": { "text/plain": [ "[attn_ex, sdpa, nvfuser]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import thunder\n", "\n", "attn_ex = thunder.extend.OperatorExecutor('attn_ex', version=0.01)\n", "thunder.add_default_executor(attn_ex)" ] }, { "cell_type": "markdown", "id": "8f00672e-f88e-43a8-a9b6-dbda1f0f8e80", "metadata": { "hideCode": false, "hidePrompt": false }, "source": [ "## Register our implementation as an operator\n", "\n", "The next thing we do is to register our implementation as an executor. We use our implementation above for the execution function (the `fn` parameter) and provide a short meta describing the result metadata as determined by the input metadata." ] }, { "cell_type": "code", "execution_count": 13, "id": "17b1d735-1b03-4ae5-aaca-1cb151a7bd48", "metadata": { "hideCode": false, "hidePrompt": false }, "outputs": [], "source": [ "def my_attn_meta(query, key, value, is_causal, scale):\n", " return thunder.TensorProxy(like=query), thunder.TensorProxy(like=query, shape=(*query.shape[:1], 1))\n", "\n", "my_attn = attn_ex.register_operator('my_attn', meta=my_attn_meta, fn=cuda_python_flash_attention)" ] }, { "cell_type": "markdown", "id": "ae94b369-5bde-4784-8de2-13c3543e65cf", "metadata": { "hideCode": false, "hidePrompt": false }, "source": [ "## Register our attention as an implementation of torch sdpa\n", "\n", "But to have Thunder automatically use our implementation we need to tell it that it implements sdpa.\n", "Our checker function takes the same arguments as PyTorch sdpa and returns `True` or `False` depending on whether our implementation applies. It makes sure we do not take variants we do not support in our implementation (non-cuda, or with causal or bespoke marking or dropout).\n", "\n", "The execution transform is also just a function that again takes the same inputs as PyTorch sdpa but has the same returns as well. The function itself is very basic, just calling the symbol we registered with Thunder.\n", "\n", "If we had a backward, we could not register a grad transform with `register_implementation` as well (see the [zero to thunder tutorial](./zero_to_thunder.ipynb) or the [extending thunder tutorial](./dev_tutorials/extend.ipynb) for an example with grad transform.\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "77d1e584-af9c-4d18-9136-d40d645435cb", "metadata": { "hideCode": false, "hidePrompt": false }, "outputs": [], "source": [ "def my_attn_checker(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):\n", " if attn_mask is not None or dropout_p != 0.0 or is_causal:\n", " return False\n", " return (query.device.type == 'cuda' and \n", " key.device == query.device and\n", " value.device == query.device)\n", "\n", "def my_attn_transform(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):\n", " if scale is None:\n", " d = query.shape[-1]\n", " scale = d**(-0.5)\n", " out = my_attn(query, key, value, is_causal, scale)\n", " return out[0]\n", "\n", "\n", "attn_ex.register_implementation(thunder.torch.scaled_dot_product_attention, checker=my_attn_checker,\n", " execution_transform=my_attn_transform)" ] }, { "cell_type": "markdown", "id": "23bf790e-afdc-4dcb-9a67-c5ea3ca54f8a", "metadata": { "hideCode": false, "hidePrompt": false }, "source": [ "## Run...\n", "\n", "Now we are ready to run models with our implementation. To keep things simple, we just use a function calling the PyTorch attention function, but you could also use your favourite LLM from [LitGPT](https://github.com/Lightning-AI/litgpt) here." ] }, { "cell_type": "code", "execution_count": 15, "id": "0ddab1c9", "metadata": { "hideCode": false, "hidePrompt": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(1.4305e-06, device='cuda:0')\n" ] } ], "source": [ "def test_fn(query, key, value):\n", " return torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=False)\n", "\n", "jfn = thunder.jit(test_fn)\n", "\n", "print((jfn(Qc, Kc, Vc) - test_fn(Qc, Kc, Vc)).abs().max())" ] }, { "cell_type": "markdown", "id": "95ee118e", "metadata": { "hideCode": false, "hidePrompt": false }, "source": [ "# Inspect\n", "\n", "Using `thunder.last_traces` we can look at what happened. The last trace in the list returned by this function is the fully transformed program and uses our new function. " ] }, { "cell_type": "code", "execution_count": 16, "id": "efa29992", "metadata": { "hideCode": false, "hidePrompt": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "# Constructed by Delete Last Used (took 0 milliseconds)\n", "import torch\n", "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", "@no_autocast\n", "def computation(query, key, value):\n", " # query: \"cuda:0 f32[96, 512, 128]\"\n", " # key: \"cuda:0 f32[96, 512, 128]\"\n", " # value: \"cuda:0 f32[96, 512, 128]\"\n", " (t13, _) = my_attn(query, key, value, False, 0.08838834764831845)\n", " del query, key, value\n", " return t13\n" ] } ], "source": [ "print(thunder.last_traces(jfn)[-1])" ] }, { "cell_type": "markdown", "id": "062ca49f", "metadata": {}, "source": [ "The first trace contains the program as captured by Thunder, so it still has the call to PyTorch SDPA that is then translated. " ] }, { "cell_type": "code", "execution_count": 17, "id": "44378e68", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "import thunder\n", "import thunder.torch as ltorch\n", "import torch\n", "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", "@no_autocast\n", "def computation(query, key, value):\n", " # query: \"cuda:0 f32[96, 512, 128]\"\n", " # key: \"cuda:0 f32[96, 512, 128]\"\n", " # value: \"cuda:0 f32[96, 512, 128]\"\n", " t13 = ltorch.scaled_dot_product_attention(query, key, value, None, 0.0, False, scale=None) # t13: \"cuda:0 f32[96, 512, 128]\"\n", " # t0 = ltorch.mul(query, 0.29730177875068026) # t0: \"cuda:0 f32[96, 512, 128]\"\n", " # t0 = prims.mul(query, 0.29730177875068026) # t0: \"cuda:0 f32[96, 512, 128]\"\n", " # t1 = ltorch.transpose(key, -2, -1) # t1: \"cuda:0 f32[96, 128, 512]\"\n", " # t1 = prims.transpose(key, (0, 2, 1)) # t1: \"cuda:0 f32[96, 128, 512]\"\n", " # t2 = ltorch.mul(t1, 0.29730177875068026) # t2: \"cuda:0 f32[96, 128, 512]\"\n", " # t2 = prims.mul(t1, 0.29730177875068026) # t2: \"cuda:0 f32[96, 128, 512]\"\n", " # t3 = ltorch.matmul(t0, t2) # t3: \"cuda:0 f32[96, 512, 512]\"\n", " # t3 = prims.matmul(t0, t2) # t3: \"cuda:0 f32[96, 512, 512]\"\n", " # t12 = ltorch.softmax(t3, -1, dtype=None) # t12: \"cuda:0 f32[96, 512, 512]\"\n", " # t5 = ltorch.amax(t3, -1, True) # t5: \"cuda:0 f32[96, 512, 1]\"\n", " # t4 = prims.amax(t3, (2,)) # t4: \"cuda:0 f32[96, 512]\"\n", " # t5 = prims.broadcast_in_dim(t4, [96, 512, 1], [0, 1]) # t5: \"cuda:0 f32[96, 512, 1]\"\n", " # t7 = ltorch.sub(t3, t5, alpha=None) # t7: \"cuda:0 f32[96, 512, 512]\"\n", " # t6 = prims.broadcast_in_dim(t5, (96, 512, 512), (0, 1, 2)) # t6: \"cuda:0 f32[96, 512, 512]\"\n", " # t7 = prims.sub(t3, t6) # t7: \"cuda:0 f32[96, 512, 512]\"\n", " # t8 = ltorch.exp(t7) # t8: \"cuda:0 f32[96, 512, 512]\"\n", " # t8 = prims.exp(t7) # t8: \"cuda:0 f32[96, 512, 512]\"\n", " # t10 = ltorch.sum(t8, -1, True, dtype=None) # t10: \"cuda:0 f32[96, 512, 1]\"\n", " # t9 = prims.sum(t8, (2,)) # t9: \"cuda:0 f32[96, 512]\"\n", " # t10 = prims.broadcast_in_dim(t9, [96, 512, 1], [0, 1]) # t10: \"cuda:0 f32[96, 512, 1]\"\n", " # t12 = ltorch.true_divide(t8, t10) # t12: \"cuda:0 f32[96, 512, 512]\"\n", " # t11 = prims.broadcast_in_dim(t10, (96, 512, 512), (0, 1, 2)) # t11: \"cuda:0 f32[96, 512, 512]\"\n", " # t12 = prims.div(t8, t11) # t12: \"cuda:0 f32[96, 512, 512]\"\n", " # t13 = ltorch.matmul(t12, value) # t13: \"cuda:0 f32[96, 512, 128]\"\n", " # t13 = prims.matmul(t12, value) # t13: \"cuda:0 f32[96, 512, 128]\"\n", " return t13\n" ] } ], "source": [ "print(thunder.last_traces(jfn)[0])" ] }, { "cell_type": "markdown", "id": "5af18944", "metadata": {}, "source": [ "# Comparing implementations\n", "\n", "If we want to compare implementations, we can also compile the function without our executor to get the \"default\" implementation." ] }, { "cell_type": "code", "execution_count": 18, "id": "edf2d53b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(1.4305e-06, device='cuda:0')\n" ] } ], "source": [ "jfn_without_attn_ex = thunder.jit(test_fn, executors=[thunder.sdpa_executor, thunder.nvfuser_executor])\n", "\n", "print((jfn(Qc, Kc, Vc) - jfn_without_attn_ex(Qc, Kc, Vc)).abs().max())" ] }, { "cell_type": "code", "execution_count": 19, "id": "edda6d3a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "14.9 ms ± 37.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", "1.12 ms ± 1.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], "source": [ "%timeit jfn(Qc, Kc, Vc) ; torch.cuda.synchronize()\n", "%timeit jfn_without_attn_ex(Qc, Kc, Vc) ; torch.cuda.synchronize()\n" ] }, { "cell_type": "markdown", "id": "70f1b20d", "metadata": {}, "source": [ "And, of course, we can also see this by inspecting the traces:" ] }, { "cell_type": "code", "execution_count": 20, "id": "bf869e3a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "# Constructed by Delete Last Used (took 0 milliseconds)\n", "import torch\n", "import torch.nn.functional\n", "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", "@no_autocast\n", "def computation(query, key, value):\n", " # query: \"cuda:0 f32[96, 512, 128]\"\n", " # key: \"cuda:0 f32[96, 512, 128]\"\n", " # value: \"cuda:0 f32[96, 512, 128]\"\n", " t13 = torch.nn.functional.scaled_dot_product_attention(query, key, value, None, 0.0, False, scale=None) # t13: \"cuda:0 f32[96, 512, 128]\"\n", " # t13 = ltorch.scaled_dot_product_attention(query, key, value, None, 0.0, False, scale=None) # t13: \"cuda:0 f32[96, 512, 128]\"\n", " # t14 = ltorch.mul(query, 0.29730177875068026) # t14: \"cuda:0 f32[96, 512, 128]\"\n", " # t14 = prims.mul(query, 0.29730177875068026) # t14: \"cuda:0 f32[96, 512, 128]\"\n", " # t15 = ltorch.transpose(key, -2, -1) # t15: \"cuda:0 f32[96, 128, 512]\"\n", " # t15 = prims.transpose(key, (0, 2, 1)) # t15: \"cuda:0 f32[96, 128, 512]\"\n", " # t16 = ltorch.mul(t15, 0.29730177875068026) # t16: \"cuda:0 f32[96, 128, 512]\"\n", " # t16 = prims.mul(t15, 0.29730177875068026) # t16: \"cuda:0 f32[96, 128, 512]\"\n", " # t17 = ltorch.matmul(t14, t16) # t17: \"cuda:0 f32[96, 512, 512]\"\n", " # t17 = prims.matmul(t14, t16) # t17: \"cuda:0 f32[96, 512, 512]\"\n", " # t26 = ltorch.softmax(t17, -1, dtype=None) # t26: \"cuda:0 f32[96, 512, 512]\"\n", " # t19 = ltorch.amax(t17, -1, True) # t19: \"cuda:0 f32[96, 512, 1]\"\n", " # t18 = prims.amax(t17, (2,)) # t18: \"cuda:0 f32[96, 512]\"\n", " # t19 = prims.broadcast_in_dim(t18, [96, 512, 1], [0, 1]) # t19: \"cuda:0 f32[96, 512, 1]\"\n", " # t21 = ltorch.sub(t17, t19, alpha=None) # t21: \"cuda:0 f32[96, 512, 512]\"\n", " # t20 = prims.broadcast_in_dim(t19, (96, 512, 512), (0, 1, 2)) # t20: \"cuda:0 f32[96, 512, 512]\"\n", " # t21 = prims.sub(t17, t20) # t21: \"cuda:0 f32[96, 512, 512]\"\n", " # t22 = ltorch.exp(t21) # t22: \"cuda:0 f32[96, 512, 512]\"\n", " # t22 = prims.exp(t21) # t22: \"cuda:0 f32[96, 512, 512]\"\n", " # t24 = ltorch.sum(t22, -1, True, dtype=None) # t24: \"cuda:0 f32[96, 512, 1]\"\n", " # t23 = prims.sum(t22, (2,)) # t23: \"cuda:0 f32[96, 512]\"\n", " # t24 = prims.broadcast_in_dim(t23, [96, 512, 1], [0, 1]) # t24: \"cuda:0 f32[96, 512, 1]\"\n", " # t26 = ltorch.true_divide(t22, t24) # t26: \"cuda:0 f32[96, 512, 512]\"\n", " # t25 = prims.broadcast_in_dim(t24, (96, 512, 512), (0, 1, 2)) # t25: \"cuda:0 f32[96, 512, 512]\"\n", " # t26 = prims.div(t22, t25) # t26: \"cuda:0 f32[96, 512, 512]\"\n", " # t13 = ltorch.matmul(t26, value) # t13: \"cuda:0 f32[96, 512, 128]\"\n", " # t13 = prims.matmul(t26, value) # t13: \"cuda:0 f32[96, 512, 128]\"\n", " del query, key, value\n", " return t13\n" ] } ], "source": [ "print(thunder.last_traces(jfn_without_attn_ex)[-1])" ] }, { "cell_type": "markdown", "id": "c63ff2eb", "metadata": {}, "source": [ "# Summary\n", "\n", "So that is it.\n", "\n", "What did we achieve?\n", "\n", "- We implemented a kernel following the flash-attention 2 pseudocode (but not having all the optimizations) in CUDA.\n", "- Then we looked at how to compile and run it using the NVIDIA CUDA-Python bindings.\n", "- Finally, we saw how Thunder executors make it easy run PyTorch programs with targeted optimizations.\n", "\n", "We hope that you do great things and may your kernels always turn out to be faster than the baseline!" ] }, { "cell_type": "code", "execution_count": null, "id": "7a41ebe4", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "hide_code_all_hidden": false, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 5 }