thunder.transforms.MaterializationTransform¶
- class thunder.transforms.MaterializationTransform(device, *, init)[source]¶
Bases:
Transform
Materialize a model that can fit on a device only after transforms applied.
- Parameters:
device¶ (
str
|device
) – Device to hostThunderModule
after materialization. The transform will annotate any unannotated parameters on the meta device as to be initialized on this device.device (str | torch.device) –
init (Callable[[MaterializationTransform, ThunderModule], None]) –
- Keyword Arguments:
init¶ –
Post-processing callable applied to
ThunderModule
after materialization. possible values are obtained fromMaterializationTransform.init_from_original_state_dict(state_dict) populate weights from a state_dict from the untransformed module,
MaterializationTransform.init_from_transformed_state_dict(state_dict) populate weights from a state_dict from the transformed module,
MaterializationTransform.init_from_original_module_init() initialize using the weight initialization of the original module (reset_parameters)
- __init__(device, *, init)[source]¶
- Parameters:
device (str | torch.device) –
init (Callable[[MaterializationTransform, ThunderModule], None]) –
- Return type:
None
Methods
__init__
(device, *, init)init_from_original_module_init
()init_from_original_state_dict
(state_dict)init_from_transformed_state_dict
(state_dict)reverse_transform_state_dict_for_submodule
(...)transform_module
(model)Transforms the ThunderModule.
transform_state_dict_for_submodule
(model, ...)Implement this to transform the state dict (mostly parameters and buffers) of a module, e.g.
transform_trace_post_optimization
(...)transform_trace_post_optimization enables transforming computation trace after optimization pass.
transform_traces_pre_prologue
(...)transform_traces_pre_prologue enables transforming prologue, computation and epilogue trace.
- transform_module(model)[source]¶
Transforms the ThunderModule. This is executed once on application of the transform
- Parameters:
model (ThunderModule) –