Extending Thunder ################# .. TODO RC1: update using the extend API This section describes how to add an executor to Thunder for a PyTorch operation. First, define a Python function with the same signature as the targeted operation, and have it call your implementation. For example, the Apex executor for ``torch.nn.functional.cross_entropy`` might define its implementation like:: import torch import xentropy_cuda def apex_xentropy( a: torch.Tensor, # a is an actual PyTorch tensor target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", label_smoothing=0.0, ): losses, max_log_sum_exp = xentropy_cuda.forward(a, target, label_smoothing, half_to_float) When this implementation is used it will be called with actual PyTorch tensors, and not with proxies. Next, define a “checker” function with the same signature as the targeted operation that returns True if your operation can execute the targeted operation and False otherwise. Checkers, unlike the implementations, are called with proxies, and not actual PyTorch tensors, because they're called at optimization time. The purpose of a checker function is to let executors target only specific inputs to an operation, and defer to another executor on other inputs. A checker function for the Apex executor might look like:: from thunder.core.proxies import TensorProxy def apex_xentropy_checker( a: TensorProxy, # a is a proxy target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", label_smoothing=0.0, ): # Apex's xentropy only supports "sum", "mean" or "none" reductions if reduction not in ["sum", "mean", "none"]: return False return True Create a mapping from the name of the PyTorch operation to your replacement implementation's name, its checker, and its implementation:: _op_to_xentropy = { "torch.nn.functional.cross_entropy": ("apex_xentropy", apex_xentropy_checker, apex_xentropy), } Then define a registration function that practitioners can call to access your executor:: def register_apex_xentropyex(*, add_to_default_executors: bool = True) -> None: from thunder.executors import add_operator_executor return add_operator_executor("apex_xentropy", _op_to_xentropy, add_to_default_executors=add_to_default_executors) You can test your executor by registering it, compiling a function that calls the targeted operator, and then verifying that your operation is called (by inspecting the execution trace) and producing the correct output. A good example of this is the tests for the Apex executor.