Zero to Thunder
Here we take a very short tour of what is possible with Thunder.
To get started we import it (and a bunch of things for this notebook).
[1]:
import sys
sys.path.insert(0, '..')
import torch, thunder
Compiling a first module with Thunder
So let’s get started! As a “Hello World”, let us apply it to it to a small model, say, the MLP part found in Llama 2. We take it from LitGPT.
[2]:
class LLaMAMLP(torch.nn.Module):
def __init__(self, n_embd, intermediate_size) -> None:
super().__init__()
self.fc_1 = torch.nn.Linear(n_embd, intermediate_size, bias=False)
self.fc_2 = torch.nn.Linear(n_embd, intermediate_size, bias=False)
self.proj = torch.nn.Linear(intermediate_size, n_embd, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
x = torch.nn.functional.silu(x_fc_1) * x_fc_2
return self.proj(x)
with torch.device("cuda"):
m = LLaMAMLP(4096, 11008)
for p in m.parameters():
p.requires_grad_(False)
print(m)
LLaMAMLP(
(fc_1): Linear(in_features=4096, out_features=11008, bias=False)
(fc_2): Linear(in_features=4096, out_features=11008, bias=False)
(proj): Linear(in_features=11008, out_features=4096, bias=False)
)
Now we can apply Thunder. This uses the most important function of Thunder, thunder.jit
, which can be used to compile a torch.nn.Module
or a function. It will wrap our MLP in a ThunderModule
[3]:
thunder_model = thunder.jit(m)
[4]:
thunder_model
[4]:
ThunderModule(
(_model): LLaMAMLP(
(fc_1): Linear(in_features=4096, out_features=11008, bias=False)
(fc_2): Linear(in_features=4096, out_features=11008, bias=False)
(proj): Linear(in_features=11008, out_features=4096, bias=False)
)
)
Our Thunder module computes (up to numerical accuracy) the same thing as our original model and for a small model like this, it also has approximately the same performance.
[5]:
x = torch.randn(2, 2048, 4096, device="cuda")
print('deviation:', (thunder_model(x) - m(x)).abs().max().item())
%timeit thunder_model(x); torch.cuda.synchronize()
%timeit m(x); torch.cuda.synchronize()
deviation: 1.4901161193847656e-07
61.3 ms ± 106 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
62.1 ms ± 89.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
So what has changed? Quite a bit!
When we call the Thunder module, it does the computation in a single function without control flow. And what’s more, it applies optimizations, such as creating fusions for NVFuser to execute. We can see all this by showing the last computation trace:
[6]:
thunder.last_traces(thunder_model)[-1]
[6]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
import torch.nn.functional
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight):
# x: "cuda:0 f32[2, 2048, 4096]"
# t_fc_1_weight: "cuda:0 f32[11008, 4096]"
# t_fc_2_weight: "cuda:0 f32[11008, 4096]"
# t_proj_weight: "cuda:0 f32[4096, 11008]"
x_fc_1 = torch.nn.functional.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]"
# x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]"
# x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]"
del t_fc_1_weight
x_fc_2 = torch.nn.functional.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]"
# x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]"
# x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]"
del x, t_fc_2_weight
[result] = nvFusion0(x_fc_1, x_fc_2)
# t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]"
# t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]"
# t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]"
# t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]"
# a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]"
# result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]"
del x_fc_1, x_fc_2
t18 = torch.nn.functional.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]"
# t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]"
# t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]"
del result, t_proj_weight
return t18
For more detail of what is going on in this trace:
Thunder has transformed the computation (more precisely,
m.__call__
) into a single function which has all the MLP parameters as arguments.It has recorded the tensor metadata.
Operations have been mapped from the PyTorch functions to
thunder.torch
(akaltorch
) equivalents and decomposed into primitive operations.The multiplication and activation (
x = torch.nn.functional.silu(x_fc_1) * x_fc_2
have been put into one NVFuser fusion. (NVFuser here is (a particularly important) one of many optimizations, and we make it easy to add your own.)You can see how the parameters are obtained and the metadata is checked in the prologue - get it through
thunder.last_prologue_traces(thunder_model)[-1]
.
You can actually see the series of traces, last_traces
gives you a list of transformed traces in chronological order - for example the initial trace thunder.last_traces(thunder_model)[0]
does not have the fusion yet.
Compiling a more complex model
Obviously, we aim for larger models, so we can do the same with the entire LLama 2 (well, we have a smaller momdel here to be mild to our CI, but if you have a large GPU, just drop reducing the number of layers):
NOTE: For running the cells below, we require litgpt
which can be installed with pip install 'litgpt[all] @ git+https://github.com/Lightning-AI/litgpt'
. See here to learn more about litgpt.
[7]:
from litgpt import GPT
from thunder.tests.litgpt_model import Config
cfg = Config.from_name('Llama-2-7b-hf')
cfg.n_layer = 16 # fewer layers
torch.set_default_dtype(torch.bfloat16)
with torch.device('cuda'):
m = GPT(cfg)
m
[7]:
GPT(
(lm_head): Linear(in_features=4096, out_features=32000, bias=False)
(transformer): ModuleDict(
(wte): Embedding(32000, 4096)
(h): ModuleList(
(0-15): 16 x Block(
(norm_1): RMSNorm()
(attn): CausalSelfAttention(
(attn): Linear(in_features=4096, out_features=12288, bias=False)
(proj): Linear(in_features=4096, out_features=4096, bias=False)
)
(norm_2): RMSNorm()
(mlp): LLaMAMLP(
(fc_1): Linear(in_features=4096, out_features=11008, bias=False)
(fc_2): Linear(in_features=4096, out_features=11008, bias=False)
(proj): Linear(in_features=11008, out_features=4096, bias=False)
)
)
)
(ln_f): RMSNorm()
)
)
Again we jit our model and compare the output…
[8]:
thunder_model = thunder.jit(m)
inp = torch.randint(1, m.config.vocab_size, (1, 512), device="cuda")
actual = thunder_model(inp)
expected = m(inp)
print("deviation:", (actual - expected).abs().max().item())
deviation: 0.03125
One thing to keep in mind here is that for bf16, the numerical accuracy impact of rearranging operations can be quite pronounced.
Just like before, we can see the program it ran, it is a lot longer, though.
[9]:
print(actual.grad_fn)
thunder.last_traces(thunder_model)[-1]
<torch.autograd.function.ThunderFunctionBackward object at 0x7f923f792ac0>
[9]:
# Constructed by Delete Last Used (took 10 milliseconds)
import torch
from torch import Tensor
import torch.nn.functional
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def augmented_forward_fn(*args):
# args: "Collection"
t0, \
t1, \
t2, \
t3, \
t4, \
t5, \
t6, \
t7, \
t8, \
t9, \
t10, \
t11, \
t12, \
t13, \
t14, \
t15, \
t16, \
t17, \
t18, \
t19, \
t20, \
t21, \
t22, \
t23, \
t24, \
t25, \
t26, \
t27, \
t28, \
t29, \
t30, \
t31, \
t32, \
t33, \
t34, \
t35, \
t36, \
t37, \
t38, \
t39, \
t40, \
t41, \
t42, \
t43, \
t44, \
t45, \
t46, \
t47, \
t48, \
t49, \
t50, \
t51, \
t52, \
t53, \
t54, \
t55, \
t56, \
t57, \
t58, \
t59, \
t60, \
t61, \
t62, \
t63, \
t64, \
t65, \
t66, \
t67, \
t68, \
t69, \
t70, \
t71, \
t72, \
t73, \
t74, \
t75, \
t76, \
t77, \
t78, \
t79, \
t80, \
t81, \
t82, \
t83, \
t84, \
t85, \
t86, \
t87, \
t88, \
t89, \
t90, \
t91, \
t92, \
t93, \
t94, \
t95, \
t96, \
t97, \
t98, \
t99, \
t100, \
t101, \
t102, \
t103, \
t104, \
t105, \
t106, \
t107, \
t108, \
t109, \
t110, \
t111, \
t112, \
t113, \
t114, \
t115, \
t116, \
t117, \
= args
del args
t122 = torch.nn.functional.embedding(t0, t117, None, None, 2.0, False, False) # t122: "cuda:0 bf16[1, 512, 4096]"
# t122 = ltorch.embedding(t0, t117, None, None, 2.0, False, False) # t122: "cuda:0 bf16[1, 512, 4096]"
# t1867 = ltorch.reshape(t0, [512]) # t1867: "cuda:0 i64[512]"
# t1867 = prims.reshape(t0, (512,)) # t1867: "cuda:0 i64[512]"
# t1868 = prims.take(t117, t1867, 0) # t1868: "cuda:0 bf16[512, 4096]"
# t122 = ltorch.reshape(t1868, [1, 512, 4096]) # t122: "cuda:0 bf16[1, 512, 4096]"
# t122 = prims.reshape(t1868, (1, 512, 4096)) # t122: "cuda:0 bf16[1, 512, 4096]"
t118 = torch_slice_prim_impl(t1, [0, 0], [512, 128], [1, 1]) # t118: "cuda:0 f32[512, 128]"
t119 = torch_slice_prim_impl(t2, [0, 0], [512, 128], [1, 1]) # t119: "cuda:0 f32[512, 128]"
t2015 = torch.unsqueeze(t53, 0) # t2015: "cuda:0 bf16[1, 4096]"
# t2015 = ltorch.unsqueeze(t53, 0) # t2015: "cuda:0 bf16[1, 4096]"
# t2015 = prims.broadcast_in_dim(t53, [1, 4096], [1]) # t2015: "cuda:0 bf16[1, 4096]"
t2016 = torch.unsqueeze(t2015, 1) # t2016: "cuda:0 bf16[1, 1, 4096]"
# t2016 = ltorch.unsqueeze(t2015, 1) # t2016: "cuda:0 bf16[1, 1, 4096]"
# t2016 = prims.broadcast_in_dim(t2015, [1, 1, 4096], [0, 2]) # t2016: "cuda:0 bf16[1, 1, 4096]"
del t2015
t133 = Tensor.expand(t2016, (1, 512, 4096)) # t133: "cuda:0 bf16[1, 512, 4096]"
# t133 = ltorch.expand(t2016, (1, 512, 4096)) # t133: "cuda:0 bf16[1, 512, 4096]"
# t133 = prims.broadcast_in_dim(t2016, (1, 512, 4096), (0, 1, 2)) # t133: "cuda:0 bf16[1, 512, 4096]"
del t2016
t2356 = torch.unsqueeze(t82, 0) # t2356: "cuda:0 bf16[1, 4096]"
# t2356 = ltorch.unsqueeze(t82, 0) # t2356: "cuda:0 bf16[1, 4096]"
# t2356 = prims.broadcast_in_dim(t82, [1, 4096], [1]) # t2356: "cuda:0 bf16[1, 4096]"
t2357 = torch.unsqueeze(t2356, 1) # t2357: "cuda:0 bf16[1, 1, 4096]"
# t2357 = ltorch.unsqueeze(t2356, 1) # t2357: "cuda:0 bf16[1, 1, 4096]"
# t2357 = prims.broadcast_in_dim(t2356, [1, 1, 4096], [0, 2]) # t2357: "cuda:0 bf16[1, 1, 4096]"
del t2356
t1609 = Tensor.expand(t2357, (1, 512, 4096)) # t1609: "cuda:0 bf16[1, 512, 4096]"
# t1609 = ltorch.expand(t2357, (1, 512, 4096)) # t1609: "cuda:0 bf16[1, 512, 4096]"
# t1609 = prims.broadcast_in_dim(t2357, (1, 512, 4096), (0, 1, 2)) # t1609: "cuda:0 bf16[1, 512, 4096]"
del t2357
t2359 = torch.unsqueeze(t58, 0) # t2359: "cuda:0 bf16[1, 4096]"
# t2359 = ltorch.unsqueeze(t58, 0) # t2359: "cuda:0 bf16[1, 4096]"
# t2359 = prims.broadcast_in_dim(t58, [1, 4096], [1]) # t2359: "cuda:0 bf16[1, 4096]"
t2360 = torch.unsqueeze(t2359, 1) # t2360: "cuda:0 bf16[1, 1, 4096]"
# t2360 = ltorch.unsqueeze(t2359, 1) # t2360: "cuda:0 bf16[1, 1, 4096]"
# t2360 = prims.broadcast_in_dim(t2359, [1, 1, 4096], [0, 2]) # t2360: "cuda:0 bf16[1, 1, 4096]"
del t2359
t1645 = Tensor.expand(t2360, (1, 512, 4096)) # t1645: "cuda:0 bf16[1, 512, 4096]"
# t1645 = ltorch.expand(t2360, (1, 512, 4096)) # t1645: "cuda:0 bf16[1, 512, 4096]"
# t1645 = prims.broadcast_in_dim(t2360, (1, 512, 4096), (0, 1, 2)) # t1645: "cuda:0 bf16[1, 512, 4096]"
del t2360
t2044 = torch.unsqueeze(t69, 0) # t2044: "cuda:0 bf16[1, 4096]"
# t2044 = ltorch.unsqueeze(t69, 0) # t2044: "cuda:0 bf16[1, 4096]"
# t2044 = prims.broadcast_in_dim(t69, [1, 4096], [1]) # t2044: "cuda:0 bf16[1, 4096]"
t2045 = torch.unsqueeze(t2044, 1) # t2045: "cuda:0 bf16[1, 1, 4096]"
# t2045 = ltorch.unsqueeze(t2044, 1) # t2045: "cuda:0 bf16[1, 1, 4096]"
# t2045 = prims.broadcast_in_dim(t2044, [1, 1, 4096], [0, 2]) # t2045: "cuda:0 bf16[1, 1, 4096]"
del t2044
t205 = Tensor.expand(t2045, (1, 512, 4096)) # t205: "cuda:0 bf16[1, 512, 4096]"
# t205 = ltorch.expand(t2045, (1, 512, 4096)) # t205: "cuda:0 bf16[1, 512, 4096]"
# t205 = prims.broadcast_in_dim(t2045, (1, 512, 4096), (0, 1, 2)) # t205: "cuda:0 bf16[1, 512, 4096]"
del t2045
t2380 = torch.unsqueeze(t83, 0) # t2380: "cuda:0 bf16[1, 4096]"
# t2380 = ltorch.unsqueeze(t83, 0) # t2380: "cuda:0 bf16[1, 4096]"
# t2380 = prims.broadcast_in_dim(t83, [1, 4096], [1]) # t2380: "cuda:0 bf16[1, 4096]"
t2381 = torch.unsqueeze(t2380, 1) # t2381: "cuda:0 bf16[1, 1, 4096]"
# t2381 = ltorch.unsqueeze(t2380, 1) # t2381: "cuda:0 bf16[1, 1, 4096]"
# t2381 = prims.broadcast_in_dim(t2380, [1, 1, 4096], [0, 2]) # t2381: "cuda:0 bf16[1, 1, 4096]"
del t2380
t1717 = Tensor.expand(t2381, (1, 512, 4096)) # t1717: "cuda:0 bf16[1, 512, 4096]"
# t1717 = ltorch.expand(t2381, (1, 512, 4096)) # t1717: "cuda:0 bf16[1, 512, 4096]"
# t1717 = prims.broadcast_in_dim(t2381, (1, 512, 4096), (0, 1, 2)) # t1717: "cuda:0 bf16[1, 512, 4096]"
del t2381
t2047 = torch.unsqueeze(t60, 0) # t2047: "cuda:0 bf16[1, 4096]"
# t2047 = ltorch.unsqueeze(t60, 0) # t2047: "cuda:0 bf16[1, 4096]"
# t2047 = prims.broadcast_in_dim(t60, [1, 4096], [1]) # t2047: "cuda:0 bf16[1, 4096]"
t2048 = torch.unsqueeze(t2047, 1) # t2048: "cuda:0 bf16[1, 1, 4096]"
# t2048 = ltorch.unsqueeze(t2047, 1) # t2048: "cuda:0 bf16[1, 1, 4096]"
# t2048 = prims.broadcast_in_dim(t2047, [1, 1, 4096], [0, 2]) # t2048: "cuda:0 bf16[1, 1, 4096]"
del t2047
t241 = Tensor.expand(t2048, (1, 512, 4096)) # t241: "cuda:0 bf16[1, 512, 4096]"
# t241 = ltorch.expand(t2048, (1, 512, 4096)) # t241: "cuda:0 bf16[1, 512, 4096]"
# t241 = prims.broadcast_in_dim(t2048, (1, 512, 4096), (0, 1, 2)) # t241: "cuda:0 bf16[1, 512, 4096]"
del t2048
t2383 = torch.unsqueeze(t59, 0) # t2383: "cuda:0 bf16[1, 4096]"
# t2383 = ltorch.unsqueeze(t59, 0) # t2383: "cuda:0 bf16[1, 4096]"
# t2383 = prims.broadcast_in_dim(t59, [1, 4096], [1]) # t2383: "cuda:0 bf16[1, 4096]"
t2384 = torch.unsqueeze(t2383, 1) # t2384: "cuda:0 bf16[1, 1, 4096]"
# t2384 = ltorch.unsqueeze(t2383, 1) # t2384: "cuda:0 bf16[1, 1, 4096]"
# t2384 = prims.broadcast_in_dim(t2383, [1, 1, 4096], [0, 2]) # t2384: "cuda:0 bf16[1, 1, 4096]"
del t2383
t1753 = Tensor.expand(t2384, (1, 512, 4096)) # t1753: "cuda:0 bf16[1, 512, 4096]"
# t1753 = ltorch.expand(t2384, (1, 512, 4096)) # t1753: "cuda:0 bf16[1, 512, 4096]"
# t1753 = prims.broadcast_in_dim(t2384, (1, 512, 4096), (0, 1, 2)) # t1753: "cuda:0 bf16[1, 512, 4096]"
del t2384
t2068 = torch.unsqueeze(t70, 0) # t2068: "cuda:0 bf16[1, 4096]"
# t2068 = ltorch.unsqueeze(t70, 0) # t2068: "cuda:0 bf16[1, 4096]"
# t2068 = prims.broadcast_in_dim(t70, [1, 4096], [1]) # t2068: "cuda:0 bf16[1, 4096]"
t2069 = torch.unsqueeze(t2068, 1) # t2069: "cuda:0 bf16[1, 1, 4096]"
# t2069 = ltorch.unsqueeze(t2068, 1) # t2069: "cuda:0 bf16[1, 1, 4096]"
# t2069 = prims.broadcast_in_dim(t2068, [1, 1, 4096], [0, 2]) # t2069: "cuda:0 bf16[1, 1, 4096]"
del t2068
t313 = Tensor.expand(t2069, (1, 512, 4096)) # t313: "cuda:0 bf16[1, 512, 4096]"
# t313 = ltorch.expand(t2069, (1, 512, 4096)) # t313: "cuda:0 bf16[1, 512, 4096]"
# t313 = prims.broadcast_in_dim(t2069, (1, 512, 4096), (0, 1, 2)) # t313: "cuda:0 bf16[1, 512, 4096]"
del t2069
t2404 = torch.unsqueeze(t84, 0) # t2404: "cuda:0 bf16[1, 4096]"
# t2404 = ltorch.unsqueeze(t84, 0) # t2404: "cuda:0 bf16[1, 4096]"
# t2404 = prims.broadcast_in_dim(t84, [1, 4096], [1]) # t2404: "cuda:0 bf16[1, 4096]"
t2405 = torch.unsqueeze(t2404, 1) # t2405: "cuda:0 bf16[1, 1, 4096]"
# t2405 = ltorch.unsqueeze(t2404, 1) # t2405: "cuda:0 bf16[1, 1, 4096]"
# t2405 = prims.broadcast_in_dim(t2404, [1, 1, 4096], [0, 2]) # t2405: "cuda:0 bf16[1, 1, 4096]"
del t2404
t1825 = Tensor.expand(t2405, (1, 512, 4096)) # t1825: "cuda:0 bf16[1, 512, 4096]"
# t1825 = ltorch.expand(t2405, (1, 512, 4096)) # t1825: "cuda:0 bf16[1, 512, 4096]"
# t1825 = prims.broadcast_in_dim(t2405, (1, 512, 4096), (0, 1, 2)) # t1825: "cuda:0 bf16[1, 512, 4096]"
del t2405
t2071 = torch.unsqueeze(t61, 0) # t2071: "cuda:0 bf16[1, 4096]"
# t2071 = ltorch.unsqueeze(t61, 0) # t2071: "cuda:0 bf16[1, 4096]"
# t2071 = prims.broadcast_in_dim(t61, [1, 4096], [1]) # t2071: "cuda:0 bf16[1, 4096]"
t2072 = torch.unsqueeze(t2071, 1) # t2072: "cuda:0 bf16[1, 1, 4096]"
# t2072 = ltorch.unsqueeze(t2071, 1) # t2072: "cuda:0 bf16[1, 1, 4096]"
# t2072 = prims.broadcast_in_dim(t2071, [1, 1, 4096], [0, 2]) # t2072: "cuda:0 bf16[1, 1, 4096]"
del t2071
t349 = Tensor.expand(t2072, (1, 512, 4096)) # t349: "cuda:0 bf16[1, 512, 4096]"
# t349 = ltorch.expand(t2072, (1, 512, 4096)) # t349: "cuda:0 bf16[1, 512, 4096]"
# t349 = prims.broadcast_in_dim(t2072, (1, 512, 4096), (0, 1, 2)) # t349: "cuda:0 bf16[1, 512, 4096]"
del t2072
t2407 = torch.unsqueeze(t52, 0) # t2407: "cuda:0 bf16[1, 4096]"
# t2407 = ltorch.unsqueeze(t52, 0) # t2407: "cuda:0 bf16[1, 4096]"
# t2407 = prims.broadcast_in_dim(t52, [1, 4096], [1]) # t2407: "cuda:0 bf16[1, 4096]"
t2408 = torch.unsqueeze(t2407, 1) # t2408: "cuda:0 bf16[1, 1, 4096]"
# t2408 = ltorch.unsqueeze(t2407, 1) # t2408: "cuda:0 bf16[1, 1, 4096]"
# t2408 = prims.broadcast_in_dim(t2407, [1, 1, 4096], [0, 2]) # t2408: "cuda:0 bf16[1, 1, 4096]"
del t2407
t1861 = Tensor.expand(t2408, (1, 512, 4096)) # t1861: "cuda:0 bf16[1, 512, 4096]"
# t1861 = ltorch.expand(t2408, (1, 512, 4096)) # t1861: "cuda:0 bf16[1, 512, 4096]"
# t1861 = prims.broadcast_in_dim(t2408, (1, 512, 4096), (0, 1, 2)) # t1861: "cuda:0 bf16[1, 512, 4096]"
del t2408
t2095 = torch.unsqueeze(t62, 0) # t2095: "cuda:0 bf16[1, 4096]"
# t2095 = ltorch.unsqueeze(t62, 0) # t2095: "cuda:0 bf16[1, 4096]"
# t2095 = prims.broadcast_in_dim(t62, [1, 4096], [1]) # t2095: "cuda:0 bf16[1, 4096]"
t2096 = torch.unsqueeze(t2095, 1) # t2096: "cuda:0 bf16[1, 1, 4096]"
# t2096 = ltorch.unsqueeze(t2095, 1) # t2096: "cuda:0 bf16[1, 1, 4096]"
# t2096 = prims.broadcast_in_dim(t2095, [1, 1, 4096], [0, 2]) # t2096: "cuda:0 bf16[1, 1, 4096]"
del t2095
t457 = Tensor.expand(t2096, (1, 512, 4096)) # t457: "cuda:0 bf16[1, 512, 4096]"
# t457 = ltorch.expand(t2096, (1, 512, 4096)) # t457: "cuda:0 bf16[1, 512, 4096]"
# t457 = prims.broadcast_in_dim(t2096, (1, 512, 4096), (0, 1, 2)) # t457: "cuda:0 bf16[1, 512, 4096]"
del t2096
t2092 = torch.unsqueeze(t71, 0) # t2092: "cuda:0 bf16[1, 4096]"
# t2092 = ltorch.unsqueeze(t71, 0) # t2092: "cuda:0 bf16[1, 4096]"
# t2092 = prims.broadcast_in_dim(t71, [1, 4096], [1]) # t2092: "cuda:0 bf16[1, 4096]"
t2093 = torch.unsqueeze(t2092, 1) # t2093: "cuda:0 bf16[1, 1, 4096]"
# t2093 = ltorch.unsqueeze(t2092, 1) # t2093: "cuda:0 bf16[1, 1, 4096]"
# t2093 = prims.broadcast_in_dim(t2092, [1, 1, 4096], [0, 2]) # t2093: "cuda:0 bf16[1, 1, 4096]"
del t2092
t421 = Tensor.expand(t2093, (1, 512, 4096)) # t421: "cuda:0 bf16[1, 512, 4096]"
# t421 = ltorch.expand(t2093, (1, 512, 4096)) # t421: "cuda:0 bf16[1, 512, 4096]"
# t421 = prims.broadcast_in_dim(t2093, (1, 512, 4096), (0, 1, 2)) # t421: "cuda:0 bf16[1, 512, 4096]"
del t2093
t2116 = torch.unsqueeze(t72, 0) # t2116: "cuda:0 bf16[1, 4096]"
# t2116 = ltorch.unsqueeze(t72, 0) # t2116: "cuda:0 bf16[1, 4096]"
# t2116 = prims.broadcast_in_dim(t72, [1, 4096], [1]) # t2116: "cuda:0 bf16[1, 4096]"
t2117 = torch.unsqueeze(t2116, 1) # t2117: "cuda:0 bf16[1, 1, 4096]"
# t2117 = ltorch.unsqueeze(t2116, 1) # t2117: "cuda:0 bf16[1, 1, 4096]"
# t2117 = prims.broadcast_in_dim(t2116, [1, 1, 4096], [0, 2]) # t2117: "cuda:0 bf16[1, 1, 4096]"
del t2116
t529 = Tensor.expand(t2117, (1, 512, 4096)) # t529: "cuda:0 bf16[1, 512, 4096]"
# t529 = ltorch.expand(t2117, (1, 512, 4096)) # t529: "cuda:0 bf16[1, 512, 4096]"
# t529 = prims.broadcast_in_dim(t2117, (1, 512, 4096), (0, 1, 2)) # t529: "cuda:0 bf16[1, 512, 4096]"
del t2117
t2119 = torch.unsqueeze(t63, 0) # t2119: "cuda:0 bf16[1, 4096]"
# t2119 = ltorch.unsqueeze(t63, 0) # t2119: "cuda:0 bf16[1, 4096]"
# t2119 = prims.broadcast_in_dim(t63, [1, 4096], [1]) # t2119: "cuda:0 bf16[1, 4096]"
t2120 = torch.unsqueeze(t2119, 1) # t2120: "cuda:0 bf16[1, 1, 4096]"
# t2120 = ltorch.unsqueeze(t2119, 1) # t2120: "cuda:0 bf16[1, 1, 4096]"
# t2120 = prims.broadcast_in_dim(t2119, [1, 1, 4096], [0, 2]) # t2120: "cuda:0 bf16[1, 1, 4096]"
del t2119
t565 = Tensor.expand(t2120, (1, 512, 4096)) # t565: "cuda:0 bf16[1, 512, 4096]"
# t565 = ltorch.expand(t2120, (1, 512, 4096)) # t565: "cuda:0 bf16[1, 512, 4096]"
# t565 = prims.broadcast_in_dim(t2120, (1, 512, 4096), (0, 1, 2)) # t565: "cuda:0 bf16[1, 512, 4096]"
del t2120
t2140 = torch.unsqueeze(t73, 0) # t2140: "cuda:0 bf16[1, 4096]"
# t2140 = ltorch.unsqueeze(t73, 0) # t2140: "cuda:0 bf16[1, 4096]"
# t2140 = prims.broadcast_in_dim(t73, [1, 4096], [1]) # t2140: "cuda:0 bf16[1, 4096]"
t2141 = torch.unsqueeze(t2140, 1) # t2141: "cuda:0 bf16[1, 1, 4096]"
# t2141 = ltorch.unsqueeze(t2140, 1) # t2141: "cuda:0 bf16[1, 1, 4096]"
# t2141 = prims.broadcast_in_dim(t2140, [1, 1, 4096], [0, 2]) # t2141: "cuda:0 bf16[1, 1, 4096]"
del t2140
t637 = Tensor.expand(t2141, (1, 512, 4096)) # t637: "cuda:0 bf16[1, 512, 4096]"
# t637 = ltorch.expand(t2141, (1, 512, 4096)) # t637: "cuda:0 bf16[1, 512, 4096]"
# t637 = prims.broadcast_in_dim(t2141, (1, 512, 4096), (0, 1, 2)) # t637: "cuda:0 bf16[1, 512, 4096]"
del t2141
t2143 = torch.unsqueeze(t64, 0) # t2143: "cuda:0 bf16[1, 4096]"
# t2143 = ltorch.unsqueeze(t64, 0) # t2143: "cuda:0 bf16[1, 4096]"
# t2143 = prims.broadcast_in_dim(t64, [1, 4096], [1]) # t2143: "cuda:0 bf16[1, 4096]"
t2144 = torch.unsqueeze(t2143, 1) # t2144: "cuda:0 bf16[1, 1, 4096]"
# t2144 = ltorch.unsqueeze(t2143, 1) # t2144: "cuda:0 bf16[1, 1, 4096]"
# t2144 = prims.broadcast_in_dim(t2143, [1, 1, 4096], [0, 2]) # t2144: "cuda:0 bf16[1, 1, 4096]"
del t2143
t673 = Tensor.expand(t2144, (1, 512, 4096)) # t673: "cuda:0 bf16[1, 512, 4096]"
# t673 = ltorch.expand(t2144, (1, 512, 4096)) # t673: "cuda:0 bf16[1, 512, 4096]"
# t673 = prims.broadcast_in_dim(t2144, (1, 512, 4096), (0, 1, 2)) # t673: "cuda:0 bf16[1, 512, 4096]"
del t2144
t2164 = torch.unsqueeze(t74, 0) # t2164: "cuda:0 bf16[1, 4096]"
# t2164 = ltorch.unsqueeze(t74, 0) # t2164: "cuda:0 bf16[1, 4096]"
# t2164 = prims.broadcast_in_dim(t74, [1, 4096], [1]) # t2164: "cuda:0 bf16[1, 4096]"
t2165 = torch.unsqueeze(t2164, 1) # t2165: "cuda:0 bf16[1, 1, 4096]"
# t2165 = ltorch.unsqueeze(t2164, 1) # t2165: "cuda:0 bf16[1, 1, 4096]"
# t2165 = prims.broadcast_in_dim(t2164, [1, 1, 4096], [0, 2]) # t2165: "cuda:0 bf16[1, 1, 4096]"
del t2164
t745 = Tensor.expand(t2165, (1, 512, 4096)) # t745: "cuda:0 bf16[1, 512, 4096]"
# t745 = ltorch.expand(t2165, (1, 512, 4096)) # t745: "cuda:0 bf16[1, 512, 4096]"
# t745 = prims.broadcast_in_dim(t2165, (1, 512, 4096), (0, 1, 2)) # t745: "cuda:0 bf16[1, 512, 4096]"
del t2165
t2167 = torch.unsqueeze(t65, 0) # t2167: "cuda:0 bf16[1, 4096]"
# t2167 = ltorch.unsqueeze(t65, 0) # t2167: "cuda:0 bf16[1, 4096]"
# t2167 = prims.broadcast_in_dim(t65, [1, 4096], [1]) # t2167: "cuda:0 bf16[1, 4096]"
t2168 = torch.unsqueeze(t2167, 1) # t2168: "cuda:0 bf16[1, 1, 4096]"
# t2168 = ltorch.unsqueeze(t2167, 1) # t2168: "cuda:0 bf16[1, 1, 4096]"
# t2168 = prims.broadcast_in_dim(t2167, [1, 1, 4096], [0, 2]) # t2168: "cuda:0 bf16[1, 1, 4096]"
del t2167
t781 = Tensor.expand(t2168, (1, 512, 4096)) # t781: "cuda:0 bf16[1, 512, 4096]"
# t781 = ltorch.expand(t2168, (1, 512, 4096)) # t781: "cuda:0 bf16[1, 512, 4096]"
# t781 = prims.broadcast_in_dim(t2168, (1, 512, 4096), (0, 1, 2)) # t781: "cuda:0 bf16[1, 512, 4096]"
del t2168
t2188 = torch.unsqueeze(t75, 0) # t2188: "cuda:0 bf16[1, 4096]"
# t2188 = ltorch.unsqueeze(t75, 0) # t2188: "cuda:0 bf16[1, 4096]"
# t2188 = prims.broadcast_in_dim(t75, [1, 4096], [1]) # t2188: "cuda:0 bf16[1, 4096]"
t2189 = torch.unsqueeze(t2188, 1) # t2189: "cuda:0 bf16[1, 1, 4096]"
# t2189 = ltorch.unsqueeze(t2188, 1) # t2189: "cuda:0 bf16[1, 1, 4096]"
# t2189 = prims.broadcast_in_dim(t2188, [1, 1, 4096], [0, 2]) # t2189: "cuda:0 bf16[1, 1, 4096]"
del t2188
t853 = Tensor.expand(t2189, (1, 512, 4096)) # t853: "cuda:0 bf16[1, 512, 4096]"
# t853 = ltorch.expand(t2189, (1, 512, 4096)) # t853: "cuda:0 bf16[1, 512, 4096]"
# t853 = prims.broadcast_in_dim(t2189, (1, 512, 4096), (0, 1, 2)) # t853: "cuda:0 bf16[1, 512, 4096]"
del t2189
t2191 = torch.unsqueeze(t66, 0) # t2191: "cuda:0 bf16[1, 4096]"
# t2191 = ltorch.unsqueeze(t66, 0) # t2191: "cuda:0 bf16[1, 4096]"
# t2191 = prims.broadcast_in_dim(t66, [1, 4096], [1]) # t2191: "cuda:0 bf16[1, 4096]"
t2192 = torch.unsqueeze(t2191, 1) # t2192: "cuda:0 bf16[1, 1, 4096]"
# t2192 = ltorch.unsqueeze(t2191, 1) # t2192: "cuda:0 bf16[1, 1, 4096]"
# t2192 = prims.broadcast_in_dim(t2191, [1, 1, 4096], [0, 2]) # t2192: "cuda:0 bf16[1, 1, 4096]"
del t2191
t889 = Tensor.expand(t2192, (1, 512, 4096)) # t889: "cuda:0 bf16[1, 512, 4096]"
# t889 = ltorch.expand(t2192, (1, 512, 4096)) # t889: "cuda:0 bf16[1, 512, 4096]"
# t889 = prims.broadcast_in_dim(t2192, (1, 512, 4096), (0, 1, 2)) # t889: "cuda:0 bf16[1, 512, 4096]"
del t2192
t2212 = torch.unsqueeze(t76, 0) # t2212: "cuda:0 bf16[1, 4096]"
# t2212 = ltorch.unsqueeze(t76, 0) # t2212: "cuda:0 bf16[1, 4096]"
# t2212 = prims.broadcast_in_dim(t76, [1, 4096], [1]) # t2212: "cuda:0 bf16[1, 4096]"
t2213 = torch.unsqueeze(t2212, 1) # t2213: "cuda:0 bf16[1, 1, 4096]"
# t2213 = ltorch.unsqueeze(t2212, 1) # t2213: "cuda:0 bf16[1, 1, 4096]"
# t2213 = prims.broadcast_in_dim(t2212, [1, 1, 4096], [0, 2]) # t2213: "cuda:0 bf16[1, 1, 4096]"
del t2212
t961 = Tensor.expand(t2213, (1, 512, 4096)) # t961: "cuda:0 bf16[1, 512, 4096]"
# t961 = ltorch.expand(t2213, (1, 512, 4096)) # t961: "cuda:0 bf16[1, 512, 4096]"
# t961 = prims.broadcast_in_dim(t2213, (1, 512, 4096), (0, 1, 2)) # t961: "cuda:0 bf16[1, 512, 4096]"
del t2213
t2215 = torch.unsqueeze(t67, 0) # t2215: "cuda:0 bf16[1, 4096]"
# t2215 = ltorch.unsqueeze(t67, 0) # t2215: "cuda:0 bf16[1, 4096]"
# t2215 = prims.broadcast_in_dim(t67, [1, 4096], [1]) # t2215: "cuda:0 bf16[1, 4096]"
t2216 = torch.unsqueeze(t2215, 1) # t2216: "cuda:0 bf16[1, 1, 4096]"
# t2216 = ltorch.unsqueeze(t2215, 1) # t2216: "cuda:0 bf16[1, 1, 4096]"
# t2216 = prims.broadcast_in_dim(t2215, [1, 1, 4096], [0, 2]) # t2216: "cuda:0 bf16[1, 1, 4096]"
del t2215
t997 = Tensor.expand(t2216, (1, 512, 4096)) # t997: "cuda:0 bf16[1, 512, 4096]"
# t997 = ltorch.expand(t2216, (1, 512, 4096)) # t997: "cuda:0 bf16[1, 512, 4096]"
# t997 = prims.broadcast_in_dim(t2216, (1, 512, 4096), (0, 1, 2)) # t997: "cuda:0 bf16[1, 512, 4096]"
del t2216
t2236 = torch.unsqueeze(t77, 0) # t2236: "cuda:0 bf16[1, 4096]"
# t2236 = ltorch.unsqueeze(t77, 0) # t2236: "cuda:0 bf16[1, 4096]"
# t2236 = prims.broadcast_in_dim(t77, [1, 4096], [1]) # t2236: "cuda:0 bf16[1, 4096]"
t2237 = torch.unsqueeze(t2236, 1) # t2237: "cuda:0 bf16[1, 1, 4096]"
# t2237 = ltorch.unsqueeze(t2236, 1) # t2237: "cuda:0 bf16[1, 1, 4096]"
# t2237 = prims.broadcast_in_dim(t2236, [1, 1, 4096], [0, 2]) # t2237: "cuda:0 bf16[1, 1, 4096]"
del t2236
t1069 = Tensor.expand(t2237, (1, 512, 4096)) # t1069: "cuda:0 bf16[1, 512, 4096]"
# t1069 = ltorch.expand(t2237, (1, 512, 4096)) # t1069: "cuda:0 bf16[1, 512, 4096]"
# t1069 = prims.broadcast_in_dim(t2237, (1, 512, 4096), (0, 1, 2)) # t1069: "cuda:0 bf16[1, 512, 4096]"
del t2237
t2239 = torch.unsqueeze(t68, 0) # t2239: "cuda:0 bf16[1, 4096]"
# t2239 = ltorch.unsqueeze(t68, 0) # t2239: "cuda:0 bf16[1, 4096]"
# t2239 = prims.broadcast_in_dim(t68, [1, 4096], [1]) # t2239: "cuda:0 bf16[1, 4096]"
t2240 = torch.unsqueeze(t2239, 1) # t2240: "cuda:0 bf16[1, 1, 4096]"
# t2240 = ltorch.unsqueeze(t2239, 1) # t2240: "cuda:0 bf16[1, 1, 4096]"
# t2240 = prims.broadcast_in_dim(t2239, [1, 1, 4096], [0, 2]) # t2240: "cuda:0 bf16[1, 1, 4096]"
del t2239
t1105 = Tensor.expand(t2240, (1, 512, 4096)) # t1105: "cuda:0 bf16[1, 512, 4096]"
# t1105 = ltorch.expand(t2240, (1, 512, 4096)) # t1105: "cuda:0 bf16[1, 512, 4096]"
# t1105 = prims.broadcast_in_dim(t2240, (1, 512, 4096), (0, 1, 2)) # t1105: "cuda:0 bf16[1, 512, 4096]"
del t2240
t2260 = torch.unsqueeze(t78, 0) # t2260: "cuda:0 bf16[1, 4096]"
# t2260 = ltorch.unsqueeze(t78, 0) # t2260: "cuda:0 bf16[1, 4096]"
# t2260 = prims.broadcast_in_dim(t78, [1, 4096], [1]) # t2260: "cuda:0 bf16[1, 4096]"
t2261 = torch.unsqueeze(t2260, 1) # t2261: "cuda:0 bf16[1, 1, 4096]"
# t2261 = ltorch.unsqueeze(t2260, 1) # t2261: "cuda:0 bf16[1, 1, 4096]"
# t2261 = prims.broadcast_in_dim(t2260, [1, 1, 4096], [0, 2]) # t2261: "cuda:0 bf16[1, 1, 4096]"
del t2260
t1177 = Tensor.expand(t2261, (1, 512, 4096)) # t1177: "cuda:0 bf16[1, 512, 4096]"
# t1177 = ltorch.expand(t2261, (1, 512, 4096)) # t1177: "cuda:0 bf16[1, 512, 4096]"
# t1177 = prims.broadcast_in_dim(t2261, (1, 512, 4096), (0, 1, 2)) # t1177: "cuda:0 bf16[1, 512, 4096]"
del t2261
t2263 = torch.unsqueeze(t54, 0) # t2263: "cuda:0 bf16[1, 4096]"
# t2263 = ltorch.unsqueeze(t54, 0) # t2263: "cuda:0 bf16[1, 4096]"
# t2263 = prims.broadcast_in_dim(t54, [1, 4096], [1]) # t2263: "cuda:0 bf16[1, 4096]"
t2264 = torch.unsqueeze(t2263, 1) # t2264: "cuda:0 bf16[1, 1, 4096]"
# t2264 = ltorch.unsqueeze(t2263, 1) # t2264: "cuda:0 bf16[1, 1, 4096]"
# t2264 = prims.broadcast_in_dim(t2263, [1, 1, 4096], [0, 2]) # t2264: "cuda:0 bf16[1, 1, 4096]"
del t2263
t1213 = Tensor.expand(t2264, (1, 512, 4096)) # t1213: "cuda:0 bf16[1, 512, 4096]"
# t1213 = ltorch.expand(t2264, (1, 512, 4096)) # t1213: "cuda:0 bf16[1, 512, 4096]"
# t1213 = prims.broadcast_in_dim(t2264, (1, 512, 4096), (0, 1, 2)) # t1213: "cuda:0 bf16[1, 512, 4096]"
del t2264
t2284 = torch.unsqueeze(t79, 0) # t2284: "cuda:0 bf16[1, 4096]"
# t2284 = ltorch.unsqueeze(t79, 0) # t2284: "cuda:0 bf16[1, 4096]"
# t2284 = prims.broadcast_in_dim(t79, [1, 4096], [1]) # t2284: "cuda:0 bf16[1, 4096]"
t2285 = torch.unsqueeze(t2284, 1) # t2285: "cuda:0 bf16[1, 1, 4096]"
# t2285 = ltorch.unsqueeze(t2284, 1) # t2285: "cuda:0 bf16[1, 1, 4096]"
# t2285 = prims.broadcast_in_dim(t2284, [1, 1, 4096], [0, 2]) # t2285: "cuda:0 bf16[1, 1, 4096]"
del t2284
t1285 = Tensor.expand(t2285, (1, 512, 4096)) # t1285: "cuda:0 bf16[1, 512, 4096]"
# t1285 = ltorch.expand(t2285, (1, 512, 4096)) # t1285: "cuda:0 bf16[1, 512, 4096]"
# t1285 = prims.broadcast_in_dim(t2285, (1, 512, 4096), (0, 1, 2)) # t1285: "cuda:0 bf16[1, 512, 4096]"
del t2285
t2287 = torch.unsqueeze(t55, 0) # t2287: "cuda:0 bf16[1, 4096]"
# t2287 = ltorch.unsqueeze(t55, 0) # t2287: "cuda:0 bf16[1, 4096]"
# t2287 = prims.broadcast_in_dim(t55, [1, 4096], [1]) # t2287: "cuda:0 bf16[1, 4096]"
t2288 = torch.unsqueeze(t2287, 1) # t2288: "cuda:0 bf16[1, 1, 4096]"
# t2288 = ltorch.unsqueeze(t2287, 1) # t2288: "cuda:0 bf16[1, 1, 4096]"
# t2288 = prims.broadcast_in_dim(t2287, [1, 1, 4096], [0, 2]) # t2288: "cuda:0 bf16[1, 1, 4096]"
del t2287
t1321 = Tensor.expand(t2288, (1, 512, 4096)) # t1321: "cuda:0 bf16[1, 512, 4096]"
# t1321 = ltorch.expand(t2288, (1, 512, 4096)) # t1321: "cuda:0 bf16[1, 512, 4096]"
# t1321 = prims.broadcast_in_dim(t2288, (1, 512, 4096), (0, 1, 2)) # t1321: "cuda:0 bf16[1, 512, 4096]"
del t2288
t2308 = torch.unsqueeze(t80, 0) # t2308: "cuda:0 bf16[1, 4096]"
# t2308 = ltorch.unsqueeze(t80, 0) # t2308: "cuda:0 bf16[1, 4096]"
# t2308 = prims.broadcast_in_dim(t80, [1, 4096], [1]) # t2308: "cuda:0 bf16[1, 4096]"
t2309 = torch.unsqueeze(t2308, 1) # t2309: "cuda:0 bf16[1, 1, 4096]"
# t2309 = ltorch.unsqueeze(t2308, 1) # t2309: "cuda:0 bf16[1, 1, 4096]"
# t2309 = prims.broadcast_in_dim(t2308, [1, 1, 4096], [0, 2]) # t2309: "cuda:0 bf16[1, 1, 4096]"
del t2308
t1393 = Tensor.expand(t2309, (1, 512, 4096)) # t1393: "cuda:0 bf16[1, 512, 4096]"
# t1393 = ltorch.expand(t2309, (1, 512, 4096)) # t1393: "cuda:0 bf16[1, 512, 4096]"
# t1393 = prims.broadcast_in_dim(t2309, (1, 512, 4096), (0, 1, 2)) # t1393: "cuda:0 bf16[1, 512, 4096]"
del t2309
t2311 = torch.unsqueeze(t56, 0) # t2311: "cuda:0 bf16[1, 4096]"
# t2311 = ltorch.unsqueeze(t56, 0) # t2311: "cuda:0 bf16[1, 4096]"
# t2311 = prims.broadcast_in_dim(t56, [1, 4096], [1]) # t2311: "cuda:0 bf16[1, 4096]"
t2312 = torch.unsqueeze(t2311, 1) # t2312: "cuda:0 bf16[1, 1, 4096]"
# t2312 = ltorch.unsqueeze(t2311, 1) # t2312: "cuda:0 bf16[1, 1, 4096]"
# t2312 = prims.broadcast_in_dim(t2311, [1, 1, 4096], [0, 2]) # t2312: "cuda:0 bf16[1, 1, 4096]"
del t2311
t1429 = Tensor.expand(t2312, (1, 512, 4096)) # t1429: "cuda:0 bf16[1, 512, 4096]"
# t1429 = ltorch.expand(t2312, (1, 512, 4096)) # t1429: "cuda:0 bf16[1, 512, 4096]"
# t1429 = prims.broadcast_in_dim(t2312, (1, 512, 4096), (0, 1, 2)) # t1429: "cuda:0 bf16[1, 512, 4096]"
del t2312
t2332 = torch.unsqueeze(t81, 0) # t2332: "cuda:0 bf16[1, 4096]"
# t2332 = ltorch.unsqueeze(t81, 0) # t2332: "cuda:0 bf16[1, 4096]"
# t2332 = prims.broadcast_in_dim(t81, [1, 4096], [1]) # t2332: "cuda:0 bf16[1, 4096]"
t2333 = torch.unsqueeze(t2332, 1) # t2333: "cuda:0 bf16[1, 1, 4096]"
# t2333 = ltorch.unsqueeze(t2332, 1) # t2333: "cuda:0 bf16[1, 1, 4096]"
# t2333 = prims.broadcast_in_dim(t2332, [1, 1, 4096], [0, 2]) # t2333: "cuda:0 bf16[1, 1, 4096]"
del t2332
t1501 = Tensor.expand(t2333, (1, 512, 4096)) # t1501: "cuda:0 bf16[1, 512, 4096]"
# t1501 = ltorch.expand(t2333, (1, 512, 4096)) # t1501: "cuda:0 bf16[1, 512, 4096]"
# t1501 = prims.broadcast_in_dim(t2333, (1, 512, 4096), (0, 1, 2)) # t1501: "cuda:0 bf16[1, 512, 4096]"
del t2333
t2335 = torch.unsqueeze(t57, 0) # t2335: "cuda:0 bf16[1, 4096]"
# t2335 = ltorch.unsqueeze(t57, 0) # t2335: "cuda:0 bf16[1, 4096]"
# t2335 = prims.broadcast_in_dim(t57, [1, 4096], [1]) # t2335: "cuda:0 bf16[1, 4096]"
t2336 = torch.unsqueeze(t2335, 1) # t2336: "cuda:0 bf16[1, 1, 4096]"
# t2336 = ltorch.unsqueeze(t2335, 1) # t2336: "cuda:0 bf16[1, 1, 4096]"
# t2336 = prims.broadcast_in_dim(t2335, [1, 1, 4096], [0, 2]) # t2336: "cuda:0 bf16[1, 1, 4096]"
del t2335
t1537 = Tensor.expand(t2336, (1, 512, 4096)) # t1537: "cuda:0 bf16[1, 512, 4096]"
# t1537 = ltorch.expand(t2336, (1, 512, 4096)) # t1537: "cuda:0 bf16[1, 512, 4096]"
# t1537 = prims.broadcast_in_dim(t2336, (1, 512, 4096), (0, 1, 2)) # t1537: "cuda:0 bf16[1, 512, 4096]"
del t2336
t2036 = torch.unsqueeze(t118, 0) # t2036: "cuda:0 f32[1, 512, 128]"
# t2036 = ltorch.unsqueeze(t118, 0) # t2036: "cuda:0 f32[1, 512, 128]"
# t2036 = prims.broadcast_in_dim(t118, [1, 512, 128], [1, 2]) # t2036: "cuda:0 f32[1, 512, 128]"
del t118
t2037 = torch.unsqueeze(t2036, 1) # t2037: "cuda:0 f32[1, 1, 512, 128]"
# t2037 = ltorch.unsqueeze(t2036, 1) # t2037: "cuda:0 f32[1, 1, 512, 128]"
# t2037 = prims.broadcast_in_dim(t2036, [1, 1, 512, 128], [0, 2, 3]) # t2037: "cuda:0 f32[1, 1, 512, 128]"
del t2036
t154 = Tensor.expand(t2037, (1, 32, 512, 128)) # t154: "cuda:0 f32[1, 32, 512, 128]"
# t154 = ltorch.expand(t2037, (1, 32, 512, 128)) # t154: "cuda:0 f32[1, 32, 512, 128]"
# t154 = prims.broadcast_in_dim(t2037, (1, 32, 512, 128), (0, 1, 2, 3)) # t154: "cuda:0 f32[1, 32, 512, 128]"
del t2037
t2039 = torch.unsqueeze(t119, 0) # t2039: "cuda:0 f32[1, 512, 128]"
# t2039 = ltorch.unsqueeze(t119, 0) # t2039: "cuda:0 f32[1, 512, 128]"
# t2039 = prims.broadcast_in_dim(t119, [1, 512, 128], [1, 2]) # t2039: "cuda:0 f32[1, 512, 128]"
del t119
t2040 = torch.unsqueeze(t2039, 1) # t2040: "cuda:0 f32[1, 1, 512, 128]"
# t2040 = ltorch.unsqueeze(t2039, 1) # t2040: "cuda:0 f32[1, 1, 512, 128]"
# t2040 = prims.broadcast_in_dim(t2039, [1, 1, 512, 128], [0, 2, 3]) # t2040: "cuda:0 f32[1, 1, 512, 128]"
del t2039
t157 = Tensor.expand(t2040, (1, 32, 512, 128)) # t157: "cuda:0 f32[1, 32, 512, 128]"
# t157 = ltorch.expand(t2040, (1, 32, 512, 128)) # t157: "cuda:0 f32[1, 32, 512, 128]"
# t157 = prims.broadcast_in_dim(t2040, (1, 32, 512, 128), (0, 1, 2, 3)) # t157: "cuda:0 f32[1, 32, 512, 128]"
del t2040
[t129, t137] = nvFusion0(t122, t133)
# t123 = prims.convert_element_type(t122, dtypes.float32) # t123: "cuda:0 f32[1, 512, 4096]"
# t124 = prims.mul(t123, t123) # t124: "cuda:0 f32[1, 512, 4096]"
# t125 = prims.sum(t124, (2,)) # t125: "cuda:0 f32[1, 512]"
# t126 = prims.broadcast_in_dim(t125, [1, 512, 1], [0, 1]) # t126: "cuda:0 f32[1, 512, 1]"
# t127 = prims.div(t126, 4096.0) # t127: "cuda:0 f32[1, 512, 1]"
# t128 = prims.add(t127, 1e-05) # t128: "cuda:0 f32[1, 512, 1]"
# t129 = prims.rsqrt(t128) # t129: "cuda:0 f32[1, 512, 1]"
# t130 = prims.broadcast_in_dim(t129, (1, 512, 4096), (0, 1, 2)) # t130: "cuda:0 f32[1, 512, 4096]"
# t131 = prims.mul(t123, t130) # t131: "cuda:0 f32[1, 512, 4096]"
# t135 = prims.convert_element_type(t133, dtypes.float32) # t135: "cuda:0 f32[1, 512, 4096]"
# t136 = prims.mul(t131, t135) # t136: "cuda:0 f32[1, 512, 4096]"
# t137 = prims.convert_element_type(t136, dtypes.bfloat16) # t137: "cuda:0 bf16[1, 512, 4096]"
t138 = torch.nn.functional.linear(t137, t3, None) # t138: "cuda:0 bf16[1, 512, 12288]"
# t138 = ltorch.linear(t137, t3, None) # t138: "cuda:0 bf16[1, 512, 12288]"
# t138 = prims.linear(t137, t3, None) # t138: "cuda:0 bf16[1, 512, 12288]"
t139 = torch.reshape(t138, (1, 512, 32, 3, 128)) # t139: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t139 = ltorch.reshape(t138, (1, 512, 32, 3, 128)) # t139: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t139 = prims.reshape(t138, (1, 512, 32, 3, 128)) # t139: "cuda:0 bf16[1, 512, 32, 3, 128]"
del t138
t140 = torch.permute(t139, (0, 2, 3, 1, 4)) # t140: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t140 = ltorch.permute(t139, (0, 2, 3, 1, 4)) # t140: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t140 = prims.transpose(t139, (0, 2, 3, 1, 4)) # t140: "cuda:0 bf16[1, 32, 3, 512, 128]"
del t139
(t141, t142, t143) = torch.split(t140, (1, 1, 1), 2)
# (t141, t142, t143) = ltorch.split(t140, (1, 1, 1), 2)
# t141 = prims.slice_prim(t140, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t141: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t142 = prims.slice_prim(t140, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t142: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t143 = prims.slice_prim(t140, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t143: "cuda:0 bf16[1, 32, 1, 512, 128]"
del t140
t144 = torch.reshape(t141, (1, 32, 512, 128)) # t144: "cuda:0 bf16[1, 32, 512, 128]"
# t144 = ltorch.reshape(t141, (1, 32, 512, 128)) # t144: "cuda:0 bf16[1, 32, 512, 128]"
# t144 = prims.reshape(t141, (1, 32, 512, 128)) # t144: "cuda:0 bf16[1, 32, 512, 128]"
del t141
t145 = torch.reshape(t142, (1, 32, 512, 128)) # t145: "cuda:0 bf16[1, 32, 512, 128]"
# t145 = ltorch.reshape(t142, (1, 32, 512, 128)) # t145: "cuda:0 bf16[1, 32, 512, 128]"
# t145 = prims.reshape(t142, (1, 32, 512, 128)) # t145: "cuda:0 bf16[1, 32, 512, 128]"
del t142
t146 = torch.reshape(t143, (1, 32, 512, 128)) # t146: "cuda:0 bf16[1, 32, 512, 128]"
# t146 = ltorch.reshape(t143, (1, 32, 512, 128)) # t146: "cuda:0 bf16[1, 32, 512, 128]"
# t146 = prims.reshape(t143, (1, 32, 512, 128)) # t146: "cuda:0 bf16[1, 32, 512, 128]"
del t143
t147 = torch_slice_prim_impl(t144, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t147: "cuda:0 bf16[1, 32, 512, 128]"
t162 = torch_slice_prim_impl(t145, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t162: "cuda:0 bf16[1, 32, 512, 128]"
t177 = torch_slice_prim_impl(t144, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t177: "cuda:0 bf16[1, 32, 512, 0]"
del t144
t179 = torch_slice_prim_impl(t145, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t179: "cuda:0 bf16[1, 32, 512, 0]"
del t145
t149 = torch_slice_prim_impl(t147, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t149: "cuda:0 bf16[1, 32, 512, 64]"
t148 = torch_slice_prim_impl(t147, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t148: "cuda:0 bf16[1, 32, 512, 64]"
t163 = torch_slice_prim_impl(t162, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t163: "cuda:0 bf16[1, 32, 512, 64]"
t164 = torch_slice_prim_impl(t162, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t164: "cuda:0 bf16[1, 32, 512, 64]"
[t152, t167] = nvFusion1(t147, t149, t162, t164)
# t150 = prims.convert_element_type(t149, dtypes.float32) # t150: "cuda:0 f32[1, 32, 512, 64]"
# t151 = prims.neg(t150) # t151: "cuda:0 f32[1, 32, 512, 64]"
# t152 = prims.convert_element_type(t151, dtypes.bfloat16) # t152: "cuda:0 bf16[1, 32, 512, 64]"
# t165 = prims.convert_element_type(t164, dtypes.float32) # t165: "cuda:0 f32[1, 32, 512, 64]"
# t166 = prims.neg(t165) # t166: "cuda:0 f32[1, 32, 512, 64]"
# t167 = prims.convert_element_type(t166, dtypes.bfloat16) # t167: "cuda:0 bf16[1, 32, 512, 64]"
del t149, t164
t168 = torch.cat((t167, t163), -1) # t168: "cuda:0 bf16[1, 32, 512, 128]"
# t168 = ltorch.cat((t167, t163), -1) # t168: "cuda:0 bf16[1, 32, 512, 128]"
# t168 = prims.cat((t167, t163), -1) # t168: "cuda:0 bf16[1, 32, 512, 128]"
del t167, t163
t153 = torch.cat((t152, t148), -1) # t153: "cuda:0 bf16[1, 32, 512, 128]"
# t153 = ltorch.cat((t152, t148), -1) # t153: "cuda:0 bf16[1, 32, 512, 128]"
# t153 = prims.cat((t152, t148), -1) # t153: "cuda:0 bf16[1, 32, 512, 128]"
del t152, t148
[t161, t176] = nvFusion2(t147, t153, t154, t157, t162, t168)
# t155 = prims.convert_element_type(t147, dtypes.float32) # t155: "cuda:0 f32[1, 32, 512, 128]"
# t170 = prims.convert_element_type(t162, dtypes.float32) # t170: "cuda:0 f32[1, 32, 512, 128]"
# t156 = prims.mul(t155, t154) # t156: "cuda:0 f32[1, 32, 512, 128]"
# t158 = prims.convert_element_type(t153, dtypes.float32) # t158: "cuda:0 f32[1, 32, 512, 128]"
# t159 = prims.mul(t158, t157) # t159: "cuda:0 f32[1, 32, 512, 128]"
# t160 = prims.add(t156, t159) # t160: "cuda:0 f32[1, 32, 512, 128]"
# t161 = prims.convert_element_type(t160, dtypes.bfloat16) # t161: "cuda:0 bf16[1, 32, 512, 128]"
# t171 = prims.mul(t170, t154) # t171: "cuda:0 f32[1, 32, 512, 128]"
# t173 = prims.convert_element_type(t168, dtypes.float32) # t173: "cuda:0 f32[1, 32, 512, 128]"
# t174 = prims.mul(t173, t157) # t174: "cuda:0 f32[1, 32, 512, 128]"
# t175 = prims.add(t171, t174) # t175: "cuda:0 f32[1, 32, 512, 128]"
# t176 = prims.convert_element_type(t175, dtypes.bfloat16) # t176: "cuda:0 bf16[1, 32, 512, 128]"
del t147, t153, t162, t168
t178 = torch.cat((t161, t177), -1) # t178: "cuda:0 bf16[1, 32, 512, 128]"
# t178 = ltorch.cat((t161, t177), -1) # t178: "cuda:0 bf16[1, 32, 512, 128]"
# t178 = prims.cat((t161, t177), -1) # t178: "cuda:0 bf16[1, 32, 512, 128]"
del t161, t177
t180 = torch.cat((t176, t179), -1) # t180: "cuda:0 bf16[1, 32, 512, 128]"
# t180 = ltorch.cat((t176, t179), -1) # t180: "cuda:0 bf16[1, 32, 512, 128]"
# t180 = prims.cat((t176, t179), -1) # t180: "cuda:0 bf16[1, 32, 512, 128]"
del t176, t179
(t181, t182, t183, t184, _, _, t185, t186, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t178, t180, t146, 0.0, True, scale=0.08838834764831843)
t188 = torch.permute(t181, (0, 2, 1, 3)) # t188: "cuda:0 bf16[1, 512, 32, 128]"
# t188 = ltorch.permute(t181, (0, 2, 1, 3)) # t188: "cuda:0 bf16[1, 512, 32, 128]"
# t188 = prims.transpose(t181, (0, 2, 1, 3)) # t188: "cuda:0 bf16[1, 512, 32, 128]"
t189 = torch.reshape(t188, (1, 512, 4096)) # t189: "cuda:0 bf16[1, 512, 4096]"
# t189 = ltorch.reshape(t188, (1, 512, 4096)) # t189: "cuda:0 bf16[1, 512, 4096]"
# t189 = prims.reshape(t188, (1, 512, 4096)) # t189: "cuda:0 bf16[1, 512, 4096]"
del t188
t190 = torch.nn.functional.linear(t189, t85, None) # t190: "cuda:0 bf16[1, 512, 4096]"
# t190 = ltorch.linear(t189, t85, None) # t190: "cuda:0 bf16[1, 512, 4096]"
# t190 = prims.linear(t189, t85, None) # t190: "cuda:0 bf16[1, 512, 4096]"
[t194, t201, t209] = nvFusion3(t122, t190, t205)
# t191 = prims.convert_element_type(t190, dtypes.float32) # t191: "cuda:0 f32[1, 512, 4096]"
# t192 = prims.convert_element_type(t122, dtypes.float32) # t192: "cuda:0 f32[1, 512, 4096]"
# t193 = prims.add(t191, t192) # t193: "cuda:0 f32[1, 512, 4096]"
# t194 = prims.convert_element_type(t193, dtypes.bfloat16) # t194: "cuda:0 bf16[1, 512, 4096]"
# t196 = prims.mul(t193, t193) # t196: "cuda:0 f32[1, 512, 4096]"
# t197 = prims.sum(t196, (2,)) # t197: "cuda:0 f32[1, 512]"
# t198 = prims.broadcast_in_dim(t197, [1, 512, 1], [0, 1]) # t198: "cuda:0 f32[1, 512, 1]"
# t199 = prims.div(t198, 4096.0) # t199: "cuda:0 f32[1, 512, 1]"
# t200 = prims.add(t199, 1e-05) # t200: "cuda:0 f32[1, 512, 1]"
# t201 = prims.rsqrt(t200) # t201: "cuda:0 f32[1, 512, 1]"
# t202 = prims.broadcast_in_dim(t201, (1, 512, 4096), (0, 1, 2)) # t202: "cuda:0 f32[1, 512, 4096]"
# t203 = prims.mul(t193, t202) # t203: "cuda:0 f32[1, 512, 4096]"
# t207 = prims.convert_element_type(t205, dtypes.float32) # t207: "cuda:0 f32[1, 512, 4096]"
# t208 = prims.mul(t203, t207) # t208: "cuda:0 f32[1, 512, 4096]"
# t209 = prims.convert_element_type(t208, dtypes.bfloat16) # t209: "cuda:0 bf16[1, 512, 4096]"
t210 = torch.nn.functional.linear(t209, t19, None) # t210: "cuda:0 bf16[1, 512, 11008]"
# t210 = ltorch.linear(t209, t19, None) # t210: "cuda:0 bf16[1, 512, 11008]"
# t210 = prims.linear(t209, t19, None) # t210: "cuda:0 bf16[1, 512, 11008]"
t211 = torch.nn.functional.linear(t209, t35, None) # t211: "cuda:0 bf16[1, 512, 11008]"
# t211 = ltorch.linear(t209, t35, None) # t211: "cuda:0 bf16[1, 512, 11008]"
# t211 = prims.linear(t209, t35, None) # t211: "cuda:0 bf16[1, 512, 11008]"
[t225] = nvFusion4(t210, t211)
# t212 = prims.convert_element_type(t210, dtypes.float32) # t212: "cuda:0 f32[1, 512, 11008]"
# t213 = prims.neg(t212) # t213: "cuda:0 f32[1, 512, 11008]"
# t214 = prims.exp(t213) # t214: "cuda:0 f32[1, 512, 11008]"
# t215 = prims.add(1.0, t214) # t215: "cuda:0 f32[1, 512, 11008]"
# t216 = prims.reciprocal(t215) # t216: "cuda:0 f32[1, 512, 11008]"
# t220 = prims.mul(t212, t216) # t220: "cuda:0 f32[1, 512, 11008]"
# t223 = prims.convert_element_type(t211, dtypes.float32) # t223: "cuda:0 f32[1, 512, 11008]"
# t224 = prims.mul(t220, t223) # t224: "cuda:0 f32[1, 512, 11008]"
# t225 = prims.convert_element_type(t224, dtypes.bfloat16) # t225: "cuda:0 bf16[1, 512, 11008]"
t226 = torch.nn.functional.linear(t225, t86, None) # t226: "cuda:0 bf16[1, 512, 4096]"
# t226 = ltorch.linear(t225, t86, None) # t226: "cuda:0 bf16[1, 512, 4096]"
# t226 = prims.linear(t225, t86, None) # t226: "cuda:0 bf16[1, 512, 4096]"
[t230, t237, t245] = nvFusion5(t194, t226, t241)
# t228 = prims.convert_element_type(t194, dtypes.float32) # t228: "cuda:0 f32[1, 512, 4096]"
# t227 = prims.convert_element_type(t226, dtypes.float32) # t227: "cuda:0 f32[1, 512, 4096]"
# t229 = prims.add(t227, t228) # t229: "cuda:0 f32[1, 512, 4096]"
# t230 = prims.convert_element_type(t229, dtypes.bfloat16) # t230: "cuda:0 bf16[1, 512, 4096]"
# t232 = prims.mul(t229, t229) # t232: "cuda:0 f32[1, 512, 4096]"
# t233 = prims.sum(t232, (2,)) # t233: "cuda:0 f32[1, 512]"
# t234 = prims.broadcast_in_dim(t233, [1, 512, 1], [0, 1]) # t234: "cuda:0 f32[1, 512, 1]"
# t235 = prims.div(t234, 4096.0) # t235: "cuda:0 f32[1, 512, 1]"
# t236 = prims.add(t235, 1e-05) # t236: "cuda:0 f32[1, 512, 1]"
# t237 = prims.rsqrt(t236) # t237: "cuda:0 f32[1, 512, 1]"
# t238 = prims.broadcast_in_dim(t237, (1, 512, 4096), (0, 1, 2)) # t238: "cuda:0 f32[1, 512, 4096]"
# t239 = prims.mul(t229, t238) # t239: "cuda:0 f32[1, 512, 4096]"
# t243 = prims.convert_element_type(t241, dtypes.float32) # t243: "cuda:0 f32[1, 512, 4096]"
# t244 = prims.mul(t239, t243) # t244: "cuda:0 f32[1, 512, 4096]"
# t245 = prims.convert_element_type(t244, dtypes.bfloat16) # t245: "cuda:0 bf16[1, 512, 4096]"
t246 = torch.nn.functional.linear(t245, t4, None) # t246: "cuda:0 bf16[1, 512, 12288]"
# t246 = ltorch.linear(t245, t4, None) # t246: "cuda:0 bf16[1, 512, 12288]"
# t246 = prims.linear(t245, t4, None) # t246: "cuda:0 bf16[1, 512, 12288]"
t247 = torch.reshape(t246, (1, 512, 32, 3, 128)) # t247: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t247 = ltorch.reshape(t246, (1, 512, 32, 3, 128)) # t247: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t247 = prims.reshape(t246, (1, 512, 32, 3, 128)) # t247: "cuda:0 bf16[1, 512, 32, 3, 128]"
del t246
t248 = torch.permute(t247, (0, 2, 3, 1, 4)) # t248: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t248 = ltorch.permute(t247, (0, 2, 3, 1, 4)) # t248: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t248 = prims.transpose(t247, (0, 2, 3, 1, 4)) # t248: "cuda:0 bf16[1, 32, 3, 512, 128]"
del t247
(t249, t250, t251) = torch.split(t248, (1, 1, 1), 2)
# (t249, t250, t251) = ltorch.split(t248, (1, 1, 1), 2)
# t249 = prims.slice_prim(t248, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t249: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t250 = prims.slice_prim(t248, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t250: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t251 = prims.slice_prim(t248, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t251: "cuda:0 bf16[1, 32, 1, 512, 128]"
del t248
t252 = torch.reshape(t249, (1, 32, 512, 128)) # t252: "cuda:0 bf16[1, 32, 512, 128]"
# t252 = ltorch.reshape(t249, (1, 32, 512, 128)) # t252: "cuda:0 bf16[1, 32, 512, 128]"
# t252 = prims.reshape(t249, (1, 32, 512, 128)) # t252: "cuda:0 bf16[1, 32, 512, 128]"
del t249
t253 = torch.reshape(t250, (1, 32, 512, 128)) # t253: "cuda:0 bf16[1, 32, 512, 128]"
# t253 = ltorch.reshape(t250, (1, 32, 512, 128)) # t253: "cuda:0 bf16[1, 32, 512, 128]"
# t253 = prims.reshape(t250, (1, 32, 512, 128)) # t253: "cuda:0 bf16[1, 32, 512, 128]"
del t250
t254 = torch.reshape(t251, (1, 32, 512, 128)) # t254: "cuda:0 bf16[1, 32, 512, 128]"
# t254 = ltorch.reshape(t251, (1, 32, 512, 128)) # t254: "cuda:0 bf16[1, 32, 512, 128]"
# t254 = prims.reshape(t251, (1, 32, 512, 128)) # t254: "cuda:0 bf16[1, 32, 512, 128]"
del t251
t285 = torch_slice_prim_impl(t252, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t285: "cuda:0 bf16[1, 32, 512, 0]"
t287 = torch_slice_prim_impl(t253, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t287: "cuda:0 bf16[1, 32, 512, 0]"
t255 = torch_slice_prim_impl(t252, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t255: "cuda:0 bf16[1, 32, 512, 128]"
del t252
t270 = torch_slice_prim_impl(t253, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t270: "cuda:0 bf16[1, 32, 512, 128]"
del t253
t256 = torch_slice_prim_impl(t255, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t256: "cuda:0 bf16[1, 32, 512, 64]"
t257 = torch_slice_prim_impl(t255, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t257: "cuda:0 bf16[1, 32, 512, 64]"
t272 = torch_slice_prim_impl(t270, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t272: "cuda:0 bf16[1, 32, 512, 64]"
t271 = torch_slice_prim_impl(t270, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t271: "cuda:0 bf16[1, 32, 512, 64]"
[t260, t275] = nvFusion6(t255, t257, t270, t272)
# t258 = prims.convert_element_type(t257, dtypes.float32) # t258: "cuda:0 f32[1, 32, 512, 64]"
# t259 = prims.neg(t258) # t259: "cuda:0 f32[1, 32, 512, 64]"
# t260 = prims.convert_element_type(t259, dtypes.bfloat16) # t260: "cuda:0 bf16[1, 32, 512, 64]"
# t273 = prims.convert_element_type(t272, dtypes.float32) # t273: "cuda:0 f32[1, 32, 512, 64]"
# t274 = prims.neg(t273) # t274: "cuda:0 f32[1, 32, 512, 64]"
# t275 = prims.convert_element_type(t274, dtypes.bfloat16) # t275: "cuda:0 bf16[1, 32, 512, 64]"
del t257, t272
t261 = torch.cat((t260, t256), -1) # t261: "cuda:0 bf16[1, 32, 512, 128]"
# t261 = ltorch.cat((t260, t256), -1) # t261: "cuda:0 bf16[1, 32, 512, 128]"
# t261 = prims.cat((t260, t256), -1) # t261: "cuda:0 bf16[1, 32, 512, 128]"
del t260, t256
t276 = torch.cat((t275, t271), -1) # t276: "cuda:0 bf16[1, 32, 512, 128]"
# t276 = ltorch.cat((t275, t271), -1) # t276: "cuda:0 bf16[1, 32, 512, 128]"
# t276 = prims.cat((t275, t271), -1) # t276: "cuda:0 bf16[1, 32, 512, 128]"
del t275, t271
[t269, t284] = nvFusion7(t154, t157, t255, t261, t270, t276)
# t263 = prims.convert_element_type(t255, dtypes.float32) # t263: "cuda:0 f32[1, 32, 512, 128]"
# t278 = prims.convert_element_type(t270, dtypes.float32) # t278: "cuda:0 f32[1, 32, 512, 128]"
# t264 = prims.mul(t263, t154) # t264: "cuda:0 f32[1, 32, 512, 128]"
# t266 = prims.convert_element_type(t261, dtypes.float32) # t266: "cuda:0 f32[1, 32, 512, 128]"
# t267 = prims.mul(t266, t157) # t267: "cuda:0 f32[1, 32, 512, 128]"
# t268 = prims.add(t264, t267) # t268: "cuda:0 f32[1, 32, 512, 128]"
# t269 = prims.convert_element_type(t268, dtypes.bfloat16) # t269: "cuda:0 bf16[1, 32, 512, 128]"
# t279 = prims.mul(t278, t154) # t279: "cuda:0 f32[1, 32, 512, 128]"
# t281 = prims.convert_element_type(t276, dtypes.float32) # t281: "cuda:0 f32[1, 32, 512, 128]"
# t282 = prims.mul(t281, t157) # t282: "cuda:0 f32[1, 32, 512, 128]"
# t283 = prims.add(t279, t282) # t283: "cuda:0 f32[1, 32, 512, 128]"
# t284 = prims.convert_element_type(t283, dtypes.bfloat16) # t284: "cuda:0 bf16[1, 32, 512, 128]"
del t255, t261, t270, t276
t288 = torch.cat((t284, t287), -1) # t288: "cuda:0 bf16[1, 32, 512, 128]"
# t288 = ltorch.cat((t284, t287), -1) # t288: "cuda:0 bf16[1, 32, 512, 128]"
# t288 = prims.cat((t284, t287), -1) # t288: "cuda:0 bf16[1, 32, 512, 128]"
del t284, t287
t286 = torch.cat((t269, t285), -1) # t286: "cuda:0 bf16[1, 32, 512, 128]"
# t286 = ltorch.cat((t269, t285), -1) # t286: "cuda:0 bf16[1, 32, 512, 128]"
# t286 = prims.cat((t269, t285), -1) # t286: "cuda:0 bf16[1, 32, 512, 128]"
del t269, t285
(t289, t290, t291, t292, _, _, t293, t294, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t286, t288, t254, 0.0, True, scale=0.08838834764831843)
t296 = torch.permute(t289, (0, 2, 1, 3)) # t296: "cuda:0 bf16[1, 512, 32, 128]"
# t296 = ltorch.permute(t289, (0, 2, 1, 3)) # t296: "cuda:0 bf16[1, 512, 32, 128]"
# t296 = prims.transpose(t289, (0, 2, 1, 3)) # t296: "cuda:0 bf16[1, 512, 32, 128]"
t297 = torch.reshape(t296, (1, 512, 4096)) # t297: "cuda:0 bf16[1, 512, 4096]"
# t297 = ltorch.reshape(t296, (1, 512, 4096)) # t297: "cuda:0 bf16[1, 512, 4096]"
# t297 = prims.reshape(t296, (1, 512, 4096)) # t297: "cuda:0 bf16[1, 512, 4096]"
del t296
t298 = torch.nn.functional.linear(t297, t87, None) # t298: "cuda:0 bf16[1, 512, 4096]"
# t298 = ltorch.linear(t297, t87, None) # t298: "cuda:0 bf16[1, 512, 4096]"
# t298 = prims.linear(t297, t87, None) # t298: "cuda:0 bf16[1, 512, 4096]"
[t302, t309, t317] = nvFusion8(t230, t298, t313)
# t300 = prims.convert_element_type(t230, dtypes.float32) # t300: "cuda:0 f32[1, 512, 4096]"
# t299 = prims.convert_element_type(t298, dtypes.float32) # t299: "cuda:0 f32[1, 512, 4096]"
# t301 = prims.add(t299, t300) # t301: "cuda:0 f32[1, 512, 4096]"
# t302 = prims.convert_element_type(t301, dtypes.bfloat16) # t302: "cuda:0 bf16[1, 512, 4096]"
# t304 = prims.mul(t301, t301) # t304: "cuda:0 f32[1, 512, 4096]"
# t305 = prims.sum(t304, (2,)) # t305: "cuda:0 f32[1, 512]"
# t306 = prims.broadcast_in_dim(t305, [1, 512, 1], [0, 1]) # t306: "cuda:0 f32[1, 512, 1]"
# t307 = prims.div(t306, 4096.0) # t307: "cuda:0 f32[1, 512, 1]"
# t308 = prims.add(t307, 1e-05) # t308: "cuda:0 f32[1, 512, 1]"
# t309 = prims.rsqrt(t308) # t309: "cuda:0 f32[1, 512, 1]"
# t310 = prims.broadcast_in_dim(t309, (1, 512, 4096), (0, 1, 2)) # t310: "cuda:0 f32[1, 512, 4096]"
# t311 = prims.mul(t301, t310) # t311: "cuda:0 f32[1, 512, 4096]"
# t315 = prims.convert_element_type(t313, dtypes.float32) # t315: "cuda:0 f32[1, 512, 4096]"
# t316 = prims.mul(t311, t315) # t316: "cuda:0 f32[1, 512, 4096]"
# t317 = prims.convert_element_type(t316, dtypes.bfloat16) # t317: "cuda:0 bf16[1, 512, 4096]"
t318 = torch.nn.functional.linear(t317, t20, None) # t318: "cuda:0 bf16[1, 512, 11008]"
# t318 = ltorch.linear(t317, t20, None) # t318: "cuda:0 bf16[1, 512, 11008]"
# t318 = prims.linear(t317, t20, None) # t318: "cuda:0 bf16[1, 512, 11008]"
t319 = torch.nn.functional.linear(t317, t36, None) # t319: "cuda:0 bf16[1, 512, 11008]"
# t319 = ltorch.linear(t317, t36, None) # t319: "cuda:0 bf16[1, 512, 11008]"
# t319 = prims.linear(t317, t36, None) # t319: "cuda:0 bf16[1, 512, 11008]"
[t333] = nvFusion9(t318, t319)
# t320 = prims.convert_element_type(t318, dtypes.float32) # t320: "cuda:0 f32[1, 512, 11008]"
# t321 = prims.neg(t320) # t321: "cuda:0 f32[1, 512, 11008]"
# t322 = prims.exp(t321) # t322: "cuda:0 f32[1, 512, 11008]"
# t323 = prims.add(1.0, t322) # t323: "cuda:0 f32[1, 512, 11008]"
# t324 = prims.reciprocal(t323) # t324: "cuda:0 f32[1, 512, 11008]"
# t328 = prims.mul(t320, t324) # t328: "cuda:0 f32[1, 512, 11008]"
# t331 = prims.convert_element_type(t319, dtypes.float32) # t331: "cuda:0 f32[1, 512, 11008]"
# t332 = prims.mul(t328, t331) # t332: "cuda:0 f32[1, 512, 11008]"
# t333 = prims.convert_element_type(t332, dtypes.bfloat16) # t333: "cuda:0 bf16[1, 512, 11008]"
t334 = torch.nn.functional.linear(t333, t88, None) # t334: "cuda:0 bf16[1, 512, 4096]"
# t334 = ltorch.linear(t333, t88, None) # t334: "cuda:0 bf16[1, 512, 4096]"
# t334 = prims.linear(t333, t88, None) # t334: "cuda:0 bf16[1, 512, 4096]"
[t338, t345, t353] = nvFusion10(t302, t334, t349)
# t336 = prims.convert_element_type(t302, dtypes.float32) # t336: "cuda:0 f32[1, 512, 4096]"
# t335 = prims.convert_element_type(t334, dtypes.float32) # t335: "cuda:0 f32[1, 512, 4096]"
# t337 = prims.add(t335, t336) # t337: "cuda:0 f32[1, 512, 4096]"
# t338 = prims.convert_element_type(t337, dtypes.bfloat16) # t338: "cuda:0 bf16[1, 512, 4096]"
# t340 = prims.mul(t337, t337) # t340: "cuda:0 f32[1, 512, 4096]"
# t341 = prims.sum(t340, (2,)) # t341: "cuda:0 f32[1, 512]"
# t342 = prims.broadcast_in_dim(t341, [1, 512, 1], [0, 1]) # t342: "cuda:0 f32[1, 512, 1]"
# t343 = prims.div(t342, 4096.0) # t343: "cuda:0 f32[1, 512, 1]"
# t344 = prims.add(t343, 1e-05) # t344: "cuda:0 f32[1, 512, 1]"
# t345 = prims.rsqrt(t344) # t345: "cuda:0 f32[1, 512, 1]"
# t346 = prims.broadcast_in_dim(t345, (1, 512, 4096), (0, 1, 2)) # t346: "cuda:0 f32[1, 512, 4096]"
# t347 = prims.mul(t337, t346) # t347: "cuda:0 f32[1, 512, 4096]"
# t351 = prims.convert_element_type(t349, dtypes.float32) # t351: "cuda:0 f32[1, 512, 4096]"
# t352 = prims.mul(t347, t351) # t352: "cuda:0 f32[1, 512, 4096]"
# t353 = prims.convert_element_type(t352, dtypes.bfloat16) # t353: "cuda:0 bf16[1, 512, 4096]"
t354 = torch.nn.functional.linear(t353, t5, None) # t354: "cuda:0 bf16[1, 512, 12288]"
# t354 = ltorch.linear(t353, t5, None) # t354: "cuda:0 bf16[1, 512, 12288]"
# t354 = prims.linear(t353, t5, None) # t354: "cuda:0 bf16[1, 512, 12288]"
t355 = torch.reshape(t354, (1, 512, 32, 3, 128)) # t355: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t355 = ltorch.reshape(t354, (1, 512, 32, 3, 128)) # t355: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t355 = prims.reshape(t354, (1, 512, 32, 3, 128)) # t355: "cuda:0 bf16[1, 512, 32, 3, 128]"
del t354
t356 = torch.permute(t355, (0, 2, 3, 1, 4)) # t356: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t356 = ltorch.permute(t355, (0, 2, 3, 1, 4)) # t356: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t356 = prims.transpose(t355, (0, 2, 3, 1, 4)) # t356: "cuda:0 bf16[1, 32, 3, 512, 128]"
del t355
(t357, t358, t359) = torch.split(t356, (1, 1, 1), 2)
# (t357, t358, t359) = ltorch.split(t356, (1, 1, 1), 2)
# t357 = prims.slice_prim(t356, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t357: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t358 = prims.slice_prim(t356, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t358: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t359 = prims.slice_prim(t356, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t359: "cuda:0 bf16[1, 32, 1, 512, 128]"
del t356
t360 = torch.reshape(t357, (1, 32, 512, 128)) # t360: "cuda:0 bf16[1, 32, 512, 128]"
# t360 = ltorch.reshape(t357, (1, 32, 512, 128)) # t360: "cuda:0 bf16[1, 32, 512, 128]"
# t360 = prims.reshape(t357, (1, 32, 512, 128)) # t360: "cuda:0 bf16[1, 32, 512, 128]"
del t357
t361 = torch.reshape(t358, (1, 32, 512, 128)) # t361: "cuda:0 bf16[1, 32, 512, 128]"
# t361 = ltorch.reshape(t358, (1, 32, 512, 128)) # t361: "cuda:0 bf16[1, 32, 512, 128]"
# t361 = prims.reshape(t358, (1, 32, 512, 128)) # t361: "cuda:0 bf16[1, 32, 512, 128]"
del t358
t362 = torch.reshape(t359, (1, 32, 512, 128)) # t362: "cuda:0 bf16[1, 32, 512, 128]"
# t362 = ltorch.reshape(t359, (1, 32, 512, 128)) # t362: "cuda:0 bf16[1, 32, 512, 128]"
# t362 = prims.reshape(t359, (1, 32, 512, 128)) # t362: "cuda:0 bf16[1, 32, 512, 128]"
del t359
t363 = torch_slice_prim_impl(t360, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t363: "cuda:0 bf16[1, 32, 512, 128]"
t378 = torch_slice_prim_impl(t361, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t378: "cuda:0 bf16[1, 32, 512, 128]"
t393 = torch_slice_prim_impl(t360, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t393: "cuda:0 bf16[1, 32, 512, 0]"
del t360
t395 = torch_slice_prim_impl(t361, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t395: "cuda:0 bf16[1, 32, 512, 0]"
del t361
t364 = torch_slice_prim_impl(t363, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t364: "cuda:0 bf16[1, 32, 512, 64]"
t365 = torch_slice_prim_impl(t363, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t365: "cuda:0 bf16[1, 32, 512, 64]"
t379 = torch_slice_prim_impl(t378, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t379: "cuda:0 bf16[1, 32, 512, 64]"
t380 = torch_slice_prim_impl(t378, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t380: "cuda:0 bf16[1, 32, 512, 64]"
[t368, t383] = nvFusion11(t363, t365, t378, t380)
# t366 = prims.convert_element_type(t365, dtypes.float32) # t366: "cuda:0 f32[1, 32, 512, 64]"
# t367 = prims.neg(t366) # t367: "cuda:0 f32[1, 32, 512, 64]"
# t368 = prims.convert_element_type(t367, dtypes.bfloat16) # t368: "cuda:0 bf16[1, 32, 512, 64]"
# t381 = prims.convert_element_type(t380, dtypes.float32) # t381: "cuda:0 f32[1, 32, 512, 64]"
# t382 = prims.neg(t381) # t382: "cuda:0 f32[1, 32, 512, 64]"
# t383 = prims.convert_element_type(t382, dtypes.bfloat16) # t383: "cuda:0 bf16[1, 32, 512, 64]"
del t365, t380
t369 = torch.cat((t368, t364), -1) # t369: "cuda:0 bf16[1, 32, 512, 128]"
# t369 = ltorch.cat((t368, t364), -1) # t369: "cuda:0 bf16[1, 32, 512, 128]"
# t369 = prims.cat((t368, t364), -1) # t369: "cuda:0 bf16[1, 32, 512, 128]"
del t368, t364
t384 = torch.cat((t383, t379), -1) # t384: "cuda:0 bf16[1, 32, 512, 128]"
# t384 = ltorch.cat((t383, t379), -1) # t384: "cuda:0 bf16[1, 32, 512, 128]"
# t384 = prims.cat((t383, t379), -1) # t384: "cuda:0 bf16[1, 32, 512, 128]"
del t383, t379
[t377, t392] = nvFusion12(t154, t157, t363, t369, t378, t384)
# t371 = prims.convert_element_type(t363, dtypes.float32) # t371: "cuda:0 f32[1, 32, 512, 128]"
# t386 = prims.convert_element_type(t378, dtypes.float32) # t386: "cuda:0 f32[1, 32, 512, 128]"
# t372 = prims.mul(t371, t154) # t372: "cuda:0 f32[1, 32, 512, 128]"
# t374 = prims.convert_element_type(t369, dtypes.float32) # t374: "cuda:0 f32[1, 32, 512, 128]"
# t375 = prims.mul(t374, t157) # t375: "cuda:0 f32[1, 32, 512, 128]"
# t376 = prims.add(t372, t375) # t376: "cuda:0 f32[1, 32, 512, 128]"
# t377 = prims.convert_element_type(t376, dtypes.bfloat16) # t377: "cuda:0 bf16[1, 32, 512, 128]"
# t387 = prims.mul(t386, t154) # t387: "cuda:0 f32[1, 32, 512, 128]"
# t389 = prims.convert_element_type(t384, dtypes.float32) # t389: "cuda:0 f32[1, 32, 512, 128]"
# t390 = prims.mul(t389, t157) # t390: "cuda:0 f32[1, 32, 512, 128]"
# t391 = prims.add(t387, t390) # t391: "cuda:0 f32[1, 32, 512, 128]"
# t392 = prims.convert_element_type(t391, dtypes.bfloat16) # t392: "cuda:0 bf16[1, 32, 512, 128]"
del t363, t369, t378, t384
t394 = torch.cat((t377, t393), -1) # t394: "cuda:0 bf16[1, 32, 512, 128]"
# t394 = ltorch.cat((t377, t393), -1) # t394: "cuda:0 bf16[1, 32, 512, 128]"
# t394 = prims.cat((t377, t393), -1) # t394: "cuda:0 bf16[1, 32, 512, 128]"
del t377, t393
t396 = torch.cat((t392, t395), -1) # t396: "cuda:0 bf16[1, 32, 512, 128]"
# t396 = ltorch.cat((t392, t395), -1) # t396: "cuda:0 bf16[1, 32, 512, 128]"
# t396 = prims.cat((t392, t395), -1) # t396: "cuda:0 bf16[1, 32, 512, 128]"
del t392, t395
(t397, t398, t399, t400, _, _, t401, t402, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t394, t396, t362, 0.0, True, scale=0.08838834764831843)
t404 = torch.permute(t397, (0, 2, 1, 3)) # t404: "cuda:0 bf16[1, 512, 32, 128]"
# t404 = ltorch.permute(t397, (0, 2, 1, 3)) # t404: "cuda:0 bf16[1, 512, 32, 128]"
# t404 = prims.transpose(t397, (0, 2, 1, 3)) # t404: "cuda:0 bf16[1, 512, 32, 128]"
t405 = torch.reshape(t404, (1, 512, 4096)) # t405: "cuda:0 bf16[1, 512, 4096]"
# t405 = ltorch.reshape(t404, (1, 512, 4096)) # t405: "cuda:0 bf16[1, 512, 4096]"
# t405 = prims.reshape(t404, (1, 512, 4096)) # t405: "cuda:0 bf16[1, 512, 4096]"
del t404
t406 = torch.nn.functional.linear(t405, t89, None) # t406: "cuda:0 bf16[1, 512, 4096]"
# t406 = ltorch.linear(t405, t89, None) # t406: "cuda:0 bf16[1, 512, 4096]"
# t406 = prims.linear(t405, t89, None) # t406: "cuda:0 bf16[1, 512, 4096]"
[t410, t417, t425] = nvFusion13(t338, t406, t421)
# t408 = prims.convert_element_type(t338, dtypes.float32) # t408: "cuda:0 f32[1, 512, 4096]"
# t407 = prims.convert_element_type(t406, dtypes.float32) # t407: "cuda:0 f32[1, 512, 4096]"
# t409 = prims.add(t407, t408) # t409: "cuda:0 f32[1, 512, 4096]"
# t410 = prims.convert_element_type(t409, dtypes.bfloat16) # t410: "cuda:0 bf16[1, 512, 4096]"
# t412 = prims.mul(t409, t409) # t412: "cuda:0 f32[1, 512, 4096]"
# t413 = prims.sum(t412, (2,)) # t413: "cuda:0 f32[1, 512]"
# t414 = prims.broadcast_in_dim(t413, [1, 512, 1], [0, 1]) # t414: "cuda:0 f32[1, 512, 1]"
# t415 = prims.div(t414, 4096.0) # t415: "cuda:0 f32[1, 512, 1]"
# t416 = prims.add(t415, 1e-05) # t416: "cuda:0 f32[1, 512, 1]"
# t417 = prims.rsqrt(t416) # t417: "cuda:0 f32[1, 512, 1]"
# t418 = prims.broadcast_in_dim(t417, (1, 512, 4096), (0, 1, 2)) # t418: "cuda:0 f32[1, 512, 4096]"
# t419 = prims.mul(t409, t418) # t419: "cuda:0 f32[1, 512, 4096]"
# t423 = prims.convert_element_type(t421, dtypes.float32) # t423: "cuda:0 f32[1, 512, 4096]"
# t424 = prims.mul(t419, t423) # t424: "cuda:0 f32[1, 512, 4096]"
# t425 = prims.convert_element_type(t424, dtypes.bfloat16) # t425: "cuda:0 bf16[1, 512, 4096]"
t426 = torch.nn.functional.linear(t425, t21, None) # t426: "cuda:0 bf16[1, 512, 11008]"
# t426 = ltorch.linear(t425, t21, None) # t426: "cuda:0 bf16[1, 512, 11008]"
# t426 = prims.linear(t425, t21, None) # t426: "cuda:0 bf16[1, 512, 11008]"
t427 = torch.nn.functional.linear(t425, t37, None) # t427: "cuda:0 bf16[1, 512, 11008]"
# t427 = ltorch.linear(t425, t37, None) # t427: "cuda:0 bf16[1, 512, 11008]"
# t427 = prims.linear(t425, t37, None) # t427: "cuda:0 bf16[1, 512, 11008]"
[t441] = nvFusion14(t426, t427)
# t428 = prims.convert_element_type(t426, dtypes.float32) # t428: "cuda:0 f32[1, 512, 11008]"
# t429 = prims.neg(t428) # t429: "cuda:0 f32[1, 512, 11008]"
# t430 = prims.exp(t429) # t430: "cuda:0 f32[1, 512, 11008]"
# t431 = prims.add(1.0, t430) # t431: "cuda:0 f32[1, 512, 11008]"
# t432 = prims.reciprocal(t431) # t432: "cuda:0 f32[1, 512, 11008]"
# t436 = prims.mul(t428, t432) # t436: "cuda:0 f32[1, 512, 11008]"
# t439 = prims.convert_element_type(t427, dtypes.float32) # t439: "cuda:0 f32[1, 512, 11008]"
# t440 = prims.mul(t436, t439) # t440: "cuda:0 f32[1, 512, 11008]"
# t441 = prims.convert_element_type(t440, dtypes.bfloat16) # t441: "cuda:0 bf16[1, 512, 11008]"
t442 = torch.nn.functional.linear(t441, t90, None) # t442: "cuda:0 bf16[1, 512, 4096]"
# t442 = ltorch.linear(t441, t90, None) # t442: "cuda:0 bf16[1, 512, 4096]"
# t442 = prims.linear(t441, t90, None) # t442: "cuda:0 bf16[1, 512, 4096]"
[t446, t453, t461] = nvFusion15(t410, t442, t457)
# t444 = prims.convert_element_type(t410, dtypes.float32) # t444: "cuda:0 f32[1, 512, 4096]"
# t443 = prims.convert_element_type(t442, dtypes.float32) # t443: "cuda:0 f32[1, 512, 4096]"
# t445 = prims.add(t443, t444) # t445: "cuda:0 f32[1, 512, 4096]"
# t446 = prims.convert_element_type(t445, dtypes.bfloat16) # t446: "cuda:0 bf16[1, 512, 4096]"
# t448 = prims.mul(t445, t445) # t448: "cuda:0 f32[1, 512, 4096]"
# t449 = prims.sum(t448, (2,)) # t449: "cuda:0 f32[1, 512]"
# t450 = prims.broadcast_in_dim(t449, [1, 512, 1], [0, 1]) # t450: "cuda:0 f32[1, 512, 1]"
# t451 = prims.div(t450, 4096.0) # t451: "cuda:0 f32[1, 512, 1]"
# t452 = prims.add(t451, 1e-05) # t452: "cuda:0 f32[1, 512, 1]"
# t453 = prims.rsqrt(t452) # t453: "cuda:0 f32[1, 512, 1]"
# t454 = prims.broadcast_in_dim(t453, (1, 512, 4096), (0, 1, 2)) # t454: "cuda:0 f32[1, 512, 4096]"
# t455 = prims.mul(t445, t454) # t455: "cuda:0 f32[1, 512, 4096]"
# t459 = prims.convert_element_type(t457, dtypes.float32) # t459: "cuda:0 f32[1, 512, 4096]"
# t460 = prims.mul(t455, t459) # t460: "cuda:0 f32[1, 512, 4096]"
# t461 = prims.convert_element_type(t460, dtypes.bfloat16) # t461: "cuda:0 bf16[1, 512, 4096]"
t462 = torch.nn.functional.linear(t461, t6, None) # t462: "cuda:0 bf16[1, 512, 12288]"
# t462 = ltorch.linear(t461, t6, None) # t462: "cuda:0 bf16[1, 512, 12288]"
# t462 = prims.linear(t461, t6, None) # t462: "cuda:0 bf16[1, 512, 12288]"
t463 = torch.reshape(t462, (1, 512, 32, 3, 128)) # t463: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t463 = ltorch.reshape(t462, (1, 512, 32, 3, 128)) # t463: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t463 = prims.reshape(t462, (1, 512, 32, 3, 128)) # t463: "cuda:0 bf16[1, 512, 32, 3, 128]"
del t462
t464 = torch.permute(t463, (0, 2, 3, 1, 4)) # t464: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t464 = ltorch.permute(t463, (0, 2, 3, 1, 4)) # t464: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t464 = prims.transpose(t463, (0, 2, 3, 1, 4)) # t464: "cuda:0 bf16[1, 32, 3, 512, 128]"
del t463
(t465, t466, t467) = torch.split(t464, (1, 1, 1), 2)
# (t465, t466, t467) = ltorch.split(t464, (1, 1, 1), 2)
# t465 = prims.slice_prim(t464, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t465: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t466 = prims.slice_prim(t464, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t466: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t467 = prims.slice_prim(t464, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t467: "cuda:0 bf16[1, 32, 1, 512, 128]"
del t464
t468 = torch.reshape(t465, (1, 32, 512, 128)) # t468: "cuda:0 bf16[1, 32, 512, 128]"
# t468 = ltorch.reshape(t465, (1, 32, 512, 128)) # t468: "cuda:0 bf16[1, 32, 512, 128]"
# t468 = prims.reshape(t465, (1, 32, 512, 128)) # t468: "cuda:0 bf16[1, 32, 512, 128]"
del t465
t469 = torch.reshape(t466, (1, 32, 512, 128)) # t469: "cuda:0 bf16[1, 32, 512, 128]"
# t469 = ltorch.reshape(t466, (1, 32, 512, 128)) # t469: "cuda:0 bf16[1, 32, 512, 128]"
# t469 = prims.reshape(t466, (1, 32, 512, 128)) # t469: "cuda:0 bf16[1, 32, 512, 128]"
del t466
t470 = torch.reshape(t467, (1, 32, 512, 128)) # t470: "cuda:0 bf16[1, 32, 512, 128]"
# t470 = ltorch.reshape(t467, (1, 32, 512, 128)) # t470: "cuda:0 bf16[1, 32, 512, 128]"
# t470 = prims.reshape(t467, (1, 32, 512, 128)) # t470: "cuda:0 bf16[1, 32, 512, 128]"
del t467
t471 = torch_slice_prim_impl(t468, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t471: "cuda:0 bf16[1, 32, 512, 128]"
t486 = torch_slice_prim_impl(t469, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t486: "cuda:0 bf16[1, 32, 512, 128]"
t501 = torch_slice_prim_impl(t468, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t501: "cuda:0 bf16[1, 32, 512, 0]"
del t468
t503 = torch_slice_prim_impl(t469, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t503: "cuda:0 bf16[1, 32, 512, 0]"
del t469
t472 = torch_slice_prim_impl(t471, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t472: "cuda:0 bf16[1, 32, 512, 64]"
t473 = torch_slice_prim_impl(t471, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t473: "cuda:0 bf16[1, 32, 512, 64]"
t487 = torch_slice_prim_impl(t486, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t487: "cuda:0 bf16[1, 32, 512, 64]"
t488 = torch_slice_prim_impl(t486, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t488: "cuda:0 bf16[1, 32, 512, 64]"
[t476, t491] = nvFusion16(t471, t473, t486, t488)
# t474 = prims.convert_element_type(t473, dtypes.float32) # t474: "cuda:0 f32[1, 32, 512, 64]"
# t475 = prims.neg(t474) # t475: "cuda:0 f32[1, 32, 512, 64]"
# t476 = prims.convert_element_type(t475, dtypes.bfloat16) # t476: "cuda:0 bf16[1, 32, 512, 64]"
# t489 = prims.convert_element_type(t488, dtypes.float32) # t489: "cuda:0 f32[1, 32, 512, 64]"
# t490 = prims.neg(t489) # t490: "cuda:0 f32[1, 32, 512, 64]"
# t491 = prims.convert_element_type(t490, dtypes.bfloat16) # t491: "cuda:0 bf16[1, 32, 512, 64]"
del t473, t488
t477 = torch.cat((t476, t472), -1) # t477: "cuda:0 bf16[1, 32, 512, 128]"
# t477 = ltorch.cat((t476, t472), -1) # t477: "cuda:0 bf16[1, 32, 512, 128]"
# t477 = prims.cat((t476, t472), -1) # t477: "cuda:0 bf16[1, 32, 512, 128]"
del t476, t472
t492 = torch.cat((t491, t487), -1) # t492: "cuda:0 bf16[1, 32, 512, 128]"
# t492 = ltorch.cat((t491, t487), -1) # t492: "cuda:0 bf16[1, 32, 512, 128]"
# t492 = prims.cat((t491, t487), -1) # t492: "cuda:0 bf16[1, 32, 512, 128]"
del t491, t487
[t485, t500] = nvFusion17(t154, t157, t471, t477, t486, t492)
# t479 = prims.convert_element_type(t471, dtypes.float32) # t479: "cuda:0 f32[1, 32, 512, 128]"
# t494 = prims.convert_element_type(t486, dtypes.float32) # t494: "cuda:0 f32[1, 32, 512, 128]"
# t480 = prims.mul(t479, t154) # t480: "cuda:0 f32[1, 32, 512, 128]"
# t482 = prims.convert_element_type(t477, dtypes.float32) # t482: "cuda:0 f32[1, 32, 512, 128]"
# t483 = prims.mul(t482, t157) # t483: "cuda:0 f32[1, 32, 512, 128]"
# t484 = prims.add(t480, t483) # t484: "cuda:0 f32[1, 32, 512, 128]"
# t485 = prims.convert_element_type(t484, dtypes.bfloat16) # t485: "cuda:0 bf16[1, 32, 512, 128]"
# t495 = prims.mul(t494, t154) # t495: "cuda:0 f32[1, 32, 512, 128]"
# t497 = prims.convert_element_type(t492, dtypes.float32) # t497: "cuda:0 f32[1, 32, 512, 128]"
# t498 = prims.mul(t497, t157) # t498: "cuda:0 f32[1, 32, 512, 128]"
# t499 = prims.add(t495, t498) # t499: "cuda:0 f32[1, 32, 512, 128]"
# t500 = prims.convert_element_type(t499, dtypes.bfloat16) # t500: "cuda:0 bf16[1, 32, 512, 128]"
del t471, t477, t486, t492
t502 = torch.cat((t485, t501), -1) # t502: "cuda:0 bf16[1, 32, 512, 128]"
# t502 = ltorch.cat((t485, t501), -1) # t502: "cuda:0 bf16[1, 32, 512, 128]"
# t502 = prims.cat((t485, t501), -1) # t502: "cuda:0 bf16[1, 32, 512, 128]"
del t485, t501
t504 = torch.cat((t500, t503), -1) # t504: "cuda:0 bf16[1, 32, 512, 128]"
# t504 = ltorch.cat((t500, t503), -1) # t504: "cuda:0 bf16[1, 32, 512, 128]"
# t504 = prims.cat((t500, t503), -1) # t504: "cuda:0 bf16[1, 32, 512, 128]"
del t500, t503
(t505, t506, t507, t508, _, _, t509, t510, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t502, t504, t470, 0.0, True, scale=0.08838834764831843)
t512 = torch.permute(t505, (0, 2, 1, 3)) # t512: "cuda:0 bf16[1, 512, 32, 128]"
# t512 = ltorch.permute(t505, (0, 2, 1, 3)) # t512: "cuda:0 bf16[1, 512, 32, 128]"
# t512 = prims.transpose(t505, (0, 2, 1, 3)) # t512: "cuda:0 bf16[1, 512, 32, 128]"
t513 = torch.reshape(t512, (1, 512, 4096)) # t513: "cuda:0 bf16[1, 512, 4096]"
# t513 = ltorch.reshape(t512, (1, 512, 4096)) # t513: "cuda:0 bf16[1, 512, 4096]"
# t513 = prims.reshape(t512, (1, 512, 4096)) # t513: "cuda:0 bf16[1, 512, 4096]"
del t512
t514 = torch.nn.functional.linear(t513, t91, None) # t514: "cuda:0 bf16[1, 512, 4096]"
# t514 = ltorch.linear(t513, t91, None) # t514: "cuda:0 bf16[1, 512, 4096]"
# t514 = prims.linear(t513, t91, None) # t514: "cuda:0 bf16[1, 512, 4096]"
[t518, t525, t533] = nvFusion18(t446, t514, t529)
# t516 = prims.convert_element_type(t446, dtypes.float32) # t516: "cuda:0 f32[1, 512, 4096]"
# t515 = prims.convert_element_type(t514, dtypes.float32) # t515: "cuda:0 f32[1, 512, 4096]"
# t517 = prims.add(t515, t516) # t517: "cuda:0 f32[1, 512, 4096]"
# t518 = prims.convert_element_type(t517, dtypes.bfloat16) # t518: "cuda:0 bf16[1, 512, 4096]"
# t520 = prims.mul(t517, t517) # t520: "cuda:0 f32[1, 512, 4096]"
# t521 = prims.sum(t520, (2,)) # t521: "cuda:0 f32[1, 512]"
# t522 = prims.broadcast_in_dim(t521, [1, 512, 1], [0, 1]) # t522: "cuda:0 f32[1, 512, 1]"
# t523 = prims.div(t522, 4096.0) # t523: "cuda:0 f32[1, 512, 1]"
# t524 = prims.add(t523, 1e-05) # t524: "cuda:0 f32[1, 512, 1]"
# t525 = prims.rsqrt(t524) # t525: "cuda:0 f32[1, 512, 1]"
# t526 = prims.broadcast_in_dim(t525, (1, 512, 4096), (0, 1, 2)) # t526: "cuda:0 f32[1, 512, 4096]"
# t527 = prims.mul(t517, t526) # t527: "cuda:0 f32[1, 512, 4096]"
# t531 = prims.convert_element_type(t529, dtypes.float32) # t531: "cuda:0 f32[1, 512, 4096]"
# t532 = prims.mul(t527, t531) # t532: "cuda:0 f32[1, 512, 4096]"
# t533 = prims.convert_element_type(t532, dtypes.bfloat16) # t533: "cuda:0 bf16[1, 512, 4096]"
t534 = torch.nn.functional.linear(t533, t22, None) # t534: "cuda:0 bf16[1, 512, 11008]"
# t534 = ltorch.linear(t533, t22, None) # t534: "cuda:0 bf16[1, 512, 11008]"
# t534 = prims.linear(t533, t22, None) # t534: "cuda:0 bf16[1, 512, 11008]"
t535 = torch.nn.functional.linear(t533, t38, None) # t535: "cuda:0 bf16[1, 512, 11008]"
# t535 = ltorch.linear(t533, t38, None) # t535: "cuda:0 bf16[1, 512, 11008]"
# t535 = prims.linear(t533, t38, None) # t535: "cuda:0 bf16[1, 512, 11008]"
[t549] = nvFusion19(t534, t535)
# t536 = prims.convert_element_type(t534, dtypes.float32) # t536: "cuda:0 f32[1, 512, 11008]"
# t537 = prims.neg(t536) # t537: "cuda:0 f32[1, 512, 11008]"
# t538 = prims.exp(t537) # t538: "cuda:0 f32[1, 512, 11008]"
# t539 = prims.add(1.0, t538) # t539: "cuda:0 f32[1, 512, 11008]"
# t540 = prims.reciprocal(t539) # t540: "cuda:0 f32[1, 512, 11008]"
# t544 = prims.mul(t536, t540) # t544: "cuda:0 f32[1, 512, 11008]"
# t547 = prims.convert_element_type(t535, dtypes.float32) # t547: "cuda:0 f32[1, 512, 11008]"
# t548 = prims.mul(t544, t547) # t548: "cuda:0 f32[1, 512, 11008]"
# t549 = prims.convert_element_type(t548, dtypes.bfloat16) # t549: "cuda:0 bf16[1, 512, 11008]"
t550 = torch.nn.functional.linear(t549, t92, None) # t550: "cuda:0 bf16[1, 512, 4096]"
# t550 = ltorch.linear(t549, t92, None) # t550: "cuda:0 bf16[1, 512, 4096]"
# t550 = prims.linear(t549, t92, None) # t550: "cuda:0 bf16[1, 512, 4096]"
[t554, t561, t569] = nvFusion20(t518, t550, t565)
# t552 = prims.convert_element_type(t518, dtypes.float32) # t552: "cuda:0 f32[1, 512, 4096]"
# t551 = prims.convert_element_type(t550, dtypes.float32) # t551: "cuda:0 f32[1, 512, 4096]"
# t553 = prims.add(t551, t552) # t553: "cuda:0 f32[1, 512, 4096]"
# t554 = prims.convert_element_type(t553, dtypes.bfloat16) # t554: "cuda:0 bf16[1, 512, 4096]"
# t556 = prims.mul(t553, t553) # t556: "cuda:0 f32[1, 512, 4096]"
# t557 = prims.sum(t556, (2,)) # t557: "cuda:0 f32[1, 512]"
# t558 = prims.broadcast_in_dim(t557, [1, 512, 1], [0, 1]) # t558: "cuda:0 f32[1, 512, 1]"
# t559 = prims.div(t558, 4096.0) # t559: "cuda:0 f32[1, 512, 1]"
# t560 = prims.add(t559, 1e-05) # t560: "cuda:0 f32[1, 512, 1]"
# t561 = prims.rsqrt(t560) # t561: "cuda:0 f32[1, 512, 1]"
# t562 = prims.broadcast_in_dim(t561, (1, 512, 4096), (0, 1, 2)) # t562: "cuda:0 f32[1, 512, 4096]"
# t563 = prims.mul(t553, t562) # t563: "cuda:0 f32[1, 512, 4096]"
# t567 = prims.convert_element_type(t565, dtypes.float32) # t567: "cuda:0 f32[1, 512, 4096]"
# t568 = prims.mul(t563, t567) # t568: "cuda:0 f32[1, 512, 4096]"
# t569 = prims.convert_element_type(t568, dtypes.bfloat16) # t569: "cuda:0 bf16[1, 512, 4096]"
t570 = torch.nn.functional.linear(t569, t7, None) # t570: "cuda:0 bf16[1, 512, 12288]"
# t570 = ltorch.linear(t569, t7, None) # t570: "cuda:0 bf16[1, 512, 12288]"
# t570 = prims.linear(t569, t7, None) # t570: "cuda:0 bf16[1, 512, 12288]"
t571 = torch.reshape(t570, (1, 512, 32, 3, 128)) # t571: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t571 = ltorch.reshape(t570, (1, 512, 32, 3, 128)) # t571: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t571 = prims.reshape(t570, (1, 512, 32, 3, 128)) # t571: "cuda:0 bf16[1, 512, 32, 3, 128]"
del t570
t572 = torch.permute(t571, (0, 2, 3, 1, 4)) # t572: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t572 = ltorch.permute(t571, (0, 2, 3, 1, 4)) # t572: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t572 = prims.transpose(t571, (0, 2, 3, 1, 4)) # t572: "cuda:0 bf16[1, 32, 3, 512, 128]"
del t571
(t573, t574, t575) = torch.split(t572, (1, 1, 1), 2)
# (t573, t574, t575) = ltorch.split(t572, (1, 1, 1), 2)
# t573 = prims.slice_prim(t572, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t573: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t574 = prims.slice_prim(t572, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t574: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t575 = prims.slice_prim(t572, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t575: "cuda:0 bf16[1, 32, 1, 512, 128]"
del t572
t576 = torch.reshape(t573, (1, 32, 512, 128)) # t576: "cuda:0 bf16[1, 32, 512, 128]"
# t576 = ltorch.reshape(t573, (1, 32, 512, 128)) # t576: "cuda:0 bf16[1, 32, 512, 128]"
# t576 = prims.reshape(t573, (1, 32, 512, 128)) # t576: "cuda:0 bf16[1, 32, 512, 128]"
del t573
t577 = torch.reshape(t574, (1, 32, 512, 128)) # t577: "cuda:0 bf16[1, 32, 512, 128]"
# t577 = ltorch.reshape(t574, (1, 32, 512, 128)) # t577: "cuda:0 bf16[1, 32, 512, 128]"
# t577 = prims.reshape(t574, (1, 32, 512, 128)) # t577: "cuda:0 bf16[1, 32, 512, 128]"
del t574
t578 = torch.reshape(t575, (1, 32, 512, 128)) # t578: "cuda:0 bf16[1, 32, 512, 128]"
# t578 = ltorch.reshape(t575, (1, 32, 512, 128)) # t578: "cuda:0 bf16[1, 32, 512, 128]"
# t578 = prims.reshape(t575, (1, 32, 512, 128)) # t578: "cuda:0 bf16[1, 32, 512, 128]"
del t575
t579 = torch_slice_prim_impl(t576, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t579: "cuda:0 bf16[1, 32, 512, 128]"
t594 = torch_slice_prim_impl(t577, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t594: "cuda:0 bf16[1, 32, 512, 128]"
t609 = torch_slice_prim_impl(t576, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t609: "cuda:0 bf16[1, 32, 512, 0]"
del t576
t611 = torch_slice_prim_impl(t577, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t611: "cuda:0 bf16[1, 32, 512, 0]"
del t577
t580 = torch_slice_prim_impl(t579, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t580: "cuda:0 bf16[1, 32, 512, 64]"
t581 = torch_slice_prim_impl(t579, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t581: "cuda:0 bf16[1, 32, 512, 64]"
t595 = torch_slice_prim_impl(t594, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t595: "cuda:0 bf16[1, 32, 512, 64]"
t596 = torch_slice_prim_impl(t594, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t596: "cuda:0 bf16[1, 32, 512, 64]"
[t584, t599] = nvFusion21(t579, t581, t594, t596)
# t582 = prims.convert_element_type(t581, dtypes.float32) # t582: "cuda:0 f32[1, 32, 512, 64]"
# t583 = prims.neg(t582) # t583: "cuda:0 f32[1, 32, 512, 64]"
# t584 = prims.convert_element_type(t583, dtypes.bfloat16) # t584: "cuda:0 bf16[1, 32, 512, 64]"
# t597 = prims.convert_element_type(t596, dtypes.float32) # t597: "cuda:0 f32[1, 32, 512, 64]"
# t598 = prims.neg(t597) # t598: "cuda:0 f32[1, 32, 512, 64]"
# t599 = prims.convert_element_type(t598, dtypes.bfloat16) # t599: "cuda:0 bf16[1, 32, 512, 64]"
del t581, t596
t600 = torch.cat((t599, t595), -1) # t600: "cuda:0 bf16[1, 32, 512, 128]"
# t600 = ltorch.cat((t599, t595), -1) # t600: "cuda:0 bf16[1, 32, 512, 128]"
# t600 = prims.cat((t599, t595), -1) # t600: "cuda:0 bf16[1, 32, 512, 128]"
del t599, t595
t585 = torch.cat((t584, t580), -1) # t585: "cuda:0 bf16[1, 32, 512, 128]"
# t585 = ltorch.cat((t584, t580), -1) # t585: "cuda:0 bf16[1, 32, 512, 128]"
# t585 = prims.cat((t584, t580), -1) # t585: "cuda:0 bf16[1, 32, 512, 128]"
del t584, t580
[t593, t608] = nvFusion22(t154, t157, t579, t585, t594, t600)
# t587 = prims.convert_element_type(t579, dtypes.float32) # t587: "cuda:0 f32[1, 32, 512, 128]"
# t602 = prims.convert_element_type(t594, dtypes.float32) # t602: "cuda:0 f32[1, 32, 512, 128]"
# t603 = prims.mul(t602, t154) # t603: "cuda:0 f32[1, 32, 512, 128]"
# t605 = prims.convert_element_type(t600, dtypes.float32) # t605: "cuda:0 f32[1, 32, 512, 128]"
# t606 = prims.mul(t605, t157) # t606: "cuda:0 f32[1, 32, 512, 128]"
# t607 = prims.add(t603, t606) # t607: "cuda:0 f32[1, 32, 512, 128]"
# t608 = prims.convert_element_type(t607, dtypes.bfloat16) # t608: "cuda:0 bf16[1, 32, 512, 128]"
# t588 = prims.mul(t587, t154) # t588: "cuda:0 f32[1, 32, 512, 128]"
# t590 = prims.convert_element_type(t585, dtypes.float32) # t590: "cuda:0 f32[1, 32, 512, 128]"
# t591 = prims.mul(t590, t157) # t591: "cuda:0 f32[1, 32, 512, 128]"
# t592 = prims.add(t588, t591) # t592: "cuda:0 f32[1, 32, 512, 128]"
# t593 = prims.convert_element_type(t592, dtypes.bfloat16) # t593: "cuda:0 bf16[1, 32, 512, 128]"
del t579, t585, t594, t600
t612 = torch.cat((t608, t611), -1) # t612: "cuda:0 bf16[1, 32, 512, 128]"
# t612 = ltorch.cat((t608, t611), -1) # t612: "cuda:0 bf16[1, 32, 512, 128]"
# t612 = prims.cat((t608, t611), -1) # t612: "cuda:0 bf16[1, 32, 512, 128]"
del t608, t611
t610 = torch.cat((t593, t609), -1) # t610: "cuda:0 bf16[1, 32, 512, 128]"
# t610 = ltorch.cat((t593, t609), -1) # t610: "cuda:0 bf16[1, 32, 512, 128]"
# t610 = prims.cat((t593, t609), -1) # t610: "cuda:0 bf16[1, 32, 512, 128]"
del t593, t609
(t613, t614, t615, t616, _, _, t617, t618, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t610, t612, t578, 0.0, True, scale=0.08838834764831843)
t620 = torch.permute(t613, (0, 2, 1, 3)) # t620: "cuda:0 bf16[1, 512, 32, 128]"
# t620 = ltorch.permute(t613, (0, 2, 1, 3)) # t620: "cuda:0 bf16[1, 512, 32, 128]"
# t620 = prims.transpose(t613, (0, 2, 1, 3)) # t620: "cuda:0 bf16[1, 512, 32, 128]"
t621 = torch.reshape(t620, (1, 512, 4096)) # t621: "cuda:0 bf16[1, 512, 4096]"
# t621 = ltorch.reshape(t620, (1, 512, 4096)) # t621: "cuda:0 bf16[1, 512, 4096]"
# t621 = prims.reshape(t620, (1, 512, 4096)) # t621: "cuda:0 bf16[1, 512, 4096]"
del t620
t622 = torch.nn.functional.linear(t621, t93, None) # t622: "cuda:0 bf16[1, 512, 4096]"
# t622 = ltorch.linear(t621, t93, None) # t622: "cuda:0 bf16[1, 512, 4096]"
# t622 = prims.linear(t621, t93, None) # t622: "cuda:0 bf16[1, 512, 4096]"
[t626, t633, t641] = nvFusion23(t554, t622, t637)
# t624 = prims.convert_element_type(t554, dtypes.float32) # t624: "cuda:0 f32[1, 512, 4096]"
# t623 = prims.convert_element_type(t622, dtypes.float32) # t623: "cuda:0 f32[1, 512, 4096]"
# t625 = prims.add(t623, t624) # t625: "cuda:0 f32[1, 512, 4096]"
# t626 = prims.convert_element_type(t625, dtypes.bfloat16) # t626: "cuda:0 bf16[1, 512, 4096]"
# t628 = prims.mul(t625, t625) # t628: "cuda:0 f32[1, 512, 4096]"
# t629 = prims.sum(t628, (2,)) # t629: "cuda:0 f32[1, 512]"
# t630 = prims.broadcast_in_dim(t629, [1, 512, 1], [0, 1]) # t630: "cuda:0 f32[1, 512, 1]"
# t631 = prims.div(t630, 4096.0) # t631: "cuda:0 f32[1, 512, 1]"
# t632 = prims.add(t631, 1e-05) # t632: "cuda:0 f32[1, 512, 1]"
# t633 = prims.rsqrt(t632) # t633: "cuda:0 f32[1, 512, 1]"
# t634 = prims.broadcast_in_dim(t633, (1, 512, 4096), (0, 1, 2)) # t634: "cuda:0 f32[1, 512, 4096]"
# t635 = prims.mul(t625, t634) # t635: "cuda:0 f32[1, 512, 4096]"
# t639 = prims.convert_element_type(t637, dtypes.float32) # t639: "cuda:0 f32[1, 512, 4096]"
# t640 = prims.mul(t635, t639) # t640: "cuda:0 f32[1, 512, 4096]"
# t641 = prims.convert_element_type(t640, dtypes.bfloat16) # t641: "cuda:0 bf16[1, 512, 4096]"
t643 = torch.nn.functional.linear(t641, t39, None) # t643: "cuda:0 bf16[1, 512, 11008]"
# t643 = ltorch.linear(t641, t39, None) # t643: "cuda:0 bf16[1, 512, 11008]"
# t643 = prims.linear(t641, t39, None) # t643: "cuda:0 bf16[1, 512, 11008]"
t642 = torch.nn.functional.linear(t641, t23, None) # t642: "cuda:0 bf16[1, 512, 11008]"
# t642 = ltorch.linear(t641, t23, None) # t642: "cuda:0 bf16[1, 512, 11008]"
# t642 = prims.linear(t641, t23, None) # t642: "cuda:0 bf16[1, 512, 11008]"
[t657] = nvFusion24(t642, t643)
# t644 = prims.convert_element_type(t642, dtypes.float32) # t644: "cuda:0 f32[1, 512, 11008]"
# t645 = prims.neg(t644) # t645: "cuda:0 f32[1, 512, 11008]"
# t646 = prims.exp(t645) # t646: "cuda:0 f32[1, 512, 11008]"
# t647 = prims.add(1.0, t646) # t647: "cuda:0 f32[1, 512, 11008]"
# t648 = prims.reciprocal(t647) # t648: "cuda:0 f32[1, 512, 11008]"
# t652 = prims.mul(t644, t648) # t652: "cuda:0 f32[1, 512, 11008]"
# t655 = prims.convert_element_type(t643, dtypes.float32) # t655: "cuda:0 f32[1, 512, 11008]"
# t656 = prims.mul(t652, t655) # t656: "cuda:0 f32[1, 512, 11008]"
# t657 = prims.convert_element_type(t656, dtypes.bfloat16) # t657: "cuda:0 bf16[1, 512, 11008]"
t658 = torch.nn.functional.linear(t657, t94, None) # t658: "cuda:0 bf16[1, 512, 4096]"
# t658 = ltorch.linear(t657, t94, None) # t658: "cuda:0 bf16[1, 512, 4096]"
# t658 = prims.linear(t657, t94, None) # t658: "cuda:0 bf16[1, 512, 4096]"
[t662, t669, t677] = nvFusion25(t626, t658, t673)
# t660 = prims.convert_element_type(t626, dtypes.float32) # t660: "cuda:0 f32[1, 512, 4096]"
# t659 = prims.convert_element_type(t658, dtypes.float32) # t659: "cuda:0 f32[1, 512, 4096]"
# t661 = prims.add(t659, t660) # t661: "cuda:0 f32[1, 512, 4096]"
# t662 = prims.convert_element_type(t661, dtypes.bfloat16) # t662: "cuda:0 bf16[1, 512, 4096]"
# t664 = prims.mul(t661, t661) # t664: "cuda:0 f32[1, 512, 4096]"
# t665 = prims.sum(t664, (2,)) # t665: "cuda:0 f32[1, 512]"
# t666 = prims.broadcast_in_dim(t665, [1, 512, 1], [0, 1]) # t666: "cuda:0 f32[1, 512, 1]"
# t667 = prims.div(t666, 4096.0) # t667: "cuda:0 f32[1, 512, 1]"
# t668 = prims.add(t667, 1e-05) # t668: "cuda:0 f32[1, 512, 1]"
# t669 = prims.rsqrt(t668) # t669: "cuda:0 f32[1, 512, 1]"
# t670 = prims.broadcast_in_dim(t669, (1, 512, 4096), (0, 1, 2)) # t670: "cuda:0 f32[1, 512, 4096]"
# t671 = prims.mul(t661, t670) # t671: "cuda:0 f32[1, 512, 4096]"
# t675 = prims.convert_element_type(t673, dtypes.float32) # t675: "cuda:0 f32[1, 512, 4096]"
# t676 = prims.mul(t671, t675) # t676: "cuda:0 f32[1, 512, 4096]"
# t677 = prims.convert_element_type(t676, dtypes.bfloat16) # t677: "cuda:0 bf16[1, 512, 4096]"
t678 = torch.nn.functional.linear(t677, t8, None) # t678: "cuda:0 bf16[1, 512, 12288]"
# t678 = ltorch.linear(t677, t8, None) # t678: "cuda:0 bf16[1, 512, 12288]"
# t678 = prims.linear(t677, t8, None) # t678: "cuda:0 bf16[1, 512, 12288]"
t679 = torch.reshape(t678, (1, 512, 32, 3, 128)) # t679: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t679 = ltorch.reshape(t678, (1, 512, 32, 3, 128)) # t679: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t679 = prims.reshape(t678, (1, 512, 32, 3, 128)) # t679: "cuda:0 bf16[1, 512, 32, 3, 128]"
del t678
t680 = torch.permute(t679, (0, 2, 3, 1, 4)) # t680: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t680 = ltorch.permute(t679, (0, 2, 3, 1, 4)) # t680: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t680 = prims.transpose(t679, (0, 2, 3, 1, 4)) # t680: "cuda:0 bf16[1, 32, 3, 512, 128]"
del t679
(t681, t682, t683) = torch.split(t680, (1, 1, 1), 2)
# (t681, t682, t683) = ltorch.split(t680, (1, 1, 1), 2)
# t681 = prims.slice_prim(t680, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t681: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t682 = prims.slice_prim(t680, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t682: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t683 = prims.slice_prim(t680, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t683: "cuda:0 bf16[1, 32, 1, 512, 128]"
del t680
t684 = torch.reshape(t681, (1, 32, 512, 128)) # t684: "cuda:0 bf16[1, 32, 512, 128]"
# t684 = ltorch.reshape(t681, (1, 32, 512, 128)) # t684: "cuda:0 bf16[1, 32, 512, 128]"
# t684 = prims.reshape(t681, (1, 32, 512, 128)) # t684: "cuda:0 bf16[1, 32, 512, 128]"
del t681
t685 = torch.reshape(t682, (1, 32, 512, 128)) # t685: "cuda:0 bf16[1, 32, 512, 128]"
# t685 = ltorch.reshape(t682, (1, 32, 512, 128)) # t685: "cuda:0 bf16[1, 32, 512, 128]"
# t685 = prims.reshape(t682, (1, 32, 512, 128)) # t685: "cuda:0 bf16[1, 32, 512, 128]"
del t682
t686 = torch.reshape(t683, (1, 32, 512, 128)) # t686: "cuda:0 bf16[1, 32, 512, 128]"
# t686 = ltorch.reshape(t683, (1, 32, 512, 128)) # t686: "cuda:0 bf16[1, 32, 512, 128]"
# t686 = prims.reshape(t683, (1, 32, 512, 128)) # t686: "cuda:0 bf16[1, 32, 512, 128]"
del t683
t687 = torch_slice_prim_impl(t684, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t687: "cuda:0 bf16[1, 32, 512, 128]"
t702 = torch_slice_prim_impl(t685, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t702: "cuda:0 bf16[1, 32, 512, 128]"
t717 = torch_slice_prim_impl(t684, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t717: "cuda:0 bf16[1, 32, 512, 0]"
del t684
t719 = torch_slice_prim_impl(t685, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t719: "cuda:0 bf16[1, 32, 512, 0]"
del t685
t688 = torch_slice_prim_impl(t687, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t688: "cuda:0 bf16[1, 32, 512, 64]"
t689 = torch_slice_prim_impl(t687, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t689: "cuda:0 bf16[1, 32, 512, 64]"
t703 = torch_slice_prim_impl(t702, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t703: "cuda:0 bf16[1, 32, 512, 64]"
t704 = torch_slice_prim_impl(t702, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t704: "cuda:0 bf16[1, 32, 512, 64]"
[t692, t707] = nvFusion26(t687, t689, t702, t704)
# t690 = prims.convert_element_type(t689, dtypes.float32) # t690: "cuda:0 f32[1, 32, 512, 64]"
# t691 = prims.neg(t690) # t691: "cuda:0 f32[1, 32, 512, 64]"
# t692 = prims.convert_element_type(t691, dtypes.bfloat16) # t692: "cuda:0 bf16[1, 32, 512, 64]"
# t705 = prims.convert_element_type(t704, dtypes.float32) # t705: "cuda:0 f32[1, 32, 512, 64]"
# t706 = prims.neg(t705) # t706: "cuda:0 f32[1, 32, 512, 64]"
# t707 = prims.convert_element_type(t706, dtypes.bfloat16) # t707: "cuda:0 bf16[1, 32, 512, 64]"
del t689, t704
t708 = torch.cat((t707, t703), -1) # t708: "cuda:0 bf16[1, 32, 512, 128]"
# t708 = ltorch.cat((t707, t703), -1) # t708: "cuda:0 bf16[1, 32, 512, 128]"
# t708 = prims.cat((t707, t703), -1) # t708: "cuda:0 bf16[1, 32, 512, 128]"
del t707, t703
t693 = torch.cat((t692, t688), -1) # t693: "cuda:0 bf16[1, 32, 512, 128]"
# t693 = ltorch.cat((t692, t688), -1) # t693: "cuda:0 bf16[1, 32, 512, 128]"
# t693 = prims.cat((t692, t688), -1) # t693: "cuda:0 bf16[1, 32, 512, 128]"
del t692, t688
[t701, t716] = nvFusion27(t154, t157, t687, t693, t702, t708)
# t695 = prims.convert_element_type(t687, dtypes.float32) # t695: "cuda:0 f32[1, 32, 512, 128]"
# t710 = prims.convert_element_type(t702, dtypes.float32) # t710: "cuda:0 f32[1, 32, 512, 128]"
# t711 = prims.mul(t710, t154) # t711: "cuda:0 f32[1, 32, 512, 128]"
# t713 = prims.convert_element_type(t708, dtypes.float32) # t713: "cuda:0 f32[1, 32, 512, 128]"
# t714 = prims.mul(t713, t157) # t714: "cuda:0 f32[1, 32, 512, 128]"
# t715 = prims.add(t711, t714) # t715: "cuda:0 f32[1, 32, 512, 128]"
# t716 = prims.convert_element_type(t715, dtypes.bfloat16) # t716: "cuda:0 bf16[1, 32, 512, 128]"
# t696 = prims.mul(t695, t154) # t696: "cuda:0 f32[1, 32, 512, 128]"
# t698 = prims.convert_element_type(t693, dtypes.float32) # t698: "cuda:0 f32[1, 32, 512, 128]"
# t699 = prims.mul(t698, t157) # t699: "cuda:0 f32[1, 32, 512, 128]"
# t700 = prims.add(t696, t699) # t700: "cuda:0 f32[1, 32, 512, 128]"
# t701 = prims.convert_element_type(t700, dtypes.bfloat16) # t701: "cuda:0 bf16[1, 32, 512, 128]"
del t687, t693, t702, t708
t720 = torch.cat((t716, t719), -1) # t720: "cuda:0 bf16[1, 32, 512, 128]"
# t720 = ltorch.cat((t716, t719), -1) # t720: "cuda:0 bf16[1, 32, 512, 128]"
# t720 = prims.cat((t716, t719), -1) # t720: "cuda:0 bf16[1, 32, 512, 128]"
del t716, t719
t718 = torch.cat((t701, t717), -1) # t718: "cuda:0 bf16[1, 32, 512, 128]"
# t718 = ltorch.cat((t701, t717), -1) # t718: "cuda:0 bf16[1, 32, 512, 128]"
# t718 = prims.cat((t701, t717), -1) # t718: "cuda:0 bf16[1, 32, 512, 128]"
del t701, t717
(t721, t722, t723, t724, _, _, t725, t726, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t718, t720, t686, 0.0, True, scale=0.08838834764831843)
t728 = torch.permute(t721, (0, 2, 1, 3)) # t728: "cuda:0 bf16[1, 512, 32, 128]"
# t728 = ltorch.permute(t721, (0, 2, 1, 3)) # t728: "cuda:0 bf16[1, 512, 32, 128]"
# t728 = prims.transpose(t721, (0, 2, 1, 3)) # t728: "cuda:0 bf16[1, 512, 32, 128]"
t729 = torch.reshape(t728, (1, 512, 4096)) # t729: "cuda:0 bf16[1, 512, 4096]"
# t729 = ltorch.reshape(t728, (1, 512, 4096)) # t729: "cuda:0 bf16[1, 512, 4096]"
# t729 = prims.reshape(t728, (1, 512, 4096)) # t729: "cuda:0 bf16[1, 512, 4096]"
del t728
t730 = torch.nn.functional.linear(t729, t95, None) # t730: "cuda:0 bf16[1, 512, 4096]"
# t730 = ltorch.linear(t729, t95, None) # t730: "cuda:0 bf16[1, 512, 4096]"
# t730 = prims.linear(t729, t95, None) # t730: "cuda:0 bf16[1, 512, 4096]"
[t734, t741, t749] = nvFusion28(t662, t730, t745)
# t732 = prims.convert_element_type(t662, dtypes.float32) # t732: "cuda:0 f32[1, 512, 4096]"
# t731 = prims.convert_element_type(t730, dtypes.float32) # t731: "cuda:0 f32[1, 512, 4096]"
# t733 = prims.add(t731, t732) # t733: "cuda:0 f32[1, 512, 4096]"
# t734 = prims.convert_element_type(t733, dtypes.bfloat16) # t734: "cuda:0 bf16[1, 512, 4096]"
# t736 = prims.mul(t733, t733) # t736: "cuda:0 f32[1, 512, 4096]"
# t737 = prims.sum(t736, (2,)) # t737: "cuda:0 f32[1, 512]"
# t738 = prims.broadcast_in_dim(t737, [1, 512, 1], [0, 1]) # t738: "cuda:0 f32[1, 512, 1]"
# t739 = prims.div(t738, 4096.0) # t739: "cuda:0 f32[1, 512, 1]"
# t740 = prims.add(t739, 1e-05) # t740: "cuda:0 f32[1, 512, 1]"
# t741 = prims.rsqrt(t740) # t741: "cuda:0 f32[1, 512, 1]"
# t742 = prims.broadcast_in_dim(t741, (1, 512, 4096), (0, 1, 2)) # t742: "cuda:0 f32[1, 512, 4096]"
# t743 = prims.mul(t733, t742) # t743: "cuda:0 f32[1, 512, 4096]"
# t747 = prims.convert_element_type(t745, dtypes.float32) # t747: "cuda:0 f32[1, 512, 4096]"
# t748 = prims.mul(t743, t747) # t748: "cuda:0 f32[1, 512, 4096]"
# t749 = prims.convert_element_type(t748, dtypes.bfloat16) # t749: "cuda:0 bf16[1, 512, 4096]"
t750 = torch.nn.functional.linear(t749, t24, None) # t750: "cuda:0 bf16[1, 512, 11008]"
# t750 = ltorch.linear(t749, t24, None) # t750: "cuda:0 bf16[1, 512, 11008]"
# t750 = prims.linear(t749, t24, None) # t750: "cuda:0 bf16[1, 512, 11008]"
t751 = torch.nn.functional.linear(t749, t40, None) # t751: "cuda:0 bf16[1, 512, 11008]"
# t751 = ltorch.linear(t749, t40, None) # t751: "cuda:0 bf16[1, 512, 11008]"
# t751 = prims.linear(t749, t40, None) # t751: "cuda:0 bf16[1, 512, 11008]"
[t765] = nvFusion29(t750, t751)
# t752 = prims.convert_element_type(t750, dtypes.float32) # t752: "cuda:0 f32[1, 512, 11008]"
# t753 = prims.neg(t752) # t753: "cuda:0 f32[1, 512, 11008]"
# t754 = prims.exp(t753) # t754: "cuda:0 f32[1, 512, 11008]"
# t755 = prims.add(1.0, t754) # t755: "cuda:0 f32[1, 512, 11008]"
# t756 = prims.reciprocal(t755) # t756: "cuda:0 f32[1, 512, 11008]"
# t760 = prims.mul(t752, t756) # t760: "cuda:0 f32[1, 512, 11008]"
# t763 = prims.convert_element_type(t751, dtypes.float32) # t763: "cuda:0 f32[1, 512, 11008]"
# t764 = prims.mul(t760, t763) # t764: "cuda:0 f32[1, 512, 11008]"
# t765 = prims.convert_element_type(t764, dtypes.bfloat16) # t765: "cuda:0 bf16[1, 512, 11008]"
t766 = torch.nn.functional.linear(t765, t96, None) # t766: "cuda:0 bf16[1, 512, 4096]"
# t766 = ltorch.linear(t765, t96, None) # t766: "cuda:0 bf16[1, 512, 4096]"
# t766 = prims.linear(t765, t96, None) # t766: "cuda:0 bf16[1, 512, 4096]"
[t770, t777, t785] = nvFusion30(t734, t766, t781)
# t768 = prims.convert_element_type(t734, dtypes.float32) # t768: "cuda:0 f32[1, 512, 4096]"
# t767 = prims.convert_element_type(t766, dtypes.float32) # t767: "cuda:0 f32[1, 512, 4096]"
# t769 = prims.add(t767, t768) # t769: "cuda:0 f32[1, 512, 4096]"
# t770 = prims.convert_element_type(t769, dtypes.bfloat16) # t770: "cuda:0 bf16[1, 512, 4096]"
# t772 = prims.mul(t769, t769) # t772: "cuda:0 f32[1, 512, 4096]"
# t773 = prims.sum(t772, (2,)) # t773: "cuda:0 f32[1, 512]"
# t774 = prims.broadcast_in_dim(t773, [1, 512, 1], [0, 1]) # t774: "cuda:0 f32[1, 512, 1]"
# t775 = prims.div(t774, 4096.0) # t775: "cuda:0 f32[1, 512, 1]"
# t776 = prims.add(t775, 1e-05) # t776: "cuda:0 f32[1, 512, 1]"
# t777 = prims.rsqrt(t776) # t777: "cuda:0 f32[1, 512, 1]"
# t778 = prims.broadcast_in_dim(t777, (1, 512, 4096), (0, 1, 2)) # t778: "cuda:0 f32[1, 512, 4096]"
# t779 = prims.mul(t769, t778) # t779: "cuda:0 f32[1, 512, 4096]"
# t783 = prims.convert_element_type(t781, dtypes.float32) # t783: "cuda:0 f32[1, 512, 4096]"
# t784 = prims.mul(t779, t783) # t784: "cuda:0 f32[1, 512, 4096]"
# t785 = prims.convert_element_type(t784, dtypes.bfloat16) # t785: "cuda:0 bf16[1, 512, 4096]"
t786 = torch.nn.functional.linear(t785, t9, None) # t786: "cuda:0 bf16[1, 512, 12288]"
# t786 = ltorch.linear(t785, t9, None) # t786: "cuda:0 bf16[1, 512, 12288]"
# t786 = prims.linear(t785, t9, None) # t786: "cuda:0 bf16[1, 512, 12288]"
t787 = torch.reshape(t786, (1, 512, 32, 3, 128)) # t787: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t787 = ltorch.reshape(t786, (1, 512, 32, 3, 128)) # t787: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t787 = prims.reshape(t786, (1, 512, 32, 3, 128)) # t787: "cuda:0 bf16[1, 512, 32, 3, 128]"
del t786
t788 = torch.permute(t787, (0, 2, 3, 1, 4)) # t788: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t788 = ltorch.permute(t787, (0, 2, 3, 1, 4)) # t788: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t788 = prims.transpose(t787, (0, 2, 3, 1, 4)) # t788: "cuda:0 bf16[1, 32, 3, 512, 128]"
del t787
(t789, t790, t791) = torch.split(t788, (1, 1, 1), 2)
# (t789, t790, t791) = ltorch.split(t788, (1, 1, 1), 2)
# t789 = prims.slice_prim(t788, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t789: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t790 = prims.slice_prim(t788, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t790: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t791 = prims.slice_prim(t788, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t791: "cuda:0 bf16[1, 32, 1, 512, 128]"
del t788
t792 = torch.reshape(t789, (1, 32, 512, 128)) # t792: "cuda:0 bf16[1, 32, 512, 128]"
# t792 = ltorch.reshape(t789, (1, 32, 512, 128)) # t792: "cuda:0 bf16[1, 32, 512, 128]"
# t792 = prims.reshape(t789, (1, 32, 512, 128)) # t792: "cuda:0 bf16[1, 32, 512, 128]"
del t789
t793 = torch.reshape(t790, (1, 32, 512, 128)) # t793: "cuda:0 bf16[1, 32, 512, 128]"
# t793 = ltorch.reshape(t790, (1, 32, 512, 128)) # t793: "cuda:0 bf16[1, 32, 512, 128]"
# t793 = prims.reshape(t790, (1, 32, 512, 128)) # t793: "cuda:0 bf16[1, 32, 512, 128]"
del t790
t794 = torch.reshape(t791, (1, 32, 512, 128)) # t794: "cuda:0 bf16[1, 32, 512, 128]"
# t794 = ltorch.reshape(t791, (1, 32, 512, 128)) # t794: "cuda:0 bf16[1, 32, 512, 128]"
# t794 = prims.reshape(t791, (1, 32, 512, 128)) # t794: "cuda:0 bf16[1, 32, 512, 128]"
del t791
t795 = torch_slice_prim_impl(t792, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t795: "cuda:0 bf16[1, 32, 512, 128]"
t810 = torch_slice_prim_impl(t793, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t810: "cuda:0 bf16[1, 32, 512, 128]"
t825 = torch_slice_prim_impl(t792, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t825: "cuda:0 bf16[1, 32, 512, 0]"
del t792
t827 = torch_slice_prim_impl(t793, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t827: "cuda:0 bf16[1, 32, 512, 0]"
del t793
t796 = torch_slice_prim_impl(t795, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t796: "cuda:0 bf16[1, 32, 512, 64]"
t797 = torch_slice_prim_impl(t795, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t797: "cuda:0 bf16[1, 32, 512, 64]"
t811 = torch_slice_prim_impl(t810, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t811: "cuda:0 bf16[1, 32, 512, 64]"
t812 = torch_slice_prim_impl(t810, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t812: "cuda:0 bf16[1, 32, 512, 64]"
[t800, t815] = nvFusion31(t795, t797, t810, t812)
# t798 = prims.convert_element_type(t797, dtypes.float32) # t798: "cuda:0 f32[1, 32, 512, 64]"
# t799 = prims.neg(t798) # t799: "cuda:0 f32[1, 32, 512, 64]"
# t800 = prims.convert_element_type(t799, dtypes.bfloat16) # t800: "cuda:0 bf16[1, 32, 512, 64]"
# t813 = prims.convert_element_type(t812, dtypes.float32) # t813: "cuda:0 f32[1, 32, 512, 64]"
# t814 = prims.neg(t813) # t814: "cuda:0 f32[1, 32, 512, 64]"
# t815 = prims.convert_element_type(t814, dtypes.bfloat16) # t815: "cuda:0 bf16[1, 32, 512, 64]"
del t797, t812
t816 = torch.cat((t815, t811), -1) # t816: "cuda:0 bf16[1, 32, 512, 128]"
# t816 = ltorch.cat((t815, t811), -1) # t816: "cuda:0 bf16[1, 32, 512, 128]"
# t816 = prims.cat((t815, t811), -1) # t816: "cuda:0 bf16[1, 32, 512, 128]"
del t815, t811
t801 = torch.cat((t800, t796), -1) # t801: "cuda:0 bf16[1, 32, 512, 128]"
# t801 = ltorch.cat((t800, t796), -1) # t801: "cuda:0 bf16[1, 32, 512, 128]"
# t801 = prims.cat((t800, t796), -1) # t801: "cuda:0 bf16[1, 32, 512, 128]"
del t800, t796
[t809, t824] = nvFusion32(t154, t157, t795, t801, t810, t816)
# t803 = prims.convert_element_type(t795, dtypes.float32) # t803: "cuda:0 f32[1, 32, 512, 128]"
# t818 = prims.convert_element_type(t810, dtypes.float32) # t818: "cuda:0 f32[1, 32, 512, 128]"
# t819 = prims.mul(t818, t154) # t819: "cuda:0 f32[1, 32, 512, 128]"
# t821 = prims.convert_element_type(t816, dtypes.float32) # t821: "cuda:0 f32[1, 32, 512, 128]"
# t822 = prims.mul(t821, t157) # t822: "cuda:0 f32[1, 32, 512, 128]"
# t823 = prims.add(t819, t822) # t823: "cuda:0 f32[1, 32, 512, 128]"
# t824 = prims.convert_element_type(t823, dtypes.bfloat16) # t824: "cuda:0 bf16[1, 32, 512, 128]"
# t804 = prims.mul(t803, t154) # t804: "cuda:0 f32[1, 32, 512, 128]"
# t806 = prims.convert_element_type(t801, dtypes.float32) # t806: "cuda:0 f32[1, 32, 512, 128]"
# t807 = prims.mul(t806, t157) # t807: "cuda:0 f32[1, 32, 512, 128]"
# t808 = prims.add(t804, t807) # t808: "cuda:0 f32[1, 32, 512, 128]"
# t809 = prims.convert_element_type(t808, dtypes.bfloat16) # t809: "cuda:0 bf16[1, 32, 512, 128]"
del t795, t801, t810, t816
t828 = torch.cat((t824, t827), -1) # t828: "cuda:0 bf16[1, 32, 512, 128]"
# t828 = ltorch.cat((t824, t827), -1) # t828: "cuda:0 bf16[1, 32, 512, 128]"
# t828 = prims.cat((t824, t827), -1) # t828: "cuda:0 bf16[1, 32, 512, 128]"
del t824, t827
t826 = torch.cat((t809, t825), -1) # t826: "cuda:0 bf16[1, 32, 512, 128]"
# t826 = ltorch.cat((t809, t825), -1) # t826: "cuda:0 bf16[1, 32, 512, 128]"
# t826 = prims.cat((t809, t825), -1) # t826: "cuda:0 bf16[1, 32, 512, 128]"
del t809, t825
(t829, t830, t831, t832, _, _, t833, t834, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t826, t828, t794, 0.0, True, scale=0.08838834764831843)
t836 = torch.permute(t829, (0, 2, 1, 3)) # t836: "cuda:0 bf16[1, 512, 32, 128]"
# t836 = ltorch.permute(t829, (0, 2, 1, 3)) # t836: "cuda:0 bf16[1, 512, 32, 128]"
# t836 = prims.transpose(t829, (0, 2, 1, 3)) # t836: "cuda:0 bf16[1, 512, 32, 128]"
t837 = torch.reshape(t836, (1, 512, 4096)) # t837: "cuda:0 bf16[1, 512, 4096]"
# t837 = ltorch.reshape(t836, (1, 512, 4096)) # t837: "cuda:0 bf16[1, 512, 4096]"
# t837 = prims.reshape(t836, (1, 512, 4096)) # t837: "cuda:0 bf16[1, 512, 4096]"
del t836
t838 = torch.nn.functional.linear(t837, t97, None) # t838: "cuda:0 bf16[1, 512, 4096]"
# t838 = ltorch.linear(t837, t97, None) # t838: "cuda:0 bf16[1, 512, 4096]"
# t838 = prims.linear(t837, t97, None) # t838: "cuda:0 bf16[1, 512, 4096]"
[t842, t849, t857] = nvFusion33(t770, t838, t853)
# t840 = prims.convert_element_type(t770, dtypes.float32) # t840: "cuda:0 f32[1, 512, 4096]"
# t839 = prims.convert_element_type(t838, dtypes.float32) # t839: "cuda:0 f32[1, 512, 4096]"
# t841 = prims.add(t839, t840) # t841: "cuda:0 f32[1, 512, 4096]"
# t842 = prims.convert_element_type(t841, dtypes.bfloat16) # t842: "cuda:0 bf16[1, 512, 4096]"
# t844 = prims.mul(t841, t841) # t844: "cuda:0 f32[1, 512, 4096]"
# t845 = prims.sum(t844, (2,)) # t845: "cuda:0 f32[1, 512]"
# t846 = prims.broadcast_in_dim(t845, [1, 512, 1], [0, 1]) # t846: "cuda:0 f32[1, 512, 1]"
# t847 = prims.div(t846, 4096.0) # t847: "cuda:0 f32[1, 512, 1]"
# t848 = prims.add(t847, 1e-05) # t848: "cuda:0 f32[1, 512, 1]"
# t849 = prims.rsqrt(t848) # t849: "cuda:0 f32[1, 512, 1]"
# t850 = prims.broadcast_in_dim(t849, (1, 512, 4096), (0, 1, 2)) # t850: "cuda:0 f32[1, 512, 4096]"
# t851 = prims.mul(t841, t850) # t851: "cuda:0 f32[1, 512, 4096]"
# t855 = prims.convert_element_type(t853, dtypes.float32) # t855: "cuda:0 f32[1, 512, 4096]"
# t856 = prims.mul(t851, t855) # t856: "cuda:0 f32[1, 512, 4096]"
# t857 = prims.convert_element_type(t856, dtypes.bfloat16) # t857: "cuda:0 bf16[1, 512, 4096]"
t858 = torch.nn.functional.linear(t857, t25, None) # t858: "cuda:0 bf16[1, 512, 11008]"
# t858 = ltorch.linear(t857, t25, None) # t858: "cuda:0 bf16[1, 512, 11008]"
# t858 = prims.linear(t857, t25, None) # t858: "cuda:0 bf16[1, 512, 11008]"
t859 = torch.nn.functional.linear(t857, t41, None) # t859: "cuda:0 bf16[1, 512, 11008]"
# t859 = ltorch.linear(t857, t41, None) # t859: "cuda:0 bf16[1, 512, 11008]"
# t859 = prims.linear(t857, t41, None) # t859: "cuda:0 bf16[1, 512, 11008]"
[t873] = nvFusion34(t858, t859)
# t860 = prims.convert_element_type(t858, dtypes.float32) # t860: "cuda:0 f32[1, 512, 11008]"
# t861 = prims.neg(t860) # t861: "cuda:0 f32[1, 512, 11008]"
# t862 = prims.exp(t861) # t862: "cuda:0 f32[1, 512, 11008]"
# t863 = prims.add(1.0, t862) # t863: "cuda:0 f32[1, 512, 11008]"
# t864 = prims.reciprocal(t863) # t864: "cuda:0 f32[1, 512, 11008]"
# t868 = prims.mul(t860, t864) # t868: "cuda:0 f32[1, 512, 11008]"
# t871 = prims.convert_element_type(t859, dtypes.float32) # t871: "cuda:0 f32[1, 512, 11008]"
# t872 = prims.mul(t868, t871) # t872: "cuda:0 f32[1, 512, 11008]"
# t873 = prims.convert_element_type(t872, dtypes.bfloat16) # t873: "cuda:0 bf16[1, 512, 11008]"
t874 = torch.nn.functional.linear(t873, t98, None) # t874: "cuda:0 bf16[1, 512, 4096]"
# t874 = ltorch.linear(t873, t98, None) # t874: "cuda:0 bf16[1, 512, 4096]"
# t874 = prims.linear(t873, t98, None) # t874: "cuda:0 bf16[1, 512, 4096]"
[t878, t885, t893] = nvFusion35(t842, t874, t889)
# t876 = prims.convert_element_type(t842, dtypes.float32) # t876: "cuda:0 f32[1, 512, 4096]"
# t875 = prims.convert_element_type(t874, dtypes.float32) # t875: "cuda:0 f32[1, 512, 4096]"
# t877 = prims.add(t875, t876) # t877: "cuda:0 f32[1, 512, 4096]"
# t878 = prims.convert_element_type(t877, dtypes.bfloat16) # t878: "cuda:0 bf16[1, 512, 4096]"
# t880 = prims.mul(t877, t877) # t880: "cuda:0 f32[1, 512, 4096]"
# t881 = prims.sum(t880, (2,)) # t881: "cuda:0 f32[1, 512]"
# t882 = prims.broadcast_in_dim(t881, [1, 512, 1], [0, 1]) # t882: "cuda:0 f32[1, 512, 1]"
# t883 = prims.div(t882, 4096.0) # t883: "cuda:0 f32[1, 512, 1]"
# t884 = prims.add(t883, 1e-05) # t884: "cuda:0 f32[1, 512, 1]"
# t885 = prims.rsqrt(t884) # t885: "cuda:0 f32[1, 512, 1]"
# t886 = prims.broadcast_in_dim(t885, (1, 512, 4096), (0, 1, 2)) # t886: "cuda:0 f32[1, 512, 4096]"
# t887 = prims.mul(t877, t886) # t887: "cuda:0 f32[1, 512, 4096]"
# t891 = prims.convert_element_type(t889, dtypes.float32) # t891: "cuda:0 f32[1, 512, 4096]"
# t892 = prims.mul(t887, t891) # t892: "cuda:0 f32[1, 512, 4096]"
# t893 = prims.convert_element_type(t892, dtypes.bfloat16) # t893: "cuda:0 bf16[1, 512, 4096]"
t894 = torch.nn.functional.linear(t893, t10, None) # t894: "cuda:0 bf16[1, 512, 12288]"
# t894 = ltorch.linear(t893, t10, None) # t894: "cuda:0 bf16[1, 512, 12288]"
# t894 = prims.linear(t893, t10, None) # t894: "cuda:0 bf16[1, 512, 12288]"
t895 = torch.reshape(t894, (1, 512, 32, 3, 128)) # t895: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t895 = ltorch.reshape(t894, (1, 512, 32, 3, 128)) # t895: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t895 = prims.reshape(t894, (1, 512, 32, 3, 128)) # t895: "cuda:0 bf16[1, 512, 32, 3, 128]"
del t894
t896 = torch.permute(t895, (0, 2, 3, 1, 4)) # t896: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t896 = ltorch.permute(t895, (0, 2, 3, 1, 4)) # t896: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t896 = prims.transpose(t895, (0, 2, 3, 1, 4)) # t896: "cuda:0 bf16[1, 32, 3, 512, 128]"
del t895
(t897, t898, t899) = torch.split(t896, (1, 1, 1), 2)
# (t897, t898, t899) = ltorch.split(t896, (1, 1, 1), 2)
# t897 = prims.slice_prim(t896, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t897: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t898 = prims.slice_prim(t896, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t898: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t899 = prims.slice_prim(t896, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t899: "cuda:0 bf16[1, 32, 1, 512, 128]"
del t896
t900 = torch.reshape(t897, (1, 32, 512, 128)) # t900: "cuda:0 bf16[1, 32, 512, 128]"
# t900 = ltorch.reshape(t897, (1, 32, 512, 128)) # t900: "cuda:0 bf16[1, 32, 512, 128]"
# t900 = prims.reshape(t897, (1, 32, 512, 128)) # t900: "cuda:0 bf16[1, 32, 512, 128]"
del t897
t901 = torch.reshape(t898, (1, 32, 512, 128)) # t901: "cuda:0 bf16[1, 32, 512, 128]"
# t901 = ltorch.reshape(t898, (1, 32, 512, 128)) # t901: "cuda:0 bf16[1, 32, 512, 128]"
# t901 = prims.reshape(t898, (1, 32, 512, 128)) # t901: "cuda:0 bf16[1, 32, 512, 128]"
del t898
t902 = torch.reshape(t899, (1, 32, 512, 128)) # t902: "cuda:0 bf16[1, 32, 512, 128]"
# t902 = ltorch.reshape(t899, (1, 32, 512, 128)) # t902: "cuda:0 bf16[1, 32, 512, 128]"
# t902 = prims.reshape(t899, (1, 32, 512, 128)) # t902: "cuda:0 bf16[1, 32, 512, 128]"
del t899
t935 = torch_slice_prim_impl(t901, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t935: "cuda:0 bf16[1, 32, 512, 0]"
t903 = torch_slice_prim_impl(t900, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t903: "cuda:0 bf16[1, 32, 512, 128]"
t918 = torch_slice_prim_impl(t901, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t918: "cuda:0 bf16[1, 32, 512, 128]"
del t901
t933 = torch_slice_prim_impl(t900, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t933: "cuda:0 bf16[1, 32, 512, 0]"
del t900
t904 = torch_slice_prim_impl(t903, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t904: "cuda:0 bf16[1, 32, 512, 64]"
t905 = torch_slice_prim_impl(t903, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t905: "cuda:0 bf16[1, 32, 512, 64]"
t919 = torch_slice_prim_impl(t918, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t919: "cuda:0 bf16[1, 32, 512, 64]"
t920 = torch_slice_prim_impl(t918, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t920: "cuda:0 bf16[1, 32, 512, 64]"
[t908, t923] = nvFusion36(t903, t905, t918, t920)
# t906 = prims.convert_element_type(t905, dtypes.float32) # t906: "cuda:0 f32[1, 32, 512, 64]"
# t907 = prims.neg(t906) # t907: "cuda:0 f32[1, 32, 512, 64]"
# t908 = prims.convert_element_type(t907, dtypes.bfloat16) # t908: "cuda:0 bf16[1, 32, 512, 64]"
# t921 = prims.convert_element_type(t920, dtypes.float32) # t921: "cuda:0 f32[1, 32, 512, 64]"
# t922 = prims.neg(t921) # t922: "cuda:0 f32[1, 32, 512, 64]"
# t923 = prims.convert_element_type(t922, dtypes.bfloat16) # t923: "cuda:0 bf16[1, 32, 512, 64]"
del t905, t920
t924 = torch.cat((t923, t919), -1) # t924: "cuda:0 bf16[1, 32, 512, 128]"
# t924 = ltorch.cat((t923, t919), -1) # t924: "cuda:0 bf16[1, 32, 512, 128]"
# t924 = prims.cat((t923, t919), -1) # t924: "cuda:0 bf16[1, 32, 512, 128]"
del t923, t919
t909 = torch.cat((t908, t904), -1) # t909: "cuda:0 bf16[1, 32, 512, 128]"
# t909 = ltorch.cat((t908, t904), -1) # t909: "cuda:0 bf16[1, 32, 512, 128]"
# t909 = prims.cat((t908, t904), -1) # t909: "cuda:0 bf16[1, 32, 512, 128]"
del t908, t904
[t917, t932] = nvFusion37(t154, t157, t903, t909, t918, t924)
# t911 = prims.convert_element_type(t903, dtypes.float32) # t911: "cuda:0 f32[1, 32, 512, 128]"
# t926 = prims.convert_element_type(t918, dtypes.float32) # t926: "cuda:0 f32[1, 32, 512, 128]"
# t927 = prims.mul(t926, t154) # t927: "cuda:0 f32[1, 32, 512, 128]"
# t929 = prims.convert_element_type(t924, dtypes.float32) # t929: "cuda:0 f32[1, 32, 512, 128]"
# t930 = prims.mul(t929, t157) # t930: "cuda:0 f32[1, 32, 512, 128]"
# t931 = prims.add(t927, t930) # t931: "cuda:0 f32[1, 32, 512, 128]"
# t932 = prims.convert_element_type(t931, dtypes.bfloat16) # t932: "cuda:0 bf16[1, 32, 512, 128]"
# t912 = prims.mul(t911, t154) # t912: "cuda:0 f32[1, 32, 512, 128]"
# t914 = prims.convert_element_type(t909, dtypes.float32) # t914: "cuda:0 f32[1, 32, 512, 128]"
# t915 = prims.mul(t914, t157) # t915: "cuda:0 f32[1, 32, 512, 128]"
# t916 = prims.add(t912, t915) # t916: "cuda:0 f32[1, 32, 512, 128]"
# t917 = prims.convert_element_type(t916, dtypes.bfloat16) # t917: "cuda:0 bf16[1, 32, 512, 128]"
del t903, t909, t918, t924
t936 = torch.cat((t932, t935), -1) # t936: "cuda:0 bf16[1, 32, 512, 128]"
# t936 = ltorch.cat((t932, t935), -1) # t936: "cuda:0 bf16[1, 32, 512, 128]"
# t936 = prims.cat((t932, t935), -1) # t936: "cuda:0 bf16[1, 32, 512, 128]"
del t932, t935
t934 = torch.cat((t917, t933), -1) # t934: "cuda:0 bf16[1, 32, 512, 128]"
# t934 = ltorch.cat((t917, t933), -1) # t934: "cuda:0 bf16[1, 32, 512, 128]"
# t934 = prims.cat((t917, t933), -1) # t934: "cuda:0 bf16[1, 32, 512, 128]"
del t917, t933
(t937, t938, t939, t940, _, _, t941, t942, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t934, t936, t902, 0.0, True, scale=0.08838834764831843)
t944 = torch.permute(t937, (0, 2, 1, 3)) # t944: "cuda:0 bf16[1, 512, 32, 128]"
# t944 = ltorch.permute(t937, (0, 2, 1, 3)) # t944: "cuda:0 bf16[1, 512, 32, 128]"
# t944 = prims.transpose(t937, (0, 2, 1, 3)) # t944: "cuda:0 bf16[1, 512, 32, 128]"
t945 = torch.reshape(t944, (1, 512, 4096)) # t945: "cuda:0 bf16[1, 512, 4096]"
# t945 = ltorch.reshape(t944, (1, 512, 4096)) # t945: "cuda:0 bf16[1, 512, 4096]"
# t945 = prims.reshape(t944, (1, 512, 4096)) # t945: "cuda:0 bf16[1, 512, 4096]"
del t944
t946 = torch.nn.functional.linear(t945, t99, None) # t946: "cuda:0 bf16[1, 512, 4096]"
# t946 = ltorch.linear(t945, t99, None) # t946: "cuda:0 bf16[1, 512, 4096]"
# t946 = prims.linear(t945, t99, None) # t946: "cuda:0 bf16[1, 512, 4096]"
[t950, t957, t965] = nvFusion38(t878, t946, t961)
# t948 = prims.convert_element_type(t878, dtypes.float32) # t948: "cuda:0 f32[1, 512, 4096]"
# t947 = prims.convert_element_type(t946, dtypes.float32) # t947: "cuda:0 f32[1, 512, 4096]"
# t949 = prims.add(t947, t948) # t949: "cuda:0 f32[1, 512, 4096]"
# t950 = prims.convert_element_type(t949, dtypes.bfloat16) # t950: "cuda:0 bf16[1, 512, 4096]"
# t952 = prims.mul(t949, t949) # t952: "cuda:0 f32[1, 512, 4096]"
# t953 = prims.sum(t952, (2,)) # t953: "cuda:0 f32[1, 512]"
# t954 = prims.broadcast_in_dim(t953, [1, 512, 1], [0, 1]) # t954: "cuda:0 f32[1, 512, 1]"
# t955 = prims.div(t954, 4096.0) # t955: "cuda:0 f32[1, 512, 1]"
# t956 = prims.add(t955, 1e-05) # t956: "cuda:0 f32[1, 512, 1]"
# t957 = prims.rsqrt(t956) # t957: "cuda:0 f32[1, 512, 1]"
# t958 = prims.broadcast_in_dim(t957, (1, 512, 4096), (0, 1, 2)) # t958: "cuda:0 f32[1, 512, 4096]"
# t959 = prims.mul(t949, t958) # t959: "cuda:0 f32[1, 512, 4096]"
# t963 = prims.convert_element_type(t961, dtypes.float32) # t963: "cuda:0 f32[1, 512, 4096]"
# t964 = prims.mul(t959, t963) # t964: "cuda:0 f32[1, 512, 4096]"
# t965 = prims.convert_element_type(t964, dtypes.bfloat16) # t965: "cuda:0 bf16[1, 512, 4096]"
t967 = torch.nn.functional.linear(t965, t42, None) # t967: "cuda:0 bf16[1, 512, 11008]"
# t967 = ltorch.linear(t965, t42, None) # t967: "cuda:0 bf16[1, 512, 11008]"
# t967 = prims.linear(t965, t42, None) # t967: "cuda:0 bf16[1, 512, 11008]"
t966 = torch.nn.functional.linear(t965, t26, None) # t966: "cuda:0 bf16[1, 512, 11008]"
# t966 = ltorch.linear(t965, t26, None) # t966: "cuda:0 bf16[1, 512, 11008]"
# t966 = prims.linear(t965, t26, None) # t966: "cuda:0 bf16[1, 512, 11008]"
[t981] = nvFusion39(t966, t967)
# t968 = prims.convert_element_type(t966, dtypes.float32) # t968: "cuda:0 f32[1, 512, 11008]"
# t969 = prims.neg(t968) # t969: "cuda:0 f32[1, 512, 11008]"
# t970 = prims.exp(t969) # t970: "cuda:0 f32[1, 512, 11008]"
# t971 = prims.add(1.0, t970) # t971: "cuda:0 f32[1, 512, 11008]"
# t972 = prims.reciprocal(t971) # t972: "cuda:0 f32[1, 512, 11008]"
# t976 = prims.mul(t968, t972) # t976: "cuda:0 f32[1, 512, 11008]"
# t979 = prims.convert_element_type(t967, dtypes.float32) # t979: "cuda:0 f32[1, 512, 11008]"
# t980 = prims.mul(t976, t979) # t980: "cuda:0 f32[1, 512, 11008]"
# t981 = prims.convert_element_type(t980, dtypes.bfloat16) # t981: "cuda:0 bf16[1, 512, 11008]"
t982 = torch.nn.functional.linear(t981, t100, None) # t982: "cuda:0 bf16[1, 512, 4096]"
# t982 = ltorch.linear(t981, t100, None) # t982: "cuda:0 bf16[1, 512, 4096]"
# t982 = prims.linear(t981, t100, None) # t982: "cuda:0 bf16[1, 512, 4096]"
[t1001, t986, t993] = nvFusion40(t950, t982, t997)
# t984 = prims.convert_element_type(t950, dtypes.float32) # t984: "cuda:0 f32[1, 512, 4096]"
# t983 = prims.convert_element_type(t982, dtypes.float32) # t983: "cuda:0 f32[1, 512, 4096]"
# t985 = prims.add(t983, t984) # t985: "cuda:0 f32[1, 512, 4096]"
# t986 = prims.convert_element_type(t985, dtypes.bfloat16) # t986: "cuda:0 bf16[1, 512, 4096]"
# t988 = prims.mul(t985, t985) # t988: "cuda:0 f32[1, 512, 4096]"
# t989 = prims.sum(t988, (2,)) # t989: "cuda:0 f32[1, 512]"
# t990 = prims.broadcast_in_dim(t989, [1, 512, 1], [0, 1]) # t990: "cuda:0 f32[1, 512, 1]"
# t991 = prims.div(t990, 4096.0) # t991: "cuda:0 f32[1, 512, 1]"
# t992 = prims.add(t991, 1e-05) # t992: "cuda:0 f32[1, 512, 1]"
# t993 = prims.rsqrt(t992) # t993: "cuda:0 f32[1, 512, 1]"
# t994 = prims.broadcast_in_dim(t993, (1, 512, 4096), (0, 1, 2)) # t994: "cuda:0 f32[1, 512, 4096]"
# t995 = prims.mul(t985, t994) # t995: "cuda:0 f32[1, 512, 4096]"
# t999 = prims.convert_element_type(t997, dtypes.float32) # t999: "cuda:0 f32[1, 512, 4096]"
# t1000 = prims.mul(t995, t999) # t1000: "cuda:0 f32[1, 512, 4096]"
# t1001 = prims.convert_element_type(t1000, dtypes.bfloat16) # t1001: "cuda:0 bf16[1, 512, 4096]"
t1002 = torch.nn.functional.linear(t1001, t11, None) # t1002: "cuda:0 bf16[1, 512, 12288]"
# t1002 = ltorch.linear(t1001, t11, None) # t1002: "cuda:0 bf16[1, 512, 12288]"
# t1002 = prims.linear(t1001, t11, None) # t1002: "cuda:0 bf16[1, 512, 12288]"
t1003 = torch.reshape(t1002, (1, 512, 32, 3, 128)) # t1003: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t1003 = ltorch.reshape(t1002, (1, 512, 32, 3, 128)) # t1003: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t1003 = prims.reshape(t1002, (1, 512, 32, 3, 128)) # t1003: "cuda:0 bf16[1, 512, 32, 3, 128]"
del t1002
t1004 = torch.permute(t1003, (0, 2, 3, 1, 4)) # t1004: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t1004 = ltorch.permute(t1003, (0, 2, 3, 1, 4)) # t1004: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t1004 = prims.transpose(t1003, (0, 2, 3, 1, 4)) # t1004: "cuda:0 bf16[1, 32, 3, 512, 128]"
del t1003
(t1005, t1006, t1007) = torch.split(t1004, (1, 1, 1), 2)
# (t1005, t1006, t1007) = ltorch.split(t1004, (1, 1, 1), 2)
# t1005 = prims.slice_prim(t1004, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1005: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t1006 = prims.slice_prim(t1004, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1006: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t1007 = prims.slice_prim(t1004, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1007: "cuda:0 bf16[1, 32, 1, 512, 128]"
del t1004
t1008 = torch.reshape(t1005, (1, 32, 512, 128)) # t1008: "cuda:0 bf16[1, 32, 512, 128]"
# t1008 = ltorch.reshape(t1005, (1, 32, 512, 128)) # t1008: "cuda:0 bf16[1, 32, 512, 128]"
# t1008 = prims.reshape(t1005, (1, 32, 512, 128)) # t1008: "cuda:0 bf16[1, 32, 512, 128]"
del t1005
t1009 = torch.reshape(t1006, (1, 32, 512, 128)) # t1009: "cuda:0 bf16[1, 32, 512, 128]"
# t1009 = ltorch.reshape(t1006, (1, 32, 512, 128)) # t1009: "cuda:0 bf16[1, 32, 512, 128]"
# t1009 = prims.reshape(t1006, (1, 32, 512, 128)) # t1009: "cuda:0 bf16[1, 32, 512, 128]"
del t1006
t1010 = torch.reshape(t1007, (1, 32, 512, 128)) # t1010: "cuda:0 bf16[1, 32, 512, 128]"
# t1010 = ltorch.reshape(t1007, (1, 32, 512, 128)) # t1010: "cuda:0 bf16[1, 32, 512, 128]"
# t1010 = prims.reshape(t1007, (1, 32, 512, 128)) # t1010: "cuda:0 bf16[1, 32, 512, 128]"
del t1007
t1026 = torch_slice_prim_impl(t1009, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1026: "cuda:0 bf16[1, 32, 512, 128]"
t1041 = torch_slice_prim_impl(t1008, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1041: "cuda:0 bf16[1, 32, 512, 0]"
t1043 = torch_slice_prim_impl(t1009, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1043: "cuda:0 bf16[1, 32, 512, 0]"
del t1009
t1011 = torch_slice_prim_impl(t1008, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1011: "cuda:0 bf16[1, 32, 512, 128]"
del t1008
t1027 = torch_slice_prim_impl(t1026, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1027: "cuda:0 bf16[1, 32, 512, 64]"
t1028 = torch_slice_prim_impl(t1026, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1028: "cuda:0 bf16[1, 32, 512, 64]"
t1013 = torch_slice_prim_impl(t1011, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1013: "cuda:0 bf16[1, 32, 512, 64]"
t1012 = torch_slice_prim_impl(t1011, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1012: "cuda:0 bf16[1, 32, 512, 64]"
[t1016, t1031] = nvFusion41(t1011, t1013, t1026, t1028)
# t1014 = prims.convert_element_type(t1013, dtypes.float32) # t1014: "cuda:0 f32[1, 32, 512, 64]"
# t1015 = prims.neg(t1014) # t1015: "cuda:0 f32[1, 32, 512, 64]"
# t1016 = prims.convert_element_type(t1015, dtypes.bfloat16) # t1016: "cuda:0 bf16[1, 32, 512, 64]"
# t1029 = prims.convert_element_type(t1028, dtypes.float32) # t1029: "cuda:0 f32[1, 32, 512, 64]"
# t1030 = prims.neg(t1029) # t1030: "cuda:0 f32[1, 32, 512, 64]"
# t1031 = prims.convert_element_type(t1030, dtypes.bfloat16) # t1031: "cuda:0 bf16[1, 32, 512, 64]"
del t1013, t1028
t1032 = torch.cat((t1031, t1027), -1) # t1032: "cuda:0 bf16[1, 32, 512, 128]"
# t1032 = ltorch.cat((t1031, t1027), -1) # t1032: "cuda:0 bf16[1, 32, 512, 128]"
# t1032 = prims.cat((t1031, t1027), -1) # t1032: "cuda:0 bf16[1, 32, 512, 128]"
del t1031, t1027
t1017 = torch.cat((t1016, t1012), -1) # t1017: "cuda:0 bf16[1, 32, 512, 128]"
# t1017 = ltorch.cat((t1016, t1012), -1) # t1017: "cuda:0 bf16[1, 32, 512, 128]"
# t1017 = prims.cat((t1016, t1012), -1) # t1017: "cuda:0 bf16[1, 32, 512, 128]"
del t1016, t1012
[t1025, t1040] = nvFusion42(t1011, t1017, t1026, t1032, t154, t157)
# t1019 = prims.convert_element_type(t1011, dtypes.float32) # t1019: "cuda:0 f32[1, 32, 512, 128]"
# t1034 = prims.convert_element_type(t1026, dtypes.float32) # t1034: "cuda:0 f32[1, 32, 512, 128]"
# t1020 = prims.mul(t1019, t154) # t1020: "cuda:0 f32[1, 32, 512, 128]"
# t1022 = prims.convert_element_type(t1017, dtypes.float32) # t1022: "cuda:0 f32[1, 32, 512, 128]"
# t1023 = prims.mul(t1022, t157) # t1023: "cuda:0 f32[1, 32, 512, 128]"
# t1024 = prims.add(t1020, t1023) # t1024: "cuda:0 f32[1, 32, 512, 128]"
# t1025 = prims.convert_element_type(t1024, dtypes.bfloat16) # t1025: "cuda:0 bf16[1, 32, 512, 128]"
# t1035 = prims.mul(t1034, t154) # t1035: "cuda:0 f32[1, 32, 512, 128]"
# t1037 = prims.convert_element_type(t1032, dtypes.float32) # t1037: "cuda:0 f32[1, 32, 512, 128]"
# t1038 = prims.mul(t1037, t157) # t1038: "cuda:0 f32[1, 32, 512, 128]"
# t1039 = prims.add(t1035, t1038) # t1039: "cuda:0 f32[1, 32, 512, 128]"
# t1040 = prims.convert_element_type(t1039, dtypes.bfloat16) # t1040: "cuda:0 bf16[1, 32, 512, 128]"
del t1011, t1017, t1026, t1032
t1042 = torch.cat((t1025, t1041), -1) # t1042: "cuda:0 bf16[1, 32, 512, 128]"
# t1042 = ltorch.cat((t1025, t1041), -1) # t1042: "cuda:0 bf16[1, 32, 512, 128]"
# t1042 = prims.cat((t1025, t1041), -1) # t1042: "cuda:0 bf16[1, 32, 512, 128]"
del t1025, t1041
t1044 = torch.cat((t1040, t1043), -1) # t1044: "cuda:0 bf16[1, 32, 512, 128]"
# t1044 = ltorch.cat((t1040, t1043), -1) # t1044: "cuda:0 bf16[1, 32, 512, 128]"
# t1044 = prims.cat((t1040, t1043), -1) # t1044: "cuda:0 bf16[1, 32, 512, 128]"
del t1040, t1043
(t1045, t1046, t1047, t1048, _, _, t1049, t1050, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1042, t1044, t1010, 0.0, True, scale=0.08838834764831843)
t1052 = torch.permute(t1045, (0, 2, 1, 3)) # t1052: "cuda:0 bf16[1, 512, 32, 128]"
# t1052 = ltorch.permute(t1045, (0, 2, 1, 3)) # t1052: "cuda:0 bf16[1, 512, 32, 128]"
# t1052 = prims.transpose(t1045, (0, 2, 1, 3)) # t1052: "cuda:0 bf16[1, 512, 32, 128]"
t1053 = torch.reshape(t1052, (1, 512, 4096)) # t1053: "cuda:0 bf16[1, 512, 4096]"
# t1053 = ltorch.reshape(t1052, (1, 512, 4096)) # t1053: "cuda:0 bf16[1, 512, 4096]"
# t1053 = prims.reshape(t1052, (1, 512, 4096)) # t1053: "cuda:0 bf16[1, 512, 4096]"
del t1052
t1054 = torch.nn.functional.linear(t1053, t101, None) # t1054: "cuda:0 bf16[1, 512, 4096]"
# t1054 = ltorch.linear(t1053, t101, None) # t1054: "cuda:0 bf16[1, 512, 4096]"
# t1054 = prims.linear(t1053, t101, None) # t1054: "cuda:0 bf16[1, 512, 4096]"
[t1058, t1065, t1073] = nvFusion43(t1054, t1069, t986)
# t1056 = prims.convert_element_type(t986, dtypes.float32) # t1056: "cuda:0 f32[1, 512, 4096]"
# t1055 = prims.convert_element_type(t1054, dtypes.float32) # t1055: "cuda:0 f32[1, 512, 4096]"
# t1057 = prims.add(t1055, t1056) # t1057: "cuda:0 f32[1, 512, 4096]"
# t1058 = prims.convert_element_type(t1057, dtypes.bfloat16) # t1058: "cuda:0 bf16[1, 512, 4096]"
# t1060 = prims.mul(t1057, t1057) # t1060: "cuda:0 f32[1, 512, 4096]"
# t1061 = prims.sum(t1060, (2,)) # t1061: "cuda:0 f32[1, 512]"
# t1062 = prims.broadcast_in_dim(t1061, [1, 512, 1], [0, 1]) # t1062: "cuda:0 f32[1, 512, 1]"
# t1063 = prims.div(t1062, 4096.0) # t1063: "cuda:0 f32[1, 512, 1]"
# t1064 = prims.add(t1063, 1e-05) # t1064: "cuda:0 f32[1, 512, 1]"
# t1065 = prims.rsqrt(t1064) # t1065: "cuda:0 f32[1, 512, 1]"
# t1066 = prims.broadcast_in_dim(t1065, (1, 512, 4096), (0, 1, 2)) # t1066: "cuda:0 f32[1, 512, 4096]"
# t1067 = prims.mul(t1057, t1066) # t1067: "cuda:0 f32[1, 512, 4096]"
# t1071 = prims.convert_element_type(t1069, dtypes.float32) # t1071: "cuda:0 f32[1, 512, 4096]"
# t1072 = prims.mul(t1067, t1071) # t1072: "cuda:0 f32[1, 512, 4096]"
# t1073 = prims.convert_element_type(t1072, dtypes.bfloat16) # t1073: "cuda:0 bf16[1, 512, 4096]"
t1074 = torch.nn.functional.linear(t1073, t27, None) # t1074: "cuda:0 bf16[1, 512, 11008]"
# t1074 = ltorch.linear(t1073, t27, None) # t1074: "cuda:0 bf16[1, 512, 11008]"
# t1074 = prims.linear(t1073, t27, None) # t1074: "cuda:0 bf16[1, 512, 11008]"
t1075 = torch.nn.functional.linear(t1073, t43, None) # t1075: "cuda:0 bf16[1, 512, 11008]"
# t1075 = ltorch.linear(t1073, t43, None) # t1075: "cuda:0 bf16[1, 512, 11008]"
# t1075 = prims.linear(t1073, t43, None) # t1075: "cuda:0 bf16[1, 512, 11008]"
[t1089] = nvFusion44(t1074, t1075)
# t1076 = prims.convert_element_type(t1074, dtypes.float32) # t1076: "cuda:0 f32[1, 512, 11008]"
# t1077 = prims.neg(t1076) # t1077: "cuda:0 f32[1, 512, 11008]"
# t1078 = prims.exp(t1077) # t1078: "cuda:0 f32[1, 512, 11008]"
# t1079 = prims.add(1.0, t1078) # t1079: "cuda:0 f32[1, 512, 11008]"
# t1080 = prims.reciprocal(t1079) # t1080: "cuda:0 f32[1, 512, 11008]"
# t1084 = prims.mul(t1076, t1080) # t1084: "cuda:0 f32[1, 512, 11008]"
# t1087 = prims.convert_element_type(t1075, dtypes.float32) # t1087: "cuda:0 f32[1, 512, 11008]"
# t1088 = prims.mul(t1084, t1087) # t1088: "cuda:0 f32[1, 512, 11008]"
# t1089 = prims.convert_element_type(t1088, dtypes.bfloat16) # t1089: "cuda:0 bf16[1, 512, 11008]"
t1090 = torch.nn.functional.linear(t1089, t102, None) # t1090: "cuda:0 bf16[1, 512, 4096]"
# t1090 = ltorch.linear(t1089, t102, None) # t1090: "cuda:0 bf16[1, 512, 4096]"
# t1090 = prims.linear(t1089, t102, None) # t1090: "cuda:0 bf16[1, 512, 4096]"
[t1094, t1101, t1109] = nvFusion45(t1058, t1090, t1105)
# t1092 = prims.convert_element_type(t1058, dtypes.float32) # t1092: "cuda:0 f32[1, 512, 4096]"
# t1091 = prims.convert_element_type(t1090, dtypes.float32) # t1091: "cuda:0 f32[1, 512, 4096]"
# t1093 = prims.add(t1091, t1092) # t1093: "cuda:0 f32[1, 512, 4096]"
# t1094 = prims.convert_element_type(t1093, dtypes.bfloat16) # t1094: "cuda:0 bf16[1, 512, 4096]"
# t1096 = prims.mul(t1093, t1093) # t1096: "cuda:0 f32[1, 512, 4096]"
# t1097 = prims.sum(t1096, (2,)) # t1097: "cuda:0 f32[1, 512]"
# t1098 = prims.broadcast_in_dim(t1097, [1, 512, 1], [0, 1]) # t1098: "cuda:0 f32[1, 512, 1]"
# t1099 = prims.div(t1098, 4096.0) # t1099: "cuda:0 f32[1, 512, 1]"
# t1100 = prims.add(t1099, 1e-05) # t1100: "cuda:0 f32[1, 512, 1]"
# t1101 = prims.rsqrt(t1100) # t1101: "cuda:0 f32[1, 512, 1]"
# t1102 = prims.broadcast_in_dim(t1101, (1, 512, 4096), (0, 1, 2)) # t1102: "cuda:0 f32[1, 512, 4096]"
# t1103 = prims.mul(t1093, t1102) # t1103: "cuda:0 f32[1, 512, 4096]"
# t1107 = prims.convert_element_type(t1105, dtypes.float32) # t1107: "cuda:0 f32[1, 512, 4096]"
# t1108 = prims.mul(t1103, t1107) # t1108: "cuda:0 f32[1, 512, 4096]"
# t1109 = prims.convert_element_type(t1108, dtypes.bfloat16) # t1109: "cuda:0 bf16[1, 512, 4096]"
t1110 = torch.nn.functional.linear(t1109, t12, None) # t1110: "cuda:0 bf16[1, 512, 12288]"
# t1110 = ltorch.linear(t1109, t12, None) # t1110: "cuda:0 bf16[1, 512, 12288]"
# t1110 = prims.linear(t1109, t12, None) # t1110: "cuda:0 bf16[1, 512, 12288]"
t1111 = torch.reshape(t1110, (1, 512, 32, 3, 128)) # t1111: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t1111 = ltorch.reshape(t1110, (1, 512, 32, 3, 128)) # t1111: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t1111 = prims.reshape(t1110, (1, 512, 32, 3, 128)) # t1111: "cuda:0 bf16[1, 512, 32, 3, 128]"
del t1110
t1112 = torch.permute(t1111, (0, 2, 3, 1, 4)) # t1112: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t1112 = ltorch.permute(t1111, (0, 2, 3, 1, 4)) # t1112: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t1112 = prims.transpose(t1111, (0, 2, 3, 1, 4)) # t1112: "cuda:0 bf16[1, 32, 3, 512, 128]"
del t1111
(t1113, t1114, t1115) = torch.split(t1112, (1, 1, 1), 2)
# (t1113, t1114, t1115) = ltorch.split(t1112, (1, 1, 1), 2)
# t1113 = prims.slice_prim(t1112, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1113: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t1114 = prims.slice_prim(t1112, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1114: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t1115 = prims.slice_prim(t1112, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1115: "cuda:0 bf16[1, 32, 1, 512, 128]"
del t1112
t1116 = torch.reshape(t1113, (1, 32, 512, 128)) # t1116: "cuda:0 bf16[1, 32, 512, 128]"
# t1116 = ltorch.reshape(t1113, (1, 32, 512, 128)) # t1116: "cuda:0 bf16[1, 32, 512, 128]"
# t1116 = prims.reshape(t1113, (1, 32, 512, 128)) # t1116: "cuda:0 bf16[1, 32, 512, 128]"
del t1113
t1117 = torch.reshape(t1114, (1, 32, 512, 128)) # t1117: "cuda:0 bf16[1, 32, 512, 128]"
# t1117 = ltorch.reshape(t1114, (1, 32, 512, 128)) # t1117: "cuda:0 bf16[1, 32, 512, 128]"
# t1117 = prims.reshape(t1114, (1, 32, 512, 128)) # t1117: "cuda:0 bf16[1, 32, 512, 128]"
del t1114
t1118 = torch.reshape(t1115, (1, 32, 512, 128)) # t1118: "cuda:0 bf16[1, 32, 512, 128]"
# t1118 = ltorch.reshape(t1115, (1, 32, 512, 128)) # t1118: "cuda:0 bf16[1, 32, 512, 128]"
# t1118 = prims.reshape(t1115, (1, 32, 512, 128)) # t1118: "cuda:0 bf16[1, 32, 512, 128]"
del t1115
t1119 = torch_slice_prim_impl(t1116, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1119: "cuda:0 bf16[1, 32, 512, 128]"
t1134 = torch_slice_prim_impl(t1117, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1134: "cuda:0 bf16[1, 32, 512, 128]"
t1149 = torch_slice_prim_impl(t1116, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1149: "cuda:0 bf16[1, 32, 512, 0]"
del t1116
t1151 = torch_slice_prim_impl(t1117, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1151: "cuda:0 bf16[1, 32, 512, 0]"
del t1117
t1120 = torch_slice_prim_impl(t1119, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1120: "cuda:0 bf16[1, 32, 512, 64]"
t1121 = torch_slice_prim_impl(t1119, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1121: "cuda:0 bf16[1, 32, 512, 64]"
t1136 = torch_slice_prim_impl(t1134, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1136: "cuda:0 bf16[1, 32, 512, 64]"
t1135 = torch_slice_prim_impl(t1134, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1135: "cuda:0 bf16[1, 32, 512, 64]"
[t1124, t1139] = nvFusion46(t1119, t1121, t1134, t1136)
# t1122 = prims.convert_element_type(t1121, dtypes.float32) # t1122: "cuda:0 f32[1, 32, 512, 64]"
# t1123 = prims.neg(t1122) # t1123: "cuda:0 f32[1, 32, 512, 64]"
# t1124 = prims.convert_element_type(t1123, dtypes.bfloat16) # t1124: "cuda:0 bf16[1, 32, 512, 64]"
# t1137 = prims.convert_element_type(t1136, dtypes.float32) # t1137: "cuda:0 f32[1, 32, 512, 64]"
# t1138 = prims.neg(t1137) # t1138: "cuda:0 f32[1, 32, 512, 64]"
# t1139 = prims.convert_element_type(t1138, dtypes.bfloat16) # t1139: "cuda:0 bf16[1, 32, 512, 64]"
del t1121, t1136
t1125 = torch.cat((t1124, t1120), -1) # t1125: "cuda:0 bf16[1, 32, 512, 128]"
# t1125 = ltorch.cat((t1124, t1120), -1) # t1125: "cuda:0 bf16[1, 32, 512, 128]"
# t1125 = prims.cat((t1124, t1120), -1) # t1125: "cuda:0 bf16[1, 32, 512, 128]"
del t1124, t1120
t1140 = torch.cat((t1139, t1135), -1) # t1140: "cuda:0 bf16[1, 32, 512, 128]"
# t1140 = ltorch.cat((t1139, t1135), -1) # t1140: "cuda:0 bf16[1, 32, 512, 128]"
# t1140 = prims.cat((t1139, t1135), -1) # t1140: "cuda:0 bf16[1, 32, 512, 128]"
del t1139, t1135
[t1133, t1148] = nvFusion47(t1119, t1125, t1134, t1140, t154, t157)
# t1127 = prims.convert_element_type(t1119, dtypes.float32) # t1127: "cuda:0 f32[1, 32, 512, 128]"
# t1142 = prims.convert_element_type(t1134, dtypes.float32) # t1142: "cuda:0 f32[1, 32, 512, 128]"
# t1128 = prims.mul(t1127, t154) # t1128: "cuda:0 f32[1, 32, 512, 128]"
# t1130 = prims.convert_element_type(t1125, dtypes.float32) # t1130: "cuda:0 f32[1, 32, 512, 128]"
# t1131 = prims.mul(t1130, t157) # t1131: "cuda:0 f32[1, 32, 512, 128]"
# t1132 = prims.add(t1128, t1131) # t1132: "cuda:0 f32[1, 32, 512, 128]"
# t1133 = prims.convert_element_type(t1132, dtypes.bfloat16) # t1133: "cuda:0 bf16[1, 32, 512, 128]"
# t1143 = prims.mul(t1142, t154) # t1143: "cuda:0 f32[1, 32, 512, 128]"
# t1145 = prims.convert_element_type(t1140, dtypes.float32) # t1145: "cuda:0 f32[1, 32, 512, 128]"
# t1146 = prims.mul(t1145, t157) # t1146: "cuda:0 f32[1, 32, 512, 128]"
# t1147 = prims.add(t1143, t1146) # t1147: "cuda:0 f32[1, 32, 512, 128]"
# t1148 = prims.convert_element_type(t1147, dtypes.bfloat16) # t1148: "cuda:0 bf16[1, 32, 512, 128]"
del t1119, t1125, t1134, t1140
t1152 = torch.cat((t1148, t1151), -1) # t1152: "cuda:0 bf16[1, 32, 512, 128]"
# t1152 = ltorch.cat((t1148, t1151), -1) # t1152: "cuda:0 bf16[1, 32, 512, 128]"
# t1152 = prims.cat((t1148, t1151), -1) # t1152: "cuda:0 bf16[1, 32, 512, 128]"
del t1148, t1151
t1150 = torch.cat((t1133, t1149), -1) # t1150: "cuda:0 bf16[1, 32, 512, 128]"
# t1150 = ltorch.cat((t1133, t1149), -1) # t1150: "cuda:0 bf16[1, 32, 512, 128]"
# t1150 = prims.cat((t1133, t1149), -1) # t1150: "cuda:0 bf16[1, 32, 512, 128]"
del t1133, t1149
(t1153, t1154, t1155, t1156, _, _, t1157, t1158, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1150, t1152, t1118, 0.0, True, scale=0.08838834764831843)
t1160 = torch.permute(t1153, (0, 2, 1, 3)) # t1160: "cuda:0 bf16[1, 512, 32, 128]"
# t1160 = ltorch.permute(t1153, (0, 2, 1, 3)) # t1160: "cuda:0 bf16[1, 512, 32, 128]"
# t1160 = prims.transpose(t1153, (0, 2, 1, 3)) # t1160: "cuda:0 bf16[1, 512, 32, 128]"
t1161 = torch.reshape(t1160, (1, 512, 4096)) # t1161: "cuda:0 bf16[1, 512, 4096]"
# t1161 = ltorch.reshape(t1160, (1, 512, 4096)) # t1161: "cuda:0 bf16[1, 512, 4096]"
# t1161 = prims.reshape(t1160, (1, 512, 4096)) # t1161: "cuda:0 bf16[1, 512, 4096]"
del t1160
t1162 = torch.nn.functional.linear(t1161, t103, None) # t1162: "cuda:0 bf16[1, 512, 4096]"
# t1162 = ltorch.linear(t1161, t103, None) # t1162: "cuda:0 bf16[1, 512, 4096]"
# t1162 = prims.linear(t1161, t103, None) # t1162: "cuda:0 bf16[1, 512, 4096]"
[t1166, t1173, t1181] = nvFusion48(t1094, t1162, t1177)
# t1164 = prims.convert_element_type(t1094, dtypes.float32) # t1164: "cuda:0 f32[1, 512, 4096]"
# t1163 = prims.convert_element_type(t1162, dtypes.float32) # t1163: "cuda:0 f32[1, 512, 4096]"
# t1165 = prims.add(t1163, t1164) # t1165: "cuda:0 f32[1, 512, 4096]"
# t1166 = prims.convert_element_type(t1165, dtypes.bfloat16) # t1166: "cuda:0 bf16[1, 512, 4096]"
# t1168 = prims.mul(t1165, t1165) # t1168: "cuda:0 f32[1, 512, 4096]"
# t1169 = prims.sum(t1168, (2,)) # t1169: "cuda:0 f32[1, 512]"
# t1170 = prims.broadcast_in_dim(t1169, [1, 512, 1], [0, 1]) # t1170: "cuda:0 f32[1, 512, 1]"
# t1171 = prims.div(t1170, 4096.0) # t1171: "cuda:0 f32[1, 512, 1]"
# t1172 = prims.add(t1171, 1e-05) # t1172: "cuda:0 f32[1, 512, 1]"
# t1173 = prims.rsqrt(t1172) # t1173: "cuda:0 f32[1, 512, 1]"
# t1174 = prims.broadcast_in_dim(t1173, (1, 512, 4096), (0, 1, 2)) # t1174: "cuda:0 f32[1, 512, 4096]"
# t1175 = prims.mul(t1165, t1174) # t1175: "cuda:0 f32[1, 512, 4096]"
# t1179 = prims.convert_element_type(t1177, dtypes.float32) # t1179: "cuda:0 f32[1, 512, 4096]"
# t1180 = prims.mul(t1175, t1179) # t1180: "cuda:0 f32[1, 512, 4096]"
# t1181 = prims.convert_element_type(t1180, dtypes.bfloat16) # t1181: "cuda:0 bf16[1, 512, 4096]"
t1182 = torch.nn.functional.linear(t1181, t28, None) # t1182: "cuda:0 bf16[1, 512, 11008]"
# t1182 = ltorch.linear(t1181, t28, None) # t1182: "cuda:0 bf16[1, 512, 11008]"
# t1182 = prims.linear(t1181, t28, None) # t1182: "cuda:0 bf16[1, 512, 11008]"
t1183 = torch.nn.functional.linear(t1181, t44, None) # t1183: "cuda:0 bf16[1, 512, 11008]"
# t1183 = ltorch.linear(t1181, t44, None) # t1183: "cuda:0 bf16[1, 512, 11008]"
# t1183 = prims.linear(t1181, t44, None) # t1183: "cuda:0 bf16[1, 512, 11008]"
[t1197] = nvFusion49(t1182, t1183)
# t1184 = prims.convert_element_type(t1182, dtypes.float32) # t1184: "cuda:0 f32[1, 512, 11008]"
# t1185 = prims.neg(t1184) # t1185: "cuda:0 f32[1, 512, 11008]"
# t1186 = prims.exp(t1185) # t1186: "cuda:0 f32[1, 512, 11008]"
# t1187 = prims.add(1.0, t1186) # t1187: "cuda:0 f32[1, 512, 11008]"
# t1188 = prims.reciprocal(t1187) # t1188: "cuda:0 f32[1, 512, 11008]"
# t1192 = prims.mul(t1184, t1188) # t1192: "cuda:0 f32[1, 512, 11008]"
# t1195 = prims.convert_element_type(t1183, dtypes.float32) # t1195: "cuda:0 f32[1, 512, 11008]"
# t1196 = prims.mul(t1192, t1195) # t1196: "cuda:0 f32[1, 512, 11008]"
# t1197 = prims.convert_element_type(t1196, dtypes.bfloat16) # t1197: "cuda:0 bf16[1, 512, 11008]"
t1198 = torch.nn.functional.linear(t1197, t104, None) # t1198: "cuda:0 bf16[1, 512, 4096]"
# t1198 = ltorch.linear(t1197, t104, None) # t1198: "cuda:0 bf16[1, 512, 4096]"
# t1198 = prims.linear(t1197, t104, None) # t1198: "cuda:0 bf16[1, 512, 4096]"
[t1202, t1209, t1217] = nvFusion50(t1166, t1198, t1213)
# t1200 = prims.convert_element_type(t1166, dtypes.float32) # t1200: "cuda:0 f32[1, 512, 4096]"
# t1199 = prims.convert_element_type(t1198, dtypes.float32) # t1199: "cuda:0 f32[1, 512, 4096]"
# t1201 = prims.add(t1199, t1200) # t1201: "cuda:0 f32[1, 512, 4096]"
# t1202 = prims.convert_element_type(t1201, dtypes.bfloat16) # t1202: "cuda:0 bf16[1, 512, 4096]"
# t1204 = prims.mul(t1201, t1201) # t1204: "cuda:0 f32[1, 512, 4096]"
# t1205 = prims.sum(t1204, (2,)) # t1205: "cuda:0 f32[1, 512]"
# t1206 = prims.broadcast_in_dim(t1205, [1, 512, 1], [0, 1]) # t1206: "cuda:0 f32[1, 512, 1]"
# t1207 = prims.div(t1206, 4096.0) # t1207: "cuda:0 f32[1, 512, 1]"
# t1208 = prims.add(t1207, 1e-05) # t1208: "cuda:0 f32[1, 512, 1]"
# t1209 = prims.rsqrt(t1208) # t1209: "cuda:0 f32[1, 512, 1]"
# t1210 = prims.broadcast_in_dim(t1209, (1, 512, 4096), (0, 1, 2)) # t1210: "cuda:0 f32[1, 512, 4096]"
# t1211 = prims.mul(t1201, t1210) # t1211: "cuda:0 f32[1, 512, 4096]"
# t1215 = prims.convert_element_type(t1213, dtypes.float32) # t1215: "cuda:0 f32[1, 512, 4096]"
# t1216 = prims.mul(t1211, t1215) # t1216: "cuda:0 f32[1, 512, 4096]"
# t1217 = prims.convert_element_type(t1216, dtypes.bfloat16) # t1217: "cuda:0 bf16[1, 512, 4096]"
t1218 = torch.nn.functional.linear(t1217, t13, None) # t1218: "cuda:0 bf16[1, 512, 12288]"
# t1218 = ltorch.linear(t1217, t13, None) # t1218: "cuda:0 bf16[1, 512, 12288]"
# t1218 = prims.linear(t1217, t13, None) # t1218: "cuda:0 bf16[1, 512, 12288]"
t1219 = torch.reshape(t1218, (1, 512, 32, 3, 128)) # t1219: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t1219 = ltorch.reshape(t1218, (1, 512, 32, 3, 128)) # t1219: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t1219 = prims.reshape(t1218, (1, 512, 32, 3, 128)) # t1219: "cuda:0 bf16[1, 512, 32, 3, 128]"
del t1218
t1220 = torch.permute(t1219, (0, 2, 3, 1, 4)) # t1220: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t1220 = ltorch.permute(t1219, (0, 2, 3, 1, 4)) # t1220: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t1220 = prims.transpose(t1219, (0, 2, 3, 1, 4)) # t1220: "cuda:0 bf16[1, 32, 3, 512, 128]"
del t1219
(t1221, t1222, t1223) = torch.split(t1220, (1, 1, 1), 2)
# (t1221, t1222, t1223) = ltorch.split(t1220, (1, 1, 1), 2)
# t1221 = prims.slice_prim(t1220, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1221: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t1222 = prims.slice_prim(t1220, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1222: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t1223 = prims.slice_prim(t1220, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1223: "cuda:0 bf16[1, 32, 1, 512, 128]"
del t1220
t1224 = torch.reshape(t1221, (1, 32, 512, 128)) # t1224: "cuda:0 bf16[1, 32, 512, 128]"
# t1224 = ltorch.reshape(t1221, (1, 32, 512, 128)) # t1224: "cuda:0 bf16[1, 32, 512, 128]"
# t1224 = prims.reshape(t1221, (1, 32, 512, 128)) # t1224: "cuda:0 bf16[1, 32, 512, 128]"
del t1221
t1225 = torch.reshape(t1222, (1, 32, 512, 128)) # t1225: "cuda:0 bf16[1, 32, 512, 128]"
# t1225 = ltorch.reshape(t1222, (1, 32, 512, 128)) # t1225: "cuda:0 bf16[1, 32, 512, 128]"
# t1225 = prims.reshape(t1222, (1, 32, 512, 128)) # t1225: "cuda:0 bf16[1, 32, 512, 128]"
del t1222
t1226 = torch.reshape(t1223, (1, 32, 512, 128)) # t1226: "cuda:0 bf16[1, 32, 512, 128]"
# t1226 = ltorch.reshape(t1223, (1, 32, 512, 128)) # t1226: "cuda:0 bf16[1, 32, 512, 128]"
# t1226 = prims.reshape(t1223, (1, 32, 512, 128)) # t1226: "cuda:0 bf16[1, 32, 512, 128]"
del t1223
t1227 = torch_slice_prim_impl(t1224, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1227: "cuda:0 bf16[1, 32, 512, 128]"
t1242 = torch_slice_prim_impl(t1225, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1242: "cuda:0 bf16[1, 32, 512, 128]"
t1257 = torch_slice_prim_impl(t1224, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1257: "cuda:0 bf16[1, 32, 512, 0]"
del t1224
t1259 = torch_slice_prim_impl(t1225, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1259: "cuda:0 bf16[1, 32, 512, 0]"
del t1225
t1228 = torch_slice_prim_impl(t1227, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1228: "cuda:0 bf16[1, 32, 512, 64]"
t1229 = torch_slice_prim_impl(t1227, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1229: "cuda:0 bf16[1, 32, 512, 64]"
t1243 = torch_slice_prim_impl(t1242, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1243: "cuda:0 bf16[1, 32, 512, 64]"
t1244 = torch_slice_prim_impl(t1242, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1244: "cuda:0 bf16[1, 32, 512, 64]"
[t1232, t1247] = nvFusion51(t1227, t1229, t1242, t1244)
# t1230 = prims.convert_element_type(t1229, dtypes.float32) # t1230: "cuda:0 f32[1, 32, 512, 64]"
# t1231 = prims.neg(t1230) # t1231: "cuda:0 f32[1, 32, 512, 64]"
# t1232 = prims.convert_element_type(t1231, dtypes.bfloat16) # t1232: "cuda:0 bf16[1, 32, 512, 64]"
# t1245 = prims.convert_element_type(t1244, dtypes.float32) # t1245: "cuda:0 f32[1, 32, 512, 64]"
# t1246 = prims.neg(t1245) # t1246: "cuda:0 f32[1, 32, 512, 64]"
# t1247 = prims.convert_element_type(t1246, dtypes.bfloat16) # t1247: "cuda:0 bf16[1, 32, 512, 64]"
del t1229, t1244
t1233 = torch.cat((t1232, t1228), -1) # t1233: "cuda:0 bf16[1, 32, 512, 128]"
# t1233 = ltorch.cat((t1232, t1228), -1) # t1233: "cuda:0 bf16[1, 32, 512, 128]"
# t1233 = prims.cat((t1232, t1228), -1) # t1233: "cuda:0 bf16[1, 32, 512, 128]"
del t1232, t1228
t1248 = torch.cat((t1247, t1243), -1) # t1248: "cuda:0 bf16[1, 32, 512, 128]"
# t1248 = ltorch.cat((t1247, t1243), -1) # t1248: "cuda:0 bf16[1, 32, 512, 128]"
# t1248 = prims.cat((t1247, t1243), -1) # t1248: "cuda:0 bf16[1, 32, 512, 128]"
del t1247, t1243
[t1241, t1256] = nvFusion52(t1227, t1233, t1242, t1248, t154, t157)
# t1235 = prims.convert_element_type(t1227, dtypes.float32) # t1235: "cuda:0 f32[1, 32, 512, 128]"
# t1250 = prims.convert_element_type(t1242, dtypes.float32) # t1250: "cuda:0 f32[1, 32, 512, 128]"
# t1236 = prims.mul(t1235, t154) # t1236: "cuda:0 f32[1, 32, 512, 128]"
# t1238 = prims.convert_element_type(t1233, dtypes.float32) # t1238: "cuda:0 f32[1, 32, 512, 128]"
# t1239 = prims.mul(t1238, t157) # t1239: "cuda:0 f32[1, 32, 512, 128]"
# t1240 = prims.add(t1236, t1239) # t1240: "cuda:0 f32[1, 32, 512, 128]"
# t1241 = prims.convert_element_type(t1240, dtypes.bfloat16) # t1241: "cuda:0 bf16[1, 32, 512, 128]"
# t1251 = prims.mul(t1250, t154) # t1251: "cuda:0 f32[1, 32, 512, 128]"
# t1253 = prims.convert_element_type(t1248, dtypes.float32) # t1253: "cuda:0 f32[1, 32, 512, 128]"
# t1254 = prims.mul(t1253, t157) # t1254: "cuda:0 f32[1, 32, 512, 128]"
# t1255 = prims.add(t1251, t1254) # t1255: "cuda:0 f32[1, 32, 512, 128]"
# t1256 = prims.convert_element_type(t1255, dtypes.bfloat16) # t1256: "cuda:0 bf16[1, 32, 512, 128]"
del t1227, t1233, t1242, t1248
t1258 = torch.cat((t1241, t1257), -1) # t1258: "cuda:0 bf16[1, 32, 512, 128]"
# t1258 = ltorch.cat((t1241, t1257), -1) # t1258: "cuda:0 bf16[1, 32, 512, 128]"
# t1258 = prims.cat((t1241, t1257), -1) # t1258: "cuda:0 bf16[1, 32, 512, 128]"
del t1241, t1257
t1260 = torch.cat((t1256, t1259), -1) # t1260: "cuda:0 bf16[1, 32, 512, 128]"
# t1260 = ltorch.cat((t1256, t1259), -1) # t1260: "cuda:0 bf16[1, 32, 512, 128]"
# t1260 = prims.cat((t1256, t1259), -1) # t1260: "cuda:0 bf16[1, 32, 512, 128]"
del t1256, t1259
(t1261, t1262, t1263, t1264, _, _, t1265, t1266, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1258, t1260, t1226, 0.0, True, scale=0.08838834764831843)
t1268 = torch.permute(t1261, (0, 2, 1, 3)) # t1268: "cuda:0 bf16[1, 512, 32, 128]"
# t1268 = ltorch.permute(t1261, (0, 2, 1, 3)) # t1268: "cuda:0 bf16[1, 512, 32, 128]"
# t1268 = prims.transpose(t1261, (0, 2, 1, 3)) # t1268: "cuda:0 bf16[1, 512, 32, 128]"
t1269 = torch.reshape(t1268, (1, 512, 4096)) # t1269: "cuda:0 bf16[1, 512, 4096]"
# t1269 = ltorch.reshape(t1268, (1, 512, 4096)) # t1269: "cuda:0 bf16[1, 512, 4096]"
# t1269 = prims.reshape(t1268, (1, 512, 4096)) # t1269: "cuda:0 bf16[1, 512, 4096]"
del t1268
t1270 = torch.nn.functional.linear(t1269, t105, None) # t1270: "cuda:0 bf16[1, 512, 4096]"
# t1270 = ltorch.linear(t1269, t105, None) # t1270: "cuda:0 bf16[1, 512, 4096]"
# t1270 = prims.linear(t1269, t105, None) # t1270: "cuda:0 bf16[1, 512, 4096]"
[t1274, t1281, t1289] = nvFusion53(t1202, t1270, t1285)
# t1272 = prims.convert_element_type(t1202, dtypes.float32) # t1272: "cuda:0 f32[1, 512, 4096]"
# t1271 = prims.convert_element_type(t1270, dtypes.float32) # t1271: "cuda:0 f32[1, 512, 4096]"
# t1273 = prims.add(t1271, t1272) # t1273: "cuda:0 f32[1, 512, 4096]"
# t1274 = prims.convert_element_type(t1273, dtypes.bfloat16) # t1274: "cuda:0 bf16[1, 512, 4096]"
# t1276 = prims.mul(t1273, t1273) # t1276: "cuda:0 f32[1, 512, 4096]"
# t1277 = prims.sum(t1276, (2,)) # t1277: "cuda:0 f32[1, 512]"
# t1278 = prims.broadcast_in_dim(t1277, [1, 512, 1], [0, 1]) # t1278: "cuda:0 f32[1, 512, 1]"
# t1279 = prims.div(t1278, 4096.0) # t1279: "cuda:0 f32[1, 512, 1]"
# t1280 = prims.add(t1279, 1e-05) # t1280: "cuda:0 f32[1, 512, 1]"
# t1281 = prims.rsqrt(t1280) # t1281: "cuda:0 f32[1, 512, 1]"
# t1282 = prims.broadcast_in_dim(t1281, (1, 512, 4096), (0, 1, 2)) # t1282: "cuda:0 f32[1, 512, 4096]"
# t1283 = prims.mul(t1273, t1282) # t1283: "cuda:0 f32[1, 512, 4096]"
# t1287 = prims.convert_element_type(t1285, dtypes.float32) # t1287: "cuda:0 f32[1, 512, 4096]"
# t1288 = prims.mul(t1283, t1287) # t1288: "cuda:0 f32[1, 512, 4096]"
# t1289 = prims.convert_element_type(t1288, dtypes.bfloat16) # t1289: "cuda:0 bf16[1, 512, 4096]"
t1290 = torch.nn.functional.linear(t1289, t29, None) # t1290: "cuda:0 bf16[1, 512, 11008]"
# t1290 = ltorch.linear(t1289, t29, None) # t1290: "cuda:0 bf16[1, 512, 11008]"
# t1290 = prims.linear(t1289, t29, None) # t1290: "cuda:0 bf16[1, 512, 11008]"
t1291 = torch.nn.functional.linear(t1289, t45, None) # t1291: "cuda:0 bf16[1, 512, 11008]"
# t1291 = ltorch.linear(t1289, t45, None) # t1291: "cuda:0 bf16[1, 512, 11008]"
# t1291 = prims.linear(t1289, t45, None) # t1291: "cuda:0 bf16[1, 512, 11008]"
[t1305] = nvFusion54(t1290, t1291)
# t1292 = prims.convert_element_type(t1290, dtypes.float32) # t1292: "cuda:0 f32[1, 512, 11008]"
# t1293 = prims.neg(t1292) # t1293: "cuda:0 f32[1, 512, 11008]"
# t1294 = prims.exp(t1293) # t1294: "cuda:0 f32[1, 512, 11008]"
# t1295 = prims.add(1.0, t1294) # t1295: "cuda:0 f32[1, 512, 11008]"
# t1296 = prims.reciprocal(t1295) # t1296: "cuda:0 f32[1, 512, 11008]"
# t1300 = prims.mul(t1292, t1296) # t1300: "cuda:0 f32[1, 512, 11008]"
# t1303 = prims.convert_element_type(t1291, dtypes.float32) # t1303: "cuda:0 f32[1, 512, 11008]"
# t1304 = prims.mul(t1300, t1303) # t1304: "cuda:0 f32[1, 512, 11008]"
# t1305 = prims.convert_element_type(t1304, dtypes.bfloat16) # t1305: "cuda:0 bf16[1, 512, 11008]"
t1306 = torch.nn.functional.linear(t1305, t106, None) # t1306: "cuda:0 bf16[1, 512, 4096]"
# t1306 = ltorch.linear(t1305, t106, None) # t1306: "cuda:0 bf16[1, 512, 4096]"
# t1306 = prims.linear(t1305, t106, None) # t1306: "cuda:0 bf16[1, 512, 4096]"
[t1310, t1317, t1325] = nvFusion55(t1274, t1306, t1321)
# t1308 = prims.convert_element_type(t1274, dtypes.float32) # t1308: "cuda:0 f32[1, 512, 4096]"
# t1307 = prims.convert_element_type(t1306, dtypes.float32) # t1307: "cuda:0 f32[1, 512, 4096]"
# t1309 = prims.add(t1307, t1308) # t1309: "cuda:0 f32[1, 512, 4096]"
# t1310 = prims.convert_element_type(t1309, dtypes.bfloat16) # t1310: "cuda:0 bf16[1, 512, 4096]"
# t1312 = prims.mul(t1309, t1309) # t1312: "cuda:0 f32[1, 512, 4096]"
# t1313 = prims.sum(t1312, (2,)) # t1313: "cuda:0 f32[1, 512]"
# t1314 = prims.broadcast_in_dim(t1313, [1, 512, 1], [0, 1]) # t1314: "cuda:0 f32[1, 512, 1]"
# t1315 = prims.div(t1314, 4096.0) # t1315: "cuda:0 f32[1, 512, 1]"
# t1316 = prims.add(t1315, 1e-05) # t1316: "cuda:0 f32[1, 512, 1]"
# t1317 = prims.rsqrt(t1316) # t1317: "cuda:0 f32[1, 512, 1]"
# t1318 = prims.broadcast_in_dim(t1317, (1, 512, 4096), (0, 1, 2)) # t1318: "cuda:0 f32[1, 512, 4096]"
# t1319 = prims.mul(t1309, t1318) # t1319: "cuda:0 f32[1, 512, 4096]"
# t1323 = prims.convert_element_type(t1321, dtypes.float32) # t1323: "cuda:0 f32[1, 512, 4096]"
# t1324 = prims.mul(t1319, t1323) # t1324: "cuda:0 f32[1, 512, 4096]"
# t1325 = prims.convert_element_type(t1324, dtypes.bfloat16) # t1325: "cuda:0 bf16[1, 512, 4096]"
t1326 = torch.nn.functional.linear(t1325, t14, None) # t1326: "cuda:0 bf16[1, 512, 12288]"
# t1326 = ltorch.linear(t1325, t14, None) # t1326: "cuda:0 bf16[1, 512, 12288]"
# t1326 = prims.linear(t1325, t14, None) # t1326: "cuda:0 bf16[1, 512, 12288]"
t1327 = torch.reshape(t1326, (1, 512, 32, 3, 128)) # t1327: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t1327 = ltorch.reshape(t1326, (1, 512, 32, 3, 128)) # t1327: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t1327 = prims.reshape(t1326, (1, 512, 32, 3, 128)) # t1327: "cuda:0 bf16[1, 512, 32, 3, 128]"
del t1326
t1328 = torch.permute(t1327, (0, 2, 3, 1, 4)) # t1328: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t1328 = ltorch.permute(t1327, (0, 2, 3, 1, 4)) # t1328: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t1328 = prims.transpose(t1327, (0, 2, 3, 1, 4)) # t1328: "cuda:0 bf16[1, 32, 3, 512, 128]"
del t1327
(t1329, t1330, t1331) = torch.split(t1328, (1, 1, 1), 2)
# (t1329, t1330, t1331) = ltorch.split(t1328, (1, 1, 1), 2)
# t1329 = prims.slice_prim(t1328, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1329: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t1330 = prims.slice_prim(t1328, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1330: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t1331 = prims.slice_prim(t1328, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1331: "cuda:0 bf16[1, 32, 1, 512, 128]"
del t1328
t1332 = torch.reshape(t1329, (1, 32, 512, 128)) # t1332: "cuda:0 bf16[1, 32, 512, 128]"
# t1332 = ltorch.reshape(t1329, (1, 32, 512, 128)) # t1332: "cuda:0 bf16[1, 32, 512, 128]"
# t1332 = prims.reshape(t1329, (1, 32, 512, 128)) # t1332: "cuda:0 bf16[1, 32, 512, 128]"
del t1329
t1333 = torch.reshape(t1330, (1, 32, 512, 128)) # t1333: "cuda:0 bf16[1, 32, 512, 128]"
# t1333 = ltorch.reshape(t1330, (1, 32, 512, 128)) # t1333: "cuda:0 bf16[1, 32, 512, 128]"
# t1333 = prims.reshape(t1330, (1, 32, 512, 128)) # t1333: "cuda:0 bf16[1, 32, 512, 128]"
del t1330
t1334 = torch.reshape(t1331, (1, 32, 512, 128)) # t1334: "cuda:0 bf16[1, 32, 512, 128]"
# t1334 = ltorch.reshape(t1331, (1, 32, 512, 128)) # t1334: "cuda:0 bf16[1, 32, 512, 128]"
# t1334 = prims.reshape(t1331, (1, 32, 512, 128)) # t1334: "cuda:0 bf16[1, 32, 512, 128]"
del t1331
t1335 = torch_slice_prim_impl(t1332, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1335: "cuda:0 bf16[1, 32, 512, 128]"
t1350 = torch_slice_prim_impl(t1333, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1350: "cuda:0 bf16[1, 32, 512, 128]"
t1365 = torch_slice_prim_impl(t1332, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1365: "cuda:0 bf16[1, 32, 512, 0]"
del t1332
t1367 = torch_slice_prim_impl(t1333, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1367: "cuda:0 bf16[1, 32, 512, 0]"
del t1333
t1336 = torch_slice_prim_impl(t1335, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1336: "cuda:0 bf16[1, 32, 512, 64]"
t1337 = torch_slice_prim_impl(t1335, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1337: "cuda:0 bf16[1, 32, 512, 64]"
t1351 = torch_slice_prim_impl(t1350, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1351: "cuda:0 bf16[1, 32, 512, 64]"
t1352 = torch_slice_prim_impl(t1350, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1352: "cuda:0 bf16[1, 32, 512, 64]"
[t1340, t1355] = nvFusion56(t1335, t1337, t1350, t1352)
# t1338 = prims.convert_element_type(t1337, dtypes.float32) # t1338: "cuda:0 f32[1, 32, 512, 64]"
# t1339 = prims.neg(t1338) # t1339: "cuda:0 f32[1, 32, 512, 64]"
# t1340 = prims.convert_element_type(t1339, dtypes.bfloat16) # t1340: "cuda:0 bf16[1, 32, 512, 64]"
# t1353 = prims.convert_element_type(t1352, dtypes.float32) # t1353: "cuda:0 f32[1, 32, 512, 64]"
# t1354 = prims.neg(t1353) # t1354: "cuda:0 f32[1, 32, 512, 64]"
# t1355 = prims.convert_element_type(t1354, dtypes.bfloat16) # t1355: "cuda:0 bf16[1, 32, 512, 64]"
del t1337, t1352
t1341 = torch.cat((t1340, t1336), -1) # t1341: "cuda:0 bf16[1, 32, 512, 128]"
# t1341 = ltorch.cat((t1340, t1336), -1) # t1341: "cuda:0 bf16[1, 32, 512, 128]"
# t1341 = prims.cat((t1340, t1336), -1) # t1341: "cuda:0 bf16[1, 32, 512, 128]"
del t1340, t1336
t1356 = torch.cat((t1355, t1351), -1) # t1356: "cuda:0 bf16[1, 32, 512, 128]"
# t1356 = ltorch.cat((t1355, t1351), -1) # t1356: "cuda:0 bf16[1, 32, 512, 128]"
# t1356 = prims.cat((t1355, t1351), -1) # t1356: "cuda:0 bf16[1, 32, 512, 128]"
del t1355, t1351
[t1349, t1364] = nvFusion57(t1335, t1341, t1350, t1356, t154, t157)
# t1343 = prims.convert_element_type(t1335, dtypes.float32) # t1343: "cuda:0 f32[1, 32, 512, 128]"
# t1358 = prims.convert_element_type(t1350, dtypes.float32) # t1358: "cuda:0 f32[1, 32, 512, 128]"
# t1344 = prims.mul(t1343, t154) # t1344: "cuda:0 f32[1, 32, 512, 128]"
# t1346 = prims.convert_element_type(t1341, dtypes.float32) # t1346: "cuda:0 f32[1, 32, 512, 128]"
# t1347 = prims.mul(t1346, t157) # t1347: "cuda:0 f32[1, 32, 512, 128]"
# t1348 = prims.add(t1344, t1347) # t1348: "cuda:0 f32[1, 32, 512, 128]"
# t1349 = prims.convert_element_type(t1348, dtypes.bfloat16) # t1349: "cuda:0 bf16[1, 32, 512, 128]"
# t1359 = prims.mul(t1358, t154) # t1359: "cuda:0 f32[1, 32, 512, 128]"
# t1361 = prims.convert_element_type(t1356, dtypes.float32) # t1361: "cuda:0 f32[1, 32, 512, 128]"
# t1362 = prims.mul(t1361, t157) # t1362: "cuda:0 f32[1, 32, 512, 128]"
# t1363 = prims.add(t1359, t1362) # t1363: "cuda:0 f32[1, 32, 512, 128]"
# t1364 = prims.convert_element_type(t1363, dtypes.bfloat16) # t1364: "cuda:0 bf16[1, 32, 512, 128]"
del t1335, t1341, t1350, t1356
t1366 = torch.cat((t1349, t1365), -1) # t1366: "cuda:0 bf16[1, 32, 512, 128]"
# t1366 = ltorch.cat((t1349, t1365), -1) # t1366: "cuda:0 bf16[1, 32, 512, 128]"
# t1366 = prims.cat((t1349, t1365), -1) # t1366: "cuda:0 bf16[1, 32, 512, 128]"
del t1349, t1365
t1368 = torch.cat((t1364, t1367), -1) # t1368: "cuda:0 bf16[1, 32, 512, 128]"
# t1368 = ltorch.cat((t1364, t1367), -1) # t1368: "cuda:0 bf16[1, 32, 512, 128]"
# t1368 = prims.cat((t1364, t1367), -1) # t1368: "cuda:0 bf16[1, 32, 512, 128]"
del t1364, t1367
(t1369, t1370, t1371, t1372, _, _, t1373, t1374, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1366, t1368, t1334, 0.0, True, scale=0.08838834764831843)
t1376 = torch.permute(t1369, (0, 2, 1, 3)) # t1376: "cuda:0 bf16[1, 512, 32, 128]"
# t1376 = ltorch.permute(t1369, (0, 2, 1, 3)) # t1376: "cuda:0 bf16[1, 512, 32, 128]"
# t1376 = prims.transpose(t1369, (0, 2, 1, 3)) # t1376: "cuda:0 bf16[1, 512, 32, 128]"
t1377 = torch.reshape(t1376, (1, 512, 4096)) # t1377: "cuda:0 bf16[1, 512, 4096]"
# t1377 = ltorch.reshape(t1376, (1, 512, 4096)) # t1377: "cuda:0 bf16[1, 512, 4096]"
# t1377 = prims.reshape(t1376, (1, 512, 4096)) # t1377: "cuda:0 bf16[1, 512, 4096]"
del t1376
t1378 = torch.nn.functional.linear(t1377, t107, None) # t1378: "cuda:0 bf16[1, 512, 4096]"
# t1378 = ltorch.linear(t1377, t107, None) # t1378: "cuda:0 bf16[1, 512, 4096]"
# t1378 = prims.linear(t1377, t107, None) # t1378: "cuda:0 bf16[1, 512, 4096]"
[t1382, t1389, t1397] = nvFusion58(t1310, t1378, t1393)
# t1380 = prims.convert_element_type(t1310, dtypes.float32) # t1380: "cuda:0 f32[1, 512, 4096]"
# t1379 = prims.convert_element_type(t1378, dtypes.float32) # t1379: "cuda:0 f32[1, 512, 4096]"
# t1381 = prims.add(t1379, t1380) # t1381: "cuda:0 f32[1, 512, 4096]"
# t1382 = prims.convert_element_type(t1381, dtypes.bfloat16) # t1382: "cuda:0 bf16[1, 512, 4096]"
# t1384 = prims.mul(t1381, t1381) # t1384: "cuda:0 f32[1, 512, 4096]"
# t1385 = prims.sum(t1384, (2,)) # t1385: "cuda:0 f32[1, 512]"
# t1386 = prims.broadcast_in_dim(t1385, [1, 512, 1], [0, 1]) # t1386: "cuda:0 f32[1, 512, 1]"
# t1387 = prims.div(t1386, 4096.0) # t1387: "cuda:0 f32[1, 512, 1]"
# t1388 = prims.add(t1387, 1e-05) # t1388: "cuda:0 f32[1, 512, 1]"
# t1389 = prims.rsqrt(t1388) # t1389: "cuda:0 f32[1, 512, 1]"
# t1390 = prims.broadcast_in_dim(t1389, (1, 512, 4096), (0, 1, 2)) # t1390: "cuda:0 f32[1, 512, 4096]"
# t1391 = prims.mul(t1381, t1390) # t1391: "cuda:0 f32[1, 512, 4096]"
# t1395 = prims.convert_element_type(t1393, dtypes.float32) # t1395: "cuda:0 f32[1, 512, 4096]"
# t1396 = prims.mul(t1391, t1395) # t1396: "cuda:0 f32[1, 512, 4096]"
# t1397 = prims.convert_element_type(t1396, dtypes.bfloat16) # t1397: "cuda:0 bf16[1, 512, 4096]"
t1398 = torch.nn.functional.linear(t1397, t30, None) # t1398: "cuda:0 bf16[1, 512, 11008]"
# t1398 = ltorch.linear(t1397, t30, None) # t1398: "cuda:0 bf16[1, 512, 11008]"
# t1398 = prims.linear(t1397, t30, None) # t1398: "cuda:0 bf16[1, 512, 11008]"
t1399 = torch.nn.functional.linear(t1397, t46, None) # t1399: "cuda:0 bf16[1, 512, 11008]"
# t1399 = ltorch.linear(t1397, t46, None) # t1399: "cuda:0 bf16[1, 512, 11008]"
# t1399 = prims.linear(t1397, t46, None) # t1399: "cuda:0 bf16[1, 512, 11008]"
[t1413] = nvFusion59(t1398, t1399)
# t1400 = prims.convert_element_type(t1398, dtypes.float32) # t1400: "cuda:0 f32[1, 512, 11008]"
# t1401 = prims.neg(t1400) # t1401: "cuda:0 f32[1, 512, 11008]"
# t1402 = prims.exp(t1401) # t1402: "cuda:0 f32[1, 512, 11008]"
# t1403 = prims.add(1.0, t1402) # t1403: "cuda:0 f32[1, 512, 11008]"
# t1404 = prims.reciprocal(t1403) # t1404: "cuda:0 f32[1, 512, 11008]"
# t1408 = prims.mul(t1400, t1404) # t1408: "cuda:0 f32[1, 512, 11008]"
# t1411 = prims.convert_element_type(t1399, dtypes.float32) # t1411: "cuda:0 f32[1, 512, 11008]"
# t1412 = prims.mul(t1408, t1411) # t1412: "cuda:0 f32[1, 512, 11008]"
# t1413 = prims.convert_element_type(t1412, dtypes.bfloat16) # t1413: "cuda:0 bf16[1, 512, 11008]"
t1414 = torch.nn.functional.linear(t1413, t108, None) # t1414: "cuda:0 bf16[1, 512, 4096]"
# t1414 = ltorch.linear(t1413, t108, None) # t1414: "cuda:0 bf16[1, 512, 4096]"
# t1414 = prims.linear(t1413, t108, None) # t1414: "cuda:0 bf16[1, 512, 4096]"
[t1418, t1425, t1433] = nvFusion60(t1382, t1414, t1429)
# t1416 = prims.convert_element_type(t1382, dtypes.float32) # t1416: "cuda:0 f32[1, 512, 4096]"
# t1415 = prims.convert_element_type(t1414, dtypes.float32) # t1415: "cuda:0 f32[1, 512, 4096]"
# t1417 = prims.add(t1415, t1416) # t1417: "cuda:0 f32[1, 512, 4096]"
# t1418 = prims.convert_element_type(t1417, dtypes.bfloat16) # t1418: "cuda:0 bf16[1, 512, 4096]"
# t1420 = prims.mul(t1417, t1417) # t1420: "cuda:0 f32[1, 512, 4096]"
# t1421 = prims.sum(t1420, (2,)) # t1421: "cuda:0 f32[1, 512]"
# t1422 = prims.broadcast_in_dim(t1421, [1, 512, 1], [0, 1]) # t1422: "cuda:0 f32[1, 512, 1]"
# t1423 = prims.div(t1422, 4096.0) # t1423: "cuda:0 f32[1, 512, 1]"
# t1424 = prims.add(t1423, 1e-05) # t1424: "cuda:0 f32[1, 512, 1]"
# t1425 = prims.rsqrt(t1424) # t1425: "cuda:0 f32[1, 512, 1]"
# t1426 = prims.broadcast_in_dim(t1425, (1, 512, 4096), (0, 1, 2)) # t1426: "cuda:0 f32[1, 512, 4096]"
# t1427 = prims.mul(t1417, t1426) # t1427: "cuda:0 f32[1, 512, 4096]"
# t1431 = prims.convert_element_type(t1429, dtypes.float32) # t1431: "cuda:0 f32[1, 512, 4096]"
# t1432 = prims.mul(t1427, t1431) # t1432: "cuda:0 f32[1, 512, 4096]"
# t1433 = prims.convert_element_type(t1432, dtypes.bfloat16) # t1433: "cuda:0 bf16[1, 512, 4096]"
t1434 = torch.nn.functional.linear(t1433, t15, None) # t1434: "cuda:0 bf16[1, 512, 12288]"
# t1434 = ltorch.linear(t1433, t15, None) # t1434: "cuda:0 bf16[1, 512, 12288]"
# t1434 = prims.linear(t1433, t15, None) # t1434: "cuda:0 bf16[1, 512, 12288]"
t1435 = torch.reshape(t1434, (1, 512, 32, 3, 128)) # t1435: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t1435 = ltorch.reshape(t1434, (1, 512, 32, 3, 128)) # t1435: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t1435 = prims.reshape(t1434, (1, 512, 32, 3, 128)) # t1435: "cuda:0 bf16[1, 512, 32, 3, 128]"
del t1434
t1436 = torch.permute(t1435, (0, 2, 3, 1, 4)) # t1436: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t1436 = ltorch.permute(t1435, (0, 2, 3, 1, 4)) # t1436: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t1436 = prims.transpose(t1435, (0, 2, 3, 1, 4)) # t1436: "cuda:0 bf16[1, 32, 3, 512, 128]"
del t1435
(t1437, t1438, t1439) = torch.split(t1436, (1, 1, 1), 2)
# (t1437, t1438, t1439) = ltorch.split(t1436, (1, 1, 1), 2)
# t1437 = prims.slice_prim(t1436, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1437: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t1438 = prims.slice_prim(t1436, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1438: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t1439 = prims.slice_prim(t1436, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1439: "cuda:0 bf16[1, 32, 1, 512, 128]"
del t1436
t1440 = torch.reshape(t1437, (1, 32, 512, 128)) # t1440: "cuda:0 bf16[1, 32, 512, 128]"
# t1440 = ltorch.reshape(t1437, (1, 32, 512, 128)) # t1440: "cuda:0 bf16[1, 32, 512, 128]"
# t1440 = prims.reshape(t1437, (1, 32, 512, 128)) # t1440: "cuda:0 bf16[1, 32, 512, 128]"
del t1437
t1441 = torch.reshape(t1438, (1, 32, 512, 128)) # t1441: "cuda:0 bf16[1, 32, 512, 128]"
# t1441 = ltorch.reshape(t1438, (1, 32, 512, 128)) # t1441: "cuda:0 bf16[1, 32, 512, 128]"
# t1441 = prims.reshape(t1438, (1, 32, 512, 128)) # t1441: "cuda:0 bf16[1, 32, 512, 128]"
del t1438
t1442 = torch.reshape(t1439, (1, 32, 512, 128)) # t1442: "cuda:0 bf16[1, 32, 512, 128]"
# t1442 = ltorch.reshape(t1439, (1, 32, 512, 128)) # t1442: "cuda:0 bf16[1, 32, 512, 128]"
# t1442 = prims.reshape(t1439, (1, 32, 512, 128)) # t1442: "cuda:0 bf16[1, 32, 512, 128]"
del t1439
t1443 = torch_slice_prim_impl(t1440, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1443: "cuda:0 bf16[1, 32, 512, 128]"
t1458 = torch_slice_prim_impl(t1441, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1458: "cuda:0 bf16[1, 32, 512, 128]"
t1473 = torch_slice_prim_impl(t1440, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1473: "cuda:0 bf16[1, 32, 512, 0]"
del t1440
t1475 = torch_slice_prim_impl(t1441, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1475: "cuda:0 bf16[1, 32, 512, 0]"
del t1441
t1444 = torch_slice_prim_impl(t1443, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1444: "cuda:0 bf16[1, 32, 512, 64]"
t1445 = torch_slice_prim_impl(t1443, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1445: "cuda:0 bf16[1, 32, 512, 64]"
t1459 = torch_slice_prim_impl(t1458, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1459: "cuda:0 bf16[1, 32, 512, 64]"
t1460 = torch_slice_prim_impl(t1458, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1460: "cuda:0 bf16[1, 32, 512, 64]"
[t1448, t1463] = nvFusion61(t1443, t1445, t1458, t1460)
# t1446 = prims.convert_element_type(t1445, dtypes.float32) # t1446: "cuda:0 f32[1, 32, 512, 64]"
# t1447 = prims.neg(t1446) # t1447: "cuda:0 f32[1, 32, 512, 64]"
# t1448 = prims.convert_element_type(t1447, dtypes.bfloat16) # t1448: "cuda:0 bf16[1, 32, 512, 64]"
# t1461 = prims.convert_element_type(t1460, dtypes.float32) # t1461: "cuda:0 f32[1, 32, 512, 64]"
# t1462 = prims.neg(t1461) # t1462: "cuda:0 f32[1, 32, 512, 64]"
# t1463 = prims.convert_element_type(t1462, dtypes.bfloat16) # t1463: "cuda:0 bf16[1, 32, 512, 64]"
del t1445, t1460
t1464 = torch.cat((t1463, t1459), -1) # t1464: "cuda:0 bf16[1, 32, 512, 128]"
# t1464 = ltorch.cat((t1463, t1459), -1) # t1464: "cuda:0 bf16[1, 32, 512, 128]"
# t1464 = prims.cat((t1463, t1459), -1) # t1464: "cuda:0 bf16[1, 32, 512, 128]"
del t1463, t1459
t1449 = torch.cat((t1448, t1444), -1) # t1449: "cuda:0 bf16[1, 32, 512, 128]"
# t1449 = ltorch.cat((t1448, t1444), -1) # t1449: "cuda:0 bf16[1, 32, 512, 128]"
# t1449 = prims.cat((t1448, t1444), -1) # t1449: "cuda:0 bf16[1, 32, 512, 128]"
del t1448, t1444
[t1457, t1472] = nvFusion62(t1443, t1449, t1458, t1464, t154, t157)
# t1451 = prims.convert_element_type(t1443, dtypes.float32) # t1451: "cuda:0 f32[1, 32, 512, 128]"
# t1466 = prims.convert_element_type(t1458, dtypes.float32) # t1466: "cuda:0 f32[1, 32, 512, 128]"
# t1467 = prims.mul(t1466, t154) # t1467: "cuda:0 f32[1, 32, 512, 128]"
# t1469 = prims.convert_element_type(t1464, dtypes.float32) # t1469: "cuda:0 f32[1, 32, 512, 128]"
# t1470 = prims.mul(t1469, t157) # t1470: "cuda:0 f32[1, 32, 512, 128]"
# t1471 = prims.add(t1467, t1470) # t1471: "cuda:0 f32[1, 32, 512, 128]"
# t1472 = prims.convert_element_type(t1471, dtypes.bfloat16) # t1472: "cuda:0 bf16[1, 32, 512, 128]"
# t1452 = prims.mul(t1451, t154) # t1452: "cuda:0 f32[1, 32, 512, 128]"
# t1454 = prims.convert_element_type(t1449, dtypes.float32) # t1454: "cuda:0 f32[1, 32, 512, 128]"
# t1455 = prims.mul(t1454, t157) # t1455: "cuda:0 f32[1, 32, 512, 128]"
# t1456 = prims.add(t1452, t1455) # t1456: "cuda:0 f32[1, 32, 512, 128]"
# t1457 = prims.convert_element_type(t1456, dtypes.bfloat16) # t1457: "cuda:0 bf16[1, 32, 512, 128]"
del t1443, t1449, t1458, t1464
t1476 = torch.cat((t1472, t1475), -1) # t1476: "cuda:0 bf16[1, 32, 512, 128]"
# t1476 = ltorch.cat((t1472, t1475), -1) # t1476: "cuda:0 bf16[1, 32, 512, 128]"
# t1476 = prims.cat((t1472, t1475), -1) # t1476: "cuda:0 bf16[1, 32, 512, 128]"
del t1472, t1475
t1474 = torch.cat((t1457, t1473), -1) # t1474: "cuda:0 bf16[1, 32, 512, 128]"
# t1474 = ltorch.cat((t1457, t1473), -1) # t1474: "cuda:0 bf16[1, 32, 512, 128]"
# t1474 = prims.cat((t1457, t1473), -1) # t1474: "cuda:0 bf16[1, 32, 512, 128]"
del t1457, t1473
(t1477, t1478, t1479, t1480, _, _, t1481, t1482, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1474, t1476, t1442, 0.0, True, scale=0.08838834764831843)
t1484 = torch.permute(t1477, (0, 2, 1, 3)) # t1484: "cuda:0 bf16[1, 512, 32, 128]"
# t1484 = ltorch.permute(t1477, (0, 2, 1, 3)) # t1484: "cuda:0 bf16[1, 512, 32, 128]"
# t1484 = prims.transpose(t1477, (0, 2, 1, 3)) # t1484: "cuda:0 bf16[1, 512, 32, 128]"
t1485 = torch.reshape(t1484, (1, 512, 4096)) # t1485: "cuda:0 bf16[1, 512, 4096]"
# t1485 = ltorch.reshape(t1484, (1, 512, 4096)) # t1485: "cuda:0 bf16[1, 512, 4096]"
# t1485 = prims.reshape(t1484, (1, 512, 4096)) # t1485: "cuda:0 bf16[1, 512, 4096]"
del t1484
t1486 = torch.nn.functional.linear(t1485, t109, None) # t1486: "cuda:0 bf16[1, 512, 4096]"
# t1486 = ltorch.linear(t1485, t109, None) # t1486: "cuda:0 bf16[1, 512, 4096]"
# t1486 = prims.linear(t1485, t109, None) # t1486: "cuda:0 bf16[1, 512, 4096]"
[t1490, t1497, t1505] = nvFusion63(t1418, t1486, t1501)
# t1488 = prims.convert_element_type(t1418, dtypes.float32) # t1488: "cuda:0 f32[1, 512, 4096]"
# t1487 = prims.convert_element_type(t1486, dtypes.float32) # t1487: "cuda:0 f32[1, 512, 4096]"
# t1489 = prims.add(t1487, t1488) # t1489: "cuda:0 f32[1, 512, 4096]"
# t1490 = prims.convert_element_type(t1489, dtypes.bfloat16) # t1490: "cuda:0 bf16[1, 512, 4096]"
# t1492 = prims.mul(t1489, t1489) # t1492: "cuda:0 f32[1, 512, 4096]"
# t1493 = prims.sum(t1492, (2,)) # t1493: "cuda:0 f32[1, 512]"
# t1494 = prims.broadcast_in_dim(t1493, [1, 512, 1], [0, 1]) # t1494: "cuda:0 f32[1, 512, 1]"
# t1495 = prims.div(t1494, 4096.0) # t1495: "cuda:0 f32[1, 512, 1]"
# t1496 = prims.add(t1495, 1e-05) # t1496: "cuda:0 f32[1, 512, 1]"
# t1497 = prims.rsqrt(t1496) # t1497: "cuda:0 f32[1, 512, 1]"
# t1498 = prims.broadcast_in_dim(t1497, (1, 512, 4096), (0, 1, 2)) # t1498: "cuda:0 f32[1, 512, 4096]"
# t1499 = prims.mul(t1489, t1498) # t1499: "cuda:0 f32[1, 512, 4096]"
# t1503 = prims.convert_element_type(t1501, dtypes.float32) # t1503: "cuda:0 f32[1, 512, 4096]"
# t1504 = prims.mul(t1499, t1503) # t1504: "cuda:0 f32[1, 512, 4096]"
# t1505 = prims.convert_element_type(t1504, dtypes.bfloat16) # t1505: "cuda:0 bf16[1, 512, 4096]"
t1506 = torch.nn.functional.linear(t1505, t31, None) # t1506: "cuda:0 bf16[1, 512, 11008]"
# t1506 = ltorch.linear(t1505, t31, None) # t1506: "cuda:0 bf16[1, 512, 11008]"
# t1506 = prims.linear(t1505, t31, None) # t1506: "cuda:0 bf16[1, 512, 11008]"
t1507 = torch.nn.functional.linear(t1505, t47, None) # t1507: "cuda:0 bf16[1, 512, 11008]"
# t1507 = ltorch.linear(t1505, t47, None) # t1507: "cuda:0 bf16[1, 512, 11008]"
# t1507 = prims.linear(t1505, t47, None) # t1507: "cuda:0 bf16[1, 512, 11008]"
[t1521] = nvFusion64(t1506, t1507)
# t1508 = prims.convert_element_type(t1506, dtypes.float32) # t1508: "cuda:0 f32[1, 512, 11008]"
# t1509 = prims.neg(t1508) # t1509: "cuda:0 f32[1, 512, 11008]"
# t1510 = prims.exp(t1509) # t1510: "cuda:0 f32[1, 512, 11008]"
# t1511 = prims.add(1.0, t1510) # t1511: "cuda:0 f32[1, 512, 11008]"
# t1512 = prims.reciprocal(t1511) # t1512: "cuda:0 f32[1, 512, 11008]"
# t1516 = prims.mul(t1508, t1512) # t1516: "cuda:0 f32[1, 512, 11008]"
# t1519 = prims.convert_element_type(t1507, dtypes.float32) # t1519: "cuda:0 f32[1, 512, 11008]"
# t1520 = prims.mul(t1516, t1519) # t1520: "cuda:0 f32[1, 512, 11008]"
# t1521 = prims.convert_element_type(t1520, dtypes.bfloat16) # t1521: "cuda:0 bf16[1, 512, 11008]"
t1522 = torch.nn.functional.linear(t1521, t110, None) # t1522: "cuda:0 bf16[1, 512, 4096]"
# t1522 = ltorch.linear(t1521, t110, None) # t1522: "cuda:0 bf16[1, 512, 4096]"
# t1522 = prims.linear(t1521, t110, None) # t1522: "cuda:0 bf16[1, 512, 4096]"
[t1526, t1533, t1541] = nvFusion65(t1490, t1522, t1537)
# t1524 = prims.convert_element_type(t1490, dtypes.float32) # t1524: "cuda:0 f32[1, 512, 4096]"
# t1523 = prims.convert_element_type(t1522, dtypes.float32) # t1523: "cuda:0 f32[1, 512, 4096]"
# t1525 = prims.add(t1523, t1524) # t1525: "cuda:0 f32[1, 512, 4096]"
# t1526 = prims.convert_element_type(t1525, dtypes.bfloat16) # t1526: "cuda:0 bf16[1, 512, 4096]"
# t1528 = prims.mul(t1525, t1525) # t1528: "cuda:0 f32[1, 512, 4096]"
# t1529 = prims.sum(t1528, (2,)) # t1529: "cuda:0 f32[1, 512]"
# t1530 = prims.broadcast_in_dim(t1529, [1, 512, 1], [0, 1]) # t1530: "cuda:0 f32[1, 512, 1]"
# t1531 = prims.div(t1530, 4096.0) # t1531: "cuda:0 f32[1, 512, 1]"
# t1532 = prims.add(t1531, 1e-05) # t1532: "cuda:0 f32[1, 512, 1]"
# t1533 = prims.rsqrt(t1532) # t1533: "cuda:0 f32[1, 512, 1]"
# t1534 = prims.broadcast_in_dim(t1533, (1, 512, 4096), (0, 1, 2)) # t1534: "cuda:0 f32[1, 512, 4096]"
# t1535 = prims.mul(t1525, t1534) # t1535: "cuda:0 f32[1, 512, 4096]"
# t1539 = prims.convert_element_type(t1537, dtypes.float32) # t1539: "cuda:0 f32[1, 512, 4096]"
# t1540 = prims.mul(t1535, t1539) # t1540: "cuda:0 f32[1, 512, 4096]"
# t1541 = prims.convert_element_type(t1540, dtypes.bfloat16) # t1541: "cuda:0 bf16[1, 512, 4096]"
t1542 = torch.nn.functional.linear(t1541, t16, None) # t1542: "cuda:0 bf16[1, 512, 12288]"
# t1542 = ltorch.linear(t1541, t16, None) # t1542: "cuda:0 bf16[1, 512, 12288]"
# t1542 = prims.linear(t1541, t16, None) # t1542: "cuda:0 bf16[1, 512, 12288]"
t1543 = torch.reshape(t1542, (1, 512, 32, 3, 128)) # t1543: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t1543 = ltorch.reshape(t1542, (1, 512, 32, 3, 128)) # t1543: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t1543 = prims.reshape(t1542, (1, 512, 32, 3, 128)) # t1543: "cuda:0 bf16[1, 512, 32, 3, 128]"
del t1542
t1544 = torch.permute(t1543, (0, 2, 3, 1, 4)) # t1544: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t1544 = ltorch.permute(t1543, (0, 2, 3, 1, 4)) # t1544: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t1544 = prims.transpose(t1543, (0, 2, 3, 1, 4)) # t1544: "cuda:0 bf16[1, 32, 3, 512, 128]"
del t1543
(t1545, t1546, t1547) = torch.split(t1544, (1, 1, 1), 2)
# (t1545, t1546, t1547) = ltorch.split(t1544, (1, 1, 1), 2)
# t1545 = prims.slice_prim(t1544, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1545: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t1546 = prims.slice_prim(t1544, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1546: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t1547 = prims.slice_prim(t1544, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1547: "cuda:0 bf16[1, 32, 1, 512, 128]"
del t1544
t1548 = torch.reshape(t1545, (1, 32, 512, 128)) # t1548: "cuda:0 bf16[1, 32, 512, 128]"
# t1548 = ltorch.reshape(t1545, (1, 32, 512, 128)) # t1548: "cuda:0 bf16[1, 32, 512, 128]"
# t1548 = prims.reshape(t1545, (1, 32, 512, 128)) # t1548: "cuda:0 bf16[1, 32, 512, 128]"
del t1545
t1549 = torch.reshape(t1546, (1, 32, 512, 128)) # t1549: "cuda:0 bf16[1, 32, 512, 128]"
# t1549 = ltorch.reshape(t1546, (1, 32, 512, 128)) # t1549: "cuda:0 bf16[1, 32, 512, 128]"
# t1549 = prims.reshape(t1546, (1, 32, 512, 128)) # t1549: "cuda:0 bf16[1, 32, 512, 128]"
del t1546
t1550 = torch.reshape(t1547, (1, 32, 512, 128)) # t1550: "cuda:0 bf16[1, 32, 512, 128]"
# t1550 = ltorch.reshape(t1547, (1, 32, 512, 128)) # t1550: "cuda:0 bf16[1, 32, 512, 128]"
# t1550 = prims.reshape(t1547, (1, 32, 512, 128)) # t1550: "cuda:0 bf16[1, 32, 512, 128]"
del t1547
t1551 = torch_slice_prim_impl(t1548, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1551: "cuda:0 bf16[1, 32, 512, 128]"
t1566 = torch_slice_prim_impl(t1549, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1566: "cuda:0 bf16[1, 32, 512, 128]"
t1581 = torch_slice_prim_impl(t1548, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1581: "cuda:0 bf16[1, 32, 512, 0]"
del t1548
t1583 = torch_slice_prim_impl(t1549, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1583: "cuda:0 bf16[1, 32, 512, 0]"
del t1549
t1552 = torch_slice_prim_impl(t1551, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1552: "cuda:0 bf16[1, 32, 512, 64]"
t1553 = torch_slice_prim_impl(t1551, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1553: "cuda:0 bf16[1, 32, 512, 64]"
t1567 = torch_slice_prim_impl(t1566, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1567: "cuda:0 bf16[1, 32, 512, 64]"
t1568 = torch_slice_prim_impl(t1566, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1568: "cuda:0 bf16[1, 32, 512, 64]"
[t1556, t1571] = nvFusion66(t1551, t1553, t1566, t1568)
# t1554 = prims.convert_element_type(t1553, dtypes.float32) # t1554: "cuda:0 f32[1, 32, 512, 64]"
# t1555 = prims.neg(t1554) # t1555: "cuda:0 f32[1, 32, 512, 64]"
# t1556 = prims.convert_element_type(t1555, dtypes.bfloat16) # t1556: "cuda:0 bf16[1, 32, 512, 64]"
# t1569 = prims.convert_element_type(t1568, dtypes.float32) # t1569: "cuda:0 f32[1, 32, 512, 64]"
# t1570 = prims.neg(t1569) # t1570: "cuda:0 f32[1, 32, 512, 64]"
# t1571 = prims.convert_element_type(t1570, dtypes.bfloat16) # t1571: "cuda:0 bf16[1, 32, 512, 64]"
del t1553, t1568
t1572 = torch.cat((t1571, t1567), -1) # t1572: "cuda:0 bf16[1, 32, 512, 128]"
# t1572 = ltorch.cat((t1571, t1567), -1) # t1572: "cuda:0 bf16[1, 32, 512, 128]"
# t1572 = prims.cat((t1571, t1567), -1) # t1572: "cuda:0 bf16[1, 32, 512, 128]"
del t1571, t1567
t1557 = torch.cat((t1556, t1552), -1) # t1557: "cuda:0 bf16[1, 32, 512, 128]"
# t1557 = ltorch.cat((t1556, t1552), -1) # t1557: "cuda:0 bf16[1, 32, 512, 128]"
# t1557 = prims.cat((t1556, t1552), -1) # t1557: "cuda:0 bf16[1, 32, 512, 128]"
del t1556, t1552
[t1565, t1580] = nvFusion67(t154, t1551, t1557, t1566, t157, t1572)
# t1559 = prims.convert_element_type(t1551, dtypes.float32) # t1559: "cuda:0 f32[1, 32, 512, 128]"
# t1574 = prims.convert_element_type(t1566, dtypes.float32) # t1574: "cuda:0 f32[1, 32, 512, 128]"
# t1575 = prims.mul(t1574, t154) # t1575: "cuda:0 f32[1, 32, 512, 128]"
# t1577 = prims.convert_element_type(t1572, dtypes.float32) # t1577: "cuda:0 f32[1, 32, 512, 128]"
# t1578 = prims.mul(t1577, t157) # t1578: "cuda:0 f32[1, 32, 512, 128]"
# t1579 = prims.add(t1575, t1578) # t1579: "cuda:0 f32[1, 32, 512, 128]"
# t1580 = prims.convert_element_type(t1579, dtypes.bfloat16) # t1580: "cuda:0 bf16[1, 32, 512, 128]"
# t1560 = prims.mul(t1559, t154) # t1560: "cuda:0 f32[1, 32, 512, 128]"
# t1562 = prims.convert_element_type(t1557, dtypes.float32) # t1562: "cuda:0 f32[1, 32, 512, 128]"
# t1563 = prims.mul(t1562, t157) # t1563: "cuda:0 f32[1, 32, 512, 128]"
# t1564 = prims.add(t1560, t1563) # t1564: "cuda:0 f32[1, 32, 512, 128]"
# t1565 = prims.convert_element_type(t1564, dtypes.bfloat16) # t1565: "cuda:0 bf16[1, 32, 512, 128]"
del t1551, t1557, t1566, t1572
t1584 = torch.cat((t1580, t1583), -1) # t1584: "cuda:0 bf16[1, 32, 512, 128]"
# t1584 = ltorch.cat((t1580, t1583), -1) # t1584: "cuda:0 bf16[1, 32, 512, 128]"
# t1584 = prims.cat((t1580, t1583), -1) # t1584: "cuda:0 bf16[1, 32, 512, 128]"
del t1580, t1583
t1582 = torch.cat((t1565, t1581), -1) # t1582: "cuda:0 bf16[1, 32, 512, 128]"
# t1582 = ltorch.cat((t1565, t1581), -1) # t1582: "cuda:0 bf16[1, 32, 512, 128]"
# t1582 = prims.cat((t1565, t1581), -1) # t1582: "cuda:0 bf16[1, 32, 512, 128]"
del t1565, t1581
(t1585, t1586, t1587, t1588, _, _, t1589, t1590, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1582, t1584, t1550, 0.0, True, scale=0.08838834764831843)
t1592 = torch.permute(t1585, (0, 2, 1, 3)) # t1592: "cuda:0 bf16[1, 512, 32, 128]"
# t1592 = ltorch.permute(t1585, (0, 2, 1, 3)) # t1592: "cuda:0 bf16[1, 512, 32, 128]"
# t1592 = prims.transpose(t1585, (0, 2, 1, 3)) # t1592: "cuda:0 bf16[1, 512, 32, 128]"
t1593 = torch.reshape(t1592, (1, 512, 4096)) # t1593: "cuda:0 bf16[1, 512, 4096]"
# t1593 = ltorch.reshape(t1592, (1, 512, 4096)) # t1593: "cuda:0 bf16[1, 512, 4096]"
# t1593 = prims.reshape(t1592, (1, 512, 4096)) # t1593: "cuda:0 bf16[1, 512, 4096]"
del t1592
t1594 = torch.nn.functional.linear(t1593, t111, None) # t1594: "cuda:0 bf16[1, 512, 4096]"
# t1594 = ltorch.linear(t1593, t111, None) # t1594: "cuda:0 bf16[1, 512, 4096]"
# t1594 = prims.linear(t1593, t111, None) # t1594: "cuda:0 bf16[1, 512, 4096]"
[t1598, t1605, t1613] = nvFusion68(t1526, t1594, t1609)
# t1596 = prims.convert_element_type(t1526, dtypes.float32) # t1596: "cuda:0 f32[1, 512, 4096]"
# t1595 = prims.convert_element_type(t1594, dtypes.float32) # t1595: "cuda:0 f32[1, 512, 4096]"
# t1597 = prims.add(t1595, t1596) # t1597: "cuda:0 f32[1, 512, 4096]"
# t1598 = prims.convert_element_type(t1597, dtypes.bfloat16) # t1598: "cuda:0 bf16[1, 512, 4096]"
# t1600 = prims.mul(t1597, t1597) # t1600: "cuda:0 f32[1, 512, 4096]"
# t1601 = prims.sum(t1600, (2,)) # t1601: "cuda:0 f32[1, 512]"
# t1602 = prims.broadcast_in_dim(t1601, [1, 512, 1], [0, 1]) # t1602: "cuda:0 f32[1, 512, 1]"
# t1603 = prims.div(t1602, 4096.0) # t1603: "cuda:0 f32[1, 512, 1]"
# t1604 = prims.add(t1603, 1e-05) # t1604: "cuda:0 f32[1, 512, 1]"
# t1605 = prims.rsqrt(t1604) # t1605: "cuda:0 f32[1, 512, 1]"
# t1606 = prims.broadcast_in_dim(t1605, (1, 512, 4096), (0, 1, 2)) # t1606: "cuda:0 f32[1, 512, 4096]"
# t1607 = prims.mul(t1597, t1606) # t1607: "cuda:0 f32[1, 512, 4096]"
# t1611 = prims.convert_element_type(t1609, dtypes.float32) # t1611: "cuda:0 f32[1, 512, 4096]"
# t1612 = prims.mul(t1607, t1611) # t1612: "cuda:0 f32[1, 512, 4096]"
# t1613 = prims.convert_element_type(t1612, dtypes.bfloat16) # t1613: "cuda:0 bf16[1, 512, 4096]"
t1614 = torch.nn.functional.linear(t1613, t32, None) # t1614: "cuda:0 bf16[1, 512, 11008]"
# t1614 = ltorch.linear(t1613, t32, None) # t1614: "cuda:0 bf16[1, 512, 11008]"
# t1614 = prims.linear(t1613, t32, None) # t1614: "cuda:0 bf16[1, 512, 11008]"
t1615 = torch.nn.functional.linear(t1613, t48, None) # t1615: "cuda:0 bf16[1, 512, 11008]"
# t1615 = ltorch.linear(t1613, t48, None) # t1615: "cuda:0 bf16[1, 512, 11008]"
# t1615 = prims.linear(t1613, t48, None) # t1615: "cuda:0 bf16[1, 512, 11008]"
[t1629] = nvFusion69(t1614, t1615)
# t1616 = prims.convert_element_type(t1614, dtypes.float32) # t1616: "cuda:0 f32[1, 512, 11008]"
# t1617 = prims.neg(t1616) # t1617: "cuda:0 f32[1, 512, 11008]"
# t1618 = prims.exp(t1617) # t1618: "cuda:0 f32[1, 512, 11008]"
# t1619 = prims.add(1.0, t1618) # t1619: "cuda:0 f32[1, 512, 11008]"
# t1620 = prims.reciprocal(t1619) # t1620: "cuda:0 f32[1, 512, 11008]"
# t1624 = prims.mul(t1616, t1620) # t1624: "cuda:0 f32[1, 512, 11008]"
# t1627 = prims.convert_element_type(t1615, dtypes.float32) # t1627: "cuda:0 f32[1, 512, 11008]"
# t1628 = prims.mul(t1624, t1627) # t1628: "cuda:0 f32[1, 512, 11008]"
# t1629 = prims.convert_element_type(t1628, dtypes.bfloat16) # t1629: "cuda:0 bf16[1, 512, 11008]"
t1630 = torch.nn.functional.linear(t1629, t112, None) # t1630: "cuda:0 bf16[1, 512, 4096]"
# t1630 = ltorch.linear(t1629, t112, None) # t1630: "cuda:0 bf16[1, 512, 4096]"
# t1630 = prims.linear(t1629, t112, None) # t1630: "cuda:0 bf16[1, 512, 4096]"
[t1634, t1641, t1649] = nvFusion70(t1598, t1630, t1645)
# t1632 = prims.convert_element_type(t1598, dtypes.float32) # t1632: "cuda:0 f32[1, 512, 4096]"
# t1631 = prims.convert_element_type(t1630, dtypes.float32) # t1631: "cuda:0 f32[1, 512, 4096]"
# t1633 = prims.add(t1631, t1632) # t1633: "cuda:0 f32[1, 512, 4096]"
# t1634 = prims.convert_element_type(t1633, dtypes.bfloat16) # t1634: "cuda:0 bf16[1, 512, 4096]"
# t1636 = prims.mul(t1633, t1633) # t1636: "cuda:0 f32[1, 512, 4096]"
# t1637 = prims.sum(t1636, (2,)) # t1637: "cuda:0 f32[1, 512]"
# t1638 = prims.broadcast_in_dim(t1637, [1, 512, 1], [0, 1]) # t1638: "cuda:0 f32[1, 512, 1]"
# t1639 = prims.div(t1638, 4096.0) # t1639: "cuda:0 f32[1, 512, 1]"
# t1640 = prims.add(t1639, 1e-05) # t1640: "cuda:0 f32[1, 512, 1]"
# t1641 = prims.rsqrt(t1640) # t1641: "cuda:0 f32[1, 512, 1]"
# t1642 = prims.broadcast_in_dim(t1641, (1, 512, 4096), (0, 1, 2)) # t1642: "cuda:0 f32[1, 512, 4096]"
# t1643 = prims.mul(t1633, t1642) # t1643: "cuda:0 f32[1, 512, 4096]"
# t1647 = prims.convert_element_type(t1645, dtypes.float32) # t1647: "cuda:0 f32[1, 512, 4096]"
# t1648 = prims.mul(t1643, t1647) # t1648: "cuda:0 f32[1, 512, 4096]"
# t1649 = prims.convert_element_type(t1648, dtypes.bfloat16) # t1649: "cuda:0 bf16[1, 512, 4096]"
t1650 = torch.nn.functional.linear(t1649, t17, None) # t1650: "cuda:0 bf16[1, 512, 12288]"
# t1650 = ltorch.linear(t1649, t17, None) # t1650: "cuda:0 bf16[1, 512, 12288]"
# t1650 = prims.linear(t1649, t17, None) # t1650: "cuda:0 bf16[1, 512, 12288]"
t1651 = torch.reshape(t1650, (1, 512, 32, 3, 128)) # t1651: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t1651 = ltorch.reshape(t1650, (1, 512, 32, 3, 128)) # t1651: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t1651 = prims.reshape(t1650, (1, 512, 32, 3, 128)) # t1651: "cuda:0 bf16[1, 512, 32, 3, 128]"
del t1650
t1652 = torch.permute(t1651, (0, 2, 3, 1, 4)) # t1652: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t1652 = ltorch.permute(t1651, (0, 2, 3, 1, 4)) # t1652: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t1652 = prims.transpose(t1651, (0, 2, 3, 1, 4)) # t1652: "cuda:0 bf16[1, 32, 3, 512, 128]"
del t1651
(t1653, t1654, t1655) = torch.split(t1652, (1, 1, 1), 2)
# (t1653, t1654, t1655) = ltorch.split(t1652, (1, 1, 1), 2)
# t1653 = prims.slice_prim(t1652, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1653: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t1654 = prims.slice_prim(t1652, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1654: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t1655 = prims.slice_prim(t1652, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1655: "cuda:0 bf16[1, 32, 1, 512, 128]"
del t1652
t1656 = torch.reshape(t1653, (1, 32, 512, 128)) # t1656: "cuda:0 bf16[1, 32, 512, 128]"
# t1656 = ltorch.reshape(t1653, (1, 32, 512, 128)) # t1656: "cuda:0 bf16[1, 32, 512, 128]"
# t1656 = prims.reshape(t1653, (1, 32, 512, 128)) # t1656: "cuda:0 bf16[1, 32, 512, 128]"
del t1653
t1657 = torch.reshape(t1654, (1, 32, 512, 128)) # t1657: "cuda:0 bf16[1, 32, 512, 128]"
# t1657 = ltorch.reshape(t1654, (1, 32, 512, 128)) # t1657: "cuda:0 bf16[1, 32, 512, 128]"
# t1657 = prims.reshape(t1654, (1, 32, 512, 128)) # t1657: "cuda:0 bf16[1, 32, 512, 128]"
del t1654
t1658 = torch.reshape(t1655, (1, 32, 512, 128)) # t1658: "cuda:0 bf16[1, 32, 512, 128]"
# t1658 = ltorch.reshape(t1655, (1, 32, 512, 128)) # t1658: "cuda:0 bf16[1, 32, 512, 128]"
# t1658 = prims.reshape(t1655, (1, 32, 512, 128)) # t1658: "cuda:0 bf16[1, 32, 512, 128]"
del t1655
t1689 = torch_slice_prim_impl(t1656, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1689: "cuda:0 bf16[1, 32, 512, 0]"
t1691 = torch_slice_prim_impl(t1657, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1691: "cuda:0 bf16[1, 32, 512, 0]"
t1659 = torch_slice_prim_impl(t1656, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1659: "cuda:0 bf16[1, 32, 512, 128]"
del t1656
t1674 = torch_slice_prim_impl(t1657, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1674: "cuda:0 bf16[1, 32, 512, 128]"
del t1657
t1660 = torch_slice_prim_impl(t1659, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1660: "cuda:0 bf16[1, 32, 512, 64]"
t1661 = torch_slice_prim_impl(t1659, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1661: "cuda:0 bf16[1, 32, 512, 64]"
t1675 = torch_slice_prim_impl(t1674, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1675: "cuda:0 bf16[1, 32, 512, 64]"
t1676 = torch_slice_prim_impl(t1674, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1676: "cuda:0 bf16[1, 32, 512, 64]"
[t1664, t1679] = nvFusion71(t1659, t1661, t1674, t1676)
# t1662 = prims.convert_element_type(t1661, dtypes.float32) # t1662: "cuda:0 f32[1, 32, 512, 64]"
# t1663 = prims.neg(t1662) # t1663: "cuda:0 f32[1, 32, 512, 64]"
# t1664 = prims.convert_element_type(t1663, dtypes.bfloat16) # t1664: "cuda:0 bf16[1, 32, 512, 64]"
# t1677 = prims.convert_element_type(t1676, dtypes.float32) # t1677: "cuda:0 f32[1, 32, 512, 64]"
# t1678 = prims.neg(t1677) # t1678: "cuda:0 f32[1, 32, 512, 64]"
# t1679 = prims.convert_element_type(t1678, dtypes.bfloat16) # t1679: "cuda:0 bf16[1, 32, 512, 64]"
del t1661, t1676
t1680 = torch.cat((t1679, t1675), -1) # t1680: "cuda:0 bf16[1, 32, 512, 128]"
# t1680 = ltorch.cat((t1679, t1675), -1) # t1680: "cuda:0 bf16[1, 32, 512, 128]"
# t1680 = prims.cat((t1679, t1675), -1) # t1680: "cuda:0 bf16[1, 32, 512, 128]"
del t1679, t1675
t1665 = torch.cat((t1664, t1660), -1) # t1665: "cuda:0 bf16[1, 32, 512, 128]"
# t1665 = ltorch.cat((t1664, t1660), -1) # t1665: "cuda:0 bf16[1, 32, 512, 128]"
# t1665 = prims.cat((t1664, t1660), -1) # t1665: "cuda:0 bf16[1, 32, 512, 128]"
del t1664, t1660
[t1673, t1688] = nvFusion72(t154, t157, t1659, t1665, t1674, t1680)
# t1667 = prims.convert_element_type(t1659, dtypes.float32) # t1667: "cuda:0 f32[1, 32, 512, 128]"
# t1682 = prims.convert_element_type(t1674, dtypes.float32) # t1682: "cuda:0 f32[1, 32, 512, 128]"
# t1683 = prims.mul(t1682, t154) # t1683: "cuda:0 f32[1, 32, 512, 128]"
# t1685 = prims.convert_element_type(t1680, dtypes.float32) # t1685: "cuda:0 f32[1, 32, 512, 128]"
# t1686 = prims.mul(t1685, t157) # t1686: "cuda:0 f32[1, 32, 512, 128]"
# t1687 = prims.add(t1683, t1686) # t1687: "cuda:0 f32[1, 32, 512, 128]"
# t1688 = prims.convert_element_type(t1687, dtypes.bfloat16) # t1688: "cuda:0 bf16[1, 32, 512, 128]"
# t1668 = prims.mul(t1667, t154) # t1668: "cuda:0 f32[1, 32, 512, 128]"
# t1670 = prims.convert_element_type(t1665, dtypes.float32) # t1670: "cuda:0 f32[1, 32, 512, 128]"
# t1671 = prims.mul(t1670, t157) # t1671: "cuda:0 f32[1, 32, 512, 128]"
# t1672 = prims.add(t1668, t1671) # t1672: "cuda:0 f32[1, 32, 512, 128]"
# t1673 = prims.convert_element_type(t1672, dtypes.bfloat16) # t1673: "cuda:0 bf16[1, 32, 512, 128]"
del t1659, t1665, t1674, t1680
t1692 = torch.cat((t1688, t1691), -1) # t1692: "cuda:0 bf16[1, 32, 512, 128]"
# t1692 = ltorch.cat((t1688, t1691), -1) # t1692: "cuda:0 bf16[1, 32, 512, 128]"
# t1692 = prims.cat((t1688, t1691), -1) # t1692: "cuda:0 bf16[1, 32, 512, 128]"
del t1688, t1691
t1690 = torch.cat((t1673, t1689), -1) # t1690: "cuda:0 bf16[1, 32, 512, 128]"
# t1690 = ltorch.cat((t1673, t1689), -1) # t1690: "cuda:0 bf16[1, 32, 512, 128]"
# t1690 = prims.cat((t1673, t1689), -1) # t1690: "cuda:0 bf16[1, 32, 512, 128]"
del t1673, t1689
(t1693, t1694, t1695, t1696, _, _, t1697, t1698, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1690, t1692, t1658, 0.0, True, scale=0.08838834764831843)
t1700 = torch.permute(t1693, (0, 2, 1, 3)) # t1700: "cuda:0 bf16[1, 512, 32, 128]"
# t1700 = ltorch.permute(t1693, (0, 2, 1, 3)) # t1700: "cuda:0 bf16[1, 512, 32, 128]"
# t1700 = prims.transpose(t1693, (0, 2, 1, 3)) # t1700: "cuda:0 bf16[1, 512, 32, 128]"
t1701 = torch.reshape(t1700, (1, 512, 4096)) # t1701: "cuda:0 bf16[1, 512, 4096]"
# t1701 = ltorch.reshape(t1700, (1, 512, 4096)) # t1701: "cuda:0 bf16[1, 512, 4096]"
# t1701 = prims.reshape(t1700, (1, 512, 4096)) # t1701: "cuda:0 bf16[1, 512, 4096]"
del t1700
t1702 = torch.nn.functional.linear(t1701, t113, None) # t1702: "cuda:0 bf16[1, 512, 4096]"
# t1702 = ltorch.linear(t1701, t113, None) # t1702: "cuda:0 bf16[1, 512, 4096]"
# t1702 = prims.linear(t1701, t113, None) # t1702: "cuda:0 bf16[1, 512, 4096]"
[t1706, t1713, t1721] = nvFusion73(t1634, t1702, t1717)
# t1704 = prims.convert_element_type(t1634, dtypes.float32) # t1704: "cuda:0 f32[1, 512, 4096]"
# t1703 = prims.convert_element_type(t1702, dtypes.float32) # t1703: "cuda:0 f32[1, 512, 4096]"
# t1705 = prims.add(t1703, t1704) # t1705: "cuda:0 f32[1, 512, 4096]"
# t1706 = prims.convert_element_type(t1705, dtypes.bfloat16) # t1706: "cuda:0 bf16[1, 512, 4096]"
# t1708 = prims.mul(t1705, t1705) # t1708: "cuda:0 f32[1, 512, 4096]"
# t1709 = prims.sum(t1708, (2,)) # t1709: "cuda:0 f32[1, 512]"
# t1710 = prims.broadcast_in_dim(t1709, [1, 512, 1], [0, 1]) # t1710: "cuda:0 f32[1, 512, 1]"
# t1711 = prims.div(t1710, 4096.0) # t1711: "cuda:0 f32[1, 512, 1]"
# t1712 = prims.add(t1711, 1e-05) # t1712: "cuda:0 f32[1, 512, 1]"
# t1713 = prims.rsqrt(t1712) # t1713: "cuda:0 f32[1, 512, 1]"
# t1714 = prims.broadcast_in_dim(t1713, (1, 512, 4096), (0, 1, 2)) # t1714: "cuda:0 f32[1, 512, 4096]"
# t1715 = prims.mul(t1705, t1714) # t1715: "cuda:0 f32[1, 512, 4096]"
# t1719 = prims.convert_element_type(t1717, dtypes.float32) # t1719: "cuda:0 f32[1, 512, 4096]"
# t1720 = prims.mul(t1715, t1719) # t1720: "cuda:0 f32[1, 512, 4096]"
# t1721 = prims.convert_element_type(t1720, dtypes.bfloat16) # t1721: "cuda:0 bf16[1, 512, 4096]"
t1722 = torch.nn.functional.linear(t1721, t33, None) # t1722: "cuda:0 bf16[1, 512, 11008]"
# t1722 = ltorch.linear(t1721, t33, None) # t1722: "cuda:0 bf16[1, 512, 11008]"
# t1722 = prims.linear(t1721, t33, None) # t1722: "cuda:0 bf16[1, 512, 11008]"
t1723 = torch.nn.functional.linear(t1721, t49, None) # t1723: "cuda:0 bf16[1, 512, 11008]"
# t1723 = ltorch.linear(t1721, t49, None) # t1723: "cuda:0 bf16[1, 512, 11008]"
# t1723 = prims.linear(t1721, t49, None) # t1723: "cuda:0 bf16[1, 512, 11008]"
[t1737] = nvFusion74(t1722, t1723)
# t1724 = prims.convert_element_type(t1722, dtypes.float32) # t1724: "cuda:0 f32[1, 512, 11008]"
# t1725 = prims.neg(t1724) # t1725: "cuda:0 f32[1, 512, 11008]"
# t1726 = prims.exp(t1725) # t1726: "cuda:0 f32[1, 512, 11008]"
# t1727 = prims.add(1.0, t1726) # t1727: "cuda:0 f32[1, 512, 11008]"
# t1728 = prims.reciprocal(t1727) # t1728: "cuda:0 f32[1, 512, 11008]"
# t1732 = prims.mul(t1724, t1728) # t1732: "cuda:0 f32[1, 512, 11008]"
# t1735 = prims.convert_element_type(t1723, dtypes.float32) # t1735: "cuda:0 f32[1, 512, 11008]"
# t1736 = prims.mul(t1732, t1735) # t1736: "cuda:0 f32[1, 512, 11008]"
# t1737 = prims.convert_element_type(t1736, dtypes.bfloat16) # t1737: "cuda:0 bf16[1, 512, 11008]"
t1738 = torch.nn.functional.linear(t1737, t114, None) # t1738: "cuda:0 bf16[1, 512, 4096]"
# t1738 = ltorch.linear(t1737, t114, None) # t1738: "cuda:0 bf16[1, 512, 4096]"
# t1738 = prims.linear(t1737, t114, None) # t1738: "cuda:0 bf16[1, 512, 4096]"
[t1742, t1749, t1757] = nvFusion75(t1706, t1738, t1753)
# t1740 = prims.convert_element_type(t1706, dtypes.float32) # t1740: "cuda:0 f32[1, 512, 4096]"
# t1739 = prims.convert_element_type(t1738, dtypes.float32) # t1739: "cuda:0 f32[1, 512, 4096]"
# t1741 = prims.add(t1739, t1740) # t1741: "cuda:0 f32[1, 512, 4096]"
# t1742 = prims.convert_element_type(t1741, dtypes.bfloat16) # t1742: "cuda:0 bf16[1, 512, 4096]"
# t1744 = prims.mul(t1741, t1741) # t1744: "cuda:0 f32[1, 512, 4096]"
# t1745 = prims.sum(t1744, (2,)) # t1745: "cuda:0 f32[1, 512]"
# t1746 = prims.broadcast_in_dim(t1745, [1, 512, 1], [0, 1]) # t1746: "cuda:0 f32[1, 512, 1]"
# t1747 = prims.div(t1746, 4096.0) # t1747: "cuda:0 f32[1, 512, 1]"
# t1748 = prims.add(t1747, 1e-05) # t1748: "cuda:0 f32[1, 512, 1]"
# t1749 = prims.rsqrt(t1748) # t1749: "cuda:0 f32[1, 512, 1]"
# t1750 = prims.broadcast_in_dim(t1749, (1, 512, 4096), (0, 1, 2)) # t1750: "cuda:0 f32[1, 512, 4096]"
# t1751 = prims.mul(t1741, t1750) # t1751: "cuda:0 f32[1, 512, 4096]"
# t1755 = prims.convert_element_type(t1753, dtypes.float32) # t1755: "cuda:0 f32[1, 512, 4096]"
# t1756 = prims.mul(t1751, t1755) # t1756: "cuda:0 f32[1, 512, 4096]"
# t1757 = prims.convert_element_type(t1756, dtypes.bfloat16) # t1757: "cuda:0 bf16[1, 512, 4096]"
t1758 = torch.nn.functional.linear(t1757, t18, None) # t1758: "cuda:0 bf16[1, 512, 12288]"
# t1758 = ltorch.linear(t1757, t18, None) # t1758: "cuda:0 bf16[1, 512, 12288]"
# t1758 = prims.linear(t1757, t18, None) # t1758: "cuda:0 bf16[1, 512, 12288]"
t1759 = torch.reshape(t1758, (1, 512, 32, 3, 128)) # t1759: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t1759 = ltorch.reshape(t1758, (1, 512, 32, 3, 128)) # t1759: "cuda:0 bf16[1, 512, 32, 3, 128]"
# t1759 = prims.reshape(t1758, (1, 512, 32, 3, 128)) # t1759: "cuda:0 bf16[1, 512, 32, 3, 128]"
del t1758
t1760 = torch.permute(t1759, (0, 2, 3, 1, 4)) # t1760: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t1760 = ltorch.permute(t1759, (0, 2, 3, 1, 4)) # t1760: "cuda:0 bf16[1, 32, 3, 512, 128]"
# t1760 = prims.transpose(t1759, (0, 2, 3, 1, 4)) # t1760: "cuda:0 bf16[1, 32, 3, 512, 128]"
del t1759
(t1761, t1762, t1763) = torch.split(t1760, (1, 1, 1), 2)
# (t1761, t1762, t1763) = ltorch.split(t1760, (1, 1, 1), 2)
# t1761 = prims.slice_prim(t1760, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1761: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t1762 = prims.slice_prim(t1760, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1762: "cuda:0 bf16[1, 32, 1, 512, 128]"
# t1763 = prims.slice_prim(t1760, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1763: "cuda:0 bf16[1, 32, 1, 512, 128]"
del t1760
t1764 = torch.reshape(t1761, (1, 32, 512, 128)) # t1764: "cuda:0 bf16[1, 32, 512, 128]"
# t1764 = ltorch.reshape(t1761, (1, 32, 512, 128)) # t1764: "cuda:0 bf16[1, 32, 512, 128]"
# t1764 = prims.reshape(t1761, (1, 32, 512, 128)) # t1764: "cuda:0 bf16[1, 32, 512, 128]"
del t1761
t1765 = torch.reshape(t1762, (1, 32, 512, 128)) # t1765: "cuda:0 bf16[1, 32, 512, 128]"
# t1765 = ltorch.reshape(t1762, (1, 32, 512, 128)) # t1765: "cuda:0 bf16[1, 32, 512, 128]"
# t1765 = prims.reshape(t1762, (1, 32, 512, 128)) # t1765: "cuda:0 bf16[1, 32, 512, 128]"
del t1762
t1766 = torch.reshape(t1763, (1, 32, 512, 128)) # t1766: "cuda:0 bf16[1, 32, 512, 128]"
# t1766 = ltorch.reshape(t1763, (1, 32, 512, 128)) # t1766: "cuda:0 bf16[1, 32, 512, 128]"
# t1766 = prims.reshape(t1763, (1, 32, 512, 128)) # t1766: "cuda:0 bf16[1, 32, 512, 128]"
del t1763
t1767 = torch_slice_prim_impl(t1764, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1767: "cuda:0 bf16[1, 32, 512, 128]"
t1782 = torch_slice_prim_impl(t1765, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1782: "cuda:0 bf16[1, 32, 512, 128]"
t1797 = torch_slice_prim_impl(t1764, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1797: "cuda:0 bf16[1, 32, 512, 0]"
del t1764
t1799 = torch_slice_prim_impl(t1765, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1799: "cuda:0 bf16[1, 32, 512, 0]"
del t1765
t1768 = torch_slice_prim_impl(t1767, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1768: "cuda:0 bf16[1, 32, 512, 64]"
t1769 = torch_slice_prim_impl(t1767, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1769: "cuda:0 bf16[1, 32, 512, 64]"
t1783 = torch_slice_prim_impl(t1782, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1783: "cuda:0 bf16[1, 32, 512, 64]"
t1784 = torch_slice_prim_impl(t1782, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1784: "cuda:0 bf16[1, 32, 512, 64]"
[t1772, t1787] = nvFusion76(t1767, t1769, t1782, t1784)
# t1770 = prims.convert_element_type(t1769, dtypes.float32) # t1770: "cuda:0 f32[1, 32, 512, 64]"
# t1771 = prims.neg(t1770) # t1771: "cuda:0 f32[1, 32, 512, 64]"
# t1772 = prims.convert_element_type(t1771, dtypes.bfloat16) # t1772: "cuda:0 bf16[1, 32, 512, 64]"
# t1785 = prims.convert_element_type(t1784, dtypes.float32) # t1785: "cuda:0 f32[1, 32, 512, 64]"
# t1786 = prims.neg(t1785) # t1786: "cuda:0 f32[1, 32, 512, 64]"
# t1787 = prims.convert_element_type(t1786, dtypes.bfloat16) # t1787: "cuda:0 bf16[1, 32, 512, 64]"
del t1769, t1784
t1788 = torch.cat((t1787, t1783), -1) # t1788: "cuda:0 bf16[1, 32, 512, 128]"
# t1788 = ltorch.cat((t1787, t1783), -1) # t1788: "cuda:0 bf16[1, 32, 512, 128]"
# t1788 = prims.cat((t1787, t1783), -1) # t1788: "cuda:0 bf16[1, 32, 512, 128]"
del t1787, t1783
t1773 = torch.cat((t1772, t1768), -1) # t1773: "cuda:0 bf16[1, 32, 512, 128]"
# t1773 = ltorch.cat((t1772, t1768), -1) # t1773: "cuda:0 bf16[1, 32, 512, 128]"
# t1773 = prims.cat((t1772, t1768), -1) # t1773: "cuda:0 bf16[1, 32, 512, 128]"
del t1772, t1768
[t1781, t1796] = nvFusion77(t154, t157, t1767, t1773, t1782, t1788)
# t1775 = prims.convert_element_type(t1767, dtypes.float32) # t1775: "cuda:0 f32[1, 32, 512, 128]"
# t1790 = prims.convert_element_type(t1782, dtypes.float32) # t1790: "cuda:0 f32[1, 32, 512, 128]"
# t1791 = prims.mul(t1790, t154) # t1791: "cuda:0 f32[1, 32, 512, 128]"
# t1793 = prims.convert_element_type(t1788, dtypes.float32) # t1793: "cuda:0 f32[1, 32, 512, 128]"
# t1794 = prims.mul(t1793, t157) # t1794: "cuda:0 f32[1, 32, 512, 128]"
# t1795 = prims.add(t1791, t1794) # t1795: "cuda:0 f32[1, 32, 512, 128]"
# t1796 = prims.convert_element_type(t1795, dtypes.bfloat16) # t1796: "cuda:0 bf16[1, 32, 512, 128]"
# t1776 = prims.mul(t1775, t154) # t1776: "cuda:0 f32[1, 32, 512, 128]"
# t1778 = prims.convert_element_type(t1773, dtypes.float32) # t1778: "cuda:0 f32[1, 32, 512, 128]"
# t1779 = prims.mul(t1778, t157) # t1779: "cuda:0 f32[1, 32, 512, 128]"
# t1780 = prims.add(t1776, t1779) # t1780: "cuda:0 f32[1, 32, 512, 128]"
# t1781 = prims.convert_element_type(t1780, dtypes.bfloat16) # t1781: "cuda:0 bf16[1, 32, 512, 128]"
del t1767, t1773, t1782, t1788
t1800 = torch.cat((t1796, t1799), -1) # t1800: "cuda:0 bf16[1, 32, 512, 128]"
# t1800 = ltorch.cat((t1796, t1799), -1) # t1800: "cuda:0 bf16[1, 32, 512, 128]"
# t1800 = prims.cat((t1796, t1799), -1) # t1800: "cuda:0 bf16[1, 32, 512, 128]"
del t1796, t1799
t1798 = torch.cat((t1781, t1797), -1) # t1798: "cuda:0 bf16[1, 32, 512, 128]"
# t1798 = ltorch.cat((t1781, t1797), -1) # t1798: "cuda:0 bf16[1, 32, 512, 128]"
# t1798 = prims.cat((t1781, t1797), -1) # t1798: "cuda:0 bf16[1, 32, 512, 128]"
del t1781, t1797
(t1801, t1802, t1803, t1804, _, _, t1805, t1806, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1798, t1800, t1766, 0.0, True, scale=0.08838834764831843)
t1808 = torch.permute(t1801, (0, 2, 1, 3)) # t1808: "cuda:0 bf16[1, 512, 32, 128]"
# t1808 = ltorch.permute(t1801, (0, 2, 1, 3)) # t1808: "cuda:0 bf16[1, 512, 32, 128]"
# t1808 = prims.transpose(t1801, (0, 2, 1, 3)) # t1808: "cuda:0 bf16[1, 512, 32, 128]"
t1809 = torch.reshape(t1808, (1, 512, 4096)) # t1809: "cuda:0 bf16[1, 512, 4096]"
# t1809 = ltorch.reshape(t1808, (1, 512, 4096)) # t1809: "cuda:0 bf16[1, 512, 4096]"
# t1809 = prims.reshape(t1808, (1, 512, 4096)) # t1809: "cuda:0 bf16[1, 512, 4096]"
del t1808
t1810 = torch.nn.functional.linear(t1809, t115, None) # t1810: "cuda:0 bf16[1, 512, 4096]"
# t1810 = ltorch.linear(t1809, t115, None) # t1810: "cuda:0 bf16[1, 512, 4096]"
# t1810 = prims.linear(t1809, t115, None) # t1810: "cuda:0 bf16[1, 512, 4096]"
[t1814, t1821, t1829] = nvFusion78(t1742, t1810, t1825)
# t1812 = prims.convert_element_type(t1742, dtypes.float32) # t1812: "cuda:0 f32[1, 512, 4096]"
# t1811 = prims.convert_element_type(t1810, dtypes.float32) # t1811: "cuda:0 f32[1, 512, 4096]"
# t1813 = prims.add(t1811, t1812) # t1813: "cuda:0 f32[1, 512, 4096]"
# t1814 = prims.convert_element_type(t1813, dtypes.bfloat16) # t1814: "cuda:0 bf16[1, 512, 4096]"
# t1816 = prims.mul(t1813, t1813) # t1816: "cuda:0 f32[1, 512, 4096]"
# t1817 = prims.sum(t1816, (2,)) # t1817: "cuda:0 f32[1, 512]"
# t1818 = prims.broadcast_in_dim(t1817, [1, 512, 1], [0, 1]) # t1818: "cuda:0 f32[1, 512, 1]"
# t1819 = prims.div(t1818, 4096.0) # t1819: "cuda:0 f32[1, 512, 1]"
# t1820 = prims.add(t1819, 1e-05) # t1820: "cuda:0 f32[1, 512, 1]"
# t1821 = prims.rsqrt(t1820) # t1821: "cuda:0 f32[1, 512, 1]"
# t1822 = prims.broadcast_in_dim(t1821, (1, 512, 4096), (0, 1, 2)) # t1822: "cuda:0 f32[1, 512, 4096]"
# t1823 = prims.mul(t1813, t1822) # t1823: "cuda:0 f32[1, 512, 4096]"
# t1827 = prims.convert_element_type(t1825, dtypes.float32) # t1827: "cuda:0 f32[1, 512, 4096]"
# t1828 = prims.mul(t1823, t1827) # t1828: "cuda:0 f32[1, 512, 4096]"
# t1829 = prims.convert_element_type(t1828, dtypes.bfloat16) # t1829: "cuda:0 bf16[1, 512, 4096]"
t1831 = torch.nn.functional.linear(t1829, t50, None) # t1831: "cuda:0 bf16[1, 512, 11008]"
# t1831 = ltorch.linear(t1829, t50, None) # t1831: "cuda:0 bf16[1, 512, 11008]"
# t1831 = prims.linear(t1829, t50, None) # t1831: "cuda:0 bf16[1, 512, 11008]"
t1830 = torch.nn.functional.linear(t1829, t34, None) # t1830: "cuda:0 bf16[1, 512, 11008]"
# t1830 = ltorch.linear(t1829, t34, None) # t1830: "cuda:0 bf16[1, 512, 11008]"
# t1830 = prims.linear(t1829, t34, None) # t1830: "cuda:0 bf16[1, 512, 11008]"
[t1845] = nvFusion79(t1830, t1831)
# t1832 = prims.convert_element_type(t1830, dtypes.float32) # t1832: "cuda:0 f32[1, 512, 11008]"
# t1833 = prims.neg(t1832) # t1833: "cuda:0 f32[1, 512, 11008]"
# t1834 = prims.exp(t1833) # t1834: "cuda:0 f32[1, 512, 11008]"
# t1835 = prims.add(1.0, t1834) # t1835: "cuda:0 f32[1, 512, 11008]"
# t1836 = prims.reciprocal(t1835) # t1836: "cuda:0 f32[1, 512, 11008]"
# t1840 = prims.mul(t1832, t1836) # t1840: "cuda:0 f32[1, 512, 11008]"
# t1843 = prims.convert_element_type(t1831, dtypes.float32) # t1843: "cuda:0 f32[1, 512, 11008]"
# t1844 = prims.mul(t1840, t1843) # t1844: "cuda:0 f32[1, 512, 11008]"
# t1845 = prims.convert_element_type(t1844, dtypes.bfloat16) # t1845: "cuda:0 bf16[1, 512, 11008]"
t1846 = torch.nn.functional.linear(t1845, t116, None) # t1846: "cuda:0 bf16[1, 512, 4096]"
# t1846 = ltorch.linear(t1845, t116, None) # t1846: "cuda:0 bf16[1, 512, 4096]"
# t1846 = prims.linear(t1845, t116, None) # t1846: "cuda:0 bf16[1, 512, 4096]"
[t1857, t1865] = nvFusion80(t1814, t1846, t1861)
# t1848 = prims.convert_element_type(t1814, dtypes.float32) # t1848: "cuda:0 f32[1, 512, 4096]"
# t1847 = prims.convert_element_type(t1846, dtypes.float32) # t1847: "cuda:0 f32[1, 512, 4096]"
# t1849 = prims.add(t1847, t1848) # t1849: "cuda:0 f32[1, 512, 4096]"
# t1852 = prims.mul(t1849, t1849) # t1852: "cuda:0 f32[1, 512, 4096]"
# t1853 = prims.sum(t1852, (2,)) # t1853: "cuda:0 f32[1, 512]"
# t1854 = prims.broadcast_in_dim(t1853, [1, 512, 1], [0, 1]) # t1854: "cuda:0 f32[1, 512, 1]"
# t1855 = prims.div(t1854, 4096.0) # t1855: "cuda:0 f32[1, 512, 1]"
# t1856 = prims.add(t1855, 1e-05) # t1856: "cuda:0 f32[1, 512, 1]"
# t1857 = prims.rsqrt(t1856) # t1857: "cuda:0 f32[1, 512, 1]"
# t1858 = prims.broadcast_in_dim(t1857, (1, 512, 4096), (0, 1, 2)) # t1858: "cuda:0 f32[1, 512, 4096]"
# t1859 = prims.mul(t1849, t1858) # t1859: "cuda:0 f32[1, 512, 4096]"
# t1863 = prims.convert_element_type(t1861, dtypes.float32) # t1863: "cuda:0 f32[1, 512, 4096]"
# t1864 = prims.mul(t1859, t1863) # t1864: "cuda:0 f32[1, 512, 4096]"
# t1865 = prims.convert_element_type(t1864, dtypes.bfloat16) # t1865: "cuda:0 bf16[1, 512, 4096]"
t1866 = torch.nn.functional.linear(t1865, t51, None) # t1866: "cuda:0 bf16[1, 512, 32000]"
# t1866 = ltorch.linear(t1865, t51, None) # t1866: "cuda:0 bf16[1, 512, 32000]"
# t1866 = prims.linear(t1865, t51, None) # t1866: "cuda:0 bf16[1, 512, 32000]"
return {'output': t1866, 'flat_args': [t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, t18, t19, t20, t21, t22, t23, t24, t25, t26, t27, t28, t29, t30, t31, t32, t33, t34, t35, t36, t37, t38, t39, t40, t41, t42, t43, t44, t45, t46, t47, t48, t49, t50, t51, t52, t53, t54, t55, t56, t57, t58, t59, t60, t61, t62, t63, t64, t65, t66, t67, t68, t69, t70, t71, t72, t73, t74, t75, t76, t77, t78, t79, t80, t81, t82, t83, t84, t85, t86, t87, t88, t89, t90, t91, t92, t93, t94, t95, t96, t97, t98, t99, t100, t101, t102, t103, t104, t105, t106, t107, t108, t109, t110, t111, t112, t113, t114, t115, t116, t117], 'flat_output': (t1866,)}, ((t0, t10, t100, t1001, t101, t1010, t102, t103, t104, t1042, t1044, t1045, t1046, t1047, t1048, t1049, t105, t1050, t1053, t1054, t1058, t106, t1065, t1069, t107, t1073, t1074, t1075, t108, t1089, t109, t1090, t1094, t11, t110, t1101, t1105, t1109, t111, t1118, t112, t113, t114, t115, t1150, t1152, t1153, t1154, t1155, t1156, t1157, t1158, t116, t1161, t1162, t1166, t1173, t1177, t1181, t1182, t1183, t1197, t1198, t12, t1202, t1209, t1213, t1217, t122, t1226, t1258, t1260, t1261, t1262, t1263, t1264, t1265, t1266, t1269, t1270, t1274, t1281, t1285, t1289, t129, t1290, t1291, t13, t1305, t1306, t1310, t1317, t1321, t1325, t133, t1334, t1366, t1368, t1369, t137, t1370, t1371, t1372, t1373, t1374, t1377, t1378, t1382, t1389, t1393, t1397, t1398, t1399, t14, t1413, t1414, t1418, t1425, t1429, t1433, t1442, t146, t1474, t1476, t1477, t1478, t1479, t1480, t1481, t1482, t1485, t1486, t1490, t1497, t15, t1501, t1505, t1506, t1507, t1521, t1522, t1526, t1533, t1537, t154, t1541, t1550, t157, t1582, t1584, t1585, t1586, t1587, t1588, t1589, t1590, t1593, t1594, t1598, t16, t1605, t1609, t1613, t1614, t1615, t1629, t1630, t1634, t1641, t1645, t1649, t1658, t1690, t1692, t1693, t1694, t1695, t1696, t1697, t1698, t17, t1701, t1702, t1706, t1713, t1717, t1721, t1722, t1723, t1737, t1738, t1742, t1749, t1753, t1757, t1766, t178, t1798, t18, t180, t1800, t1801, t1802, t1803, t1804, t1805, t1806, t1809, t181, t1810, t1814, t182, t1821, t1825, t1829, t183, t1830, t1831, t184, t1845, t1846, t185, t1857, t186, t1861, t1865, t189, t19, t190, t194, t20, t201, t205, t209, t21, t210, t211, t22, t225, t226, t23, t230, t237, t24, t241, t245, t25, t254, t26, t27, t28, t286, t288, t289, t29, t290, t291, t292, t293, t294, t297, t298, t3, t30, t302, t309, t31, t313, t317, t318, t319, t32, t33, t333, t334, t338, t34, t345, t349, t35, t353, t36, t362, t37, t38, t39, t394, t396, t397, t398, t399, t4, t40, t400, t401, t402, t405, t406, t41, t410, t417, t42, t421, t425, t426, t427, t43, t44, t441, t442, t446, t45, t453, t457, t46, t461, t47, t470, t48, t49, t5, t50, t502, t504, t505, t506, t507, t508, t509, t51, t510, t513, t514, t518, t525, t529, t533, t534, t535, t549, t550, t554, t561, t565, t569, t578, t6, t610, t612, t613, t614, t615, t616, t617, t618, t621, t622, t626, t633, t637, t641, t642, t643, t657, t658, t662, t669, t673, t677, t686, t7, t718, t720, t721, t722, t723, t724, t725, t726, t729, t730, t734, t741, t745, t749, t750, t751, t765, t766, t770, t777, t781, t785, t794, t8, t826, t828, t829, t830, t831, t832, t833, t834, t837, t838, t842, t849, t85, t853, t857, t858, t859, t86, t87, t873, t874, t878, t88, t885, t889, t89, t893, t9, t90, t902, t91, t92, t93, t934, t936, t937, t938, t939, t94, t940, t941, t942, t945, t946, t95, t950, t957, t96, t961, t965, t966, t967, t97, t98, t981, t982, t986, t99, t993, t997), (False, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 0.0, 4096.0, 4096.0, 0.08838834764831843, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 32000, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2))
Well, that is quite a bit to look through. But here is a key thing: The function now returns a bunch of things. This is because Thunder applies the same treatment to the backward and to this end saves information from the forward. You can see a hint of this because the output has a ThunderFunctionBackward
on as its grad_fn
. (You can see the backward trace with thunder.last_backward_traces(thunder_model)[-1]
).
[10]:
actual
[10]:
tensor([[[ 0.4160, -0.4668, 1.1016, ..., 0.5430, 1.2656, 0.2891],
[ 0.3320, -0.0557, 1.7891, ..., 1.0703, 1.0078, 1.2266],
[ 0.6836, -0.2871, 0.9531, ..., 0.0806, 0.7070, 0.8477],
...,
[ 0.7695, -0.1260, 0.7266, ..., 0.1118, -0.0238, -1.2656],
[-0.7773, -0.5547, -0.3047, ..., -0.1807, 0.1895, 0.6875],
[ 0.8867, 0.4766, 0.3984, ..., 0.0815, -0.0879, 0.3477]]],
device='cuda:0', grad_fn=<ThunderFunctionBackward>)
Let us clean up a bit.
[11]:
del actual, expected
import gc
gc.collect();
But is it faster? Yes!
[12]:
%timeit r = m(inp); torch.autograd.grad(r.sum(), m.parameters()); torch.cuda.synchronize()
%timeit r = thunder_model(inp); torch.autograd.grad(r.sum(), m.parameters()); torch.cuda.synchronize()
240 ms ± 105 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
208 ms ± 147 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
So far, so good! Thunder should work with LitGPT today and we busy are adding the support required to run other models as well!
[13]:
del m, thunder_model
import gc
gc.collect()
torch.cuda.empty_cache()
Distributed with Thunder
Those Large Language Models are called Large for a reason, and memory in a single GPU is invariably small. So we need multiple.
Happily Thunder sports an FSDP interface to use multiple cards in our box.
You still need to setup the process group, but as far as the model is concerned,
model = thunder.jit(thunder.distributed.fsdp(model))
is all you need. Because it is tricky to run multiprocessing from Notebooks, we write a small example into a file and run it though torch-run
.
Check out our LitGPT Thunder examples for complete distributed training and finetuning!
[14]:
%%writefile zero_to_thunder_fsdp_simple_example.py
from thunder.tests.litgpt_model import GPT, Config
import os
import torch, torch.distributed
import thunder, thunder.distributed
# Create Model
# NOTE: We create the model on CPU.
device='cpu'
torch.set_default_dtype(torch.bfloat16)
cfg = Config.from_name('Llama-2-7b-hf')
cfg.n_layer = 8 # fewer layers
model = GPT(cfg)
# Setup for distributed
torch.distributed.init_process_group(backend='nccl')
rank = int(os.environ["LOCAL_RANK"])
device = f"cuda:{rank}"
x = torch.randint(1, model.config.vocab_size, (1, 1024), device=device)
# thunder.distributed.fsdp takes care of moving the parameter
# shard to the correct GPU for the current process.
model = thunder.jit(thunder.distributed.fsdp(model)) # <---------------------------------------
print(f"rank {rank} computing")
# Run the forward pass.
for i in range(10):
res = model(x)
res.sum().backward()
Overwriting zero_to_thunder_fsdp_simple_example.py
Now we can launch it. Note that you need two GPUs for this to run correctly.
[15]:
# commented out for CI limitations, see https://github.com/Lightning-AI/lightning-thunder/issues/465
# !torchrun --standalone --nnodes=1 --nproc_per_node=2 zero_to_thunder_fsdp_simple_example.py
W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757]
W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] *****************************************
W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] *****************************************
rank 1 computing
rank 0 computing
So there. FSDP with just wrapping the model in fsdp
.
Extending Thunder
But we promised that thunder is extensible. Let’s find out what’s up with that.
Specifically, we will incorporate the fast rope embedding kernel from the great Unsloth project into our model (note that NVFuser also creates a fused kernel for this).
In Thunder, extensions (as well as most builtin optimizations which use the exact same mechanism) work with executors handling operations. Let us define one.
[16]:
my_ex = thunder.extend.OperatorExecutor('my_ex', version='0.0.1')
thunder.extend.register_executor(my_ex)
[16]:
my_ex
For our base implementation, we take the code from LitGPT’s implementation
In thunder, we define a meta function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the register_operator
function. Because we will demonstrate Thunder’s ability to divert functions in the model, we make a version here that will not be diverted.
[17]:
import litgpt
def apply_rope_copy(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
head_size = x.size(-1)
x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
roped = (x * cos) + (rotated * sin)
return roped.to(dtype=x.dtype)
Registering operators
Say we have a function apply_rope
applying the RoPE transformation in PyTorch.
In thunder, we define a meta function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the register_operator
function and tell it to use the new symbol instead of the original function litgpt.model.apply_rope
.
[18]:
import torch, thunder
from thunder.tests.litgpt_model import GPT
from thunder import TensorProxy
def apply_rope_impl(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
return litgpt.model.apply_rope(x, cos, sin)
def apply_rope_meta(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:
return TensorProxy(like=x)
apply_rope = my_ex.register_operator('apply_rope', like=apply_rope_meta, fn=apply_rope_impl,
replaces=litgpt.model.apply_rope)
Testing our new operator
[19]:
with torch.device('cuda'): m = GPT.from_name('llama2-like'); Q = torch.randn(2, 128, 4096, 16)
def test_apply_rope(x, m):
return litgpt.model.apply_rope(x, m.cos, m.sin)
thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors())
expected = test_apply_rope(Q, m); actual = thunder_apply_rope(Q, m); print("deviation:", (expected - actual).abs().max().item())
thunder.last_traces(thunder_apply_rope)[-1]
deviation: 0.0
[19]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(x, t_1_cos, t_1_sin):
# x: "cuda:0 bf16[2, 128, 4096, 16]"
# t_1_cos: "cuda:0 f32[4096, 16]"
# t_1_sin: "cuda:0 f32[4096, 16]"
t2 = apply_rope(x, t_1_cos, t_1_sin) # t2: "cuda:0 bf16[2, 128, 4096, 16]"
del x, t_1_cos, t_1_sin
return t2
Optimized kernels
But why did we do this? Well, we can now layer a faster implementation on top. For this we take the unsloth fast rope embedding kernels. We take the bits that were in the forward and backward of the autograd.Function
into our implementation functions. Note that we include the transpositions in our setup in order to have compatibility to the LitGPT implementation. This
change in memory layout of the operands can have a large effect on the runtime though, so our timings are likely not representative of the ones the Unsloth project gets in their use of the same triton kernels.
[20]:
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import triton
import triton.language as tl
import torch
MAX_FUSED_SIZE = 65536
next_power_of_2 = triton.next_power_of_2
def calculate_settings(n):
BLOCK_SIZE = next_power_of_2(n)
if BLOCK_SIZE > MAX_FUSED_SIZE:
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
num_warps = 4
if BLOCK_SIZE >= 32768: num_warps = 32
elif BLOCK_SIZE >= 8192: num_warps = 16
elif BLOCK_SIZE >= 2048: num_warps = 8
return BLOCK_SIZE, num_warps
@triton.heuristics({"BACKWARD_PASS": lambda args: args["BACKWARD_PASS"],})
@triton.jit
def _rope_embedding(
Q, Q_row_stride,
cos, cos_row_stride,
sin, sin_row_stride,
seqlen, head_dim, group_size, n_heads,
BACKWARD_PASS: tl.constexpr,
BLOCK_SIZE : tl.constexpr,
):
"""
Calculates the RoPE Embedding quickly
RoPE is Q * cos + rotate_half(Q) * sin
See our blog post for more info
"""
row_position = tl.program_id(0)
group_head_position = tl.program_id(1)
col_offsets = tl.arange(0, BLOCK_SIZE)
half_head_dim = head_dim // 2
mask = col_offsets < half_head_dim
sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \
half_head_dim*0 + col_offsets, mask = mask, other = 0)
cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \
half_head_dim*0 + col_offsets, mask = mask, other = 0)
if BACKWARD_PASS:
# See our blog post for more info.
sin1 = -sin1
pass
head_start = group_head_position * group_size
head_end = min((head_start + group_size), n_heads)
for i in range(head_start, head_end):
offs_q1 = row_position * Q_row_stride + i * head_dim + col_offsets
offs_q2 = row_position * Q_row_stride + i * head_dim + col_offsets + half_head_dim
# For Gemma - sometimes RoPE must be done in float32 and not bfloat16
Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)
Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)
tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)
tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)
pass
pass
def fast_rope_embedding_forward(Q, cos, sin):
Q = Q.transpose(1, 2).clone()
cos, sin = cos.squeeze(), sin.squeeze()
batch, seq_len, n_heads, head_dim = Q.shape
Q = Q.reshape(batch*seq_len, n_heads*head_dim)
n_rows, n_cols = Q.shape
assert(seq_len <= cos.shape[0])
# [TODO] Changing blocksize to head_dim//2 seems to have
# some concurrency / un-deterministic issues.
BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)
group_size = 4 # 4 or 8, too large group_size can hurt performance.
n_groups = triton.cdiv(n_heads, group_size)
grid = (n_rows, n_groups, )
_rope_embedding[grid](
Q, Q.stride(0),
cos, cos.stride(0),
sin, sin.stride(0),
seq_len, head_dim, group_size, n_heads,
BACKWARD_PASS = False,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = num_warps,
)
Q = Q.view(batch, seq_len, n_heads, head_dim).transpose(1, 2)
return Q, (BLOCK_SIZE, num_warps)
def fast_rope_embedding_backward(BLOCK_SIZE, num_warps, cos, sin, dY):
dY = dY.transpose(1, 2)
batch, seq_len, n_heads, head_dim = dY.shape
dY = dY.reshape(batch*seq_len, n_heads*head_dim)
# Must be reshape not view
n_rows, n_cols = dY.shape
group_size = 4 # 4 or 8, too large group_size can hurt performance.
n_groups = triton.cdiv(n_heads, group_size)
grid = (n_rows, n_groups, )
_rope_embedding[grid](
dY, dY .stride(0),
cos, cos.stride(0),
sin, sin.stride(0),
seq_len, head_dim, group_size, n_heads,
BACKWARD_PASS = True,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = num_warps,
)
dY = dY.view(batch, seq_len, n_heads, head_dim)
dY = dY.transpose(1, 2)
return dY
We also define the corresponding meta functions.
[21]:
def fast_rope_embedding_forward_meta(Q, cos, sin):
batch, n_heads, seq_len, head_dim = Q.shape
n_rows, n_cols = batch*seq_len, n_heads*head_dim
assert(seq_len <= cos.shape[0])
BLOCK_SIZE, num_warps = calculate_settings(head_dim//2)
return TensorProxy(like=Q), (BLOCK_SIZE, num_warps)
def fast_rope_embedding_backward_meta(BLOCK_SIZE, num_warps, cos, sin, dY):
return TensorProxy(like=dY)
Register optimized operators
Just like the apply_rope
before, we can register operators for the optimized forward and backward.
[22]:
unsloth_apply_rope_forward = my_ex.register_operator('unsloth_apply_rope_forward',
meta=fast_rope_embedding_forward_meta, fn=fast_rope_embedding_forward)
unsloth_apply_rope_backward = my_ex.register_operator('unsloth_apply_rope_backward',
meta=fast_rope_embedding_backward_meta, fn=fast_rope_embedding_backward)
Implementations for operators
Do we need to divert apply_rope
again? No! We can register the specialized kernel as an implementation of our base apply_rope
operator. For this we need an execution transform - which is a fancy word for a function that implements the original operator (apply_ropw
) in terms of our new operator - so it has the call signature of the apply_rope
. Because - like many fast implementations - the unsloth rope embedding does not implement the operator in full generality (well,
actually they mainly want a 4d tensor input), we implement a checker function, too: It takes the arguments of the operator we want specialize and returns a bool whether our implementation handles the given inputs.
[23]:
def apply_rope_to_unsloth(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:
assert len(x.shape) == 4
res, *_ = unsloth_apply_rope_forward(x, cos, sin)
return res
def apply_rope_to_unsloth_checker(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> bool:
if len(x.shape) != 4:
return False
return (x.device.devicetype == thunder.devices.DeviceType.CUDA and
cos.device.devicetype == thunder.devices.DeviceType.CUDA and
cos.device.devicetype == thunder.devices.DeviceType.CUDA)
my_ex.register_implementation(apply_rope,
checker=apply_rope_to_unsloth_checker,
execution_transform=apply_rope_to_unsloth)
So let us give it a try! Works great…
[24]:
thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors())
expected = test_apply_rope(Q, m)
actual = thunder_apply_rope(Q, m)
print("deviation:", (expected - actual).abs().max().item())
thunder.last_traces(thunder_apply_rope)[-1]
deviation: 0.015625
[24]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(x, t_1_cos, t_1_sin):
# x: "cuda:0 bf16[2, 128, 4096, 16]"
# t_1_cos: "cuda:0 f32[4096, 16]"
# t_1_sin: "cuda:0 f32[4096, 16]"
(t2, (_, _)) = unsloth_apply_rope_forward(x, t_1_cos, t_1_sin)
del x, t_1_cos, t_1_sin
return t2
And this is also automatic when we instantiate a larger llama2-like model:
[25]:
torch.set_default_dtype(torch.float32)
with torch.device('cuda'):
m = GPT(Config.from_name('llama2-like'))
for p in m.parameters():
p.requires_grad_(False)
thunder_model = thunder.jit(m, executors=(my_ex,) + thunder.get_default_executors())
inp = torch.randint(1, m.config.vocab_size, (1, 128), device="cuda")
actual = thunder_model(inp)
expected = m(inp)
print("deviation:", (actual - expected).abs().max().item())
deviation: 5.960464477539062e-07
By peeking into the trace, we can see that it actually used the unsloth apply rope:
[26]:
[s for s in str(thunder.last_traces(thunder_model)[-1]).split('\n') if 'apply_rope' in s]
[26]:
[' (q_roped, (_, _)) = unsloth_apply_rope_forward(t55, cos, sin)',
' (k_roped, (_, _)) = unsloth_apply_rope_forward(t57, cos, sin)',
' (t165, (_, _)) = unsloth_apply_rope_forward(t164, cos, sin)',
' (t167, (_, _)) = unsloth_apply_rope_forward(t166, cos, sin)']
But what about the backward?
Well, we have to connect forward and backward with a grad transformation. With our specialized ops, this is very simple, we compute the forward, call get_grad
for the output, compute the backward, and put it on the input with put_grads
.
[27]:
from thunder.core.transforms import get_grad, put_grads
def unsloth_apply_rope_grad(x: TensorProxy, cos: TensorProxy, sin: TensorProxy):
res, (BLOCK_SIZE, num_warps) = unsloth_apply_rope_forward(x, cos, sin)
grad_res = get_grad(res)
grad_x = unsloth_apply_rope_backward(BLOCK_SIZE, num_warps, cos, sin, grad_res)
put_grads((x,), (grad_x,))
return res
my_ex.register_implementation(apply_rope, checker=apply_rope_to_unsloth_checker,
execution_transform=apply_rope_to_unsloth,
grad_transform=unsloth_apply_rope_grad
)
Note that the parts are not actually executed at the same time in the actual computation, but just during tracing.
And let us try our function using the optimized backward
[28]:
Q.requires_grad_()
thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors())
expected = test_apply_rope(Q, m)
go = torch.ones_like(expected)
gr_expected, = torch.autograd.grad(expected, Q, go)
actual = thunder_apply_rope(Q, m)
gr_actual, = torch.autograd.grad(actual, Q, go)
print("res deviation:", (expected - actual).abs().max().item())
print("grad deviation:", (gr_expected - gr_actual).abs().max().item())
res deviation: 0.015625
grad deviation: 0.0078125
And with last_backward_traces
we can check that our module is using the unsloth backward:
[29]:
thunder.last_backward_traces(thunder_apply_rope)[-1]
[29]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, \
_, \
= saved_for_backward
clear_collection(saved_for_backward)
del saved_for_backward
t4, \
= cotangents
clear_collection(cotangents)
del cotangents
t1, \
t2, \
= C0
clear_collection(C0)
del C0
t3 = unsloth_apply_rope_backward(8, 4, t1, t2, t4) # t3: "cuda:0 bf16[2, 128, 4096, 16]"
del t1, t2, t4
return (t3, None, None)
Comparing and exploring optimizations
It is also straightforward to compare potential optimizations.
Note again, that our use of the unsloth kernel might not result in the same performance as the unsloth project sees due to differences in the hardware used, software environment, or memory layout of the operands.
[30]:
def test_apply_rope_copy(x, m):
return apply_rope_copy(x, m.cos, m.sin)
test_apply_rope_myex = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors())
test_apply_rope_nvfuser = thunder.jit(test_apply_rope_copy)
y = test_apply_rope(Q, m); gr, = torch.autograd.grad(y, Q, go)
y = test_apply_rope_myex(Q, m); gr, = torch.autograd.grad(y, Q, go)
y = test_apply_rope_nvfuser(Q, m); gr, = torch.autograd.grad(y, Q, go)
print("eager")
%timeit y = test_apply_rope(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()
print("thunder + unsloth")
%timeit y = test_apply_rope_myex(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()
print("thunder default (nvfuser)")
%timeit y = test_apply_rope_nvfuser(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()
eager
3.84 ms ± 3.46 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
thunder + unsloth
6.69 ms ± 3.45 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
thunder default (nvfuser)
1.4 ms ± 4.98 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
That’s it!
Conclusion
To wrap up, we hope you got a taste of
Getting things going with Thunder:
Applying Thunder through
thunder.jit
andusing FSDP by just wrapping the model in
thunder.distributed.fsdp
before compilation.
See what’s going on inspecting traces:
thunder.last_traces
for the forward traces,thunder.last_backward_traces
for the backward,
Extending Thunder:
registering operators with the
OperatorExecutor
,defining implementations with custom forward and backward to include optimized kernels.
Keep in mind that Thunder is still experimental and only expected to work with the limited set of models we have tested it with. You will find bugs and missing pieces. Naturally, we would love for you to help us fix these! You can find us on the Thunder section of the Lightning forums or in the #thunder
channel on the PyTorch-Lightning slack.
Do check out our LitGPT studios and the other tutorial notebooks.