Source code for pytorch_lightning.utilities.cloud_io
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Utilities related to data saving/loading."""importiofrompathlibimportPathfromtypingimportAny,Callable,Dict,IO,Optional,Unionimportfsspecimporttorchfromfsspec.coreimporturl_to_fsfromfsspec.implementations.localimportAbstractFileSystemfrompytorch_lightning.utilities.typesimport_PATH
[docs]defload(path_or_url:Union[IO,_PATH],map_location:Optional[Union[str,Callable,torch.device,Dict[Union[str,torch.device],Union[str,torch.device]]]]=None,)->Any:"""Loads a checkpoint. Args: path_or_url: Path or URL of the checkpoint. map_location: a function, ``torch.device``, string or a dict specifying how to remap storage locations. """ifnotisinstance(path_or_url,(str,Path)):# any sort of BytesIO or similarreturntorch.load(path_or_url,map_location=map_location)ifstr(path_or_url).startswith("http"):returntorch.hub.load_state_dict_from_url(str(path_or_url),map_location=map_location)fs=get_filesystem(path_or_url)withfs.open(path_or_url,"rb")asf:returntorch.load(f,map_location=map_location)
[docs]defatomic_save(checkpoint:Dict[str,Any],filepath:Union[str,Path])->None:"""Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints. Args: checkpoint: The object to save. Built to be used with the ``dump_checkpoint`` method, but can deal with anything which ``torch.save`` accepts. filepath: The path to which the checkpoint will be saved. This points to the file that the checkpoint will be stored in. """bytesbuffer=io.BytesIO()torch.save(checkpoint,bytesbuffer)withfsspec.open(filepath,"wb")asf:f.write(bytesbuffer.getvalue())
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.