FSDPStrategy¶
- class lightning.fabric.strategies.FSDPStrategy(accelerator=None, parallel_devices=None, cluster_environment=None, checkpoint_io=None, precision=None, process_group_backend=None, timeout=datetime.timedelta(seconds=1800), cpu_offload=None, mixed_precision=None, activation_checkpointing=None, **kwargs)[source]¶
Bases:
ParallelStrategy
,_Sharded
Strategy for Fully Sharded Data Parallel provided by torch.distributed.
Warning
This is an experimental feature.
Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model size, whilst using efficient communication to reduce overhead. In practice, this means we can remain at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar to ZeRO-Stage 3.
For more information check out this blogpost.
Defaults have been set and options have been exposed, but may require configuration based on your level of memory/speed efficiency. We suggest having a look at this tutorial for more information.
- Parameters:
cpu_offload¶ (
Union
[bool
,CPUOffload
,None
]) – Seecpu_offload
parameter intorch.distributed.fsdp.FullyShardedDataParallel
.mixed_precision¶ (
Optional
[MixedPrecision
]) – Seemixed_precision
parameter intorch.distributed.fsdp.FullyShardedDataParallel
.activation_checkpointing¶ (
Union
[Type
[Module
],List
[Type
[Module
]],None
]) – A single layer or a list of layer classes for which you want to enable activation checkpointing. This is typically your transformer block (including attention + feed-forward). Enabling this can free up a significant amount of memory at the cost of speed since activations in these layers need to be recomputed during backpropagation.**kwargs¶ (
Any
) – See available parameters intorch.distributed.fsdp.FullyShardedDataParallel
.
- all_reduce(tensor, group=None, reduce_op='mean')[source]¶
Reduces the given tensor (e.g. across GPUs/processes).
- barrier(*args, **kwargs)[source]¶
Synchronizes all processes which blocks processes until the whole group enters this function.
- clip_gradients_norm(module, optimizer, max_norm, norm_type=2.0, error_if_nonfinite=True)[source]¶
Clip gradients by norm.
- Return type:
- module_sharded_context()[source]¶
A context manager that goes over the instantiation of an
torch.nn.Module
and handles sharding of parameters on creation.By sharding layers directly on instantiation, one can reduce peak memory usage and initialization time.
- Return type:
- setup_environment()[source]¶
Setup any processes or distributed connections.
This must be called by the framework at the beginning of every process, before any distributed communication takes place.
- Return type:
- setup_module(module)[source]¶
Wraps the model into a
FullyShardedDataParallel
module.- Return type:
- setup_module_and_optimizers(module, optimizers)[source]¶
Set up a model and multiple optimizers together.
The returned objects are expected to be in the same order they were passed in. The default implementation will call
setup_module()
andsetup_optimizer()
on the inputs.
- setup_optimizer(optimizer)[source]¶
Set up an optimizer for a model wrapped with FSDP.
This setup method doesn’t modify the optimizer or wrap the optimizer. The only thing it currently does is verify that the optimizer was created after the model was wrapped with
setup_module()
with a reference to the flattened parameters.- Return type: