Source code for pytorch_lightning.strategies.sharded

# Copyright The PyTorch Lightning team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Dict, Generator, List, Optional, Tuple, Union

from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_only

    from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
    from fairscale.optim import OSS

    from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded
    OSS = ShardedDataParallel = object

[docs]class DDPShardedStrategy(DDPStrategy): """Optimizer and gradient sharded training provided by FairScale.""" strategy_name = "ddp_sharded" _REDUCE_BUFFER_SIZE_DEFAULT: int = 2**23 # 8M
[docs] def setup(self, trainer: "pl.Trainer") -> None: # share ddp pids to all processes self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts) if self._should_run_deadlock_detection(): self._share_information_to_prevent_deadlock() self.accelerator.setup(trainer) # move the model to the correct device self.model_to_device() # skip wrapping the model if we are not fitting as no gradients need to be exchanged trainer_fn = trainer.state.fn if trainer_fn == TrainerFn.FITTING: if self._layer_sync: self.model = self._layer_sync.apply(self.model) self.setup_precision_plugin() if trainer_fn == TrainerFn.FITTING: self.configure_ddp()
def configure_ddp(self) -> None: self._set_ddp_kwargs() self.setup_optimizers(self.model.trainer) self.model, self.optimizers = self._setup_model_and_optimizers( model=LightningShardedDataParallel(self.model), optimizers=self.optimizers, ) optimizers_to_device(self.optimizers, self.root_device) def _set_ddp_kwargs(self) -> None: if "reduce_buffer_size" not in self._ddp_kwargs: # For multi-node training, enabling bucketing will improve performance. self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0 def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: """Wraps the model and optimizers with fairscale components. Return: The model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`. """ optimizers = self._wrap_optimizers(optimizers) model = ShardedDataParallel(model, sharded_optimizer=optimizers, **self._ddp_kwargs) return model, optimizers def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING: return optimizers return self._reinit_optimizers_with_oss(optimizers) def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]: for x, optimizer in enumerate(optimizers): if isinstance(optimizer, LightningOptimizer): optimizer = optimizer._optimizer if not isinstance(optimizer, OSS): optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE: is_fp16 = self.precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF) # For multi-node training, compressing the model shards in fp16 before broadcasting # improves performance. When using PyTorch AMP, it will not degrade # the model performance. zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1 optimizers[x] = zero_optimizer del optimizer return optimizers
[docs] def optimizer_state(self, optimizer: "OSS") -> Optional[dict]: if isinstance(optimizer, LightningOptimizer): optimizer = optimizer._optimizer optimizer.consolidate_state_dict() return self._optim_state_dict(optimizer)
@rank_zero_only def _optim_state_dict(self, optimizer): """ Retrieves state dict only on rank 0, which contains the entire optimizer state after calling :meth:`consolidate_state_dict`. """ return optimizer.state_dict() @property def lightning_module(self) -> Optional["pl.LightningModule"]: if not _FAIRSCALE_AVAILABLE: # pragma: no cover raise MisconfigurationException( "`DDPShardedStrategy` requires `fairscale` to be installed." " Install it by running `pip install fairscale`." ) return unwrap_lightning_module_sharded(self.model) if self.model is not None else None
[docs] def pre_backward(self, closure_loss: Tensor) -> None: pass
[docs] @contextmanager def block_backward_sync(self) -> Generator: """Blocks syncing gradients behaviour on backwards pass. This is useful for skipping sync when accumulating gradients, reducing communication overhead Returns: context manager with sync behaviour off """ if isinstance(self.model, ShardedDataParallel): with self.model.no_sync(): yield None else: yield None
def post_training_step(self): pass @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: strategy_registry.register( "ddp_sharded_find_unused_parameters_false", cls, description="DDP Sharded Strategy with `find_unused_parameters` as False", find_unused_parameters=False, ) strategy_registry.register( cls.strategy_name, cls, description=f"{cls.__class__.__name__}", )

© Copyright Copyright (c) 2018-2023, Lightning AI et al...

Built with Sphinx using a theme provided by Read the Docs.