Source code for pytorch_lightning.strategies.tpu_spawn
# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importioimportosfromtypingimportAny,Dict,List,Mapping,Optional,Sequence,TYPE_CHECKING,Unionimporttorchfromlightning_utilities.core.apply_funcimportapply_to_collectionfromtorchimportTensorfromtorch.nnimportModulefromtorch.utils.dataimportDataLoaderimportpytorch_lightningasplfromlightning_fabric.accelerators.tpuimport_XLA_AVAILABLEfromlightning_fabric.pluginsimportCheckpointIO,XLACheckpointIOfromlightning_fabric.plugins.environmentsimportXLAEnvironmentfromlightning_fabric.utilities.dataimporthas_lenfromlightning_fabric.utilities.optimizerimport_optimizers_to_devicefromlightning_fabric.utilities.typesimport_PATH,ReduceOpfrompytorch_lightning.overridesimportLightningDistributedModulefrompytorch_lightning.plugins.io.wrapperimport_WrappingCheckpointIOfrompytorch_lightning.plugins.precisionimportPrecisionPluginfrompytorch_lightning.strategies.ddp_spawnimportDDPSpawnStrategyfrompytorch_lightning.strategies.launchers.xlaimport_XLALauncherfrompytorch_lightning.strategies.strategyimportTBroadcastfrompytorch_lightning.trainer.connectors.data_connectorimportDataConnectorfrompytorch_lightning.trainer.statesimportTrainerFnfrompytorch_lightning.utilitiesimportfind_shared_parameters,set_shared_parametersfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.rank_zeroimportrank_zero_onlyfrompytorch_lightning.utilities.typesimportEVAL_DATALOADERS,STEP_OUTPUT,TRAIN_DATALOADERSifTYPE_CHECKINGand_XLA_AVAILABLE:fromtorch_xla.distributed.parallel_loaderimportMpDeviceLoaderelse:MpDeviceLoader=None
[docs]classTPUSpawnStrategy(DDPSpawnStrategy):"""Strategy for training multiple TPU devices using the :func:`torch_xla.distributed.xla_multiprocessing.spawn` method."""strategy_name="tpu_spawn"def__init__(self,accelerator:Optional["pl.accelerators.Accelerator"]=None,parallel_devices:Optional[List[torch.device]]=None,checkpoint_io:Optional[CheckpointIO]=None,precision_plugin:Optional[PrecisionPlugin]=None,debug:bool=False,**_:Any,)->None:ifnot_XLA_AVAILABLE:raiseModuleNotFoundError(str(_XLA_AVAILABLE))super().__init__(accelerator=accelerator,parallel_devices=parallel_devices,cluster_environment=XLAEnvironment(),checkpoint_io=checkpoint_io,precision_plugin=precision_plugin,start_method="fork",)self._checkpoint_io:Optional[CheckpointIO]self.debug=debugself._launched=False@propertydefcheckpoint_io(self)->CheckpointIO:ifself._checkpoint_ioisNone:self._checkpoint_io=XLACheckpointIO()elifisinstance(self._checkpoint_io,_WrappingCheckpointIO):self._checkpoint_io.checkpoint_io=XLACheckpointIO()returnself._checkpoint_io@checkpoint_io.setterdefcheckpoint_io(self,io:Optional[CheckpointIO])->None:self._checkpoint_io=io@propertydefroot_device(self)->torch.device:ifnotself._launched:raiseRuntimeError("Accessing the XLA device before processes have spawned is not allowed.")importtorch_xla.core.xla_modelasxmreturnxm.xla_device()@propertydeflocal_rank(self)->int:returnself.cluster_environment.local_rank()ifself.cluster_environmentisnotNoneelse0@staticmethoddef_validate_dataloader(dataloaders:Union[TRAIN_DATALOADERS,EVAL_DATALOADERS])->None:defcheck_has_len(dataloader:DataLoader)->None:ifnothas_len(dataloader):raiseMisconfigurationException("TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`."" HINT: You can mock the length on your dataset to bypass this MisconfigurationException.")apply_to_collection(dataloaders,dtype=object,wrong_dtype=(Sequence,Mapping),function=check_has_len)@staticmethoddef_validate_patched_dataloaders(model:"pl.LightningModule")->None:"""Validate and fail fast if the dataloaders were passed directly to fit."""connector:DataConnector=model.trainer._data_connectorsources=(connector._train_dataloader_source,connector._val_dataloader_source,connector._test_dataloader_source,connector._predict_dataloader_source,)forsourceinsources:ifnotsource.is_module():assertsource.instanceisnotNoneassertnotisinstance(source.instance,(pl.LightningModule,pl.LightningDataModule))TPUSpawnStrategy._validate_dataloader(source.instance)
def_setup_model(self,model:Module)->Module:# type: ignorereturnmodel@propertydefdistributed_sampler_kwargs(self)->Dict[str,int]:returndict(num_replicas=self.world_size,rank=self.global_rank)@propertydefis_distributed(self)->bool:# HOST_WORLD_SIZE is not set outside the xmp.spawn processimporttorch_xla.core.xla_env_varsasxenvreturn(xenv.HOST_WORLD_SIZEinos.environ)andself.world_size!=1
[docs]defprocess_dataloader(self,dataloader:DataLoader)->"MpDeviceLoader":TPUSpawnStrategy._validate_dataloader(dataloader)fromtorch_xla.distributed.parallel_loaderimportMpDeviceLoaderifisinstance(dataloader,MpDeviceLoader):# dataloader is already wrapped by MpDeviceLoaderreturndataloaderdataloader=MpDeviceLoader(dataloader,self.root_device)# Mimic interface to torch.utils.data.DataLoaderdataloader.dataset=dataloader._loader.datasetdataloader.batch_sampler=getattr(dataloader._loader,"batch_sampler",None)returndataloader
[docs]defreduce(self,output:Union[Tensor,Any],group:Optional[Any]=None,reduce_op:Optional[Union[ReduceOp,str]]=None)->Tensor:ifnotisinstance(output,Tensor):output=torch.tensor(output,device=self.root_device)invalid_reduce_op=isinstance(reduce_op,ReduceOp)andreduce_op!=ReduceOp.SUMinvalid_reduce_op_str=isinstance(reduce_op,str)andreduce_op.lower()notin("sum","mean","avg")ifinvalid_reduce_oporinvalid_reduce_op_str:raiseValueError("Currently, the TPUSpawnStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:"f" {reduce_op}")importtorch_xla.core.xla_modelasxmoutput=xm.mesh_reduce("reduce",output,sum)ifisinstance(reduce_op,str)andreduce_op.lower()in("avg","mean"):output=output/self.world_sizereturnoutput
deftraining_step_end(self,output:STEP_OUTPUT)->STEP_OUTPUT:self._pod_progress_bar_force_stdout()returnoutputdefvalidation_step_end(self,output:STEP_OUTPUT)->STEP_OUTPUT:self._pod_progress_bar_force_stdout()returnoutputdeftest_step_end(self,output:STEP_OUTPUT)->STEP_OUTPUT:self._pod_progress_bar_force_stdout()returnoutputdef_pod_progress_bar_force_stdout(self)->None:# Why is it required? The way `pytorch_xla.distributed` streams logs# from different vms to the main worker doesn't work well with tqdm# Ref: https://github.com/pytorch/xla/blob/master/torch_xla/distributed/xla_dist.py#L140# The print statement seems to force tqdm to flush stdout.importtorch_xla.core.xla_env_varsasxenvifself.global_rank==0andint(os.getenv(xenv.TPUVM_MODE,0))==1:print()
[docs]defsave_checkpoint(self,checkpoint:Dict[str,Any],filepath:_PATH,storage_options:Optional[Any]=None)->None:"""Save model/training states as a checkpoint file through state-dump and file-write. Args: checkpoint: dict containing model and trainer state filepath: write-target file's path storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin """# `xla_model.save` needs to be called on all ranks. It internally checks if the local rank is 0self.checkpoint_io.save_checkpoint(checkpoint,filepath,storage_options=storage_options)
[docs]defremove_checkpoint(self,filepath:_PATH)->None:"""Remove checkpoint filepath from the filesystem. Args: filepath: Path to checkpoint """ifself.local_rank==0:self.checkpoint_io.remove_checkpoint(filepath)
[docs]defall_gather(self,tensor:Tensor,group:Optional[Any]=None,sync_grads:bool=False)->Tensor:"""Function to gather a tensor from several distributed processes. Args: tensor: tensor of shape (batch, ...) group: not available with TPUs sync_grads: flag that allows users to synchronize gradients for the all_gather operation Return: A tensor of shape (world_size, batch, ...) """ifisinstance(tensor,Tensor)andtensor.dim()==0:tensor=tensor.unsqueeze(0)importtorch_xla.core.functionsasxfimporttorch_xla.core.xla_modelasxmreturnxf.all_gather(tensor)ifsync_gradselsexm.all_gather(tensor)
@classmethoddefregister_strategies(cls,strategy_registry:Dict)->None:strategy_registry.register("tpu_spawn_debug",cls,description="TPUSpawn Strategy with `debug` as True",debug=True)strategy_registry.register(cls.strategy_name,cls,description=f"{cls.__class__.__name__}",)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.