Source code for pytorch_lightning.strategies.sharded
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromcontextlibimportcontextmanagerfromtypingimportDict,Generator,List,Optional,Tuple,Unionimporttorchfromtorch.nnimportModulefromtorch.optimimportOptimizerimportpytorch_lightningasplfrompytorch_lightning.core.optimizerimportLightningOptimizerfrompytorch_lightning.strategies.ddpimportDDPStrategyfrompytorch_lightning.trainer.statesimportTrainerFnfrompytorch_lightning.utilities.enumsimportPrecisionTypefrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.importsimport_FAIRSCALE_AVAILABLE,_FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLEfrompytorch_lightning.utilities.rank_zeroimportrank_zero_onlyif_FAIRSCALE_AVAILABLE:fromfairscale.nn.data_parallel.sharded_ddpimportShardedDataParallelfromfairscale.optimimportOSSfrompytorch_lightning.overrides.fairscaleimportLightningShardedDataParallel,unwrap_lightning_module_sharded
[docs]classDDPShardedStrategy(DDPStrategy):"""Optimizer and gradient sharded training provided by FairScale."""strategy_name="ddp_sharded"_REDUCE_BUFFER_SIZE_DEFAULT:int=2**23# 8Mdefconfigure_ddp(self)->None:trainer=self.lightning_module.trainerif"reduce_buffer_size"notinself._ddp_kwargs:# For multi-node training, enabling bucketing will improve performance.self._ddp_kwargs["reduce_buffer_size"]=self._REDUCE_BUFFER_SIZE_DEFAULTifself.num_nodes>1else0self.model,self.optimizers=self._setup_model_and_optimizers(model=LightningShardedDataParallel(self.model),optimizers=trainer.optimizers,)def_setup_model_and_optimizers(self,model:Module,optimizers:List[Optimizer])->Tuple[Module,List[Optimizer]]:"""Wraps the model and optimizers with fairscale components. Return: The model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`. """optimizers=self._wrap_optimizers(optimizers)model=ShardedDataParallel(model,sharded_optimizer=optimizers,**self._ddp_kwargs)returnmodel,optimizersdef_reinit_optimizers_with_oss(self,optimizers:List[Union[Optimizer,LightningOptimizer]])->List["OSS"]:forx,optimizerinenumerate(optimizers):ifisinstance(optimizer,LightningOptimizer):optimizer=optimizer._optimizerifnotisinstance(optimizer,OSS):optim_class=type(optimizer)zero_optimizer=OSS(params=optimizer.param_groups,optim=optim_class,**optimizer.defaults)if_FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:is_fp16=self.precision_plugin.precisionin(PrecisionType.MIXED,PrecisionType.HALF)# For multi-node training, compressing the model shards in fp16 before broadcasting# improves performance. When using PyTorch AMP, it will not degrade# the model performance.zero_optimizer.broadcast_fp16=is_fp16andself.num_nodes>1optimizers[x]=zero_optimizerdeloptimizerreturnoptimizersdef_wrap_optimizers(self,optimizers:List[Optimizer])->List["OSS"]:ifself.modelisnotNoneandself.model.trainer.state.fn!=TrainerFn.FITTING:returnoptimizersreturnself._reinit_optimizers_with_oss(optimizers)
@rank_zero_onlydef_optim_state_dict(self,optimizer):""" Retrieves state dict only on rank 0, which contains the entire optimizer state after calling :meth:`consolidate_state_dict`. """returnoptimizer.state_dict()@propertydeflightning_module(self)->Optional["pl.LightningModule"]:ifnot_FAIRSCALE_AVAILABLE:# pragma: no coverraiseMisconfigurationException("`DDPShardedStrategy` requires `fairscale` to be installed."" Install it by running `pip install fairscale`.")returnunwrap_lightning_module_sharded(self.model)ifself.modelisnotNoneelseNone
[docs]@contextmanagerdefblock_backward_sync(self)->Generator:"""Blocks syncing gradients behaviour on backwards pass. This is useful for skipping sync when accumulating gradients, reducing communication overhead Returns: context manager with sync behaviour off """ifisinstance(self.model,ShardedDataParallel):withself.model.no_sync():yieldNoneelse:yieldNone
defpost_training_step(self):pass@classmethoddefregister_strategies(cls,strategy_registry:Dict)->None:strategy_registry.register("ddp_sharded_find_unused_parameters_false",cls,description="DDP Sharded Strategy with `find_unused_parameters` as False",find_unused_parameters=False,)strategy_registry.register(cls.strategy_name,cls,description=f"{cls.__class__.__name__}",)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.