thunder.transforms.MaterializationTransform

class thunder.transforms.MaterializationTransform(device, *, init)[source]

Bases: Transform

Materialize a model that can fit on a device only after transforms applied.

Parameters:
Keyword Arguments:

init

Post-processing callable applied to ThunderModule after materialization. possible values are obtained from

  • MaterializationTransform.init_from_original_state_dict(state_dict) populate weights from a state_dict from the untransformed module,

  • MaterializationTransform.init_from_transformed_state_dict(state_dict) populate weights from a state_dict from the transformed module,

  • MaterializationTransform.init_from_original_module_init() initialize using the weight initialization of the original module (reset_parameters)

__init__(device, *, init)[source]
Parameters:
Return type:

None

Methods

__init__(device, *, init)

init_from_original_module_init()

init_from_original_state_dict(state_dict)

init_from_transformed_state_dict(state_dict)

reverse_transform_state_dict_for_submodule(...)

rtype:

dict[str, Any]

transform_module(model)

Transforms the ThunderModule.

transform_state_dict_for_submodule(model, ...)

Implement this to transform the state dict (mostly parameters and buffers) of a module, e.g.

transform_trace_post_optimization(...)

transform_trace_post_optimization enables transforming computation trace after optimization pass.

transform_traces_pre_prologue(...)

transform_traces_pre_prologue enables transforming prologue, computation and epilogue trace.

transform_module(model)[source]

Transforms the ThunderModule. This is executed once on application of the transform

Parameters:

model (ThunderModule) –