# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importjsonimportosfromtypingimportAny,Callable,Dict,Iterable,List,Optional,Tuple,Unionimporttorchfromlightning_utilities.core.apply_funcimportapply_to_collectionfromtorchimportTensorfromtorch.utils.dataimportDataLoader,Samplerimportpytorch_lightningasplfromlightning_fabric.pluginsimportCheckpointIO,ClusterEnvironmentfromlightning_fabric.utilities.cloud_ioimportget_filesystemfrompytorch_lightning.accelerators.ipuimport_IPU_AVAILABLE,_POPTORCH_AVAILABLEfrompytorch_lightning.overrides.baseimport_LightningModuleWrapperBasefrompytorch_lightning.plugins.precisionimportPrecisionPluginfrompytorch_lightning.strategies.parallelimportParallelStrategyfrompytorch_lightning.strategies.strategyimportTBroadcastfrompytorch_lightning.strategies.utilsimport_fp_to_halffrompytorch_lightning.trainer.statesimportRunningStage,TrainerFnfrompytorch_lightning.utilitiesimportrank_zero_warnfrompytorch_lightning.utilities.dataimport_get_dataloader_init_args_and_kwargs,_reinstantiate_wrapped_clsfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.model_helpersimportis_overriddenfrompytorch_lightning.utilities.typesimportSTEP_OUTPUTif_POPTORCH_AVAILABLE:importpoptorchelse:poptorch=None
[docs]classIPUStrategy(ParallelStrategy):"""Plugin for training on IPU devices."""strategy_name="ipu_strategy"def__init__(self,accelerator:Optional["pl.accelerators.Accelerator"]=None,device_iterations:int=1,autoreport:bool=False,autoreport_dir:Optional[str]=None,parallel_devices:Optional[List[torch.device]]=None,cluster_environment:Optional[ClusterEnvironment]=None,checkpoint_io:Optional[CheckpointIO]=None,precision_plugin:Optional[PrecisionPlugin]=None,training_opts:Optional["poptorch.Options"]=None,inference_opts:Optional["poptorch.Options"]=None,)->None:""" Arguments: device_iterations: Number of iterations to run on device at once before returning to host. This can be used as an optimization to speed up training. https://docs.graphcore.ai/projects/poptorch-user-guide/en/latest/batching.html autoreport: Enable auto-reporting for IPUs using PopVision https://docs.graphcore.ai/projects/graphcore-popvision-user-guide/en/latest/graph/graph.html autoreport_dir: Optional directory to store autoReport output. training_opts: Optional ``poptorch.Options`` to override the default created options for training. inference_opts: Optional ``poptorch.Options`` to override the default created options for validation/testing and predicting. """super().__init__(accelerator=accelerator,parallel_devices=parallel_devices,cluster_environment=cluster_environment,checkpoint_io=checkpoint_io,precision_plugin=precision_plugin,)ifnot_IPU_AVAILABLE:raiseMisconfigurationException("The IPU Accelerator requires IPU devices to run. ""Learn more or get started with IPUs at https://www.graphcore.ai/getstarted")self.device_iterations=device_iterationsself.autoreport=autoreportself.autoreport_dir=autoreport_dirself.poptorch_models:Dict[RunningStage,"poptorch.PoplarExecutor"]={}self._training_opts=training_optsself._inference_opts=inference_optsifself.autoreport:options:Dict[str,Any]={"autoReport.all":self.autoreport}ifself.autoreport_dir:self._fs=get_filesystem(str(self.autoreport_dir))self._fs.makedirs(self.autoreport_dir,exist_ok=True)options["autoReport.directory"]=self.autoreport_diros.environ["POPLAR_ENGINE_OPTIONS"]=json.dumps(options)self._update_dataloader_original:Optional[Callable]=Noneself._optimizer_zero_grad_original:Optional[Callable]=None
[docs]defsetup(self,trainer:"pl.Trainer")->None:# set the `accumulate_grad_batches` property as early as possibleself._handle_gradient_accumulation_steps()# patch the dataloader creation function with the custom `poptorch.DataLoader`.# this violates the intended control flow for the plugins, but since this is experimental, we have chosen# to use the simpler solution before adding abstractions to override the `DataLoader` classself._update_dataloader_original=pl.trainer.connectors.data_connector._update_dataloaderpl.trainer.connectors.data_connector._update_dataloader=self._convert_to_poptorch_loadersuper().setup(trainer)assertself.lightning_moduleisnotNone# disable the `optimizer_zero_grad` function by setting it to `None`.# this is because the IPU zeros the gradients internallyself._optimizer_zero_grad_original=self.lightning_module.optimizer_zero_gradself._disable_zero_grad()self.model=_LightningModuleWrapperBase(self.lightning_module)# reset the backupself.poptorch_models={}# Separate models are instantiated for different stages, but they share the same weights on host.# When validation/test models are run, weights are synced first.trainer_fn=self.lightning_module.trainer.state.fniftrainer_fn==TrainerFn.FITTING:# Create model for training and validation which will run on fittraining_opts=self.training_optsinference_opts=self.inference_optsoptimizer=self.lightning_module.trainer.optimizers[0]model=poptorch.trainingModel(model=self.model,options=training_opts,optimizer=optimizer)self.poptorch_models[RunningStage.TRAINING]=modelifself.lightning_module.trainer.enable_validation:model=poptorch.inferenceModel(model=self.model,options=inference_opts)self.poptorch_models[RunningStage.VALIDATING]=modelifself.lightning_module.trainer.num_sanity_val_steps>0:self.poptorch_models[RunningStage.SANITY_CHECKING]=modeleliftrainer_fn==TrainerFn.VALIDATING:model=poptorch.inferenceModel(model=self.model,options=self.inference_opts)self.poptorch_models[RunningStage.VALIDATING]=modeleliftrainer_fn==TrainerFn.TESTING:model=poptorch.inferenceModel(model=self.model,options=self.inference_opts)self.poptorch_models[RunningStage.TESTING]=modeleliftrainer_fn==TrainerFn.PREDICTING:model=poptorch.inferenceModel(model=self.model,options=self.inference_opts)self.poptorch_models[RunningStage.PREDICTING]=model
[docs]defsetup_optimizers(self,trainer:"pl.Trainer")->None:super().setup_optimizers(trainer)iflen(self.optimizers)>1:raiseMisconfigurationException("IPUs currently only support one optimizer.")
@propertydefreplication_factor(self)->int:ifnotself.lightning_moduleornotself.poptorch_models:# The plugin has been passed in by the user and has not been connected to the Trainer.# Check if the user has passed in custom poptorch.Options to infer number of IPUs being used.# In this scenario we prioritize the training options.ifself._training_opts:returnself._training_opts.replication_factorifself._inference_opts:returnself._inference_opts.replication_factorassertself.parallel_devicesreturnlen(self.parallel_devices)stage=self.lightning_module.trainer.state.stageassertstageisnotNonereturnself.poptorch_models[stage]._options.toDict()["replication_factor"]def_create_opts(self,training:bool)->"poptorch.Options":assertself.lightning_moduleisnotNoneopts=poptorch.Options()opts.deviceIterations(self.device_iterations)opts.replicationFactor(self.replication_factor)gradient_accumulation=self.lightning_module.trainer.accumulate_grad_batchesiftrainingelse1opts.Training.gradientAccumulation(gradient_accumulation)ifos.environ.get("PL_GLOBAL_SEED"):opts.randomSeed(int(os.environ["PL_GLOBAL_SEED"]))returnopts@propertydeftraining_opts(self)->"poptorch.Options":ifself._training_optsisNone:self._training_opts=self._create_opts(training=True)returnself._training_opts@propertydefinference_opts(self)->"poptorch.Options":ifself._inference_optsisNone:self._inference_opts=self._create_opts(training=False)returnself._inference_optsdef_convert_to_poptorch_loader(self,dataloader:DataLoader,sampler:Union[Sampler,Iterable],mode:Optional[RunningStage]=None)->"poptorch.DataLoader":ifisinstance(dataloader,poptorch.DataLoader):# the user is returning the `poptorch.DataLoader` directly, don't change anything.returndataloaderdl_args,dl_kwargs=_get_dataloader_init_args_and_kwargs(dataloader,sampler,mode,self.replication_factor>1)opts=self.training_optsifmode==RunningStage.TRAININGelseself.inference_optsdataloader=_reinstantiate_wrapped_cls(dataloader,opts,*dl_args,explicit_cls=poptorch.DataLoader,**dl_kwargs)returndataloaderdef_handle_gradient_accumulation_steps(self)->None:"""Override the trainer.accumulation_scheduler to act as ``accumulate_grad_batches=1`` if gradient accumulation has been set. ``optimizer_step`` will be called on every batch, and the IPU will handle grad accumulation internally. """assertself.lightning_moduleisnotNoneaccumulation_scheduler=self.lightning_module.trainer.accumulation_schedulerifaccumulation_scheduler.epochs!=[0]:raiseMisconfigurationException("IPUs currently does not support different `accumulate_grad_batches` at different epochs.")# TODO(@tchaton): Add support for accumulate_grad_batches being a dictionaryaccumulation_scheduler.scheduling.update({0:1})@propertydef_n_replicate(self)->int:assertself.lightning_moduleisnotNoneopts=self.training_optsifself.lightning_module.trainingelseself.inference_optsaccumulate_grad_batches=opts.Training.gradient_accumulationdevice_iterations=opts.device_iterationsreplication_factor=opts.replication_factorreturnreplication_factor*device_iterations*accumulate_grad_batchesdef_prepare_input(self,args:Any)->Any:defto_tuple(x:Any)->Tuple:returntuple(x)defto_tensor(x:Any)->Tensor:returntorch.tensor(x).unsqueeze(0).repeat(self._n_replicate)args=apply_to_collection(args,dtype=list,function=to_tuple)args=apply_to_collection(args,dtype=(int,float),function=to_tensor)returnargs
[docs]defbatch_to_device(self,batch:Any,device:Optional[torch.device]=None,dataloader_idx:int=0)->Any:# This override is necessary because the cast must occur before the data# is moved to the device to prevent wasteful host->device copies.batch=apply_to_collection(batch,Tensor,function=_fp_to_half,precision=self.precision_plugin.precision)# We don't call `super().batch_to_device` because `data.to(device)` is not# currently necessary for IPUs. The movement of data from host<->IPU is# currently handled by PopTorch.returnbatch
def_disable_zero_grad(self)->None:lightning_module=self.lightning_moduleassertlightning_moduleisnotNoneifis_overridden("optimizer_zero_grad",lightning_module):assertlightning_moduleisnotNone# `is_overridden` returns False otherwiserank_zero_warn("You have overridden the `LightningModule.optimizer_zero_grad` hook but it will be ignored since"" IPUs handle the zeroing of gradients internally.")lightning_module.optimizer_zero_grad=None# type: ignore[assignment]def_step(self,stage:RunningStage,*args:Any,**kwargs:Any)->STEP_OUTPUT:args=self._prepare_input(args)poptorch_model=self.poptorch_models[stage]withpl.core.module._jit_is_scripting():returnpoptorch_model(*args,**kwargs)
def_compiled(self,model:Any)->bool:# Required to ensure we only attach compiled models, as they are compiled lazily.returnmodel._executableisnotNonedef_detach_models(self)->None:"""Detaches all stage specific models from IPU devices."""fork,modelinself.poptorch_models.items():ifself._compiled(model)andmodel.isAttachedToDevice():model.detachFromDevice()def_load_model(self,stage:RunningStage)->None:"""Loads the stage specific accelerator model onto device if compiled and not attached to IPU devices. Args: stage: The stage to load """self._detach_models()model=self.poptorch_models[stage]ifself._compiled(model)andnotmodel.isAttachedToDevice():model.attachToDevice()
[docs]defon_train_batch_start(self,batch:Any,batch_idx:int)->None:# Updates optimizer stats if LR scheduler modified the optimizer stateoptimizer=self.optimizers[0]self.poptorch_models[RunningStage.TRAINING].setOptimizer(optimizer)
@propertydefroot_device(self)->torch.device:# type: ignore[empty-body]# TODO: this should return `self.parallel_devices[self.local_rank]`pass
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.