XLAFSDPStrategy¶
- class lightning.fabric.strategies.XLAFSDPStrategy(accelerator=None, parallel_devices=None, checkpoint_io=None, precision=None, auto_wrap_policy=None, activation_checkpointing_policy=None, state_dict_type='sharded', sequential_save=False, **kwargs)[source]¶
Bases:
ParallelStrategy,_ShardedStrategy for training multiple XLA devices using the
torch_xla.distributed.xla_fully_sharded_data_parallel.XlaFullyShardedDataParallel()method.Warning
This is an experimental feature.
For more information check out https://github.com/pytorch/xla/blob/v2.5.0/docs/fsdp.md
- Parameters:
auto_wrap_policy¶ (
Union[set[type[Module]],Callable[[Module,bool,int],bool],None]) – Same asauto_wrap_policyparameter intorch_xla.distributed.fsdp.XlaFullyShardedDataParallel. For convenience, this also accepts a set of the layer classes to wrap.activation_checkpointing_policy¶ (
Optional[set[type[Module]]]) – Used when selecting the modules for which you want to enable activation checkpointing. Enabling this can free up a significant amount of memory at the cost of speed since activations in these layers need to be recomputed during backpropagation. This accepts a set of the layer classes to wrap.state_dict_type¶ (
Literal['full','sharded']) –The format in which the state of the model and optimizers gets saved into the checkpoint.
"full": The full weights and optimizer states get assembled on rank 0 and saved to a single file."sharded": Each rank saves its shard of weights and optimizer states to a file. The checkpoint is a folder with files for each shard in the host. Note that TPU VM multihost does not have a shared filesystem.
sequential_save¶ (
bool) – With this enabled, individual ranks consecutively save their state dictionary shards, reducing peak system RAM usage, although it elongates the saving process.**kwargs¶ (
Any) – See available parameters intorch_xla.distributed.fsdp.XlaFullyShardedDataParallel.
- all_gather(tensor, group=None, sync_grads=False)[source]¶
Function to gather a tensor from several distributed processes.
- all_reduce(output, group=None, reduce_op=None)[source]¶
Reduces the given tensor (e.g. across GPUs/processes).
- barrier(name=None, *args, **kwargs)[source]¶
Synchronizes all processes which blocks processes until the whole group enters this function.
- clip_gradients_norm(module, optimizer, max_norm, norm_type=2.0, error_if_nonfinite=True)[source]¶
Clip gradients by norm.
- Return type:
- load_checkpoint(path, state=None, strict=True, weights_only=None)[source]¶
Given a folder, load the contents from a checkpoint and restore the state of the given objects.
The strategy currently only supports saving and loading sharded checkpoints which are stored in form of a directory of multiple files rather than a single file.
- module_init_context(empty_init=None)[source]¶
A context manager wrapping the model instantiation.
Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other patches to the model.
- module_sharded_context()[source]¶
A context manager that goes over the instantiation of an
torch.nn.Moduleand handles sharding of parameters on creation.By sharding layers directly on instantiation, one can reduce peak memory usage and initialization time.
- Return type:
- process_dataloader(dataloader)[source]¶
Wraps the dataloader if necessary.
- Parameters:
dataloader¶ (
DataLoader) – iterable. Ideally of type:torch.utils.data.DataLoader- Return type:
MpDeviceLoader
- save_checkpoint(path, state, storage_options=None, filter=None)[source]¶
Save model, optimizer, and other state in the provided checkpoint directory.
If the user specifies sharded checkpointing, the directory will contain one file per process, with model- and optimizer shards stored per file. If the user specifies full checkpointing, the directory will contain a consolidated checkpoint combining all of the sharded checkpoints.
- Return type:
- setup_environment()[source]¶
Setup any processes or distributed connections.
This must be called by the framework at the beginning of every process, before any distributed communication takes place.
- Return type:
- setup_module(module)[source]¶
Performs setup for the model, e.g., by wrapping it by another class.
- Return type:
- setup_module_and_optimizers(module, optimizers, scheduler=None)[source]¶
Set up a model and multiple optimizers together.
The returned objects are expected to be in the same order they were passed in. The default implementation will call
setup_module()andsetup_optimizer()on the inputs.