ModelParallelStrategy¶
- class lightning.fabric.strategies.ModelParallelStrategy(parallelize_fn, data_parallel_size='auto', tensor_parallel_size='auto', save_distributed_checkpoint=True, process_group_backend=None, timeout=datetime.timedelta(seconds=1800))[source]¶
Bases:
ParallelStrategy
Enables user-defined parallelism applied to a model.
Warning
This is an experimental feature.
Currently supports up to 2D parallelism. Specifically, it supports the combination of Fully Sharded Data-Parallel 2 (FSDP2) with Tensor Parallelism (DTensor). These PyTorch APIs are currently still experimental in PyTorch. Requires PyTorch 2.4 or newer.
- Parameters:
parallelize_fn¶ (
Callable
[[TypeVar
(TModel
, bound=Module
),DeviceMesh
],TypeVar
(TModel
, bound=Module
)]) – A function that applies parallelisms to a module. The strategy will provide the model and device mesh as input.data_parallel_size¶ (
Union
[Literal
['auto'
],int
]) – The number of devices within a data-parallel group. Defaults to"auto"
, which sets this size to the number of nodes in the cluster.tensor_parallel_size¶ (
Union
[Literal
['auto'
],int
]) – The number of devices within a tensor-parallel group. Defaults to"auto"
, which sets this size to the number of GPUs in a single node.save_distributed_checkpoint¶ (
bool
) – IfTrue
, each rank saves its shard of weights and optimizer states to a file. The checkpoint is a folder with as many files as the world size. IfFalse
, the full weights and optimizer states get assembled on rank 0 and saved to a single file.
- all_reduce(tensor, group=None, reduce_op='mean')[source]¶
Reduces the given tensor (e.g. across GPUs/processes).
- barrier(*args, **kwargs)[source]¶
Synchronizes all processes which blocks processes until the whole group enters this function.
- load_checkpoint(path, state=None, strict=True)[source]¶
Load the contents from a checkpoint and restore the state of the given objects.
- module_init_context(empty_init=None)[source]¶
A context manager wrapping the model instantiation.
Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other patches to the model.
- save_checkpoint(path, state, storage_options=None, filter=None)[source]¶
Save model, optimizer, and other state to a checkpoint on disk.
If distributed checkpointing is enabled (default), the checkpoint gets saved as a directory containing one file per process, with model- and optimizer shards stored per file. Additionally, it creates a metadata file meta.pt with the rest of the user’s state (only saved from rank 0). If distributed checkpointing is disabled (
save_distributed_checkpoint=False
), the checkpoint will be written to a single file containing the weights, optimizer state and other metadata.- Return type:
- setup_environment()[source]¶
Setup any processes or distributed connections.
This must be called by the framework at the beginning of every process, before any distributed communication takes place.
- Return type:
- setup_module(module)[source]¶
Performs setup for the model, e.g., by wrapping it by another class.
- Return type: