thunder.torch._register_nvfuser_translator¶
- thunder.torch._register_nvfuser_translator(symbol, translator_for_nvfuser, checker=None)[source]¶
Register a translator for nvfuser executor for
symbol.- Parameters:
symbol¶ (
Symbol) – This should be the symbol fromthunder.torch.custom_op._register_custom_op().translator_for_nvfuser¶ (
Callable[[Any],Any]) – A function that takesProxyobjects and Python built-in types as args and kwargs offd,nvfuser.FusionDefinitionandlc_to_nv_map, a dictionary ofTensorProxyto actual values.checker¶ (
Optional[Callable[[Any],bool]]) – A function that takes arguments ofsymboland returnsTrueif the nvfuser definition supports those arguments. By default, A function that always retursTrueis used.symbol (Symbol) –
- Return type:
Note
Currently backward is not supported.
Example
import torch import torch.nn as nn import thunder from thunder.core.dtypes import to_dtype from thunder.executors.nvfuserex_impl import getnv from thunder.executors.nvfuserex_impl import lcdtype_to_nvdtype from thunder.torch.custom_op import _register_custom_op from thunder.torch.custom_op import _register_nvfuser_translator @torch.library.custom_op("my_custom_op::mul", mutates_args=()) def mul(a: torch.Tensor, b: torch.Tensor, c: float | None = None) -> torch.Tensor: return a * b @torch.library.register_fake("my_custom_op::mul") def _(a: torch.Tensor, b: torch.Tensor, c: float | None = None) -> torch.Tensor: return torch.empty_like(a) def setup_context_for_my_custom_op_mul(ctx, inputs, output) -> None: a, b, *_ = inputs ctx.save_for_backward(a, b) def backward_of_my_custom_op_mul(ctx, grad) -> tuple[torch.Tensor, torch.Tensor, None]: a, b = ctx.saved_tensors return torch.ops.my_custom_op.mul(grad, b), torch.ops.my_custom_op.mul(grad, a), None torch.library.register_autograd( "my_custom_op::mul", backward_of_my_custom_op_mul, setup_context=setup_context_for_my_custom_op_mul, ) class MyModule(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(2, 2, bias=False) def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: out = torch.ops.my_custom_op.mul(self.linear(x), y) return torch.relu(out) # Custom nvfuser definition. def mul_translator(a, b, c=None, *, fd, lc_to_nv_map): nva = getnv(a, fd, lc_to_nv_map) nvb = getnv(b, fd, lc_to_nv_map) result = fd.ops.mul(nva, nvb) out = fd.ops.cast(result, lcdtype_to_nvdtype(to_dtype(a.dtype))) return out DEVICE = torch.device("cuda") DTYPE = torch.bfloat16 SHAPE = (8, 2) if __name__ == "__main__": # Register the custom_op of `mul` with :func:`thunder.torch._register_custom_op` _symbol = _register_custom_op(mul) # Register custom nvfuser definition for the already registered custom_op of mul _register_nvfuser_translator(_symbol, mul_translator) model = MyModule().to(device=DEVICE, dtype=DTYPE) with DEVICE: x = torch.randn(SHAPE, dtype=DTYPE) y = torch.randn(SHAPE, dtype=DTYPE) jitted = thunder.jit(model) out = jitted(x, y) fwd_extrace = thunder.last_traces(jitted)[-1] print(fwd_extrace) # def computation(x, y, t_linear_weight): # # x: "cuda:0 bf16[8, 2]" # # y: "cuda:0 bf16[8, 2]" # # t_linear_weight: "cuda:0 bf16[2, 2]" # # # /path/to/torch/nn/modules/linear.py:134: return F.linear(input, self.weight, self.bias) # t28 = torch.nn.functional.linear(x, t_linear_weight, None) # t28: "cuda:0 bf16[8, 2]" # # t28 = ltorch.linear(x, t_linear_weight, None) # t28: "cuda:0 bf16[8, 2]" # # t28 = prims.linear(x, t_linear_weight, None) # t28: "cuda:0 bf16[8, 2]" # [t21, t22] = nvFusion0(t28, y) # # t17 = ltorch.my_custom_op_mul(t28, y) # t17: "cuda:0 bf16[8, 2]" # # t21 = prims.gt(t17, 0.0) # t21: "cuda:0 b8[8, 2]" # # t22 = prims.where(t21, t17, 0.0) # t22: "cuda:0 bf16[8, 2]" # return {'output': (t22,), 'flat_args': [x, y, t_linear_weight], 'flat_output': (t22,)}, ((t21, t28, y, x), ()) out.mean().backward() # The backward uses the original implementation. print(thunder.last_backward_traces(jitted)[-1]) # def backward_fn(saved_for_backward, cotangents): # # saved_for_backward: "Collection" # # cotangents: "Collection" # C0, C1, = saved_for_backward # # C0: "Collection" # # C1: "Collection" # clear_mutable_collection(saved_for_backward) # clear_mutable_collection(C1) # del C1, saved_for_backward # t23, = cotangents # # t23: "cuda:0 bf16[8, 2]" # clear_mutable_collection(cotangents) # del cotangents # t21, t28, y, x, = C0 # # t21: "cuda:0 b8[8, 2]" # # t28: "cuda:0 bf16[8, 2]" # # y: "cuda:0 bf16[8, 2]" # # x: "cuda:0 bf16[8, 2]" # clear_mutable_collection(C0) # del C0 # [t24] = nvFusion1(t21, t23) # # t24 = prims.where(t21, t23, 0.0) # t24: "cuda:0 bf16[8, 2]" # del t21, t23 # [t19, t20] = my_custom_op_mul_backward(t28, y, t24) # del t20, t28, y, t24 # t30 = torch.reshape(t19, (-1, 2)) # t30: "cuda:0 bf16[8, 2]" # # t30 = ltorch.reshape(t19, (-1, 2)) # t30: "cuda:0 bf16[8, 2]" # # t30 = prims.reshape(t19, (8, 2)) # t30: "cuda:0 bf16[8, 2]" # del t19 # t31 = torch.permute(t30, (1, 0)) # t31: "cuda:0 bf16[2, 8]" # # t31 = ltorch.permute(t30, (1, 0)) # t31: "cuda:0 bf16[2, 8]" # # t31 = prims.transpose(t30, (1, 0)) # t31: "cuda:0 bf16[2, 8]" # del t30 # t32 = torch.reshape(x, (-1, 2)) # t32: "cuda:0 bf16[8, 2]" # # t32 = ltorch.reshape(x, (-1, 2)) # t32: "cuda:0 bf16[8, 2]" # # t32 = prims.reshape(x, (8, 2)) # t32: "cuda:0 bf16[8, 2]" # del x # t29 = torch.matmul(t31, t32) # t29: "cuda:0 bf16[2, 2]" # # t29 = ltorch.matmul(t31, t32) # t29: "cuda:0 bf16[2, 2]" # # t29 = prims.matmul(t31, t32) # t29: "cuda:0 bf16[2, 2]" # del t31, t32 # return (None, None, t29)