Source code for pytorch_lightning.strategies.tpu_spawn
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importioimportosfromtypingimportAny,Dict,List,Mapping,Optional,Sequence,UnionimporttorchfromtorchimportTensorfromtorch.nnimportModulefromtorch.utils.dataimportDataLoaderimportpytorch_lightningasplfrompytorch_lightning.overridesimportLightningDistributedModulefrompytorch_lightning.plugins.environmentsimportXLAEnvironmentfrompytorch_lightning.plugins.io.checkpoint_pluginimportCheckpointIOfrompytorch_lightning.plugins.io.wrapperimport_WrappingCheckpointIOfrompytorch_lightning.plugins.io.xla_pluginimportXLACheckpointIOfrompytorch_lightning.plugins.precisionimportPrecisionPluginfrompytorch_lightning.strategies.ddp_spawnimportDDPSpawnStrategyfrompytorch_lightning.strategies.launchers.xlaimport_XLALauncherfrompytorch_lightning.strategies.strategyimportTBroadcastfrompytorch_lightning.trainer.connectors.data_connectorimportDataConnectorfrompytorch_lightning.trainer.statesimportTrainerFnfrompytorch_lightning.utilitiesimport_TPU_AVAILABLE,find_shared_parameters,set_shared_parametersfrompytorch_lightning.utilities.apply_funcimportapply_to_collectionfrompytorch_lightning.utilities.dataimporthas_lenfrompytorch_lightning.utilities.distributedimportReduceOpfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.optimizerimportoptimizers_to_devicefrompytorch_lightning.utilities.rank_zeroimportrank_zero_onlyfrompytorch_lightning.utilities.typesimport_PATH,EVAL_DATALOADERS,STEP_OUTPUT,TRAIN_DATALOADERSif_TPU_AVAILABLE:importtorch_xla.core.xla_env_varsasxenvimporttorch_xla.core.xla_modelasxmimporttorch_xla.distributed.xla_multiprocessingasxmpfromtorch_xla.core.xla_modelimportrendezvousfromtorch_xla.distributed.parallel_loaderimportMpDeviceLoaderelse:xm,xmp,MpDeviceLoader,rendezvous=[None]*4
[docs]classTPUSpawnStrategy(DDPSpawnStrategy):"""Strategy for training multiple TPU devices using the :func:`torch_xla.distributed.xla_multiprocessing.spawn` method."""strategy_name="tpu_spawn"def__init__(self,accelerator:Optional["pl.accelerators.accelerator.Accelerator"]=None,parallel_devices:Optional[List[torch.device]]=None,checkpoint_io:Optional[CheckpointIO]=None,precision_plugin:Optional[PrecisionPlugin]=None,debug:bool=False,**_:Any,)->None:super().__init__(accelerator=accelerator,parallel_devices=parallel_devices,cluster_environment=XLAEnvironment(),checkpoint_io=checkpoint_io,precision_plugin=precision_plugin,start_method="fork",)self._checkpoint_io:Optional[CheckpointIO]self.debug=debugself._launched=False@propertydefcheckpoint_io(self)->CheckpointIO:ifself._checkpoint_ioisNone:self._checkpoint_io=XLACheckpointIO()elifisinstance(self._checkpoint_io,_WrappingCheckpointIO):self._checkpoint_io.checkpoint_io=XLACheckpointIO()returnself._checkpoint_io@checkpoint_io.setterdefcheckpoint_io(self,io:Optional[CheckpointIO])->None:self._checkpoint_io=io@propertydefroot_device(self)->torch.device:ifnotself._launched:raiseRuntimeError("Accessing the XLA device before processes have spawned is not allowed.")returnxm.xla_device()@staticmethoddef_validate_dataloader(dataloaders:Union[TRAIN_DATALOADERS,EVAL_DATALOADERS])->None:defcheck_has_len(dataloader:DataLoader)->None:ifnothas_len(dataloader):raiseMisconfigurationException("TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`."" HINT: You can mock the length on your dataset to bypass this MisconfigurationException.")apply_to_collection(dataloaders,dtype=object,wrong_dtype=(Sequence,Mapping),function=check_has_len)@staticmethoddef_validate_patched_dataloaders(model:"pl.LightningModule")->None:"""Validate and fail fast if the dataloaders were passed directly to fit."""connector:DataConnector=model.trainer._data_connectorsources=(connector._train_dataloader_source,connector._val_dataloader_source,connector._test_dataloader_source,connector._predict_dataloader_source,)forsourceinsources:ifnotsource.is_module():assertsource.instanceisnotNoneassertnotisinstance(source.instance,(pl.LightningModule,pl.LightningDataModule))TPUSpawnStrategy._validate_dataloader(source.instance)
def_setup_model(self,model:Module)->Module:# type: ignorereturnmodel@propertydefdistributed_sampler_kwargs(self)->Dict[str,int]:returndict(num_replicas=self.world_size,rank=self.global_rank)@propertydefis_distributed(self)->bool:# HOST_WORLD_SIZE is not set outside the xmp.spawn processreturn(xenv.HOST_WORLD_SIZEinos.environ)andself.world_size!=1
[docs]defprocess_dataloader(self,dataloader:DataLoader)->MpDeviceLoader:TPUSpawnStrategy._validate_dataloader(dataloader)dataloader=MpDeviceLoader(dataloader,self.root_device)# Mimic interface to torch.utils.data.DataLoaderdataloader.dataset=dataloader._loader.datasetreturndataloader
[docs]defreduce(self,output:Union[Tensor,Any],group:Optional[Any]=None,reduce_op:Optional[Union[ReduceOp,str]]=None)->Tensor:ifnotisinstance(output,Tensor):output=torch.tensor(output,device=self.root_device)invalid_reduce_op=isinstance(reduce_op,ReduceOp)andreduce_op!=ReduceOp.SUMinvalid_reduce_op_str=isinstance(reduce_op,str)andreduce_op.lower()notin("sum","mean","avg")ifinvalid_reduce_oporinvalid_reduce_op_str:raiseValueError("Currently, the TPUSpawnStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:"f" {reduce_op}")output=xm.mesh_reduce("reduce",output,sum)ifisinstance(reduce_op,str)andreduce_op.lower()in("avg","mean"):output=output/self.world_sizereturnoutput
deftraining_step_end(self,output:STEP_OUTPUT)->STEP_OUTPUT:self._pod_progress_bar_force_stdout()returnoutputdefvalidation_step_end(self,output:STEP_OUTPUT)->STEP_OUTPUT:self._pod_progress_bar_force_stdout()returnoutputdeftest_step_end(self,output:STEP_OUTPUT)->STEP_OUTPUT:self._pod_progress_bar_force_stdout()returnoutputdef_pod_progress_bar_force_stdout(self)->None:# Why is it required? The way `pytorch_xla.distributed` streams logs# from different vms to the main worker doesn't work well with tqdm# Ref: https://github.com/pytorch/xla/blob/master/torch_xla/distributed/xla_dist.py#L140# The print statement seems to force tqdm to flush stdout.ifself.global_rank==0andint(os.getenv(xenv.TPUVM_MODE,0))==1:print()
[docs]defsave_checkpoint(self,checkpoint:Dict[str,Any],filepath:_PATH,storage_options:Optional[Any]=None)->None:"""Save model/training states as a checkpoint file through state-dump and file-write. Args: checkpoint: dict containing model and trainer state filepath: write-target file's path storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin """# `xla_model.save` needs to be called on all ranks. It internally checks if the local rank is 0self.checkpoint_io.save_checkpoint(checkpoint,filepath,storage_options=storage_options)
[docs]defremove_checkpoint(self,filepath:_PATH)->None:"""Remove checkpoint filepath from the filesystem. Args: filepath: Path to checkpoint """ifself.local_rank==0:self.checkpoint_io.remove_checkpoint(filepath)
[docs]defall_gather(self,tensor:Tensor,group:Optional[Any]=None,sync_grads:bool=False)->Tensor:""" Function to gather a tensor from several distributed processes Args: tensor: tensor of shape (batch, ...) group: not available with TPUs sync_grads: not available with TPUs Return: A tensor of shape (world_size, batch, ...) """ifisinstance(tensor,Tensor)andtensor.dim()==0:tensor=tensor.unsqueeze(0)returnxm.all_gather(tensor)
@classmethoddefregister_strategies(cls,strategy_registry:Dict)->None:strategy_registry.register("tpu_spawn_debug",cls,description="TPUSpawn Strategy with `debug` as True",debug=True)strategy_registry.register(cls.strategy_name,cls,description=f"{cls.__class__.__name__}",)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.