Source code for pytorch_lightning.plugins.precision.double
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromcontextlibimportcontextmanagerfromtypingimportAny,cast,Generator,List,Tupleimporttorchimporttorch.nnasnnfromtorch.optimimportOptimizerimportpytorch_lightningasplfrompytorch_lightning.overrides.baseimport_LightningPrecisionModuleWrapperBasefrompytorch_lightning.plugins.precision.precision_pluginimportPrecisionPluginfrompytorch_lightning.utilities.apply_funcimportapply_to_collectionclassLightningDoublePrecisionModule(_LightningPrecisionModuleWrapperBase):"""LightningModule wrapper which converts incoming floating point data in ``*_step`` and ``forward`` to double (``torch.float64``) precision. Args: pl_module: the model to wrap """@staticmethoddef_to_double_precision(data:torch.Tensor)->torch.Tensor:ifdata.is_floating_point():returndata.double()returndata@staticmethoddef_move_float_tensors_to_double(collection:Any)->Any:returnapply_to_collection(collection,torch.Tensor,LightningDoublePrecisionModule._to_double_precision)deftraining_step(self,*args:Any,**kwargs:Any)->Any:returnself.module.training_step(*LightningDoublePrecisionModule._move_float_tensors_to_double(args),**LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs),)defvalidation_step(self,*args:Any,**kwargs:Any)->Any:returnself.module.validation_step(*LightningDoublePrecisionModule._move_float_tensors_to_double(args),**LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs),)deftest_step(self,*args:Any,**kwargs:Any)->Any:returnself.module.test_step(*LightningDoublePrecisionModule._move_float_tensors_to_double(args),**LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs),)defpredict_step(self,*args:Any,**kwargs:Any)->Any:returnself.module.predict_step(*LightningDoublePrecisionModule._move_float_tensors_to_double(args),**LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs),)defforward(self,*args:Any,**kwargs:Any)->Any:returnself.module(*LightningDoublePrecisionModule._move_float_tensors_to_double(args),**LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs),)
[docs]classDoublePrecisionPlugin(PrecisionPlugin):"""Plugin for training with double (``torch.float64``) precision."""precision:int=64
[docs]defconnect(self,model:nn.Module,optimizers:List[Optimizer],lr_schedulers:List[Any])->Tuple[nn.Module,List["Optimizer"],List[Any]]:"""Converts the model to double precision and wraps it in a ``LightningDoublePrecisionModule`` to convert incoming floating point data to double (``torch.float64``) precision. Does not alter `optimizers` or `lr_schedulers`. """model=cast(pl.LightningModule,model.double())model=LightningDoublePrecisionModule(model)returnsuper().connect(model,optimizers,lr_schedulers)
[docs]@contextmanagerdefforward_context(self)->Generator[None,None,None]:"""A context manager to change the default tensor type. See: :meth:`torch.set_default_tensor_type` """torch.set_default_tensor_type(torch.DoubleTensor)yieldtorch.set_default_tensor_type(torch.FloatTensor)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.