Train a MLP on MNIST #################### Here's a complete program that trains a torchvision MLP on MNIST:: pip install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 Here's the code:: import os import torch import torchvision import torchvision.transforms as transforms import thunder # Creates train and test datasets device = 'cuda' device_transform = transforms.Lambda(lambda t: t.to(device)) flatten_transform = transforms.Lambda(lambda t: t.flatten()) my_transform = transforms.Compose((transforms.ToTensor(), device_transform, flatten_transform)) train_dataset = torchvision.datasets.MNIST("/tmp/mnist/train", train=True, download=True, transform=my_transform) test_dataset = torchvision.datasets.MNIST("/tmp/mnist/test", train=False, download=True, transform=my_transform) # Creates Samplers train_sampler = torch.utils.data.RandomSampler(train_dataset) test_sampler = torch.utils.data.RandomSampler(test_dataset) # Creates DataLoaders batch_size = 8 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, sampler=test_sampler) # Evaluates the model def eval_model(model, test_loader): num_correct = 0 total_guesses = 0 for data, targets in iter(test_loader): targets = targets.cuda() # Acquires the model's best guesses at each class results = model(data) best_guesses = torch.argmax(results, 1) # Updates number of correct and total guesses num_correct += torch.eq(targets, best_guesses).sum().item() total_guesses += batch_size # Prints output print("Correctly guessed ", (num_correct/total_guesses) * 100, "% of the dataset") # Trains the model def train_model(model, train_loader, *, num_epochs: int = 1): loss_fn = torch.nn.CrossEntropyLoss().to(device) optimizer = torch.optim.Adam(model.parameters()) for epoch in range(num_epochs): for data, targets in iter(train_loader): targets = targets.cuda() # Acquires the model's best guesses at each class results = model(data) # Computes loss loss = loss_fn(results, targets) # Updates model optimizer.zero_grad() loss.backward() optimizer.step() # Constructs the model model = torchvision.ops.MLP(in_channels=784, hidden_channels=[784, 784, 784, 28, 10], bias=True, dropout=.1).to(device) # Performs an initial evaluation model.eval().requires_grad_(False) jitted_eval_model = thunder.jit(model) eval_model(jitted_eval_model, test_loader) # Trains the model model.train().requires_grad_(True) jitted_train_model = thunder.jit(model) train_model(jitted_train_model, train_loader) model.eval().requires_grad_(False) # Performs a final evaluation eval_model(jitted_eval_model, test_loader) # Evaluates the original, unjitted model # The unjitted and jitted model share parameters, so it's # also updated eval_model(model, test_loader) # Acquires and prints thunder's "traces", which show what thunder executed # The training model has both "forward" and "backward" traces, corresponding # to its forward and backward computations. # The evaluation model has only one set of traces. fwd_traces = thunder.last_traces(jitted_train_model) bwd_traces = thunder.last_backward_traces(jitted_train_model) eval_traces = thunder.last_traces(jitted_eval_model) print("This is the trace that thunder executed for training's forward computation:") print(fwd_traces[-1]) print("This is the trace that thunder executed for training's backward computation:") print(bwd_traces[-1]) print("This is the trace that thunder executed for eval's computation:") print(eval_traces[-1]) Let's look at a few parts of this program more closely. First, up until the call to ``thunder.jit()`` the program is just Python, PyTorch and torchvision. ``thunder.jit()`` accepts a PyTorch module (or function) and returns a Thunder-optimized module that has the same signature, parameters and buffers. After compilation the program is, again, just Python and PyTorch, until the very end. Behind the scenes, when a Thunder module is called it produces a “trace” representing the sequence of tensor operations to perform. This trace is then transformed and optimized, and the sequence of these traces for the last inputs can be acquired by calling ``thunder.last_traces()`` on the module (the traced program changes when different input data types, devices, or other properties are used). When the module is used for training, ``thunder.last_traces()`` will return both the sequence of “forward” traces and the sequence of “backward” traces, and when it's just used for evaluation it will just return one sequence of traces. In this case we're printing the last traces in the sequence, which print as Python programs, and these Python programs are what gets executed by Thunder. Let's take a look at the execution trace for the training module's forward:: @torch.no_grad() @no_autocast def augmented_forward_fn(t0, t4, t5, t21, t22, t38, t39, t55, t56, t72, t73): # t0 # t4 # t5 # t21 # t22 # t38 # t39 # t55 # t56 # t72 # t73 t1 = torch.nn.functional.linear(t0, t4, t5) # t1 # t1 = ltorch.linear(t0, t4, t5) # t1 # t1 = prims.linear(t0, t4, t5) # t1 [t10, t2, t7] = nvFusion0(t1) # t2 = prims.gt(t1, 0.0) # t2 # t3 = prims.where(t2, t1, 0.0) # t3 # t6 = prims.uniform((8, 784), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t6 # t7 = prims.lt(t6, 0.9) # t7 # t8 = prims.convert_element_type(t7, dtypes.float32) # t8 # t9 = prims.mul(t3, t8) # t9 # t10 = prims.mul(t9, 1.1111111111111112) # t10 del t1 t11 = torch.nn.functional.linear(t10, t21, t22) # t11 # t11 = ltorch.linear(t10, t21, t22) # t11 # t11 = prims.linear(t10, t21, t22) # t11 [t12, t15, t18] = nvFusion1(t11) # t12 = prims.gt(t11, 0.0) # t12 # t13 = prims.where(t12, t11, 0.0) # t13 # t14 = prims.uniform((8, 784), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t14 # t15 = prims.lt(t14, 0.9) # t15 # t16 = prims.convert_element_type(t15, dtypes.float32) # t16 # t17 = prims.mul(t13, t16) # t17 # t18 = prims.mul(t17, 1.1111111111111112) # t18 del t11 t19 = torch.nn.functional.linear(t18, t38, t39) # t19 # t19 = ltorch.linear(t18, t38, t39) # t19 # t19 = prims.linear(t18, t38, t39) # t19 [t20, t25, t28] = nvFusion2(t19) # t20 = prims.gt(t19, 0.0) # t20 # t23 = prims.where(t20, t19, 0.0) # t23 # t24 = prims.uniform((8, 784), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t24 # t25 = prims.lt(t24, 0.9) # t25 # t26 = prims.convert_element_type(t25, dtypes.float32) # t26 # t27 = prims.mul(t23, t26) # t27 # t28 = prims.mul(t27, 1.1111111111111112) # t28 del t19 t29 = torch.nn.functional.linear(t28, t55, t56) # t29 # t29 = ltorch.linear(t28, t55, t56) # t29 # t29 = prims.linear(t28, t55, t56) # t29 [t30, t33, t36] = nvFusion3(t29) # t30 = prims.gt(t29, 0.0) # t30 # t31 = prims.where(t30, t29, 0.0) # t31 # t32 = prims.uniform((8, 28), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t32 # t33 = prims.lt(t32, 0.9) # t33 # t34 = prims.convert_element_type(t33, dtypes.float32) # t34 # t35 = prims.mul(t31, t34) # t35 # t36 = prims.mul(t35, 1.1111111111111112) # t36 del t29 t37 = torch.nn.functional.linear(t36, t72, t73) # t37 # t37 = ltorch.linear(t36, t72, t73) # t37 # t37 = prims.linear(t36, t72, t73) # t37 [t41, t44] = nvFusion4(t37) # t40 = prims.uniform((8, 10), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t40 # t41 = prims.lt(t40, 0.9) # t41 # t42 = prims.convert_element_type(t41, dtypes.float32) # t42 # t43 = prims.mul(t37, t42) # t43 # t44 = prims.mul(t43, 1.1111111111111112) # t44 del t37 return {'output': (t44, ()), 'flat_args': [t0, t4, t5, t21, t22, t38, t39, t55, t56, t72, t73], 'flat_output': (t44,)}, ((t0, t10, t12, t15, t18, t2, t20, t21, t25, t28, t30, t33, t36, t38, t41, t55, t7, t72), (1.1111111111111112, 1.1111111111111112, 1.1111111111111112, 1.1111111111111112, 1.1111111111111112)) There's a lot going on here, and if you'd like to get into the details then keep reading! But we can see that the trace is a functional Python function, and Thunder has produced several groups of primitives that are sent to nvFuser. Instead of leaving these primitives directly in the module, nvFuser has produced several optimized kernels (fusions) and inserted them into the program (``nvFusion0``, ``nvFusion1``, ...). Under each fusion (in comments) are the “primitive” operations that describe precisely what each group does, although how each fusion is executed is entirely up to nvFuser.