Source code for pytorch_lightning.strategies.colossalai
# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importmathfromtypingimportAny,Callable,Dict,List,Mapping,Optional,TYPE_CHECKING,Unionimporttorchfromlightning_utilities.core.importsimportRequirementCachefromtorchimportTensorfromtorch.nnimportModulefromtorch.optim.optimizerimportOptimizerfromtyping_extensionsimportOrderedDictimportpytorch_lightningasplfromlightning_fabric.accelerators.cudaimport_patch_cuda_is_availablefromlightning_fabric.plugins.environments.cluster_environmentimportClusterEnvironmentfromlightning_fabric.utilities.distributedimportReduceOpfrompytorch_lightning.accelerators.cudaimportCUDAAcceleratorfrompytorch_lightning.overrides.baseimport_LightningModuleWrapperBase,_LightningPrecisionModuleWrapperBasefrompytorch_lightning.plugins.io.checkpoint_pluginimportCheckpointIOfrompytorch_lightning.plugins.precisionimportColossalAIPrecisionPluginfrompytorch_lightning.strategies.ddpimportDDPStrategyfrompytorch_lightning.strategies.strategyimportTBroadcastfrompytorch_lightning.trainer.statesimportTrainerFnfrompytorch_lightning.utilities.model_helpersimportis_overriddenfrompytorch_lightning.utilities.rank_zeroimportrank_zero_warnfrompytorch_lightning.utilities.typesimportSTEP_OUTPUT_COLOSSALAI_AVAILABLE=RequirementCache("colossalai")ifTYPE_CHECKINGand_COLOSSALAI_AVAILABLE:with_patch_cuda_is_available():fromcolossalai.utils.model.colo_init_contextimportColoInitContextelse:ColoInitContext=Any
[docs]classColossalAIStrategy(DDPStrategy):"""ColossalAI strategy. It only supports a single optimizer, which must be :class:`colossalai.nn.optimizer.CPUAdam` or :class:`colossalai.nn.optimizer.HybridAdam` now. Your model must be created in the function ``LightningModule.configure_sharded_model()``. Thus, you should overwrite this function. More details can be found in the below example. It configures accelerator and precision, and you should not configure them when initializing ``Trainer``. CUDA is essential for this strategy. Please make sure CUDA is available. Example:: class GLUETransformer(LightningModule): ... def configure_sharded_model(self) -> None: self.model = BertForSequenceClassification.from_pretrained('bert-base-uncased') trainer = Trainer(..., accelerator="gpu", precision=16, strategy="colossalai") Args: use_chunk: Whether to use chunk-based memory management. It can speed up training, but slightly more memory will be used. chunk_size: The size of a chunk. It will be ignored when ``use_chunk=False``. If it's None, a best chunk size will be searched out based on ``chunk_search_range``, ``chunk_search_n_grids`` and ``min_chunk_size``. enable_distributed_storage: Whether to storage model in a distributed manner. It reduces memory from 1 to 1/N, but it may slow down training. placement_policy: It can be "cpu", "cuda" and "auto". * If it's "cpu", parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used. * If it's "cuda", they won't be offloaded, which means max CUDA memory will be used. It's the fastest. * If it's "auto", they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well. Note that "auto" policy can only work well when no other processes use CUDA during your training. force_outputs_fp32: Whether to cast outputs to fp32. gpu_margin_mem_ratio: The ratio of GPU remaining memory (after the first forward-backward) which will be used by optimizer. This argument will be ignored when ``placement_policy`` is not "auto". chunk_search_range: The range of chunk size to search. The actual search range will be from ``max(min_chunk_size, max_param_size)`` to ``max(min_chunk_size, max_param_size) + chunk_search_range``. chunk_search_n_grids: The number of intervals in the search range. min_chunk_size: The minimum size for a chunk in bytes. initial_scale: The initial dynamic loss scale value. min_scale: The minimum dynamic loss scaling value. growth_factor: The multiplication factor for increasing loss scale. backoff_factor: The multiplication factor for decreasing loss scale. growth_interval: The number of steps to increase loss scale when no overflow occurs. hysteresis: The number of overflows before decreasing loss scale. max_scale: The maximum dynamic loss scaling value. .. _colossalai.nn.optimizer.CPUAdam: https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.nn.optimizer.cpu_adam.html .. _colossalai.nn.optimizer.HybridAdam: https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.nn.optimizer.hybrid_adam.html """strategy_name="colossalai"def__init__(self,use_chunk:bool=True,chunk_size:Optional[int]=None,enable_distributed_storage:bool=True,placement_policy:str="auto",force_outputs_fp32:bool=False,gpu_margin_mem_ratio:float=0.0,chunk_search_range:int=64*1024**2,chunk_search_n_grids:int=4096,min_chunk_size:int=32*1024**2,initial_scale:float=2**16,min_scale:float=1,growth_factor:float=2,backoff_factor:float=0.5,growth_interval:int=1000,hysteresis:int=2,max_scale:float=2**32,accelerator:Optional["pl.accelerators.Accelerator"]=None,parallel_devices:Optional[List[torch.device]]=None,cluster_environment:Optional[ClusterEnvironment]=None,checkpoint_io:Optional[CheckpointIO]=None,precision_plugin:Optional[ColossalAIPrecisionPlugin]=None,)->None:ifnot_COLOSSALAI_AVAILABLE:raiseModuleNotFoundError("To use the `ColossalAIStrategy`, please install `colossalai` first. ""Download `colossalai` by consulting `https://colossalai.org/download`.")with_patch_cuda_is_available():fromcolossalai.loggingimportget_dist_loggersuper().__init__(accelerator=accelerator,parallel_devices=parallel_devices,cluster_environment=cluster_environment,checkpoint_io=checkpoint_io,precision_plugin=precision_plugin,)self.use_chunk=use_chunkself.chunk_size=chunk_sizeself.enable_distributed_storage=enable_distributed_storageself.placement_policy=placement_policyself.force_outputs_fp32=force_outputs_fp32self.gpu_margin_mem_ratio=gpu_margin_mem_ratioself.chunk_size_search_kwargs={"search_range":chunk_search_range,"n_grids":chunk_search_n_grids,"min_chunk_size":min_chunk_size,}self.amp_kwargs={"initial_scale":initial_scale,"min_scale":min_scale,"growth_factor":growth_factor,"backoff_factor":backoff_factor,"growth_interval":growth_interval,"hysteresis":hysteresis,"max_scale":max_scale,}self._num_nodes=1self._logger=get_dist_logger()@propertydefroot_device(self)->torch.device:with_patch_cuda_is_available():fromcolossalai.utilsimportget_current_deviceifself.parallel_devicesisnotNone:returnself.parallel_devices[self.local_rank]returnget_current_device()@propertydefhandles_gradient_accumulation(self)->bool:"""Whether the plugin handles gradient accumulation internally."""returnTrue@propertydefrestore_checkpoint_after_setup(self)->bool:"""Override to delay restoring from checkpoint till after pre-dispatch."""returnTruedefsetup_distributed(self)->None:with_patch_cuda_is_available():fromcolossalai.contextimportParallelModefromcolossalai.coreimportglobal_contextasgpcfromcolossalai.loggingimportdisable_existing_loggersassertself.cluster_environmentisnotNoneself.set_world_ranks()ifnotgpc.is_initialized(ParallelMode.GLOBAL):disable_existing_loggers()gpc.init_global_dist(rank=self.global_rank,world_size=self.world_size,backend="nccl",host=self.cluster_environment.main_address,port=self.cluster_environment.main_port,)gpc.set_device(self.local_rank)
[docs]defmodel_sharded_context(self)->"ColoInitContext":"""Provide hook to create modules in a distributed aware context. This is useful for when we'd like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time. Returns: Model parallel context. """with_patch_cuda_is_available():fromcolossalai.utils.model.colo_init_contextimportColoInitContextclassModelShardedContext(ColoInitContext):def_post_init_method(self,module:torch.nn.Module,*args:Any,**kwargs:Any)->None:ifgetattr(module,"_colossalai_module",False)isTrue:returnsuper()._post_init_method(module,*args,**kwargs)forsub_moduleinmodule.modules():sub_module._colossalai_module=True# type: ignore[assignment]returnModelShardedContext()
[docs]defsetup_precision_plugin(self)->None:with_patch_cuda_is_available():fromcolossalai.nn.optimizerimportCPUAdam,HybridAdamfromcolossalai.zeroimportZeroOptimizersuper().setup_precision_plugin()assertself.lightning_moduleisnotNoneis_training=self.lightning_module.trainerandself.lightning_module.trainer.trainingifis_training:iflen(self.optimizers)>1:raiseValueError("`ColossalAIStrategy` only supports single Optimizer now.")optimizer=self.optimizers[0]ifnotisinstance(optimizer,(CPUAdam,HybridAdam)):raiseValueError("`ColossalAIStrategy` only supports `colossalai.nn.optimizer.CPUAdam` ""and `colossalai.nn.optimizer.HybridAdam` as its optimizer.")assertisinstance(self.model,(pl.LightningModule,_LightningPrecisionModuleWrapperBase))pl_module=self.modelifnothasattr(pl_module,"_colossalai_zero"):with_patch_cuda_is_available():fromcolossalai.nn.parallelimportGeminiDDPfromcolossalai.utilsimportget_current_deviceifnotself.use_chunk:raiseValueError("`ColossalAIStrategy` must use chunk in versions higher than 0.1.10")chunk_search_range:int=self.chunk_size_search_kwargs.get("search_range",32*1024**2)search_range_mb:float=chunk_search_range/1024**2search_n_grids:int=self.chunk_size_search_kwargs.get("n_grids",4096)search_interval:int=math.ceil(chunk_search_range/search_n_grids)min_chunk_size_mb=int(self.chunk_size_search_kwargs["min_chunk_size"]//(1024**2))model=_LightningModuleWrapperBase(self.model)self.model=GeminiDDP(module=model,device=get_current_device(),placement_policy=self.placement_policy,pin_memory=True,force_outputs_fp32=self.force_outputs_fp32,search_range_mb=search_range_mb,hidden_dim=search_interval,min_chunk_size_mb=min_chunk_size_mb,)assertself.modelisnotNonepl_module._colossalai_zero=[self.model]# type: ignore[assignment]else:self.model=pl_module._colossalai_zero[0]# type: ignore[index, assignment]ifis_training:self.optimizers=[ZeroOptimizer(optimizer,self.model,gpu_margin_mem_ratio=self.gpu_margin_mem_ratio,**self.amp_kwargs)]
[docs]defsetup(self,trainer:"pl.Trainer")->None:precision=self.precision_plugin.precisionifprecision!="16":raiseValueError(f"`Trainer(strategy='colossalai', precision={precision!r})` is not supported."" Consider setting `precision=16`.")ifnotisinstance(self.accelerator,CUDAAccelerator):raiseValueError("`ColossalAIStrategy` is only supported on `CUDAAccelerator`, "f"but `{self.accelerator.__class__.__name__}` is used.")iftrainer.state.fn==TrainerFn.FITTING:ifis_overridden("backward",trainer.lightning_module):rank_zero_warn("You have overridden the `LightningModule.backward` hook"" but it will be ignored since ColossalAI handles"" the backward logic internally.")iftrainer.accumulate_grad_batches>1:raiseValueError("ColossalAI does not support gradient accumulation now. Please set `accumulate_grad_batches` to 1.")accumulation_scheduler=trainer.accumulation_schedulerifaccumulation_scheduler.epochs!=[0]:raiseValueError("ColossalAI currently does not support different `accumulate_grad_batches` at different epochs.")ifnotisinstance(self.precision_plugin,ColossalAIPrecisionPlugin):raiseValueError("`ColossalAIStrategy` is only compatible with `ColossalAIPrecisionPlugin`.")self.accelerator.setup(trainer)assertself.lightning_moduleisnotNoneself.lightning_module._device=self.root_deviceself.ignore_no_grad_parameters(self.root_device)self.setup_optimizers(trainer)self.setup_precision_plugin()self.model_to_device()
defignore_no_grad_parameters(self,running_device:torch.device)->None:# for those parameters with no gradients# we shold ignore them on DDP and move them to CUDAassertself.modelisnotNoneforparaminself.model.parameters():ifnotparam.requires_grad:setattr(param,"_ddp_to_ignore",True)param.data=param.data.to(running_device)
[docs]defoptimizer_step(self,optimizer:Optimizer,opt_idx:int,closure:Callable[[],Any],model:Optional[Union["pl.LightningModule",Module]]=None,**kwargs:Any,)->Any:model=modelorself.lightning_module# TODO(fabric): remove assertion once strategy's optimizer_step typing is fixedassertisinstance(model,pl.LightningModule)returnself.precision_plugin.optimizer_step(optimizer,model=model,optimizer_idx=opt_idx,closure=closure,**kwargs)
[docs]deflightning_module_state_dict(self,rank_zero_only:bool=False)->Dict[str,Any]:"""Returns a dictionary containing a whole state of the module. But all the tensors in the dictionary are detached from their parameters and located in cpu memory. Args: rank_zero_only: If True, only process rank 0 gets the correct dictionary. Otherwise, all processes get the same dictionary. """with_patch_cuda_is_available():fromcolossalai.nn.parallelimportZeroDDPassertisinstance(self.model,ZeroDDP)org_dict=self.model.state_dict(only_rank_0=rank_zero_only)children=list(self.model.named_children())assertlen(children)==1prefix,child=children[0]prefix+="."assertchildisself.lightning_modulemapping_dict=dict()forkeyinorg_dict.keys():mapping_dict[key]=key.replace(prefix,"")# remove "_forward_module." from the keyreturn{mapping_dict[key]:valueforkey,valueinorg_dict.items()}
defload_model_state_dict(self,checkpoint:Mapping[str,Any])->None:orig_dict=checkpoint["state_dict"]assertself.modelisnotNonechildren=list(self.model.named_children())assertlen(children)==1prefix,child=children[0]prefix+="."assertchildisself.lightning_modulemapping_dict=dict()forkeyinorig_dict.keys():mapping_dict[key]=prefix+key# add "_forward_module." to the keyload_dict=OrderedDict({mapping_dict[key]:valueforkey,valueinorig_dict.items()})self.model.load_state_dict(load_dict)
[docs]defbroadcast(self,obj:TBroadcast,src:int=0)->TBroadcast:"""Broadcasts an object to all processes. Args: obj: the object to broadcast src: source rank """with_patch_cuda_is_available():fromcolossalai.communication.collectiveimportbroadcastfromcolossalai.contextimportParallelModefromcolossalai.coreimportglobal_contextasgpcifisinstance(obj,Tensor):returnbroadcast(obj,src=src,parallel_mode=ParallelMode.GLOBAL)else:obj_list=[obj]torch.distributed.broadcast_object_list(obj_list,src,group=gpc.get_group(ParallelMode.GLOBAL))returnobj_list[0]
[docs]defall_gather(self,tensor:Tensor,group:Optional[Any]=None,sync_grads:bool=False)->Tensor:"""Perform a all_gather on all processes."""with_patch_cuda_is_available():fromcolossalai.communication.collectiveimportall_gatherfromcolossalai.contextimportParallelModeassertsync_gradsisFalsereturnall_gather(tensor,dim=0,parallel_mode=ParallelMode.GLOBAL)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.