{ "cells": [ { "cell_type": "markdown", "id": "1638964c", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Zero to Thunder\n", "\n", "Here we take a very short tour of what is possible with Thunder.\n", "\n", "To get started we import it (and a bunch of things for this notebook)." ] }, { "cell_type": "code", "execution_count": 1, "id": "28b99b58", "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.insert(0, '..')\n", "\n", "import torch, thunder" ] }, { "cell_type": "markdown", "id": "54f87aba", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Compiling a first module with Thunder\n", "\n", "So let's get started! As a \"Hello World\", let us apply it to it to a small model, say, the MLP part found in Llama 2. We take it from LitGPT." ] }, { "cell_type": "code", "execution_count": 2, "id": "892be718", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LLaMAMLP(\n", " (fc_1): Linear(in_features=4096, out_features=11008, bias=False)\n", " (fc_2): Linear(in_features=4096, out_features=11008, bias=False)\n", " (proj): Linear(in_features=11008, out_features=4096, bias=False)\n", ")\n" ] } ], "source": [ "class LLaMAMLP(torch.nn.Module):\n", " def __init__(self, n_embd, intermediate_size) -> None:\n", " super().__init__()\n", " self.fc_1 = torch.nn.Linear(n_embd, intermediate_size, bias=False)\n", " self.fc_2 = torch.nn.Linear(n_embd, intermediate_size, bias=False)\n", " self.proj = torch.nn.Linear(intermediate_size, n_embd, bias=False)\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " x_fc_1 = self.fc_1(x)\n", " x_fc_2 = self.fc_2(x)\n", " x = torch.nn.functional.silu(x_fc_1) * x_fc_2\n", " return self.proj(x)\n", "with torch.device(\"cuda\"):\n", " m = LLaMAMLP(4096, 11008)\n", "for p in m.parameters():\n", " p.requires_grad_(False)\n", "print(m)\n" ] }, { "cell_type": "markdown", "id": "702ea054", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Now we can apply Thunder. This uses the most important function of Thunder, `thunder.jit`, which can be used to compile a `torch.nn.Module` or a function. It will wrap our MLP in a `ThunderModule`" ] }, { "cell_type": "code", "execution_count": 3, "id": "67ca2d37", "metadata": {}, "outputs": [], "source": [ "thunder_model = thunder.jit(m)" ] }, { "cell_type": "code", "execution_count": 4, "id": "964e2689", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ThunderModule(\n", " (_model): LLaMAMLP(\n", " (fc_1): Linear(in_features=4096, out_features=11008, bias=False)\n", " (fc_2): Linear(in_features=4096, out_features=11008, bias=False)\n", " (proj): Linear(in_features=11008, out_features=4096, bias=False)\n", " )\n", ")" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "thunder_model" ] }, { "cell_type": "markdown", "id": "47d24f2d-0e89-4fe8-8154-9b50f2633e1b", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Our Thunder module computes (up to numerical accuracy) the same thing as our original model and for a small model like this, it also has approximately the same performance." ] }, { "cell_type": "code", "execution_count": 5, "id": "7f4de1b3", "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "deviation: 1.4901161193847656e-07\n", "61.3 ms ± 106 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n", "62.1 ms ± 89.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], "source": [ "x = torch.randn(2, 2048, 4096, device=\"cuda\")\n", "print('deviation:', (thunder_model(x) - m(x)).abs().max().item())\n", "\n", "%timeit thunder_model(x); torch.cuda.synchronize()\n", "%timeit m(x); torch.cuda.synchronize()" ] }, { "cell_type": "markdown", "id": "7996acc7-de20-4aa5-80f0-1ab6042e2650", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "So what has changed? Quite a bit!\n", "\n", "When we call the Thunder module, it does the computation in a single function without control flow. And what's more, it applies optimizations, such as creating fusions for NVFuser to execute. We can see all this by showing the last computation trace:" ] }, { "cell_type": "code", "execution_count": 6, "id": "a6f4b77c", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "# Constructed by Delete Last Used (took 0 milliseconds)\n", "import torch\n", "import torch.nn.functional\n", "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", "@no_autocast\n", "def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight):\n", " # x: \"cuda:0 f32[2, 2048, 4096]\" \n", " # t_fc_1_weight: \"cuda:0 f32[11008, 4096]\" \n", " # t_fc_2_weight: \"cuda:0 f32[11008, 4096]\" \n", " # t_proj_weight: \"cuda:0 f32[4096, 11008]\" \n", " x_fc_1 = torch.nn.functional.linear(x, t_fc_1_weight, None) # x_fc_1: \"cuda:0 f32[2, 2048, 11008]\"\n", " # x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: \"cuda:0 f32[2, 2048, 11008]\"\n", " # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: \"cuda:0 f32[2, 2048, 11008]\"\n", " del t_fc_1_weight\n", " x_fc_2 = torch.nn.functional.linear(x, t_fc_2_weight, None) # x_fc_2: \"cuda:0 f32[2, 2048, 11008]\"\n", " # x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: \"cuda:0 f32[2, 2048, 11008]\"\n", " # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: \"cuda:0 f32[2, 2048, 11008]\"\n", " del x, t_fc_2_weight\n", " [result] = nvFusion0(x_fc_1, x_fc_2)\n", " # t9 = prims.neg(x_fc_1) # t9: \"cuda:0 f32[2, 2048, 11008]\"\n", " # t10 = prims.exp(t9) # t10: \"cuda:0 f32[2, 2048, 11008]\"\n", " # t11 = prims.add(1.0, t10) # t11: \"cuda:0 f32[2, 2048, 11008]\"\n", " # t12 = prims.reciprocal(t11) # t12: \"cuda:0 f32[2, 2048, 11008]\"\n", " # a = prims.mul(x_fc_1, t12) # a: \"cuda:0 f32[2, 2048, 11008]\"\n", " # result = prims.mul(a, x_fc_2) # result: \"cuda:0 f32[2, 2048, 11008]\"\n", " del x_fc_1, x_fc_2\n", " t18 = torch.nn.functional.linear(result, t_proj_weight, None) # t18: \"cuda:0 f32[2, 2048, 4096]\"\n", " # t18 = ltorch.linear(result, t_proj_weight, None) # t18: \"cuda:0 f32[2, 2048, 4096]\"\n", " # t18 = prims.linear(result, t_proj_weight, None) # t18: \"cuda:0 f32[2, 2048, 4096]\"\n", " del result, t_proj_weight\n", " return t18" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "thunder.last_traces(thunder_model)[-1]" ] }, { "cell_type": "markdown", "id": "2ef89186-70cd-4737-9695-ed282da2a56c", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "For more detail of what is going on in this trace:\n", "- Thunder has transformed the computation (more precisely, `m.__call__`) into a single function which has all the MLP parameters as arguments.\n", "- It has recorded the tensor metadata.\n", "- Operations have been mapped from the PyTorch functions to `thunder.torch`(aka `ltorch`) equivalents and decomposed into _primitive operations_.\n", "- The multiplication and activation (`x = torch.nn.functional.silu(x_fc_1) * x_fc_2`have been put into one NVFuser fusion. (NVFuser here is (a particularly important) one of many optimizations, and we make it easy to add your own.) \n", "- You can see how the parameters are obtained and the metadata is checked in the prologue - get it through `thunder.last_prologue_traces(thunder_model)[-1]`.\n", "\n", "You can actually see the series of traces, `last_traces` gives you a list of transformed traces in chronological order - for example the initial trace `thunder.last_traces(thunder_model)[0]` does not have the fusion yet.\n" ] }, { "cell_type": "markdown", "id": "7749aed1", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Compiling a more complex model\n", "\n", "Obviously, we aim for larger models, so we can do the same with the entire LLama 2 (well, we have a smaller momdel here to be mild to our CI, but if you have a large GPU, just drop reducing the number of layers):\n", "\n", "**NOTE**: For running the cells below, we require `litgpt` which can be installed with `pip install 'litgpt[all] @ git+https://github.com/Lightning-AI/litgpt'`. See [here](https://github.com/Lightning-AI/litgpt) to learn more about litgpt." ] }, { "cell_type": "code", "execution_count": 7, "id": "d53e0c43", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GPT(\n", " (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n", " (transformer): ModuleDict(\n", " (wte): Embedding(32000, 4096)\n", " (h): ModuleList(\n", " (0-15): 16 x Block(\n", " (norm_1): RMSNorm()\n", " (attn): CausalSelfAttention(\n", " (attn): Linear(in_features=4096, out_features=12288, bias=False)\n", " (proj): Linear(in_features=4096, out_features=4096, bias=False)\n", " )\n", " (norm_2): RMSNorm()\n", " (mlp): LLaMAMLP(\n", " (fc_1): Linear(in_features=4096, out_features=11008, bias=False)\n", " (fc_2): Linear(in_features=4096, out_features=11008, bias=False)\n", " (proj): Linear(in_features=11008, out_features=4096, bias=False)\n", " )\n", " )\n", " )\n", " (ln_f): RMSNorm()\n", " )\n", ")" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from litgpt import GPT\n", "from thunder.tests.litgpt_model import Config\n", "cfg = Config.from_name('Llama-2-7b-hf')\n", "cfg.n_layer = 16 # fewer layers\n", "torch.set_default_dtype(torch.bfloat16)\n", "with torch.device('cuda'):\n", " m = GPT(cfg)\n", "m\n" ] }, { "cell_type": "markdown", "id": "e536a4aa", "metadata": {}, "source": [ "Again we jit our model and compare the output..." ] }, { "cell_type": "code", "execution_count": 8, "id": "36a7be96", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "deviation: 0.03125\n" ] } ], "source": [ "thunder_model = thunder.jit(m)\n", "\n", "inp = torch.randint(1, m.config.vocab_size, (1, 512), device=\"cuda\")\n", "\n", "actual = thunder_model(inp)\n", "expected = m(inp)\n", "\n", "print(\"deviation:\", (actual - expected).abs().max().item())\n" ] }, { "cell_type": "markdown", "id": "9947e8df-cd2d-447d-90b9-ee08bb5a9fb2", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "One thing to keep in mind here is that for bf16, the numerical accuracy impact of rearranging operations can be quite pronounced.\n", "\n", "Just like before, we can see the program it ran, it is a lot longer, though." ] }, { "cell_type": "code", "execution_count": 9, "id": "ac7e8bc9", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "text/plain": [ "# Constructed by Delete Last Used (took 10 milliseconds)\n", "import torch\n", "from torch import Tensor\n", "import torch.nn.functional\n", "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", "@no_autocast\n", "def augmented_forward_fn(*args):\n", " # args: \"Collection\" \n", " t0, \\\n", " t1, \\\n", " t2, \\\n", " t3, \\\n", " t4, \\\n", " t5, \\\n", " t6, \\\n", " t7, \\\n", " t8, \\\n", " t9, \\\n", " t10, \\\n", " t11, \\\n", " t12, \\\n", " t13, \\\n", " t14, \\\n", " t15, \\\n", " t16, \\\n", " t17, \\\n", " t18, \\\n", " t19, \\\n", " t20, \\\n", " t21, \\\n", " t22, \\\n", " t23, \\\n", " t24, \\\n", " t25, \\\n", " t26, \\\n", " t27, \\\n", " t28, \\\n", " t29, \\\n", " t30, \\\n", " t31, \\\n", " t32, \\\n", " t33, \\\n", " t34, \\\n", " t35, \\\n", " t36, \\\n", " t37, \\\n", " t38, \\\n", " t39, \\\n", " t40, \\\n", " t41, \\\n", " t42, \\\n", " t43, \\\n", " t44, \\\n", " t45, \\\n", " t46, \\\n", " t47, \\\n", " t48, \\\n", " t49, \\\n", " t50, \\\n", " t51, \\\n", " t52, \\\n", " t53, \\\n", " t54, \\\n", " t55, \\\n", " t56, \\\n", " t57, \\\n", " t58, \\\n", " t59, \\\n", " t60, \\\n", " t61, \\\n", " t62, \\\n", " t63, \\\n", " t64, \\\n", " t65, \\\n", " t66, \\\n", " t67, \\\n", " t68, \\\n", " t69, \\\n", " t70, \\\n", " t71, \\\n", " t72, \\\n", " t73, \\\n", " t74, \\\n", " t75, \\\n", " t76, \\\n", " t77, \\\n", " t78, \\\n", " t79, \\\n", " t80, \\\n", " t81, \\\n", " t82, \\\n", " t83, \\\n", " t84, \\\n", " t85, \\\n", " t86, \\\n", " t87, \\\n", " t88, \\\n", " t89, \\\n", " t90, \\\n", " t91, \\\n", " t92, \\\n", " t93, \\\n", " t94, \\\n", " t95, \\\n", " t96, \\\n", " t97, \\\n", " t98, \\\n", " t99, \\\n", " t100, \\\n", " t101, \\\n", " t102, \\\n", " t103, \\\n", " t104, \\\n", " t105, \\\n", " t106, \\\n", " t107, \\\n", " t108, \\\n", " t109, \\\n", " t110, \\\n", " t111, \\\n", " t112, \\\n", " t113, \\\n", " t114, \\\n", " t115, \\\n", " t116, \\\n", " t117, \\\n", " = args\n", " del args\n", " t122 = torch.nn.functional.embedding(t0, t117, None, None, 2.0, False, False) # t122: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t122 = ltorch.embedding(t0, t117, None, None, 2.0, False, False) # t122: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1867 = ltorch.reshape(t0, [512]) # t1867: \"cuda:0 i64[512]\"\n", " # t1867 = prims.reshape(t0, (512,)) # t1867: \"cuda:0 i64[512]\"\n", " # t1868 = prims.take(t117, t1867, 0) # t1868: \"cuda:0 bf16[512, 4096]\"\n", " # t122 = ltorch.reshape(t1868, [1, 512, 4096]) # t122: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t122 = prims.reshape(t1868, (1, 512, 4096)) # t122: \"cuda:0 bf16[1, 512, 4096]\"\n", " t118 = torch_slice_prim_impl(t1, [0, 0], [512, 128], [1, 1]) # t118: \"cuda:0 f32[512, 128]\"\n", " t119 = torch_slice_prim_impl(t2, [0, 0], [512, 128], [1, 1]) # t119: \"cuda:0 f32[512, 128]\"\n", " t2015 = torch.unsqueeze(t53, 0) # t2015: \"cuda:0 bf16[1, 4096]\"\n", " # t2015 = ltorch.unsqueeze(t53, 0) # t2015: \"cuda:0 bf16[1, 4096]\"\n", " # t2015 = prims.broadcast_in_dim(t53, [1, 4096], [1]) # t2015: \"cuda:0 bf16[1, 4096]\"\n", " t2016 = torch.unsqueeze(t2015, 1) # t2016: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2016 = ltorch.unsqueeze(t2015, 1) # t2016: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2016 = prims.broadcast_in_dim(t2015, [1, 1, 4096], [0, 2]) # t2016: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2015\n", " t133 = Tensor.expand(t2016, (1, 512, 4096)) # t133: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t133 = ltorch.expand(t2016, (1, 512, 4096)) # t133: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t133 = prims.broadcast_in_dim(t2016, (1, 512, 4096), (0, 1, 2)) # t133: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2016\n", " t2356 = torch.unsqueeze(t82, 0) # t2356: \"cuda:0 bf16[1, 4096]\"\n", " # t2356 = ltorch.unsqueeze(t82, 0) # t2356: \"cuda:0 bf16[1, 4096]\"\n", " # t2356 = prims.broadcast_in_dim(t82, [1, 4096], [1]) # t2356: \"cuda:0 bf16[1, 4096]\"\n", " t2357 = torch.unsqueeze(t2356, 1) # t2357: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2357 = ltorch.unsqueeze(t2356, 1) # t2357: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2357 = prims.broadcast_in_dim(t2356, [1, 1, 4096], [0, 2]) # t2357: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2356\n", " t1609 = Tensor.expand(t2357, (1, 512, 4096)) # t1609: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1609 = ltorch.expand(t2357, (1, 512, 4096)) # t1609: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1609 = prims.broadcast_in_dim(t2357, (1, 512, 4096), (0, 1, 2)) # t1609: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2357\n", " t2359 = torch.unsqueeze(t58, 0) # t2359: \"cuda:0 bf16[1, 4096]\"\n", " # t2359 = ltorch.unsqueeze(t58, 0) # t2359: \"cuda:0 bf16[1, 4096]\"\n", " # t2359 = prims.broadcast_in_dim(t58, [1, 4096], [1]) # t2359: \"cuda:0 bf16[1, 4096]\"\n", " t2360 = torch.unsqueeze(t2359, 1) # t2360: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2360 = ltorch.unsqueeze(t2359, 1) # t2360: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2360 = prims.broadcast_in_dim(t2359, [1, 1, 4096], [0, 2]) # t2360: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2359\n", " t1645 = Tensor.expand(t2360, (1, 512, 4096)) # t1645: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1645 = ltorch.expand(t2360, (1, 512, 4096)) # t1645: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1645 = prims.broadcast_in_dim(t2360, (1, 512, 4096), (0, 1, 2)) # t1645: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2360\n", " t2044 = torch.unsqueeze(t69, 0) # t2044: \"cuda:0 bf16[1, 4096]\"\n", " # t2044 = ltorch.unsqueeze(t69, 0) # t2044: \"cuda:0 bf16[1, 4096]\"\n", " # t2044 = prims.broadcast_in_dim(t69, [1, 4096], [1]) # t2044: \"cuda:0 bf16[1, 4096]\"\n", " t2045 = torch.unsqueeze(t2044, 1) # t2045: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2045 = ltorch.unsqueeze(t2044, 1) # t2045: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2045 = prims.broadcast_in_dim(t2044, [1, 1, 4096], [0, 2]) # t2045: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2044\n", " t205 = Tensor.expand(t2045, (1, 512, 4096)) # t205: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t205 = ltorch.expand(t2045, (1, 512, 4096)) # t205: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t205 = prims.broadcast_in_dim(t2045, (1, 512, 4096), (0, 1, 2)) # t205: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2045\n", " t2380 = torch.unsqueeze(t83, 0) # t2380: \"cuda:0 bf16[1, 4096]\"\n", " # t2380 = ltorch.unsqueeze(t83, 0) # t2380: \"cuda:0 bf16[1, 4096]\"\n", " # t2380 = prims.broadcast_in_dim(t83, [1, 4096], [1]) # t2380: \"cuda:0 bf16[1, 4096]\"\n", " t2381 = torch.unsqueeze(t2380, 1) # t2381: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2381 = ltorch.unsqueeze(t2380, 1) # t2381: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2381 = prims.broadcast_in_dim(t2380, [1, 1, 4096], [0, 2]) # t2381: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2380\n", " t1717 = Tensor.expand(t2381, (1, 512, 4096)) # t1717: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1717 = ltorch.expand(t2381, (1, 512, 4096)) # t1717: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1717 = prims.broadcast_in_dim(t2381, (1, 512, 4096), (0, 1, 2)) # t1717: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2381\n", " t2047 = torch.unsqueeze(t60, 0) # t2047: \"cuda:0 bf16[1, 4096]\"\n", " # t2047 = ltorch.unsqueeze(t60, 0) # t2047: \"cuda:0 bf16[1, 4096]\"\n", " # t2047 = prims.broadcast_in_dim(t60, [1, 4096], [1]) # t2047: \"cuda:0 bf16[1, 4096]\"\n", " t2048 = torch.unsqueeze(t2047, 1) # t2048: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2048 = ltorch.unsqueeze(t2047, 1) # t2048: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2048 = prims.broadcast_in_dim(t2047, [1, 1, 4096], [0, 2]) # t2048: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2047\n", " t241 = Tensor.expand(t2048, (1, 512, 4096)) # t241: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t241 = ltorch.expand(t2048, (1, 512, 4096)) # t241: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t241 = prims.broadcast_in_dim(t2048, (1, 512, 4096), (0, 1, 2)) # t241: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2048\n", " t2383 = torch.unsqueeze(t59, 0) # t2383: \"cuda:0 bf16[1, 4096]\"\n", " # t2383 = ltorch.unsqueeze(t59, 0) # t2383: \"cuda:0 bf16[1, 4096]\"\n", " # t2383 = prims.broadcast_in_dim(t59, [1, 4096], [1]) # t2383: \"cuda:0 bf16[1, 4096]\"\n", " t2384 = torch.unsqueeze(t2383, 1) # t2384: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2384 = ltorch.unsqueeze(t2383, 1) # t2384: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2384 = prims.broadcast_in_dim(t2383, [1, 1, 4096], [0, 2]) # t2384: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2383\n", " t1753 = Tensor.expand(t2384, (1, 512, 4096)) # t1753: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1753 = ltorch.expand(t2384, (1, 512, 4096)) # t1753: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1753 = prims.broadcast_in_dim(t2384, (1, 512, 4096), (0, 1, 2)) # t1753: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2384\n", " t2068 = torch.unsqueeze(t70, 0) # t2068: \"cuda:0 bf16[1, 4096]\"\n", " # t2068 = ltorch.unsqueeze(t70, 0) # t2068: \"cuda:0 bf16[1, 4096]\"\n", " # t2068 = prims.broadcast_in_dim(t70, [1, 4096], [1]) # t2068: \"cuda:0 bf16[1, 4096]\"\n", " t2069 = torch.unsqueeze(t2068, 1) # t2069: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2069 = ltorch.unsqueeze(t2068, 1) # t2069: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2069 = prims.broadcast_in_dim(t2068, [1, 1, 4096], [0, 2]) # t2069: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2068\n", " t313 = Tensor.expand(t2069, (1, 512, 4096)) # t313: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t313 = ltorch.expand(t2069, (1, 512, 4096)) # t313: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t313 = prims.broadcast_in_dim(t2069, (1, 512, 4096), (0, 1, 2)) # t313: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2069\n", " t2404 = torch.unsqueeze(t84, 0) # t2404: \"cuda:0 bf16[1, 4096]\"\n", " # t2404 = ltorch.unsqueeze(t84, 0) # t2404: \"cuda:0 bf16[1, 4096]\"\n", " # t2404 = prims.broadcast_in_dim(t84, [1, 4096], [1]) # t2404: \"cuda:0 bf16[1, 4096]\"\n", " t2405 = torch.unsqueeze(t2404, 1) # t2405: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2405 = ltorch.unsqueeze(t2404, 1) # t2405: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2405 = prims.broadcast_in_dim(t2404, [1, 1, 4096], [0, 2]) # t2405: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2404\n", " t1825 = Tensor.expand(t2405, (1, 512, 4096)) # t1825: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1825 = ltorch.expand(t2405, (1, 512, 4096)) # t1825: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1825 = prims.broadcast_in_dim(t2405, (1, 512, 4096), (0, 1, 2)) # t1825: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2405\n", " t2071 = torch.unsqueeze(t61, 0) # t2071: \"cuda:0 bf16[1, 4096]\"\n", " # t2071 = ltorch.unsqueeze(t61, 0) # t2071: \"cuda:0 bf16[1, 4096]\"\n", " # t2071 = prims.broadcast_in_dim(t61, [1, 4096], [1]) # t2071: \"cuda:0 bf16[1, 4096]\"\n", " t2072 = torch.unsqueeze(t2071, 1) # t2072: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2072 = ltorch.unsqueeze(t2071, 1) # t2072: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2072 = prims.broadcast_in_dim(t2071, [1, 1, 4096], [0, 2]) # t2072: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2071\n", " t349 = Tensor.expand(t2072, (1, 512, 4096)) # t349: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t349 = ltorch.expand(t2072, (1, 512, 4096)) # t349: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t349 = prims.broadcast_in_dim(t2072, (1, 512, 4096), (0, 1, 2)) # t349: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2072\n", " t2407 = torch.unsqueeze(t52, 0) # t2407: \"cuda:0 bf16[1, 4096]\"\n", " # t2407 = ltorch.unsqueeze(t52, 0) # t2407: \"cuda:0 bf16[1, 4096]\"\n", " # t2407 = prims.broadcast_in_dim(t52, [1, 4096], [1]) # t2407: \"cuda:0 bf16[1, 4096]\"\n", " t2408 = torch.unsqueeze(t2407, 1) # t2408: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2408 = ltorch.unsqueeze(t2407, 1) # t2408: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2408 = prims.broadcast_in_dim(t2407, [1, 1, 4096], [0, 2]) # t2408: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2407\n", " t1861 = Tensor.expand(t2408, (1, 512, 4096)) # t1861: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1861 = ltorch.expand(t2408, (1, 512, 4096)) # t1861: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1861 = prims.broadcast_in_dim(t2408, (1, 512, 4096), (0, 1, 2)) # t1861: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2408\n", " t2095 = torch.unsqueeze(t62, 0) # t2095: \"cuda:0 bf16[1, 4096]\"\n", " # t2095 = ltorch.unsqueeze(t62, 0) # t2095: \"cuda:0 bf16[1, 4096]\"\n", " # t2095 = prims.broadcast_in_dim(t62, [1, 4096], [1]) # t2095: \"cuda:0 bf16[1, 4096]\"\n", " t2096 = torch.unsqueeze(t2095, 1) # t2096: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2096 = ltorch.unsqueeze(t2095, 1) # t2096: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2096 = prims.broadcast_in_dim(t2095, [1, 1, 4096], [0, 2]) # t2096: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2095\n", " t457 = Tensor.expand(t2096, (1, 512, 4096)) # t457: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t457 = ltorch.expand(t2096, (1, 512, 4096)) # t457: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t457 = prims.broadcast_in_dim(t2096, (1, 512, 4096), (0, 1, 2)) # t457: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2096\n", " t2092 = torch.unsqueeze(t71, 0) # t2092: \"cuda:0 bf16[1, 4096]\"\n", " # t2092 = ltorch.unsqueeze(t71, 0) # t2092: \"cuda:0 bf16[1, 4096]\"\n", " # t2092 = prims.broadcast_in_dim(t71, [1, 4096], [1]) # t2092: \"cuda:0 bf16[1, 4096]\"\n", " t2093 = torch.unsqueeze(t2092, 1) # t2093: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2093 = ltorch.unsqueeze(t2092, 1) # t2093: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2093 = prims.broadcast_in_dim(t2092, [1, 1, 4096], [0, 2]) # t2093: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2092\n", " t421 = Tensor.expand(t2093, (1, 512, 4096)) # t421: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t421 = ltorch.expand(t2093, (1, 512, 4096)) # t421: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t421 = prims.broadcast_in_dim(t2093, (1, 512, 4096), (0, 1, 2)) # t421: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2093\n", " t2116 = torch.unsqueeze(t72, 0) # t2116: \"cuda:0 bf16[1, 4096]\"\n", " # t2116 = ltorch.unsqueeze(t72, 0) # t2116: \"cuda:0 bf16[1, 4096]\"\n", " # t2116 = prims.broadcast_in_dim(t72, [1, 4096], [1]) # t2116: \"cuda:0 bf16[1, 4096]\"\n", " t2117 = torch.unsqueeze(t2116, 1) # t2117: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2117 = ltorch.unsqueeze(t2116, 1) # t2117: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2117 = prims.broadcast_in_dim(t2116, [1, 1, 4096], [0, 2]) # t2117: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2116\n", " t529 = Tensor.expand(t2117, (1, 512, 4096)) # t529: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t529 = ltorch.expand(t2117, (1, 512, 4096)) # t529: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t529 = prims.broadcast_in_dim(t2117, (1, 512, 4096), (0, 1, 2)) # t529: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2117\n", " t2119 = torch.unsqueeze(t63, 0) # t2119: \"cuda:0 bf16[1, 4096]\"\n", " # t2119 = ltorch.unsqueeze(t63, 0) # t2119: \"cuda:0 bf16[1, 4096]\"\n", " # t2119 = prims.broadcast_in_dim(t63, [1, 4096], [1]) # t2119: \"cuda:0 bf16[1, 4096]\"\n", " t2120 = torch.unsqueeze(t2119, 1) # t2120: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2120 = ltorch.unsqueeze(t2119, 1) # t2120: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2120 = prims.broadcast_in_dim(t2119, [1, 1, 4096], [0, 2]) # t2120: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2119\n", " t565 = Tensor.expand(t2120, (1, 512, 4096)) # t565: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t565 = ltorch.expand(t2120, (1, 512, 4096)) # t565: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t565 = prims.broadcast_in_dim(t2120, (1, 512, 4096), (0, 1, 2)) # t565: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2120\n", " t2140 = torch.unsqueeze(t73, 0) # t2140: \"cuda:0 bf16[1, 4096]\"\n", " # t2140 = ltorch.unsqueeze(t73, 0) # t2140: \"cuda:0 bf16[1, 4096]\"\n", " # t2140 = prims.broadcast_in_dim(t73, [1, 4096], [1]) # t2140: \"cuda:0 bf16[1, 4096]\"\n", " t2141 = torch.unsqueeze(t2140, 1) # t2141: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2141 = ltorch.unsqueeze(t2140, 1) # t2141: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2141 = prims.broadcast_in_dim(t2140, [1, 1, 4096], [0, 2]) # t2141: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2140\n", " t637 = Tensor.expand(t2141, (1, 512, 4096)) # t637: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t637 = ltorch.expand(t2141, (1, 512, 4096)) # t637: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t637 = prims.broadcast_in_dim(t2141, (1, 512, 4096), (0, 1, 2)) # t637: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2141\n", " t2143 = torch.unsqueeze(t64, 0) # t2143: \"cuda:0 bf16[1, 4096]\"\n", " # t2143 = ltorch.unsqueeze(t64, 0) # t2143: \"cuda:0 bf16[1, 4096]\"\n", " # t2143 = prims.broadcast_in_dim(t64, [1, 4096], [1]) # t2143: \"cuda:0 bf16[1, 4096]\"\n", " t2144 = torch.unsqueeze(t2143, 1) # t2144: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2144 = ltorch.unsqueeze(t2143, 1) # t2144: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2144 = prims.broadcast_in_dim(t2143, [1, 1, 4096], [0, 2]) # t2144: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2143\n", " t673 = Tensor.expand(t2144, (1, 512, 4096)) # t673: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t673 = ltorch.expand(t2144, (1, 512, 4096)) # t673: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t673 = prims.broadcast_in_dim(t2144, (1, 512, 4096), (0, 1, 2)) # t673: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2144\n", " t2164 = torch.unsqueeze(t74, 0) # t2164: \"cuda:0 bf16[1, 4096]\"\n", " # t2164 = ltorch.unsqueeze(t74, 0) # t2164: \"cuda:0 bf16[1, 4096]\"\n", " # t2164 = prims.broadcast_in_dim(t74, [1, 4096], [1]) # t2164: \"cuda:0 bf16[1, 4096]\"\n", " t2165 = torch.unsqueeze(t2164, 1) # t2165: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2165 = ltorch.unsqueeze(t2164, 1) # t2165: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2165 = prims.broadcast_in_dim(t2164, [1, 1, 4096], [0, 2]) # t2165: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2164\n", " t745 = Tensor.expand(t2165, (1, 512, 4096)) # t745: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t745 = ltorch.expand(t2165, (1, 512, 4096)) # t745: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t745 = prims.broadcast_in_dim(t2165, (1, 512, 4096), (0, 1, 2)) # t745: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2165\n", " t2167 = torch.unsqueeze(t65, 0) # t2167: \"cuda:0 bf16[1, 4096]\"\n", " # t2167 = ltorch.unsqueeze(t65, 0) # t2167: \"cuda:0 bf16[1, 4096]\"\n", " # t2167 = prims.broadcast_in_dim(t65, [1, 4096], [1]) # t2167: \"cuda:0 bf16[1, 4096]\"\n", " t2168 = torch.unsqueeze(t2167, 1) # t2168: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2168 = ltorch.unsqueeze(t2167, 1) # t2168: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2168 = prims.broadcast_in_dim(t2167, [1, 1, 4096], [0, 2]) # t2168: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2167\n", " t781 = Tensor.expand(t2168, (1, 512, 4096)) # t781: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t781 = ltorch.expand(t2168, (1, 512, 4096)) # t781: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t781 = prims.broadcast_in_dim(t2168, (1, 512, 4096), (0, 1, 2)) # t781: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2168\n", " t2188 = torch.unsqueeze(t75, 0) # t2188: \"cuda:0 bf16[1, 4096]\"\n", " # t2188 = ltorch.unsqueeze(t75, 0) # t2188: \"cuda:0 bf16[1, 4096]\"\n", " # t2188 = prims.broadcast_in_dim(t75, [1, 4096], [1]) # t2188: \"cuda:0 bf16[1, 4096]\"\n", " t2189 = torch.unsqueeze(t2188, 1) # t2189: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2189 = ltorch.unsqueeze(t2188, 1) # t2189: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2189 = prims.broadcast_in_dim(t2188, [1, 1, 4096], [0, 2]) # t2189: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2188\n", " t853 = Tensor.expand(t2189, (1, 512, 4096)) # t853: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t853 = ltorch.expand(t2189, (1, 512, 4096)) # t853: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t853 = prims.broadcast_in_dim(t2189, (1, 512, 4096), (0, 1, 2)) # t853: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2189\n", " t2191 = torch.unsqueeze(t66, 0) # t2191: \"cuda:0 bf16[1, 4096]\"\n", " # t2191 = ltorch.unsqueeze(t66, 0) # t2191: \"cuda:0 bf16[1, 4096]\"\n", " # t2191 = prims.broadcast_in_dim(t66, [1, 4096], [1]) # t2191: \"cuda:0 bf16[1, 4096]\"\n", " t2192 = torch.unsqueeze(t2191, 1) # t2192: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2192 = ltorch.unsqueeze(t2191, 1) # t2192: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2192 = prims.broadcast_in_dim(t2191, [1, 1, 4096], [0, 2]) # t2192: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2191\n", " t889 = Tensor.expand(t2192, (1, 512, 4096)) # t889: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t889 = ltorch.expand(t2192, (1, 512, 4096)) # t889: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t889 = prims.broadcast_in_dim(t2192, (1, 512, 4096), (0, 1, 2)) # t889: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2192\n", " t2212 = torch.unsqueeze(t76, 0) # t2212: \"cuda:0 bf16[1, 4096]\"\n", " # t2212 = ltorch.unsqueeze(t76, 0) # t2212: \"cuda:0 bf16[1, 4096]\"\n", " # t2212 = prims.broadcast_in_dim(t76, [1, 4096], [1]) # t2212: \"cuda:0 bf16[1, 4096]\"\n", " t2213 = torch.unsqueeze(t2212, 1) # t2213: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2213 = ltorch.unsqueeze(t2212, 1) # t2213: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2213 = prims.broadcast_in_dim(t2212, [1, 1, 4096], [0, 2]) # t2213: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2212\n", " t961 = Tensor.expand(t2213, (1, 512, 4096)) # t961: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t961 = ltorch.expand(t2213, (1, 512, 4096)) # t961: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t961 = prims.broadcast_in_dim(t2213, (1, 512, 4096), (0, 1, 2)) # t961: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2213\n", " t2215 = torch.unsqueeze(t67, 0) # t2215: \"cuda:0 bf16[1, 4096]\"\n", " # t2215 = ltorch.unsqueeze(t67, 0) # t2215: \"cuda:0 bf16[1, 4096]\"\n", " # t2215 = prims.broadcast_in_dim(t67, [1, 4096], [1]) # t2215: \"cuda:0 bf16[1, 4096]\"\n", " t2216 = torch.unsqueeze(t2215, 1) # t2216: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2216 = ltorch.unsqueeze(t2215, 1) # t2216: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2216 = prims.broadcast_in_dim(t2215, [1, 1, 4096], [0, 2]) # t2216: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2215\n", " t997 = Tensor.expand(t2216, (1, 512, 4096)) # t997: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t997 = ltorch.expand(t2216, (1, 512, 4096)) # t997: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t997 = prims.broadcast_in_dim(t2216, (1, 512, 4096), (0, 1, 2)) # t997: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2216\n", " t2236 = torch.unsqueeze(t77, 0) # t2236: \"cuda:0 bf16[1, 4096]\"\n", " # t2236 = ltorch.unsqueeze(t77, 0) # t2236: \"cuda:0 bf16[1, 4096]\"\n", " # t2236 = prims.broadcast_in_dim(t77, [1, 4096], [1]) # t2236: \"cuda:0 bf16[1, 4096]\"\n", " t2237 = torch.unsqueeze(t2236, 1) # t2237: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2237 = ltorch.unsqueeze(t2236, 1) # t2237: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2237 = prims.broadcast_in_dim(t2236, [1, 1, 4096], [0, 2]) # t2237: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2236\n", " t1069 = Tensor.expand(t2237, (1, 512, 4096)) # t1069: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1069 = ltorch.expand(t2237, (1, 512, 4096)) # t1069: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1069 = prims.broadcast_in_dim(t2237, (1, 512, 4096), (0, 1, 2)) # t1069: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2237\n", " t2239 = torch.unsqueeze(t68, 0) # t2239: \"cuda:0 bf16[1, 4096]\"\n", " # t2239 = ltorch.unsqueeze(t68, 0) # t2239: \"cuda:0 bf16[1, 4096]\"\n", " # t2239 = prims.broadcast_in_dim(t68, [1, 4096], [1]) # t2239: \"cuda:0 bf16[1, 4096]\"\n", " t2240 = torch.unsqueeze(t2239, 1) # t2240: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2240 = ltorch.unsqueeze(t2239, 1) # t2240: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2240 = prims.broadcast_in_dim(t2239, [1, 1, 4096], [0, 2]) # t2240: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2239\n", " t1105 = Tensor.expand(t2240, (1, 512, 4096)) # t1105: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1105 = ltorch.expand(t2240, (1, 512, 4096)) # t1105: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1105 = prims.broadcast_in_dim(t2240, (1, 512, 4096), (0, 1, 2)) # t1105: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2240\n", " t2260 = torch.unsqueeze(t78, 0) # t2260: \"cuda:0 bf16[1, 4096]\"\n", " # t2260 = ltorch.unsqueeze(t78, 0) # t2260: \"cuda:0 bf16[1, 4096]\"\n", " # t2260 = prims.broadcast_in_dim(t78, [1, 4096], [1]) # t2260: \"cuda:0 bf16[1, 4096]\"\n", " t2261 = torch.unsqueeze(t2260, 1) # t2261: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2261 = ltorch.unsqueeze(t2260, 1) # t2261: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2261 = prims.broadcast_in_dim(t2260, [1, 1, 4096], [0, 2]) # t2261: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2260\n", " t1177 = Tensor.expand(t2261, (1, 512, 4096)) # t1177: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1177 = ltorch.expand(t2261, (1, 512, 4096)) # t1177: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1177 = prims.broadcast_in_dim(t2261, (1, 512, 4096), (0, 1, 2)) # t1177: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2261\n", " t2263 = torch.unsqueeze(t54, 0) # t2263: \"cuda:0 bf16[1, 4096]\"\n", " # t2263 = ltorch.unsqueeze(t54, 0) # t2263: \"cuda:0 bf16[1, 4096]\"\n", " # t2263 = prims.broadcast_in_dim(t54, [1, 4096], [1]) # t2263: \"cuda:0 bf16[1, 4096]\"\n", " t2264 = torch.unsqueeze(t2263, 1) # t2264: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2264 = ltorch.unsqueeze(t2263, 1) # t2264: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2264 = prims.broadcast_in_dim(t2263, [1, 1, 4096], [0, 2]) # t2264: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2263\n", " t1213 = Tensor.expand(t2264, (1, 512, 4096)) # t1213: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1213 = ltorch.expand(t2264, (1, 512, 4096)) # t1213: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1213 = prims.broadcast_in_dim(t2264, (1, 512, 4096), (0, 1, 2)) # t1213: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2264\n", " t2284 = torch.unsqueeze(t79, 0) # t2284: \"cuda:0 bf16[1, 4096]\"\n", " # t2284 = ltorch.unsqueeze(t79, 0) # t2284: \"cuda:0 bf16[1, 4096]\"\n", " # t2284 = prims.broadcast_in_dim(t79, [1, 4096], [1]) # t2284: \"cuda:0 bf16[1, 4096]\"\n", " t2285 = torch.unsqueeze(t2284, 1) # t2285: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2285 = ltorch.unsqueeze(t2284, 1) # t2285: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2285 = prims.broadcast_in_dim(t2284, [1, 1, 4096], [0, 2]) # t2285: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2284\n", " t1285 = Tensor.expand(t2285, (1, 512, 4096)) # t1285: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1285 = ltorch.expand(t2285, (1, 512, 4096)) # t1285: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1285 = prims.broadcast_in_dim(t2285, (1, 512, 4096), (0, 1, 2)) # t1285: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2285\n", " t2287 = torch.unsqueeze(t55, 0) # t2287: \"cuda:0 bf16[1, 4096]\"\n", " # t2287 = ltorch.unsqueeze(t55, 0) # t2287: \"cuda:0 bf16[1, 4096]\"\n", " # t2287 = prims.broadcast_in_dim(t55, [1, 4096], [1]) # t2287: \"cuda:0 bf16[1, 4096]\"\n", " t2288 = torch.unsqueeze(t2287, 1) # t2288: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2288 = ltorch.unsqueeze(t2287, 1) # t2288: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2288 = prims.broadcast_in_dim(t2287, [1, 1, 4096], [0, 2]) # t2288: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2287\n", " t1321 = Tensor.expand(t2288, (1, 512, 4096)) # t1321: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1321 = ltorch.expand(t2288, (1, 512, 4096)) # t1321: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1321 = prims.broadcast_in_dim(t2288, (1, 512, 4096), (0, 1, 2)) # t1321: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2288\n", " t2308 = torch.unsqueeze(t80, 0) # t2308: \"cuda:0 bf16[1, 4096]\"\n", " # t2308 = ltorch.unsqueeze(t80, 0) # t2308: \"cuda:0 bf16[1, 4096]\"\n", " # t2308 = prims.broadcast_in_dim(t80, [1, 4096], [1]) # t2308: \"cuda:0 bf16[1, 4096]\"\n", " t2309 = torch.unsqueeze(t2308, 1) # t2309: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2309 = ltorch.unsqueeze(t2308, 1) # t2309: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2309 = prims.broadcast_in_dim(t2308, [1, 1, 4096], [0, 2]) # t2309: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2308\n", " t1393 = Tensor.expand(t2309, (1, 512, 4096)) # t1393: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1393 = ltorch.expand(t2309, (1, 512, 4096)) # t1393: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1393 = prims.broadcast_in_dim(t2309, (1, 512, 4096), (0, 1, 2)) # t1393: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2309\n", " t2311 = torch.unsqueeze(t56, 0) # t2311: \"cuda:0 bf16[1, 4096]\"\n", " # t2311 = ltorch.unsqueeze(t56, 0) # t2311: \"cuda:0 bf16[1, 4096]\"\n", " # t2311 = prims.broadcast_in_dim(t56, [1, 4096], [1]) # t2311: \"cuda:0 bf16[1, 4096]\"\n", " t2312 = torch.unsqueeze(t2311, 1) # t2312: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2312 = ltorch.unsqueeze(t2311, 1) # t2312: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2312 = prims.broadcast_in_dim(t2311, [1, 1, 4096], [0, 2]) # t2312: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2311\n", " t1429 = Tensor.expand(t2312, (1, 512, 4096)) # t1429: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1429 = ltorch.expand(t2312, (1, 512, 4096)) # t1429: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1429 = prims.broadcast_in_dim(t2312, (1, 512, 4096), (0, 1, 2)) # t1429: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2312\n", " t2332 = torch.unsqueeze(t81, 0) # t2332: \"cuda:0 bf16[1, 4096]\"\n", " # t2332 = ltorch.unsqueeze(t81, 0) # t2332: \"cuda:0 bf16[1, 4096]\"\n", " # t2332 = prims.broadcast_in_dim(t81, [1, 4096], [1]) # t2332: \"cuda:0 bf16[1, 4096]\"\n", " t2333 = torch.unsqueeze(t2332, 1) # t2333: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2333 = ltorch.unsqueeze(t2332, 1) # t2333: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2333 = prims.broadcast_in_dim(t2332, [1, 1, 4096], [0, 2]) # t2333: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2332\n", " t1501 = Tensor.expand(t2333, (1, 512, 4096)) # t1501: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1501 = ltorch.expand(t2333, (1, 512, 4096)) # t1501: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1501 = prims.broadcast_in_dim(t2333, (1, 512, 4096), (0, 1, 2)) # t1501: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2333\n", " t2335 = torch.unsqueeze(t57, 0) # t2335: \"cuda:0 bf16[1, 4096]\"\n", " # t2335 = ltorch.unsqueeze(t57, 0) # t2335: \"cuda:0 bf16[1, 4096]\"\n", " # t2335 = prims.broadcast_in_dim(t57, [1, 4096], [1]) # t2335: \"cuda:0 bf16[1, 4096]\"\n", " t2336 = torch.unsqueeze(t2335, 1) # t2336: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2336 = ltorch.unsqueeze(t2335, 1) # t2336: \"cuda:0 bf16[1, 1, 4096]\"\n", " # t2336 = prims.broadcast_in_dim(t2335, [1, 1, 4096], [0, 2]) # t2336: \"cuda:0 bf16[1, 1, 4096]\"\n", " del t2335\n", " t1537 = Tensor.expand(t2336, (1, 512, 4096)) # t1537: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1537 = ltorch.expand(t2336, (1, 512, 4096)) # t1537: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1537 = prims.broadcast_in_dim(t2336, (1, 512, 4096), (0, 1, 2)) # t1537: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t2336\n", " t2036 = torch.unsqueeze(t118, 0) # t2036: \"cuda:0 f32[1, 512, 128]\"\n", " # t2036 = ltorch.unsqueeze(t118, 0) # t2036: \"cuda:0 f32[1, 512, 128]\"\n", " # t2036 = prims.broadcast_in_dim(t118, [1, 512, 128], [1, 2]) # t2036: \"cuda:0 f32[1, 512, 128]\"\n", " del t118\n", " t2037 = torch.unsqueeze(t2036, 1) # t2037: \"cuda:0 f32[1, 1, 512, 128]\"\n", " # t2037 = ltorch.unsqueeze(t2036, 1) # t2037: \"cuda:0 f32[1, 1, 512, 128]\"\n", " # t2037 = prims.broadcast_in_dim(t2036, [1, 1, 512, 128], [0, 2, 3]) # t2037: \"cuda:0 f32[1, 1, 512, 128]\"\n", " del t2036\n", " t154 = Tensor.expand(t2037, (1, 32, 512, 128)) # t154: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t154 = ltorch.expand(t2037, (1, 32, 512, 128)) # t154: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t154 = prims.broadcast_in_dim(t2037, (1, 32, 512, 128), (0, 1, 2, 3)) # t154: \"cuda:0 f32[1, 32, 512, 128]\"\n", " del t2037\n", " t2039 = torch.unsqueeze(t119, 0) # t2039: \"cuda:0 f32[1, 512, 128]\"\n", " # t2039 = ltorch.unsqueeze(t119, 0) # t2039: \"cuda:0 f32[1, 512, 128]\"\n", " # t2039 = prims.broadcast_in_dim(t119, [1, 512, 128], [1, 2]) # t2039: \"cuda:0 f32[1, 512, 128]\"\n", " del t119\n", " t2040 = torch.unsqueeze(t2039, 1) # t2040: \"cuda:0 f32[1, 1, 512, 128]\"\n", " # t2040 = ltorch.unsqueeze(t2039, 1) # t2040: \"cuda:0 f32[1, 1, 512, 128]\"\n", " # t2040 = prims.broadcast_in_dim(t2039, [1, 1, 512, 128], [0, 2, 3]) # t2040: \"cuda:0 f32[1, 1, 512, 128]\"\n", " del t2039\n", " t157 = Tensor.expand(t2040, (1, 32, 512, 128)) # t157: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t157 = ltorch.expand(t2040, (1, 32, 512, 128)) # t157: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t157 = prims.broadcast_in_dim(t2040, (1, 32, 512, 128), (0, 1, 2, 3)) # t157: \"cuda:0 f32[1, 32, 512, 128]\"\n", " del t2040\n", " [t129, t137] = nvFusion0(t122, t133)\n", " # t123 = prims.convert_element_type(t122, dtypes.float32) # t123: \"cuda:0 f32[1, 512, 4096]\"\n", " # t124 = prims.mul(t123, t123) # t124: \"cuda:0 f32[1, 512, 4096]\"\n", " # t125 = prims.sum(t124, (2,)) # t125: \"cuda:0 f32[1, 512]\"\n", " # t126 = prims.broadcast_in_dim(t125, [1, 512, 1], [0, 1]) # t126: \"cuda:0 f32[1, 512, 1]\"\n", " # t127 = prims.div(t126, 4096.0) # t127: \"cuda:0 f32[1, 512, 1]\"\n", " # t128 = prims.add(t127, 1e-05) # t128: \"cuda:0 f32[1, 512, 1]\"\n", " # t129 = prims.rsqrt(t128) # t129: \"cuda:0 f32[1, 512, 1]\"\n", " # t130 = prims.broadcast_in_dim(t129, (1, 512, 4096), (0, 1, 2)) # t130: \"cuda:0 f32[1, 512, 4096]\"\n", " # t131 = prims.mul(t123, t130) # t131: \"cuda:0 f32[1, 512, 4096]\"\n", " # t135 = prims.convert_element_type(t133, dtypes.float32) # t135: \"cuda:0 f32[1, 512, 4096]\"\n", " # t136 = prims.mul(t131, t135) # t136: \"cuda:0 f32[1, 512, 4096]\"\n", " # t137 = prims.convert_element_type(t136, dtypes.bfloat16) # t137: \"cuda:0 bf16[1, 512, 4096]\"\n", " t138 = torch.nn.functional.linear(t137, t3, None) # t138: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t138 = ltorch.linear(t137, t3, None) # t138: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t138 = prims.linear(t137, t3, None) # t138: \"cuda:0 bf16[1, 512, 12288]\"\n", " t139 = torch.reshape(t138, (1, 512, 32, 3, 128)) # t139: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t139 = ltorch.reshape(t138, (1, 512, 32, 3, 128)) # t139: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t139 = prims.reshape(t138, (1, 512, 32, 3, 128)) # t139: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " del t138\n", " t140 = torch.permute(t139, (0, 2, 3, 1, 4)) # t140: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t140 = ltorch.permute(t139, (0, 2, 3, 1, 4)) # t140: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t140 = prims.transpose(t139, (0, 2, 3, 1, 4)) # t140: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " del t139\n", " (t141, t142, t143) = torch.split(t140, (1, 1, 1), 2)\n", " # (t141, t142, t143) = ltorch.split(t140, (1, 1, 1), 2)\n", " # t141 = prims.slice_prim(t140, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t141: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t142 = prims.slice_prim(t140, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t142: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t143 = prims.slice_prim(t140, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t143: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " del t140\n", " t144 = torch.reshape(t141, (1, 32, 512, 128)) # t144: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t144 = ltorch.reshape(t141, (1, 32, 512, 128)) # t144: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t144 = prims.reshape(t141, (1, 32, 512, 128)) # t144: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t141\n", " t145 = torch.reshape(t142, (1, 32, 512, 128)) # t145: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t145 = ltorch.reshape(t142, (1, 32, 512, 128)) # t145: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t145 = prims.reshape(t142, (1, 32, 512, 128)) # t145: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t142\n", " t146 = torch.reshape(t143, (1, 32, 512, 128)) # t146: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t146 = ltorch.reshape(t143, (1, 32, 512, 128)) # t146: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t146 = prims.reshape(t143, (1, 32, 512, 128)) # t146: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t143\n", " t147 = torch_slice_prim_impl(t144, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t147: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t162 = torch_slice_prim_impl(t145, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t162: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t177 = torch_slice_prim_impl(t144, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t177: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t144\n", " t179 = torch_slice_prim_impl(t145, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t179: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t145\n", " t149 = torch_slice_prim_impl(t147, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t149: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t148 = torch_slice_prim_impl(t147, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t148: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t163 = torch_slice_prim_impl(t162, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t163: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t164 = torch_slice_prim_impl(t162, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t164: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " [t152, t167] = nvFusion1(t147, t149, t162, t164)\n", " # t150 = prims.convert_element_type(t149, dtypes.float32) # t150: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t151 = prims.neg(t150) # t151: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t152 = prims.convert_element_type(t151, dtypes.bfloat16) # t152: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " # t165 = prims.convert_element_type(t164, dtypes.float32) # t165: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t166 = prims.neg(t165) # t166: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t167 = prims.convert_element_type(t166, dtypes.bfloat16) # t167: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " del t149, t164\n", " t168 = torch.cat((t167, t163), -1) # t168: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t168 = ltorch.cat((t167, t163), -1) # t168: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t168 = prims.cat((t167, t163), -1) # t168: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t167, t163\n", " t153 = torch.cat((t152, t148), -1) # t153: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t153 = ltorch.cat((t152, t148), -1) # t153: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t153 = prims.cat((t152, t148), -1) # t153: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t152, t148\n", " [t161, t176] = nvFusion2(t147, t153, t154, t157, t162, t168)\n", " # t155 = prims.convert_element_type(t147, dtypes.float32) # t155: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t170 = prims.convert_element_type(t162, dtypes.float32) # t170: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t156 = prims.mul(t155, t154) # t156: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t158 = prims.convert_element_type(t153, dtypes.float32) # t158: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t159 = prims.mul(t158, t157) # t159: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t160 = prims.add(t156, t159) # t160: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t161 = prims.convert_element_type(t160, dtypes.bfloat16) # t161: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t171 = prims.mul(t170, t154) # t171: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t173 = prims.convert_element_type(t168, dtypes.float32) # t173: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t174 = prims.mul(t173, t157) # t174: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t175 = prims.add(t171, t174) # t175: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t176 = prims.convert_element_type(t175, dtypes.bfloat16) # t176: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t147, t153, t162, t168\n", " t178 = torch.cat((t161, t177), -1) # t178: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t178 = ltorch.cat((t161, t177), -1) # t178: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t178 = prims.cat((t161, t177), -1) # t178: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t161, t177\n", " t180 = torch.cat((t176, t179), -1) # t180: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t180 = ltorch.cat((t176, t179), -1) # t180: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t180 = prims.cat((t176, t179), -1) # t180: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t176, t179\n", " (t181, t182, t183, t184, _, _, t185, t186, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t178, t180, t146, 0.0, True, scale=0.08838834764831843)\n", " t188 = torch.permute(t181, (0, 2, 1, 3)) # t188: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t188 = ltorch.permute(t181, (0, 2, 1, 3)) # t188: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t188 = prims.transpose(t181, (0, 2, 1, 3)) # t188: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " t189 = torch.reshape(t188, (1, 512, 4096)) # t189: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t189 = ltorch.reshape(t188, (1, 512, 4096)) # t189: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t189 = prims.reshape(t188, (1, 512, 4096)) # t189: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t188\n", " t190 = torch.nn.functional.linear(t189, t85, None) # t190: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t190 = ltorch.linear(t189, t85, None) # t190: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t190 = prims.linear(t189, t85, None) # t190: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t194, t201, t209] = nvFusion3(t122, t190, t205)\n", " # t191 = prims.convert_element_type(t190, dtypes.float32) # t191: \"cuda:0 f32[1, 512, 4096]\"\n", " # t192 = prims.convert_element_type(t122, dtypes.float32) # t192: \"cuda:0 f32[1, 512, 4096]\"\n", " # t193 = prims.add(t191, t192) # t193: \"cuda:0 f32[1, 512, 4096]\"\n", " # t194 = prims.convert_element_type(t193, dtypes.bfloat16) # t194: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t196 = prims.mul(t193, t193) # t196: \"cuda:0 f32[1, 512, 4096]\"\n", " # t197 = prims.sum(t196, (2,)) # t197: \"cuda:0 f32[1, 512]\"\n", " # t198 = prims.broadcast_in_dim(t197, [1, 512, 1], [0, 1]) # t198: \"cuda:0 f32[1, 512, 1]\"\n", " # t199 = prims.div(t198, 4096.0) # t199: \"cuda:0 f32[1, 512, 1]\"\n", " # t200 = prims.add(t199, 1e-05) # t200: \"cuda:0 f32[1, 512, 1]\"\n", " # t201 = prims.rsqrt(t200) # t201: \"cuda:0 f32[1, 512, 1]\"\n", " # t202 = prims.broadcast_in_dim(t201, (1, 512, 4096), (0, 1, 2)) # t202: \"cuda:0 f32[1, 512, 4096]\"\n", " # t203 = prims.mul(t193, t202) # t203: \"cuda:0 f32[1, 512, 4096]\"\n", " # t207 = prims.convert_element_type(t205, dtypes.float32) # t207: \"cuda:0 f32[1, 512, 4096]\"\n", " # t208 = prims.mul(t203, t207) # t208: \"cuda:0 f32[1, 512, 4096]\"\n", " # t209 = prims.convert_element_type(t208, dtypes.bfloat16) # t209: \"cuda:0 bf16[1, 512, 4096]\"\n", " t210 = torch.nn.functional.linear(t209, t19, None) # t210: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t210 = ltorch.linear(t209, t19, None) # t210: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t210 = prims.linear(t209, t19, None) # t210: \"cuda:0 bf16[1, 512, 11008]\"\n", " t211 = torch.nn.functional.linear(t209, t35, None) # t211: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t211 = ltorch.linear(t209, t35, None) # t211: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t211 = prims.linear(t209, t35, None) # t211: \"cuda:0 bf16[1, 512, 11008]\"\n", " [t225] = nvFusion4(t210, t211)\n", " # t212 = prims.convert_element_type(t210, dtypes.float32) # t212: \"cuda:0 f32[1, 512, 11008]\"\n", " # t213 = prims.neg(t212) # t213: \"cuda:0 f32[1, 512, 11008]\"\n", " # t214 = prims.exp(t213) # t214: \"cuda:0 f32[1, 512, 11008]\"\n", " # t215 = prims.add(1.0, t214) # t215: \"cuda:0 f32[1, 512, 11008]\"\n", " # t216 = prims.reciprocal(t215) # t216: \"cuda:0 f32[1, 512, 11008]\"\n", " # t220 = prims.mul(t212, t216) # t220: \"cuda:0 f32[1, 512, 11008]\"\n", " # t223 = prims.convert_element_type(t211, dtypes.float32) # t223: \"cuda:0 f32[1, 512, 11008]\"\n", " # t224 = prims.mul(t220, t223) # t224: \"cuda:0 f32[1, 512, 11008]\"\n", " # t225 = prims.convert_element_type(t224, dtypes.bfloat16) # t225: \"cuda:0 bf16[1, 512, 11008]\"\n", " t226 = torch.nn.functional.linear(t225, t86, None) # t226: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t226 = ltorch.linear(t225, t86, None) # t226: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t226 = prims.linear(t225, t86, None) # t226: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t230, t237, t245] = nvFusion5(t194, t226, t241)\n", " # t228 = prims.convert_element_type(t194, dtypes.float32) # t228: \"cuda:0 f32[1, 512, 4096]\"\n", " # t227 = prims.convert_element_type(t226, dtypes.float32) # t227: \"cuda:0 f32[1, 512, 4096]\"\n", " # t229 = prims.add(t227, t228) # t229: \"cuda:0 f32[1, 512, 4096]\"\n", " # t230 = prims.convert_element_type(t229, dtypes.bfloat16) # t230: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t232 = prims.mul(t229, t229) # t232: \"cuda:0 f32[1, 512, 4096]\"\n", " # t233 = prims.sum(t232, (2,)) # t233: \"cuda:0 f32[1, 512]\"\n", " # t234 = prims.broadcast_in_dim(t233, [1, 512, 1], [0, 1]) # t234: \"cuda:0 f32[1, 512, 1]\"\n", " # t235 = prims.div(t234, 4096.0) # t235: \"cuda:0 f32[1, 512, 1]\"\n", " # t236 = prims.add(t235, 1e-05) # t236: \"cuda:0 f32[1, 512, 1]\"\n", " # t237 = prims.rsqrt(t236) # t237: \"cuda:0 f32[1, 512, 1]\"\n", " # t238 = prims.broadcast_in_dim(t237, (1, 512, 4096), (0, 1, 2)) # t238: \"cuda:0 f32[1, 512, 4096]\"\n", " # t239 = prims.mul(t229, t238) # t239: \"cuda:0 f32[1, 512, 4096]\"\n", " # t243 = prims.convert_element_type(t241, dtypes.float32) # t243: \"cuda:0 f32[1, 512, 4096]\"\n", " # t244 = prims.mul(t239, t243) # t244: \"cuda:0 f32[1, 512, 4096]\"\n", " # t245 = prims.convert_element_type(t244, dtypes.bfloat16) # t245: \"cuda:0 bf16[1, 512, 4096]\"\n", " t246 = torch.nn.functional.linear(t245, t4, None) # t246: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t246 = ltorch.linear(t245, t4, None) # t246: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t246 = prims.linear(t245, t4, None) # t246: \"cuda:0 bf16[1, 512, 12288]\"\n", " t247 = torch.reshape(t246, (1, 512, 32, 3, 128)) # t247: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t247 = ltorch.reshape(t246, (1, 512, 32, 3, 128)) # t247: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t247 = prims.reshape(t246, (1, 512, 32, 3, 128)) # t247: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " del t246\n", " t248 = torch.permute(t247, (0, 2, 3, 1, 4)) # t248: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t248 = ltorch.permute(t247, (0, 2, 3, 1, 4)) # t248: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t248 = prims.transpose(t247, (0, 2, 3, 1, 4)) # t248: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " del t247\n", " (t249, t250, t251) = torch.split(t248, (1, 1, 1), 2)\n", " # (t249, t250, t251) = ltorch.split(t248, (1, 1, 1), 2)\n", " # t249 = prims.slice_prim(t248, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t249: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t250 = prims.slice_prim(t248, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t250: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t251 = prims.slice_prim(t248, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t251: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " del t248\n", " t252 = torch.reshape(t249, (1, 32, 512, 128)) # t252: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t252 = ltorch.reshape(t249, (1, 32, 512, 128)) # t252: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t252 = prims.reshape(t249, (1, 32, 512, 128)) # t252: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t249\n", " t253 = torch.reshape(t250, (1, 32, 512, 128)) # t253: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t253 = ltorch.reshape(t250, (1, 32, 512, 128)) # t253: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t253 = prims.reshape(t250, (1, 32, 512, 128)) # t253: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t250\n", " t254 = torch.reshape(t251, (1, 32, 512, 128)) # t254: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t254 = ltorch.reshape(t251, (1, 32, 512, 128)) # t254: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t254 = prims.reshape(t251, (1, 32, 512, 128)) # t254: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t251\n", " t285 = torch_slice_prim_impl(t252, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t285: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " t287 = torch_slice_prim_impl(t253, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t287: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " t255 = torch_slice_prim_impl(t252, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t255: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t252\n", " t270 = torch_slice_prim_impl(t253, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t270: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t253\n", " t256 = torch_slice_prim_impl(t255, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t256: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t257 = torch_slice_prim_impl(t255, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t257: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t272 = torch_slice_prim_impl(t270, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t272: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t271 = torch_slice_prim_impl(t270, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t271: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " [t260, t275] = nvFusion6(t255, t257, t270, t272)\n", " # t258 = prims.convert_element_type(t257, dtypes.float32) # t258: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t259 = prims.neg(t258) # t259: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t260 = prims.convert_element_type(t259, dtypes.bfloat16) # t260: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " # t273 = prims.convert_element_type(t272, dtypes.float32) # t273: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t274 = prims.neg(t273) # t274: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t275 = prims.convert_element_type(t274, dtypes.bfloat16) # t275: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " del t257, t272\n", " t261 = torch.cat((t260, t256), -1) # t261: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t261 = ltorch.cat((t260, t256), -1) # t261: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t261 = prims.cat((t260, t256), -1) # t261: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t260, t256\n", " t276 = torch.cat((t275, t271), -1) # t276: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t276 = ltorch.cat((t275, t271), -1) # t276: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t276 = prims.cat((t275, t271), -1) # t276: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t275, t271\n", " [t269, t284] = nvFusion7(t154, t157, t255, t261, t270, t276)\n", " # t263 = prims.convert_element_type(t255, dtypes.float32) # t263: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t278 = prims.convert_element_type(t270, dtypes.float32) # t278: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t264 = prims.mul(t263, t154) # t264: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t266 = prims.convert_element_type(t261, dtypes.float32) # t266: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t267 = prims.mul(t266, t157) # t267: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t268 = prims.add(t264, t267) # t268: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t269 = prims.convert_element_type(t268, dtypes.bfloat16) # t269: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t279 = prims.mul(t278, t154) # t279: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t281 = prims.convert_element_type(t276, dtypes.float32) # t281: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t282 = prims.mul(t281, t157) # t282: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t283 = prims.add(t279, t282) # t283: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t284 = prims.convert_element_type(t283, dtypes.bfloat16) # t284: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t255, t261, t270, t276\n", " t288 = torch.cat((t284, t287), -1) # t288: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t288 = ltorch.cat((t284, t287), -1) # t288: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t288 = prims.cat((t284, t287), -1) # t288: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t284, t287\n", " t286 = torch.cat((t269, t285), -1) # t286: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t286 = ltorch.cat((t269, t285), -1) # t286: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t286 = prims.cat((t269, t285), -1) # t286: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t269, t285\n", " (t289, t290, t291, t292, _, _, t293, t294, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t286, t288, t254, 0.0, True, scale=0.08838834764831843)\n", " t296 = torch.permute(t289, (0, 2, 1, 3)) # t296: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t296 = ltorch.permute(t289, (0, 2, 1, 3)) # t296: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t296 = prims.transpose(t289, (0, 2, 1, 3)) # t296: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " t297 = torch.reshape(t296, (1, 512, 4096)) # t297: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t297 = ltorch.reshape(t296, (1, 512, 4096)) # t297: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t297 = prims.reshape(t296, (1, 512, 4096)) # t297: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t296\n", " t298 = torch.nn.functional.linear(t297, t87, None) # t298: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t298 = ltorch.linear(t297, t87, None) # t298: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t298 = prims.linear(t297, t87, None) # t298: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t302, t309, t317] = nvFusion8(t230, t298, t313)\n", " # t300 = prims.convert_element_type(t230, dtypes.float32) # t300: \"cuda:0 f32[1, 512, 4096]\"\n", " # t299 = prims.convert_element_type(t298, dtypes.float32) # t299: \"cuda:0 f32[1, 512, 4096]\"\n", " # t301 = prims.add(t299, t300) # t301: \"cuda:0 f32[1, 512, 4096]\"\n", " # t302 = prims.convert_element_type(t301, dtypes.bfloat16) # t302: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t304 = prims.mul(t301, t301) # t304: \"cuda:0 f32[1, 512, 4096]\"\n", " # t305 = prims.sum(t304, (2,)) # t305: \"cuda:0 f32[1, 512]\"\n", " # t306 = prims.broadcast_in_dim(t305, [1, 512, 1], [0, 1]) # t306: \"cuda:0 f32[1, 512, 1]\"\n", " # t307 = prims.div(t306, 4096.0) # t307: \"cuda:0 f32[1, 512, 1]\"\n", " # t308 = prims.add(t307, 1e-05) # t308: \"cuda:0 f32[1, 512, 1]\"\n", " # t309 = prims.rsqrt(t308) # t309: \"cuda:0 f32[1, 512, 1]\"\n", " # t310 = prims.broadcast_in_dim(t309, (1, 512, 4096), (0, 1, 2)) # t310: \"cuda:0 f32[1, 512, 4096]\"\n", " # t311 = prims.mul(t301, t310) # t311: \"cuda:0 f32[1, 512, 4096]\"\n", " # t315 = prims.convert_element_type(t313, dtypes.float32) # t315: \"cuda:0 f32[1, 512, 4096]\"\n", " # t316 = prims.mul(t311, t315) # t316: \"cuda:0 f32[1, 512, 4096]\"\n", " # t317 = prims.convert_element_type(t316, dtypes.bfloat16) # t317: \"cuda:0 bf16[1, 512, 4096]\"\n", " t318 = torch.nn.functional.linear(t317, t20, None) # t318: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t318 = ltorch.linear(t317, t20, None) # t318: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t318 = prims.linear(t317, t20, None) # t318: \"cuda:0 bf16[1, 512, 11008]\"\n", " t319 = torch.nn.functional.linear(t317, t36, None) # t319: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t319 = ltorch.linear(t317, t36, None) # t319: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t319 = prims.linear(t317, t36, None) # t319: \"cuda:0 bf16[1, 512, 11008]\"\n", " [t333] = nvFusion9(t318, t319)\n", " # t320 = prims.convert_element_type(t318, dtypes.float32) # t320: \"cuda:0 f32[1, 512, 11008]\"\n", " # t321 = prims.neg(t320) # t321: \"cuda:0 f32[1, 512, 11008]\"\n", " # t322 = prims.exp(t321) # t322: \"cuda:0 f32[1, 512, 11008]\"\n", " # t323 = prims.add(1.0, t322) # t323: \"cuda:0 f32[1, 512, 11008]\"\n", " # t324 = prims.reciprocal(t323) # t324: \"cuda:0 f32[1, 512, 11008]\"\n", " # t328 = prims.mul(t320, t324) # t328: \"cuda:0 f32[1, 512, 11008]\"\n", " # t331 = prims.convert_element_type(t319, dtypes.float32) # t331: \"cuda:0 f32[1, 512, 11008]\"\n", " # t332 = prims.mul(t328, t331) # t332: \"cuda:0 f32[1, 512, 11008]\"\n", " # t333 = prims.convert_element_type(t332, dtypes.bfloat16) # t333: \"cuda:0 bf16[1, 512, 11008]\"\n", " t334 = torch.nn.functional.linear(t333, t88, None) # t334: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t334 = ltorch.linear(t333, t88, None) # t334: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t334 = prims.linear(t333, t88, None) # t334: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t338, t345, t353] = nvFusion10(t302, t334, t349)\n", " # t336 = prims.convert_element_type(t302, dtypes.float32) # t336: \"cuda:0 f32[1, 512, 4096]\"\n", " # t335 = prims.convert_element_type(t334, dtypes.float32) # t335: \"cuda:0 f32[1, 512, 4096]\"\n", " # t337 = prims.add(t335, t336) # t337: \"cuda:0 f32[1, 512, 4096]\"\n", " # t338 = prims.convert_element_type(t337, dtypes.bfloat16) # t338: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t340 = prims.mul(t337, t337) # t340: \"cuda:0 f32[1, 512, 4096]\"\n", " # t341 = prims.sum(t340, (2,)) # t341: \"cuda:0 f32[1, 512]\"\n", " # t342 = prims.broadcast_in_dim(t341, [1, 512, 1], [0, 1]) # t342: \"cuda:0 f32[1, 512, 1]\"\n", " # t343 = prims.div(t342, 4096.0) # t343: \"cuda:0 f32[1, 512, 1]\"\n", " # t344 = prims.add(t343, 1e-05) # t344: \"cuda:0 f32[1, 512, 1]\"\n", " # t345 = prims.rsqrt(t344) # t345: \"cuda:0 f32[1, 512, 1]\"\n", " # t346 = prims.broadcast_in_dim(t345, (1, 512, 4096), (0, 1, 2)) # t346: \"cuda:0 f32[1, 512, 4096]\"\n", " # t347 = prims.mul(t337, t346) # t347: \"cuda:0 f32[1, 512, 4096]\"\n", " # t351 = prims.convert_element_type(t349, dtypes.float32) # t351: \"cuda:0 f32[1, 512, 4096]\"\n", " # t352 = prims.mul(t347, t351) # t352: \"cuda:0 f32[1, 512, 4096]\"\n", " # t353 = prims.convert_element_type(t352, dtypes.bfloat16) # t353: \"cuda:0 bf16[1, 512, 4096]\"\n", " t354 = torch.nn.functional.linear(t353, t5, None) # t354: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t354 = ltorch.linear(t353, t5, None) # t354: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t354 = prims.linear(t353, t5, None) # t354: \"cuda:0 bf16[1, 512, 12288]\"\n", " t355 = torch.reshape(t354, (1, 512, 32, 3, 128)) # t355: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t355 = ltorch.reshape(t354, (1, 512, 32, 3, 128)) # t355: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t355 = prims.reshape(t354, (1, 512, 32, 3, 128)) # t355: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " del t354\n", " t356 = torch.permute(t355, (0, 2, 3, 1, 4)) # t356: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t356 = ltorch.permute(t355, (0, 2, 3, 1, 4)) # t356: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t356 = prims.transpose(t355, (0, 2, 3, 1, 4)) # t356: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " del t355\n", " (t357, t358, t359) = torch.split(t356, (1, 1, 1), 2)\n", " # (t357, t358, t359) = ltorch.split(t356, (1, 1, 1), 2)\n", " # t357 = prims.slice_prim(t356, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t357: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t358 = prims.slice_prim(t356, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t358: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t359 = prims.slice_prim(t356, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t359: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " del t356\n", " t360 = torch.reshape(t357, (1, 32, 512, 128)) # t360: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t360 = ltorch.reshape(t357, (1, 32, 512, 128)) # t360: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t360 = prims.reshape(t357, (1, 32, 512, 128)) # t360: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t357\n", " t361 = torch.reshape(t358, (1, 32, 512, 128)) # t361: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t361 = ltorch.reshape(t358, (1, 32, 512, 128)) # t361: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t361 = prims.reshape(t358, (1, 32, 512, 128)) # t361: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t358\n", " t362 = torch.reshape(t359, (1, 32, 512, 128)) # t362: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t362 = ltorch.reshape(t359, (1, 32, 512, 128)) # t362: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t362 = prims.reshape(t359, (1, 32, 512, 128)) # t362: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t359\n", " t363 = torch_slice_prim_impl(t360, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t363: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t378 = torch_slice_prim_impl(t361, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t378: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t393 = torch_slice_prim_impl(t360, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t393: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t360\n", " t395 = torch_slice_prim_impl(t361, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t395: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t361\n", " t364 = torch_slice_prim_impl(t363, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t364: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t365 = torch_slice_prim_impl(t363, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t365: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t379 = torch_slice_prim_impl(t378, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t379: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t380 = torch_slice_prim_impl(t378, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t380: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " [t368, t383] = nvFusion11(t363, t365, t378, t380)\n", " # t366 = prims.convert_element_type(t365, dtypes.float32) # t366: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t367 = prims.neg(t366) # t367: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t368 = prims.convert_element_type(t367, dtypes.bfloat16) # t368: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " # t381 = prims.convert_element_type(t380, dtypes.float32) # t381: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t382 = prims.neg(t381) # t382: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t383 = prims.convert_element_type(t382, dtypes.bfloat16) # t383: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " del t365, t380\n", " t369 = torch.cat((t368, t364), -1) # t369: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t369 = ltorch.cat((t368, t364), -1) # t369: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t369 = prims.cat((t368, t364), -1) # t369: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t368, t364\n", " t384 = torch.cat((t383, t379), -1) # t384: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t384 = ltorch.cat((t383, t379), -1) # t384: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t384 = prims.cat((t383, t379), -1) # t384: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t383, t379\n", " [t377, t392] = nvFusion12(t154, t157, t363, t369, t378, t384)\n", " # t371 = prims.convert_element_type(t363, dtypes.float32) # t371: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t386 = prims.convert_element_type(t378, dtypes.float32) # t386: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t372 = prims.mul(t371, t154) # t372: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t374 = prims.convert_element_type(t369, dtypes.float32) # t374: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t375 = prims.mul(t374, t157) # t375: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t376 = prims.add(t372, t375) # t376: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t377 = prims.convert_element_type(t376, dtypes.bfloat16) # t377: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t387 = prims.mul(t386, t154) # t387: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t389 = prims.convert_element_type(t384, dtypes.float32) # t389: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t390 = prims.mul(t389, t157) # t390: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t391 = prims.add(t387, t390) # t391: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t392 = prims.convert_element_type(t391, dtypes.bfloat16) # t392: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t363, t369, t378, t384\n", " t394 = torch.cat((t377, t393), -1) # t394: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t394 = ltorch.cat((t377, t393), -1) # t394: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t394 = prims.cat((t377, t393), -1) # t394: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t377, t393\n", " t396 = torch.cat((t392, t395), -1) # t396: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t396 = ltorch.cat((t392, t395), -1) # t396: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t396 = prims.cat((t392, t395), -1) # t396: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t392, t395\n", " (t397, t398, t399, t400, _, _, t401, t402, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t394, t396, t362, 0.0, True, scale=0.08838834764831843)\n", " t404 = torch.permute(t397, (0, 2, 1, 3)) # t404: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t404 = ltorch.permute(t397, (0, 2, 1, 3)) # t404: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t404 = prims.transpose(t397, (0, 2, 1, 3)) # t404: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " t405 = torch.reshape(t404, (1, 512, 4096)) # t405: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t405 = ltorch.reshape(t404, (1, 512, 4096)) # t405: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t405 = prims.reshape(t404, (1, 512, 4096)) # t405: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t404\n", " t406 = torch.nn.functional.linear(t405, t89, None) # t406: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t406 = ltorch.linear(t405, t89, None) # t406: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t406 = prims.linear(t405, t89, None) # t406: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t410, t417, t425] = nvFusion13(t338, t406, t421)\n", " # t408 = prims.convert_element_type(t338, dtypes.float32) # t408: \"cuda:0 f32[1, 512, 4096]\"\n", " # t407 = prims.convert_element_type(t406, dtypes.float32) # t407: \"cuda:0 f32[1, 512, 4096]\"\n", " # t409 = prims.add(t407, t408) # t409: \"cuda:0 f32[1, 512, 4096]\"\n", " # t410 = prims.convert_element_type(t409, dtypes.bfloat16) # t410: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t412 = prims.mul(t409, t409) # t412: \"cuda:0 f32[1, 512, 4096]\"\n", " # t413 = prims.sum(t412, (2,)) # t413: \"cuda:0 f32[1, 512]\"\n", " # t414 = prims.broadcast_in_dim(t413, [1, 512, 1], [0, 1]) # t414: \"cuda:0 f32[1, 512, 1]\"\n", " # t415 = prims.div(t414, 4096.0) # t415: \"cuda:0 f32[1, 512, 1]\"\n", " # t416 = prims.add(t415, 1e-05) # t416: \"cuda:0 f32[1, 512, 1]\"\n", " # t417 = prims.rsqrt(t416) # t417: \"cuda:0 f32[1, 512, 1]\"\n", " # t418 = prims.broadcast_in_dim(t417, (1, 512, 4096), (0, 1, 2)) # t418: \"cuda:0 f32[1, 512, 4096]\"\n", " # t419 = prims.mul(t409, t418) # t419: \"cuda:0 f32[1, 512, 4096]\"\n", " # t423 = prims.convert_element_type(t421, dtypes.float32) # t423: \"cuda:0 f32[1, 512, 4096]\"\n", " # t424 = prims.mul(t419, t423) # t424: \"cuda:0 f32[1, 512, 4096]\"\n", " # t425 = prims.convert_element_type(t424, dtypes.bfloat16) # t425: \"cuda:0 bf16[1, 512, 4096]\"\n", " t426 = torch.nn.functional.linear(t425, t21, None) # t426: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t426 = ltorch.linear(t425, t21, None) # t426: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t426 = prims.linear(t425, t21, None) # t426: \"cuda:0 bf16[1, 512, 11008]\"\n", " t427 = torch.nn.functional.linear(t425, t37, None) # t427: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t427 = ltorch.linear(t425, t37, None) # t427: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t427 = prims.linear(t425, t37, None) # t427: \"cuda:0 bf16[1, 512, 11008]\"\n", " [t441] = nvFusion14(t426, t427)\n", " # t428 = prims.convert_element_type(t426, dtypes.float32) # t428: \"cuda:0 f32[1, 512, 11008]\"\n", " # t429 = prims.neg(t428) # t429: \"cuda:0 f32[1, 512, 11008]\"\n", " # t430 = prims.exp(t429) # t430: \"cuda:0 f32[1, 512, 11008]\"\n", " # t431 = prims.add(1.0, t430) # t431: \"cuda:0 f32[1, 512, 11008]\"\n", " # t432 = prims.reciprocal(t431) # t432: \"cuda:0 f32[1, 512, 11008]\"\n", " # t436 = prims.mul(t428, t432) # t436: \"cuda:0 f32[1, 512, 11008]\"\n", " # t439 = prims.convert_element_type(t427, dtypes.float32) # t439: \"cuda:0 f32[1, 512, 11008]\"\n", " # t440 = prims.mul(t436, t439) # t440: \"cuda:0 f32[1, 512, 11008]\"\n", " # t441 = prims.convert_element_type(t440, dtypes.bfloat16) # t441: \"cuda:0 bf16[1, 512, 11008]\"\n", " t442 = torch.nn.functional.linear(t441, t90, None) # t442: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t442 = ltorch.linear(t441, t90, None) # t442: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t442 = prims.linear(t441, t90, None) # t442: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t446, t453, t461] = nvFusion15(t410, t442, t457)\n", " # t444 = prims.convert_element_type(t410, dtypes.float32) # t444: \"cuda:0 f32[1, 512, 4096]\"\n", " # t443 = prims.convert_element_type(t442, dtypes.float32) # t443: \"cuda:0 f32[1, 512, 4096]\"\n", " # t445 = prims.add(t443, t444) # t445: \"cuda:0 f32[1, 512, 4096]\"\n", " # t446 = prims.convert_element_type(t445, dtypes.bfloat16) # t446: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t448 = prims.mul(t445, t445) # t448: \"cuda:0 f32[1, 512, 4096]\"\n", " # t449 = prims.sum(t448, (2,)) # t449: \"cuda:0 f32[1, 512]\"\n", " # t450 = prims.broadcast_in_dim(t449, [1, 512, 1], [0, 1]) # t450: \"cuda:0 f32[1, 512, 1]\"\n", " # t451 = prims.div(t450, 4096.0) # t451: \"cuda:0 f32[1, 512, 1]\"\n", " # t452 = prims.add(t451, 1e-05) # t452: \"cuda:0 f32[1, 512, 1]\"\n", " # t453 = prims.rsqrt(t452) # t453: \"cuda:0 f32[1, 512, 1]\"\n", " # t454 = prims.broadcast_in_dim(t453, (1, 512, 4096), (0, 1, 2)) # t454: \"cuda:0 f32[1, 512, 4096]\"\n", " # t455 = prims.mul(t445, t454) # t455: \"cuda:0 f32[1, 512, 4096]\"\n", " # t459 = prims.convert_element_type(t457, dtypes.float32) # t459: \"cuda:0 f32[1, 512, 4096]\"\n", " # t460 = prims.mul(t455, t459) # t460: \"cuda:0 f32[1, 512, 4096]\"\n", " # t461 = prims.convert_element_type(t460, dtypes.bfloat16) # t461: \"cuda:0 bf16[1, 512, 4096]\"\n", " t462 = torch.nn.functional.linear(t461, t6, None) # t462: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t462 = ltorch.linear(t461, t6, None) # t462: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t462 = prims.linear(t461, t6, None) # t462: \"cuda:0 bf16[1, 512, 12288]\"\n", " t463 = torch.reshape(t462, (1, 512, 32, 3, 128)) # t463: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t463 = ltorch.reshape(t462, (1, 512, 32, 3, 128)) # t463: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t463 = prims.reshape(t462, (1, 512, 32, 3, 128)) # t463: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " del t462\n", " t464 = torch.permute(t463, (0, 2, 3, 1, 4)) # t464: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t464 = ltorch.permute(t463, (0, 2, 3, 1, 4)) # t464: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t464 = prims.transpose(t463, (0, 2, 3, 1, 4)) # t464: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " del t463\n", " (t465, t466, t467) = torch.split(t464, (1, 1, 1), 2)\n", " # (t465, t466, t467) = ltorch.split(t464, (1, 1, 1), 2)\n", " # t465 = prims.slice_prim(t464, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t465: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t466 = prims.slice_prim(t464, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t466: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t467 = prims.slice_prim(t464, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t467: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " del t464\n", " t468 = torch.reshape(t465, (1, 32, 512, 128)) # t468: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t468 = ltorch.reshape(t465, (1, 32, 512, 128)) # t468: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t468 = prims.reshape(t465, (1, 32, 512, 128)) # t468: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t465\n", " t469 = torch.reshape(t466, (1, 32, 512, 128)) # t469: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t469 = ltorch.reshape(t466, (1, 32, 512, 128)) # t469: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t469 = prims.reshape(t466, (1, 32, 512, 128)) # t469: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t466\n", " t470 = torch.reshape(t467, (1, 32, 512, 128)) # t470: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t470 = ltorch.reshape(t467, (1, 32, 512, 128)) # t470: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t470 = prims.reshape(t467, (1, 32, 512, 128)) # t470: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t467\n", " t471 = torch_slice_prim_impl(t468, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t471: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t486 = torch_slice_prim_impl(t469, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t486: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t501 = torch_slice_prim_impl(t468, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t501: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t468\n", " t503 = torch_slice_prim_impl(t469, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t503: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t469\n", " t472 = torch_slice_prim_impl(t471, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t472: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t473 = torch_slice_prim_impl(t471, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t473: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t487 = torch_slice_prim_impl(t486, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t487: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t488 = torch_slice_prim_impl(t486, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t488: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " [t476, t491] = nvFusion16(t471, t473, t486, t488)\n", " # t474 = prims.convert_element_type(t473, dtypes.float32) # t474: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t475 = prims.neg(t474) # t475: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t476 = prims.convert_element_type(t475, dtypes.bfloat16) # t476: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " # t489 = prims.convert_element_type(t488, dtypes.float32) # t489: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t490 = prims.neg(t489) # t490: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t491 = prims.convert_element_type(t490, dtypes.bfloat16) # t491: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " del t473, t488\n", " t477 = torch.cat((t476, t472), -1) # t477: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t477 = ltorch.cat((t476, t472), -1) # t477: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t477 = prims.cat((t476, t472), -1) # t477: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t476, t472\n", " t492 = torch.cat((t491, t487), -1) # t492: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t492 = ltorch.cat((t491, t487), -1) # t492: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t492 = prims.cat((t491, t487), -1) # t492: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t491, t487\n", " [t485, t500] = nvFusion17(t154, t157, t471, t477, t486, t492)\n", " # t479 = prims.convert_element_type(t471, dtypes.float32) # t479: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t494 = prims.convert_element_type(t486, dtypes.float32) # t494: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t480 = prims.mul(t479, t154) # t480: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t482 = prims.convert_element_type(t477, dtypes.float32) # t482: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t483 = prims.mul(t482, t157) # t483: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t484 = prims.add(t480, t483) # t484: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t485 = prims.convert_element_type(t484, dtypes.bfloat16) # t485: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t495 = prims.mul(t494, t154) # t495: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t497 = prims.convert_element_type(t492, dtypes.float32) # t497: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t498 = prims.mul(t497, t157) # t498: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t499 = prims.add(t495, t498) # t499: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t500 = prims.convert_element_type(t499, dtypes.bfloat16) # t500: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t471, t477, t486, t492\n", " t502 = torch.cat((t485, t501), -1) # t502: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t502 = ltorch.cat((t485, t501), -1) # t502: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t502 = prims.cat((t485, t501), -1) # t502: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t485, t501\n", " t504 = torch.cat((t500, t503), -1) # t504: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t504 = ltorch.cat((t500, t503), -1) # t504: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t504 = prims.cat((t500, t503), -1) # t504: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t500, t503\n", " (t505, t506, t507, t508, _, _, t509, t510, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t502, t504, t470, 0.0, True, scale=0.08838834764831843)\n", " t512 = torch.permute(t505, (0, 2, 1, 3)) # t512: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t512 = ltorch.permute(t505, (0, 2, 1, 3)) # t512: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t512 = prims.transpose(t505, (0, 2, 1, 3)) # t512: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " t513 = torch.reshape(t512, (1, 512, 4096)) # t513: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t513 = ltorch.reshape(t512, (1, 512, 4096)) # t513: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t513 = prims.reshape(t512, (1, 512, 4096)) # t513: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t512\n", " t514 = torch.nn.functional.linear(t513, t91, None) # t514: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t514 = ltorch.linear(t513, t91, None) # t514: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t514 = prims.linear(t513, t91, None) # t514: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t518, t525, t533] = nvFusion18(t446, t514, t529)\n", " # t516 = prims.convert_element_type(t446, dtypes.float32) # t516: \"cuda:0 f32[1, 512, 4096]\"\n", " # t515 = prims.convert_element_type(t514, dtypes.float32) # t515: \"cuda:0 f32[1, 512, 4096]\"\n", " # t517 = prims.add(t515, t516) # t517: \"cuda:0 f32[1, 512, 4096]\"\n", " # t518 = prims.convert_element_type(t517, dtypes.bfloat16) # t518: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t520 = prims.mul(t517, t517) # t520: \"cuda:0 f32[1, 512, 4096]\"\n", " # t521 = prims.sum(t520, (2,)) # t521: \"cuda:0 f32[1, 512]\"\n", " # t522 = prims.broadcast_in_dim(t521, [1, 512, 1], [0, 1]) # t522: \"cuda:0 f32[1, 512, 1]\"\n", " # t523 = prims.div(t522, 4096.0) # t523: \"cuda:0 f32[1, 512, 1]\"\n", " # t524 = prims.add(t523, 1e-05) # t524: \"cuda:0 f32[1, 512, 1]\"\n", " # t525 = prims.rsqrt(t524) # t525: \"cuda:0 f32[1, 512, 1]\"\n", " # t526 = prims.broadcast_in_dim(t525, (1, 512, 4096), (0, 1, 2)) # t526: \"cuda:0 f32[1, 512, 4096]\"\n", " # t527 = prims.mul(t517, t526) # t527: \"cuda:0 f32[1, 512, 4096]\"\n", " # t531 = prims.convert_element_type(t529, dtypes.float32) # t531: \"cuda:0 f32[1, 512, 4096]\"\n", " # t532 = prims.mul(t527, t531) # t532: \"cuda:0 f32[1, 512, 4096]\"\n", " # t533 = prims.convert_element_type(t532, dtypes.bfloat16) # t533: \"cuda:0 bf16[1, 512, 4096]\"\n", " t534 = torch.nn.functional.linear(t533, t22, None) # t534: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t534 = ltorch.linear(t533, t22, None) # t534: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t534 = prims.linear(t533, t22, None) # t534: \"cuda:0 bf16[1, 512, 11008]\"\n", " t535 = torch.nn.functional.linear(t533, t38, None) # t535: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t535 = ltorch.linear(t533, t38, None) # t535: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t535 = prims.linear(t533, t38, None) # t535: \"cuda:0 bf16[1, 512, 11008]\"\n", " [t549] = nvFusion19(t534, t535)\n", " # t536 = prims.convert_element_type(t534, dtypes.float32) # t536: \"cuda:0 f32[1, 512, 11008]\"\n", " # t537 = prims.neg(t536) # t537: \"cuda:0 f32[1, 512, 11008]\"\n", " # t538 = prims.exp(t537) # t538: \"cuda:0 f32[1, 512, 11008]\"\n", " # t539 = prims.add(1.0, t538) # t539: \"cuda:0 f32[1, 512, 11008]\"\n", " # t540 = prims.reciprocal(t539) # t540: \"cuda:0 f32[1, 512, 11008]\"\n", " # t544 = prims.mul(t536, t540) # t544: \"cuda:0 f32[1, 512, 11008]\"\n", " # t547 = prims.convert_element_type(t535, dtypes.float32) # t547: \"cuda:0 f32[1, 512, 11008]\"\n", " # t548 = prims.mul(t544, t547) # t548: \"cuda:0 f32[1, 512, 11008]\"\n", " # t549 = prims.convert_element_type(t548, dtypes.bfloat16) # t549: \"cuda:0 bf16[1, 512, 11008]\"\n", " t550 = torch.nn.functional.linear(t549, t92, None) # t550: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t550 = ltorch.linear(t549, t92, None) # t550: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t550 = prims.linear(t549, t92, None) # t550: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t554, t561, t569] = nvFusion20(t518, t550, t565)\n", " # t552 = prims.convert_element_type(t518, dtypes.float32) # t552: \"cuda:0 f32[1, 512, 4096]\"\n", " # t551 = prims.convert_element_type(t550, dtypes.float32) # t551: \"cuda:0 f32[1, 512, 4096]\"\n", " # t553 = prims.add(t551, t552) # t553: \"cuda:0 f32[1, 512, 4096]\"\n", " # t554 = prims.convert_element_type(t553, dtypes.bfloat16) # t554: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t556 = prims.mul(t553, t553) # t556: \"cuda:0 f32[1, 512, 4096]\"\n", " # t557 = prims.sum(t556, (2,)) # t557: \"cuda:0 f32[1, 512]\"\n", " # t558 = prims.broadcast_in_dim(t557, [1, 512, 1], [0, 1]) # t558: \"cuda:0 f32[1, 512, 1]\"\n", " # t559 = prims.div(t558, 4096.0) # t559: \"cuda:0 f32[1, 512, 1]\"\n", " # t560 = prims.add(t559, 1e-05) # t560: \"cuda:0 f32[1, 512, 1]\"\n", " # t561 = prims.rsqrt(t560) # t561: \"cuda:0 f32[1, 512, 1]\"\n", " # t562 = prims.broadcast_in_dim(t561, (1, 512, 4096), (0, 1, 2)) # t562: \"cuda:0 f32[1, 512, 4096]\"\n", " # t563 = prims.mul(t553, t562) # t563: \"cuda:0 f32[1, 512, 4096]\"\n", " # t567 = prims.convert_element_type(t565, dtypes.float32) # t567: \"cuda:0 f32[1, 512, 4096]\"\n", " # t568 = prims.mul(t563, t567) # t568: \"cuda:0 f32[1, 512, 4096]\"\n", " # t569 = prims.convert_element_type(t568, dtypes.bfloat16) # t569: \"cuda:0 bf16[1, 512, 4096]\"\n", " t570 = torch.nn.functional.linear(t569, t7, None) # t570: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t570 = ltorch.linear(t569, t7, None) # t570: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t570 = prims.linear(t569, t7, None) # t570: \"cuda:0 bf16[1, 512, 12288]\"\n", " t571 = torch.reshape(t570, (1, 512, 32, 3, 128)) # t571: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t571 = ltorch.reshape(t570, (1, 512, 32, 3, 128)) # t571: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t571 = prims.reshape(t570, (1, 512, 32, 3, 128)) # t571: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " del t570\n", " t572 = torch.permute(t571, (0, 2, 3, 1, 4)) # t572: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t572 = ltorch.permute(t571, (0, 2, 3, 1, 4)) # t572: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t572 = prims.transpose(t571, (0, 2, 3, 1, 4)) # t572: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " del t571\n", " (t573, t574, t575) = torch.split(t572, (1, 1, 1), 2)\n", " # (t573, t574, t575) = ltorch.split(t572, (1, 1, 1), 2)\n", " # t573 = prims.slice_prim(t572, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t573: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t574 = prims.slice_prim(t572, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t574: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t575 = prims.slice_prim(t572, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t575: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " del t572\n", " t576 = torch.reshape(t573, (1, 32, 512, 128)) # t576: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t576 = ltorch.reshape(t573, (1, 32, 512, 128)) # t576: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t576 = prims.reshape(t573, (1, 32, 512, 128)) # t576: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t573\n", " t577 = torch.reshape(t574, (1, 32, 512, 128)) # t577: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t577 = ltorch.reshape(t574, (1, 32, 512, 128)) # t577: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t577 = prims.reshape(t574, (1, 32, 512, 128)) # t577: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t574\n", " t578 = torch.reshape(t575, (1, 32, 512, 128)) # t578: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t578 = ltorch.reshape(t575, (1, 32, 512, 128)) # t578: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t578 = prims.reshape(t575, (1, 32, 512, 128)) # t578: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t575\n", " t579 = torch_slice_prim_impl(t576, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t579: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t594 = torch_slice_prim_impl(t577, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t594: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t609 = torch_slice_prim_impl(t576, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t609: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t576\n", " t611 = torch_slice_prim_impl(t577, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t611: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t577\n", " t580 = torch_slice_prim_impl(t579, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t580: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t581 = torch_slice_prim_impl(t579, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t581: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t595 = torch_slice_prim_impl(t594, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t595: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t596 = torch_slice_prim_impl(t594, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t596: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " [t584, t599] = nvFusion21(t579, t581, t594, t596)\n", " # t582 = prims.convert_element_type(t581, dtypes.float32) # t582: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t583 = prims.neg(t582) # t583: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t584 = prims.convert_element_type(t583, dtypes.bfloat16) # t584: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " # t597 = prims.convert_element_type(t596, dtypes.float32) # t597: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t598 = prims.neg(t597) # t598: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t599 = prims.convert_element_type(t598, dtypes.bfloat16) # t599: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " del t581, t596\n", " t600 = torch.cat((t599, t595), -1) # t600: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t600 = ltorch.cat((t599, t595), -1) # t600: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t600 = prims.cat((t599, t595), -1) # t600: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t599, t595\n", " t585 = torch.cat((t584, t580), -1) # t585: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t585 = ltorch.cat((t584, t580), -1) # t585: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t585 = prims.cat((t584, t580), -1) # t585: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t584, t580\n", " [t593, t608] = nvFusion22(t154, t157, t579, t585, t594, t600)\n", " # t587 = prims.convert_element_type(t579, dtypes.float32) # t587: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t602 = prims.convert_element_type(t594, dtypes.float32) # t602: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t603 = prims.mul(t602, t154) # t603: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t605 = prims.convert_element_type(t600, dtypes.float32) # t605: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t606 = prims.mul(t605, t157) # t606: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t607 = prims.add(t603, t606) # t607: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t608 = prims.convert_element_type(t607, dtypes.bfloat16) # t608: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t588 = prims.mul(t587, t154) # t588: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t590 = prims.convert_element_type(t585, dtypes.float32) # t590: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t591 = prims.mul(t590, t157) # t591: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t592 = prims.add(t588, t591) # t592: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t593 = prims.convert_element_type(t592, dtypes.bfloat16) # t593: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t579, t585, t594, t600\n", " t612 = torch.cat((t608, t611), -1) # t612: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t612 = ltorch.cat((t608, t611), -1) # t612: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t612 = prims.cat((t608, t611), -1) # t612: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t608, t611\n", " t610 = torch.cat((t593, t609), -1) # t610: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t610 = ltorch.cat((t593, t609), -1) # t610: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t610 = prims.cat((t593, t609), -1) # t610: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t593, t609\n", " (t613, t614, t615, t616, _, _, t617, t618, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t610, t612, t578, 0.0, True, scale=0.08838834764831843)\n", " t620 = torch.permute(t613, (0, 2, 1, 3)) # t620: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t620 = ltorch.permute(t613, (0, 2, 1, 3)) # t620: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t620 = prims.transpose(t613, (0, 2, 1, 3)) # t620: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " t621 = torch.reshape(t620, (1, 512, 4096)) # t621: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t621 = ltorch.reshape(t620, (1, 512, 4096)) # t621: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t621 = prims.reshape(t620, (1, 512, 4096)) # t621: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t620\n", " t622 = torch.nn.functional.linear(t621, t93, None) # t622: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t622 = ltorch.linear(t621, t93, None) # t622: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t622 = prims.linear(t621, t93, None) # t622: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t626, t633, t641] = nvFusion23(t554, t622, t637)\n", " # t624 = prims.convert_element_type(t554, dtypes.float32) # t624: \"cuda:0 f32[1, 512, 4096]\"\n", " # t623 = prims.convert_element_type(t622, dtypes.float32) # t623: \"cuda:0 f32[1, 512, 4096]\"\n", " # t625 = prims.add(t623, t624) # t625: \"cuda:0 f32[1, 512, 4096]\"\n", " # t626 = prims.convert_element_type(t625, dtypes.bfloat16) # t626: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t628 = prims.mul(t625, t625) # t628: \"cuda:0 f32[1, 512, 4096]\"\n", " # t629 = prims.sum(t628, (2,)) # t629: \"cuda:0 f32[1, 512]\"\n", " # t630 = prims.broadcast_in_dim(t629, [1, 512, 1], [0, 1]) # t630: \"cuda:0 f32[1, 512, 1]\"\n", " # t631 = prims.div(t630, 4096.0) # t631: \"cuda:0 f32[1, 512, 1]\"\n", " # t632 = prims.add(t631, 1e-05) # t632: \"cuda:0 f32[1, 512, 1]\"\n", " # t633 = prims.rsqrt(t632) # t633: \"cuda:0 f32[1, 512, 1]\"\n", " # t634 = prims.broadcast_in_dim(t633, (1, 512, 4096), (0, 1, 2)) # t634: \"cuda:0 f32[1, 512, 4096]\"\n", " # t635 = prims.mul(t625, t634) # t635: \"cuda:0 f32[1, 512, 4096]\"\n", " # t639 = prims.convert_element_type(t637, dtypes.float32) # t639: \"cuda:0 f32[1, 512, 4096]\"\n", " # t640 = prims.mul(t635, t639) # t640: \"cuda:0 f32[1, 512, 4096]\"\n", " # t641 = prims.convert_element_type(t640, dtypes.bfloat16) # t641: \"cuda:0 bf16[1, 512, 4096]\"\n", " t643 = torch.nn.functional.linear(t641, t39, None) # t643: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t643 = ltorch.linear(t641, t39, None) # t643: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t643 = prims.linear(t641, t39, None) # t643: \"cuda:0 bf16[1, 512, 11008]\"\n", " t642 = torch.nn.functional.linear(t641, t23, None) # t642: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t642 = ltorch.linear(t641, t23, None) # t642: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t642 = prims.linear(t641, t23, None) # t642: \"cuda:0 bf16[1, 512, 11008]\"\n", " [t657] = nvFusion24(t642, t643)\n", " # t644 = prims.convert_element_type(t642, dtypes.float32) # t644: \"cuda:0 f32[1, 512, 11008]\"\n", " # t645 = prims.neg(t644) # t645: \"cuda:0 f32[1, 512, 11008]\"\n", " # t646 = prims.exp(t645) # t646: \"cuda:0 f32[1, 512, 11008]\"\n", " # t647 = prims.add(1.0, t646) # t647: \"cuda:0 f32[1, 512, 11008]\"\n", " # t648 = prims.reciprocal(t647) # t648: \"cuda:0 f32[1, 512, 11008]\"\n", " # t652 = prims.mul(t644, t648) # t652: \"cuda:0 f32[1, 512, 11008]\"\n", " # t655 = prims.convert_element_type(t643, dtypes.float32) # t655: \"cuda:0 f32[1, 512, 11008]\"\n", " # t656 = prims.mul(t652, t655) # t656: \"cuda:0 f32[1, 512, 11008]\"\n", " # t657 = prims.convert_element_type(t656, dtypes.bfloat16) # t657: \"cuda:0 bf16[1, 512, 11008]\"\n", " t658 = torch.nn.functional.linear(t657, t94, None) # t658: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t658 = ltorch.linear(t657, t94, None) # t658: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t658 = prims.linear(t657, t94, None) # t658: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t662, t669, t677] = nvFusion25(t626, t658, t673)\n", " # t660 = prims.convert_element_type(t626, dtypes.float32) # t660: \"cuda:0 f32[1, 512, 4096]\"\n", " # t659 = prims.convert_element_type(t658, dtypes.float32) # t659: \"cuda:0 f32[1, 512, 4096]\"\n", " # t661 = prims.add(t659, t660) # t661: \"cuda:0 f32[1, 512, 4096]\"\n", " # t662 = prims.convert_element_type(t661, dtypes.bfloat16) # t662: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t664 = prims.mul(t661, t661) # t664: \"cuda:0 f32[1, 512, 4096]\"\n", " # t665 = prims.sum(t664, (2,)) # t665: \"cuda:0 f32[1, 512]\"\n", " # t666 = prims.broadcast_in_dim(t665, [1, 512, 1], [0, 1]) # t666: \"cuda:0 f32[1, 512, 1]\"\n", " # t667 = prims.div(t666, 4096.0) # t667: \"cuda:0 f32[1, 512, 1]\"\n", " # t668 = prims.add(t667, 1e-05) # t668: \"cuda:0 f32[1, 512, 1]\"\n", " # t669 = prims.rsqrt(t668) # t669: \"cuda:0 f32[1, 512, 1]\"\n", " # t670 = prims.broadcast_in_dim(t669, (1, 512, 4096), (0, 1, 2)) # t670: \"cuda:0 f32[1, 512, 4096]\"\n", " # t671 = prims.mul(t661, t670) # t671: \"cuda:0 f32[1, 512, 4096]\"\n", " # t675 = prims.convert_element_type(t673, dtypes.float32) # t675: \"cuda:0 f32[1, 512, 4096]\"\n", " # t676 = prims.mul(t671, t675) # t676: \"cuda:0 f32[1, 512, 4096]\"\n", " # t677 = prims.convert_element_type(t676, dtypes.bfloat16) # t677: \"cuda:0 bf16[1, 512, 4096]\"\n", " t678 = torch.nn.functional.linear(t677, t8, None) # t678: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t678 = ltorch.linear(t677, t8, None) # t678: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t678 = prims.linear(t677, t8, None) # t678: \"cuda:0 bf16[1, 512, 12288]\"\n", " t679 = torch.reshape(t678, (1, 512, 32, 3, 128)) # t679: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t679 = ltorch.reshape(t678, (1, 512, 32, 3, 128)) # t679: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t679 = prims.reshape(t678, (1, 512, 32, 3, 128)) # t679: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " del t678\n", " t680 = torch.permute(t679, (0, 2, 3, 1, 4)) # t680: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t680 = ltorch.permute(t679, (0, 2, 3, 1, 4)) # t680: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t680 = prims.transpose(t679, (0, 2, 3, 1, 4)) # t680: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " del t679\n", " (t681, t682, t683) = torch.split(t680, (1, 1, 1), 2)\n", " # (t681, t682, t683) = ltorch.split(t680, (1, 1, 1), 2)\n", " # t681 = prims.slice_prim(t680, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t681: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t682 = prims.slice_prim(t680, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t682: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t683 = prims.slice_prim(t680, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t683: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " del t680\n", " t684 = torch.reshape(t681, (1, 32, 512, 128)) # t684: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t684 = ltorch.reshape(t681, (1, 32, 512, 128)) # t684: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t684 = prims.reshape(t681, (1, 32, 512, 128)) # t684: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t681\n", " t685 = torch.reshape(t682, (1, 32, 512, 128)) # t685: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t685 = ltorch.reshape(t682, (1, 32, 512, 128)) # t685: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t685 = prims.reshape(t682, (1, 32, 512, 128)) # t685: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t682\n", " t686 = torch.reshape(t683, (1, 32, 512, 128)) # t686: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t686 = ltorch.reshape(t683, (1, 32, 512, 128)) # t686: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t686 = prims.reshape(t683, (1, 32, 512, 128)) # t686: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t683\n", " t687 = torch_slice_prim_impl(t684, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t687: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t702 = torch_slice_prim_impl(t685, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t702: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t717 = torch_slice_prim_impl(t684, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t717: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t684\n", " t719 = torch_slice_prim_impl(t685, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t719: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t685\n", " t688 = torch_slice_prim_impl(t687, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t688: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t689 = torch_slice_prim_impl(t687, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t689: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t703 = torch_slice_prim_impl(t702, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t703: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t704 = torch_slice_prim_impl(t702, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t704: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " [t692, t707] = nvFusion26(t687, t689, t702, t704)\n", " # t690 = prims.convert_element_type(t689, dtypes.float32) # t690: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t691 = prims.neg(t690) # t691: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t692 = prims.convert_element_type(t691, dtypes.bfloat16) # t692: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " # t705 = prims.convert_element_type(t704, dtypes.float32) # t705: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t706 = prims.neg(t705) # t706: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t707 = prims.convert_element_type(t706, dtypes.bfloat16) # t707: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " del t689, t704\n", " t708 = torch.cat((t707, t703), -1) # t708: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t708 = ltorch.cat((t707, t703), -1) # t708: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t708 = prims.cat((t707, t703), -1) # t708: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t707, t703\n", " t693 = torch.cat((t692, t688), -1) # t693: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t693 = ltorch.cat((t692, t688), -1) # t693: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t693 = prims.cat((t692, t688), -1) # t693: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t692, t688\n", " [t701, t716] = nvFusion27(t154, t157, t687, t693, t702, t708)\n", " # t695 = prims.convert_element_type(t687, dtypes.float32) # t695: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t710 = prims.convert_element_type(t702, dtypes.float32) # t710: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t711 = prims.mul(t710, t154) # t711: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t713 = prims.convert_element_type(t708, dtypes.float32) # t713: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t714 = prims.mul(t713, t157) # t714: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t715 = prims.add(t711, t714) # t715: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t716 = prims.convert_element_type(t715, dtypes.bfloat16) # t716: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t696 = prims.mul(t695, t154) # t696: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t698 = prims.convert_element_type(t693, dtypes.float32) # t698: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t699 = prims.mul(t698, t157) # t699: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t700 = prims.add(t696, t699) # t700: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t701 = prims.convert_element_type(t700, dtypes.bfloat16) # t701: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t687, t693, t702, t708\n", " t720 = torch.cat((t716, t719), -1) # t720: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t720 = ltorch.cat((t716, t719), -1) # t720: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t720 = prims.cat((t716, t719), -1) # t720: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t716, t719\n", " t718 = torch.cat((t701, t717), -1) # t718: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t718 = ltorch.cat((t701, t717), -1) # t718: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t718 = prims.cat((t701, t717), -1) # t718: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t701, t717\n", " (t721, t722, t723, t724, _, _, t725, t726, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t718, t720, t686, 0.0, True, scale=0.08838834764831843)\n", " t728 = torch.permute(t721, (0, 2, 1, 3)) # t728: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t728 = ltorch.permute(t721, (0, 2, 1, 3)) # t728: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t728 = prims.transpose(t721, (0, 2, 1, 3)) # t728: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " t729 = torch.reshape(t728, (1, 512, 4096)) # t729: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t729 = ltorch.reshape(t728, (1, 512, 4096)) # t729: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t729 = prims.reshape(t728, (1, 512, 4096)) # t729: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t728\n", " t730 = torch.nn.functional.linear(t729, t95, None) # t730: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t730 = ltorch.linear(t729, t95, None) # t730: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t730 = prims.linear(t729, t95, None) # t730: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t734, t741, t749] = nvFusion28(t662, t730, t745)\n", " # t732 = prims.convert_element_type(t662, dtypes.float32) # t732: \"cuda:0 f32[1, 512, 4096]\"\n", " # t731 = prims.convert_element_type(t730, dtypes.float32) # t731: \"cuda:0 f32[1, 512, 4096]\"\n", " # t733 = prims.add(t731, t732) # t733: \"cuda:0 f32[1, 512, 4096]\"\n", " # t734 = prims.convert_element_type(t733, dtypes.bfloat16) # t734: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t736 = prims.mul(t733, t733) # t736: \"cuda:0 f32[1, 512, 4096]\"\n", " # t737 = prims.sum(t736, (2,)) # t737: \"cuda:0 f32[1, 512]\"\n", " # t738 = prims.broadcast_in_dim(t737, [1, 512, 1], [0, 1]) # t738: \"cuda:0 f32[1, 512, 1]\"\n", " # t739 = prims.div(t738, 4096.0) # t739: \"cuda:0 f32[1, 512, 1]\"\n", " # t740 = prims.add(t739, 1e-05) # t740: \"cuda:0 f32[1, 512, 1]\"\n", " # t741 = prims.rsqrt(t740) # t741: \"cuda:0 f32[1, 512, 1]\"\n", " # t742 = prims.broadcast_in_dim(t741, (1, 512, 4096), (0, 1, 2)) # t742: \"cuda:0 f32[1, 512, 4096]\"\n", " # t743 = prims.mul(t733, t742) # t743: \"cuda:0 f32[1, 512, 4096]\"\n", " # t747 = prims.convert_element_type(t745, dtypes.float32) # t747: \"cuda:0 f32[1, 512, 4096]\"\n", " # t748 = prims.mul(t743, t747) # t748: \"cuda:0 f32[1, 512, 4096]\"\n", " # t749 = prims.convert_element_type(t748, dtypes.bfloat16) # t749: \"cuda:0 bf16[1, 512, 4096]\"\n", " t750 = torch.nn.functional.linear(t749, t24, None) # t750: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t750 = ltorch.linear(t749, t24, None) # t750: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t750 = prims.linear(t749, t24, None) # t750: \"cuda:0 bf16[1, 512, 11008]\"\n", " t751 = torch.nn.functional.linear(t749, t40, None) # t751: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t751 = ltorch.linear(t749, t40, None) # t751: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t751 = prims.linear(t749, t40, None) # t751: \"cuda:0 bf16[1, 512, 11008]\"\n", " [t765] = nvFusion29(t750, t751)\n", " # t752 = prims.convert_element_type(t750, dtypes.float32) # t752: \"cuda:0 f32[1, 512, 11008]\"\n", " # t753 = prims.neg(t752) # t753: \"cuda:0 f32[1, 512, 11008]\"\n", " # t754 = prims.exp(t753) # t754: \"cuda:0 f32[1, 512, 11008]\"\n", " # t755 = prims.add(1.0, t754) # t755: \"cuda:0 f32[1, 512, 11008]\"\n", " # t756 = prims.reciprocal(t755) # t756: \"cuda:0 f32[1, 512, 11008]\"\n", " # t760 = prims.mul(t752, t756) # t760: \"cuda:0 f32[1, 512, 11008]\"\n", " # t763 = prims.convert_element_type(t751, dtypes.float32) # t763: \"cuda:0 f32[1, 512, 11008]\"\n", " # t764 = prims.mul(t760, t763) # t764: \"cuda:0 f32[1, 512, 11008]\"\n", " # t765 = prims.convert_element_type(t764, dtypes.bfloat16) # t765: \"cuda:0 bf16[1, 512, 11008]\"\n", " t766 = torch.nn.functional.linear(t765, t96, None) # t766: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t766 = ltorch.linear(t765, t96, None) # t766: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t766 = prims.linear(t765, t96, None) # t766: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t770, t777, t785] = nvFusion30(t734, t766, t781)\n", " # t768 = prims.convert_element_type(t734, dtypes.float32) # t768: \"cuda:0 f32[1, 512, 4096]\"\n", " # t767 = prims.convert_element_type(t766, dtypes.float32) # t767: \"cuda:0 f32[1, 512, 4096]\"\n", " # t769 = prims.add(t767, t768) # t769: \"cuda:0 f32[1, 512, 4096]\"\n", " # t770 = prims.convert_element_type(t769, dtypes.bfloat16) # t770: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t772 = prims.mul(t769, t769) # t772: \"cuda:0 f32[1, 512, 4096]\"\n", " # t773 = prims.sum(t772, (2,)) # t773: \"cuda:0 f32[1, 512]\"\n", " # t774 = prims.broadcast_in_dim(t773, [1, 512, 1], [0, 1]) # t774: \"cuda:0 f32[1, 512, 1]\"\n", " # t775 = prims.div(t774, 4096.0) # t775: \"cuda:0 f32[1, 512, 1]\"\n", " # t776 = prims.add(t775, 1e-05) # t776: \"cuda:0 f32[1, 512, 1]\"\n", " # t777 = prims.rsqrt(t776) # t777: \"cuda:0 f32[1, 512, 1]\"\n", " # t778 = prims.broadcast_in_dim(t777, (1, 512, 4096), (0, 1, 2)) # t778: \"cuda:0 f32[1, 512, 4096]\"\n", " # t779 = prims.mul(t769, t778) # t779: \"cuda:0 f32[1, 512, 4096]\"\n", " # t783 = prims.convert_element_type(t781, dtypes.float32) # t783: \"cuda:0 f32[1, 512, 4096]\"\n", " # t784 = prims.mul(t779, t783) # t784: \"cuda:0 f32[1, 512, 4096]\"\n", " # t785 = prims.convert_element_type(t784, dtypes.bfloat16) # t785: \"cuda:0 bf16[1, 512, 4096]\"\n", " t786 = torch.nn.functional.linear(t785, t9, None) # t786: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t786 = ltorch.linear(t785, t9, None) # t786: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t786 = prims.linear(t785, t9, None) # t786: \"cuda:0 bf16[1, 512, 12288]\"\n", " t787 = torch.reshape(t786, (1, 512, 32, 3, 128)) # t787: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t787 = ltorch.reshape(t786, (1, 512, 32, 3, 128)) # t787: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t787 = prims.reshape(t786, (1, 512, 32, 3, 128)) # t787: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " del t786\n", " t788 = torch.permute(t787, (0, 2, 3, 1, 4)) # t788: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t788 = ltorch.permute(t787, (0, 2, 3, 1, 4)) # t788: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t788 = prims.transpose(t787, (0, 2, 3, 1, 4)) # t788: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " del t787\n", " (t789, t790, t791) = torch.split(t788, (1, 1, 1), 2)\n", " # (t789, t790, t791) = ltorch.split(t788, (1, 1, 1), 2)\n", " # t789 = prims.slice_prim(t788, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t789: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t790 = prims.slice_prim(t788, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t790: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t791 = prims.slice_prim(t788, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t791: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " del t788\n", " t792 = torch.reshape(t789, (1, 32, 512, 128)) # t792: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t792 = ltorch.reshape(t789, (1, 32, 512, 128)) # t792: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t792 = prims.reshape(t789, (1, 32, 512, 128)) # t792: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t789\n", " t793 = torch.reshape(t790, (1, 32, 512, 128)) # t793: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t793 = ltorch.reshape(t790, (1, 32, 512, 128)) # t793: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t793 = prims.reshape(t790, (1, 32, 512, 128)) # t793: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t790\n", " t794 = torch.reshape(t791, (1, 32, 512, 128)) # t794: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t794 = ltorch.reshape(t791, (1, 32, 512, 128)) # t794: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t794 = prims.reshape(t791, (1, 32, 512, 128)) # t794: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t791\n", " t795 = torch_slice_prim_impl(t792, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t795: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t810 = torch_slice_prim_impl(t793, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t810: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t825 = torch_slice_prim_impl(t792, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t825: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t792\n", " t827 = torch_slice_prim_impl(t793, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t827: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t793\n", " t796 = torch_slice_prim_impl(t795, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t796: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t797 = torch_slice_prim_impl(t795, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t797: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t811 = torch_slice_prim_impl(t810, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t811: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t812 = torch_slice_prim_impl(t810, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t812: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " [t800, t815] = nvFusion31(t795, t797, t810, t812)\n", " # t798 = prims.convert_element_type(t797, dtypes.float32) # t798: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t799 = prims.neg(t798) # t799: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t800 = prims.convert_element_type(t799, dtypes.bfloat16) # t800: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " # t813 = prims.convert_element_type(t812, dtypes.float32) # t813: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t814 = prims.neg(t813) # t814: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t815 = prims.convert_element_type(t814, dtypes.bfloat16) # t815: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " del t797, t812\n", " t816 = torch.cat((t815, t811), -1) # t816: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t816 = ltorch.cat((t815, t811), -1) # t816: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t816 = prims.cat((t815, t811), -1) # t816: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t815, t811\n", " t801 = torch.cat((t800, t796), -1) # t801: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t801 = ltorch.cat((t800, t796), -1) # t801: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t801 = prims.cat((t800, t796), -1) # t801: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t800, t796\n", " [t809, t824] = nvFusion32(t154, t157, t795, t801, t810, t816)\n", " # t803 = prims.convert_element_type(t795, dtypes.float32) # t803: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t818 = prims.convert_element_type(t810, dtypes.float32) # t818: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t819 = prims.mul(t818, t154) # t819: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t821 = prims.convert_element_type(t816, dtypes.float32) # t821: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t822 = prims.mul(t821, t157) # t822: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t823 = prims.add(t819, t822) # t823: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t824 = prims.convert_element_type(t823, dtypes.bfloat16) # t824: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t804 = prims.mul(t803, t154) # t804: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t806 = prims.convert_element_type(t801, dtypes.float32) # t806: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t807 = prims.mul(t806, t157) # t807: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t808 = prims.add(t804, t807) # t808: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t809 = prims.convert_element_type(t808, dtypes.bfloat16) # t809: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t795, t801, t810, t816\n", " t828 = torch.cat((t824, t827), -1) # t828: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t828 = ltorch.cat((t824, t827), -1) # t828: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t828 = prims.cat((t824, t827), -1) # t828: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t824, t827\n", " t826 = torch.cat((t809, t825), -1) # t826: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t826 = ltorch.cat((t809, t825), -1) # t826: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t826 = prims.cat((t809, t825), -1) # t826: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t809, t825\n", " (t829, t830, t831, t832, _, _, t833, t834, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t826, t828, t794, 0.0, True, scale=0.08838834764831843)\n", " t836 = torch.permute(t829, (0, 2, 1, 3)) # t836: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t836 = ltorch.permute(t829, (0, 2, 1, 3)) # t836: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t836 = prims.transpose(t829, (0, 2, 1, 3)) # t836: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " t837 = torch.reshape(t836, (1, 512, 4096)) # t837: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t837 = ltorch.reshape(t836, (1, 512, 4096)) # t837: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t837 = prims.reshape(t836, (1, 512, 4096)) # t837: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t836\n", " t838 = torch.nn.functional.linear(t837, t97, None) # t838: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t838 = ltorch.linear(t837, t97, None) # t838: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t838 = prims.linear(t837, t97, None) # t838: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t842, t849, t857] = nvFusion33(t770, t838, t853)\n", " # t840 = prims.convert_element_type(t770, dtypes.float32) # t840: \"cuda:0 f32[1, 512, 4096]\"\n", " # t839 = prims.convert_element_type(t838, dtypes.float32) # t839: \"cuda:0 f32[1, 512, 4096]\"\n", " # t841 = prims.add(t839, t840) # t841: \"cuda:0 f32[1, 512, 4096]\"\n", " # t842 = prims.convert_element_type(t841, dtypes.bfloat16) # t842: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t844 = prims.mul(t841, t841) # t844: \"cuda:0 f32[1, 512, 4096]\"\n", " # t845 = prims.sum(t844, (2,)) # t845: \"cuda:0 f32[1, 512]\"\n", " # t846 = prims.broadcast_in_dim(t845, [1, 512, 1], [0, 1]) # t846: \"cuda:0 f32[1, 512, 1]\"\n", " # t847 = prims.div(t846, 4096.0) # t847: \"cuda:0 f32[1, 512, 1]\"\n", " # t848 = prims.add(t847, 1e-05) # t848: \"cuda:0 f32[1, 512, 1]\"\n", " # t849 = prims.rsqrt(t848) # t849: \"cuda:0 f32[1, 512, 1]\"\n", " # t850 = prims.broadcast_in_dim(t849, (1, 512, 4096), (0, 1, 2)) # t850: \"cuda:0 f32[1, 512, 4096]\"\n", " # t851 = prims.mul(t841, t850) # t851: \"cuda:0 f32[1, 512, 4096]\"\n", " # t855 = prims.convert_element_type(t853, dtypes.float32) # t855: \"cuda:0 f32[1, 512, 4096]\"\n", " # t856 = prims.mul(t851, t855) # t856: \"cuda:0 f32[1, 512, 4096]\"\n", " # t857 = prims.convert_element_type(t856, dtypes.bfloat16) # t857: \"cuda:0 bf16[1, 512, 4096]\"\n", " t858 = torch.nn.functional.linear(t857, t25, None) # t858: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t858 = ltorch.linear(t857, t25, None) # t858: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t858 = prims.linear(t857, t25, None) # t858: \"cuda:0 bf16[1, 512, 11008]\"\n", " t859 = torch.nn.functional.linear(t857, t41, None) # t859: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t859 = ltorch.linear(t857, t41, None) # t859: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t859 = prims.linear(t857, t41, None) # t859: \"cuda:0 bf16[1, 512, 11008]\"\n", " [t873] = nvFusion34(t858, t859)\n", " # t860 = prims.convert_element_type(t858, dtypes.float32) # t860: \"cuda:0 f32[1, 512, 11008]\"\n", " # t861 = prims.neg(t860) # t861: \"cuda:0 f32[1, 512, 11008]\"\n", " # t862 = prims.exp(t861) # t862: \"cuda:0 f32[1, 512, 11008]\"\n", " # t863 = prims.add(1.0, t862) # t863: \"cuda:0 f32[1, 512, 11008]\"\n", " # t864 = prims.reciprocal(t863) # t864: \"cuda:0 f32[1, 512, 11008]\"\n", " # t868 = prims.mul(t860, t864) # t868: \"cuda:0 f32[1, 512, 11008]\"\n", " # t871 = prims.convert_element_type(t859, dtypes.float32) # t871: \"cuda:0 f32[1, 512, 11008]\"\n", " # t872 = prims.mul(t868, t871) # t872: \"cuda:0 f32[1, 512, 11008]\"\n", " # t873 = prims.convert_element_type(t872, dtypes.bfloat16) # t873: \"cuda:0 bf16[1, 512, 11008]\"\n", " t874 = torch.nn.functional.linear(t873, t98, None) # t874: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t874 = ltorch.linear(t873, t98, None) # t874: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t874 = prims.linear(t873, t98, None) # t874: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t878, t885, t893] = nvFusion35(t842, t874, t889)\n", " # t876 = prims.convert_element_type(t842, dtypes.float32) # t876: \"cuda:0 f32[1, 512, 4096]\"\n", " # t875 = prims.convert_element_type(t874, dtypes.float32) # t875: \"cuda:0 f32[1, 512, 4096]\"\n", " # t877 = prims.add(t875, t876) # t877: \"cuda:0 f32[1, 512, 4096]\"\n", " # t878 = prims.convert_element_type(t877, dtypes.bfloat16) # t878: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t880 = prims.mul(t877, t877) # t880: \"cuda:0 f32[1, 512, 4096]\"\n", " # t881 = prims.sum(t880, (2,)) # t881: \"cuda:0 f32[1, 512]\"\n", " # t882 = prims.broadcast_in_dim(t881, [1, 512, 1], [0, 1]) # t882: \"cuda:0 f32[1, 512, 1]\"\n", " # t883 = prims.div(t882, 4096.0) # t883: \"cuda:0 f32[1, 512, 1]\"\n", " # t884 = prims.add(t883, 1e-05) # t884: \"cuda:0 f32[1, 512, 1]\"\n", " # t885 = prims.rsqrt(t884) # t885: \"cuda:0 f32[1, 512, 1]\"\n", " # t886 = prims.broadcast_in_dim(t885, (1, 512, 4096), (0, 1, 2)) # t886: \"cuda:0 f32[1, 512, 4096]\"\n", " # t887 = prims.mul(t877, t886) # t887: \"cuda:0 f32[1, 512, 4096]\"\n", " # t891 = prims.convert_element_type(t889, dtypes.float32) # t891: \"cuda:0 f32[1, 512, 4096]\"\n", " # t892 = prims.mul(t887, t891) # t892: \"cuda:0 f32[1, 512, 4096]\"\n", " # t893 = prims.convert_element_type(t892, dtypes.bfloat16) # t893: \"cuda:0 bf16[1, 512, 4096]\"\n", " t894 = torch.nn.functional.linear(t893, t10, None) # t894: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t894 = ltorch.linear(t893, t10, None) # t894: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t894 = prims.linear(t893, t10, None) # t894: \"cuda:0 bf16[1, 512, 12288]\"\n", " t895 = torch.reshape(t894, (1, 512, 32, 3, 128)) # t895: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t895 = ltorch.reshape(t894, (1, 512, 32, 3, 128)) # t895: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t895 = prims.reshape(t894, (1, 512, 32, 3, 128)) # t895: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " del t894\n", " t896 = torch.permute(t895, (0, 2, 3, 1, 4)) # t896: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t896 = ltorch.permute(t895, (0, 2, 3, 1, 4)) # t896: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t896 = prims.transpose(t895, (0, 2, 3, 1, 4)) # t896: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " del t895\n", " (t897, t898, t899) = torch.split(t896, (1, 1, 1), 2)\n", " # (t897, t898, t899) = ltorch.split(t896, (1, 1, 1), 2)\n", " # t897 = prims.slice_prim(t896, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t897: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t898 = prims.slice_prim(t896, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t898: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t899 = prims.slice_prim(t896, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t899: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " del t896\n", " t900 = torch.reshape(t897, (1, 32, 512, 128)) # t900: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t900 = ltorch.reshape(t897, (1, 32, 512, 128)) # t900: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t900 = prims.reshape(t897, (1, 32, 512, 128)) # t900: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t897\n", " t901 = torch.reshape(t898, (1, 32, 512, 128)) # t901: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t901 = ltorch.reshape(t898, (1, 32, 512, 128)) # t901: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t901 = prims.reshape(t898, (1, 32, 512, 128)) # t901: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t898\n", " t902 = torch.reshape(t899, (1, 32, 512, 128)) # t902: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t902 = ltorch.reshape(t899, (1, 32, 512, 128)) # t902: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t902 = prims.reshape(t899, (1, 32, 512, 128)) # t902: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t899\n", " t935 = torch_slice_prim_impl(t901, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t935: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " t903 = torch_slice_prim_impl(t900, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t903: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t918 = torch_slice_prim_impl(t901, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t918: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t901\n", " t933 = torch_slice_prim_impl(t900, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t933: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t900\n", " t904 = torch_slice_prim_impl(t903, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t904: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t905 = torch_slice_prim_impl(t903, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t905: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t919 = torch_slice_prim_impl(t918, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t919: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t920 = torch_slice_prim_impl(t918, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t920: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " [t908, t923] = nvFusion36(t903, t905, t918, t920)\n", " # t906 = prims.convert_element_type(t905, dtypes.float32) # t906: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t907 = prims.neg(t906) # t907: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t908 = prims.convert_element_type(t907, dtypes.bfloat16) # t908: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " # t921 = prims.convert_element_type(t920, dtypes.float32) # t921: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t922 = prims.neg(t921) # t922: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t923 = prims.convert_element_type(t922, dtypes.bfloat16) # t923: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " del t905, t920\n", " t924 = torch.cat((t923, t919), -1) # t924: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t924 = ltorch.cat((t923, t919), -1) # t924: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t924 = prims.cat((t923, t919), -1) # t924: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t923, t919\n", " t909 = torch.cat((t908, t904), -1) # t909: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t909 = ltorch.cat((t908, t904), -1) # t909: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t909 = prims.cat((t908, t904), -1) # t909: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t908, t904\n", " [t917, t932] = nvFusion37(t154, t157, t903, t909, t918, t924)\n", " # t911 = prims.convert_element_type(t903, dtypes.float32) # t911: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t926 = prims.convert_element_type(t918, dtypes.float32) # t926: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t927 = prims.mul(t926, t154) # t927: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t929 = prims.convert_element_type(t924, dtypes.float32) # t929: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t930 = prims.mul(t929, t157) # t930: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t931 = prims.add(t927, t930) # t931: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t932 = prims.convert_element_type(t931, dtypes.bfloat16) # t932: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t912 = prims.mul(t911, t154) # t912: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t914 = prims.convert_element_type(t909, dtypes.float32) # t914: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t915 = prims.mul(t914, t157) # t915: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t916 = prims.add(t912, t915) # t916: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t917 = prims.convert_element_type(t916, dtypes.bfloat16) # t917: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t903, t909, t918, t924\n", " t936 = torch.cat((t932, t935), -1) # t936: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t936 = ltorch.cat((t932, t935), -1) # t936: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t936 = prims.cat((t932, t935), -1) # t936: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t932, t935\n", " t934 = torch.cat((t917, t933), -1) # t934: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t934 = ltorch.cat((t917, t933), -1) # t934: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t934 = prims.cat((t917, t933), -1) # t934: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t917, t933\n", " (t937, t938, t939, t940, _, _, t941, t942, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t934, t936, t902, 0.0, True, scale=0.08838834764831843)\n", " t944 = torch.permute(t937, (0, 2, 1, 3)) # t944: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t944 = ltorch.permute(t937, (0, 2, 1, 3)) # t944: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t944 = prims.transpose(t937, (0, 2, 1, 3)) # t944: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " t945 = torch.reshape(t944, (1, 512, 4096)) # t945: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t945 = ltorch.reshape(t944, (1, 512, 4096)) # t945: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t945 = prims.reshape(t944, (1, 512, 4096)) # t945: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t944\n", " t946 = torch.nn.functional.linear(t945, t99, None) # t946: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t946 = ltorch.linear(t945, t99, None) # t946: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t946 = prims.linear(t945, t99, None) # t946: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t950, t957, t965] = nvFusion38(t878, t946, t961)\n", " # t948 = prims.convert_element_type(t878, dtypes.float32) # t948: \"cuda:0 f32[1, 512, 4096]\"\n", " # t947 = prims.convert_element_type(t946, dtypes.float32) # t947: \"cuda:0 f32[1, 512, 4096]\"\n", " # t949 = prims.add(t947, t948) # t949: \"cuda:0 f32[1, 512, 4096]\"\n", " # t950 = prims.convert_element_type(t949, dtypes.bfloat16) # t950: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t952 = prims.mul(t949, t949) # t952: \"cuda:0 f32[1, 512, 4096]\"\n", " # t953 = prims.sum(t952, (2,)) # t953: \"cuda:0 f32[1, 512]\"\n", " # t954 = prims.broadcast_in_dim(t953, [1, 512, 1], [0, 1]) # t954: \"cuda:0 f32[1, 512, 1]\"\n", " # t955 = prims.div(t954, 4096.0) # t955: \"cuda:0 f32[1, 512, 1]\"\n", " # t956 = prims.add(t955, 1e-05) # t956: \"cuda:0 f32[1, 512, 1]\"\n", " # t957 = prims.rsqrt(t956) # t957: \"cuda:0 f32[1, 512, 1]\"\n", " # t958 = prims.broadcast_in_dim(t957, (1, 512, 4096), (0, 1, 2)) # t958: \"cuda:0 f32[1, 512, 4096]\"\n", " # t959 = prims.mul(t949, t958) # t959: \"cuda:0 f32[1, 512, 4096]\"\n", " # t963 = prims.convert_element_type(t961, dtypes.float32) # t963: \"cuda:0 f32[1, 512, 4096]\"\n", " # t964 = prims.mul(t959, t963) # t964: \"cuda:0 f32[1, 512, 4096]\"\n", " # t965 = prims.convert_element_type(t964, dtypes.bfloat16) # t965: \"cuda:0 bf16[1, 512, 4096]\"\n", " t967 = torch.nn.functional.linear(t965, t42, None) # t967: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t967 = ltorch.linear(t965, t42, None) # t967: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t967 = prims.linear(t965, t42, None) # t967: \"cuda:0 bf16[1, 512, 11008]\"\n", " t966 = torch.nn.functional.linear(t965, t26, None) # t966: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t966 = ltorch.linear(t965, t26, None) # t966: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t966 = prims.linear(t965, t26, None) # t966: \"cuda:0 bf16[1, 512, 11008]\"\n", " [t981] = nvFusion39(t966, t967)\n", " # t968 = prims.convert_element_type(t966, dtypes.float32) # t968: \"cuda:0 f32[1, 512, 11008]\"\n", " # t969 = prims.neg(t968) # t969: \"cuda:0 f32[1, 512, 11008]\"\n", " # t970 = prims.exp(t969) # t970: \"cuda:0 f32[1, 512, 11008]\"\n", " # t971 = prims.add(1.0, t970) # t971: \"cuda:0 f32[1, 512, 11008]\"\n", " # t972 = prims.reciprocal(t971) # t972: \"cuda:0 f32[1, 512, 11008]\"\n", " # t976 = prims.mul(t968, t972) # t976: \"cuda:0 f32[1, 512, 11008]\"\n", " # t979 = prims.convert_element_type(t967, dtypes.float32) # t979: \"cuda:0 f32[1, 512, 11008]\"\n", " # t980 = prims.mul(t976, t979) # t980: \"cuda:0 f32[1, 512, 11008]\"\n", " # t981 = prims.convert_element_type(t980, dtypes.bfloat16) # t981: \"cuda:0 bf16[1, 512, 11008]\"\n", " t982 = torch.nn.functional.linear(t981, t100, None) # t982: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t982 = ltorch.linear(t981, t100, None) # t982: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t982 = prims.linear(t981, t100, None) # t982: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t1001, t986, t993] = nvFusion40(t950, t982, t997)\n", " # t984 = prims.convert_element_type(t950, dtypes.float32) # t984: \"cuda:0 f32[1, 512, 4096]\"\n", " # t983 = prims.convert_element_type(t982, dtypes.float32) # t983: \"cuda:0 f32[1, 512, 4096]\"\n", " # t985 = prims.add(t983, t984) # t985: \"cuda:0 f32[1, 512, 4096]\"\n", " # t986 = prims.convert_element_type(t985, dtypes.bfloat16) # t986: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t988 = prims.mul(t985, t985) # t988: \"cuda:0 f32[1, 512, 4096]\"\n", " # t989 = prims.sum(t988, (2,)) # t989: \"cuda:0 f32[1, 512]\"\n", " # t990 = prims.broadcast_in_dim(t989, [1, 512, 1], [0, 1]) # t990: \"cuda:0 f32[1, 512, 1]\"\n", " # t991 = prims.div(t990, 4096.0) # t991: \"cuda:0 f32[1, 512, 1]\"\n", " # t992 = prims.add(t991, 1e-05) # t992: \"cuda:0 f32[1, 512, 1]\"\n", " # t993 = prims.rsqrt(t992) # t993: \"cuda:0 f32[1, 512, 1]\"\n", " # t994 = prims.broadcast_in_dim(t993, (1, 512, 4096), (0, 1, 2)) # t994: \"cuda:0 f32[1, 512, 4096]\"\n", " # t995 = prims.mul(t985, t994) # t995: \"cuda:0 f32[1, 512, 4096]\"\n", " # t999 = prims.convert_element_type(t997, dtypes.float32) # t999: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1000 = prims.mul(t995, t999) # t1000: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1001 = prims.convert_element_type(t1000, dtypes.bfloat16) # t1001: \"cuda:0 bf16[1, 512, 4096]\"\n", " t1002 = torch.nn.functional.linear(t1001, t11, None) # t1002: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t1002 = ltorch.linear(t1001, t11, None) # t1002: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t1002 = prims.linear(t1001, t11, None) # t1002: \"cuda:0 bf16[1, 512, 12288]\"\n", " t1003 = torch.reshape(t1002, (1, 512, 32, 3, 128)) # t1003: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t1003 = ltorch.reshape(t1002, (1, 512, 32, 3, 128)) # t1003: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t1003 = prims.reshape(t1002, (1, 512, 32, 3, 128)) # t1003: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " del t1002\n", " t1004 = torch.permute(t1003, (0, 2, 3, 1, 4)) # t1004: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t1004 = ltorch.permute(t1003, (0, 2, 3, 1, 4)) # t1004: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t1004 = prims.transpose(t1003, (0, 2, 3, 1, 4)) # t1004: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " del t1003\n", " (t1005, t1006, t1007) = torch.split(t1004, (1, 1, 1), 2)\n", " # (t1005, t1006, t1007) = ltorch.split(t1004, (1, 1, 1), 2)\n", " # t1005 = prims.slice_prim(t1004, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1005: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t1006 = prims.slice_prim(t1004, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1006: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t1007 = prims.slice_prim(t1004, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1007: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " del t1004\n", " t1008 = torch.reshape(t1005, (1, 32, 512, 128)) # t1008: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1008 = ltorch.reshape(t1005, (1, 32, 512, 128)) # t1008: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1008 = prims.reshape(t1005, (1, 32, 512, 128)) # t1008: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1005\n", " t1009 = torch.reshape(t1006, (1, 32, 512, 128)) # t1009: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1009 = ltorch.reshape(t1006, (1, 32, 512, 128)) # t1009: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1009 = prims.reshape(t1006, (1, 32, 512, 128)) # t1009: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1006\n", " t1010 = torch.reshape(t1007, (1, 32, 512, 128)) # t1010: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1010 = ltorch.reshape(t1007, (1, 32, 512, 128)) # t1010: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1010 = prims.reshape(t1007, (1, 32, 512, 128)) # t1010: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1007\n", " t1026 = torch_slice_prim_impl(t1009, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1026: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t1041 = torch_slice_prim_impl(t1008, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1041: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " t1043 = torch_slice_prim_impl(t1009, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1043: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t1009\n", " t1011 = torch_slice_prim_impl(t1008, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1011: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1008\n", " t1027 = torch_slice_prim_impl(t1026, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1027: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1028 = torch_slice_prim_impl(t1026, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1028: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1013 = torch_slice_prim_impl(t1011, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1013: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1012 = torch_slice_prim_impl(t1011, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1012: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " [t1016, t1031] = nvFusion41(t1011, t1013, t1026, t1028)\n", " # t1014 = prims.convert_element_type(t1013, dtypes.float32) # t1014: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1015 = prims.neg(t1014) # t1015: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1016 = prims.convert_element_type(t1015, dtypes.bfloat16) # t1016: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " # t1029 = prims.convert_element_type(t1028, dtypes.float32) # t1029: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1030 = prims.neg(t1029) # t1030: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1031 = prims.convert_element_type(t1030, dtypes.bfloat16) # t1031: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " del t1013, t1028\n", " t1032 = torch.cat((t1031, t1027), -1) # t1032: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1032 = ltorch.cat((t1031, t1027), -1) # t1032: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1032 = prims.cat((t1031, t1027), -1) # t1032: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1031, t1027\n", " t1017 = torch.cat((t1016, t1012), -1) # t1017: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1017 = ltorch.cat((t1016, t1012), -1) # t1017: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1017 = prims.cat((t1016, t1012), -1) # t1017: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1016, t1012\n", " [t1025, t1040] = nvFusion42(t1011, t1017, t1026, t1032, t154, t157)\n", " # t1019 = prims.convert_element_type(t1011, dtypes.float32) # t1019: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1034 = prims.convert_element_type(t1026, dtypes.float32) # t1034: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1020 = prims.mul(t1019, t154) # t1020: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1022 = prims.convert_element_type(t1017, dtypes.float32) # t1022: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1023 = prims.mul(t1022, t157) # t1023: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1024 = prims.add(t1020, t1023) # t1024: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1025 = prims.convert_element_type(t1024, dtypes.bfloat16) # t1025: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1035 = prims.mul(t1034, t154) # t1035: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1037 = prims.convert_element_type(t1032, dtypes.float32) # t1037: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1038 = prims.mul(t1037, t157) # t1038: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1039 = prims.add(t1035, t1038) # t1039: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1040 = prims.convert_element_type(t1039, dtypes.bfloat16) # t1040: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1011, t1017, t1026, t1032\n", " t1042 = torch.cat((t1025, t1041), -1) # t1042: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1042 = ltorch.cat((t1025, t1041), -1) # t1042: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1042 = prims.cat((t1025, t1041), -1) # t1042: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1025, t1041\n", " t1044 = torch.cat((t1040, t1043), -1) # t1044: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1044 = ltorch.cat((t1040, t1043), -1) # t1044: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1044 = prims.cat((t1040, t1043), -1) # t1044: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1040, t1043\n", " (t1045, t1046, t1047, t1048, _, _, t1049, t1050, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1042, t1044, t1010, 0.0, True, scale=0.08838834764831843)\n", " t1052 = torch.permute(t1045, (0, 2, 1, 3)) # t1052: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t1052 = ltorch.permute(t1045, (0, 2, 1, 3)) # t1052: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t1052 = prims.transpose(t1045, (0, 2, 1, 3)) # t1052: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " t1053 = torch.reshape(t1052, (1, 512, 4096)) # t1053: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1053 = ltorch.reshape(t1052, (1, 512, 4096)) # t1053: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1053 = prims.reshape(t1052, (1, 512, 4096)) # t1053: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t1052\n", " t1054 = torch.nn.functional.linear(t1053, t101, None) # t1054: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1054 = ltorch.linear(t1053, t101, None) # t1054: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1054 = prims.linear(t1053, t101, None) # t1054: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t1058, t1065, t1073] = nvFusion43(t1054, t1069, t986)\n", " # t1056 = prims.convert_element_type(t986, dtypes.float32) # t1056: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1055 = prims.convert_element_type(t1054, dtypes.float32) # t1055: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1057 = prims.add(t1055, t1056) # t1057: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1058 = prims.convert_element_type(t1057, dtypes.bfloat16) # t1058: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1060 = prims.mul(t1057, t1057) # t1060: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1061 = prims.sum(t1060, (2,)) # t1061: \"cuda:0 f32[1, 512]\"\n", " # t1062 = prims.broadcast_in_dim(t1061, [1, 512, 1], [0, 1]) # t1062: \"cuda:0 f32[1, 512, 1]\"\n", " # t1063 = prims.div(t1062, 4096.0) # t1063: \"cuda:0 f32[1, 512, 1]\"\n", " # t1064 = prims.add(t1063, 1e-05) # t1064: \"cuda:0 f32[1, 512, 1]\"\n", " # t1065 = prims.rsqrt(t1064) # t1065: \"cuda:0 f32[1, 512, 1]\"\n", " # t1066 = prims.broadcast_in_dim(t1065, (1, 512, 4096), (0, 1, 2)) # t1066: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1067 = prims.mul(t1057, t1066) # t1067: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1071 = prims.convert_element_type(t1069, dtypes.float32) # t1071: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1072 = prims.mul(t1067, t1071) # t1072: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1073 = prims.convert_element_type(t1072, dtypes.bfloat16) # t1073: \"cuda:0 bf16[1, 512, 4096]\"\n", " t1074 = torch.nn.functional.linear(t1073, t27, None) # t1074: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1074 = ltorch.linear(t1073, t27, None) # t1074: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1074 = prims.linear(t1073, t27, None) # t1074: \"cuda:0 bf16[1, 512, 11008]\"\n", " t1075 = torch.nn.functional.linear(t1073, t43, None) # t1075: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1075 = ltorch.linear(t1073, t43, None) # t1075: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1075 = prims.linear(t1073, t43, None) # t1075: \"cuda:0 bf16[1, 512, 11008]\"\n", " [t1089] = nvFusion44(t1074, t1075)\n", " # t1076 = prims.convert_element_type(t1074, dtypes.float32) # t1076: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1077 = prims.neg(t1076) # t1077: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1078 = prims.exp(t1077) # t1078: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1079 = prims.add(1.0, t1078) # t1079: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1080 = prims.reciprocal(t1079) # t1080: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1084 = prims.mul(t1076, t1080) # t1084: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1087 = prims.convert_element_type(t1075, dtypes.float32) # t1087: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1088 = prims.mul(t1084, t1087) # t1088: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1089 = prims.convert_element_type(t1088, dtypes.bfloat16) # t1089: \"cuda:0 bf16[1, 512, 11008]\"\n", " t1090 = torch.nn.functional.linear(t1089, t102, None) # t1090: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1090 = ltorch.linear(t1089, t102, None) # t1090: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1090 = prims.linear(t1089, t102, None) # t1090: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t1094, t1101, t1109] = nvFusion45(t1058, t1090, t1105)\n", " # t1092 = prims.convert_element_type(t1058, dtypes.float32) # t1092: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1091 = prims.convert_element_type(t1090, dtypes.float32) # t1091: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1093 = prims.add(t1091, t1092) # t1093: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1094 = prims.convert_element_type(t1093, dtypes.bfloat16) # t1094: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1096 = prims.mul(t1093, t1093) # t1096: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1097 = prims.sum(t1096, (2,)) # t1097: \"cuda:0 f32[1, 512]\"\n", " # t1098 = prims.broadcast_in_dim(t1097, [1, 512, 1], [0, 1]) # t1098: \"cuda:0 f32[1, 512, 1]\"\n", " # t1099 = prims.div(t1098, 4096.0) # t1099: \"cuda:0 f32[1, 512, 1]\"\n", " # t1100 = prims.add(t1099, 1e-05) # t1100: \"cuda:0 f32[1, 512, 1]\"\n", " # t1101 = prims.rsqrt(t1100) # t1101: \"cuda:0 f32[1, 512, 1]\"\n", " # t1102 = prims.broadcast_in_dim(t1101, (1, 512, 4096), (0, 1, 2)) # t1102: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1103 = prims.mul(t1093, t1102) # t1103: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1107 = prims.convert_element_type(t1105, dtypes.float32) # t1107: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1108 = prims.mul(t1103, t1107) # t1108: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1109 = prims.convert_element_type(t1108, dtypes.bfloat16) # t1109: \"cuda:0 bf16[1, 512, 4096]\"\n", " t1110 = torch.nn.functional.linear(t1109, t12, None) # t1110: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t1110 = ltorch.linear(t1109, t12, None) # t1110: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t1110 = prims.linear(t1109, t12, None) # t1110: \"cuda:0 bf16[1, 512, 12288]\"\n", " t1111 = torch.reshape(t1110, (1, 512, 32, 3, 128)) # t1111: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t1111 = ltorch.reshape(t1110, (1, 512, 32, 3, 128)) # t1111: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t1111 = prims.reshape(t1110, (1, 512, 32, 3, 128)) # t1111: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " del t1110\n", " t1112 = torch.permute(t1111, (0, 2, 3, 1, 4)) # t1112: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t1112 = ltorch.permute(t1111, (0, 2, 3, 1, 4)) # t1112: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t1112 = prims.transpose(t1111, (0, 2, 3, 1, 4)) # t1112: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " del t1111\n", " (t1113, t1114, t1115) = torch.split(t1112, (1, 1, 1), 2)\n", " # (t1113, t1114, t1115) = ltorch.split(t1112, (1, 1, 1), 2)\n", " # t1113 = prims.slice_prim(t1112, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1113: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t1114 = prims.slice_prim(t1112, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1114: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t1115 = prims.slice_prim(t1112, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1115: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " del t1112\n", " t1116 = torch.reshape(t1113, (1, 32, 512, 128)) # t1116: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1116 = ltorch.reshape(t1113, (1, 32, 512, 128)) # t1116: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1116 = prims.reshape(t1113, (1, 32, 512, 128)) # t1116: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1113\n", " t1117 = torch.reshape(t1114, (1, 32, 512, 128)) # t1117: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1117 = ltorch.reshape(t1114, (1, 32, 512, 128)) # t1117: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1117 = prims.reshape(t1114, (1, 32, 512, 128)) # t1117: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1114\n", " t1118 = torch.reshape(t1115, (1, 32, 512, 128)) # t1118: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1118 = ltorch.reshape(t1115, (1, 32, 512, 128)) # t1118: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1118 = prims.reshape(t1115, (1, 32, 512, 128)) # t1118: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1115\n", " t1119 = torch_slice_prim_impl(t1116, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1119: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t1134 = torch_slice_prim_impl(t1117, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1134: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t1149 = torch_slice_prim_impl(t1116, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1149: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t1116\n", " t1151 = torch_slice_prim_impl(t1117, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1151: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t1117\n", " t1120 = torch_slice_prim_impl(t1119, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1120: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1121 = torch_slice_prim_impl(t1119, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1121: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1136 = torch_slice_prim_impl(t1134, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1136: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1135 = torch_slice_prim_impl(t1134, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1135: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " [t1124, t1139] = nvFusion46(t1119, t1121, t1134, t1136)\n", " # t1122 = prims.convert_element_type(t1121, dtypes.float32) # t1122: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1123 = prims.neg(t1122) # t1123: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1124 = prims.convert_element_type(t1123, dtypes.bfloat16) # t1124: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " # t1137 = prims.convert_element_type(t1136, dtypes.float32) # t1137: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1138 = prims.neg(t1137) # t1138: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1139 = prims.convert_element_type(t1138, dtypes.bfloat16) # t1139: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " del t1121, t1136\n", " t1125 = torch.cat((t1124, t1120), -1) # t1125: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1125 = ltorch.cat((t1124, t1120), -1) # t1125: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1125 = prims.cat((t1124, t1120), -1) # t1125: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1124, t1120\n", " t1140 = torch.cat((t1139, t1135), -1) # t1140: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1140 = ltorch.cat((t1139, t1135), -1) # t1140: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1140 = prims.cat((t1139, t1135), -1) # t1140: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1139, t1135\n", " [t1133, t1148] = nvFusion47(t1119, t1125, t1134, t1140, t154, t157)\n", " # t1127 = prims.convert_element_type(t1119, dtypes.float32) # t1127: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1142 = prims.convert_element_type(t1134, dtypes.float32) # t1142: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1128 = prims.mul(t1127, t154) # t1128: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1130 = prims.convert_element_type(t1125, dtypes.float32) # t1130: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1131 = prims.mul(t1130, t157) # t1131: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1132 = prims.add(t1128, t1131) # t1132: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1133 = prims.convert_element_type(t1132, dtypes.bfloat16) # t1133: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1143 = prims.mul(t1142, t154) # t1143: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1145 = prims.convert_element_type(t1140, dtypes.float32) # t1145: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1146 = prims.mul(t1145, t157) # t1146: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1147 = prims.add(t1143, t1146) # t1147: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1148 = prims.convert_element_type(t1147, dtypes.bfloat16) # t1148: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1119, t1125, t1134, t1140\n", " t1152 = torch.cat((t1148, t1151), -1) # t1152: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1152 = ltorch.cat((t1148, t1151), -1) # t1152: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1152 = prims.cat((t1148, t1151), -1) # t1152: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1148, t1151\n", " t1150 = torch.cat((t1133, t1149), -1) # t1150: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1150 = ltorch.cat((t1133, t1149), -1) # t1150: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1150 = prims.cat((t1133, t1149), -1) # t1150: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1133, t1149\n", " (t1153, t1154, t1155, t1156, _, _, t1157, t1158, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1150, t1152, t1118, 0.0, True, scale=0.08838834764831843)\n", " t1160 = torch.permute(t1153, (0, 2, 1, 3)) # t1160: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t1160 = ltorch.permute(t1153, (0, 2, 1, 3)) # t1160: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t1160 = prims.transpose(t1153, (0, 2, 1, 3)) # t1160: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " t1161 = torch.reshape(t1160, (1, 512, 4096)) # t1161: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1161 = ltorch.reshape(t1160, (1, 512, 4096)) # t1161: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1161 = prims.reshape(t1160, (1, 512, 4096)) # t1161: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t1160\n", " t1162 = torch.nn.functional.linear(t1161, t103, None) # t1162: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1162 = ltorch.linear(t1161, t103, None) # t1162: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1162 = prims.linear(t1161, t103, None) # t1162: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t1166, t1173, t1181] = nvFusion48(t1094, t1162, t1177)\n", " # t1164 = prims.convert_element_type(t1094, dtypes.float32) # t1164: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1163 = prims.convert_element_type(t1162, dtypes.float32) # t1163: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1165 = prims.add(t1163, t1164) # t1165: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1166 = prims.convert_element_type(t1165, dtypes.bfloat16) # t1166: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1168 = prims.mul(t1165, t1165) # t1168: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1169 = prims.sum(t1168, (2,)) # t1169: \"cuda:0 f32[1, 512]\"\n", " # t1170 = prims.broadcast_in_dim(t1169, [1, 512, 1], [0, 1]) # t1170: \"cuda:0 f32[1, 512, 1]\"\n", " # t1171 = prims.div(t1170, 4096.0) # t1171: \"cuda:0 f32[1, 512, 1]\"\n", " # t1172 = prims.add(t1171, 1e-05) # t1172: \"cuda:0 f32[1, 512, 1]\"\n", " # t1173 = prims.rsqrt(t1172) # t1173: \"cuda:0 f32[1, 512, 1]\"\n", " # t1174 = prims.broadcast_in_dim(t1173, (1, 512, 4096), (0, 1, 2)) # t1174: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1175 = prims.mul(t1165, t1174) # t1175: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1179 = prims.convert_element_type(t1177, dtypes.float32) # t1179: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1180 = prims.mul(t1175, t1179) # t1180: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1181 = prims.convert_element_type(t1180, dtypes.bfloat16) # t1181: \"cuda:0 bf16[1, 512, 4096]\"\n", " t1182 = torch.nn.functional.linear(t1181, t28, None) # t1182: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1182 = ltorch.linear(t1181, t28, None) # t1182: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1182 = prims.linear(t1181, t28, None) # t1182: \"cuda:0 bf16[1, 512, 11008]\"\n", " t1183 = torch.nn.functional.linear(t1181, t44, None) # t1183: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1183 = ltorch.linear(t1181, t44, None) # t1183: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1183 = prims.linear(t1181, t44, None) # t1183: \"cuda:0 bf16[1, 512, 11008]\"\n", " [t1197] = nvFusion49(t1182, t1183)\n", " # t1184 = prims.convert_element_type(t1182, dtypes.float32) # t1184: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1185 = prims.neg(t1184) # t1185: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1186 = prims.exp(t1185) # t1186: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1187 = prims.add(1.0, t1186) # t1187: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1188 = prims.reciprocal(t1187) # t1188: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1192 = prims.mul(t1184, t1188) # t1192: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1195 = prims.convert_element_type(t1183, dtypes.float32) # t1195: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1196 = prims.mul(t1192, t1195) # t1196: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1197 = prims.convert_element_type(t1196, dtypes.bfloat16) # t1197: \"cuda:0 bf16[1, 512, 11008]\"\n", " t1198 = torch.nn.functional.linear(t1197, t104, None) # t1198: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1198 = ltorch.linear(t1197, t104, None) # t1198: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1198 = prims.linear(t1197, t104, None) # t1198: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t1202, t1209, t1217] = nvFusion50(t1166, t1198, t1213)\n", " # t1200 = prims.convert_element_type(t1166, dtypes.float32) # t1200: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1199 = prims.convert_element_type(t1198, dtypes.float32) # t1199: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1201 = prims.add(t1199, t1200) # t1201: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1202 = prims.convert_element_type(t1201, dtypes.bfloat16) # t1202: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1204 = prims.mul(t1201, t1201) # t1204: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1205 = prims.sum(t1204, (2,)) # t1205: \"cuda:0 f32[1, 512]\"\n", " # t1206 = prims.broadcast_in_dim(t1205, [1, 512, 1], [0, 1]) # t1206: \"cuda:0 f32[1, 512, 1]\"\n", " # t1207 = prims.div(t1206, 4096.0) # t1207: \"cuda:0 f32[1, 512, 1]\"\n", " # t1208 = prims.add(t1207, 1e-05) # t1208: \"cuda:0 f32[1, 512, 1]\"\n", " # t1209 = prims.rsqrt(t1208) # t1209: \"cuda:0 f32[1, 512, 1]\"\n", " # t1210 = prims.broadcast_in_dim(t1209, (1, 512, 4096), (0, 1, 2)) # t1210: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1211 = prims.mul(t1201, t1210) # t1211: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1215 = prims.convert_element_type(t1213, dtypes.float32) # t1215: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1216 = prims.mul(t1211, t1215) # t1216: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1217 = prims.convert_element_type(t1216, dtypes.bfloat16) # t1217: \"cuda:0 bf16[1, 512, 4096]\"\n", " t1218 = torch.nn.functional.linear(t1217, t13, None) # t1218: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t1218 = ltorch.linear(t1217, t13, None) # t1218: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t1218 = prims.linear(t1217, t13, None) # t1218: \"cuda:0 bf16[1, 512, 12288]\"\n", " t1219 = torch.reshape(t1218, (1, 512, 32, 3, 128)) # t1219: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t1219 = ltorch.reshape(t1218, (1, 512, 32, 3, 128)) # t1219: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t1219 = prims.reshape(t1218, (1, 512, 32, 3, 128)) # t1219: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " del t1218\n", " t1220 = torch.permute(t1219, (0, 2, 3, 1, 4)) # t1220: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t1220 = ltorch.permute(t1219, (0, 2, 3, 1, 4)) # t1220: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t1220 = prims.transpose(t1219, (0, 2, 3, 1, 4)) # t1220: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " del t1219\n", " (t1221, t1222, t1223) = torch.split(t1220, (1, 1, 1), 2)\n", " # (t1221, t1222, t1223) = ltorch.split(t1220, (1, 1, 1), 2)\n", " # t1221 = prims.slice_prim(t1220, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1221: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t1222 = prims.slice_prim(t1220, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1222: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t1223 = prims.slice_prim(t1220, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1223: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " del t1220\n", " t1224 = torch.reshape(t1221, (1, 32, 512, 128)) # t1224: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1224 = ltorch.reshape(t1221, (1, 32, 512, 128)) # t1224: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1224 = prims.reshape(t1221, (1, 32, 512, 128)) # t1224: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1221\n", " t1225 = torch.reshape(t1222, (1, 32, 512, 128)) # t1225: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1225 = ltorch.reshape(t1222, (1, 32, 512, 128)) # t1225: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1225 = prims.reshape(t1222, (1, 32, 512, 128)) # t1225: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1222\n", " t1226 = torch.reshape(t1223, (1, 32, 512, 128)) # t1226: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1226 = ltorch.reshape(t1223, (1, 32, 512, 128)) # t1226: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1226 = prims.reshape(t1223, (1, 32, 512, 128)) # t1226: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1223\n", " t1227 = torch_slice_prim_impl(t1224, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1227: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t1242 = torch_slice_prim_impl(t1225, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1242: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t1257 = torch_slice_prim_impl(t1224, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1257: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t1224\n", " t1259 = torch_slice_prim_impl(t1225, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1259: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t1225\n", " t1228 = torch_slice_prim_impl(t1227, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1228: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1229 = torch_slice_prim_impl(t1227, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1229: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1243 = torch_slice_prim_impl(t1242, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1243: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1244 = torch_slice_prim_impl(t1242, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1244: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " [t1232, t1247] = nvFusion51(t1227, t1229, t1242, t1244)\n", " # t1230 = prims.convert_element_type(t1229, dtypes.float32) # t1230: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1231 = prims.neg(t1230) # t1231: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1232 = prims.convert_element_type(t1231, dtypes.bfloat16) # t1232: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " # t1245 = prims.convert_element_type(t1244, dtypes.float32) # t1245: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1246 = prims.neg(t1245) # t1246: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1247 = prims.convert_element_type(t1246, dtypes.bfloat16) # t1247: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " del t1229, t1244\n", " t1233 = torch.cat((t1232, t1228), -1) # t1233: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1233 = ltorch.cat((t1232, t1228), -1) # t1233: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1233 = prims.cat((t1232, t1228), -1) # t1233: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1232, t1228\n", " t1248 = torch.cat((t1247, t1243), -1) # t1248: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1248 = ltorch.cat((t1247, t1243), -1) # t1248: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1248 = prims.cat((t1247, t1243), -1) # t1248: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1247, t1243\n", " [t1241, t1256] = nvFusion52(t1227, t1233, t1242, t1248, t154, t157)\n", " # t1235 = prims.convert_element_type(t1227, dtypes.float32) # t1235: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1250 = prims.convert_element_type(t1242, dtypes.float32) # t1250: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1236 = prims.mul(t1235, t154) # t1236: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1238 = prims.convert_element_type(t1233, dtypes.float32) # t1238: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1239 = prims.mul(t1238, t157) # t1239: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1240 = prims.add(t1236, t1239) # t1240: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1241 = prims.convert_element_type(t1240, dtypes.bfloat16) # t1241: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1251 = prims.mul(t1250, t154) # t1251: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1253 = prims.convert_element_type(t1248, dtypes.float32) # t1253: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1254 = prims.mul(t1253, t157) # t1254: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1255 = prims.add(t1251, t1254) # t1255: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1256 = prims.convert_element_type(t1255, dtypes.bfloat16) # t1256: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1227, t1233, t1242, t1248\n", " t1258 = torch.cat((t1241, t1257), -1) # t1258: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1258 = ltorch.cat((t1241, t1257), -1) # t1258: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1258 = prims.cat((t1241, t1257), -1) # t1258: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1241, t1257\n", " t1260 = torch.cat((t1256, t1259), -1) # t1260: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1260 = ltorch.cat((t1256, t1259), -1) # t1260: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1260 = prims.cat((t1256, t1259), -1) # t1260: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1256, t1259\n", " (t1261, t1262, t1263, t1264, _, _, t1265, t1266, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1258, t1260, t1226, 0.0, True, scale=0.08838834764831843)\n", " t1268 = torch.permute(t1261, (0, 2, 1, 3)) # t1268: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t1268 = ltorch.permute(t1261, (0, 2, 1, 3)) # t1268: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t1268 = prims.transpose(t1261, (0, 2, 1, 3)) # t1268: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " t1269 = torch.reshape(t1268, (1, 512, 4096)) # t1269: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1269 = ltorch.reshape(t1268, (1, 512, 4096)) # t1269: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1269 = prims.reshape(t1268, (1, 512, 4096)) # t1269: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t1268\n", " t1270 = torch.nn.functional.linear(t1269, t105, None) # t1270: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1270 = ltorch.linear(t1269, t105, None) # t1270: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1270 = prims.linear(t1269, t105, None) # t1270: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t1274, t1281, t1289] = nvFusion53(t1202, t1270, t1285)\n", " # t1272 = prims.convert_element_type(t1202, dtypes.float32) # t1272: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1271 = prims.convert_element_type(t1270, dtypes.float32) # t1271: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1273 = prims.add(t1271, t1272) # t1273: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1274 = prims.convert_element_type(t1273, dtypes.bfloat16) # t1274: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1276 = prims.mul(t1273, t1273) # t1276: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1277 = prims.sum(t1276, (2,)) # t1277: \"cuda:0 f32[1, 512]\"\n", " # t1278 = prims.broadcast_in_dim(t1277, [1, 512, 1], [0, 1]) # t1278: \"cuda:0 f32[1, 512, 1]\"\n", " # t1279 = prims.div(t1278, 4096.0) # t1279: \"cuda:0 f32[1, 512, 1]\"\n", " # t1280 = prims.add(t1279, 1e-05) # t1280: \"cuda:0 f32[1, 512, 1]\"\n", " # t1281 = prims.rsqrt(t1280) # t1281: \"cuda:0 f32[1, 512, 1]\"\n", " # t1282 = prims.broadcast_in_dim(t1281, (1, 512, 4096), (0, 1, 2)) # t1282: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1283 = prims.mul(t1273, t1282) # t1283: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1287 = prims.convert_element_type(t1285, dtypes.float32) # t1287: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1288 = prims.mul(t1283, t1287) # t1288: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1289 = prims.convert_element_type(t1288, dtypes.bfloat16) # t1289: \"cuda:0 bf16[1, 512, 4096]\"\n", " t1290 = torch.nn.functional.linear(t1289, t29, None) # t1290: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1290 = ltorch.linear(t1289, t29, None) # t1290: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1290 = prims.linear(t1289, t29, None) # t1290: \"cuda:0 bf16[1, 512, 11008]\"\n", " t1291 = torch.nn.functional.linear(t1289, t45, None) # t1291: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1291 = ltorch.linear(t1289, t45, None) # t1291: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1291 = prims.linear(t1289, t45, None) # t1291: \"cuda:0 bf16[1, 512, 11008]\"\n", " [t1305] = nvFusion54(t1290, t1291)\n", " # t1292 = prims.convert_element_type(t1290, dtypes.float32) # t1292: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1293 = prims.neg(t1292) # t1293: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1294 = prims.exp(t1293) # t1294: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1295 = prims.add(1.0, t1294) # t1295: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1296 = prims.reciprocal(t1295) # t1296: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1300 = prims.mul(t1292, t1296) # t1300: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1303 = prims.convert_element_type(t1291, dtypes.float32) # t1303: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1304 = prims.mul(t1300, t1303) # t1304: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1305 = prims.convert_element_type(t1304, dtypes.bfloat16) # t1305: \"cuda:0 bf16[1, 512, 11008]\"\n", " t1306 = torch.nn.functional.linear(t1305, t106, None) # t1306: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1306 = ltorch.linear(t1305, t106, None) # t1306: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1306 = prims.linear(t1305, t106, None) # t1306: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t1310, t1317, t1325] = nvFusion55(t1274, t1306, t1321)\n", " # t1308 = prims.convert_element_type(t1274, dtypes.float32) # t1308: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1307 = prims.convert_element_type(t1306, dtypes.float32) # t1307: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1309 = prims.add(t1307, t1308) # t1309: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1310 = prims.convert_element_type(t1309, dtypes.bfloat16) # t1310: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1312 = prims.mul(t1309, t1309) # t1312: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1313 = prims.sum(t1312, (2,)) # t1313: \"cuda:0 f32[1, 512]\"\n", " # t1314 = prims.broadcast_in_dim(t1313, [1, 512, 1], [0, 1]) # t1314: \"cuda:0 f32[1, 512, 1]\"\n", " # t1315 = prims.div(t1314, 4096.0) # t1315: \"cuda:0 f32[1, 512, 1]\"\n", " # t1316 = prims.add(t1315, 1e-05) # t1316: \"cuda:0 f32[1, 512, 1]\"\n", " # t1317 = prims.rsqrt(t1316) # t1317: \"cuda:0 f32[1, 512, 1]\"\n", " # t1318 = prims.broadcast_in_dim(t1317, (1, 512, 4096), (0, 1, 2)) # t1318: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1319 = prims.mul(t1309, t1318) # t1319: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1323 = prims.convert_element_type(t1321, dtypes.float32) # t1323: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1324 = prims.mul(t1319, t1323) # t1324: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1325 = prims.convert_element_type(t1324, dtypes.bfloat16) # t1325: \"cuda:0 bf16[1, 512, 4096]\"\n", " t1326 = torch.nn.functional.linear(t1325, t14, None) # t1326: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t1326 = ltorch.linear(t1325, t14, None) # t1326: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t1326 = prims.linear(t1325, t14, None) # t1326: \"cuda:0 bf16[1, 512, 12288]\"\n", " t1327 = torch.reshape(t1326, (1, 512, 32, 3, 128)) # t1327: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t1327 = ltorch.reshape(t1326, (1, 512, 32, 3, 128)) # t1327: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t1327 = prims.reshape(t1326, (1, 512, 32, 3, 128)) # t1327: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " del t1326\n", " t1328 = torch.permute(t1327, (0, 2, 3, 1, 4)) # t1328: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t1328 = ltorch.permute(t1327, (0, 2, 3, 1, 4)) # t1328: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t1328 = prims.transpose(t1327, (0, 2, 3, 1, 4)) # t1328: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " del t1327\n", " (t1329, t1330, t1331) = torch.split(t1328, (1, 1, 1), 2)\n", " # (t1329, t1330, t1331) = ltorch.split(t1328, (1, 1, 1), 2)\n", " # t1329 = prims.slice_prim(t1328, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1329: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t1330 = prims.slice_prim(t1328, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1330: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t1331 = prims.slice_prim(t1328, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1331: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " del t1328\n", " t1332 = torch.reshape(t1329, (1, 32, 512, 128)) # t1332: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1332 = ltorch.reshape(t1329, (1, 32, 512, 128)) # t1332: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1332 = prims.reshape(t1329, (1, 32, 512, 128)) # t1332: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1329\n", " t1333 = torch.reshape(t1330, (1, 32, 512, 128)) # t1333: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1333 = ltorch.reshape(t1330, (1, 32, 512, 128)) # t1333: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1333 = prims.reshape(t1330, (1, 32, 512, 128)) # t1333: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1330\n", " t1334 = torch.reshape(t1331, (1, 32, 512, 128)) # t1334: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1334 = ltorch.reshape(t1331, (1, 32, 512, 128)) # t1334: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1334 = prims.reshape(t1331, (1, 32, 512, 128)) # t1334: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1331\n", " t1335 = torch_slice_prim_impl(t1332, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1335: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t1350 = torch_slice_prim_impl(t1333, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1350: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t1365 = torch_slice_prim_impl(t1332, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1365: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t1332\n", " t1367 = torch_slice_prim_impl(t1333, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1367: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t1333\n", " t1336 = torch_slice_prim_impl(t1335, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1336: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1337 = torch_slice_prim_impl(t1335, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1337: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1351 = torch_slice_prim_impl(t1350, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1351: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1352 = torch_slice_prim_impl(t1350, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1352: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " [t1340, t1355] = nvFusion56(t1335, t1337, t1350, t1352)\n", " # t1338 = prims.convert_element_type(t1337, dtypes.float32) # t1338: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1339 = prims.neg(t1338) # t1339: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1340 = prims.convert_element_type(t1339, dtypes.bfloat16) # t1340: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " # t1353 = prims.convert_element_type(t1352, dtypes.float32) # t1353: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1354 = prims.neg(t1353) # t1354: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1355 = prims.convert_element_type(t1354, dtypes.bfloat16) # t1355: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " del t1337, t1352\n", " t1341 = torch.cat((t1340, t1336), -1) # t1341: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1341 = ltorch.cat((t1340, t1336), -1) # t1341: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1341 = prims.cat((t1340, t1336), -1) # t1341: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1340, t1336\n", " t1356 = torch.cat((t1355, t1351), -1) # t1356: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1356 = ltorch.cat((t1355, t1351), -1) # t1356: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1356 = prims.cat((t1355, t1351), -1) # t1356: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1355, t1351\n", " [t1349, t1364] = nvFusion57(t1335, t1341, t1350, t1356, t154, t157)\n", " # t1343 = prims.convert_element_type(t1335, dtypes.float32) # t1343: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1358 = prims.convert_element_type(t1350, dtypes.float32) # t1358: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1344 = prims.mul(t1343, t154) # t1344: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1346 = prims.convert_element_type(t1341, dtypes.float32) # t1346: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1347 = prims.mul(t1346, t157) # t1347: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1348 = prims.add(t1344, t1347) # t1348: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1349 = prims.convert_element_type(t1348, dtypes.bfloat16) # t1349: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1359 = prims.mul(t1358, t154) # t1359: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1361 = prims.convert_element_type(t1356, dtypes.float32) # t1361: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1362 = prims.mul(t1361, t157) # t1362: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1363 = prims.add(t1359, t1362) # t1363: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1364 = prims.convert_element_type(t1363, dtypes.bfloat16) # t1364: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1335, t1341, t1350, t1356\n", " t1366 = torch.cat((t1349, t1365), -1) # t1366: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1366 = ltorch.cat((t1349, t1365), -1) # t1366: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1366 = prims.cat((t1349, t1365), -1) # t1366: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1349, t1365\n", " t1368 = torch.cat((t1364, t1367), -1) # t1368: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1368 = ltorch.cat((t1364, t1367), -1) # t1368: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1368 = prims.cat((t1364, t1367), -1) # t1368: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1364, t1367\n", " (t1369, t1370, t1371, t1372, _, _, t1373, t1374, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1366, t1368, t1334, 0.0, True, scale=0.08838834764831843)\n", " t1376 = torch.permute(t1369, (0, 2, 1, 3)) # t1376: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t1376 = ltorch.permute(t1369, (0, 2, 1, 3)) # t1376: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t1376 = prims.transpose(t1369, (0, 2, 1, 3)) # t1376: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " t1377 = torch.reshape(t1376, (1, 512, 4096)) # t1377: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1377 = ltorch.reshape(t1376, (1, 512, 4096)) # t1377: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1377 = prims.reshape(t1376, (1, 512, 4096)) # t1377: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t1376\n", " t1378 = torch.nn.functional.linear(t1377, t107, None) # t1378: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1378 = ltorch.linear(t1377, t107, None) # t1378: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1378 = prims.linear(t1377, t107, None) # t1378: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t1382, t1389, t1397] = nvFusion58(t1310, t1378, t1393)\n", " # t1380 = prims.convert_element_type(t1310, dtypes.float32) # t1380: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1379 = prims.convert_element_type(t1378, dtypes.float32) # t1379: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1381 = prims.add(t1379, t1380) # t1381: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1382 = prims.convert_element_type(t1381, dtypes.bfloat16) # t1382: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1384 = prims.mul(t1381, t1381) # t1384: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1385 = prims.sum(t1384, (2,)) # t1385: \"cuda:0 f32[1, 512]\"\n", " # t1386 = prims.broadcast_in_dim(t1385, [1, 512, 1], [0, 1]) # t1386: \"cuda:0 f32[1, 512, 1]\"\n", " # t1387 = prims.div(t1386, 4096.0) # t1387: \"cuda:0 f32[1, 512, 1]\"\n", " # t1388 = prims.add(t1387, 1e-05) # t1388: \"cuda:0 f32[1, 512, 1]\"\n", " # t1389 = prims.rsqrt(t1388) # t1389: \"cuda:0 f32[1, 512, 1]\"\n", " # t1390 = prims.broadcast_in_dim(t1389, (1, 512, 4096), (0, 1, 2)) # t1390: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1391 = prims.mul(t1381, t1390) # t1391: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1395 = prims.convert_element_type(t1393, dtypes.float32) # t1395: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1396 = prims.mul(t1391, t1395) # t1396: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1397 = prims.convert_element_type(t1396, dtypes.bfloat16) # t1397: \"cuda:0 bf16[1, 512, 4096]\"\n", " t1398 = torch.nn.functional.linear(t1397, t30, None) # t1398: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1398 = ltorch.linear(t1397, t30, None) # t1398: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1398 = prims.linear(t1397, t30, None) # t1398: \"cuda:0 bf16[1, 512, 11008]\"\n", " t1399 = torch.nn.functional.linear(t1397, t46, None) # t1399: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1399 = ltorch.linear(t1397, t46, None) # t1399: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1399 = prims.linear(t1397, t46, None) # t1399: \"cuda:0 bf16[1, 512, 11008]\"\n", " [t1413] = nvFusion59(t1398, t1399)\n", " # t1400 = prims.convert_element_type(t1398, dtypes.float32) # t1400: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1401 = prims.neg(t1400) # t1401: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1402 = prims.exp(t1401) # t1402: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1403 = prims.add(1.0, t1402) # t1403: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1404 = prims.reciprocal(t1403) # t1404: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1408 = prims.mul(t1400, t1404) # t1408: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1411 = prims.convert_element_type(t1399, dtypes.float32) # t1411: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1412 = prims.mul(t1408, t1411) # t1412: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1413 = prims.convert_element_type(t1412, dtypes.bfloat16) # t1413: \"cuda:0 bf16[1, 512, 11008]\"\n", " t1414 = torch.nn.functional.linear(t1413, t108, None) # t1414: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1414 = ltorch.linear(t1413, t108, None) # t1414: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1414 = prims.linear(t1413, t108, None) # t1414: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t1418, t1425, t1433] = nvFusion60(t1382, t1414, t1429)\n", " # t1416 = prims.convert_element_type(t1382, dtypes.float32) # t1416: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1415 = prims.convert_element_type(t1414, dtypes.float32) # t1415: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1417 = prims.add(t1415, t1416) # t1417: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1418 = prims.convert_element_type(t1417, dtypes.bfloat16) # t1418: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1420 = prims.mul(t1417, t1417) # t1420: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1421 = prims.sum(t1420, (2,)) # t1421: \"cuda:0 f32[1, 512]\"\n", " # t1422 = prims.broadcast_in_dim(t1421, [1, 512, 1], [0, 1]) # t1422: \"cuda:0 f32[1, 512, 1]\"\n", " # t1423 = prims.div(t1422, 4096.0) # t1423: \"cuda:0 f32[1, 512, 1]\"\n", " # t1424 = prims.add(t1423, 1e-05) # t1424: \"cuda:0 f32[1, 512, 1]\"\n", " # t1425 = prims.rsqrt(t1424) # t1425: \"cuda:0 f32[1, 512, 1]\"\n", " # t1426 = prims.broadcast_in_dim(t1425, (1, 512, 4096), (0, 1, 2)) # t1426: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1427 = prims.mul(t1417, t1426) # t1427: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1431 = prims.convert_element_type(t1429, dtypes.float32) # t1431: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1432 = prims.mul(t1427, t1431) # t1432: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1433 = prims.convert_element_type(t1432, dtypes.bfloat16) # t1433: \"cuda:0 bf16[1, 512, 4096]\"\n", " t1434 = torch.nn.functional.linear(t1433, t15, None) # t1434: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t1434 = ltorch.linear(t1433, t15, None) # t1434: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t1434 = prims.linear(t1433, t15, None) # t1434: \"cuda:0 bf16[1, 512, 12288]\"\n", " t1435 = torch.reshape(t1434, (1, 512, 32, 3, 128)) # t1435: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t1435 = ltorch.reshape(t1434, (1, 512, 32, 3, 128)) # t1435: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t1435 = prims.reshape(t1434, (1, 512, 32, 3, 128)) # t1435: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " del t1434\n", " t1436 = torch.permute(t1435, (0, 2, 3, 1, 4)) # t1436: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t1436 = ltorch.permute(t1435, (0, 2, 3, 1, 4)) # t1436: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t1436 = prims.transpose(t1435, (0, 2, 3, 1, 4)) # t1436: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " del t1435\n", " (t1437, t1438, t1439) = torch.split(t1436, (1, 1, 1), 2)\n", " # (t1437, t1438, t1439) = ltorch.split(t1436, (1, 1, 1), 2)\n", " # t1437 = prims.slice_prim(t1436, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1437: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t1438 = prims.slice_prim(t1436, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1438: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t1439 = prims.slice_prim(t1436, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1439: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " del t1436\n", " t1440 = torch.reshape(t1437, (1, 32, 512, 128)) # t1440: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1440 = ltorch.reshape(t1437, (1, 32, 512, 128)) # t1440: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1440 = prims.reshape(t1437, (1, 32, 512, 128)) # t1440: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1437\n", " t1441 = torch.reshape(t1438, (1, 32, 512, 128)) # t1441: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1441 = ltorch.reshape(t1438, (1, 32, 512, 128)) # t1441: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1441 = prims.reshape(t1438, (1, 32, 512, 128)) # t1441: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1438\n", " t1442 = torch.reshape(t1439, (1, 32, 512, 128)) # t1442: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1442 = ltorch.reshape(t1439, (1, 32, 512, 128)) # t1442: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1442 = prims.reshape(t1439, (1, 32, 512, 128)) # t1442: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1439\n", " t1443 = torch_slice_prim_impl(t1440, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1443: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t1458 = torch_slice_prim_impl(t1441, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1458: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t1473 = torch_slice_prim_impl(t1440, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1473: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t1440\n", " t1475 = torch_slice_prim_impl(t1441, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1475: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t1441\n", " t1444 = torch_slice_prim_impl(t1443, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1444: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1445 = torch_slice_prim_impl(t1443, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1445: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1459 = torch_slice_prim_impl(t1458, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1459: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1460 = torch_slice_prim_impl(t1458, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1460: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " [t1448, t1463] = nvFusion61(t1443, t1445, t1458, t1460)\n", " # t1446 = prims.convert_element_type(t1445, dtypes.float32) # t1446: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1447 = prims.neg(t1446) # t1447: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1448 = prims.convert_element_type(t1447, dtypes.bfloat16) # t1448: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " # t1461 = prims.convert_element_type(t1460, dtypes.float32) # t1461: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1462 = prims.neg(t1461) # t1462: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1463 = prims.convert_element_type(t1462, dtypes.bfloat16) # t1463: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " del t1445, t1460\n", " t1464 = torch.cat((t1463, t1459), -1) # t1464: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1464 = ltorch.cat((t1463, t1459), -1) # t1464: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1464 = prims.cat((t1463, t1459), -1) # t1464: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1463, t1459\n", " t1449 = torch.cat((t1448, t1444), -1) # t1449: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1449 = ltorch.cat((t1448, t1444), -1) # t1449: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1449 = prims.cat((t1448, t1444), -1) # t1449: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1448, t1444\n", " [t1457, t1472] = nvFusion62(t1443, t1449, t1458, t1464, t154, t157)\n", " # t1451 = prims.convert_element_type(t1443, dtypes.float32) # t1451: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1466 = prims.convert_element_type(t1458, dtypes.float32) # t1466: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1467 = prims.mul(t1466, t154) # t1467: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1469 = prims.convert_element_type(t1464, dtypes.float32) # t1469: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1470 = prims.mul(t1469, t157) # t1470: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1471 = prims.add(t1467, t1470) # t1471: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1472 = prims.convert_element_type(t1471, dtypes.bfloat16) # t1472: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1452 = prims.mul(t1451, t154) # t1452: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1454 = prims.convert_element_type(t1449, dtypes.float32) # t1454: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1455 = prims.mul(t1454, t157) # t1455: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1456 = prims.add(t1452, t1455) # t1456: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1457 = prims.convert_element_type(t1456, dtypes.bfloat16) # t1457: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1443, t1449, t1458, t1464\n", " t1476 = torch.cat((t1472, t1475), -1) # t1476: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1476 = ltorch.cat((t1472, t1475), -1) # t1476: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1476 = prims.cat((t1472, t1475), -1) # t1476: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1472, t1475\n", " t1474 = torch.cat((t1457, t1473), -1) # t1474: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1474 = ltorch.cat((t1457, t1473), -1) # t1474: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1474 = prims.cat((t1457, t1473), -1) # t1474: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1457, t1473\n", " (t1477, t1478, t1479, t1480, _, _, t1481, t1482, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1474, t1476, t1442, 0.0, True, scale=0.08838834764831843)\n", " t1484 = torch.permute(t1477, (0, 2, 1, 3)) # t1484: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t1484 = ltorch.permute(t1477, (0, 2, 1, 3)) # t1484: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t1484 = prims.transpose(t1477, (0, 2, 1, 3)) # t1484: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " t1485 = torch.reshape(t1484, (1, 512, 4096)) # t1485: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1485 = ltorch.reshape(t1484, (1, 512, 4096)) # t1485: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1485 = prims.reshape(t1484, (1, 512, 4096)) # t1485: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t1484\n", " t1486 = torch.nn.functional.linear(t1485, t109, None) # t1486: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1486 = ltorch.linear(t1485, t109, None) # t1486: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1486 = prims.linear(t1485, t109, None) # t1486: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t1490, t1497, t1505] = nvFusion63(t1418, t1486, t1501)\n", " # t1488 = prims.convert_element_type(t1418, dtypes.float32) # t1488: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1487 = prims.convert_element_type(t1486, dtypes.float32) # t1487: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1489 = prims.add(t1487, t1488) # t1489: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1490 = prims.convert_element_type(t1489, dtypes.bfloat16) # t1490: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1492 = prims.mul(t1489, t1489) # t1492: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1493 = prims.sum(t1492, (2,)) # t1493: \"cuda:0 f32[1, 512]\"\n", " # t1494 = prims.broadcast_in_dim(t1493, [1, 512, 1], [0, 1]) # t1494: \"cuda:0 f32[1, 512, 1]\"\n", " # t1495 = prims.div(t1494, 4096.0) # t1495: \"cuda:0 f32[1, 512, 1]\"\n", " # t1496 = prims.add(t1495, 1e-05) # t1496: \"cuda:0 f32[1, 512, 1]\"\n", " # t1497 = prims.rsqrt(t1496) # t1497: \"cuda:0 f32[1, 512, 1]\"\n", " # t1498 = prims.broadcast_in_dim(t1497, (1, 512, 4096), (0, 1, 2)) # t1498: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1499 = prims.mul(t1489, t1498) # t1499: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1503 = prims.convert_element_type(t1501, dtypes.float32) # t1503: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1504 = prims.mul(t1499, t1503) # t1504: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1505 = prims.convert_element_type(t1504, dtypes.bfloat16) # t1505: \"cuda:0 bf16[1, 512, 4096]\"\n", " t1506 = torch.nn.functional.linear(t1505, t31, None) # t1506: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1506 = ltorch.linear(t1505, t31, None) # t1506: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1506 = prims.linear(t1505, t31, None) # t1506: \"cuda:0 bf16[1, 512, 11008]\"\n", " t1507 = torch.nn.functional.linear(t1505, t47, None) # t1507: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1507 = ltorch.linear(t1505, t47, None) # t1507: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1507 = prims.linear(t1505, t47, None) # t1507: \"cuda:0 bf16[1, 512, 11008]\"\n", " [t1521] = nvFusion64(t1506, t1507)\n", " # t1508 = prims.convert_element_type(t1506, dtypes.float32) # t1508: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1509 = prims.neg(t1508) # t1509: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1510 = prims.exp(t1509) # t1510: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1511 = prims.add(1.0, t1510) # t1511: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1512 = prims.reciprocal(t1511) # t1512: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1516 = prims.mul(t1508, t1512) # t1516: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1519 = prims.convert_element_type(t1507, dtypes.float32) # t1519: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1520 = prims.mul(t1516, t1519) # t1520: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1521 = prims.convert_element_type(t1520, dtypes.bfloat16) # t1521: \"cuda:0 bf16[1, 512, 11008]\"\n", " t1522 = torch.nn.functional.linear(t1521, t110, None) # t1522: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1522 = ltorch.linear(t1521, t110, None) # t1522: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1522 = prims.linear(t1521, t110, None) # t1522: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t1526, t1533, t1541] = nvFusion65(t1490, t1522, t1537)\n", " # t1524 = prims.convert_element_type(t1490, dtypes.float32) # t1524: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1523 = prims.convert_element_type(t1522, dtypes.float32) # t1523: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1525 = prims.add(t1523, t1524) # t1525: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1526 = prims.convert_element_type(t1525, dtypes.bfloat16) # t1526: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1528 = prims.mul(t1525, t1525) # t1528: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1529 = prims.sum(t1528, (2,)) # t1529: \"cuda:0 f32[1, 512]\"\n", " # t1530 = prims.broadcast_in_dim(t1529, [1, 512, 1], [0, 1]) # t1530: \"cuda:0 f32[1, 512, 1]\"\n", " # t1531 = prims.div(t1530, 4096.0) # t1531: \"cuda:0 f32[1, 512, 1]\"\n", " # t1532 = prims.add(t1531, 1e-05) # t1532: \"cuda:0 f32[1, 512, 1]\"\n", " # t1533 = prims.rsqrt(t1532) # t1533: \"cuda:0 f32[1, 512, 1]\"\n", " # t1534 = prims.broadcast_in_dim(t1533, (1, 512, 4096), (0, 1, 2)) # t1534: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1535 = prims.mul(t1525, t1534) # t1535: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1539 = prims.convert_element_type(t1537, dtypes.float32) # t1539: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1540 = prims.mul(t1535, t1539) # t1540: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1541 = prims.convert_element_type(t1540, dtypes.bfloat16) # t1541: \"cuda:0 bf16[1, 512, 4096]\"\n", " t1542 = torch.nn.functional.linear(t1541, t16, None) # t1542: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t1542 = ltorch.linear(t1541, t16, None) # t1542: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t1542 = prims.linear(t1541, t16, None) # t1542: \"cuda:0 bf16[1, 512, 12288]\"\n", " t1543 = torch.reshape(t1542, (1, 512, 32, 3, 128)) # t1543: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t1543 = ltorch.reshape(t1542, (1, 512, 32, 3, 128)) # t1543: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t1543 = prims.reshape(t1542, (1, 512, 32, 3, 128)) # t1543: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " del t1542\n", " t1544 = torch.permute(t1543, (0, 2, 3, 1, 4)) # t1544: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t1544 = ltorch.permute(t1543, (0, 2, 3, 1, 4)) # t1544: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t1544 = prims.transpose(t1543, (0, 2, 3, 1, 4)) # t1544: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " del t1543\n", " (t1545, t1546, t1547) = torch.split(t1544, (1, 1, 1), 2)\n", " # (t1545, t1546, t1547) = ltorch.split(t1544, (1, 1, 1), 2)\n", " # t1545 = prims.slice_prim(t1544, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1545: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t1546 = prims.slice_prim(t1544, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1546: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t1547 = prims.slice_prim(t1544, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1547: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " del t1544\n", " t1548 = torch.reshape(t1545, (1, 32, 512, 128)) # t1548: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1548 = ltorch.reshape(t1545, (1, 32, 512, 128)) # t1548: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1548 = prims.reshape(t1545, (1, 32, 512, 128)) # t1548: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1545\n", " t1549 = torch.reshape(t1546, (1, 32, 512, 128)) # t1549: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1549 = ltorch.reshape(t1546, (1, 32, 512, 128)) # t1549: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1549 = prims.reshape(t1546, (1, 32, 512, 128)) # t1549: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1546\n", " t1550 = torch.reshape(t1547, (1, 32, 512, 128)) # t1550: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1550 = ltorch.reshape(t1547, (1, 32, 512, 128)) # t1550: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1550 = prims.reshape(t1547, (1, 32, 512, 128)) # t1550: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1547\n", " t1551 = torch_slice_prim_impl(t1548, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1551: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t1566 = torch_slice_prim_impl(t1549, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1566: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t1581 = torch_slice_prim_impl(t1548, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1581: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t1548\n", " t1583 = torch_slice_prim_impl(t1549, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1583: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t1549\n", " t1552 = torch_slice_prim_impl(t1551, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1552: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1553 = torch_slice_prim_impl(t1551, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1553: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1567 = torch_slice_prim_impl(t1566, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1567: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1568 = torch_slice_prim_impl(t1566, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1568: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " [t1556, t1571] = nvFusion66(t1551, t1553, t1566, t1568)\n", " # t1554 = prims.convert_element_type(t1553, dtypes.float32) # t1554: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1555 = prims.neg(t1554) # t1555: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1556 = prims.convert_element_type(t1555, dtypes.bfloat16) # t1556: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " # t1569 = prims.convert_element_type(t1568, dtypes.float32) # t1569: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1570 = prims.neg(t1569) # t1570: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1571 = prims.convert_element_type(t1570, dtypes.bfloat16) # t1571: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " del t1553, t1568\n", " t1572 = torch.cat((t1571, t1567), -1) # t1572: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1572 = ltorch.cat((t1571, t1567), -1) # t1572: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1572 = prims.cat((t1571, t1567), -1) # t1572: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1571, t1567\n", " t1557 = torch.cat((t1556, t1552), -1) # t1557: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1557 = ltorch.cat((t1556, t1552), -1) # t1557: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1557 = prims.cat((t1556, t1552), -1) # t1557: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1556, t1552\n", " [t1565, t1580] = nvFusion67(t154, t1551, t1557, t1566, t157, t1572)\n", " # t1559 = prims.convert_element_type(t1551, dtypes.float32) # t1559: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1574 = prims.convert_element_type(t1566, dtypes.float32) # t1574: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1575 = prims.mul(t1574, t154) # t1575: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1577 = prims.convert_element_type(t1572, dtypes.float32) # t1577: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1578 = prims.mul(t1577, t157) # t1578: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1579 = prims.add(t1575, t1578) # t1579: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1580 = prims.convert_element_type(t1579, dtypes.bfloat16) # t1580: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1560 = prims.mul(t1559, t154) # t1560: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1562 = prims.convert_element_type(t1557, dtypes.float32) # t1562: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1563 = prims.mul(t1562, t157) # t1563: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1564 = prims.add(t1560, t1563) # t1564: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1565 = prims.convert_element_type(t1564, dtypes.bfloat16) # t1565: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1551, t1557, t1566, t1572\n", " t1584 = torch.cat((t1580, t1583), -1) # t1584: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1584 = ltorch.cat((t1580, t1583), -1) # t1584: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1584 = prims.cat((t1580, t1583), -1) # t1584: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1580, t1583\n", " t1582 = torch.cat((t1565, t1581), -1) # t1582: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1582 = ltorch.cat((t1565, t1581), -1) # t1582: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1582 = prims.cat((t1565, t1581), -1) # t1582: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1565, t1581\n", " (t1585, t1586, t1587, t1588, _, _, t1589, t1590, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1582, t1584, t1550, 0.0, True, scale=0.08838834764831843)\n", " t1592 = torch.permute(t1585, (0, 2, 1, 3)) # t1592: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t1592 = ltorch.permute(t1585, (0, 2, 1, 3)) # t1592: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t1592 = prims.transpose(t1585, (0, 2, 1, 3)) # t1592: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " t1593 = torch.reshape(t1592, (1, 512, 4096)) # t1593: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1593 = ltorch.reshape(t1592, (1, 512, 4096)) # t1593: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1593 = prims.reshape(t1592, (1, 512, 4096)) # t1593: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t1592\n", " t1594 = torch.nn.functional.linear(t1593, t111, None) # t1594: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1594 = ltorch.linear(t1593, t111, None) # t1594: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1594 = prims.linear(t1593, t111, None) # t1594: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t1598, t1605, t1613] = nvFusion68(t1526, t1594, t1609)\n", " # t1596 = prims.convert_element_type(t1526, dtypes.float32) # t1596: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1595 = prims.convert_element_type(t1594, dtypes.float32) # t1595: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1597 = prims.add(t1595, t1596) # t1597: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1598 = prims.convert_element_type(t1597, dtypes.bfloat16) # t1598: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1600 = prims.mul(t1597, t1597) # t1600: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1601 = prims.sum(t1600, (2,)) # t1601: \"cuda:0 f32[1, 512]\"\n", " # t1602 = prims.broadcast_in_dim(t1601, [1, 512, 1], [0, 1]) # t1602: \"cuda:0 f32[1, 512, 1]\"\n", " # t1603 = prims.div(t1602, 4096.0) # t1603: \"cuda:0 f32[1, 512, 1]\"\n", " # t1604 = prims.add(t1603, 1e-05) # t1604: \"cuda:0 f32[1, 512, 1]\"\n", " # t1605 = prims.rsqrt(t1604) # t1605: \"cuda:0 f32[1, 512, 1]\"\n", " # t1606 = prims.broadcast_in_dim(t1605, (1, 512, 4096), (0, 1, 2)) # t1606: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1607 = prims.mul(t1597, t1606) # t1607: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1611 = prims.convert_element_type(t1609, dtypes.float32) # t1611: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1612 = prims.mul(t1607, t1611) # t1612: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1613 = prims.convert_element_type(t1612, dtypes.bfloat16) # t1613: \"cuda:0 bf16[1, 512, 4096]\"\n", " t1614 = torch.nn.functional.linear(t1613, t32, None) # t1614: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1614 = ltorch.linear(t1613, t32, None) # t1614: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1614 = prims.linear(t1613, t32, None) # t1614: \"cuda:0 bf16[1, 512, 11008]\"\n", " t1615 = torch.nn.functional.linear(t1613, t48, None) # t1615: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1615 = ltorch.linear(t1613, t48, None) # t1615: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1615 = prims.linear(t1613, t48, None) # t1615: \"cuda:0 bf16[1, 512, 11008]\"\n", " [t1629] = nvFusion69(t1614, t1615)\n", " # t1616 = prims.convert_element_type(t1614, dtypes.float32) # t1616: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1617 = prims.neg(t1616) # t1617: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1618 = prims.exp(t1617) # t1618: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1619 = prims.add(1.0, t1618) # t1619: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1620 = prims.reciprocal(t1619) # t1620: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1624 = prims.mul(t1616, t1620) # t1624: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1627 = prims.convert_element_type(t1615, dtypes.float32) # t1627: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1628 = prims.mul(t1624, t1627) # t1628: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1629 = prims.convert_element_type(t1628, dtypes.bfloat16) # t1629: \"cuda:0 bf16[1, 512, 11008]\"\n", " t1630 = torch.nn.functional.linear(t1629, t112, None) # t1630: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1630 = ltorch.linear(t1629, t112, None) # t1630: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1630 = prims.linear(t1629, t112, None) # t1630: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t1634, t1641, t1649] = nvFusion70(t1598, t1630, t1645)\n", " # t1632 = prims.convert_element_type(t1598, dtypes.float32) # t1632: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1631 = prims.convert_element_type(t1630, dtypes.float32) # t1631: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1633 = prims.add(t1631, t1632) # t1633: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1634 = prims.convert_element_type(t1633, dtypes.bfloat16) # t1634: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1636 = prims.mul(t1633, t1633) # t1636: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1637 = prims.sum(t1636, (2,)) # t1637: \"cuda:0 f32[1, 512]\"\n", " # t1638 = prims.broadcast_in_dim(t1637, [1, 512, 1], [0, 1]) # t1638: \"cuda:0 f32[1, 512, 1]\"\n", " # t1639 = prims.div(t1638, 4096.0) # t1639: \"cuda:0 f32[1, 512, 1]\"\n", " # t1640 = prims.add(t1639, 1e-05) # t1640: \"cuda:0 f32[1, 512, 1]\"\n", " # t1641 = prims.rsqrt(t1640) # t1641: \"cuda:0 f32[1, 512, 1]\"\n", " # t1642 = prims.broadcast_in_dim(t1641, (1, 512, 4096), (0, 1, 2)) # t1642: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1643 = prims.mul(t1633, t1642) # t1643: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1647 = prims.convert_element_type(t1645, dtypes.float32) # t1647: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1648 = prims.mul(t1643, t1647) # t1648: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1649 = prims.convert_element_type(t1648, dtypes.bfloat16) # t1649: \"cuda:0 bf16[1, 512, 4096]\"\n", " t1650 = torch.nn.functional.linear(t1649, t17, None) # t1650: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t1650 = ltorch.linear(t1649, t17, None) # t1650: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t1650 = prims.linear(t1649, t17, None) # t1650: \"cuda:0 bf16[1, 512, 12288]\"\n", " t1651 = torch.reshape(t1650, (1, 512, 32, 3, 128)) # t1651: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t1651 = ltorch.reshape(t1650, (1, 512, 32, 3, 128)) # t1651: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t1651 = prims.reshape(t1650, (1, 512, 32, 3, 128)) # t1651: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " del t1650\n", " t1652 = torch.permute(t1651, (0, 2, 3, 1, 4)) # t1652: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t1652 = ltorch.permute(t1651, (0, 2, 3, 1, 4)) # t1652: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t1652 = prims.transpose(t1651, (0, 2, 3, 1, 4)) # t1652: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " del t1651\n", " (t1653, t1654, t1655) = torch.split(t1652, (1, 1, 1), 2)\n", " # (t1653, t1654, t1655) = ltorch.split(t1652, (1, 1, 1), 2)\n", " # t1653 = prims.slice_prim(t1652, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1653: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t1654 = prims.slice_prim(t1652, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1654: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t1655 = prims.slice_prim(t1652, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1655: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " del t1652\n", " t1656 = torch.reshape(t1653, (1, 32, 512, 128)) # t1656: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1656 = ltorch.reshape(t1653, (1, 32, 512, 128)) # t1656: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1656 = prims.reshape(t1653, (1, 32, 512, 128)) # t1656: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1653\n", " t1657 = torch.reshape(t1654, (1, 32, 512, 128)) # t1657: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1657 = ltorch.reshape(t1654, (1, 32, 512, 128)) # t1657: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1657 = prims.reshape(t1654, (1, 32, 512, 128)) # t1657: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1654\n", " t1658 = torch.reshape(t1655, (1, 32, 512, 128)) # t1658: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1658 = ltorch.reshape(t1655, (1, 32, 512, 128)) # t1658: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1658 = prims.reshape(t1655, (1, 32, 512, 128)) # t1658: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1655\n", " t1689 = torch_slice_prim_impl(t1656, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1689: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " t1691 = torch_slice_prim_impl(t1657, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1691: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " t1659 = torch_slice_prim_impl(t1656, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1659: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1656\n", " t1674 = torch_slice_prim_impl(t1657, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1674: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1657\n", " t1660 = torch_slice_prim_impl(t1659, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1660: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1661 = torch_slice_prim_impl(t1659, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1661: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1675 = torch_slice_prim_impl(t1674, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1675: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1676 = torch_slice_prim_impl(t1674, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1676: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " [t1664, t1679] = nvFusion71(t1659, t1661, t1674, t1676)\n", " # t1662 = prims.convert_element_type(t1661, dtypes.float32) # t1662: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1663 = prims.neg(t1662) # t1663: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1664 = prims.convert_element_type(t1663, dtypes.bfloat16) # t1664: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " # t1677 = prims.convert_element_type(t1676, dtypes.float32) # t1677: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1678 = prims.neg(t1677) # t1678: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1679 = prims.convert_element_type(t1678, dtypes.bfloat16) # t1679: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " del t1661, t1676\n", " t1680 = torch.cat((t1679, t1675), -1) # t1680: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1680 = ltorch.cat((t1679, t1675), -1) # t1680: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1680 = prims.cat((t1679, t1675), -1) # t1680: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1679, t1675\n", " t1665 = torch.cat((t1664, t1660), -1) # t1665: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1665 = ltorch.cat((t1664, t1660), -1) # t1665: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1665 = prims.cat((t1664, t1660), -1) # t1665: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1664, t1660\n", " [t1673, t1688] = nvFusion72(t154, t157, t1659, t1665, t1674, t1680)\n", " # t1667 = prims.convert_element_type(t1659, dtypes.float32) # t1667: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1682 = prims.convert_element_type(t1674, dtypes.float32) # t1682: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1683 = prims.mul(t1682, t154) # t1683: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1685 = prims.convert_element_type(t1680, dtypes.float32) # t1685: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1686 = prims.mul(t1685, t157) # t1686: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1687 = prims.add(t1683, t1686) # t1687: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1688 = prims.convert_element_type(t1687, dtypes.bfloat16) # t1688: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1668 = prims.mul(t1667, t154) # t1668: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1670 = prims.convert_element_type(t1665, dtypes.float32) # t1670: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1671 = prims.mul(t1670, t157) # t1671: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1672 = prims.add(t1668, t1671) # t1672: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1673 = prims.convert_element_type(t1672, dtypes.bfloat16) # t1673: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1659, t1665, t1674, t1680\n", " t1692 = torch.cat((t1688, t1691), -1) # t1692: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1692 = ltorch.cat((t1688, t1691), -1) # t1692: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1692 = prims.cat((t1688, t1691), -1) # t1692: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1688, t1691\n", " t1690 = torch.cat((t1673, t1689), -1) # t1690: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1690 = ltorch.cat((t1673, t1689), -1) # t1690: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1690 = prims.cat((t1673, t1689), -1) # t1690: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1673, t1689\n", " (t1693, t1694, t1695, t1696, _, _, t1697, t1698, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1690, t1692, t1658, 0.0, True, scale=0.08838834764831843)\n", " t1700 = torch.permute(t1693, (0, 2, 1, 3)) # t1700: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t1700 = ltorch.permute(t1693, (0, 2, 1, 3)) # t1700: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t1700 = prims.transpose(t1693, (0, 2, 1, 3)) # t1700: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " t1701 = torch.reshape(t1700, (1, 512, 4096)) # t1701: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1701 = ltorch.reshape(t1700, (1, 512, 4096)) # t1701: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1701 = prims.reshape(t1700, (1, 512, 4096)) # t1701: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t1700\n", " t1702 = torch.nn.functional.linear(t1701, t113, None) # t1702: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1702 = ltorch.linear(t1701, t113, None) # t1702: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1702 = prims.linear(t1701, t113, None) # t1702: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t1706, t1713, t1721] = nvFusion73(t1634, t1702, t1717)\n", " # t1704 = prims.convert_element_type(t1634, dtypes.float32) # t1704: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1703 = prims.convert_element_type(t1702, dtypes.float32) # t1703: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1705 = prims.add(t1703, t1704) # t1705: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1706 = prims.convert_element_type(t1705, dtypes.bfloat16) # t1706: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1708 = prims.mul(t1705, t1705) # t1708: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1709 = prims.sum(t1708, (2,)) # t1709: \"cuda:0 f32[1, 512]\"\n", " # t1710 = prims.broadcast_in_dim(t1709, [1, 512, 1], [0, 1]) # t1710: \"cuda:0 f32[1, 512, 1]\"\n", " # t1711 = prims.div(t1710, 4096.0) # t1711: \"cuda:0 f32[1, 512, 1]\"\n", " # t1712 = prims.add(t1711, 1e-05) # t1712: \"cuda:0 f32[1, 512, 1]\"\n", " # t1713 = prims.rsqrt(t1712) # t1713: \"cuda:0 f32[1, 512, 1]\"\n", " # t1714 = prims.broadcast_in_dim(t1713, (1, 512, 4096), (0, 1, 2)) # t1714: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1715 = prims.mul(t1705, t1714) # t1715: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1719 = prims.convert_element_type(t1717, dtypes.float32) # t1719: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1720 = prims.mul(t1715, t1719) # t1720: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1721 = prims.convert_element_type(t1720, dtypes.bfloat16) # t1721: \"cuda:0 bf16[1, 512, 4096]\"\n", " t1722 = torch.nn.functional.linear(t1721, t33, None) # t1722: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1722 = ltorch.linear(t1721, t33, None) # t1722: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1722 = prims.linear(t1721, t33, None) # t1722: \"cuda:0 bf16[1, 512, 11008]\"\n", " t1723 = torch.nn.functional.linear(t1721, t49, None) # t1723: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1723 = ltorch.linear(t1721, t49, None) # t1723: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1723 = prims.linear(t1721, t49, None) # t1723: \"cuda:0 bf16[1, 512, 11008]\"\n", " [t1737] = nvFusion74(t1722, t1723)\n", " # t1724 = prims.convert_element_type(t1722, dtypes.float32) # t1724: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1725 = prims.neg(t1724) # t1725: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1726 = prims.exp(t1725) # t1726: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1727 = prims.add(1.0, t1726) # t1727: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1728 = prims.reciprocal(t1727) # t1728: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1732 = prims.mul(t1724, t1728) # t1732: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1735 = prims.convert_element_type(t1723, dtypes.float32) # t1735: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1736 = prims.mul(t1732, t1735) # t1736: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1737 = prims.convert_element_type(t1736, dtypes.bfloat16) # t1737: \"cuda:0 bf16[1, 512, 11008]\"\n", " t1738 = torch.nn.functional.linear(t1737, t114, None) # t1738: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1738 = ltorch.linear(t1737, t114, None) # t1738: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1738 = prims.linear(t1737, t114, None) # t1738: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t1742, t1749, t1757] = nvFusion75(t1706, t1738, t1753)\n", " # t1740 = prims.convert_element_type(t1706, dtypes.float32) # t1740: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1739 = prims.convert_element_type(t1738, dtypes.float32) # t1739: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1741 = prims.add(t1739, t1740) # t1741: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1742 = prims.convert_element_type(t1741, dtypes.bfloat16) # t1742: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1744 = prims.mul(t1741, t1741) # t1744: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1745 = prims.sum(t1744, (2,)) # t1745: \"cuda:0 f32[1, 512]\"\n", " # t1746 = prims.broadcast_in_dim(t1745, [1, 512, 1], [0, 1]) # t1746: \"cuda:0 f32[1, 512, 1]\"\n", " # t1747 = prims.div(t1746, 4096.0) # t1747: \"cuda:0 f32[1, 512, 1]\"\n", " # t1748 = prims.add(t1747, 1e-05) # t1748: \"cuda:0 f32[1, 512, 1]\"\n", " # t1749 = prims.rsqrt(t1748) # t1749: \"cuda:0 f32[1, 512, 1]\"\n", " # t1750 = prims.broadcast_in_dim(t1749, (1, 512, 4096), (0, 1, 2)) # t1750: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1751 = prims.mul(t1741, t1750) # t1751: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1755 = prims.convert_element_type(t1753, dtypes.float32) # t1755: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1756 = prims.mul(t1751, t1755) # t1756: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1757 = prims.convert_element_type(t1756, dtypes.bfloat16) # t1757: \"cuda:0 bf16[1, 512, 4096]\"\n", " t1758 = torch.nn.functional.linear(t1757, t18, None) # t1758: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t1758 = ltorch.linear(t1757, t18, None) # t1758: \"cuda:0 bf16[1, 512, 12288]\"\n", " # t1758 = prims.linear(t1757, t18, None) # t1758: \"cuda:0 bf16[1, 512, 12288]\"\n", " t1759 = torch.reshape(t1758, (1, 512, 32, 3, 128)) # t1759: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t1759 = ltorch.reshape(t1758, (1, 512, 32, 3, 128)) # t1759: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " # t1759 = prims.reshape(t1758, (1, 512, 32, 3, 128)) # t1759: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", " del t1758\n", " t1760 = torch.permute(t1759, (0, 2, 3, 1, 4)) # t1760: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t1760 = ltorch.permute(t1759, (0, 2, 3, 1, 4)) # t1760: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " # t1760 = prims.transpose(t1759, (0, 2, 3, 1, 4)) # t1760: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", " del t1759\n", " (t1761, t1762, t1763) = torch.split(t1760, (1, 1, 1), 2)\n", " # (t1761, t1762, t1763) = ltorch.split(t1760, (1, 1, 1), 2)\n", " # t1761 = prims.slice_prim(t1760, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1761: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t1762 = prims.slice_prim(t1760, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1762: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " # t1763 = prims.slice_prim(t1760, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1763: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", " del t1760\n", " t1764 = torch.reshape(t1761, (1, 32, 512, 128)) # t1764: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1764 = ltorch.reshape(t1761, (1, 32, 512, 128)) # t1764: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1764 = prims.reshape(t1761, (1, 32, 512, 128)) # t1764: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1761\n", " t1765 = torch.reshape(t1762, (1, 32, 512, 128)) # t1765: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1765 = ltorch.reshape(t1762, (1, 32, 512, 128)) # t1765: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1765 = prims.reshape(t1762, (1, 32, 512, 128)) # t1765: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1762\n", " t1766 = torch.reshape(t1763, (1, 32, 512, 128)) # t1766: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1766 = ltorch.reshape(t1763, (1, 32, 512, 128)) # t1766: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1766 = prims.reshape(t1763, (1, 32, 512, 128)) # t1766: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1763\n", " t1767 = torch_slice_prim_impl(t1764, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1767: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t1782 = torch_slice_prim_impl(t1765, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1782: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " t1797 = torch_slice_prim_impl(t1764, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1797: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t1764\n", " t1799 = torch_slice_prim_impl(t1765, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1799: \"cuda:0 bf16[1, 32, 512, 0]\"\n", " del t1765\n", " t1768 = torch_slice_prim_impl(t1767, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1768: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1769 = torch_slice_prim_impl(t1767, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1769: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1783 = torch_slice_prim_impl(t1782, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1783: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " t1784 = torch_slice_prim_impl(t1782, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1784: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " [t1772, t1787] = nvFusion76(t1767, t1769, t1782, t1784)\n", " # t1770 = prims.convert_element_type(t1769, dtypes.float32) # t1770: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1771 = prims.neg(t1770) # t1771: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1772 = prims.convert_element_type(t1771, dtypes.bfloat16) # t1772: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " # t1785 = prims.convert_element_type(t1784, dtypes.float32) # t1785: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1786 = prims.neg(t1785) # t1786: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t1787 = prims.convert_element_type(t1786, dtypes.bfloat16) # t1787: \"cuda:0 bf16[1, 32, 512, 64]\"\n", " del t1769, t1784\n", " t1788 = torch.cat((t1787, t1783), -1) # t1788: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1788 = ltorch.cat((t1787, t1783), -1) # t1788: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1788 = prims.cat((t1787, t1783), -1) # t1788: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1787, t1783\n", " t1773 = torch.cat((t1772, t1768), -1) # t1773: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1773 = ltorch.cat((t1772, t1768), -1) # t1773: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1773 = prims.cat((t1772, t1768), -1) # t1773: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1772, t1768\n", " [t1781, t1796] = nvFusion77(t154, t157, t1767, t1773, t1782, t1788)\n", " # t1775 = prims.convert_element_type(t1767, dtypes.float32) # t1775: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1790 = prims.convert_element_type(t1782, dtypes.float32) # t1790: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1791 = prims.mul(t1790, t154) # t1791: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1793 = prims.convert_element_type(t1788, dtypes.float32) # t1793: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1794 = prims.mul(t1793, t157) # t1794: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1795 = prims.add(t1791, t1794) # t1795: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1796 = prims.convert_element_type(t1795, dtypes.bfloat16) # t1796: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1776 = prims.mul(t1775, t154) # t1776: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1778 = prims.convert_element_type(t1773, dtypes.float32) # t1778: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1779 = prims.mul(t1778, t157) # t1779: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1780 = prims.add(t1776, t1779) # t1780: \"cuda:0 f32[1, 32, 512, 128]\"\n", " # t1781 = prims.convert_element_type(t1780, dtypes.bfloat16) # t1781: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1767, t1773, t1782, t1788\n", " t1800 = torch.cat((t1796, t1799), -1) # t1800: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1800 = ltorch.cat((t1796, t1799), -1) # t1800: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1800 = prims.cat((t1796, t1799), -1) # t1800: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1796, t1799\n", " t1798 = torch.cat((t1781, t1797), -1) # t1798: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1798 = ltorch.cat((t1781, t1797), -1) # t1798: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " # t1798 = prims.cat((t1781, t1797), -1) # t1798: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t1781, t1797\n", " (t1801, t1802, t1803, t1804, _, _, t1805, t1806, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1798, t1800, t1766, 0.0, True, scale=0.08838834764831843)\n", " t1808 = torch.permute(t1801, (0, 2, 1, 3)) # t1808: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t1808 = ltorch.permute(t1801, (0, 2, 1, 3)) # t1808: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " # t1808 = prims.transpose(t1801, (0, 2, 1, 3)) # t1808: \"cuda:0 bf16[1, 512, 32, 128]\"\n", " t1809 = torch.reshape(t1808, (1, 512, 4096)) # t1809: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1809 = ltorch.reshape(t1808, (1, 512, 4096)) # t1809: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1809 = prims.reshape(t1808, (1, 512, 4096)) # t1809: \"cuda:0 bf16[1, 512, 4096]\"\n", " del t1808\n", " t1810 = torch.nn.functional.linear(t1809, t115, None) # t1810: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1810 = ltorch.linear(t1809, t115, None) # t1810: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1810 = prims.linear(t1809, t115, None) # t1810: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t1814, t1821, t1829] = nvFusion78(t1742, t1810, t1825)\n", " # t1812 = prims.convert_element_type(t1742, dtypes.float32) # t1812: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1811 = prims.convert_element_type(t1810, dtypes.float32) # t1811: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1813 = prims.add(t1811, t1812) # t1813: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1814 = prims.convert_element_type(t1813, dtypes.bfloat16) # t1814: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1816 = prims.mul(t1813, t1813) # t1816: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1817 = prims.sum(t1816, (2,)) # t1817: \"cuda:0 f32[1, 512]\"\n", " # t1818 = prims.broadcast_in_dim(t1817, [1, 512, 1], [0, 1]) # t1818: \"cuda:0 f32[1, 512, 1]\"\n", " # t1819 = prims.div(t1818, 4096.0) # t1819: \"cuda:0 f32[1, 512, 1]\"\n", " # t1820 = prims.add(t1819, 1e-05) # t1820: \"cuda:0 f32[1, 512, 1]\"\n", " # t1821 = prims.rsqrt(t1820) # t1821: \"cuda:0 f32[1, 512, 1]\"\n", " # t1822 = prims.broadcast_in_dim(t1821, (1, 512, 4096), (0, 1, 2)) # t1822: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1823 = prims.mul(t1813, t1822) # t1823: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1827 = prims.convert_element_type(t1825, dtypes.float32) # t1827: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1828 = prims.mul(t1823, t1827) # t1828: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1829 = prims.convert_element_type(t1828, dtypes.bfloat16) # t1829: \"cuda:0 bf16[1, 512, 4096]\"\n", " t1831 = torch.nn.functional.linear(t1829, t50, None) # t1831: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1831 = ltorch.linear(t1829, t50, None) # t1831: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1831 = prims.linear(t1829, t50, None) # t1831: \"cuda:0 bf16[1, 512, 11008]\"\n", " t1830 = torch.nn.functional.linear(t1829, t34, None) # t1830: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1830 = ltorch.linear(t1829, t34, None) # t1830: \"cuda:0 bf16[1, 512, 11008]\"\n", " # t1830 = prims.linear(t1829, t34, None) # t1830: \"cuda:0 bf16[1, 512, 11008]\"\n", " [t1845] = nvFusion79(t1830, t1831)\n", " # t1832 = prims.convert_element_type(t1830, dtypes.float32) # t1832: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1833 = prims.neg(t1832) # t1833: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1834 = prims.exp(t1833) # t1834: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1835 = prims.add(1.0, t1834) # t1835: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1836 = prims.reciprocal(t1835) # t1836: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1840 = prims.mul(t1832, t1836) # t1840: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1843 = prims.convert_element_type(t1831, dtypes.float32) # t1843: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1844 = prims.mul(t1840, t1843) # t1844: \"cuda:0 f32[1, 512, 11008]\"\n", " # t1845 = prims.convert_element_type(t1844, dtypes.bfloat16) # t1845: \"cuda:0 bf16[1, 512, 11008]\"\n", " t1846 = torch.nn.functional.linear(t1845, t116, None) # t1846: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1846 = ltorch.linear(t1845, t116, None) # t1846: \"cuda:0 bf16[1, 512, 4096]\"\n", " # t1846 = prims.linear(t1845, t116, None) # t1846: \"cuda:0 bf16[1, 512, 4096]\"\n", " [t1857, t1865] = nvFusion80(t1814, t1846, t1861)\n", " # t1848 = prims.convert_element_type(t1814, dtypes.float32) # t1848: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1847 = prims.convert_element_type(t1846, dtypes.float32) # t1847: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1849 = prims.add(t1847, t1848) # t1849: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1852 = prims.mul(t1849, t1849) # t1852: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1853 = prims.sum(t1852, (2,)) # t1853: \"cuda:0 f32[1, 512]\"\n", " # t1854 = prims.broadcast_in_dim(t1853, [1, 512, 1], [0, 1]) # t1854: \"cuda:0 f32[1, 512, 1]\"\n", " # t1855 = prims.div(t1854, 4096.0) # t1855: \"cuda:0 f32[1, 512, 1]\"\n", " # t1856 = prims.add(t1855, 1e-05) # t1856: \"cuda:0 f32[1, 512, 1]\"\n", " # t1857 = prims.rsqrt(t1856) # t1857: \"cuda:0 f32[1, 512, 1]\"\n", " # t1858 = prims.broadcast_in_dim(t1857, (1, 512, 4096), (0, 1, 2)) # t1858: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1859 = prims.mul(t1849, t1858) # t1859: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1863 = prims.convert_element_type(t1861, dtypes.float32) # t1863: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1864 = prims.mul(t1859, t1863) # t1864: \"cuda:0 f32[1, 512, 4096]\"\n", " # t1865 = prims.convert_element_type(t1864, dtypes.bfloat16) # t1865: \"cuda:0 bf16[1, 512, 4096]\"\n", " t1866 = torch.nn.functional.linear(t1865, t51, None) # t1866: \"cuda:0 bf16[1, 512, 32000]\"\n", " # t1866 = ltorch.linear(t1865, t51, None) # t1866: \"cuda:0 bf16[1, 512, 32000]\"\n", " # t1866 = prims.linear(t1865, t51, None) # t1866: \"cuda:0 bf16[1, 512, 32000]\"\n", " return {'output': t1866, 'flat_args': [t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, t18, t19, t20, t21, t22, t23, t24, t25, t26, t27, t28, t29, t30, t31, t32, t33, t34, t35, t36, t37, t38, t39, t40, t41, t42, t43, t44, t45, t46, t47, t48, t49, t50, t51, t52, t53, t54, t55, t56, t57, t58, t59, t60, t61, t62, t63, t64, t65, t66, t67, t68, t69, t70, t71, t72, t73, t74, t75, t76, t77, t78, t79, t80, t81, t82, t83, t84, t85, t86, t87, t88, t89, t90, t91, t92, t93, t94, t95, t96, t97, t98, t99, t100, t101, t102, t103, t104, t105, t106, t107, t108, t109, t110, t111, t112, t113, t114, t115, t116, t117], 'flat_output': (t1866,)}, ((t0, t10, t100, t1001, t101, t1010, t102, t103, t104, t1042, t1044, t1045, t1046, t1047, t1048, t1049, t105, t1050, t1053, t1054, t1058, t106, t1065, t1069, t107, t1073, t1074, t1075, t108, t1089, t109, t1090, t1094, t11, t110, t1101, t1105, t1109, t111, t1118, t112, t113, t114, t115, t1150, t1152, t1153, t1154, t1155, t1156, t1157, t1158, t116, t1161, t1162, t1166, t1173, t1177, t1181, t1182, t1183, t1197, t1198, t12, t1202, t1209, t1213, t1217, t122, t1226, t1258, t1260, t1261, t1262, t1263, t1264, t1265, t1266, t1269, t1270, t1274, t1281, t1285, t1289, t129, t1290, t1291, t13, t1305, t1306, t1310, t1317, t1321, t1325, t133, t1334, t1366, t1368, t1369, t137, t1370, t1371, t1372, t1373, t1374, t1377, t1378, t1382, t1389, t1393, t1397, t1398, t1399, t14, t1413, t1414, t1418, t1425, t1429, t1433, t1442, t146, t1474, t1476, t1477, t1478, t1479, t1480, t1481, t1482, t1485, t1486, t1490, t1497, t15, t1501, t1505, t1506, t1507, t1521, t1522, t1526, t1533, t1537, t154, t1541, t1550, t157, t1582, t1584, t1585, t1586, t1587, t1588, t1589, t1590, t1593, t1594, t1598, t16, t1605, t1609, t1613, t1614, t1615, t1629, t1630, t1634, t1641, t1645, t1649, t1658, t1690, t1692, t1693, t1694, t1695, t1696, t1697, t1698, t17, t1701, t1702, t1706, t1713, t1717, t1721, t1722, t1723, t1737, t1738, t1742, t1749, t1753, t1757, t1766, t178, t1798, t18, t180, t1800, t1801, t1802, t1803, t1804, t1805, t1806, t1809, t181, t1810, t1814, t182, t1821, t1825, t1829, t183, t1830, t1831, t184, t1845, t1846, t185, t1857, t186, t1861, t1865, t189, t19, t190, t194, t20, t201, t205, t209, t21, t210, t211, t22, t225, t226, t23, t230, t237, t24, t241, t245, t25, t254, t26, t27, t28, t286, t288, t289, t29, t290, t291, t292, t293, t294, t297, t298, t3, t30, t302, t309, t31, t313, t317, t318, t319, t32, t33, t333, t334, t338, t34, t345, t349, t35, t353, t36, t362, t37, t38, t39, t394, t396, t397, t398, t399, t4, t40, t400, t401, t402, t405, t406, t41, t410, t417, t42, t421, t425, t426, t427, t43, t44, t441, t442, t446, t45, t453, t457, t46, t461, t47, t470, t48, t49, t5, t50, t502, t504, t505, t506, t507, t508, t509, t51, t510, t513, t514, t518, t525, t529, t533, t534, t535, t549, t550, t554, t561, t565, t569, t578, t6, t610, t612, t613, t614, t615, t616, t617, t618, t621, t622, t626, t633, t637, t641, t642, t643, t657, t658, t662, t669, t673, t677, t686, t7, t718, t720, t721, t722, t723, t724, t725, t726, t729, t730, t734, t741, t745, t749, t750, t751, t765, t766, t770, t777, t781, t785, t794, t8, t826, t828, t829, t830, t831, t832, t833, t834, t837, t838, t842, t849, t85, t853, t857, t858, t859, t86, t87, t873, t874, t878, t88, t885, t889, t89, t893, t9, t90, t902, t91, t92, t93, t934, t936, t937, t938, t939, t94, t940, t941, t942, t945, t946, t95, t950, t957, t96, t961, t965, t966, t967, t97, t98, t981, t982, t986, t99, t993, t997), (False, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 0.0, 4096.0, 4096.0, 0.08838834764831843, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 32000, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2))" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(actual.grad_fn)\n", "thunder.last_traces(thunder_model)[-1]" ] }, { "cell_type": "markdown", "id": "558f2553-37f7-4b58-b7cd-a744155613a8", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "Well, that is quite a bit to look through.\n", "But here is a key thing: The function now returns a bunch of things. This is because Thunder applies the same treatment to the backward and to this end saves information from the forward. You can see a hint of this because the output has a `ThunderFunctionBackward` on as its `grad_fn`. (You can see the backward trace with \n", "`thunder.last_backward_traces(thunder_model)[-1]`)." ] }, { "cell_type": "code", "execution_count": 10, "id": "59643398-d6e2-4c32-81bd-145a1198b1f3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[ 0.4160, -0.4668, 1.1016, ..., 0.5430, 1.2656, 0.2891],\n", " [ 0.3320, -0.0557, 1.7891, ..., 1.0703, 1.0078, 1.2266],\n", " [ 0.6836, -0.2871, 0.9531, ..., 0.0806, 0.7070, 0.8477],\n", " ...,\n", " [ 0.7695, -0.1260, 0.7266, ..., 0.1118, -0.0238, -1.2656],\n", " [-0.7773, -0.5547, -0.3047, ..., -0.1807, 0.1895, 0.6875],\n", " [ 0.8867, 0.4766, 0.3984, ..., 0.0815, -0.0879, 0.3477]]],\n", " device='cuda:0', grad_fn=)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "actual" ] }, { "cell_type": "markdown", "id": "17341d86-d4c9-46bd-ac5e-3a05da1ff72c", "metadata": {}, "source": [ "Let us clean up a bit." ] }, { "cell_type": "code", "execution_count": 11, "id": "6ba7f715", "metadata": {}, "outputs": [], "source": [ "del actual, expected\n", "import gc\n", "gc.collect();" ] }, { "cell_type": "markdown", "id": "0261eb11", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "But is it faster? Yes!" ] }, { "cell_type": "code", "execution_count": 12, "id": "bccec79b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "240 ms ± 105 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", "208 ms ± 147 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "%timeit r = m(inp); torch.autograd.grad(r.sum(), m.parameters()); torch.cuda.synchronize()\n", "%timeit r = thunder_model(inp); torch.autograd.grad(r.sum(), m.parameters()); torch.cuda.synchronize()" ] }, { "cell_type": "markdown", "id": "1d31e7f8", "metadata": {}, "source": [ "So far, so good! Thunder should work with LitGPT today and we busy are adding the support required to run other models as well!\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "ecad9125-bbf2-42c8-b11c-23eed4a6cd8f", "metadata": {}, "outputs": [], "source": [ "del m, thunder_model\n", "import gc\n", "gc.collect()\n", "torch.cuda.empty_cache()\n" ] }, { "cell_type": "markdown", "id": "49e3273c-99be-4370-9e59-121c00481b4e", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Distributed with Thunder\n", "\n", "Those Large Language Models are called Large for a reason, and memory in a single GPU is invariably small. So we need multiple.\n", "\n", "Happily Thunder sports an FSDP interface to use multiple cards in our box.\n", "\n", "You still need to setup the process group, but as far as the model is concerned,\n", "\n", "```python\n", "model = thunder.jit(thunder.distributed.fsdp(model))\n", "```\n", "\n", "is all you need. Because it is tricky to run multiprocessing from Notebooks, we write a small example into a file and run it though `torch-run`.\n", "\n", "Check out our LitGPT Thunder examples for complete distributed training and finetuning!" ] }, { "cell_type": "code", "execution_count": 14, "id": "18dd3379", "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting zero_to_thunder_fsdp_simple_example.py\n" ] } ], "source": [ "%%writefile zero_to_thunder_fsdp_simple_example.py\n", "from thunder.tests.litgpt_model import GPT, Config\n", "import os\n", "import torch, torch.distributed\n", "import thunder, thunder.distributed\n", "\n", "# Create Model\n", "# NOTE: We create the model on CPU.\n", "device='cpu'\n", "torch.set_default_dtype(torch.bfloat16)\n", "cfg = Config.from_name('Llama-2-7b-hf')\n", "cfg.n_layer = 8 # fewer layers\n", "model = GPT(cfg)\n", "\n", "# Setup for distributed\n", "torch.distributed.init_process_group(backend='nccl')\n", "rank = int(os.environ[\"LOCAL_RANK\"])\n", "\n", "device = f\"cuda:{rank}\"\n", "x = torch.randint(1, model.config.vocab_size, (1, 1024), device=device)\n", "\n", "# thunder.distributed.fsdp takes care of moving the parameter\n", "# shard to the correct GPU for the current process.\n", "model = thunder.jit(thunder.distributed.fsdp(model)) # <---------------------------------------\n", "print(f\"rank {rank} computing\")\n", "# Run the forward pass.\n", "for i in range(10):\n", " res = model(x)\n", " res.sum().backward()\n" ] }, { "cell_type": "markdown", "id": "97e8edbf-424d-49a7-8ed6-12cb5e5d65fc", "metadata": {}, "source": [ "Now we can launch it. Note that you need two GPUs for this to run correctly." ] }, { "cell_type": "code", "execution_count": 15, "id": "2bad9b64", "metadata": { "scrolled": true, "slideshow": { "slide_type": "skip" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] \n", "W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] *****************************************\n", "W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \n", "W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] *****************************************\n", "rank 1 computing\n", "rank 0 computing\n" ] } ], "source": [ "# commented out for CI limitations, see https://github.com/Lightning-AI/lightning-thunder/issues/465\n", "# !torchrun --standalone --nnodes=1 --nproc_per_node=2 zero_to_thunder_fsdp_simple_example.py" ] }, { "cell_type": "markdown", "id": "9c65e75d", "metadata": { "slideshow": { "slide_type": "skip" } }, "source": [ "So there. FSDP with just wrapping the model in `fsdp`.\n" ] }, { "cell_type": "markdown", "id": "4a6d7a20", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Extending Thunder\n", "\n", "But we promised that thunder is extensible. Let's find out what's up with that.\n", "\n", "Specifically, we will incorporate the fast rope embedding kernel from the great [Unsloth project](https://github.com/unslothai/unsloth/) into our model (note that NVFuser also creates a fused kernel for this).\n", "\n", "In Thunder, extensions (as well as most builtin optimizations which use the exact same mechanism) work with _executors_ handling operations. Let us define one." ] }, { "cell_type": "code", "execution_count": 16, "id": "f7639065", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "my_ex" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "my_ex = thunder.extend.OperatorExecutor('my_ex', version='0.0.1')\n", "thunder.extend.register_executor(my_ex)" ] }, { "cell_type": "markdown", "id": "2fe3b40b-c6e9-417c-ab7a-32606cee871a", "metadata": { "slideshow": { "slide_type": "skip" } }, "source": [ "For our base implementation, we take the code from [LitGPT's implementation](https://github.com/Lightning-AI/litgpt/blob/be6139e1fd4b240d253efd58124457496d23d173/litgpt/model.py#L355-L361)\n", "\n", "In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function.\n", "Because we will demonstrate Thunder's ability to divert functions in the model, we make a version here that will not be diverted." ] }, { "cell_type": "code", "execution_count": 17, "id": "3e74436b-d8eb-472b-9d6d-b6412378fde7", "metadata": { "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "import litgpt\n", "def apply_rope_copy(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:\n", " head_size = x.size(-1)\n", " x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)\n", " x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)\n", " rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)\n", " roped = (x * cos) + (rotated * sin)\n", " return roped.to(dtype=x.dtype)" ] }, { "cell_type": "markdown", "id": "a63595ab", "metadata": { "slideshow": { "slide_type": "skip" } }, "source": [ "### Registering operators\n", "\n", "Say we have a function `apply_rope` applying the RoPE transformation in PyTorch.\n", "\n", "In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function and tell it to use the new symbol instead of the original function `litgpt.model.apply_rope`.\n" ] }, { "cell_type": "code", "execution_count": 18, "id": "247074b3", "metadata": {}, "outputs": [], "source": [ "import torch, thunder\n", "from thunder.tests.litgpt_model import GPT\n", "from thunder import TensorProxy\n", "\n", "def apply_rope_impl(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:\n", " return litgpt.model.apply_rope(x, cos, sin)\n", "\n", "def apply_rope_meta(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:\n", " return TensorProxy(like=x)\n", "\n", "apply_rope = my_ex.register_operator('apply_rope', like=apply_rope_meta, fn=apply_rope_impl,\n", " replaces=litgpt.model.apply_rope)" ] }, { "cell_type": "markdown", "id": "d6b7d056", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Testing our new operator " ] }, { "cell_type": "code", "execution_count": 19, "id": "0ebd5dd1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "deviation: 0.0\n" ] }, { "data": { "text/plain": [ "# Constructed by Delete Last Used (took 0 milliseconds)\n", "import torch\n", "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", "@no_autocast\n", "def computation(x, t_1_cos, t_1_sin):\n", " # x: \"cuda:0 bf16[2, 128, 4096, 16]\" \n", " # t_1_cos: \"cuda:0 f32[4096, 16]\" \n", " # t_1_sin: \"cuda:0 f32[4096, 16]\" \n", " t2 = apply_rope(x, t_1_cos, t_1_sin) # t2: \"cuda:0 bf16[2, 128, 4096, 16]\"\n", " del x, t_1_cos, t_1_sin\n", " return t2" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with torch.device('cuda'): m = GPT.from_name('llama2-like'); Q = torch.randn(2, 128, 4096, 16)\n", "\n", "def test_apply_rope(x, m):\n", " return litgpt.model.apply_rope(x, m.cos, m.sin)\n", "\n", "thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors()) \n", "\n", "expected = test_apply_rope(Q, m); actual = thunder_apply_rope(Q, m); print(\"deviation:\", (expected - actual).abs().max().item())\n", "\n", "thunder.last_traces(thunder_apply_rope)[-1]" ] }, { "cell_type": "markdown", "id": "8c620a38", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Optimized kernels\n", "\n", "But why did we do this? Well, we can now layer a faster implementation on top.\n", "For this we take the [unsloth fast rope embedding](https://github.com/unslothai/unsloth/blob/42076f6580e71522ed1c122043edfba595be64e4/unsloth/kernels/rope_embedding.py) kernels. We take the bits that were in the forward and backward of the `autograd.Function` into our implementation functions. Note that we include the transpositions in our setup in order to have compatibility to the LitGPT implementation. This change in memory layout of the operands can have a large effect on the runtime though, so our timings are likely not representative of the ones the Unsloth project gets in their use of the same triton kernels." ] }, { "cell_type": "code", "execution_count": 20, "id": "6e6d0b1e-ba14-43e5-b0d9-27c0e3b46879", "metadata": {}, "outputs": [], "source": [ "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# http://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "\n", "import triton\n", "import triton.language as tl\n", "import torch\n", "\n", "MAX_FUSED_SIZE = 65536\n", "next_power_of_2 = triton.next_power_of_2\n", "\n", "def calculate_settings(n):\n", " BLOCK_SIZE = next_power_of_2(n)\n", " if BLOCK_SIZE > MAX_FUSED_SIZE:\n", " raise RuntimeError(f\"Cannot launch Triton kernel since n = {n} exceeds \"\\\n", " f\"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.\")\n", " num_warps = 4\n", " if BLOCK_SIZE >= 32768: num_warps = 32\n", " elif BLOCK_SIZE >= 8192: num_warps = 16\n", " elif BLOCK_SIZE >= 2048: num_warps = 8\n", " return BLOCK_SIZE, num_warps\n", "\n", "@triton.heuristics({\"BACKWARD_PASS\": lambda args: args[\"BACKWARD_PASS\"],})\n", "@triton.jit\n", "def _rope_embedding(\n", " Q, Q_row_stride,\n", " cos, cos_row_stride,\n", " sin, sin_row_stride,\n", " seqlen, head_dim, group_size, n_heads,\n", " BACKWARD_PASS: tl.constexpr,\n", " BLOCK_SIZE : tl.constexpr,\n", "):\n", " \"\"\"\n", " Calculates the RoPE Embedding quickly\n", " RoPE is Q * cos + rotate_half(Q) * sin\n", " See our blog post for more info\n", " \"\"\"\n", " row_position = tl.program_id(0)\n", " group_head_position = tl.program_id(1)\n", " col_offsets = tl.arange(0, BLOCK_SIZE)\n", " half_head_dim = head_dim // 2\n", " mask = col_offsets < half_head_dim\n", "\n", " sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \\\n", " half_head_dim*0 + col_offsets, mask = mask, other = 0)\n", " cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \\\n", " half_head_dim*0 + col_offsets, mask = mask, other = 0)\n", "\n", " if BACKWARD_PASS:\n", " # See our blog post for more info.\n", " sin1 = -sin1\n", " pass\n", "\n", " head_start = group_head_position * group_size\n", " head_end = min((head_start + group_size), n_heads)\n", "\n", " for i in range(head_start, head_end):\n", " offs_q1 = row_position * Q_row_stride + i * head_dim + col_offsets\n", " offs_q2 = row_position * Q_row_stride + i * head_dim + col_offsets + half_head_dim\n", "\n", " # For Gemma - sometimes RoPE must be done in float32 and not bfloat16\n", " Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)\n", " Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)\n", "\n", " tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)\n", " tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)\n", " pass\n", "pass\n", "\n", "\n", "def fast_rope_embedding_forward(Q, cos, sin):\n", " Q = Q.transpose(1, 2).clone()\n", " cos, sin = cos.squeeze(), sin.squeeze()\n", " batch, seq_len, n_heads, head_dim = Q.shape\n", " Q = Q.reshape(batch*seq_len, n_heads*head_dim)\n", " n_rows, n_cols = Q.shape\n", " assert(seq_len <= cos.shape[0])\n", "\n", " # [TODO] Changing blocksize to head_dim//2 seems to have\n", " # some concurrency / un-deterministic issues.\n", " BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)\n", " group_size = 4 # 4 or 8, too large group_size can hurt performance.\n", " n_groups = triton.cdiv(n_heads, group_size)\n", "\n", " grid = (n_rows, n_groups, )\n", " _rope_embedding[grid](\n", " Q, Q.stride(0),\n", " cos, cos.stride(0),\n", " sin, sin.stride(0),\n", " seq_len, head_dim, group_size, n_heads,\n", " BACKWARD_PASS = False,\n", " BLOCK_SIZE = BLOCK_SIZE,\n", " num_warps = num_warps,\n", " )\n", " Q = Q.view(batch, seq_len, n_heads, head_dim).transpose(1, 2)\n", " return Q, (BLOCK_SIZE, num_warps) \n", "\n", "def fast_rope_embedding_backward(BLOCK_SIZE, num_warps, cos, sin, dY):\n", " dY = dY.transpose(1, 2)\n", " batch, seq_len, n_heads, head_dim = dY.shape\n", " dY = dY.reshape(batch*seq_len, n_heads*head_dim)\n", " # Must be reshape not view\n", " n_rows, n_cols = dY.shape\n", "\n", " group_size = 4 # 4 or 8, too large group_size can hurt performance.\n", " n_groups = triton.cdiv(n_heads, group_size)\n", "\n", " grid = (n_rows, n_groups, )\n", " _rope_embedding[grid](\n", " dY, dY .stride(0),\n", " cos, cos.stride(0),\n", " sin, sin.stride(0),\n", " seq_len, head_dim, group_size, n_heads,\n", " BACKWARD_PASS = True,\n", " BLOCK_SIZE = BLOCK_SIZE,\n", " num_warps = num_warps,\n", " )\n", " dY = dY.view(batch, seq_len, n_heads, head_dim)\n", " dY = dY.transpose(1, 2) \n", " return dY\n" ] }, { "cell_type": "markdown", "id": "ed1e9be3-d1c9-4c4b-bf14-a025a03687ac", "metadata": {}, "source": [ "We also define the corresponding meta functions." ] }, { "cell_type": "code", "execution_count": 21, "id": "d7e6612d-f1fc-497c-9d64-15ef99824086", "metadata": {}, "outputs": [], "source": [ "def fast_rope_embedding_forward_meta(Q, cos, sin):\n", " batch, n_heads, seq_len, head_dim = Q.shape\n", " n_rows, n_cols = batch*seq_len, n_heads*head_dim \n", " assert(seq_len <= cos.shape[0])\n", "\n", " BLOCK_SIZE, num_warps = calculate_settings(head_dim//2)\n", " return TensorProxy(like=Q), (BLOCK_SIZE, num_warps) \n", "\n", "def fast_rope_embedding_backward_meta(BLOCK_SIZE, num_warps, cos, sin, dY):\n", " return TensorProxy(like=dY)" ] }, { "cell_type": "markdown", "id": "b70eba5f", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Register optimized operators\n", "\n", "Just like the `apply_rope` before, we can register operators for the optimized forward and backward." ] }, { "cell_type": "code", "execution_count": 22, "id": "f8f1e77e", "metadata": {}, "outputs": [], "source": [ "unsloth_apply_rope_forward = my_ex.register_operator('unsloth_apply_rope_forward', \n", " meta=fast_rope_embedding_forward_meta, fn=fast_rope_embedding_forward)\n", "unsloth_apply_rope_backward = my_ex.register_operator('unsloth_apply_rope_backward', \n", " meta=fast_rope_embedding_backward_meta, fn=fast_rope_embedding_backward)" ] }, { "cell_type": "markdown", "id": "2426263d", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Implementations for operators\n", "\n", "Do we need to divert `apply_rope` again? No!\n", "We can register the specialized kernel as an _implementation_ of our base `apply_rope` operator. For this we need an _execution transform_ - which is a fancy word for a function that implements the original operator (`apply_ropw`) in terms of our new operator - so it has the call signature of the `apply_rope`. Because - like many fast implementations - the unsloth rope embedding does not implement the operator in full generality (well, actually they mainly want a 4d tensor input), we implement a checker function, too: It takes the arguments of the operator we want specialize and returns a bool whether our implementation handles the given inputs." ] }, { "cell_type": "code", "execution_count": 23, "id": "6b5c8320", "metadata": {}, "outputs": [], "source": [ "def apply_rope_to_unsloth(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:\n", " assert len(x.shape) == 4\n", " res, *_ = unsloth_apply_rope_forward(x, cos, sin)\n", " return res\n", "\n", "def apply_rope_to_unsloth_checker(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> bool:\n", " if len(x.shape) != 4:\n", " return False\n", " return (x.device.devicetype == thunder.devices.DeviceType.CUDA and\n", " cos.device.devicetype == thunder.devices.DeviceType.CUDA and\n", " cos.device.devicetype == thunder.devices.DeviceType.CUDA)\n", "\n", "my_ex.register_implementation(apply_rope,\n", " checker=apply_rope_to_unsloth_checker,\n", " execution_transform=apply_rope_to_unsloth)\n" ] }, { "cell_type": "markdown", "id": "eec7c95a", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "So let us give it a try! Works great..." ] }, { "cell_type": "code", "execution_count": 24, "id": "965ba1d7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "deviation: 0.015625\n" ] }, { "data": { "text/plain": [ "# Constructed by Delete Last Used (took 0 milliseconds)\n", "import torch\n", "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", "@no_autocast\n", "def computation(x, t_1_cos, t_1_sin):\n", " # x: \"cuda:0 bf16[2, 128, 4096, 16]\" \n", " # t_1_cos: \"cuda:0 f32[4096, 16]\" \n", " # t_1_sin: \"cuda:0 f32[4096, 16]\" \n", " (t2, (_, _)) = unsloth_apply_rope_forward(x, t_1_cos, t_1_sin)\n", " del x, t_1_cos, t_1_sin\n", " return t2" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors()) \n", "\n", "expected = test_apply_rope(Q, m)\n", "actual = thunder_apply_rope(Q, m)\n", "print(\"deviation:\", (expected - actual).abs().max().item())\n", "\n", "thunder.last_traces(thunder_apply_rope)[-1]" ] }, { "cell_type": "markdown", "id": "69a93d3d-3a88-4297-b330-23a7fff2c4b4", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "And this is also automatic when we instantiate a larger llama2-like model:" ] }, { "cell_type": "code", "execution_count": 25, "id": "7fff2522", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "deviation: 5.960464477539062e-07\n" ] } ], "source": [ "torch.set_default_dtype(torch.float32)\n", "with torch.device('cuda'):\n", " m = GPT(Config.from_name('llama2-like'))\n", "\n", "for p in m.parameters():\n", " p.requires_grad_(False)\n", "\n", "thunder_model = thunder.jit(m, executors=(my_ex,) + thunder.get_default_executors())\n", "\n", "inp = torch.randint(1, m.config.vocab_size, (1, 128), device=\"cuda\")\n", "actual = thunder_model(inp)\n", "expected = m(inp)\n", "\n", "print(\"deviation:\", (actual - expected).abs().max().item())" ] }, { "cell_type": "markdown", "id": "b538cb40", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "By peeking into the trace, we can see that it actually used the unsloth apply rope:" ] }, { "cell_type": "code", "execution_count": 26, "id": "c260cb25", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[' (q_roped, (_, _)) = unsloth_apply_rope_forward(t55, cos, sin)',\n", " ' (k_roped, (_, _)) = unsloth_apply_rope_forward(t57, cos, sin)',\n", " ' (t165, (_, _)) = unsloth_apply_rope_forward(t164, cos, sin)',\n", " ' (t167, (_, _)) = unsloth_apply_rope_forward(t166, cos, sin)']" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[s for s in str(thunder.last_traces(thunder_model)[-1]).split('\\n') if 'apply_rope' in s]" ] }, { "cell_type": "markdown", "id": "0f6c0780", "metadata": {}, "source": [ "### But what about the backward?\n", "\n", "Well, we have to connect forward and backward with a grad transformation. With our specialized ops, this is very simple, we compute the forward, call `get_grad` for the output, compute the backward, and put it on the input with `put_grads`. \n" ] }, { "cell_type": "code", "execution_count": 27, "id": "7670a872", "metadata": {}, "outputs": [], "source": [ "from thunder.core.transforms import get_grad, put_grads\n", "\n", "def unsloth_apply_rope_grad(x: TensorProxy, cos: TensorProxy, sin: TensorProxy):\n", " res, (BLOCK_SIZE, num_warps) = unsloth_apply_rope_forward(x, cos, sin)\n", " grad_res = get_grad(res)\n", " grad_x = unsloth_apply_rope_backward(BLOCK_SIZE, num_warps, cos, sin, grad_res)\n", " put_grads((x,), (grad_x,))\n", " return res\n", "\n", "my_ex.register_implementation(apply_rope, checker=apply_rope_to_unsloth_checker,\n", " execution_transform=apply_rope_to_unsloth,\n", " grad_transform=unsloth_apply_rope_grad \n", " )\n", "\n" ] }, { "cell_type": "markdown", "id": "219dfaa4-cdef-47de-b60c-7c7c1642cb84", "metadata": {}, "source": [ "Note that the parts are not actually executed at the same time in the actual computation, but just during tracing.\n" ] }, { "cell_type": "markdown", "id": "68226a4a-6ad8-43fb-b92f-c1e8eec6f13e", "metadata": {}, "source": [ "And let us try our function using the optimized backward" ] }, { "cell_type": "code", "execution_count": 28, "id": "ccc3ed63-ddc2-4b0e-bcd0-f77d66fefe9f", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "res deviation: 0.015625\n", "grad deviation: 0.0078125\n" ] } ], "source": [ "Q.requires_grad_()\n", "\n", "thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors())\n", "\n", "expected = test_apply_rope(Q, m)\n", "go = torch.ones_like(expected)\n", "gr_expected, = torch.autograd.grad(expected, Q, go)\n", "actual = thunder_apply_rope(Q, m)\n", "gr_actual, = torch.autograd.grad(actual, Q, go)\n", "\n", "print(\"res deviation:\", (expected - actual).abs().max().item())\n", "print(\"grad deviation:\", (gr_expected - gr_actual).abs().max().item())" ] }, { "cell_type": "markdown", "id": "63cb61ee-c791-49d1-ba5c-3fe4b5b9a9d5", "metadata": {}, "source": [ "And with `last_backward_traces` we can check that our module is using the unsloth backward:" ] }, { "cell_type": "code", "execution_count": 29, "id": "cd12ca02-6f06-4d88-b5b7-25c4c27dbc9a", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "# Constructed by Delete Last Used (took 0 milliseconds)\n", "import torch\n", "from thunder.executors.torchex import no_autocast\n", "\n", "@torch.no_grad()\n", "@no_autocast\n", "def backward_fn(saved_for_backward, cotangents):\n", " # saved_for_backward: \"Collection\" \n", " # cotangents: \"Collection\" \n", " C0, \\\n", " _, \\\n", " = saved_for_backward\n", " clear_collection(saved_for_backward)\n", " del saved_for_backward\n", " t4, \\\n", " = cotangents\n", " clear_collection(cotangents)\n", " del cotangents\n", " t1, \\\n", " t2, \\\n", " = C0\n", " clear_collection(C0)\n", " del C0\n", " t3 = unsloth_apply_rope_backward(8, 4, t1, t2, t4) # t3: \"cuda:0 bf16[2, 128, 4096, 16]\"\n", " del t1, t2, t4\n", " return (t3, None, None)" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "thunder.last_backward_traces(thunder_apply_rope)[-1]" ] }, { "cell_type": "markdown", "id": "2776d183-0232-495e-aa75-3b90e799c841", "metadata": {}, "source": [ "### Comparing and exploring optimizations\n", "\n", "It is also straightforward to compare potential optimizations.\n", "\n", "Note again, that our use of the unsloth kernel might not result in the same performance as the unsloth project sees due to differences in the hardware used, software environment, or memory layout of the operands." ] }, { "cell_type": "code", "execution_count": 30, "id": "a5e0ce05", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "eager\n", "3.84 ms ± 3.46 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", "thunder + unsloth\n", "6.69 ms ± 3.45 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", "thunder default (nvfuser)\n", "1.4 ms ± 4.98 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], "source": [ "def test_apply_rope_copy(x, m):\n", " return apply_rope_copy(x, m.cos, m.sin)\n", "\n", "test_apply_rope_myex = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors()) \n", "test_apply_rope_nvfuser = thunder.jit(test_apply_rope_copy)\n", "y = test_apply_rope(Q, m); gr, = torch.autograd.grad(y, Q, go)\n", "y = test_apply_rope_myex(Q, m); gr, = torch.autograd.grad(y, Q, go)\n", "y = test_apply_rope_nvfuser(Q, m); gr, = torch.autograd.grad(y, Q, go)\n", "\n", "print(\"eager\")\n", "%timeit y = test_apply_rope(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()\n", "print(\"thunder + unsloth\")\n", "%timeit y = test_apply_rope_myex(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()\n", "print(\"thunder default (nvfuser)\")\n", "%timeit y = test_apply_rope_nvfuser(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()\n" ] }, { "cell_type": "markdown", "id": "08b8454f-c725-470c-92a5-56b2206af0e8", "metadata": {}, "source": [ "That's it!\n", "\n", "## Conclusion\n", "\n", "To wrap up, we hope you got a taste of\n", "\n", "- Getting things going with Thunder:\n", "\n", " - Applying Thunder through `thunder.jit` and\n", " - using FSDP by just wrapping the model in `thunder.distributed.fsdp` before compilation.\n", "\n", "- See what's going on inspecting traces:\n", "\n", " - `thunder.last_traces` for the forward traces,\n", " - `thunder.last_backward_traces` for the backward,\n", " \n", "- Extending Thunder:\n", "\n", " - registering operators with the `OperatorExecutor`,\n", " - defining implementations with custom forward and backward to include optimized kernels.\n", "\n", "Keep in mind that Thunder is still experimental and only expected to work with the limited set of models we have tested it with. You will find bugs and missing pieces. Naturally, we would love for you to help us fix these! You can find us on the [Thunder section of the Lightning forums](https://lightning.ai/forums/c/thunder) or in the `#thunder` channel on the [PyTorch-Lightning slack](https://pytorch-lightning.slack.com/). \n", "\n", "Do check out our LitGPT studios and the other tutorial notebooks.\n" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.10" } }, "nbformat": 4, "nbformat_minor": 5 }