# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importastimportcsvimportinspectimportloggingimportosfromargparseimportNamespacefromcopyimportdeepcopyfromenumimportEnumfromtypingimportAny,Callable,Dict,IO,MutableMapping,Optional,Unionfromwarningsimportwarnimporttorchimportyamlfrompytorch_lightning.utilitiesimport_OMEGACONF_AVAILABLE,AttributeDictfrompytorch_lightning.utilities.apply_funcimportapply_to_collectionfrompytorch_lightning.utilities.cloud_ioimportget_filesystemfrompytorch_lightning.utilities.cloud_ioimportloadaspl_loadfrompytorch_lightning.utilities.migrationimportpl_legacy_patchfrompytorch_lightning.utilities.parsingimportparse_class_init_keysfrompytorch_lightning.utilities.rank_zeroimportrank_zero_warnlog=logging.getLogger(__name__)PRIMITIVE_TYPES=(bool,int,float,str)ALLOWED_CONFIG_TYPES=(AttributeDict,MutableMapping,Namespace)if_OMEGACONF_AVAILABLE:fromomegaconfimportOmegaConffromomegaconf.dictconfigimportDictConfigfromomegaconf.errorsimportUnsupportedValueType,ValidationError# the older shall be on the topCHECKPOINT_PAST_HPARAMS_KEYS=("hparams","module_arguments")# used in 0.7.6
[docs]@classmethoddefload_from_checkpoint(cls,checkpoint_path:Union[str,IO],map_location:Optional[Union[Dict[str,str],str,torch.device,int,Callable]]=None,hparams_file:Optional[str]=None,strict:bool=True,**kwargs,):r""" Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to ``__init__`` in the checkpoint under ``"hyper_parameters"``. Any arguments specified through \*\*kwargs will override args stored in ``"hyper_parameters"``. Args: checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object map_location: If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in :func:`torch.load`. hparams_file: Optional path to a .yaml file with hierarchical structure as in this example:: drop_prob: 0.2 dataloader: batch_size: 32 You most likely won't need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don't have the hyperparameters saved, use this method to pass in a .yaml file with the hparams you'd like to use. These will be converted into a :class:`~dict` and passed into your :class:`LightningModule` for use. If your model's ``hparams`` argument is :class:`~argparse.Namespace` and .yaml file has hierarchical structure, you need to refactor your model to treat ``hparams`` as :class:`~dict`. strict: Whether to strictly enforce that the keys in :attr:`checkpoint_path` match the keys returned by this module's state dict. kwargs: Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values. Return: :class:`LightningModule` instance with loaded weights and hyperparameters (if available). Note: ``load_from_checkpoint`` is a **class** method. You should use your :class:`LightningModule` **class** to call it instead of the :class:`LightningModule` instance. Example:: # load weights without mapping ... model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') # or load weights mapping all weights from GPU 1 to GPU 0 ... map_location = {'cuda:1':'cuda:0'} model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', map_location=map_location ) # or load weights and hyperparameters from separate files. model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', hparams_file='/path/to/hparams_file.yaml' ) # override some of the params with new values model = MyLightningModule.load_from_checkpoint( PATH, num_layers=128, pretrained_ckpt_path=NEW_PATH, ) # predict pretrained_model.eval() pretrained_model.freeze() y_hat = pretrained_model(x) """withpl_legacy_patch():ifmap_locationisnotNone:checkpoint=pl_load(checkpoint_path,map_location=map_location)else:checkpoint=pl_load(checkpoint_path,map_location=lambdastorage,loc:storage)ifhparams_fileisnotNone:extension=hparams_file.split(".")[-1]ifextension.lower()=="csv":hparams=load_hparams_from_tags_csv(hparams_file)elifextension.lower()in("yml","yaml"):hparams=load_hparams_from_yaml(hparams_file)else:raiseValueError(".csv, .yml or .yaml is required for `hparams_file`")hparams["on_gpu"]=False# overwrite hparams by the given filecheckpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]=hparams# for past checkpoint need to add the new keyifcls.CHECKPOINT_HYPER_PARAMS_KEYnotincheckpoint:checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]={}# override the hparams with values that were passed incheckpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)model=cls._load_model_state(checkpoint,strict=strict,**kwargs)returnmodel
@classmethoddef_load_model_state(cls,checkpoint:Dict[str,Any],strict:bool=True,**cls_kwargs_new):cls_spec=inspect.getfullargspec(cls.__init__)cls_init_args_name=inspect.signature(cls.__init__).parameters.keys()self_var,args_var,kwargs_var=parse_class_init_keys(cls)drop_names=[nfornin(self_var,args_var,kwargs_var)ifn]cls_init_args_name=list(filter(lambdan:nnotindrop_names,cls_init_args_name))cls_kwargs_loaded={}# pass in the values we saved automaticallyifcls.CHECKPOINT_HYPER_PARAMS_KEYincheckpoint:# 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keysfor_old_hparam_keyinCHECKPOINT_PAST_HPARAMS_KEYS:cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key,{}))# 2. Try to restore model hparams from checkpoint using the new key_new_hparam_key=cls.CHECKPOINT_HYPER_PARAMS_KEYcls_kwargs_loaded.update(checkpoint.get(_new_hparam_key))# 3. Ensure that `cls_kwargs_old` has the right type, back compatibility between dict and Namespacecls_kwargs_loaded=_convert_loaded_hparams(cls_kwargs_loaded,checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE))# 4. Update cls_kwargs_new with cls_kwargs_old, such that new has higher priorityargs_name=checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)ifargs_nameandargs_nameincls_init_args_name:cls_kwargs_loaded={args_name:cls_kwargs_loaded}_cls_kwargs={}_cls_kwargs.update(cls_kwargs_loaded)_cls_kwargs.update(cls_kwargs_new)ifnotcls_spec.varkw:# filter kwargs according to class init unless it allows any argument via kwargs_cls_kwargs={k:vfork,vin_cls_kwargs.items()ifkincls_init_args_name}model=cls(**_cls_kwargs)# give model a chance to load somethingmodel.on_load_checkpoint(checkpoint)# load the state_dict on the model automaticallykeys=model.load_state_dict(checkpoint["state_dict"],strict=strict)ifnotstrict:ifkeys.missing_keys:rank_zero_warn(f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}")ifkeys.unexpected_keys:rank_zero_warn(f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}")returnmodel# -------------------------# OPTIONAL HOOKS# -------------------------
[docs]defon_hpc_save(self,checkpoint:Dict[str,Any])->None:"""Hook to do whatever you need right before Slurm manager saves the model. Args: checkpoint: A dictionary in which you can save variables to save in a checkpoint. Contents need to be pickleable. .. deprecated:: v1.6 This method is deprecated in v1.6 and will be removed in v1.8. Please use ``LightningModule.on_save_checkpoint`` instead. """
[docs]defon_hpc_load(self,checkpoint:Dict[str,Any])->None:"""Hook to do whatever you need right before Slurm manager loads the model. Args: checkpoint: A dictionary with variables from the checkpoint. .. deprecated:: v1.6 This method is deprecated in v1.6 and will be removed in v1.8. Please use ``LightningModule.on_load_checkpoint`` instead. """
def_convert_loaded_hparams(model_args:dict,hparams_type:Optional[Union[Callable,str]]=None)->object:"""Convert hparams according given type in callable or string (past) format."""# if not hparams type defineifnothparams_type:returnmodel_args# if past checkpoint loaded, convert str to callableifisinstance(hparams_type,str):hparams_type=AttributeDict# convert hparamsreturnhparams_type(model_args)defupdate_hparams(hparams:dict,updates:dict)->None:"""Overrides hparams with new values. >>> hparams = {'c': 4} >>> update_hparams(hparams, {'a': {'b': 2}, 'c': 1}) >>> hparams['a']['b'], hparams['c'] (2, 1) >>> update_hparams(hparams, {'a': {'b': 4}, 'c': 7}) >>> hparams['a']['b'], hparams['c'] (4, 7) Args: hparams: the original params and also target object updates: new params to be used as update """fork,vinupdates.items():# if missing, add the keyifknotinhparams:hparams[k]=vcontinue# recurse if dictionaryifisinstance(v,dict):update_hparams(hparams[k],updates[k])else:# update the valuehparams.update({k:v})defload_hparams_from_tags_csv(tags_csv:str)->Dict[str,Any]:"""Load hparams from a file. >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') >>> path_csv = os.path.join('.', 'testing-hparams.csv') >>> save_hparams_to_tags_csv(path_csv, hparams) >>> hparams_new = load_hparams_from_tags_csv(path_csv) >>> vars(hparams) == hparams_new True >>> os.remove(path_csv) """fs=get_filesystem(tags_csv)ifnotfs.exists(tags_csv):rank_zero_warn(f"Missing Tags: {tags_csv}.",category=RuntimeWarning)return{}withfs.open(tags_csv,"r",newline="")asfp:csv_reader=csv.reader(fp,delimiter=",")tags={row[0]:convert(row[1])forrowinlist(csv_reader)[1:]}returntagsdefsave_hparams_to_tags_csv(tags_csv:str,hparams:Union[dict,Namespace])->None:fs=get_filesystem(tags_csv)ifnotfs.isdir(os.path.dirname(tags_csv)):raiseRuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.")ifisinstance(hparams,Namespace):hparams=vars(hparams)withfs.open(tags_csv,"w",newline="")asfp:fieldnames=["key","value"]writer=csv.DictWriter(fp,fieldnames=fieldnames)writer.writerow({"key":"key","value":"value"})fork,vinhparams.items():writer.writerow({"key":k,"value":v})defload_hparams_from_yaml(config_yaml:str,use_omegaconf:bool=True)->Dict[str,Any]:"""Load hparams from a file. Args: config_yaml: Path to config yaml file use_omegaconf: If omegaconf is available and ``use_omegaconf=True``, the hparams will be converted to ``DictConfig`` if possible. >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') >>> path_yaml = './testing-hparams.yaml' >>> save_hparams_to_yaml(path_yaml, hparams) >>> hparams_new = load_hparams_from_yaml(path_yaml) >>> vars(hparams) == hparams_new True >>> os.remove(path_yaml) """fs=get_filesystem(config_yaml)ifnotfs.exists(config_yaml):rank_zero_warn(f"Missing Tags: {config_yaml}.",category=RuntimeWarning)return{}withfs.open(config_yaml,"r")asfp:hparams=yaml.full_load(fp)if_OMEGACONF_AVAILABLE:ifuse_omegaconf:try:returnOmegaConf.create(hparams)except(UnsupportedValueType,ValidationError):passreturnhparamsdefsave_hparams_to_yaml(config_yaml,hparams:Union[dict,Namespace],use_omegaconf:bool=True)->None:""" Args: config_yaml: path to new YAML file hparams: parameters to be saved use_omegaconf: If omegaconf is available and ``use_omegaconf=True``, the hparams will be converted to ``DictConfig`` if possible. """fs=get_filesystem(config_yaml)ifnotfs.isdir(os.path.dirname(config_yaml)):raiseRuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")# convert Namespace or AD to dictifisinstance(hparams,Namespace):hparams=vars(hparams)elifisinstance(hparams,AttributeDict):hparams=dict(hparams)# saving with OmegaConf objectsif_OMEGACONF_AVAILABLEanduse_omegaconf:# deepcopy: hparams from user shouldn't be resolvedhparams=deepcopy(hparams)hparams=apply_to_collection(hparams,DictConfig,OmegaConf.to_container,resolve=True)withfs.open(config_yaml,"w",encoding="utf-8")asfp:try:OmegaConf.save(hparams,fp)returnexcept(UnsupportedValueType,ValidationError):passifnotisinstance(hparams,dict):raiseTypeError("hparams must be dictionary")hparams_allowed={}# drop parameters which contain some strange datatypes as fsspecfork,vinhparams.items():try:v=v.nameifisinstance(v,Enum)elsevyaml.dump(v)exceptTypeError:warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.")hparams[k]=type(v).__name__else:hparams_allowed[k]=v# saving the standard waywithfs.open(config_yaml,"w",newline="")asfp:yaml.dump(hparams_allowed,fp)defconvert(val:str)->Union[int,float,bool,str]:try:returnast.literal_eval(val)except(ValueError,SyntaxError)aserr:log.debug(err)returnval
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.