Source code for pytorch_lightning.strategies.fully_sharded_native
# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importcontextlibimportloggingfromtypingimportAny,Dict,Generator,List,Optional,Type,UnionimporttorchfromtorchimportTensorfromtorch.nnimportModuleimportpytorch_lightningasplfromlightning_fabric.pluginsimportCheckpointIO,ClusterEnvironmentfromlightning_fabric.strategies.fsdpimport(_init_cpu_offload,_optimizer_has_flat_params,_setup_activation_checkpointing,)fromlightning_fabric.utilities.distributedimport(_get_default_process_group_backend_for_device,_init_dist_connection,_sync_ddp_if_available,)fromlightning_fabric.utilities.distributedimportgroupas_groupfromlightning_fabric.utilities.importsimport_TORCH_GREATER_EQUAL_1_12fromlightning_fabric.utilities.optimizerimport_optimizers_to_devicefromlightning_fabric.utilities.seedimportreset_seedfromlightning_fabric.utilities.typesimportProcessGroup,ReduceOpfrompytorch_lightning.overrides.baseimport_LightningModuleWrapperBasefrompytorch_lightning.plugins.precisionimportPrecisionPluginfrompytorch_lightning.plugins.precision.fsdp_native_native_ampimportFullyShardedNativeNativeMixedPrecisionPluginfrompytorch_lightning.strategies.launchers.subprocess_scriptimport_SubprocessScriptLauncherfrompytorch_lightning.strategies.parallelimportParallelStrategyfrompytorch_lightning.strategies.strategyimportTBroadcastfrompytorch_lightning.trainer.statesimportTrainerFnfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.importsimport_TORCH_GREATER_EQUAL_1_13frompytorch_lightning.utilities.model_helpersimportis_overriddenfrompytorch_lightning.utilities.rank_zeroimportrank_zero_info,rank_zero_onlyfrompytorch_lightning.utilities.typesimportSTEP_OUTPUT_distributed_available=torch.distributed.is_available()_fsdp_available=_TORCH_GREATER_EQUAL_1_12and_distributed_availableif_fsdp_available:fromtorch.distributed.fsdp.fully_sharded_data_parallelimport(BackwardPrefetch,CPUOffload,FullyShardedDataParallel,MixedPrecision,)fromtorch.distributed.fsdp.wrapimportenable_wrapelse:FullyShardedDataParallel=None# type: ignore[misc,assignment]MixedPrecision=None# type: ignore[misc,assignment]BackwardPrefetch=None# type: ignore[misc,assignment]CPUOffload=None# type: ignore[misc,assignment]if_distributed_available:fromtorch.distributed.distributed_c10dimport_get_default_grouplog=logging.getLogger(__name__)
[docs]classDDPFullyShardedNativeStrategy(ParallelStrategy):r"""Strategy for Fully Sharded Data Parallel provided by torch.distributed. .. warning:: ``DDPFullyShardedNativeStrategy`` is in BETA and subject to change. The interface can bring breaking changes and new features with the next release of PyTorch. Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model size, whilst using efficient communication to reduce overhead. In practice, this means we can remain at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar to ZeRO-Stage 3. For more information `check out <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api>`__. Defaults have been set and options have been exposed, but may require configuration based on your level of memory/speed efficiency. We suggest having a look at `this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ for more information. Arguments: cpu_offload: Enable offloading parameters and gradients to CPU to save GPU memory at the cost of speed. You can also pass a config: ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently implicitly enables gradient offloading to CPU in order for parameters and gradients to be on same device to work with the optimizer. This API is subject to change. Default: no offloading backward_prefetch: This is an experimental feature that is subject to change in the the near future. It allows users to enable two different backward_prefetch algorithms to help backward communication and computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` or BF16 if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later. activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation checkpointing. This is typically your transformer block (including attention + feed-forward). Enabling this can free up a significant amount of memory at the cost of speed since activations in these layers need to be recomputed during backpropagation. \**kwargs: Passed to the FSDP context manager which will configure the FSDP class when wrapping modules. """strategy_name="fsdp_native"_registered_strategies:List[str]=[]def__init__(self,accelerator:Optional["pl.accelerators.Accelerator"]=None,parallel_devices:Optional[List[torch.device]]=None,cluster_environment:Optional[ClusterEnvironment]=None,checkpoint_io:Optional[CheckpointIO]=None,precision_plugin:Optional[PrecisionPlugin]=None,process_group_backend:Optional[str]=None,cpu_offload:Union[bool,"CPUOffload",None]=None,backward_prefetch:Optional[BackwardPrefetch]=None,mixed_precision:Optional[MixedPrecision]=None,activation_checkpointing:Optional[Union[Type[Module],List[Type[Module]]]]=None,**kwargs:Any,)->None:ifnot_TORCH_GREATER_EQUAL_1_12:raiseMisconfigurationException("`DDPFullyShardedNativeStrategy` is supported from PyTorch v1.12.0 onwards.")super().__init__(accelerator=accelerator,parallel_devices=parallel_devices,cluster_environment=cluster_environment,checkpoint_io=checkpoint_io,precision_plugin=precision_plugin,)self._process_group=Noneself.num_nodes=1self._process_group_backend=process_group_backendself.cpu_offload=_init_cpu_offload(cpu_offload)self.backward_prefetch=backward_prefetchself.mixed_precision=mixed_precisionself._rank_0_will_call_children_scripts:bool=Falseifactivation_checkpointingandnot_TORCH_GREATER_EQUAL_1_13:raiseValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`")activation_checkpointing=activation_checkpointingor[]self._activation_checkpointing=([activation_checkpointing]ifnotisinstance(activation_checkpointing,list)elseactivation_checkpointing)self.kwargs=kwargs@propertydefroot_device(self)->torch.device:assertself.parallel_devicesisnotNonereturnself.parallel_devices[self.local_rank]@propertydefnum_processes(self)->int:returnlen(self.parallel_devices)ifself.parallel_devicesisnotNoneelse0@propertydefprocess_group(self)->Optional[ProcessGroup]:ifself._process_groupisNone:# The strategy should have already initilized process group in setup_environment()self._process_group=_get_default_group()returnself._process_group@propertydefprocess_group_backend(self)->Optional[str]:returnself._process_group_backend@propertydefmixed_precision_config(self)->Optional[MixedPrecision]:ifself.mixed_precision:returnself.mixed_precisionplugin=self.precision_pluginifisinstance(plugin,FullyShardedNativeNativeMixedPrecisionPlugin):returnplugin.mixed_precision_config@propertydefdistributed_sampler_kwargs(self)->Dict:returndict(num_replicas=(self.num_nodes*self.num_processes),rank=self.global_rank)
[docs]defsetup_environment(self)->None:log.detail(f"{self.__class__.__name__}: setting up distributed...")reset_seed()# determine which process we are and world sizeself.set_world_ranks()# set warning rankrank_zero_only.rank=self.global_rankself._process_group_backend=self._get_process_group_backend()assertself.cluster_environmentisnotNone_init_dist_connection(self.cluster_environment,self._process_group_backend)super().setup_environment()
def_get_process_group_backend(self)->str:returnself._process_group_backendor_get_default_process_group_backend_for_device(self.root_device)defset_world_ranks(self)->None:ifself.cluster_environmentisNone:returnself.cluster_environment.set_global_rank(self.node_rank*self.num_processes+self.local_rank)self.cluster_environment.set_world_size(self.num_nodes*self.num_processes)rank_zero_only.rank=self.cluster_environment.global_rank()def_configure_launcher(self)->None:assertself.cluster_environmentisnotNoneifnotself.cluster_environment.creates_processes_externally:self._launcher=_SubprocessScriptLauncher(self.cluster_environment,self.num_processes,self.num_nodes)self._rank_0_will_call_children_scripts=Truedef_setup_model(self,model:torch.nn.Module)->FullyShardedDataParallel:"""Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`assertself.lightning_moduleisnotNoneif"auto_wrap_policy"inself.kwargsandany(isinstance(mod,FullyShardedDataParallel)formodinself.lightning_module.modules()):delself.kwargs["auto_wrap_policy"]log.detail(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}")wrapped_module=FullyShardedDataParallel(module=model,process_group=self.process_group,cpu_offload=self.cpu_offload,backward_prefetch=self.backward_prefetch,mixed_precision=self.mixed_precision_config,device_id=self.root_device.index,**self.kwargs,)# activation checkpointing needs to be set up after wrapping the modelif_TORCH_GREATER_EQUAL_1_13andself._activation_checkpointing:_setup_activation_checkpointing(module=wrapped_module,layers=self._activation_checkpointing)returnwrapped_module
[docs]defsetup(self,trainer:"pl.Trainer")->None:assertself.acceleratorisnotNoneself.accelerator.setup(trainer)# share ddp pids to all processesself._rank_0_will_call_children_scripts=self.broadcast(self._rank_0_will_call_children_scripts)iftrainer.state.fn==TrainerFn.FITTINGandself._layer_sync:assertself.modelisnotNoneself.model=self._layer_sync.apply(self.model)# we set the device so that optimizers can be created with distributed comms.assertself.lightning_moduleisnotNoneself.lightning_module._device=self.root_deviceassertisinstance(self.model,pl.LightningModule)self.model=_LightningModuleWrapperBase(self.model)ifis_overridden("configure_sharded_model",self.lightning_module):rank_zero_info("You have overridden `LightningModule.configure_sharded_model` hook. It will assume that all the layers"" are already wrapped for sharding and won't wrap the entire model using `FullyShardedDataParallel`.")else:self.model=self._setup_model(self.model)self.barrier()self.setup_optimizers(trainer)_optimizers_to_device(self.optimizers,self.root_device)self.setup_precision_plugin()
[docs]defsetup_optimizers(self,trainer:"pl.Trainer")->None:invalid_params_error=Falsetry:super().setup_optimizers(trainer)exceptValueErrorase:if"optimizer got an empty parameter list"notinstr(e):raiseinvalid_params_error=Trueifinvalid_params_errororany(not_optimizer_has_flat_params(optimizer)foroptimizerinself.optimizers):raiseValueError("The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"" optimizer after setting up the model by referencing `self.trainer.model.parameters()` in the"" `configure_optimizers()` hook.")
[docs]defreduce(self,tensor:Union[Tensor,Any],group:Optional[Any]=None,reduce_op:Optional[Union[ReduceOp,str]]="mean",)->Tensor:"""Reduces a tensor from several distributed processes to one aggregated tensor. Args: tensor: the tensor to sync and reduce group: the process group to gather results from. Defaults to all processes (world) reduce_op: the reduction operation. Defaults to 'mean'/'avg'. Can also be a string 'sum' to calculate the sum during reduction. Return: reduced value, except when the input was not a tensor the output remains is unchanged """ifisinstance(tensor,Tensor):tensor=_sync_ddp_if_available(tensor,group,reduce_op=reduce_op)returntensor
[docs]deftraining_step(self,*args:Any,**kwargs:Any)->STEP_OUTPUT:# we don't need precision context since casting is done by FSDP# read `mixed_precision` docstring here: https://pytorch.org/docs/stable/fsdp.htmlassertself.modelisnotNonereturnself.model(*args,**kwargs)
[docs]defteardown(self)->None:rank_zero_info(f"{self.__class__.__name__}: tearing down strategy...")pl_module=self.lightning_moduleif(pl_moduleisnotNone# `self.lightning_module._trainer` can be None if teardown gets called on an exception before# the trainer gets set on the LightningModuleandpl_module._trainerisnotNoneandpl_module._trainer.state.fn==TrainerFn.FITTINGandself._layer_sync):assertself.modelisnotNoneself.model=self._layer_sync.revert(self.model)assertself.cluster_environmentisnotNoneassertself.acceleratorisnotNoneself.cluster_environment.teardown()self.precision_plugin.teardown()self.accelerator.teardown()
@classmethoddefget_registered_strategies(cls)->List[str]:returncls._registered_strategies@classmethoddefregister_strategies(cls,strategy_registry:Dict)->None:if_fsdp_available:strategy_registry.register("fsdp_native",cls,description="Fully Sharded Data Parallel training from torch.distributed.",)cls._registered_strategies.append("fsdp_native")strategy_registry.register("fsdp_native_full_shard_offload",cls,description="Native FSDP with Full Sharding and CPU Offloading",cpu_offload=True,)cls._registered_strategies.append("fsdp_native_full_shard_offload")
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.