FSDPStrategy

class lightning.fabric.strategies.FSDPStrategy(accelerator=None, parallel_devices=None, cluster_environment=None, precision=None, process_group_backend=None, timeout=datetime.timedelta(seconds=1800), cpu_offload=None, mixed_precision=None, auto_wrap_policy=None, activation_checkpointing=None, activation_checkpointing_policy=None, sharding_strategy='FULL_SHARD', state_dict_type='sharded', **kwargs)[source]

Bases: ParallelStrategy, _Sharded

Strategy for Fully Sharded Data Parallel provided by torch.distributed.

Warning

This is an experimental feature.

Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model size, whilst using efficient communication to reduce overhead. In practice, this means we can remain at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar to ZeRO-Stage 3.

For more information check out this blogpost.

Defaults have been set and options have been exposed, but may require configuration based on your level of memory/speed efficiency. We suggest having a look at this tutorial for more information.

Parameters:
_configure_launcher()[source]

Attach the launcher based on Strategy.

Return type:

None

all_reduce(tensor, group=None, reduce_op='mean')[source]

Reduces the given tensor (e.g. across GPUs/processes).

Parameters:
  • tensor (Tensor) – the tensor to sync and reduce

  • group (Optional[Any]) – the process group to reduce

  • reduce_op (Union[ReduceOp, str, None]) – the reduction operation. Defaults to ‘mean’. Can also be a string ‘sum’ or ReduceOp.

Return type:

Tensor

barrier(*args, **kwargs)[source]

Synchronizes all processes which blocks processes until the whole group enters this function.

Parameters:

name – an optional name to pass into barrier.

Return type:

None

broadcast(obj, src=0)[source]

Broadcasts an object to all processes.

Parameters:
  • obj (TypeVar(TBroadcast)) – the object to broadcast

  • src (int) – source rank

Return type:

TypeVar(TBroadcast)

clip_gradients_norm(module, optimizer, max_norm, norm_type=2.0, error_if_nonfinite=True)[source]

Clip gradients by norm.

Return type:

Tensor

load_checkpoint(path, state=None, strict=True)[source]

Load the contents from a checkpoint and restore the state of the given objects.

The strategy currently only supports saving and loading sharded checkpoints which are stored in form of a directory of multiple files rather than a single file.

Return type:

Dict[str, Any]

module_init_context(empty_init=None)[source]

A context manager wrapping the model instantiation.

Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other patches to the model.

Parameters:

empty_init (Optional[bool]) – Whether to initialize the model with empty weights (uninitialized memory). If None, the strategy will decide. Some strategies may not support all options.

Return type:

ContextManager

module_sharded_context()[source]

A context manager that goes over the instantiation of an torch.nn.Module and handles sharding of parameters on creation.

By sharding layers directly on instantiation, one can reduce peak memory usage and initialization time.

Return type:

ContextManager

module_to_device(module)[source]

Moves the model to the correct device.

Return type:

None

save_checkpoint(path, state, storage_options=None, filter=None)[source]

Save model, optimizer, and other state to a checkpoint on disk.

If the state-dict-type is 'full', the checkpoint will be written to a single file containing the weights, optimizer state and other metadata. If the state-dict-type is 'sharded', the checkpoint gets saved as a directory containing one file per process, with model- and optimizer shards stored per file. Additionally, it creates a metadata file meta.pt with the rest of the user’s state (only saved from rank 0).

Return type:

None

setup_environment()[source]

Setup any processes or distributed connections.

This must be called by the framework at the beginning of every process, before any distributed communication takes place.

Return type:

None

setup_module(module)[source]

Wraps the model into a FullyShardedDataParallel module.

Return type:

Module

setup_module_and_optimizers(module, optimizers)[source]

Wraps the model into a FullyShardedDataParallel module and sets use_orig_params=True to keep the reference to the original parameters in the optimizer.

Return type:

Tuple[Module, List[Optimizer]]

setup_optimizer(optimizer)[source]

Set up an optimizer for a model wrapped with FSDP.

This setup method doesn’t modify the optimizer or wrap the optimizer. The only thing it currently does is verify that the optimizer was created after the model was wrapped with setup_module() with a reference to the flattened parameters.

Return type:

Optimizer

property distributed_sampler_kwargs: Dict[str, Any]

Arguments for the DistributedSampler.

If this method is not defined, or it returns None, then the DistributedSampler will not be used.

property root_device: device

Returns the root device.