Source code for pytorch_lightning.strategies.horovod
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import ExitStack
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import group as dist_group
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_only
if _HOROVOD_AVAILABLE:
import horovod.torch as hvd
[docs]class HorovodStrategy(ParallelStrategy):
"""Plugin for Horovod distributed training integration."""
strategy_name = "horovod"
def __init__(
self,
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
parallel_devices: Optional[List[torch.device]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
):
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=None,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
rank_zero_only.rank = self.global_rank
self._exit_stack: Optional[ExitStack] = None
@property
def global_rank(self) -> int:
return hvd.rank()
@property
def local_rank(self) -> int:
return hvd.local_rank()
@property
def world_size(self) -> int:
return hvd.size()
@property
def root_device(self) -> torch.device:
assert isinstance(self.parallel_devices, list)
return self.parallel_devices[self.local_rank]
@property
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
return distributed_sampler_kwargs
@property
def handles_gradient_accumulation(self) -> bool:
"""Whether the plugin handles gradient accumulation internally."""
return True
[docs] def setup(self, trainer: "pl.Trainer") -> None:
self.model_to_device()
super().setup(trainer)
self._exit_stack = ExitStack()
self._exit_stack.__enter__()
if not trainer.training:
# no need to setup optimizers
return
def _unpack_lightning_optimizer(opt: Optimizer) -> Optimizer:
return opt._optimizer if isinstance(opt, LightningOptimizer) else opt
optimizers = self.optimizers
optimizers = [_unpack_lightning_optimizer(opt) for opt in optimizers]
# Horovod: scale the learning rate by the number of workers to account for
# increased total batch size
for optimizer in optimizers:
for param_group in optimizer.param_groups:
param_group["lr"] *= self.world_size
# Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR
lr_scheduler_configs = self.lr_scheduler_configs
for config in lr_scheduler_configs:
scheduler = config.scheduler
if hasattr(scheduler, "base_lrs"):
scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs] # type: ignore[union-attr]
assert self.lightning_module is not None
# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
hvd.broadcast_parameters(self.lightning_module.state_dict(), root_rank=0)
for optimizer in optimizers:
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
accumulation_scheduler = trainer.accumulation_scheduler
if accumulation_scheduler.epochs != [0]:
raise MisconfigurationException(
"Horovod currently does not support different `accumulate_grad_batches` at different epochs."
)
self.optimizers = self._wrap_optimizers(optimizers, trainer.accumulate_grad_batches)
for optimizer in self.optimizers:
# Synchronization will be performed explicitly following backward()
self._exit_stack.enter_context(optimizer.skip_synchronize())
[docs] def barrier(self, *args: Any, **kwargs: Any) -> None:
if distributed_available():
self.join()
[docs] def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
obj = hvd.broadcast_object(obj, src)
return obj
[docs] def model_to_device(self) -> None:
if self.root_device.type == "cuda":
# this can potentially be removed after #8312. Not done due to lack of horovod testing
torch.cuda.set_device(self.root_device)
assert self.model is not None
self.model.to(self.root_device)
def join(self) -> None:
if self.root_device.type == "cuda":
hvd.join(self.local_rank)
else:
hvd.join()
[docs] def reduce(
self,
tensor: Union[Any, Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = "mean",
) -> Union[Any, Tensor]:
"""Reduces a tensor from several distributed processes to one aggregated tensor.
Args:
tensor: the tensor to sync and reduce
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
Can also be a string 'sum' to calculate the sum during reduction.
Return:
reduced value, except when the input was not a tensor the output remains is unchanged
"""
if group is not None:
raise ValueError("Horovod does not support allreduce using a subcommunicator at this time. Unset `group`.")
if reduce_op in (None, "avg", "mean"):
reduce_op = hvd.Average
elif reduce_op in ("sum", ReduceOp.SUM):
reduce_op = hvd.Sum
else:
raise ValueError(f"unrecognized `reduce_op`: {reduce_op}")
# sync all processes before reduction
self.join()
return hvd.allreduce(tensor, op=reduce_op)
[docs] def all_gather(self, result: Tensor, group: Optional[Any] = dist_group.WORLD, sync_grads: bool = False) -> Tensor:
if group is not None and group != dist_group.WORLD:
raise ValueError("Horovod does not support allgather using a subcommunicator at this time. Unset `group`.")
if len(result.shape) == 0:
# Convert scalars to single dimension tensors
result = result.reshape(1)
# sync and gather all
self.join()
return hvd.allgather(result)
[docs] def post_backward(self, closure_loss: Tensor) -> None:
# synchronize all horovod optimizers.
for optimizer in self.optimizers:
optimizer.synchronize()
def _wrap_optimizers(
self, optimizers: List[Optimizer], accumulate_grad_batches: int
) -> List["hvd.DistributedOptimizer"]:
"""Wraps optimizers to perform gradient aggregation via allreduce."""
assert self.lightning_module is not None
return [
hvd.DistributedOptimizer(
opt,
backward_passes_per_step=accumulate_grad_batches,
named_parameters=self._filter_named_parameters(self.lightning_module, opt),
)
if "horovod" not in str(opt.__class__)
else opt
for opt in optimizers
]
@staticmethod
def _filter_named_parameters(model: nn.Module, optimizer: Optimizer) -> List[Tuple[str, nn.Parameter]]:
opt_params = {p for group in optimizer.param_groups for p in group.get("params", [])}
return [(name, p) for name, p in model.named_parameters() if p in opt_params]
@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
cls.strategy_name,
cls,
description=f"{cls.__class__.__name__}",
)
[docs] def teardown(self) -> None:
# teardown may be called before `_exit_stack` is set
if self._exit_stack:
self._exit_stack.__exit__(None, None, None)
self._exit_stack = None
# Make sure all workers have finished training before returning to the user
self.join()
super().teardown()