thunder
Compiling functions and modules
|
Just-in-time compile a callable (function or model). |
Querying information on compiled functions and modules
|
options can be dynamically registered, currently registered ones are below |
|
Obtains the compilation data from a JITed function. |
|
Obtains the compilation statistics from a JITed function. |
|
Obtains the list of computation traces that have been produced for the last run of the function. |
Obtains the list of backward traces that have been produced for the last run of the function and the selected prologue. |
|
Obtains the list of prologue traces that have been produced for the last run of the function and the selected prologue. |
|
|
Returns the cache options set when JITting the function. |
|
Returns the number of cache hits we found when running the function. |
|
Returns the number of cache misses we found when running the function. |
|
Returns the list of (explicit) transforms applied to the JITed function. |
Returns the list of instructions the interpreter encountered while tracing through the user program (on the last cache miss). |
|
Returns the list of instructions and other information the interpreter encountered while tracing through the user program (on the last cache miss). |
|
|
Prints how compiled options were used (or not) |
JITed Model wrapper
- class thunder.ThunderModule(model, compiled_model_call)[source]
Bases:
Module
A wrapper nn.Module subclass.
This wrapper is returned by
thunder.jit
, you would typically not instantiate it manually.- get_buffer(name)[source]
Return the buffer given by
target
if it exists, otherwise throw an error.See the docstring for
get_submodule
for a more detailed explanation of this method’s functionality as well as how to correctly specifytarget
.- Parameters:
target – The fully-qualified string name of the buffer to look for. (See
get_submodule
for how to specify a fully-qualified string.)- Returns:
The buffer referenced by
target
- Return type:
- Raises:
AttributeError – If the target string references an invalid path or resolves to something that is not a buffer
- get_parameter(name)[source]
Return the parameter given by
target
if it exists, otherwise throw an error.See the docstring for
get_submodule
for a more detailed explanation of this method’s functionality as well as how to correctly specifytarget
.- Parameters:
target – The fully-qualified string name of the Parameter to look for. (See
get_submodule
for how to specify a fully-qualified string.)- Returns:
The Parameter referenced by
target
- Return type:
torch.nn.Parameter
- Raises:
AttributeError – If the target string references an invalid path or resolves to something that is not an
nn.Parameter
- get_submodule(name)[source]
Return the submodule given by
target
if it exists, otherwise throw an error.For example, let’s say you have an
nn.Module
A
that looks like this:A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) )
(The diagram shows an
nn.Module
A
.A
which has a nested submodulenet_b
, which itself has two submodulesnet_c
andlinear
.net_c
then has a submoduleconv
.)To check whether or not we have the
linear
submodule, we would callget_submodule("net_b.linear")
. To check whether we have theconv
submodule, we would callget_submodule("net_b.net_c.conv")
.The runtime of
get_submodule
is bounded by the degree of module nesting intarget
. A query againstnamed_modules
achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists,get_submodule
should always be used.- Parameters:
target – The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.)
- Returns:
The submodule referenced by
target
- Return type:
- Raises:
AttributeError – If the target string references an invalid path or resolves to something that is not an
nn.Module
- load_state_dict(state_dict, strict=True, assign=False)[source]
Loads the state dict to a transformed module.
- Parameters:
This is similar much more simple than the original load_state_dict. (Regarding hooks, customization etc.)
- named_buffers(prefix='', recurse=True, remove_duplicate=True, *, persistent=None)[source]
Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
- Parameters:
prefix (str) – prefix to prepend to all buffer names.
recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.
remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.
- Yields:
(str, torch.Tensor) – Tuple containing the name and buffer
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_parameters(prefix='', recurse=True, remove_duplicate=True)[source]
Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
- Parameters:
prefix (str) – prefix to prepend to all parameter names.
recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
remove_duplicate (bool, optional) – whether to remove the duplicated parameters in the result. Defaults to True.
- Yields:
(str, Parameter) – Tuple containing the name and parameter
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- no_sync()[source]
Context manager to disable gradient synchronization in data parallel mode.
This context manager is intended to be used in conjunction with
torch.nn.parallel.DistributedDataParallel
to disable gradient synchronization in the backward pass. It will not have any effect when used with other modules.Note
This could lead to different accumulated gradients with
torch.nn.parallel.distributed.DistributedDataParallel.no_sync
. PyTorch’s gradient synchronization is implemented by applying all-reduce to gradient buckets oftorch.nn.Parameter.grad
. Thus theno_sync
context leads to where means the number of gradient accumulation steps. In contrast, this synchronizes accumulated gradients when exiting, leading to .Warning
You must reuse this context manager in each group of gradient accumulation iterations since gradients will get synchronized on context manager exit.
with model.no_sync(): for _ in range(len(gradient_accumulation_iters)): loss(model(x)).backward() # uses no-sync-backward trace loss(model(x)).backward() # uses the regular backward trace optimizer.step()
- original_state_dict(*, destination=None, prefix='', keep_vars=False)[source]
Returns the state dict of the transformed
ThunderModule
with reverse transform applied.For example,
ThunderModule.state_dict()
returns a state dict of sharded tensors if a model isthunder.distributed.fsdp()
applied whileThunderModule.original_state_dict()
returns a state dict of unsharded tensors.
- state_dict(*, destination=None, prefix='', keep_vars=False)[source]
Returns the state dict of the (transformed) Thunder module.
- Parameters:
Note that this is similar but rather more rudimentary than the original state_dict (e.g. no hook suport yet).