thunder.transforms.ConstantFolding¶
- class thunder.transforms.ConstantFolding[source]¶
Bases:
Transform
Apply Constant Folding to computation trace.
With this transform applied to a computation trace, successive passes (meaning trace transformations) can transform the simplified compute.
from thunder.transforms import ConstantFolding model = ... transforms = [ConstantFolding()] jitted = thunder.jit(model, transforms=transforms) # If you prefer `ThunderCompiler`... from thunder.dynamo import ThunderCompiler backend = ThunderCompiler(transforms=transforms) jitted = torch.compile(model, backend=backend)
To see the effect of this transform, let’s use the following function:
def forward(x): scale_t = torch.tensor([2.]) scale_t = (scale_t * 10) / 5 return x * scale_t
The initial computation trace is as follows:
def computation(x): # x: "cpu f32[3]" scale_t = ltorch.tensor([2.0], device=None, dtype=None, requires_grad=False, pin_memory=False) # scale_t: "cpu f32[1]" # scale_t = prims.tensor_from_sequence([2.0], dtype=None, device=devices.Device("cpu")) # scale_t: "cpu f32[1]" t1 = ltorch.mul(scale_t, 10) # t1: "cpu f32[1]" # _ = prims.convert_element_type(10, float) # t1 = prims.mul(scale_t, 10.0) # t1: "cpu f32[1]" t2 = ltorch.true_divide(t1, 5) # t2: "cpu f32[1]" # _ = prims.convert_element_type(5, float) # t2 = prims.div(t1, 5.0) # t2: "cpu f32[1]" t4 = ltorch.mul(x, t2) # t4: "cpu f32[3]" # t3 = prims.broadcast_in_dim(t2, (3,), (0,)) # t3: "cpu f32[3]" # t4 = prims.mul(x, t3) # t4: "cpu f32[3]" return t4
This transform simplifies this trace into
def computation(x): # x: "cpu f32[3]" t2 = prims.tensor_from_sequence([4.0], dtype=dtypes.float32, device=devices.Device("cpu")) # t2: "cpu f32[1]" t4 = ltorch.mul(x, t2) # t4: "cpu f32[3]" # t3 = prims.broadcast_in_dim(t2, (3,), (0,)) # t3: "cpu f32[3]" # t4 = prims.mul(x, t3) # t4: "cpu f32[3]" return {'output': t4, 'flat_args': [x]}
Methods
__init__
()reverse_transform_state_dict_for_submodule
(...)transform_module
(model)Transforms the ThunderModule.
transform_state_dict_for_submodule
(model, ...)Implement this to transform the state dict (mostly parameters and buffers) of a module, e.g.
transform_trace_post_optimization
(...)transform_trace_post_optimization enables transforming computation trace after optimization pass.
transform_traces_pre_prologue
(prologue_trc, ...)transform_traces_pre_prologue enables transforming prologue, computation and epilogue trace.
- transform_traces_pre_prologue(prologue_trc, computation_trc, epilogue_trc, **kwargs)[source]¶
transform_traces_pre_prologue enables transforming prologue, computation and epilogue trace. Note that the computation trace here is before the autograd transform, so any update to the computation trace will also update backward trace.