Source code for pytorch_lightning.strategies.single_device
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.from__future__importannotationsfromtypingimportAnyimporttorchimportpytorch_lightningasplfrompytorch_lightning.plugins.io.checkpoint_pluginimportCheckpointIOfrompytorch_lightning.plugins.precisionimportPrecisionPluginfrompytorch_lightning.strategies.strategyimportStrategyfrompytorch_lightning.utilities.typesimport_DEVICE
[docs]classSingleDeviceStrategy(Strategy):"""Strategy that handles communication on a single device."""strategy_name="single_device"def__init__(self,device:_DEVICE="cpu",accelerator:pl.accelerators.accelerator.Accelerator|None=None,checkpoint_io:CheckpointIO|None=None,precision_plugin:PrecisionPlugin|None=None,):super().__init__(accelerator=accelerator,checkpoint_io=checkpoint_io,precision_plugin=precision_plugin)self._root_device=torch.device(device)self.global_rank=0self.local_rank=0self.world_size=1
[docs]defreduce(self,tensor:Any|torch.Tensor,*args:Any,**kwargs:Any)->Any|torch.Tensor:"""Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only operates with a single device, the reduction is simply the identity. Args: tensor: the tensor to sync and reduce *args: ignored **kwargs: ignored Return: the unmodified input as reduction is not needed for single process operation """returntensor
[docs]defall_gather(self,tensor:torch.Tensor,group:Any|None=None,sync_grads:bool=False)->torch.Tensor:"""Perform a all_gather on all processes."""returntensor
[docs]defteardown(self)->None:super().teardown()ifself.root_device.type=="cuda":# GPU teardownself.lightning_module.cpu()# clean up memorytorch.cuda.empty_cache()
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.