XLAFSDPStrategy¶
- class lightning.fabric.strategies.XLAFSDPStrategy(accelerator=None, parallel_devices=None, checkpoint_io=None, precision=None, auto_wrap_policy=None, activation_checkpointing_policy=None, state_dict_type='sharded', sequential_save=False, **kwargs)[source]¶
Bases:
ParallelStrategy
,_Sharded
Strategy for training multiple XLA devices using the
torch_xla.distributed.xla_fully_sharded_data_parallel.XlaFullyShardedDataParallel()
method.Warning
This is an experimental feature.
For more information check out https://github.com/pytorch/xla/blob/v2.5.0/docs/fsdp.md
- Parameters:
auto_wrap_policy¶ (
Union
[Set
[Type
[Module
]],Callable
[[Module
,bool
,int
],bool
],None
]) – Same asauto_wrap_policy
parameter intorch_xla.distributed.fsdp.XlaFullyShardedDataParallel
. For convenience, this also accepts a set of the layer classes to wrap.activation_checkpointing_policy¶ (
Optional
[Set
[Type
[Module
]]]) – Used when selecting the modules for which you want to enable activation checkpointing. Enabling this can free up a significant amount of memory at the cost of speed since activations in these layers need to be recomputed during backpropagation. This accepts a set of the layer classes to wrap.state_dict_type¶ (
Literal
['full'
,'sharded'
]) –The format in which the state of the model and optimizers gets saved into the checkpoint.
"full"
: The full weights and optimizer states get assembled on rank 0 and saved to a single file."sharded"
: Each rank saves its shard of weights and optimizer states to a file. The checkpoint is a folder with files for each shard in the host. Note that TPU VM multihost does not have a shared filesystem.
sequential_save¶ (
bool
) – With this enabled, individual ranks consecutively save their state dictionary shards, reducing peak system RAM usage, although it elongates the saving process.**kwargs¶ (
Any
) – See available parameters intorch_xla.distributed.fsdp.XlaFullyShardedDataParallel
.
- all_gather(tensor, group=None, sync_grads=False)[source]¶
Function to gather a tensor from several distributed processes.
- all_reduce(output, group=None, reduce_op=None)[source]¶
Reduces the given tensor (e.g. across GPUs/processes).
- barrier(name=None, *args, **kwargs)[source]¶
Synchronizes all processes which blocks processes until the whole group enters this function.
- clip_gradients_norm(module, optimizer, max_norm, norm_type=2.0, error_if_nonfinite=True)[source]¶
Clip gradients by norm.
- Return type:
- load_checkpoint(path, state=None, strict=True)[source]¶
Given a folder, load the contents from a checkpoint and restore the state of the given objects.
The strategy currently only supports saving and loading sharded checkpoints which are stored in form of a directory of multiple files rather than a single file.
- module_init_context(empty_init=None)[source]¶
A context manager wrapping the model instantiation.
Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other patches to the model.
- module_sharded_context()[source]¶
A context manager that goes over the instantiation of an
torch.nn.Module
and handles sharding of parameters on creation.By sharding layers directly on instantiation, one can reduce peak memory usage and initialization time.
- Return type:
- optimizer_step(optimizer, **kwargs)[source]¶
Overrides default tpu optimizer_step since FSDP should not call torch_xla.core.xla_model.optimizer_step. Performs the actual optimizer step.
- process_dataloader(dataloader)[source]¶
Wraps the dataloader if necessary.
- Parameters:
dataloader¶ (
DataLoader
) – iterable. Ideally of type:torch.utils.data.DataLoader
- Return type:
MpDeviceLoader
- save_checkpoint(path, state, storage_options=None, filter=None)[source]¶
Save model, optimizer, and other state in the provided checkpoint directory.
If the user specifies sharded checkpointing, the directory will contain one file per process, with model- and optimizer shards stored per file. If the user specifies full checkpointing, the directory will contain a consolidated checkpoint combining all of the sharded checkpoints.
- Return type:
- setup_environment()[source]¶
Setup any processes or distributed connections.
This must be called by the framework at the beginning of every process, before any distributed communication takes place.
- Return type:
- setup_module(module)[source]¶
Performs setup for the model, e.g., by wrapping it by another class.
- Return type:
- setup_module_and_optimizers(module, optimizers)[source]¶
Returns NotImplementedError since for XLAFSDP optimizer setup must happen after module setup.
- setup_optimizer(optimizer)[source]¶
Set up an optimizer for a model wrapped with XLAFSDP.
This setup method doesn’t modify the optimizer or wrap the optimizer. The only thing it currently does is verify that the optimizer was created after the model was wrapped with
setup_module()
with a reference to the flattened parameters.- Return type: