FSDPStrategy
- class lightning.pytorch.strategies.FSDPStrategy(accelerator=None, parallel_devices=None, cluster_environment=None, checkpoint_io=None, precision_plugin=None, process_group_backend=None, cpu_offload=None, mixed_precision=None, activation_checkpointing=None, **kwargs)[source]
Bases:
lightning.pytorch.strategies.parallel.ParallelStrategy
Strategy for Fully Sharded Data Parallel provided by torch.distributed.
Warning
This is an experimental feature.
Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model size, whilst using efficient communication to reduce overhead. In practice, this means we can remain at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar to ZeRO-Stage 3.
For more information check out this blogpost.
Defaults have been set and options have been exposed, but may require configuration based on your level of memory/speed efficiency. We suggest having a look at this tutorial for more information.
- Parameters
cpu_offload (
Union
[bool
,CPUOffload
,None
]) – Seecpu_offload
parameter intorch.distributed.fsdp.FullyShardedDataParallel
.mixed_precision (
Optional
[MixedPrecision
]) – Seemixed_precision
parameter intorch.distributed.fsdp.FullyShardedDataParallel
.activation_checkpointing (
Union
[Type
[Module
],List
[Type
[Module
]],None
]) – A single layer or a list of layer classes for which you want to enable activation checkpointing. This is typically your transformer block (including attention + feed-forward). Enabling this can free up a significant amount of memory at the cost of speed since activations in these layers need to be recomputed during backpropagation.**kwargs – See available parameters in
torch.distributed.fsdp.FullyShardedDataParallel
.
- barrier(name=None)[source]
Synchronizes all processes which blocks processes until the whole group enters this function.
- broadcast(obj, src=0)[source]
Broadcasts an object to all processes.
- model_sharded_context()[source]
Provide hook to create modules in a distributed aware context. This is useful for when we’d like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time.
Returns: Model parallel context.
- Return type
- predict_step(*args, **kwargs)[source]
The actual predict step.
See
predict_step()
for more details
- reduce(tensor, group=None, reduce_op='mean')[source]
Reduces a tensor from several distributed processes to one aggregated tensor.
- Parameters
- Return type
- Returns
reduced value, except when the input was not a tensor the output remains is unchanged
- setup(trainer)[source]
Setup plugins for the trainer fit and creates optimizers.
- setup_environment()[source]
Setup any processes or distributed connections.
This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator environment before setup is complete.
- Return type
- setup_optimizers(trainer)[source]
Creates optimizers and schedulers.
- teardown()[source]
This method is called to teardown the training process.
It is the right place to release memory and free other resources.
- Return type
- test_step(*args, **kwargs)[source]
The actual test step.
See
test_step()
for more details
- training_step(*args, **kwargs)[source]
The actual training step.
See
training_step()
for more details
- validation_step(*args, **kwargs)[source]
The actual validation step.
See
validation_step()
for more details
- property root_device: torch.device
Return the root device.
- Return type