# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
from collections.abc import Generator, Mapping
from contextlib import contextmanager, nullcontext
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
import torch
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
from torch import Tensor
from torch.optim import Optimizer
from typing_extensions import override
import lightning.pytorch as pl
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
from lightning.fabric.strategies.model_parallel import (
_distributed_checkpoint_save,
_is_sharded_checkpoint,
_load_checkpoint,
_setup_device_mesh,
)
from lightning.fabric.utilities.distributed import (
_distributed_is_initialized,
_get_default_process_group_backend_for_device,
_init_dist_connection,
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.init import _materialize_distributed_module
from lightning.fabric.utilities.load import _METADATA_FILENAME
from lightning.fabric.utilities.optimizer import _optimizers_to_device
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import _PATH, ReduceOp
from lightning.pytorch.core.optimizer import LightningOptimizer
from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from lightning.pytorch.strategies.parallel import ParallelStrategy
from lightning.pytorch.strategies.strategy import TBroadcast
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_only
if TYPE_CHECKING:
from torch.distributed.device_mesh import DeviceMesh
[docs]class ModelParallelStrategy(ParallelStrategy):
"""Enables user-defined parallelism applied to a model.
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
Currently supports up to 2D parallelism. Specifically, it supports the combination of
Fully Sharded Data-Parallel 2 (FSDP2) with Tensor Parallelism (DTensor). These PyTorch APIs are currently still
experimental in PyTorch (see https://pytorch.org/docs/stable/distributed.tensor.parallel.html).
Requires PyTorch 2.4 or newer.
Arguments:
data_parallel_size: The number of devices within a data-parallel group. Defaults to ``"auto"``, which
sets this size to the number of nodes in the cluster.
tensor_parallel_size: The number of devices within a tensor-parallel group. Defaults to ``"auto"``, which
sets this size to the number of GPUs in a single node.
save_distributed_checkpoint: If ``True``, each rank saves its shard of weights and optimizer states to a file.
The checkpoint is a folder with as many files as the world size.
If ``False``, the full weights and optimizer states get assembled on rank 0 and saved to a single file.
"""
def __init__(
self,
data_parallel_size: Union[Literal["auto"], int] = "auto",
tensor_parallel_size: Union[Literal["auto"], int] = "auto",
save_distributed_checkpoint: bool = True,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
) -> None:
super().__init__()
if not _TORCH_GREATER_EQUAL_2_4:
raise ImportError(f"{type(self).__name__} requires PyTorch 2.4 or higher.")
self._data_parallel_size = data_parallel_size
self._tensor_parallel_size = tensor_parallel_size
self._save_distributed_checkpoint = save_distributed_checkpoint
self._process_group_backend: Optional[str] = process_group_backend
self._timeout: Optional[timedelta] = timeout
self._device_mesh: Optional[DeviceMesh] = None
self.num_nodes = 1
@property
def device_mesh(self) -> "DeviceMesh":
if self._device_mesh is None:
raise RuntimeError("Accessing the device mesh before processes have initialized is not allowed.")
return self._device_mesh
@property
@override
def root_device(self) -> torch.device:
assert self.parallel_devices is not None
return self.parallel_devices[self.local_rank]
@property
def num_processes(self) -> int:
return len(self.parallel_devices) if self.parallel_devices is not None else 0
@property
@override
def distributed_sampler_kwargs(self) -> dict[str, Any]:
assert self.device_mesh is not None
data_parallel_mesh = self.device_mesh["data_parallel"]
return {"num_replicas": data_parallel_mesh.size(), "rank": data_parallel_mesh.get_local_rank()}
@property
def process_group_backend(self) -> Optional[str]:
return self._process_group_backend
@property
@override
def restore_checkpoint_after_setup(self) -> bool:
return True
@property
@override
def lightning_restore_optimizer(self) -> bool:
return False
@override
def _configure_launcher(self) -> None:
assert self.cluster_environment is not None
if not self.cluster_environment.creates_processes_externally:
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
[docs] @override
def setup_environment(self) -> None:
super().setup_environment()
self._setup_distributed()
if self._data_parallel_size == "auto":
self._data_parallel_size = self.num_nodes
if self._tensor_parallel_size == "auto":
self._tensor_parallel_size = self.num_processes
self._device_mesh = _setup_device_mesh(
self._data_parallel_size, self._tensor_parallel_size, self.world_size, self.root_device
)
# Users can access device mesh in `LightningModule.configure_model()`
assert self.lightning_module is not None
self.lightning_module._device_mesh = self._device_mesh
[docs] @override
def setup(self, trainer: "pl.Trainer") -> None:
from torch.distributed.fsdp import FullyShardedDataParallel
assert self.model is not None
assert self.accelerator is not None
self.accelerator.setup(trainer)
if not is_overridden("configure_model", self.lightning_module):
raise TypeError(
f"When using the {type(self).__name__}, you are required to override the `configure_model()` hook in"
f" the LightningModule and apply parallelization there."
)
if any(isinstance(mod, FullyShardedDataParallel) for mod in self.model.modules()):
raise TypeError(
"Found modules that are wrapped with `torch.distributed.fsdp.FullyShardedDataParallel`."
f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.4."
)
_materialize_distributed_module(self.model, self.root_device)
self.model = self.precision_plugin.convert_module(self.model)
self.model_to_device() # move all remaining layers if any left on CPU.
self.barrier()
if trainer.state.fn == TrainerFn.FITTING:
self.setup_optimizers(trainer)
self.setup_precision_plugin()
if trainer.state.fn == TrainerFn.FITTING:
_optimizers_to_device(self.optimizers, self.root_device)
[docs] @override
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
# If we're setting up for evaluation after fitting, we need to discard the optimizers
# since we're rewrapping the model, otherwise optimizer param references are no longer valid
# and subsequent checkpoint saving can fail
self._reset_optimizers_and_schedulers()
return super().setup_optimizers(trainer)
[docs] @override
def model_to_device(self) -> None:
assert self.model is not None
self.model.to(self.root_device)
[docs] @contextmanager
@override
def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]:
# Materializaton happens in `setup()`
empty_init_context = torch.device("meta") if empty_init else nullcontext()
with empty_init_context, self.precision_plugin.tensor_init_context():
yield
[docs] @override
def barrier(self, name: Optional[str] = None) -> None:
if not _distributed_is_initialized():
return
if torch.distributed.get_backend() == "nccl":
torch.distributed.barrier(device_ids=self._determine_device_ids())
else:
torch.distributed.barrier()
[docs] @override
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
if not _distributed_is_initialized():
return obj
obj = [obj]
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
return obj[0]
[docs] @override
def reduce(
self,
tensor: Union[Tensor, Any],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = "mean",
) -> Tensor:
if isinstance(tensor, Tensor):
return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor
def _determine_device_ids(self) -> list[int]:
return [self.root_device.index]
[docs] @override
def teardown(self) -> None:
assert self.cluster_environment is not None
assert self.accelerator is not None
self.cluster_environment.teardown()
self.precision_plugin.teardown()
self.accelerator.teardown()
[docs] @override
def lightning_module_state_dict(self) -> dict[str, Any]:
"""Collects the state dict of the model.
Only returns a non-empty state dict on rank 0 if ``save_distributed_checkpoint=False``.
"""
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
state_dict_options = StateDictOptions(full_state_dict=(not self._save_distributed_checkpoint), cpu_offload=True)
assert self.model is not None
return get_model_state_dict(self.model, options=state_dict_options)
@override
def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None:
# Override to do nothing, the strategy already loaded the states in `load_checkpoint()`
pass
[docs] @override
def optimizer_state(self, optimizer: Optimizer) -> dict[str, Any]:
"""Collects the state of the given optimizer.
Only returns a non-empty state dict on rank 0 if ``save_distributed_checkpoint=False``.
"""
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import OptimStateKeyType
state_dict_options = StateDictOptions(full_state_dict=(not self._save_distributed_checkpoint), cpu_offload=True)
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer
assert self.model is not None
state_dict = get_optimizer_state_dict(self.model, optimizer, options=state_dict_options)
if not self._save_distributed_checkpoint and self.global_rank == 0:
# Store the optimizer state dict in standard format
state_dict = FSDP.rekey_optim_state_dict(state_dict, OptimStateKeyType.PARAM_ID, self.model)
return state_dict
@override
def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
# Override to do nothing, the strategy already loaded the states in `load_checkpoint()`
pass
[docs] @override
def save_checkpoint(
self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
) -> None:
if storage_options is not None:
raise TypeError(
f"`{type(self).__name__}.save_checkpoint(..., storage_options=...)` is not supported because"
f" `{type(self).__name__}` does not use the `CheckpointIO`."
)
# broadcast the path from rank 0 to ensure all the checkpoints are saved to a common path
path = Path(self.broadcast(filepath))
if path.is_dir() and not self._save_distributed_checkpoint and not _is_sharded_checkpoint(path):
raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}")
if self._save_distributed_checkpoint:
if path.is_file():
path.unlink()
path.mkdir(parents=True, exist_ok=True)
converted_state = {"state_dict": checkpoint.pop("state_dict")}
converted_state.update({
f"optimizer_{idx}": optim_state
for idx, optim_state in enumerate(checkpoint.pop("optimizer_states", []))
})
_distributed_checkpoint_save(converted_state, path)
if self.global_rank == 0:
torch.save(checkpoint, path / _METADATA_FILENAME)
else:
if _is_sharded_checkpoint(path):
shutil.rmtree(path)
return super().save_checkpoint(checkpoint=checkpoint, filepath=path)
@override
def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:
# broadcast the path from rank 0 to ensure all the states are loaded from a common path
path = Path(self.broadcast(checkpoint_path))
state = {
"state_dict": self.model,
**{f"optimizer_{idx}": optimizer for idx, optimizer in enumerate(self.optimizers)},
}
assert self.lightning_module is not None
return _load_checkpoint(
path=path,
state=state,
strict=self.lightning_module.strict_loading,
optimizer_states_from_list=True,
)
def _setup_distributed(self) -> None:
super().setup_environment()
reset_seed()
self.set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
def set_world_ranks(self) -> None:
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank