thunder.transforms.ConstantFolding

class thunder.transforms.ConstantFolding[source]

Bases: Transform

Apply Constant Folding to computation trace.

With this transform applied to a computation trace, successive passes (meaning trace transformations) can transform the simplified compute.

from thunder.transforms import ConstantFolding

model = ...
transforms = [ConstantFolding()]
jitted = thunder.jit(model, transforms=transforms)
# If you prefer `ThunderCompiler`...
from thunder.dynamo import ThunderCompiler
backend = ThunderCompiler(transforms=transforms)
jitted = torch.compile(model, backend=backend)

To see the effect of this transform, let’s use the following function:

def forward(x):
    scale_t = torch.tensor([2.])
    scale_t = (scale_t * 10) / 5
    return x * scale_t

The initial computation trace is as follows:

def computation(x):
  # x: "cpu f32[3]"

  scale_t = ltorch.tensor([2.0], device=None, dtype=None, requires_grad=False, pin_memory=False)  # scale_t: "cpu f32[1]"
    # scale_t = prims.tensor_from_sequence([2.0], dtype=None, device=devices.Device("cpu"))  # scale_t: "cpu f32[1]"

  t1 = ltorch.mul(scale_t, 10)  # t1: "cpu f32[1]"
    # _ = prims.convert_element_type(10, float)
    # t1 = prims.mul(scale_t, 10.0)  # t1: "cpu f32[1]"
  t2 = ltorch.true_divide(t1, 5)  # t2: "cpu f32[1]"
    # _ = prims.convert_element_type(5, float)
    # t2 = prims.div(t1, 5.0)  # t2: "cpu f32[1]"

  t4 = ltorch.mul(x, t2)  # t4: "cpu f32[3]"
    # t3 = prims.broadcast_in_dim(t2, (3,), (0,))  # t3: "cpu f32[3]"
    # t4 = prims.mul(x, t3)  # t4: "cpu f32[3]"
  return t4

This transform simplifies this trace into

def computation(x):
  # x: "cpu f32[3]"
  t2 = prims.tensor_from_sequence([4.0], dtype=dtypes.float32, device=devices.Device("cpu"))  # t2: "cpu f32[1]"

  t4 = ltorch.mul(x, t2)  # t4: "cpu f32[3]"
    # t3 = prims.broadcast_in_dim(t2, (3,), (0,))  # t3: "cpu f32[3]"
    # t4 = prims.mul(x, t3)  # t4: "cpu f32[3]"
  return {'output': t4, 'flat_args': [x]}
__init__()[source]

Methods

__init__()

reverse_transform_state_dict_for_submodule(...)

rtype:

dict[str, Any]

transform_module(model)

Transforms the ThunderModule.

transform_state_dict_for_submodule(model, ...)

Implement this to transform the state dict (mostly parameters and buffers) of a module, e.g.

transform_trace_post_optimization(...)

transform_trace_post_optimization enables transforming computation trace after optimization pass.

transform_traces_pre_prologue(prologue_trc, ...)

transform_traces_pre_prologue enables transforming prologue, computation and epilogue trace.

transform_traces_pre_prologue(prologue_trc, computation_trc, epilogue_trc, **kwargs)[source]

transform_traces_pre_prologue enables transforming prologue, computation and epilogue trace. Note that the computation trace here is before the autograd transform, so any update to the computation trace will also update backward trace.