FSDPStrategy

class lightning.pytorch.strategies.FSDPStrategy(accelerator=None, parallel_devices=None, cluster_environment=None, checkpoint_io=None, precision_plugin=None, process_group_backend=None, timeout=datetime.timedelta(seconds=1800), cpu_offload=None, mixed_precision=None, auto_wrap_policy=None, activation_checkpointing=None, activation_checkpointing_policy=None, sharding_strategy='FULL_SHARD', state_dict_type='full', device_mesh=None, **kwargs)[source]

Bases: ParallelStrategy

Strategy for Fully Sharded Data Parallel provided by torch.distributed.

Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model size, whilst using efficient communication to reduce overhead. In practice, this means we can remain at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar to ZeRO-Stage 3.

For more information check out this blogpost.

Defaults have been set and options have been exposed, but may require configuration based on your level of memory/speed efficiency. We suggest having a look at this tutorial for more information.

Parameters:
  • cpu_offload (Union[bool, CPUOffload, None]) – See cpu_offload parameter in torch.distributed.fsdp.FullyShardedDataParallel.

  • mixed_precision (Optional[MixedPrecision]) – See mixed_precision parameter in torch.distributed.fsdp.FullyShardedDataParallel.

  • auto_wrap_policy (Union[set[type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy, None]) – Same as auto_wrap_policy parameter in torch.distributed.fsdp.FullyShardedDataParallel. For convenience, this also accepts a set of the layer classes to wrap.

  • activation_checkpointing (Union[type[Module], list[type[Module]], None]) – Deprecated. Use activation_checkpointing_policy.

  • activation_checkpointing_policy (Union[set[type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy, None]) – Same as auto_wrap_policy parameter in torch.distributed.fsdp.FullyShardedDataParallel but used when selecting the modules for which you want to enable activation checkpointing. Enabling this can free up a significant amount of memory at the cost of speed since activations in these layers need to be recomputed during backpropagation. For convenience, this also accepts a set of the layer classes to wrap.

  • sharding_strategy (Union[ShardingStrategy, Literal['FULL_SHARD', 'SHARD_GRAD_OP', 'NO_SHARD', 'HYBRID_SHARD']]) –

    Select whether to shard model parameters, gradients, optimizer states, or a combination of them. Available values are:

    • "FULL_SHARD": Shards model parameters, gradients, and optimizer states (default).

    • "SHARD_GRAD_OP": Shards gradients and optimizer states only. Model parameters get replicated.

    • "NO_SHARD": No sharding (identical to regular DDP).

    • "HYBRID_SHARD": Shards model parameters, gradients, and optimizer states within a single machine, but replicates across machines. See also the device_mesh parameter below.

    Also accepts a torch.distributed.fsdp.ShardingStrategy enum value.

  • device_mesh (Union[tuple[int], DeviceMesh, None]) – A tuple (replication size, sharding size) that defines over how many devices to shard and replicate the model. The product of the two numbers must equal the world size. Only valid in combination with the HYBRID_SHARD sharding strategy.

  • state_dict_type (Literal['full', 'sharded']) –

    The format in which the state of the model and optimizers gets saved into the checkpoint.

    • "full": The full weights and optimizer states get assembled on rank 0 and saved to a single file.

    • "sharded": Each rank saves its shard of weights and optimizer states to a file. The checkpoint is a folder with as many files as the world size.

  • **kwargs (Any) – See available parameters in torch.distributed.fsdp.FullyShardedDataParallel.

barrier(name=None)[source]

Synchronizes all processes which blocks processes until the whole group enters this function.

Parameters:

name (Optional[str]) – an optional name to pass into barrier.

Return type:

None

broadcast(obj, src=0)[source]

Broadcasts an object to all processes.

Parameters:
  • obj (TypeVar(TBroadcast)) – the object to broadcast

  • src (int) – source rank

Return type:

TypeVar(TBroadcast)

lightning_module_state_dict()[source]

Returns model state.

Return type:

dict[str, Any]

model_sharded_context()[source]

Provide hook to create modules in a distributed aware context. This is useful for when we’d like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time.

Returns: Model parallel context.

Return type:

Generator[None, None, None]

model_to_device()[source]

Moves the model to the correct device.

Return type:

None

optimizer_state(optimizer)[source]

Returns state of an optimizer.

Allows for syncing/collating optimizer state from processes in custom strategies.

Return type:

dict[str, Tensor]

reduce(tensor, group=None, reduce_op='mean')[source]

Reduces a tensor from several distributed processes to one aggregated tensor.

Parameters:
  • tensor (Union[Tensor, Any]) – the tensor to sync and reduce

  • group (Optional[Any]) – the process group to gather results from. Defaults to all processes (world)

  • reduce_op (Union[ReduceOp, str, None]) – the reduction operation. Defaults to ‘mean’/’avg’. Can also be a string ‘sum’ to calculate the sum during reduction.

Return type:

Tensor

Returns:

reduced value, except when the input was not a tensor the output remains is unchanged

save_checkpoint(checkpoint, filepath, storage_options=None)[source]

Save model/training states as a checkpoint file through state-dump and file-write.

Parameters:
  • checkpoint (dict[str, Any]) – dict containing model and trainer state

  • filepath (Union[str, Path]) – write-target file’s path

  • storage_options (Optional[Any]) – parameter for how to save to storage, passed to CheckpointIO plugin

Return type:

None

setup(trainer)[source]

Sets up the accelerator, plugins and initializes the optimizers (if needed).

Parameters:

trainer (Trainer) – the trainer instance

Return type:

None

setup_environment()[source]

Setup any processes or distributed connections.

This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator environment before setup is complete.

Return type:

None

setup_optimizers(trainer)[source]

Creates optimizers and schedulers.

Parameters:

trainer (Trainer) – the Trainer, these optimizers should be connected to

Return type:

None

teardown()[source]

This method is called to teardown the training process.

It is the right place to release memory and free other resources.

Return type:

None

tensor_init_context(empty_init=None)[source]

Controls how tensors get created (device, dtype).

Parameters:

empty_init (Optional[bool]) – Whether to initialize the model with empty weights (uninitialized memory). If None, the strategy will decide. Some strategies may not support all options.

Return type:

Generator[None, None, None]

property lightning_restore_optimizer: bool

Override to disable Lightning restoring optimizers/schedulers.

This is useful for strategies which manage restoring optimizers/schedulers.

property restore_checkpoint_after_setup: bool

Override to delay restoring from checkpoint till after the setup phase has completed. This is useful when the strategy requires all the setup hooks to run before loading checkpoint.

Returns:

If True, restore checkpoint after strategy setup.

property root_device: device

Return the root device.