# Zero to Thunder

Here we take a very short tour of what is possible with Thunder.

To get started we import it (and a bunch of things for this notebook).

In [1]:
import sys
sys.path.insert(0, '..')

import torch, thunder

## Compiling a first module with Thunder

So let's get started! As a "Hello World", let us apply it to it to a small model, say, the MLP part found in Llama 2. We take it from LitGPT.

In [2]:
class LLaMAMLP(torch.nn.Module):
    def __init__(self, n_embd, intermediate_size) -> None:
        super().__init__()
        self.fc_1 = torch.nn.Linear(n_embd, intermediate_size, bias=False)
        self.fc_2 = torch.nn.Linear(n_embd, intermediate_size, bias=False)
        self.proj = torch.nn.Linear(intermediate_size, n_embd, bias=False)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_fc_1 = self.fc_1(x)
        x_fc_2 = self.fc_2(x)
        x = torch.nn.functional.silu(x_fc_1) * x_fc_2
        return self.proj(x)
with torch.device("cuda"):
    m = LLaMAMLP(4096, 11008)
for p in m.parameters():
    p.requires_grad_(False)
print(m)


LLaMAMLP(
  (fc_1): Linear(in_features=4096, out_features=11008, bias=False)
  (fc_2): Linear(in_features=4096, out_features=11008, bias=False)
  (proj): Linear(in_features=11008, out_features=4096, bias=False)
)


Now we can apply Thunder. This uses the most important function of Thunder, `thunder.jit`, which can be used to compile a `torch.nn.Module` or a function. It will wrap our MLP in a `ThunderModule`

In [3]:
thunder_model = thunder.jit(m)

In [4]:
thunder_model

ThunderModule(
  (_model): LLaMAMLP(
    (fc_1): Linear(in_features=4096, out_features=11008, bias=False)
    (fc_2): Linear(in_features=4096, out_features=11008, bias=False)
    (proj): Linear(in_features=11008, out_features=4096, bias=False)
  )
)

Our Thunder module computes (up to numerical accuracy) the same thing as our original model and for a small model like this, it also has approximately the same performance.

In [5]:
x = torch.randn(2, 2048, 4096, device="cuda")
print('deviation:', (thunder_model(x) - m(x)).abs().max().item())

%timeit thunder_model(x); torch.cuda.synchronize()
%timeit m(x); torch.cuda.synchronize()

deviation: 1.4901161193847656e-07
61.3 ms ± 106 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
62.1 ms ± 89.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


So what has changed? Quite a bit!

When we call the Thunder module, it does the computation in a single function without control flow. And what's more, it applies optimizations, such as creating fusions for NVFuser to execute. We can see all this by showing the last computation trace:

In [6]:
thunder.last_traces(thunder_model)[-1]

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
import torch.nn.functional
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight):
  # x: "cuda:0 f32[2, 2048, 4096]" 
  # t_fc_1_weight: "cuda:0 f32[11008, 4096]" 
  # t_fc_2_weight: "cuda:0 f32[11008, 4096]" 
  # t_proj_weight: "cuda:0 f32[4096, 11008]" 
  x_fc_1 = torch.nn.functional.linear(x, t_fc_1_weight, None)  # x_fc_1: "cuda:0 f32[2, 2048, 11008]"
    # x_fc_1 = ltorch.linear(x, t_fc_1_weight, None)  # x_fc_1: "cuda:0 f32[2, 2048, 11008]"
      # x_fc_1 = prims.linear(x, t_fc_1_weight, None)  # x_fc_1: "cuda:0 f32[2, 2048, 11008]"
  del t_fc_1_weight
  x_fc_2 = torch.nn.functional.linear(x, t_fc_2_weight, None)  # x_fc_2: "cuda:0 f32[2, 2048, 11008]"
    # x_fc_2 = ltorch.linear(x, t_fc_2_weight, None)  # x_fc_2: "cuda:0 f32[2, 2048, 11008]"
      # x_fc_2 = prims.linear(x, t_fc_2_weight, None)  # x_fc_2: "cuda:0 f32[2, 

For more detail of what is going on in this trace:
- Thunder has transformed the computation (more precisely, `m.__call__`) into a single function which has all the MLP parameters as arguments.
- It has recorded the tensor metadata.
- Operations have been mapped from the PyTorch functions to `thunder.torch`(aka `ltorch`) equivalents and decomposed into _primitive operations_.
- The multiplication and activation (`x = torch.nn.functional.silu(x_fc_1) * x_fc_2`have been put into one NVFuser fusion. (NVFuser here is (a particularly important) one of many optimizations, and we make it easy to add your own.) 
- You can see how the parameters are obtained and the metadata is checked in the prologue - get it through `thunder.last_prologue_traces(thunder_model)[-1]`.

You can actually see the series of traces, `last_traces` gives you a list of transformed traces in chronological order - for example the initial trace `thunder.last_traces(thunder_model)[0]` does not have the fusion yet.


## Compiling a more complex model

Obviously, we aim for larger models, so we can do the same with the entire LLama 2 (well, we have a smaller momdel here to be mild to our CI, but if you have a large GPU, just drop reducing the number of layers):

**NOTE**: For running the cells below, we require `litgpt` which can be installed with `pip install 'litgpt[all] @ git+https://github.com/Lightning-AI/litgpt'`. See [here](https://github.com/Lightning-AI/litgpt) to learn more about litgpt.

In [7]:
from litgpt import GPT
from thunder.tests.litgpt_model import Config
cfg = Config.from_name('Llama-2-7b-hf')
cfg.n_layer = 16 # fewer layers
torch.set_default_dtype(torch.bfloat16)
with torch.device('cuda'):
    m = GPT(cfg)
m


GPT(
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
  (transformer): ModuleDict(
    (wte): Embedding(32000, 4096)
    (h): ModuleList(
      (0-15): 16 x Block(
        (norm_1): RMSNorm()
        (attn): CausalSelfAttention(
          (attn): Linear(in_features=4096, out_features=12288, bias=False)
          (proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (norm_2): RMSNorm()
        (mlp): LLaMAMLP(
          (fc_1): Linear(in_features=4096, out_features=11008, bias=False)
          (fc_2): Linear(in_features=4096, out_features=11008, bias=False)
          (proj): Linear(in_features=11008, out_features=4096, bias=False)
        )
      )
    )
    (ln_f): RMSNorm()
  )
)

Again we jit our model and compare the output...

In [8]:
thunder_model = thunder.jit(m)

inp = torch.randint(1, m.config.vocab_size, (1, 512), device="cuda")

actual = thunder_model(inp)
expected = m(inp)

print("deviation:", (actual - expected).abs().max().item())


deviation: 0.03125


One thing to keep in mind here is that for bf16, the numerical accuracy impact of rearranging operations can be quite pronounced.

Just like before, we can see the program it ran, it is a lot longer, though.

In [9]:
print(actual.grad_fn)
thunder.last_traces(thunder_model)[-1]

<torch.autograd.function.ThunderFunctionBackward object at 0x7f923f792ac0>


# Constructed by Delete Last Used (took 10 milliseconds)
import torch
from torch import Tensor
import torch.nn.functional
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def augmented_forward_fn(*args):
  # args: "Collection" 
  t0, \
  t1, \
  t2, \
  t3, \
  t4, \
  t5, \
  t6, \
  t7, \
  t8, \
  t9, \
  t10, \
  t11, \
  t12, \
  t13, \
  t14, \
  t15, \
  t16, \
  t17, \
  t18, \
  t19, \
  t20, \
  t21, \
  t22, \
  t23, \
  t24, \
  t25, \
  t26, \
  t27, \
  t28, \
  t29, \
  t30, \
  t31, \
  t32, \
  t33, \
  t34, \
  t35, \
  t36, \
  t37, \
  t38, \
  t39, \
  t40, \
  t41, \
  t42, \
  t43, \
  t44, \
  t45, \
  t46, \
  t47, \
  t48, \
  t49, \
  t50, \
  t51, \
  t52, \
  t53, \
  t54, \
  t55, \
  t56, \
  t57, \
  t58, \
  t59, \
  t60, \
  t61, \
  t62, \
  t63, \
  t64, \
  t65, \
  t66, \
  t67, \
  t68, \
  t69, \
  t70, \
  t71, \
  t72, \
  t73, \
  t74, \
  t75, \
  t76, \
  t77, \
  t78, \
  t79, \
  t80, \
  t81, \
  t82, \
  t

Well, that is quite a bit to look through.
But here is a key thing: The function now returns a bunch of things. This is because Thunder applies the same treatment to the backward and to this end saves information from the forward. You can see a hint of this because the output has a `ThunderFunctionBackward` on as its `grad_fn`. (You can see the backward trace with 
`thunder.last_backward_traces(thunder_model)[-1]`).

In [10]:
actual

tensor([[[ 0.4160, -0.4668,  1.1016,  ...,  0.5430,  1.2656,  0.2891],
         [ 0.3320, -0.0557,  1.7891,  ...,  1.0703,  1.0078,  1.2266],
         [ 0.6836, -0.2871,  0.9531,  ...,  0.0806,  0.7070,  0.8477],
         ...,
         [ 0.7695, -0.1260,  0.7266,  ...,  0.1118, -0.0238, -1.2656],
         [-0.7773, -0.5547, -0.3047,  ..., -0.1807,  0.1895,  0.6875],
         [ 0.8867,  0.4766,  0.3984,  ...,  0.0815, -0.0879,  0.3477]]],
       device='cuda:0', grad_fn=<ThunderFunctionBackward>)

Let us clean up a bit.

In [11]:
del actual, expected
import gc
gc.collect();

But is it faster? Yes!

In [12]:
%timeit r = m(inp); torch.autograd.grad(r.sum(), m.parameters()); torch.cuda.synchronize()
%timeit r = thunder_model(inp); torch.autograd.grad(r.sum(), m.parameters()); torch.cuda.synchronize()

240 ms ± 105 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
208 ms ± 147 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


So far, so good! Thunder should work with LitGPT today and we busy are adding the support required to run other models as well!


In [13]:
del m, thunder_model
import gc
gc.collect()
torch.cuda.empty_cache()


## Distributed with Thunder

Those Large Language Models are called Large for a reason, and memory in a single GPU is invariably small. So we need multiple.

Happily Thunder sports an FSDP interface to use multiple cards in our box.

You still need to setup the process group, but as far as the model is concerned,

```python
model = thunder.jit(thunder.distributed.fsdp(model))
```

is all you need. Because it is tricky to run multiprocessing from Notebooks, we write a small example into a file and run it though `torch-run`.

Check out our LitGPT Thunder examples for complete distributed training and finetuning!

In [14]:
%%writefile zero_to_thunder_fsdp_simple_example.py
from thunder.tests.litgpt_model import GPT, Config
import os
import torch, torch.distributed
import thunder, thunder.distributed

# Create Model
# NOTE: We create the model on CPU.
device='cpu'
torch.set_default_dtype(torch.bfloat16)
cfg = Config.from_name('Llama-2-7b-hf')
cfg.n_layer = 8 # fewer layers
model = GPT(cfg)

# Setup for distributed
torch.distributed.init_process_group(backend='nccl')
rank = int(os.environ["LOCAL_RANK"])

device = f"cuda:{rank}"
x = torch.randint(1, model.config.vocab_size, (1, 1024), device=device)

# thunder.distributed.fsdp takes care of moving the parameter
# shard to the correct GPU for the current process.
model = thunder.jit(thunder.distributed.fsdp(model)) #  <---------------------------------------
print(f"rank {rank} computing")
# Run the forward pass.
for i in range(10):
    res = model(x)
    res.sum().backward()


Overwriting zero_to_thunder_fsdp_simple_example.py


Now we can launch it. Note that you need two GPUs for this to run correctly.

In [15]:
# commented out for CI limitations, see https://github.com/Lightning-AI/lightning-thunder/issues/465
# !torchrun --standalone --nnodes=1 --nproc_per_node=2 zero_to_thunder_fsdp_simple_example.py

W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] 
W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] *****************************************
W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] *****************************************
rank 1 computing
rank 0 computing


So there. FSDP with just wrapping the model in `fsdp`.


## Extending Thunder

But we promised that thunder is extensible. Let's find out what's up with that.

Specifically, we will incorporate the fast rope embedding kernel from the great [Unsloth project](https://github.com/unslothai/unsloth/) into our model (note that NVFuser also creates a fused kernel for this).

In Thunder, extensions (as well as most builtin optimizations which use the exact same mechanism) work with _executors_ handling operations. Let us define one.

In [16]:
my_ex = thunder.extend.OperatorExecutor('my_ex', version='0.0.1')
thunder.extend.register_executor(my_ex)

my_ex

For our base implementation, we take the code from [LitGPT's  implementation](https://github.com/Lightning-AI/litgpt/blob/be6139e1fd4b240d253efd58124457496d23d173/litgpt/model.py#L355-L361)

In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function.
Because we will demonstrate Thunder's ability to divert functions in the model, we make a version here that will not be diverted.

In [17]:
import litgpt
def apply_rope_copy(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    head_size = x.size(-1)
    x1 = x[..., : head_size // 2]  # (B, nh, T, hs/2)
    x2 = x[..., head_size // 2 :]  # (B, nh, T, hs/2)
    rotated = torch.cat((-x2, x1), dim=-1)  # (B, nh, T, hs)
    roped = (x * cos) + (rotated * sin)
    return roped.to(dtype=x.dtype)

### Registering operators

Say we have a function `apply_rope` applying the RoPE transformation in PyTorch.

In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function and tell it to use the new symbol instead of the original function `litgpt.model.apply_rope`.


In [18]:
import torch, thunder
from thunder.tests.litgpt_model import GPT
from thunder import TensorProxy

def apply_rope_impl(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    return litgpt.model.apply_rope(x, cos, sin)

def apply_rope_meta(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:
    return TensorProxy(like=x)

apply_rope = my_ex.register_operator('apply_rope', like=apply_rope_meta, fn=apply_rope_impl,
                                    replaces=litgpt.model.apply_rope)

### Testing our new operator 

In [19]:
with torch.device('cuda'): m = GPT.from_name('llama2-like'); Q = torch.randn(2, 128, 4096, 16)

def test_apply_rope(x, m):
    return litgpt.model.apply_rope(x, m.cos, m.sin)

thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors())    

expected = test_apply_rope(Q, m); actual = thunder_apply_rope(Q, m); print("deviation:", (expected - actual).abs().max().item())

thunder.last_traces(thunder_apply_rope)[-1]

deviation: 0.0


# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(x, t_1_cos, t_1_sin):
  # x: "cuda:0 bf16[2, 128, 4096, 16]" 
  # t_1_cos: "cuda:0 f32[4096, 16]" 
  # t_1_sin: "cuda:0 f32[4096, 16]" 
  t2 = apply_rope(x, t_1_cos, t_1_sin)  # t2: "cuda:0 bf16[2, 128, 4096, 16]"
  del x, t_1_cos, t_1_sin
  return t2

### Optimized kernels

But why did we do this? Well, we can now layer a faster implementation on top.
For this we take the [unsloth fast rope embedding](https://github.com/unslothai/unsloth/blob/42076f6580e71522ed1c122043edfba595be64e4/unsloth/kernels/rope_embedding.py) kernels. We take the bits that were in the forward and backward of the `autograd.Function` into our implementation functions. Note that we include the transpositions in our setup in order to have compatibility to the LitGPT implementation. This change in memory layout of the operands can have a large effect on the runtime though, so our timings are likely not representative of the ones the Unsloth project gets in their use of the same triton kernels.

In [20]:
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import triton
import triton.language as tl
import torch

MAX_FUSED_SIZE = 65536
next_power_of_2 = triton.next_power_of_2

def calculate_settings(n):
    BLOCK_SIZE = next_power_of_2(n)
    if BLOCK_SIZE > MAX_FUSED_SIZE:
        raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
                           f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
    num_warps = 4
    if   BLOCK_SIZE >= 32768: num_warps = 32
    elif BLOCK_SIZE >=  8192: num_warps = 16
    elif BLOCK_SIZE >=  2048: num_warps = 8
    return BLOCK_SIZE, num_warps

@triton.heuristics({"BACKWARD_PASS": lambda args: args["BACKWARD_PASS"],})
@triton.jit
def _rope_embedding(
    Q,     Q_row_stride,
    cos, cos_row_stride,
    sin, sin_row_stride,
    seqlen, head_dim, group_size, n_heads,
    BACKWARD_PASS: tl.constexpr,
    BLOCK_SIZE : tl.constexpr,
):
    """
        Calculates the RoPE Embedding quickly
        RoPE is Q * cos + rotate_half(Q) * sin
        See our blog post for more info
    """
    row_position  = tl.program_id(0)
    group_head_position = tl.program_id(1)
    col_offsets  = tl.arange(0, BLOCK_SIZE)
    half_head_dim = head_dim // 2
    mask = col_offsets < half_head_dim

    sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \
                   half_head_dim*0 + col_offsets, mask = mask, other = 0)
    cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \
                   half_head_dim*0 + col_offsets, mask = mask, other = 0)

    if BACKWARD_PASS:
        # See our blog post for more info.
        sin1 = -sin1
    pass

    head_start = group_head_position * group_size
    head_end = min((head_start + group_size), n_heads)

    for i in range(head_start, head_end):
        offs_q1 = row_position * Q_row_stride + i * head_dim + col_offsets
        offs_q2 = row_position * Q_row_stride + i * head_dim + col_offsets + half_head_dim

        # For Gemma - sometimes RoPE must be done in float32 and not bfloat16
        Q1   = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)
        Q2   = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)

        tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)
        tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)
    pass
pass


def fast_rope_embedding_forward(Q, cos, sin):
    Q = Q.transpose(1, 2).clone()
    cos, sin = cos.squeeze(), sin.squeeze()
    batch, seq_len, n_heads, head_dim = Q.shape
    Q = Q.reshape(batch*seq_len, n_heads*head_dim)
    n_rows, n_cols = Q.shape
    assert(seq_len <= cos.shape[0])

    # [TODO] Changing blocksize to head_dim//2 seems to have
    # some concurrency / un-deterministic issues.
    BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)
    group_size = 4 # 4 or 8, too large group_size can hurt performance.
    n_groups = triton.cdiv(n_heads, group_size)

    grid = (n_rows, n_groups, )
    _rope_embedding[grid](
          Q,   Q.stride(0),
        cos, cos.stride(0),
        sin, sin.stride(0),
        seq_len, head_dim, group_size, n_heads,
        BACKWARD_PASS = False,
        BLOCK_SIZE = BLOCK_SIZE,
        num_warps  = num_warps,
    )
    Q = Q.view(batch, seq_len, n_heads, head_dim).transpose(1, 2)
    return Q, (BLOCK_SIZE, num_warps) 

def fast_rope_embedding_backward(BLOCK_SIZE, num_warps, cos, sin, dY):
    dY = dY.transpose(1, 2)
    batch, seq_len, n_heads, head_dim = dY.shape
    dY = dY.reshape(batch*seq_len, n_heads*head_dim)
    # Must be reshape not view
    n_rows, n_cols = dY.shape

    group_size = 4 # 4 or 8, too large group_size can hurt performance.
    n_groups = triton.cdiv(n_heads, group_size)

    grid = (n_rows, n_groups, )
    _rope_embedding[grid](
        dY,  dY .stride(0),
        cos, cos.stride(0),
        sin, sin.stride(0),
        seq_len, head_dim, group_size, n_heads,
        BACKWARD_PASS = True,
        BLOCK_SIZE = BLOCK_SIZE,
        num_warps  = num_warps,
    )
    dY = dY.view(batch, seq_len, n_heads, head_dim)
    dY = dY.transpose(1, 2)    
    return dY


We also define the corresponding meta functions.

In [21]:
def fast_rope_embedding_forward_meta(Q, cos, sin):
    batch, n_heads, seq_len, head_dim = Q.shape
    n_rows, n_cols = batch*seq_len, n_heads*head_dim 
    assert(seq_len <= cos.shape[0])

    BLOCK_SIZE, num_warps = calculate_settings(head_dim//2)
    return TensorProxy(like=Q), (BLOCK_SIZE, num_warps) 

def fast_rope_embedding_backward_meta(BLOCK_SIZE, num_warps, cos, sin, dY):
    return TensorProxy(like=dY)

### Register optimized operators

Just like the `apply_rope` before, we can register operators for the optimized forward and backward.

In [22]:
unsloth_apply_rope_forward = my_ex.register_operator('unsloth_apply_rope_forward', 
    meta=fast_rope_embedding_forward_meta, fn=fast_rope_embedding_forward)
unsloth_apply_rope_backward = my_ex.register_operator('unsloth_apply_rope_backward', 
    meta=fast_rope_embedding_backward_meta, fn=fast_rope_embedding_backward)

### Implementations for operators

Do we need to divert `apply_rope` again? No!
We can register the specialized kernel as an _implementation_ of our base `apply_rope` operator. For this we need an _execution transform_ - which is a fancy word for a function that implements the original operator (`apply_ropw`) in terms of our new operator - so it has the call signature of the `apply_rope`. Because - like many fast implementations - the unsloth rope embedding does not implement the operator in full generality (well, actually they mainly want a 4d tensor input), we implement a checker function, too: It takes the arguments of the operator we want specialize and returns a bool whether our implementation handles the given inputs.

In [23]:
def apply_rope_to_unsloth(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:
    assert len(x.shape) == 4
    res, *_ = unsloth_apply_rope_forward(x, cos, sin)
    return res

def apply_rope_to_unsloth_checker(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> bool:
    if len(x.shape) != 4:
        return False
    return (x.device.devicetype == thunder.devices.DeviceType.CUDA and
            cos.device.devicetype == thunder.devices.DeviceType.CUDA and
           cos.device.devicetype == thunder.devices.DeviceType.CUDA)

my_ex.register_implementation(apply_rope,
                              checker=apply_rope_to_unsloth_checker,
                              execution_transform=apply_rope_to_unsloth)


So let us give it a try! Works great...

In [24]:
thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors())    

expected = test_apply_rope(Q, m)
actual = thunder_apply_rope(Q, m)
print("deviation:", (expected - actual).abs().max().item())

thunder.last_traces(thunder_apply_rope)[-1]

deviation: 0.015625


# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(x, t_1_cos, t_1_sin):
  # x: "cuda:0 bf16[2, 128, 4096, 16]" 
  # t_1_cos: "cuda:0 f32[4096, 16]" 
  # t_1_sin: "cuda:0 f32[4096, 16]" 
  (t2, (_, _)) = unsloth_apply_rope_forward(x, t_1_cos, t_1_sin)
  del x, t_1_cos, t_1_sin
  return t2

And this is also automatic when we instantiate a larger llama2-like model:

In [25]:
torch.set_default_dtype(torch.float32)
with torch.device('cuda'):
    m = GPT(Config.from_name('llama2-like'))

for p in m.parameters():
    p.requires_grad_(False)

thunder_model = thunder.jit(m, executors=(my_ex,) + thunder.get_default_executors())

inp = torch.randint(1, m.config.vocab_size, (1, 128), device="cuda")
actual = thunder_model(inp)
expected = m(inp)

print("deviation:", (actual - expected).abs().max().item())

deviation: 5.960464477539062e-07


By peeking into the trace, we can see that it actually used the unsloth apply rope:

In [26]:
[s for s in str(thunder.last_traces(thunder_model)[-1]).split('\n') if 'apply_rope' in s]

['  (q_roped, (_, _)) = unsloth_apply_rope_forward(t55, cos, sin)',
 '  (k_roped, (_, _)) = unsloth_apply_rope_forward(t57, cos, sin)',
 '  (t165, (_, _)) = unsloth_apply_rope_forward(t164, cos, sin)',
 '  (t167, (_, _)) = unsloth_apply_rope_forward(t166, cos, sin)']

### But what about the backward?

Well, we have to connect forward and backward with a grad transformation. With our specialized ops, this is very simple, we compute the forward, call `get_grad` for the output, compute the backward, and put it on the input with `put_grads`. 


In [27]:
from thunder.core.transforms import get_grad, put_grads

def unsloth_apply_rope_grad(x: TensorProxy, cos: TensorProxy, sin: TensorProxy):
    res, (BLOCK_SIZE, num_warps) = unsloth_apply_rope_forward(x, cos, sin)
    grad_res = get_grad(res)
    grad_x = unsloth_apply_rope_backward(BLOCK_SIZE, num_warps, cos, sin, grad_res)
    put_grads((x,), (grad_x,))
    return res

my_ex.register_implementation(apply_rope, checker=apply_rope_to_unsloth_checker,
                              execution_transform=apply_rope_to_unsloth,
                              grad_transform=unsloth_apply_rope_grad 
                              )



Note that the parts are not actually executed at the same time in the actual computation, but just during tracing.


And let us try our function using the optimized backward

In [28]:
Q.requires_grad_()

thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors())

expected = test_apply_rope(Q, m)
go = torch.ones_like(expected)
gr_expected, = torch.autograd.grad(expected, Q, go)
actual = thunder_apply_rope(Q, m)
gr_actual, = torch.autograd.grad(actual, Q, go)

print("res deviation:", (expected - actual).abs().max().item())
print("grad deviation:", (gr_expected - gr_actual).abs().max().item())

res deviation: 0.015625
grad deviation: 0.0078125


And with `last_backward_traces` we can check that our module is using the unsloth backward:

In [29]:
thunder.last_backward_traces(thunder_apply_rope)[-1]

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection" 
  # cotangents: "Collection" 
  C0, \
  _, \
  = saved_for_backward
  clear_collection(saved_for_backward)
  del saved_for_backward
  t4, \
  = cotangents
  clear_collection(cotangents)
  del cotangents
  t1, \
  t2, \
  = C0
  clear_collection(C0)
  del C0
  t3 = unsloth_apply_rope_backward(8, 4, t1, t2, t4)  # t3: "cuda:0 bf16[2, 128, 4096, 16]"
  del t1, t2, t4
  return (t3, None, None)

### Comparing and exploring optimizations

It is also straightforward to compare potential optimizations.

Note again, that our use of the unsloth kernel might not result in the same performance as the unsloth project sees due to differences in the  hardware used, software environment, or memory layout of the operands.

In [30]:
def test_apply_rope_copy(x, m):
    return apply_rope_copy(x, m.cos, m.sin)

test_apply_rope_myex = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors())    
test_apply_rope_nvfuser = thunder.jit(test_apply_rope_copy)
y = test_apply_rope(Q, m); gr, = torch.autograd.grad(y, Q, go)
y = test_apply_rope_myex(Q, m); gr, = torch.autograd.grad(y, Q, go)
y = test_apply_rope_nvfuser(Q, m); gr, = torch.autograd.grad(y, Q, go)

print("eager")
%timeit y = test_apply_rope(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()
print("thunder + unsloth")
%timeit y = test_apply_rope_myex(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()
print("thunder default (nvfuser)")
%timeit y = test_apply_rope_nvfuser(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()


eager
3.84 ms ± 3.46 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
thunder + unsloth
6.69 ms ± 3.45 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
thunder default (nvfuser)
1.4 ms ± 4.98 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


That's it!

## Conclusion

To wrap up, we hope you got a taste of

- Getting things going with Thunder:

  - Applying Thunder through `thunder.jit` and
  - using FSDP by just wrapping the model in `thunder.distributed.fsdp` before compilation.

- See what's going on inspecting traces:

  - `thunder.last_traces` for the forward traces,
  - `thunder.last_backward_traces` for the backward,
 
- Extending Thunder:

  - registering operators with the `OperatorExecutor`,
  - defining implementations with custom forward and backward to include optimized kernels.

Keep in mind that Thunder is still experimental and only expected to work with the limited set of models we have tested it with. You will find bugs and missing pieces. Naturally, we would love for you to help us fix these! You can find us on the [Thunder section of the Lightning forums](https://lightning.ai/forums/c/thunder) or in the `#thunder` channel on the [PyTorch-Lightning slack](https://pytorch-lightning.slack.com/). 

Do check out our LitGPT studios and the other tutorial notebooks.
