# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importfunctoolsfromcontextlibimportcontextmanagerfromdatetimeimporttimedeltafromtypingimportAny,Dict,Generator,List,Optional,Tuple,Type,TYPE_CHECKING,UnionimporttorchfromtorchimportTensorfromtorch.nnimportModulefromtorch.optimimportOptimizerfromlightning_fabric.acceleratorsimportAcceleratorfromlightning_fabric.pluginsimportCheckpointIO,ClusterEnvironment,Precisionfromlightning_fabric.plugins.collectives.torch_collectiveimportdefault_pg_timeoutfromlightning_fabric.plugins.precision.fsdpimportFSDPPrecisionfromlightning_fabric.strategies.launchers.subprocess_scriptimport_SubprocessScriptLauncherfromlightning_fabric.strategies.parallelimportParallelStrategyfromlightning_fabric.strategies.strategyimport_BackwardSyncControl,_Sharded,TBroadcastfromlightning_fabric.utilities.distributedimport(_distributed_available,_get_default_process_group_backend_for_device,_init_dist_connection,_sync_ddp_if_available,)fromlightning_fabric.utilities.distributedimportgroupas_groupfromlightning_fabric.utilities.distributedimportReduceOpfromlightning_fabric.utilities.importsimport_TORCH_GREATER_EQUAL_1_12,_TORCH_GREATER_EQUAL_1_13fromlightning_fabric.utilities.rank_zeroimportrank_zero_onlyfromlightning_fabric.utilities.seedimportreset_seedifTYPE_CHECKING:fromtorch.distributed.fsdp.fully_sharded_data_parallelimport(BackwardPrefetch,CPUOffload,FullyShardedDataParallel,MixedPrecision,)_FSDP_ALIASES=("fsdp","fsdp_full_shard_offload")
[docs]classFSDPStrategy(ParallelStrategy,_Sharded):r"""Strategy for Fully Sharded Data Parallel provided by torch.distributed. .. warning:: ``FSDPStrategy`` is in BETA and subject to change. The interface can bring breaking changes and new features with the next release of PyTorch. Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model size, whilst using efficient communication to reduce overhead. In practice, this means we can remain at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar to ZeRO-Stage 3. For more information `check out <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api>`__. Defaults have been set and options have been exposed, but may require configuration based on your level of memory/speed efficiency. We suggest having a look at `this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ for more information. Arguments: cpu_offload: Enable offloading parameters and gradients to CPU to save GPU memory at the cost of speed. You can also pass a config: ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently implicitly enables gradient offloading to CPU in order for parameters and gradients to be on same device to work with the optimizer. This API is subject to change. Default: no offloading backward_prefetch: This is an experimental feature that is subject to change in the near future. It allows users to enable two different backward prefetching algorithms to help backward communication and computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` or BF16 if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later. activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation checkpointing. This is typically your transformer block (including attention + feed-forward). Enabling this can free up a significant amount of memory at the cost of speed since activations in these layers need to be recomputed during backpropagation. \**kwargs: Optional keyword arguments passed to the FSDP context manager which will configure the FSDP class when wrapping modules. """def__init__(self,accelerator:Optional[Accelerator]=None,parallel_devices:Optional[List[torch.device]]=None,cluster_environment:Optional[ClusterEnvironment]=None,checkpoint_io:Optional[CheckpointIO]=None,precision:Optional[Precision]=None,process_group_backend:Optional[str]=None,timeout:Optional[timedelta]=default_pg_timeout,cpu_offload:Union[bool,"CPUOffload",None]=None,backward_prefetch:Optional["BackwardPrefetch"]=None,mixed_precision:Optional["MixedPrecision"]=None,activation_checkpointing:Optional[Union[Type[Module],List[Type[Module]]]]=None,**kwargs:Any,)->None:ifnot_TORCH_GREATER_EQUAL_1_12:raiseNotImplementedError("`FSDPStrategy` is supported from PyTorch v1.12.0 onwards.")super().__init__(accelerator=accelerator,parallel_devices=parallel_devices,cluster_environment=cluster_environment,checkpoint_io=checkpoint_io,precision=precision,)self._num_nodes=1self._process_group_backend:Optional[str]=process_group_backendself._timeout:Optional[timedelta]=timeoutself._backward_sync_control=_FSDPBackwardSyncControl()self._ddp_kwargs=kwargsifactivation_checkpointingandnot_TORCH_GREATER_EQUAL_1_13:raiseValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`")activation_checkpointing=activation_checkpointingor[]self._activation_checkpointing=([activation_checkpointing]ifnotisinstance(activation_checkpointing,list)elseactivation_checkpointing)self.cpu_offload=_init_cpu_offload(cpu_offload)self.backward_prefetch=backward_prefetchself.mixed_precision=mixed_precision@propertydefroot_device(self)->torch.device:assertself.parallel_devicesisnotNonereturnself.parallel_devices[self.local_rank]@propertydefnum_nodes(self)->int:returnself._num_nodes@num_nodes.setterdefnum_nodes(self,num_nodes:int)->None:self._num_nodes=num_nodes@propertydefnum_processes(self)->int:returnlen(self.parallel_devices)ifself.parallel_devicesisnotNoneelse0@propertydefdistributed_sampler_kwargs(self)->Dict[str,Any]:returndict(num_replicas=(self.num_nodes*self.num_processes),rank=self.global_rank)@propertydefprocess_group_backend(self)->Optional[str]:returnself._process_group_backend@propertydefmixed_precision_config(self)->Optional["MixedPrecision"]:ifself.mixed_precision:returnself.mixed_precisionifisinstance(self.precision,FSDPPrecision):returnself.precision.mixed_precision_configdef_configure_launcher(self)->None:assertself.cluster_environmentisnotNoneifnotself.cluster_environment.creates_processes_externally:self._launcher=_SubprocessScriptLauncher(self.cluster_environment,self.num_processes,self.num_nodes)
[docs]defsetup_module_and_optimizers(self,module:Module,optimizers:List[Optimizer])->Tuple[Module,List[Optimizer]]:raiseNotImplementedError(f"The `{type(self).__name__}` does not support the joint setup of module and optimizer(s)."" Please do it in this order: Create the model, call `setup_module`, create the optimizer,"" call `setup_optimizer`.")
[docs]defsetup_module(self,module:Module)->"FullyShardedDataParallel":"""Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""fromtorch.distributed.fsdp.fully_sharded_data_parallelimportFullyShardedDataParallelif"auto_wrap_policy"inself._ddp_kwargsandany(isinstance(mod,FullyShardedDataParallel)formodinmodule.modules()):# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`delself._ddp_kwargs["auto_wrap_policy"]wrapped_module=FullyShardedDataParallel(module=module,cpu_offload=self.cpu_offload,backward_prefetch=self.backward_prefetch,mixed_precision=self.mixed_precision_config,device_id=self.root_device.index,**self._ddp_kwargs,)# activation checkpointing needs to be set up after wrapping the modelif_TORCH_GREATER_EQUAL_1_13andself._activation_checkpointing:_setup_activation_checkpointing(module=wrapped_module,layers=self._activation_checkpointing)returnwrapped_module
[docs]defsetup_optimizer(self,optimizer:Optimizer)->Optimizer:"""Set up an optimizer for a model wrapped with FSDP. This setup method doesn't modify the optimizer or wrap the optimizer. The only thing it currently does is verify that the optimizer was created after the model was wrapped with :meth:`setup_module` with a reference to the flattened parameters. """fromtorch.distributed.fsdpimportFlatParameternum_groups=len(optimizer.param_groups)ifnum_groups>1:raiseValueError("An optimizer used with an FSDP model does not support multiple param groups."f" Found {num_groups} parameter groups.")ifany(isinstance(param,FlatParameter)forparaminoptimizer.param_groups[0]["params"]):returnoptimizerraiseValueError("The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the optimizer"" after setting up the model.")
@classmethoddefregister_strategies(cls,strategy_registry:Dict)->None:ifnot_TORCH_GREATER_EQUAL_1_12ornottorch.distributed.is_available():returnstrategy_registry.register("fsdp",cls,description="Fully Sharded Data Parallel",)strategy_registry.register("fsdp_full_shard_offload",cls,description="Fully Sharded Data Parallel and CPU Offloading",cpu_offload=True,)def_setup_distributed(self)->None:reset_seed()self._set_world_ranks()rank_zero_only.rank=self.global_rankself._process_group_backend=self._get_process_group_backend()assertself.cluster_environmentisnotNone_init_dist_connection(self.cluster_environment,self._process_group_backend,timeout=self._timeout)def_get_process_group_backend(self)->str:returnself._process_group_backendor_get_default_process_group_backend_for_device(self.root_device)def_set_world_ranks(self)->None:ifself.cluster_environmentisNone:returnself.cluster_environment.set_global_rank(self.node_rank*self.num_processes+self.local_rank)self.cluster_environment.set_world_size(self.num_nodes*self.num_processes)rank_zero_only.rank=self.cluster_environment.global_rank()
def_setup_activation_checkpointing(module:"FullyShardedDataParallel",layers:List[Type[Module]])->None:fromtorch.distributed.algorithms._checkpoint.checkpoint_wrapperimport(apply_activation_checkpointing,checkpoint_wrapper,CheckpointImpl,)check_fn=lambdasubmodule:isinstance(submodule,tuple(layers))wrapper=functools.partial(checkpoint_wrapper,checkpoint_impl=CheckpointImpl.NO_REENTRANT,)apply_activation_checkpointing(module,checkpoint_wrapper_fn=wrapper,check_fn=check_fn)class_FSDPBackwardSyncControl(_BackwardSyncControl):@contextmanagerdefno_backward_sync(self,module:Module)->Generator:"""Blocks gradient synchronization inside the :class:`~torch.distributed.fsdp.FullyShardedDataParallel` wrapper."""fromtorch.distributed.fsdp.fully_sharded_data_parallelimportFullyShardedDataParallelifnotisinstance(module,FullyShardedDataParallel):raiseTypeError("Blocking backward sync is only possible if the module passed to"f" `{self.__class__.__name__}.no_backward_sync` is wrapped in `FullyShardedDataParallel`."f" Got: {module.__class__.__name__}.")withmodule.no_sync():yielddef_init_cpu_offload(cpu_offload:Optional[Union[bool,"CPUOffload"]])->"CPUOffload":fromtorch.distributed.fsdpimportCPUOffloadreturncpu_offloadifisinstance(cpu_offload,CPUOffload)elseCPUOffload(offload_params=bool(cpu_offload))def_optimizer_has_flat_params(optimizer:Optimizer)->bool:fromtorch.distributed.fsdpimportFlatParameterreturnany(isinstance(param,FlatParameter)forparaminoptimizer.param_groups[0]["params"])
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.