thunder.torch._register_nvfuser_translator

thunder.torch._register_nvfuser_translator(symbol, translator_for_nvfuser, checker=None)[source]

Register a translator for nvfuser executor for symbol.

Parameters:
  • symbol (Symbol) – This should be the symbol from thunder.torch.custom_op._register_custom_op().

  • translator_for_nvfuser (Callable[[Any], Any]) – A function that takes Proxy objects and Python built-in types as args and kwargs of fd, nvfuser.FusionDefinition and lc_to_nv_map, a dictionary of TensorProxy to actual values.

  • checker (Optional[Callable[[Any], bool]]) – A function that takes arguments of symbol and returns True if the nvfuser definition supports those arguments. By default, A function that always returs True is used.

  • symbol (Symbol) –

  • translator_for_nvfuser (Callable[[Any], Any]) –

  • checker (Optional[Callable[[Any], bool]]) –

Return type:

None

Note

Currently backward is not supported.

Example

import torch
import torch.nn as nn

import thunder
from thunder.core.dtypes import to_dtype
from thunder.executors.nvfuserex_impl import getnv
from thunder.executors.nvfuserex_impl import lcdtype_to_nvdtype
from thunder.torch.custom_op import _register_custom_op
from thunder.torch.custom_op import _register_nvfuser_translator

@torch.library.custom_op("my_custom_op::mul", mutates_args=())
def mul(a: torch.Tensor, b: torch.Tensor, c: float | None = None) -> torch.Tensor:
    return a * b

@torch.library.register_fake("my_custom_op::mul")
def _(a: torch.Tensor, b: torch.Tensor, c: float | None = None) -> torch.Tensor:
    return torch.empty_like(a)

def setup_context_for_my_custom_op_mul(ctx, inputs, output) -> None:
    a, b, *_ = inputs
    ctx.save_for_backward(a, b)

def backward_of_my_custom_op_mul(ctx, grad) -> tuple[torch.Tensor, torch.Tensor, None]:
    a, b = ctx.saved_tensors
    return torch.ops.my_custom_op.mul(grad, b), torch.ops.my_custom_op.mul(grad, a), None

torch.library.register_autograd(
    "my_custom_op::mul",
    backward_of_my_custom_op_mul,
    setup_context=setup_context_for_my_custom_op_mul,
)

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(2, 2, bias=False)

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        out = torch.ops.my_custom_op.mul(self.linear(x), y)
        return torch.relu(out)

# Custom nvfuser definition.
def mul_translator(a, b, c=None, *, fd, lc_to_nv_map):
    nva = getnv(a, fd, lc_to_nv_map)
    nvb = getnv(b, fd, lc_to_nv_map)
    result = fd.ops.mul(nva, nvb)
    out = fd.ops.cast(result, lcdtype_to_nvdtype(to_dtype(a.dtype)))
    return out

DEVICE = torch.device("cuda")
DTYPE = torch.bfloat16
SHAPE = (8, 2)

if __name__ == "__main__":
    # Register the custom_op of `mul` with :func:`thunder.torch._register_custom_op`
    _symbol = _register_custom_op(mul)
    # Register custom nvfuser definition for the already registered custom_op of mul
    _register_nvfuser_translator(_symbol, mul_translator)

    model = MyModule().to(device=DEVICE, dtype=DTYPE)
    with DEVICE:
        x = torch.randn(SHAPE, dtype=DTYPE)
        y = torch.randn(SHAPE, dtype=DTYPE)

    jitted = thunder.jit(model)
    out = jitted(x, y)
    fwd_extrace = thunder.last_traces(jitted)[-1]
    print(fwd_extrace)
    # def computation(x, y, t_linear_weight):
    #   # x: "cuda:0 bf16[8, 2]"
    #   # y: "cuda:0 bf16[8, 2]"
    #   # t_linear_weight: "cuda:0 bf16[2, 2]"
    #
    #   # /path/to/torch/nn/modules/linear.py:134:                return F.linear(input, self.weight, self.bias)
    #   t28 = torch.nn.functional.linear(x, t_linear_weight, None)  # t28: "cuda:0 bf16[8, 2]"
    #     # t28 = ltorch.linear(x, t_linear_weight, None)  # t28: "cuda:0 bf16[8, 2]"
    #       # t28 = prims.linear(x, t_linear_weight, None)  # t28: "cuda:0 bf16[8, 2]"
    #   [t21, t22] = nvFusion0(t28, y)
    #     # t17 = ltorch.my_custom_op_mul(t28, y)  # t17: "cuda:0 bf16[8, 2]"
    #     # t21 = prims.gt(t17, 0.0)  # t21: "cuda:0 b8[8, 2]"
    #     # t22 = prims.where(t21, t17, 0.0)  # t22: "cuda:0 bf16[8, 2]"
    #   return {'output': (t22,), 'flat_args': [x, y, t_linear_weight], 'flat_output': (t22,)}, ((t21, t28, y, x), ())

    out.mean().backward()
    # The backward uses the original implementation.
    print(thunder.last_backward_traces(jitted)[-1])
    # def backward_fn(saved_for_backward, cotangents):
    #   # saved_for_backward: "Collection"
    #   # cotangents: "Collection"
    #   C0, C1, = saved_for_backward
    #   # C0: "Collection"
    #   # C1: "Collection"
    #   clear_mutable_collection(saved_for_backward)
    #   clear_mutable_collection(C1)
    #   del C1, saved_for_backward
    #   t23, = cotangents
    #   # t23: "cuda:0 bf16[8, 2]"
    #   clear_mutable_collection(cotangents)
    #   del cotangents
    #   t21, t28, y, x, = C0
    #   # t21: "cuda:0 b8[8, 2]"
    #   # t28: "cuda:0 bf16[8, 2]"
    #   # y: "cuda:0 bf16[8, 2]"
    #   # x: "cuda:0 bf16[8, 2]"
    #   clear_mutable_collection(C0)
    #   del C0
    #   [t24] = nvFusion1(t21, t23)
    #     # t24 = prims.where(t21, t23, 0.0)  # t24: "cuda:0 bf16[8, 2]"
    #   del t21, t23
    #   [t19, t20] = my_custom_op_mul_backward(t28, y, t24)
    #   del t20, t28, y, t24
    #   t30 = torch.reshape(t19, (-1, 2))  # t30: "cuda:0 bf16[8, 2]"
    #     # t30 = ltorch.reshape(t19, (-1, 2))  # t30: "cuda:0 bf16[8, 2]"
    #       # t30 = prims.reshape(t19, (8, 2))  # t30: "cuda:0 bf16[8, 2]"
    #   del t19
    #   t31 = torch.permute(t30, (1, 0))  # t31: "cuda:0 bf16[2, 8]"
    #     # t31 = ltorch.permute(t30, (1, 0))  # t31: "cuda:0 bf16[2, 8]"
    #       # t31 = prims.transpose(t30, (1, 0))  # t31: "cuda:0 bf16[2, 8]"
    #   del t30
    #   t32 = torch.reshape(x, (-1, 2))  # t32: "cuda:0 bf16[8, 2]"
    #     # t32 = ltorch.reshape(x, (-1, 2))  # t32: "cuda:0 bf16[8, 2]"
    #       # t32 = prims.reshape(x, (8, 2))  # t32: "cuda:0 bf16[8, 2]"
    #   del x
    #   t29 = torch.matmul(t31, t32)  # t29: "cuda:0 bf16[2, 2]"
    #     # t29 = ltorch.matmul(t31, t32)  # t29: "cuda:0 bf16[2, 2]"
    #       # t29 = prims.matmul(t31, t32)  # t29: "cuda:0 bf16[2, 2]"
    #   del t31, t32
    #   return (None, None, t29)