ModelParallelStrategy

class lightning.fabric.strategies.ModelParallelStrategy(parallelize_fn, data_parallel_size='auto', tensor_parallel_size='auto', save_distributed_checkpoint=True, process_group_backend=None, timeout=datetime.timedelta(seconds=1800))[source]

Bases: ParallelStrategy

Enables user-defined parallelism applied to a model.

Warning

This is an experimental feature.

Currently supports up to 2D parallelism. Specifically, it supports the combination of Fully Sharded Data-Parallel 2 (FSDP2) with Tensor Parallelism (DTensor). These PyTorch APIs are currently still experimental in PyTorch. Requires PyTorch 2.4 or newer.

Parameters:
  • parallelize_fn (Callable[[TypeVar(TModel, bound= Module), DeviceMesh], TypeVar(TModel, bound= Module)]) – A function that applies parallelisms to a module. The strategy will provide the model and device mesh as input.

  • data_parallel_size (Union[Literal['auto'], int]) – The number of devices within a data-parallel group. Defaults to "auto", which sets this size to the number of nodes in the cluster.

  • tensor_parallel_size (Union[Literal['auto'], int]) – The number of devices within a tensor-parallel group. Defaults to "auto", which sets this size to the number of GPUs in a single node.

  • save_distributed_checkpoint (bool) – If True, each rank saves its shard of weights and optimizer states to a file. The checkpoint is a folder with as many files as the world size. If False, the full weights and optimizer states get assembled on rank 0 and saved to a single file.

_configure_launcher()[source]

Attach the launcher based on Strategy.

Return type:

None

all_reduce(tensor, group=None, reduce_op='mean')[source]

Reduces the given tensor (e.g. across GPUs/processes).

Parameters:
  • tensor (Tensor) – the tensor to sync and reduce

  • group (Optional[Any]) – the process group to reduce

  • reduce_op (Union[ReduceOp, str, None]) – the reduction operation. Defaults to ‘mean’. Can also be a string ‘sum’ or ReduceOp.

Return type:

Tensor

barrier(*args, **kwargs)[source]

Synchronizes all processes which blocks processes until the whole group enters this function.

Parameters:

name – an optional name to pass into barrier.

Return type:

None

broadcast(obj, src=0)[source]

Broadcasts an object to all processes.

Parameters:
  • obj (TypeVar(TBroadcast)) – the object to broadcast

  • src (int) – source rank

Return type:

TypeVar(TBroadcast)

load_checkpoint(path, state=None, strict=True)[source]

Load the contents from a checkpoint and restore the state of the given objects.

Return type:

dict[str, Any]

module_init_context(empty_init=None)[source]

A context manager wrapping the model instantiation.

Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other patches to the model.

Parameters:

empty_init (Optional[bool]) – Whether to initialize the model with empty weights (uninitialized memory). If None, the strategy will decide. Some strategies may not support all options.

Return type:

AbstractContextManager

module_to_device(module)[source]

Moves the model to the correct device.

Return type:

None

save_checkpoint(path, state, storage_options=None, filter=None)[source]

Save model, optimizer, and other state to a checkpoint on disk.

If distributed checkpointing is enabled (default), the checkpoint gets saved as a directory containing one file per process, with model- and optimizer shards stored per file. Additionally, it creates a metadata file meta.pt with the rest of the user’s state (only saved from rank 0). If distributed checkpointing is disabled (save_distributed_checkpoint=False), the checkpoint will be written to a single file containing the weights, optimizer state and other metadata.

Return type:

None

setup_environment()[source]

Setup any processes or distributed connections.

This must be called by the framework at the beginning of every process, before any distributed communication takes place.

Return type:

None

setup_module(module)[source]

Performs setup for the model, e.g., by wrapping it by another class.

Return type:

Module

property distributed_sampler_kwargs: dict[str, Any]

Arguments for the DistributedSampler.

If this method is not defined, or it returns None, then the DistributedSampler will not be used.

property root_device: device

Returns the root device.