Additional executors #################### nvFuser and Pytorch are not the only executors available in Thunder today. Additional executors can be added to Thunder prior to compilation through a registration mechanism, which makes it easy to have specialized executors perform certain operations more efficiently. This section contains a list of all executors supported by PyTorch beyond nvFuser and PyTorch. Triton CrossEntropy Executor ============================ The Triton CrossEntropy executor can execute ``torch.cross_entropy()`` using an optimized kernel written in OpenAI Triton (https://github.com/openai/triton). It can be used like in the following example:: import torch import thunder from thunder.executors.triton_crossentropy import triton_ex as triton_cross_entropy_ex def xentropy(logits, labels, weight, reduction, ignore_index): return thunder.torch.cross_entropy( logits, labels, weight=weight, reduction=reduction, ignore_index=ignore_index ) jitted_xentropy = thunder.jit( xentropy, executors=[triton_cross_entropy_ex,] ) device = 'cuda' dtype = torch.float32 logits = torch.randn([2048, 50257], device=device, dtype=dtype) labels = torch.randint(0, 50257, [2048], device=device) weight = torch.rand(50257, device=device, dtype=dtype, requires_grad=False) reduction = "sum" ignore_index = labels[5].item() jitted_xentropy(logits, labels, weight, reduction, ignore_index) traces = thunder.last_traces(jitted_xentropy) print(traces[-1]) This prints:: # Constructed by Delete Last Used (took 0 milliseconds) import torch from thunder.executors.torchex import no_autocast @torch.no_grad() @no_autocast def computation(logits, labels, weight): # logits: "cuda:0 f32[2048, 50257]" # labels: "cuda:0 i64[2048]" # weight: "cuda:0 f32[50257]" t23 = triton_crossentropy(logits, labels, weight, None, 45279, None, 'sum', 0.0) # t23: "cuda:0 f32[]" del logits, labels, weight return t23 As shown in the above trace, ``triton_crossentropy()`` is the one running the operation. Apex CrossEntropy Executor ========================== The Apex CrossEntropy executor can execute ``torch.cross_entropy()`` through an optimized kernel, like this:: import torch import thunder from thunder.executors.apexex import apex_ex def xentropy(logits, labels): return thunder.torch.cross_entropy( logits, labels, reduction='mean', ignore_index=-1 ) jitted_xentropy = thunder.jit(xentropy, executors=[apex_ex,]) device = 'cuda' dtype = torch.float32 logits = torch.randn([2048, 50257], device=device, dtype=dtype) labels = torch.randint(0, 50257, [2048], device=device) jitted_xentropy(logits, labels) traces = thunder.last_traces(jitted_xentropy) print(traces[-1]) This prints:: # Constructed by Delete Last Used (took 0 milliseconds) import torch from thunder.executors.torchex import no_autocast @torch.no_grad() @no_autocast def computation(logits, labels): # logits: "cuda:0 f32[2048, 50257]" # labels: "cuda:0 i64[2048]" (t18, _) = apex_cross_entropy(logits, labels, 'mean', 0.0) del logits, labels return t18 showing that Apex is running the operation. cuDNN SDPA Executor =================== TODO RC1 TransformerEngine Executor ========================== TODO RC1