Source code for pytorch_lightning.utilities.distributed
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities that can be used with distributed training."""
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch.nn.parallel.distributed import DistributedDataParallel
import pytorch_lightning as pl
from pytorch_lightning.utilities.imports import _HPU_AVAILABLE, _TORCH_GREATER_EQUAL_1_9, _TPU_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_debug as new_rank_zero_debug
from pytorch_lightning.utilities.rank_zero import rank_zero_only # noqa: F401
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
from pytorch_lightning.utilities.rank_zero import rank_zero_info as new_rank_zero_info
from pytorch_lightning.utilities.rank_zero import rank_zero_warn as new_rank_zero_warn
if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
if torch.distributed.is_available():
from torch.distributed import group, ReduceOp
else:
class ReduceOp: # type: ignore # (see https://github.com/python/mypy/issues/1153)
SUM = None
class group: # type: ignore
WORLD = None
log = logging.getLogger(__name__)
[docs]def gather_all_tensors(result: torch.Tensor, group: Optional[Any] = None) -> List[torch.Tensor]:
"""Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes.
Args:
result: the value to sync
group: the process group to gather results from. Defaults to all processes (world)
Return:
gathered_result: list with size equal to the process group where
gathered_result[i] corresponds to result tensor from process i
"""
if group is None:
group = torch.distributed.group.WORLD
# convert tensors to contiguous format
result = result.contiguous()
world_size = torch.distributed.get_world_size(group)
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
# sync and broadcast all
torch.distributed.barrier(group=group)
torch.distributed.all_gather(gathered_result, result, group)
return gathered_result
def distributed_available() -> bool:
return torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed()
[docs]def sync_ddp_if_available(
result: torch.Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
) -> torch.Tensor:
"""Function to reduce a tensor across worker processes during distributed training.
Args:
result: the value to sync and reduce (typically tensor or number)
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum.
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
Return:
reduced value
"""
if distributed_available():
return sync_ddp(result, group=group, reduce_op=reduce_op)
return result
[docs]def sync_ddp(
result: torch.Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
) -> torch.Tensor:
"""Function to reduce the tensors from several ddp processes to one main process.
Args:
result: the value to sync and reduce (typically tensor or number)
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum.
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
Return:
reduced value
"""
divide_by_world_size = False
if group is None:
group = torch.distributed.group.WORLD
if isinstance(reduce_op, str):
if reduce_op.lower() in ("avg", "mean"):
op = ReduceOp.SUM
divide_by_world_size = True
else:
op = getattr(ReduceOp, reduce_op.upper())
else:
op = reduce_op
# WA for HPU. HPU doesn't support Long types, forcefully set it to float
if _HPU_AVAILABLE:
is_hpu_backend = os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1"
if is_hpu_backend:
if (result.type() == "torch.LongTensor") or (result.type() == "torch.hpu.LongTensor"):
new_rank_zero_info("Long tensor unsupported on HPU, casting to float")
result = result.float()
# sync all processes before reduction
torch.distributed.barrier(group=group)
torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
if divide_by_world_size:
result = result / torch.distributed.get_world_size(group)
return result
[docs]class AllGatherGrad(torch.autograd.Function):
[docs] @staticmethod
def forward(
ctx: Any,
tensor: torch.Tensor,
group: Optional["torch.distributed.ProcessGroup"] = group.WORLD,
) -> torch.Tensor:
ctx.group = group
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(gathered_tensor, tensor, group=group)
gathered_tensor = torch.stack(gathered_tensor, dim=0)
return gathered_tensor
[docs] @staticmethod
def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
grad_output = torch.cat(grad_output)
torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group)
return grad_output[torch.distributed.get_rank()], None
[docs]def all_gather_ddp_if_available(
tensor: torch.Tensor, group: Optional["torch.distributed.ProcessGroup"] = None, sync_grads: bool = False
) -> torch.Tensor:
"""Function to gather a tensor from several distributed processes.
Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op
Return:
A tensor of shape (world_size, batch, ...)
"""
group = group if group is not None else torch.distributed.group.WORLD
if distributed_available():
if sync_grads:
return AllGatherGrad.apply(tensor, group)
with torch.no_grad():
return AllGatherGrad.apply(tensor, group)
return tensor
[docs]def register_ddp_comm_hook(
model: DistributedDataParallel,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[Callable] = None,
ddp_comm_wrapper: Optional[Callable] = None,
) -> None:
"""Function to register communication hook for DDP model https://pytorch.org/docs/master/ddp_comm_hooks.html.
Args:
model:
DDP model
ddp_comm_state:
state is passed to the hook and can be used to maintain
and update any state information that users would like to
maintain as part of the training process. Examples: error
feedback in gradient compression, peers to communicate with
next in GossipGrad etc.
ddp_comm_hook:
hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future
This callable function is called once the bucket is ready. The
hook can perform whatever processing is needed and return
a Future indicating completion of any async work (ex: allreduce).
If the hook doesn't perform any communication, it can also
just return a completed Future. The Future should hold the
new value of grad bucket's tensors. Once a bucket is ready,
c10d reducer would call this hook and use the tensors returned
by the Future and copy grads to individual parameters.
ddp_comm_wrapper:
communication hook wrapper to support a communication hook such
as FP16 compression as wrapper, which could be combined with
ddp_comm_hook
.. warning ::
DDP communication hook needs pytorch version at least 1.8.0
.. warning ::
DDP communication wrapper needs pytorch version at least 1.9.0
Post-localSGD hook needs pytorch version at least 1.9.0
Examples:
>>> from torch.distributed.algorithms.ddp_comm_hooks import ( # doctest: +SKIP
... default_hooks as default,
... powerSGD_hook as powerSGD,
... post_localSGD_hook as post_localSGD,
... )
>>>
>>> # fp16_compress_hook for compress gradients
>>> ddp_model = ...
>>> register_ddp_comm_hook( # doctest: +SKIP
... model=ddp_model,
... ddp_comm_hook=default.fp16_compress_hook,
... )
>>>
>>> # powerSGD_hook
>>> ddp_model = ...
>>> register_ddp_comm_hook( # doctest: +SKIP
... model=ddp_model,
... ddp_comm_state=powerSGD.PowerSGDState(
... process_group=None,
... matrix_approximation_rank=1,
... start_powerSGD_iter=5000,
... ),
... ddp_comm_hook=powerSGD.powerSGD_hook,
... )
>>>
>>> # post_localSGD_hook
>>> subgroup, _ = torch.distributed.new_subgroups() # doctest: +SKIP
>>> ddp_model = ...
>>> register_ddp_comm_hook( # doctest: +SKIP
... model=ddp_model,
... state=post_localSGD.PostLocalSGDState(
... process_group=None,
... subgroup=subgroup,
... start_localSGD_iter=1_000,
... ),
... ddp_comm_hook=post_localSGD.post_localSGD_hook,
... )
>>>
>>> # fp16_compress_wrapper combined with other communication hook
>>> ddp_model = ...
>>> register_ddp_comm_hook( # doctest: +SKIP
... model=ddp_model,
... ddp_comm_state=powerSGD.PowerSGDState(
... process_group=None,
... matrix_approximation_rank=1,
... start_powerSGD_iter=5000,
... ),
... ddp_comm_hook=powerSGD.powerSGD_hook,
... ddp_comm_wrapper=default.fp16_compress_wrapper,
... )
"""
if ddp_comm_hook is None:
return
# inform mypy that ddp_comm_hook is callable
ddp_comm_hook: Callable = ddp_comm_hook
if ddp_comm_wrapper is not None:
if not _TORCH_GREATER_EQUAL_1_9:
new_rank_zero_warn(
"Not applying DDP comm wrapper. To use communication wrapper, please use pytorch>=1.9.0."
)
else:
new_rank_zero_info(
f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})."
)
ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook)
new_rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.")
model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook)
def tpu_distributed() -> bool:
return _TPU_AVAILABLE and xm.xrt_world_size() > 1
def get_default_process_group_backend_for_device(device: torch.device) -> str:
return "nccl" if device.type == "cuda" else "gloo"
def _get_process_group_backend_from_env() -> Optional[str]:
torch_backend = os.getenv("PL_TORCH_DISTRIBUTED_BACKEND")
if torch_backend is not None:
rank_zero_deprecation(
"Environment variable `PL_TORCH_DISTRIBUTED_BACKEND`"
" was deprecated in v1.6 and will be removed in v1.8."
" Specify `process_group_backend` directly on the strategy constructor."
)
return torch_backend
[docs]def init_dist_connection(
cluster_environment: "pl.plugins.environments.ClusterEnvironment",
torch_distributed_backend: str,
global_rank: Optional[int] = None,
world_size: Optional[int] = None,
**kwargs: Any,
) -> None:
"""Utility function to initialize distributed connection by setting env variables and initializing the
distributed process group.
Args:
cluster_environment: ``ClusterEnvironment`` instance
torch_distributed_backend: backend to use (includes `nccl` and `gloo`)
global_rank: rank of the current process
world_size: number of processes in the group
kwargs: kwargs for ``init_process_group``
Raises:
RuntimeError:
If ``torch.distributed`` is not available
"""
if not torch.distributed.is_available():
raise RuntimeError("torch.distributed is not available. Cannot initialize distributed process group")
if torch.distributed.is_initialized():
log.debug("torch.distributed is already initialized. Exiting early")
return
global_rank = global_rank if global_rank is not None else cluster_environment.global_rank()
world_size = world_size if world_size is not None else cluster_environment.world_size()
os.environ["MASTER_ADDR"] = cluster_environment.main_address
os.environ["MASTER_PORT"] = str(cluster_environment.main_port)
log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs)
# on rank=0 let everyone know training is starting
new_rank_zero_info(
f"{'-' * 100}\n"
f"distributed_backend={torch_distributed_backend}\n"
f"All distributed processes registered. Starting with {world_size} processes\n"
f"{'-' * 100}\n"
)
def _broadcast_object_list(obj: Any, rank: int) -> Any:
objects = [obj if torch.distributed.get_rank() == rank else None]
torch.distributed.broadcast_object_list(objects, src=rank)
return objects[0]
# TODO: Refactor with the Strategy Collectives once finalized.
def _collect_states_on_rank_zero(state: Dict[str, Any]) -> Dict[int, Any]:
"""This distributed utility collects dictionary state across all processes.
Args:
state: Dictionary containing the state of the current process
device: Current process device.
Returns:
states: On global rank 0, a dictionary where the primary keys are
the process rank and the values their associated states. Otherwise, returns None.
"""
if not distributed_available():
return {0: state}
return {rank: _broadcast_object_list(state, rank) for rank in range(torch.distributed.get_world_size())}
def rank_zero_info(*args: Any, **kwargs: Any) -> Any:
rank_zero_deprecation(
"pytorch_lightning.utilities.distributed.rank_zero_info has been deprecated in v1.6"
" and will be removed in v1.8."
" Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead."
)
return new_rank_zero_info(*args, **kwargs)
def rank_zero_debug(*args: Any, **kwargs: Any) -> Any:
rank_zero_deprecation(
"pytorch_lightning.utilities.distributed.rank_zero_debug has been deprecated in v1.6"
" and will be removed in v1.8."
" Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead."
)
return new_rank_zero_debug(*args, **kwargs)