Source code for pytorch_lightning.strategies.sharded_spawn
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromcontextlibimportcontextmanagerfromtypingimportAny,Dict,Generator,List,TuplefromtorchimportTensorfromtorch.nnimportModulefromtorch.optimimportOptimizerimportpytorch_lightningasplfromlightning_fabric.utilities.optimizerimport_optimizers_to_devicefrompytorch_lightning.core.optimizerimportLightningOptimizerfrompytorch_lightning.overrides.baseimport_LightningModuleWrapperBase,_LightningPrecisionModuleWrapperBasefrompytorch_lightning.overrides.fairscaleimport_FAIRSCALE_AVAILABLE,_reinit_optimizers_with_ossfrompytorch_lightning.strategies.ddp_spawnimportDDPSpawnStrategyfrompytorch_lightning.trainer.statesimportTrainerFnfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.rank_zeroimportrank_zero_deprecationif_FAIRSCALE_AVAILABLE:fromfairscale.nn.data_parallel.sharded_ddpimportShardedDataParallelfromfairscale.optimimportOSSelse:OSS=ShardedDataParallel=object
[docs]classDDPSpawnShardedStrategy(DDPSpawnStrategy):"""Optimizer sharded training provided by FairScale."""strategy_name="ddp_sharded_spawn"def__init__(self,*args:Any,**kwargs:Any)->None:rank_zero_deprecation("PyTorch Lightning's sharded implementation using FairScale has been deprecated in v1.9.0 and will be"" removed in v2.0.0. You can try using the `Trainer(strategy='fsdp_native')` instead."" The difference is that native FSDP uses PyTorch's implementation and the current strategy uses"" FairScale's implementation (which was upstreamed to PyTorch). After removal, `strategy='fsdp'` will use"" the native version by default.")super().__init__(*args,**kwargs)
[docs]defconnect(self,model:"pl.LightningModule")->None:ifnot_FAIRSCALE_AVAILABLE:# pragma: no coverraiseMisconfigurationException("`DDPSpawnShardedStrategy` requires `fairscale` to be installed."" Install it by running `pip install fairscale`.")returnsuper().connect(model)
defconfigure_ddp(self)->None:# set up optimizers after the wrapped module has been moved to the deviceassertself.lightning_moduleisnotNoneself.setup_optimizers(self.lightning_module.trainer)assertisinstance(self.model,(pl.LightningModule,_LightningPrecisionModuleWrapperBase))self.model,self.optimizers=self._setup_model_and_optimizers(model=_LightningModuleWrapperBase(self.model),optimizers=self.optimizers)_optimizers_to_device(self.optimizers,self.root_device)def_setup_model_and_optimizers(self,model:Module,optimizers:List[Optimizer])->Tuple[Module,List[Optimizer]]:"""Wraps the model and optimizers with fairscale components. Return: The model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`. """optimizers=self._wrap_optimizers(optimizers)model=ShardedDataParallel(model,sharded_optimizer=optimizers,**self._ddp_kwargs)returnmodel,optimizersdef_wrap_optimizers(self,optimizers:List[Optimizer])->List["OSS"]:assertself.lightning_moduleifself.modelisnotNoneandself.lightning_module.trainer.state.fn!=TrainerFn.FITTING:returnoptimizersoptimizers=[o._optimizerifisinstance(o,LightningOptimizer)elseoforoinoptimizers]return_reinit_optimizers_with_oss(optimizers,self.precision_plugin,self.num_nodes)
[docs]@contextmanagerdefblock_backward_sync(self)->Generator:"""Blocks syncing gradients behaviour on backwards pass. This is useful for skipping sync when accumulating gradients, reducing communication overhead Returns: context manager with sync behaviour off """ifisinstance(self.model,ShardedDataParallel):withself.model.no_sync():yieldNoneelse:yieldNone
defpost_training_step(self)->None:pass@classmethoddefregister_strategies(cls,strategy_registry:Dict)->None:strategy_registry.register("ddp_sharded_spawn_find_unused_parameters_false",cls,description="DDP Spawn Sharded Strategy with `find_unused_parameters` as False",find_unused_parameters=False,)strategy_registry.register(cls.strategy_name,cls,description=f"{cls.__class__.__name__}",)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.