FSDPStrategy¶
- class lightning.fabric.strategies.FSDPStrategy(accelerator=None, parallel_devices=None, cluster_environment=None, precision=None, process_group_backend=None, timeout=datetime.timedelta(seconds=1800), cpu_offload=None, mixed_precision=None, auto_wrap_policy=None, activation_checkpointing=None, activation_checkpointing_policy=None, sharding_strategy='FULL_SHARD', state_dict_type='sharded', device_mesh=None, **kwargs)[source]¶
Bases:
ParallelStrategy
,_Sharded
Strategy for Fully Sharded Data Parallel provided by torch.distributed.
Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model size, whilst using efficient communication to reduce overhead. In practice, this means we can remain at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar to ZeRO-Stage 3.
For more information check out this blogpost.
Defaults have been set and options have been exposed, but may require configuration based on your level of memory/speed efficiency. We suggest having a look at this tutorial for more information.
- Parameters:
cpu_offload¶ (
Union
[bool
,CPUOffload
,None
]) – Seecpu_offload
parameter intorch.distributed.fsdp.FullyShardedDataParallel
.mixed_precision¶ (
Optional
[MixedPrecision
]) – Seemixed_precision
parameter intorch.distributed.fsdp.FullyShardedDataParallel
.auto_wrap_policy¶ (
Union
[set
[type
[Module
]],Callable
[[Module
,bool
,int
],bool
],ModuleWrapPolicy
,None
]) – Same asauto_wrap_policy
parameter intorch.distributed.fsdp.FullyShardedDataParallel
. For convenience, this also accepts a set of the layer classes to wrap.activation_checkpointing¶ (
Union
[type
[Module
],list
[type
[Module
]],None
]) – Deprecated. Useactivation_checkpointing_policy
.activation_checkpointing_policy¶ (
Union
[set
[type
[Module
]],Callable
[[Module
,bool
,int
],bool
],ModuleWrapPolicy
,None
]) – Same asauto_wrap_policy
parameter intorch.distributed.fsdp.FullyShardedDataParallel
but used when selecting the modules for which you want to enable activation checkpointing. Enabling this can free up a significant amount of memory at the cost of speed since activations in these layers need to be recomputed during backpropagation. For convenience, this also accepts a set of the layer classes to wrap.sharding_strategy¶ (
Union
[ShardingStrategy
,Literal
['FULL_SHARD'
,'SHARD_GRAD_OP'
,'NO_SHARD'
,'HYBRID_SHARD'
]]) –Select whether to shard model parameters, gradients, optimizer states, or a combination of them. Available values are:
"FULL_SHARD"
: Shards model parameters, gradients, and optimizer states (default)."SHARD_GRAD_OP"
: Shards gradients and optimizer states only. Model parameters get replicated."NO_SHARD"
: No sharding (identical to regular DDP)."HYBRID_SHARD"
: Shards model parameters, gradients, and optimizer states within a single machine, but replicates across machines. See also the device_mesh parameter below.
Also accepts a
torch.distributed.fsdp.ShardingStrategy
enum value.device_mesh¶ (
Union
[tuple
[int
],DeviceMesh
,None
]) – A tuple (replication size, sharding size) that defines over how many devices to shard and replicate the model. The product of the two numbers must equal the world size. Only valid in combination with the HYBRID_SHARD sharding strategy.state_dict_type¶ (
Literal
['full'
,'sharded'
]) –The format in which the state of the model and optimizers gets saved into the checkpoint.
"full"
: The full weights and optimizer states get assembled on rank 0 and saved to a single file."sharded"
: Each rank saves its shard of weights and optimizer states to a file. The checkpoint is a folder with as many files as the world size.
**kwargs¶ (
Any
) – See available parameters intorch.distributed.fsdp.FullyShardedDataParallel
.
- all_reduce(tensor, group=None, reduce_op='mean')[source]¶
Reduces the given tensor (e.g. across GPUs/processes).
- barrier(*args, **kwargs)[source]¶
Synchronizes all processes which blocks processes until the whole group enters this function.
- clip_gradients_norm(module, optimizer, max_norm, norm_type=2.0, error_if_nonfinite=True)[source]¶
Clip gradients by norm.
- Return type:
- load_checkpoint(path, state=None, strict=True)[source]¶
Load the contents from a checkpoint and restore the state of the given objects.
- module_init_context(empty_init=None)[source]¶
A context manager wrapping the model instantiation.
Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other patches to the model.
- module_sharded_context()[source]¶
A context manager that goes over the instantiation of an
torch.nn.Module
and handles sharding of parameters on creation.By sharding layers directly on instantiation, one can reduce peak memory usage and initialization time.
- Return type:
- save_checkpoint(path, state, storage_options=None, filter=None)[source]¶
Save model, optimizer, and other state to a checkpoint on disk.
If the state-dict-type is
'full'
, the checkpoint will be written to a single file containing the weights, optimizer state and other metadata. If the state-dict-type is'sharded'
, the checkpoint gets saved as a directory containing one file per process, with model- and optimizer shards stored per file. Additionally, it creates a metadata file meta.pt with the rest of the user’s state (only saved from rank 0).- Return type:
- setup_environment()[source]¶
Setup any processes or distributed connections.
This must be called by the framework at the beginning of every process, before any distributed communication takes place.
- Return type:
- setup_module(module)[source]¶
Wraps the model into a
FullyShardedDataParallel
module.- Return type:
- setup_module_and_optimizers(module, optimizers)[source]¶
Wraps the model into a
FullyShardedDataParallel
module and sets use_orig_params=True to keep the reference to the original parameters in the optimizer.
- setup_optimizer(optimizer)[source]¶
Set up an optimizer for a model wrapped with FSDP.
This setup method doesn’t modify the optimizer or wrap the optimizer. The only thing it currently does is verify that the optimizer was created after the model was wrapped with
setup_module()
with a reference to the flattened parameters.- Return type: