# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importosfromabcimportABC,abstractmethodfromcontextlibimportcontextmanagerfromfunctoolsimportpartialfrompathlibimportPathfromtypingimportAny,Callable,cast,Dict,Generator,List,Optional,overload,Sequence,Tuple,Unionimporttorchimporttorch.nnasnnfromtorchimportTensorfromtorch.optimimportOptimizerfromtorch.utils.dataimportDataLoader,DistributedSampler,RandomSampler,SequentialSamplerfrompytorch_lightning.accelerators.acceleratorimportAcceleratorfrompytorch_lightning.lite.wrappersimport_LiteDataLoader,_LiteModule,_LiteOptimizerfrompytorch_lightning.pluginsimportPLUGIN_INPUTfrompytorch_lightning.strategiesimportDeepSpeedStrategy,Strategy,TPUSpawnStrategyfrompytorch_lightning.strategies.strategyimportTBroadcastfrompytorch_lightning.trainer.connectors.accelerator_connectorimportAcceleratorConnectorfrompytorch_lightning.utilitiesimport_AcceleratorType,_StrategyType,move_data_to_devicefrompytorch_lightning.utilities.apply_funcimportapply_to_collection,convert_to_tensorsfrompytorch_lightning.utilities.dataimport(_auto_add_worker_init_fn,_replace_dataloader_init_method,_update_dataloader,has_iterable_dataset,)frompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.seedimportseed_everything
[docs]classLightningLite(ABC):"""Lite accelerates your PyTorch training or inference code with minimal changes required. - Automatic placement of models and data onto the device. - Automatic support for mixed and double precision (smaller memory footprint). - Seamless switching between hardware (CPU, GPU, TPU) and distributed training strategies (data-parallel training, sharded training, etc.). - Automated spawning of processes, no launch utilities required. - Multi-node support. Args: accelerator: The hardware to run on. Possible choices are: ``"cpu"``, ``"gpu"``, ``"tpu"``, ``"auto"``. strategy: Strategy for how to run across multiple devices. Possible choices are: ``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"ddp_sharded"``. devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``. The value applies per node. num_nodes: Number of GPU nodes for distributed training. precision: Double precision (``64``), full precision (``32``), half precision (``16``), or bfloat16 precision (``"bf16"``). plugins: One or several custom plugins gpus: Provides the same function as the ``devices`` argument but implies ``accelerator="gpu"``. tpu_cores: Provides the same function as the ``devices`` argument but implies ``accelerator="tpu"``. """
[docs]def__init__(self,accelerator:Optional[Union[str,Accelerator]]=None,strategy:Optional[Union[str,Strategy]]=None,devices:Optional[Union[List[int],str,int]]=None,num_nodes:int=1,precision:Union[int,str]=32,plugins:Optional[Union[PLUGIN_INPUT,List[PLUGIN_INPUT]]]=None,gpus:Optional[Union[List[int],str,int]]=None,tpu_cores:Optional[Union[List[int],str,int]]=None,)->None:self._check_accelerator_support(accelerator)self._check_strategy_support(strategy)self._accelerator_connector=AcceleratorConnector(num_processes=None,devices=devices,tpu_cores=tpu_cores,ipus=None,accelerator=accelerator,strategy=strategy,gpus=gpus,num_nodes=num_nodes,sync_batchnorm=False,# TODO: add support?benchmark=False,replace_sampler_ddp=True,deterministic=False,precision=precision,amp_type="native",amp_level=None,plugins=plugins,auto_select_gpus=False,)self._strategy=self._accelerator_connector.strategyself._accelerator=self._strategy.acceleratorself._precision_plugin=self._strategy.precision_pluginself._models_setup:int=0# wrap the run method so we can inject setup logic or spawn processes for the usersetattr(self,"run",partial(self._run_impl,self.run))
@propertydefdevice(self)->torch.device:"""The current device this process runs on. Use this to create tensors directly on the device if needed. """returnself._strategy.root_device@propertydefglobal_rank(self)->int:"""The global index of the current process across all devices and nodes."""returngetattr(self._strategy,"global_rank",0)@propertydeflocal_rank(self)->int:"""The index of the current process among the processes running on the local node."""returngetattr(self._strategy,"local_rank",0)@propertydefnode_rank(self)->int:"""The index of the current node."""returngetattr(self._strategy,"node_rank",0)@propertydefworld_size(self)->int:"""The total number of processes running across all devices and nodes."""returngetattr(self._strategy,"world_size",1)@propertydefis_global_zero(self)->bool:"""Wether this rank is rank zero."""returnself._strategy.is_global_zero
[docs]@abstractmethoddefrun(self,*args:Any,**kwargs:Any)->Any:"""All the code inside this run method gets accelerated by Lite. You can pass arbitrary arguments to this function when overriding it. """
[docs]defsetup(self,model:nn.Module,*optimizers:Optimizer,move_to_device:bool=True,)->Any:# no specific return because the way we want our API to look does not play well with mypy"""Setup a model and its optimizers for accelerated training. Args: model: A model to setup *optimizers: The optimizer(s) to setup (no optimizers is also possible) move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually. Returns: The tuple of the wrapped model and list of optimizers, in the same order they were passed in. """self._validate_setup(model,optimizers)ifmove_to_device:model=self._move_model_to_device(model=model,optimizers=list(optimizers))# Let accelerator/plugin wrap and connect the models and optimizersmodel,optimizers=self._strategy._setup_model_and_optimizers(model,list(optimizers))model=_LiteModule(model,self._precision_plugin)optimizers=[_LiteOptimizer(optimizer=optimizer,strategy=self._strategy)foroptimizerinoptimizers]self._models_setup+=1ifoptimizers:# join both types in a list for API conveniencereturn[model]+optimizers# type: ignorereturnmodel
[docs]defsetup_dataloaders(self,*dataloaders:DataLoader,replace_sampler:bool=True,move_to_device:bool=True)->Union[DataLoader,List[DataLoader]]:"""Setup one or multiple dataloaders for accelerated training. If you need different settings for each dataloader, call this method individually for each one. Args: *dataloaders: A single dataloader or a sequence of dataloaders. replace_sampler: If set ``True`` (default), automatically wraps or replaces the sampler on the dataloader(s) for distributed training. If you have a custom sampler defined, set this to this argument to ``False``. move_to_device: If set ``True`` (default), moves the data returned by the dataloader(s) automatically to the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually on the returned data. Returns: The wrapped dataloaders, in the same order they were passed in. """self._validate_setup_dataloaders(dataloaders)dataloaders=[self._setup_dataloader(dataloader,replace_sampler=replace_sampler,move_to_device=move_to_device)fordataloaderindataloaders]dataloaders=dataloaders[0]iflen(dataloaders)==1elsedataloadersreturndataloaders# type: ignore[return-value]
def_setup_dataloader(self,dataloader:DataLoader,replace_sampler:bool=True,move_to_device:bool=True)->DataLoader:"""Setup a single dataloader for accelerated training. Args: dataloader: The dataloader to accelerate. replace_sampler: If set ``True`` (default), automatically wraps or replaces the sampler on the dataloader for distributed training. If you have a custom sampler defined, set this to this argument to ``False``. move_to_device: If set ``True`` (default), moves the data returned by the dataloader automatically to the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually on the returned data. Returns: The wrapped dataloader. """sampler=dataloader.samplerifreplace_samplerandself._requires_distributed_sampler(dataloader):ifnotisinstance(sampler,(SequentialSampler,RandomSampler)):raiseMisconfigurationException("You seem to have configured a sampler in your DataLoader. This will be replaced "" by `DistributedSampler` since `replace_sampler_ddp` is True and you are using"" distributed training. Either remove the sampler from your DataLoader or set"" `replace_sampler=False` if you want to use your custom sampler.")sampler=self._get_distributed_sampler(dataloader,**self._strategy.distributed_sampler_kwargs)# the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler)dataloader=_update_dataloader(dataloader,sampler)# add worker_init_fn for correct seeding in worker processes_auto_add_worker_init_fn(dataloader,self.global_rank)dataloader=self._strategy.process_dataloader(dataloader)device=self.deviceifmove_to_deviceandnotisinstance(self._strategy,TPUSpawnStrategy)elseNonelite_dataloader=_LiteDataLoader(dataloader=dataloader,device=device)lite_dataloader=cast(DataLoader,lite_dataloader)returnlite_dataloader
[docs]defbackward(self,tensor:Tensor,*args:Any,model:Optional[_LiteModule]=None,**kwargs:Any)->None:"""Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you. Args: tensor: The tensor (loss) to back-propagate gradients from. *args: Optional positional arguments passed to the underlying backward function. model: Optional model instance for plugins that require the model for backward(). **kwargs: Optional named keyword arguments passed to the underlying backward function. Note: When using ``strategy="deepspeed"`` and multiple models were setup, it is required to pass in the model as argument here. """module=model.moduleifmodelisnotNoneelsemodelifisinstance(self._strategy,DeepSpeedStrategy):ifmodelisNone:ifself._models_setup==0:raiseMisconfigurationException("No models were setup for backward. Did you forget to call `self.setup()`?")ifself._models_setup>1:raiseMisconfigurationException("When using multiple models + deepspeed, please provide the model used to perform"" the optimization: `self.backward(loss, model=model)`")module=self._strategy.modelelse:# requires to attach the current `DeepSpeedEngine` for the `_LiteOptimizer.step` call.self._strategy.model=moduleself._precision_plugin._run_backward(tensor,module,*args,**kwargs)
[docs]@contextmanagerdefautocast(self)->Generator[None,None,None]:"""A context manager to automatically convert operations for the chosen precision. Use this only if the `forward` method of your model does not cover all operations you wish to run with the chosen precision setting. """withself._precision_plugin.forward_context():yield
[docs]defto_device(self,obj:Union[nn.Module,Tensor,Any])->Union[nn.Module,Tensor,Any]:"""Move a :class:`torch.nn.Module` or a collection of tensors to the current device, if it is not already on that device. Args: obj: An object to move to the device. Can be an instance of :class:`torch.nn.Module`, a tensor, or a (nested) collection of tensors (e.g., a dictionary). Returns: A reference to the object that was moved to the new device. """ifisinstance(obj,nn.Module):ifself.device.type=="cuda":# need to call this manually here again in case we spawned with DDPSpawnStrategy# TODO: refactor to let plugin handle this cleanlytorch.cuda.set_device(self.device)returnobj.to(self.device)returnmove_data_to_device(obj,device=self.device)
[docs]defprint(self,*args:Any,**kwargs:Any)->None:"""Print something only on the first process. Arguments passed to this method are forwarded to the Python built-in :func:`print` function. """ifself.local_rank==0:print(*args,**kwargs)
[docs]defbarrier(self,name:Optional[str]=None)->None:"""Wait for all processes to enter this call. Use this to synchronize all parallel processes, but only if necessary, otherwise the overhead of synchronization will cause your program to slow down. Example:: if self.global_rank == 0: # let process 0 download the dataset dataset.download_files() # let all processes wait before reading the dataset self.barrier() # now all processes can read the files and start training """self._strategy.barrier(name=name)
[docs]defall_gather(self,data:Union[torch.Tensor,Dict,List,Tuple],group:Optional[Any]=None,sync_grads:bool=False)->Union[torch.Tensor,Dict,List,Tuple]:r""" Gather tensors or collections of tensors from multiple processes. Args: data: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof. group: the process group to gather results from. Defaults to all processes (world) sync_grads: flag that allows users to synchronize gradients for the all_gather operation Return: A tensor of shape (world_size, batch, ...), or if the input was a collection the output will also be a collection with tensors of this shape. """group=groupifgroupisnotNoneelsetorch.distributed.group.WORLDdata=convert_to_tensors(data,device=self.device)returnapply_to_collection(data,torch.Tensor,self._strategy.all_gather,group=group,sync_grads=sync_grads)
[docs]defsave(self,content:Dict[str,Any],filepath:Union[str,Path])->None:"""Save checkpoint contents to a file. How and which processes save gets determined by the `strategy`. For example, the `ddp` strategy saves checkpoints only on process 0. Args: content: A dictionary with contents, i.e., the state dict of your model filepath: A path to where the file should be saved """self._strategy.save_checkpoint(content,filepath)
[docs]defload(self,filepath:Union[str,Path])->Any:"""Load a checkpoint from a file. How and which processes load gets determined by the `strategy` Args: filepath: A path to where the file is located """returnself._strategy.load_checkpoint(filepath)
[docs]@staticmethoddefseed_everything(seed:Optional[int]=None,workers:Optional[bool]=None)->int:"""Helper function to seed everything without explicitly importing Lightning. See :func:`pytorch_lightning.seed_everything` for more details. """ifworkersisNone:# Lightning sets `workers=False` by default to avoid breaking reproducibility, but since this is a new# release, we can afford to do it.workers=Truereturnseed_everything(seed=seed,workers=workers)
def_run_impl(self,run_method:Callable,*args:Any,**kwargs:Any)->Any:# apply sharded context to prevent OOMrun_method=partial(self._run_with_strategy_setup,run_method)ifself._strategy.launcherisnotNone:returnself._strategy.launcher.launch(run_method,*args,**kwargs)else:returnrun_method(*args,**kwargs)def_run_with_strategy_setup(self,run_method:Callable,*args:Any,**kwargs:Any)->Any:self._strategy.setup_environment()withself._strategy.model_sharded_context(),_replace_dataloader_init_method():returnrun_method(*args,**kwargs)def_move_model_to_device(self,model:nn.Module,optimizers:List[Optimizer])->nn.Module:ifisinstance(self._strategy,TPUSpawnStrategy):# When the user creates the optimizer, they reference the parameters on the CPU.# However, when running with TPU the parameters get copied and the reference in the optimizer# remains invalid. We need to update the references to point to the parameter tensors on the device.params_before_move=dict(model.named_parameters())model=self.to_device(model)# XLA makes a copy on the parameters, so the device is not the same before and after to_device.params_on_device=dict(model.named_parameters())mapping={param:params_on_device[name]forname,paraminparams_before_move.items()}foroptimizerinoptimizers:forparam_groupinoptimizer.param_groups:param_group["params"]=[mapping.get(p,p)forpinparam_group["params"]]else:model=self.to_device(model)returnmodeldef_requires_distributed_sampler(self,dataloader:DataLoader)->bool:return(self._accelerator_connector.is_distributedandnotisinstance(dataloader.sampler,DistributedSampler)andnothas_iterable_dataset(dataloader))@staticmethoddef_get_distributed_sampler(dataloader:DataLoader,**kwargs:Any)->DistributedSampler:kwargs.setdefault("seed",int(os.getenv("PL_GLOBAL_SEED",0)))returnDistributedSampler(dataloader.dataset,**kwargs)def_check_accelerator_support(self,accelerator:Optional[Union[str,Accelerator]])->None:supported=[t.value.lower()fortinself._supported_device_types()]+["auto"]valid=acceleratorisNoneorisinstance(accelerator,Accelerator)oracceleratorinsupportedifnotvalid:raiseMisconfigurationException(f"`accelerator={repr(accelerator)}` is not a valid choice."f" Choose one of {supported} or pass in a `Accelerator` instance.")def_check_strategy_support(self,strategy:Optional[Union[str,Strategy]])->None:supported=[t.lower()fortinself._supported_strategy_types()]valid=strategyisNoneorisinstance(strategy,Strategy)orstrategyinsupportedifnotvalid:raiseMisconfigurationException(f"`strategy={repr(strategy)}` is not a valid choice."f" Choose one of {supported} or pass in a `Strategy` instance.")@staticmethoddef_supported_device_types()->Sequence[_AcceleratorType]:return(_AcceleratorType.CPU,_AcceleratorType.GPU,_AcceleratorType.TPU,)@staticmethoddef_supported_strategy_types()->Sequence[_StrategyType]:return(_StrategyType.DP,_StrategyType.DDP,_StrategyType.DDP_SPAWN,_StrategyType.DEEPSPEED,_StrategyType.DDP_SHARDED,_StrategyType.DDP_SHARDED_SPAWN,)@staticmethoddef_validate_setup(model:nn.Module,optimizers:Sequence[Optimizer])->None:ifisinstance(model,_LiteModule):raiseMisconfigurationException("A model should be passed only once to the `setup` method.")ifany(isinstance(opt,_LiteOptimizer)foroptinoptimizers):raiseMisconfigurationException("An optimizer should be passed only once to the `setup` method.")@staticmethoddef_validate_setup_dataloaders(dataloaders:Sequence[DataLoader])->None:ifany(isinstance(dl,_LiteDataLoader)fordlindataloaders):raiseMisconfigurationException("A dataloader should be passed only once to the `setup_dataloaders` method")ifany(notisinstance(dl,DataLoader)fordlindataloaders):raiseMisconfigurationException("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.