XLAFSDPStrategy

class lightning.fabric.strategies.XLAFSDPStrategy(accelerator=None, parallel_devices=None, checkpoint_io=None, precision=None, auto_wrap_policy=None, activation_checkpointing_policy=None, state_dict_type='sharded', sequential_save=False, **kwargs)[source]

Bases: ParallelStrategy, _Sharded

Strategy for training multiple XLA devices using the torch_xla.distributed.xla_fully_sharded_data_parallel.XlaFullyShardedDataParallel() method.

Warning

This is an experimental feature.

For more information check out https://github.com/pytorch/xla/blob/v2.5.0/docs/fsdp.md

Parameters:
  • auto_wrap_policy (Union[set[type[Module]], Callable[[Module, bool, int], bool], None]) – Same as auto_wrap_policy parameter in torch_xla.distributed.fsdp.XlaFullyShardedDataParallel. For convenience, this also accepts a set of the layer classes to wrap.

  • activation_checkpointing_policy (Optional[set[type[Module]]]) – Used when selecting the modules for which you want to enable activation checkpointing. Enabling this can free up a significant amount of memory at the cost of speed since activations in these layers need to be recomputed during backpropagation. This accepts a set of the layer classes to wrap.

  • state_dict_type (Literal['full', 'sharded']) –

    The format in which the state of the model and optimizers gets saved into the checkpoint.

    • "full": The full weights and optimizer states get assembled on rank 0 and saved to a single file.

    • "sharded": Each rank saves its shard of weights and optimizer states to a file. The checkpoint is a folder with files for each shard in the host. Note that TPU VM multihost does not have a shared filesystem.

  • sequential_save (bool) – With this enabled, individual ranks consecutively save their state dictionary shards, reducing peak system RAM usage, although it elongates the saving process.

  • **kwargs (Any) – See available parameters in torch_xla.distributed.fsdp.XlaFullyShardedDataParallel.

_configure_launcher()[source]

Attach the launcher based on Strategy.

Return type:

None

all_gather(tensor, group=None, sync_grads=False)[source]

Function to gather a tensor from several distributed processes.

Parameters:
  • tensor (Tensor) – tensor to all-gather.

  • group (Optional[Any]) – unused.

  • sync_grads (bool) – flag that allows users to synchronize gradients for the all-gather operation.

Return type:

Tensor

Returns:

A tensor of shape (world_size, …)

all_reduce(output, group=None, reduce_op=None)[source]

Reduces the given tensor (e.g. across GPUs/processes).

Parameters:
  • tensor – the tensor to sync and reduce

  • group (Optional[Any]) – the process group to reduce

  • reduce_op (Union[ReduceOp, str, None]) – the reduction operation. Defaults to ‘mean’. Can also be a string ‘sum’ or ReduceOp.

Return type:

Tensor

barrier(name=None, *args, **kwargs)[source]

Synchronizes all processes which blocks processes until the whole group enters this function.

Parameters:

name (Optional[str]) – an optional name to pass into barrier.

Return type:

None

broadcast(obj, src=0)[source]

Broadcasts an object to all processes.

Parameters:
  • obj (TypeVar(TBroadcast)) – the object to broadcast

  • src (int) – source rank

Return type:

TypeVar(TBroadcast)

clip_gradients_norm(module, optimizer, max_norm, norm_type=2.0, error_if_nonfinite=True)[source]

Clip gradients by norm.

Return type:

Tensor

clip_gradients_value(module, optimizer, clip_val)[source]

Clip gradients by value.

Return type:

None

load_checkpoint(path, state=None, strict=True, weights_only=None)[source]

Given a folder, load the contents from a checkpoint and restore the state of the given objects.

The strategy currently only supports saving and loading sharded checkpoints which are stored in form of a directory of multiple files rather than a single file.

Return type:

dict[str, Any]

module_init_context(empty_init=None)[source]

A context manager wrapping the model instantiation.

Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other patches to the model.

Parameters:

empty_init (Optional[bool]) – Whether to initialize the model with empty weights (uninitialized memory). If None, the strategy will decide. Some strategies may not support all options.

Return type:

AbstractContextManager

module_sharded_context()[source]

A context manager that goes over the instantiation of an torch.nn.Module and handles sharding of parameters on creation.

By sharding layers directly on instantiation, one can reduce peak memory usage and initialization time.

Return type:

AbstractContextManager

module_to_device(module)[source]

Moves the model to the correct device.

Return type:

None

optimizer_step(optimizer, **kwargs)[source]

Performs the actual optimizer step.

Parameters:
  • optimizer (Optimizable) – the optimizer performing the step

  • **kwargs (Any) – Any extra arguments to optimizer.step

Return type:

Any

process_dataloader(dataloader)[source]

Wraps the dataloader if necessary.

Parameters:

dataloader (DataLoader) – iterable. Ideally of type: torch.utils.data.DataLoader

Return type:

MpDeviceLoader

save_checkpoint(path, state, storage_options=None, filter=None)[source]

Save model, optimizer, and other state in the provided checkpoint directory.

If the user specifies sharded checkpointing, the directory will contain one file per process, with model- and optimizer shards stored per file. If the user specifies full checkpointing, the directory will contain a consolidated checkpoint combining all of the sharded checkpoints.

Return type:

None

setup_environment()[source]

Setup any processes or distributed connections.

This must be called by the framework at the beginning of every process, before any distributed communication takes place.

Return type:

None

setup_module(module)[source]

Performs setup for the model, e.g., by wrapping it by another class.

Return type:

Module

setup_module_and_optimizers(module, optimizers, scheduler=None)[source]

Set up a model and multiple optimizers together.

The returned objects are expected to be in the same order they were passed in. The default implementation will call setup_module() and setup_optimizer() on the inputs.

Return type:

tuple[Module, list[Optimizer], Optional[_LRScheduler]]

setup_optimizer(optimizer)[source]

Performs setup for the optimizer, e.g., by wrapping it by another class.

Return type:

Optimizer

property root_device: device

Returns the root device.