# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importloggingimportshutilfromcontextlibimportcontextmanager,nullcontextfromdatetimeimporttimedeltafrompathlibimportPathfromtypingimportTYPE_CHECKING,Any,Callable,Dict,Generator,List,Literal,Mapping,Optional,Set,Type,Unionimporttorchfromlightning_utilities.core.rank_zeroimportrank_zero_onlyasutils_rank_zero_onlyfromtorchimportTensorfromtorch.nnimportModulefromtorch.optimimportOptimizerimportlightning.pytorchasplfromlightning.fabric.pluginsimportCheckpointIO,ClusterEnvironmentfromlightning.fabric.plugins.collectives.torch_collectiveimportdefault_pg_timeoutfromlightning.fabric.strategiesimport_StrategyRegistryfromlightning.fabric.strategies.fsdpimport(_METADATA_FILENAME,_activation_checkpointing_kwargs,_auto_wrap_policy_kwargs,_get_full_state_dict_context,_get_sharded_state_dict_context,_has_meta_device_parameters,_init_cpu_offload,_init_sharding_strategy,_is_full_checkpoint,_is_sharded_checkpoint,_load_raw_module_state,_move_torchmetrics_to_device,_optimizer_has_flat_params,_setup_activation_checkpointing,)fromlightning.fabric.utilities.distributedimport(_distributed_is_initialized,_get_default_process_group_backend_for_device,_init_dist_connection,_sync_ddp_if_available,)fromlightning.fabric.utilities.distributedimportgroupas_groupfromlightning.fabric.utilities.importsimport(_TORCH_GREATER_EQUAL_1_13,_TORCH_GREATER_EQUAL_2_0,_TORCH_GREATER_EQUAL_2_1,)fromlightning.fabric.utilities.initimport_EmptyInitfromlightning.fabric.utilities.loadimport_lazy_load,_materialize_tensorsfromlightning.fabric.utilities.optimizerimport_optimizers_to_devicefromlightning.fabric.utilities.seedimportreset_seedfromlightning.fabric.utilities.typesimport_PATH,ReduceOpfromlightning.pytorch.core.optimizerimportLightningOptimizerfromlightning.pytorch.plugins.precisionimportPrecisionfromlightning.pytorch.plugins.precision.fsdpimportFSDPPrecisionfromlightning.pytorch.strategies.launchers.subprocess_scriptimport_SubprocessScriptLauncherfromlightning.pytorch.strategies.parallelimportParallelStrategyfromlightning.pytorch.strategies.strategyimportTBroadcastfromlightning.pytorch.trainer.statesimportTrainerFnfromlightning.pytorch.utilities.model_helpersimportis_overriddenfromlightning.pytorch.utilities.rank_zeroimportrank_zero_info,rank_zero_only,rank_zero_warnifTYPE_CHECKING:fromtorch.distributed.fsdp.fully_sharded_data_parallelimportCPUOffload,MixedPrecision,ShardingStrategyif_TORCH_GREATER_EQUAL_2_0:fromtorch.distributed.fsdp.wrapimportModuleWrapPolicy_POLICY=Union[Set[Type[Module]],Callable[[Module,bool,int],bool],ModuleWrapPolicy]else:_POLICY=Union[Set[Type[Module]],Callable[[Module,bool,int],bool]]# type: ignore[misc]_SHARDING_STRATEGY=Union[ShardingStrategy,Literal["FULL_SHARD","SHARD_GRAD_OP","NO_SHARD","HYBRID_SHARD"]]log=logging.getLogger(__name__)
[docs]classFSDPStrategy(ParallelStrategy):r"""Strategy for Fully Sharded Data Parallel provided by torch.distributed. .. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature. Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model size, whilst using efficient communication to reduce overhead. In practice, this means we can remain at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar to ZeRO-Stage 3. For more information check out `this blogpost <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api>`__. Defaults have been set and options have been exposed, but may require configuration based on your level of memory/speed efficiency. We suggest having a look at `this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ for more information. Arguments: cpu_offload: See ``cpu_offload`` parameter in :class:`torch.distributed.fsdp.FullyShardedDataParallel`. mixed_precision: See ``mixed_precision`` parameter in :class:`torch.distributed.fsdp.FullyShardedDataParallel`. auto_wrap_policy: Same as ``auto_wrap_policy`` parameter in :class:`torch.distributed.fsdp.FullyShardedDataParallel`. For convenience, this also accepts a set of the layer classes to wrap. activation_checkpointing: Deprecated. Use ``activation_checkpointing_policy``. activation_checkpointing_policy: Same as ``auto_wrap_policy`` parameter in :class:`torch.distributed.fsdp.FullyShardedDataParallel` but used when selecting the modules for which you want to enable activation checkpointing. Enabling this can free up a significant amount of memory at the cost of speed since activations in these layers need to be recomputed during backpropagation. For convenience, this also accepts a set of the layer classes to wrap. sharding_strategy: Select whether to shard model parameters, gradients, optimizer states, or a combination of them. Available values are: - ``"FULL_SHARD"``: Shards model parameters, gradients, and optimizer states (default). - ``"SHARD_GRAD_OP"``: Shards gradients and optimizer states only. Model parameters get replicated. - ``"NO_SHARD"``: No sharding (identical to regular DDP). - ``"HYBRID_SHARD"``: Shards model parameters, gradients, and optimizer states within a single machine, but replicates across machines. Also accepts a :class:`torch.distributed.fsdp.ShardingStrategy` enum value. state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint. - ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file. - ``"sharded"``: Each rank saves its shard of weights and optimizer states to a file. The checkpoint is a folder with as many files as the world size. \**kwargs: See available parameters in :class:`torch.distributed.fsdp.FullyShardedDataParallel`. """strategy_name="fsdp"_registered_strategies:List[str]=[]def__init__(self,accelerator:Optional["pl.accelerators.Accelerator"]=None,parallel_devices:Optional[List[torch.device]]=None,cluster_environment:Optional[ClusterEnvironment]=None,checkpoint_io:Optional[CheckpointIO]=None,precision_plugin:Optional[Precision]=None,process_group_backend:Optional[str]=None,timeout:Optional[timedelta]=default_pg_timeout,cpu_offload:Union[bool,"CPUOffload",None]=None,mixed_precision:Optional["MixedPrecision"]=None,auto_wrap_policy:Optional["_POLICY"]=None,activation_checkpointing:Optional[Union[Type[Module],List[Type[Module]]]]=None,activation_checkpointing_policy:Optional["_POLICY"]=None,sharding_strategy:"_SHARDING_STRATEGY"="FULL_SHARD",state_dict_type:Literal["full","sharded"]="full",**kwargs:Any,)->None:super().__init__(accelerator=accelerator,parallel_devices=parallel_devices,cluster_environment=cluster_environment,checkpoint_io=checkpoint_io,precision_plugin=precision_plugin,)self.num_nodes=1self._process_group_backend=process_group_backendself._timeout:Optional[timedelta]=timeoutself.cpu_offload=_init_cpu_offload(cpu_offload)self.mixed_precision=mixed_precisionself.kwargs=_auto_wrap_policy_kwargs(auto_wrap_policy,kwargs)self.sharding_strategy=_init_sharding_strategy(sharding_strategy,self.kwargs)if_TORCH_GREATER_EQUAL_2_0:# Avoids the need for user to reference params in `configure_optimizers` via# `self.trainer.model.parameters()` and enables support for multiple parameter groups.self.kwargs.setdefault("use_orig_params",True)self._activation_checkpointing_kwargs=_activation_checkpointing_kwargs(activation_checkpointing,activation_checkpointing_policy)ifstate_dict_type=="sharded"andnot_TORCH_GREATER_EQUAL_2_0:raiseNotImplementedError("Saving checkpoints with `FSDPStrategy(state_dict_type='sharded')` is not supported in PyTorch < 2.0."" Please upgrade `torch`.")self._state_dict_type=state_dict_type@propertydefroot_device(self)->torch.device:assertself.parallel_devicesisnotNonereturnself.parallel_devices[self.local_rank]@propertydefnum_processes(self)->int:returnlen(self.parallel_devices)ifself.parallel_devicesisnotNoneelse0@propertydefprocess_group_backend(self)->Optional[str]:returnself._process_group_backend@propertydefmixed_precision_config(self)->Optional["MixedPrecision"]:ifself.mixed_precision:returnself.mixed_precisionplugin=self.precision_pluginifisinstance(plugin,FSDPPrecision):returnplugin.mixed_precision_configreturnNone@property# type: ignore[override]defprecision_plugin(self)->FSDPPrecision:plugin=self._precision_pluginifpluginisnotNone:assertisinstance(plugin,FSDPPrecision)returnpluginreturnFSDPPrecision("32-true")@precision_plugin.setterdefprecision_plugin(self,precision_plugin:Optional[FSDPPrecision])->None:ifprecision_pluginisnotNoneandnotisinstance(precision_plugin,FSDPPrecision):raiseTypeError(f"The FSDP strategy can only work with the `FSDPPrecision` plugin, found {precision_plugin}")self._precision_plugin=precision_plugin@propertydefdistributed_sampler_kwargs(self)->Dict:return{"num_replicas":(self.num_nodes*self.num_processes),"rank":self.global_rank}@propertydefrestore_checkpoint_after_setup(self)->bool:returnTrue@propertydeflightning_restore_optimizer(self)->bool:returnFalse
[docs]defsetup_environment(self)->None:log.debug(f"{self.__class__.__name__}: setting up distributed...")reset_seed()# determine which process we are and world sizeself.set_world_ranks()self._process_group_backend=self._get_process_group_backend()assertself.cluster_environmentisnotNone_init_dist_connection(self.cluster_environment,self._process_group_backend,timeout=self._timeout)super().setup_environment()
def_get_process_group_backend(self)->str:returnself._process_group_backendor_get_default_process_group_backend_for_device(self.root_device)defset_world_ranks(self)->None:ifself.cluster_environmentisnotNone:self.cluster_environment.set_global_rank(self.node_rank*self.num_processes+self.local_rank)self.cluster_environment.set_world_size(self.num_nodes*self.num_processes)# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail# additionally, for some implementations, the setter is a no-op, so it's safer to access the getterrank_zero_only.rank=utils_rank_zero_only.rank=self.global_rankdef_configure_launcher(self)->None:assertself.cluster_environmentisnotNoneifnotself.cluster_environment.creates_processes_externally:self._launcher=_SubprocessScriptLauncher(self.cluster_environment,self.num_processes,self.num_nodes)def_setup_model(self,model:Module)->Module:"""Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""fromtorch.distributed.fsdpimportFullyShardedDataParallelifany(isinstance(mod,FullyShardedDataParallel)formodinmodel.modules()):if_has_meta_device_parameters(model):rank_zero_warn("The model is already wrapped in `FSDP` but there are still parameters on the meta device.")if"auto_wrap_policy"inself.kwargs:# The user has wrapped their submodules manually, don't apply the auto wrap policy.rank_zero_warn("A FSDP `auto_wrap_policy` is set, but the model is already wrapped. The policy will be ignored.")delself.kwargs["auto_wrap_policy"]else:log.debug(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}")model=FullyShardedDataParallel(module=model,cpu_offload=self.cpu_offload,mixed_precision=self.mixed_precision_config,sharding_strategy=self.sharding_strategy,device_id=self.root_device.index,**self.kwargs,)_move_torchmetrics_to_device(model,self.root_device)# activation checkpointing needs to be set up after wrapping the modelif_TORCH_GREATER_EQUAL_1_13:_setup_activation_checkpointing(model,self._activation_checkpointing_kwargs)returnmodel
[docs]defsetup(self,trainer:"pl.Trainer")->None:assertself.acceleratorisnotNoneassertself.modelisnotNoneself.accelerator.setup(trainer)iftrainer.state.fn==TrainerFn.FITTINGandself._layer_sync:self.model=self._layer_sync.apply(self.model)# we set the device so that optimizers can be created with distributed comms.assertself.lightning_moduleisnotNoneself.lightning_module._device=self.root_deviceifis_overridden("configure_sharded_model",self.lightning_module):# legacy: we don't skip setup with the `configure_model` alternativerank_zero_info("You have overridden `LightningModule.configure_sharded_model` hook. It will assume that all the layers"" are already wrapped for sharding and won't wrap the entire model using `FullyShardedDataParallel`.")else:self.model=self._setup_model(self.model)self.barrier()self.setup_optimizers(trainer)_optimizers_to_device(self.optimizers,self.root_device)self.setup_precision_plugin()
[docs]defsetup_optimizers(self,trainer:"pl.Trainer")->None:# If we're setting up for evaluation after fitting, we need to discard the optimizers# since we're rewrapping the model, otherwise optimizer param references are no longer valid# and subsequent checkpoint saving can failself._reset_optimizers_and_schedulers()ifself.kwargs.get("use_orig_params"):returnsuper().setup_optimizers(trainer)invalid_params_error=Falsetry:# In PyTorch < 2.0, or if `use_orig_params=False` the user needs to do access# `self.trainer.model.parameters()` in configure_optimizers()super().setup_optimizers(trainer)exceptValueErrorasex:if"optimizer got an empty parameter list"notinstr(ex):raiseinvalid_params_error=Trueifinvalid_params_errororany(not_optimizer_has_flat_params(optimizer)foroptimizerinself.optimizers):# We avoid this limitation in PyTorch >= 2.0 by setting `use_orig_params=True`raiseValueError("The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"" optimizer after setting up the model by referencing `self.trainer.model.parameters()` in the"" `configure_optimizers()` hook.")returnNone
[docs]@contextmanagerdeftensor_init_context(self,empty_init:Optional[bool]=None)->Generator[None,None,None]:empty_init_context:Union[torch.device,_EmptyInit,nullcontext]if_TORCH_GREATER_EQUAL_2_1andempty_init:# Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is:# 1) materialize module 2) call `reset_parameters()` 3) shard the module.# These operations are applied to each submodule 'bottom up' in the module hierarchy.empty_init_context=torch.device("meta")elif_TORCH_GREATER_EQUAL_1_13:empty_init_context=_EmptyInit(enabled=bool(empty_init))else:empty_init_context=nullcontext()withempty_init_context,self.precision_plugin.tensor_init_context():yield
[docs]defreduce(self,tensor:Union[Tensor,Any],group:Optional[Any]=None,reduce_op:Optional[Union[ReduceOp,str]]="mean",)->Tensor:"""Reduces a tensor from several distributed processes to one aggregated tensor. Args: tensor: the tensor to sync and reduce group: the process group to gather results from. Defaults to all processes (world) reduce_op: the reduction operation. Defaults to 'mean'/'avg'. Can also be a string 'sum' to calculate the sum during reduction. Return: reduced value, except when the input was not a tensor the output remains is unchanged """ifisinstance(tensor,Tensor):return_sync_ddp_if_available(tensor,group,reduce_op=reduce_op)returntensor
[docs]defteardown(self)->None:log.debug(f"{self.__class__.__name__}: tearing down strategy...")pl_module=self.lightning_moduleif(pl_moduleisnotNone# `self.lightning_module._trainer` can be None if teardown gets called on an exception before# the trainer gets set on the LightningModuleandpl_module._trainerisnotNoneandpl_module._trainer.state.fn==TrainerFn.FITTINGandself._layer_sync):assertself.modelisnotNoneself.model=self._layer_sync.revert(self.model)assertself.cluster_environmentisnotNoneassertself.acceleratorisnotNoneself.cluster_environment.teardown()self.precision_plugin.teardown()self.accelerator.teardown()
@classmethoddefget_registered_strategies(cls)->List[str]:returncls._registered_strategies@classmethoddefregister_strategies(cls,strategy_registry:_StrategyRegistry)->None:ifnottorch.distributed.is_available():returnstrategy_registry.register("fsdp",cls,description="Fully Sharded Data Parallel (FSDP) training",)cls._registered_strategies.append("fsdp")strategy_registry.register("fsdp_cpu_offload",cls,description="Fully Sharded Data Parallel (FSDP) training with Full Sharding and CPU Offloading",cpu_offload=True,)cls._registered_strategies.append("fsdp_cpu_offload")
defload_model_state_dict(self,checkpoint:Mapping[str,Any])->None:# Override to do nothing, FSDP already loaded the states in `load_checkpoint()`pass
[docs]defoptimizer_state(self,optimizer:Optimizer)->Dict[str,Tensor]:ifnot_TORCH_GREATER_EQUAL_2_0:rank_zero_warn("FSDP in Lightning with PyTorch < 2.0 does not support saving the optimizer state.")return{}fromtorch.distributed.fsdpimportFullyShardedDataParallelasFSDPfromtorch.distributed.fsdpimportOptimStateKeyTypeifisinstance(optimizer,LightningOptimizer):optimizer=optimizer._optimizerassertself.modelisnotNoneifself._state_dict_type=="sharded":with_get_sharded_state_dict_context(self.model):returnFSDP.optim_state_dict(self.model,optimizer)elifself._state_dict_type=="full":with_get_full_state_dict_context(self.model,world_size=self.world_size):state_dict=FSDP.optim_state_dict(self.model,optimizer)ifself.global_rank==0:# Store the optimizer state dict in standard formatstate_dict=FSDP.rekey_optim_state_dict(state_dict,OptimStateKeyType.PARAM_ID,self.model)returnstate_dictraiseValueError(f"Unknown state_dict_type: {self._state_dict_type}")
defload_optimizer_state_dict(self,checkpoint:Mapping[str,Any])->None:# Override to do nothing, the FSDP already loaded the states in `load_checkpoint()`pass
[docs]defsave_checkpoint(self,checkpoint:Dict[str,Any],filepath:_PATH,storage_options:Optional[Any]=None)->None:ifstorage_optionsisnotNone:raiseTypeError("`FSDPStrategy.save_checkpoint(..., storage_options=...)` is not supported because"" `FSDPStrategy` does not use the `CheckpointIO`.")path=Path(self.broadcast(filepath))ifpath.is_dir()andself._state_dict_type=="full"andnot_is_sharded_checkpoint(path):raiseIsADirectoryError(f"The checkpoint path exists and is a directory: {path}")ifself._state_dict_type=="sharded":fromtorch.distributed.checkpointimportFileSystemWriter,save_state_dictifpath.is_file():path.unlink()path.mkdir(parents=True,exist_ok=True)converted_state={"model":checkpoint.pop("state_dict")}converted_state.update({f"optimizer_{idx}":optim_stateforidx,optim_stateinenumerate(checkpoint.pop("optimizer_states"))})# FSDP's FileSystemWriter streams the tensors to disk to minimize memory peakswriter=FileSystemWriter(path=path,single_file_per_rank=True)save_state_dict(converted_state,writer)ifself.global_rank==0:torch.save(checkpoint,path/_METADATA_FILENAME)elifself._state_dict_type=="full":if_is_sharded_checkpoint(path):shutil.rmtree(path)returnsuper().save_checkpoint(checkpoint=checkpoint,filepath=path)else:raiseValueError(f"Unknown state_dict_type: {self._state_dict_type}")
defload_checkpoint(self,checkpoint_path:_PATH)->Dict[str,Any]:# broadcast the path from rank 0 to ensure all the states are loaded from a common pathpath=Path(self.broadcast(checkpoint_path))fromtorch.distributed.fsdpimportFullyShardedDataParallelasFSDPassertself.modelisnotNoneassertself.lightning_moduleisnotNoneif_is_sharded_checkpoint(path):fromtorch.distributed.checkpointimportFileSystemReader,load_state_dictfromtorch.distributed.checkpoint.optimizerimportload_sharded_optimizer_state_dictstate_dict_ctx=_get_sharded_state_dict_context(self.model)reader=FileSystemReader(path=path)withstate_dict_ctx:module_state={"model":self.model.state_dict()}load_state_dict(module_state,reader)self.model.load_state_dict(module_state["model"])ifself.lightning_module.trainer.state.fn==TrainerFn.FITTING:# the optimizer states must be loaded separatelyforidx,optiminenumerate(self.optimizers):optim_key=f"optimizer_{idx}"optim_state=load_sharded_optimizer_state_dict(model_state_dict=module_state["model"],optimizer_key=optim_key,storage_reader=reader,)flattened_osd=FSDP.optim_state_dict_to_load(optim_state_dict=optim_state[optim_key],model=self.model,optim=optim,)optim.load_state_dict(flattened_osd)# Load metadata (anything not a module or optimizer)metadata=torch.load(path/_METADATA_FILENAME)returnmetadataif_is_full_checkpoint(path):checkpoint=_lazy_load(path)if_TORCH_GREATER_EQUAL_2_0elsetorch.load(path,map_location="cpu")_load_raw_module_state(checkpoint.pop("state_dict"),module=self.model,world_size=self.world_size)if_TORCH_GREATER_EQUAL_2_0:# Materialize lazy tensors if there are any left in the checkpoint# The `torch.Optimizer.load_state_dict` method can't load lazy tensors because of deepcopy pickle issuescheckpoint=_materialize_tensors(checkpoint)fromtorch.distributed.fsdpimportFullyShardedDataParallelasFSDPfromtorch.distributed.fsdpimportOptimStateKeyTypeoptimizer_states=checkpoint.get("optimizer_states")ifoptimizer_statesisNoneorself.lightning_module.trainer.state.fn!=TrainerFn.FITTING:# If the optimizer states are not present, we don't need to do anything (backward compatibility)returncheckpointifnot_TORCH_GREATER_EQUAL_2_0:rank_zero_warn("FSDP in Lightning with PyTorch < 2.0 does not support loading the optimizer state.")returncheckpointiflen(self.optimizers)!=len(optimizer_states):raiseRuntimeError(f"You have configured {len(self.optimizers)} optimizers but the checkpoint contains"f" {len(optimizer_states)} optimizers to load. Please resume training with the same number"" of optimizers or edit the checkpoint manually to remove states.")# rank0_only should be false because we need to load the optimizer state on all rankswith_get_full_state_dict_context(self.model,world_size=self.world_size,rank0_only=False):foroptimizer,opt_stateinzip(self.optimizers,optimizer_states):ifisinstance(list(opt_state["state"].keys())[0],int):# Handling the case where the optimizer state is saved from a normal optimizeropt_state=FSDP.rekey_optim_state_dict(opt_state,OptimStateKeyType.PARAM_NAME,self.model)opt_state=FSDP.optim_state_dict_to_load(optim_state_dict=opt_state,model=self.model,optim=optimizer,)optimizer.load_state_dict(opt_state)returncheckpointraiseValueError(f"The path {str(path)!r} does not point to a valid checkpoint. Make sure the path points to either a"" directory with FSDP checkpoint shards, or a single file with a full checkpoint.")
To analyze traffic and optimize your experience, we serve cookies on this
site. By clicking or navigating, you agree to allow our usage of cookies.
Read PyTorch Lightning's
Privacy Policy.
You are viewing an outdated version of PyTorch Lightning Docs