“Hello, World!” ThunderFX
In this tutorial, we’ll explore how to use ThunderFX to accelerate a PyTorch program.
We’ll cover the basics of ThunderFX, demonstrate how to apply it to PyTorch functions and models, and evaluate its performance in both inference (forward-only) and training (forward and backward).
Getting Started with ThunderFX
Let’s see an example of using ThunderFX on a PyTorch function. ThunderFX optimizes the given callable and returns a compiled version of the function. You can then use the compiled function just like you would the original one.
[1]:
import torch
from thunder.dynamo import thunderfx
def foo(x, y):
return torch.sin(x) + torch.cos(y)
# Compiles foo with ThunderFX
compiled_foo = thunderfx(foo)
# Creates inputs
inputs = [torch.randn(4, 4), torch.randn(4, 4)]
eager_results = foo(*inputs)
# Runs the compiled model
thunderfx_results = compiled_foo(*inputs)
torch.testing.assert_close(eager_results, thunderfx_results)
ThunderFX supports both CPU and CUDA tensors. However, its primary focus is optimizing CUDA calculations. The following example demonstrates ThunderFX with CUDA tensors:
[2]:
import sys
# Checks if CUDA is available
if not torch.cuda.is_available():
print("No suitable GPU detected. Unable to proceed with the tutorial. Cell execution has been stopped.")
sys.exit()
# Creates inputs
inputs = [torch.randn(4, 4, device="cuda"), torch.randn(4, 4, device="cuda")]
eager_result = foo(*inputs)
thunderfx_result = compiled_foo(*inputs)
torch.testing.assert_close(eager_result, thunderfx_result)
Performance Optimization with ThunderFX
Next, let’s evaluate how ThunderFX improves performance on a real-world model. We’ll use the Llama3 model as an example and compare the execution time for both inference and gradient calculations.
We begin by loading and configuring a smaller version of the Llama3 model:
[3]:
from litgpt import Config, GPT
from functools import partial
from torch.testing import make_tensor
from thunder.dynamo import thunderfx
cfg = Config.from_name("Llama-3-8B")
# Uses a reduced configuration for this tutorial
cfg.n_layer = 2
cfg.block_size = 1024
batch_dim = 4
torch.set_default_dtype(torch.bfloat16)
make = partial(make_tensor, low=0, high=255, device='cuda', dtype=torch.int64)
with torch.device('cuda'):
model = GPT(cfg)
shape = (batch_dim, cfg.block_size)
x = make(shape)
model
[3]:
GPT(
(lm_head): Linear(in_features=4096, out_features=128256, bias=False)
(transformer): ModuleDict(
(wte): Embedding(128256, 4096)
(h): ModuleList(
(0-1): 2 x Block(
(norm_1): RMSNorm()
(attn): CausalSelfAttention(
(attn): Linear(in_features=4096, out_features=6144, bias=False)
(proj): Linear(in_features=4096, out_features=4096, bias=False)
)
(post_attention_norm): Identity()
(norm_2): RMSNorm()
(mlp): LLaMAMLP(
(fc_1): Linear(in_features=4096, out_features=14336, bias=False)
(fc_2): Linear(in_features=4096, out_features=14336, bias=False)
(proj): Linear(in_features=14336, out_features=4096, bias=False)
)
(post_mlp_norm): Identity()
)
)
(ln_f): RMSNorm()
)
)
Again we first compile our model and compare the output. Thunder’s optimized kernels may produce slightly different results than other kernels, but the differences shouldn’t be significant in practice.
[ ]:
compiled_model = thunderfx(model)
thunderfx_result = compiled_model(x)
eager_result = model(x)
print("deviation:", (thunderfx_result - eager_result).abs().max().item())
deviation: 0.015625
Note: ThunderFX compiles the model into optimized kernels as it executes. Compiling these kernels can take seconds or even minutes for larger models, but each kernel only has to be compiled once, and subsequent runs will benefit from it.
To evaluate ThunderFX’s inference performance, we compare the execution time of the compiled model versus the standard PyTorch model:
[5]:
# Clears data to free some memory.
del thunderfx_result, eager_result
import gc
gc.collect()
torch.cuda.empty_cache()
# Measures inference time
print("ThunderFX Inference Time:")
%timeit r = compiled_model(x); torch.cuda.synchronize()
print("Torch Eager Inference Time:")
%timeit r = model(x); torch.cuda.synchronize()
ThunderFX Inference Time:
66.7 ms ± 289 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Torch Eager Inference Time:
72.2 ms ± 287 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Similarly, let’s measure the performance improvement for training:
[6]:
print("ThunderFX Training Time:")
%timeit r = compiled_model(x); r.sum().backward(); torch.cuda.synchronize()
print("Torch Eager Training Time:")
%timeit r = model(x); r.sum().backward(); torch.cuda.synchronize()
ThunderFX Training Time:
197 ms ± 5.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Torch Eager Training Time:
213 ms ± 2.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Conclusion
ThunderFX can accelerate PyTorch programs, particularly CUDA programs. By compiling optimized kernels specific to the program you’re running. It can accelerate both inference (forward-only) and training (forward and backward) computations.
For more information about Thunder and ThunderFX in particular, see https://github.com/Lightning-AI/lightning-thunder/tree/main/notebooks.