ModelIO¶
- class pytorch_lightning.core.saving.ModelIO[source]¶
Bases:
object
- classmethod load_from_checkpoint(checkpoint_path, map_location=None, hparams_file=None, strict=True, **kwargs)[source]¶
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to
__init__
in the checkpoint under"hyper_parameters"
.Any arguments specified through **kwargs will override args stored in
"hyper_parameters"
.- Parameters
checkpoint_path¶ (
Union
[str
,IO
]) – Path to checkpoint. This can also be a URL, or file-like objectmap_location¶ (
Union
[Dict
[str
,str
],str
,device
,int
,Callable
,None
]) – If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as intorch.load()
.hparams_file¶ (
Optional
[str
]) –Optional path to a .yaml file with hierarchical structure as in this example:
drop_prob: 0.2 dataloader: batch_size: 32
You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don’t have the hyperparameters saved, use this method to pass in a .yaml file with the hparams you’d like to use. These will be converted into a
dict
and passed into yourLightningModule
for use.If your model’s
hparams
argument isNamespace
and .yaml file has hierarchical structure, you need to refactor your model to treathparams
asdict
.strict¶ (
bool
) – Whether to strictly enforce that the keys incheckpoint_path
match the keys returned by this module’s state dict.kwargs¶ – Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values.
- Returns
LightningModule
instance with loaded weights and hyperparameters (if available).
Note
load_from_checkpoint
is a class method. You should use yourLightningModule
class to call it instead of theLightningModule
instance.Example:
# load weights without mapping ... model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') # or load weights mapping all weights from GPU 1 to GPU 0 ... map_location = {'cuda:1':'cuda:0'} model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', map_location=map_location ) # or load weights and hyperparameters from separate files. model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', hparams_file='/path/to/hparams_file.yaml' ) # override some of the params with new values model = MyLightningModule.load_from_checkpoint( PATH, num_layers=128, pretrained_ckpt_path=NEW_PATH, ) # predict pretrained_model.eval() pretrained_model.freeze() y_hat = pretrained_model(x)
- on_hpc_load(checkpoint)[source]¶
Hook to do whatever you need right before Slurm manager loads the model.
Deprecated since version v1.6: This method is deprecated in v1.6 and will be removed in v1.8. Please use
LightningModule.on_load_checkpoint
instead.- Return type
- on_hpc_save(checkpoint)[source]¶
Hook to do whatever you need right before Slurm manager saves the model.
- Parameters
checkpoint¶ (
Dict
[str
,Any
]) – A dictionary in which you can save variables to save in a checkpoint. Contents need to be pickleable.
Deprecated since version v1.6: This method is deprecated in v1.6 and will be removed in v1.8. Please use
LightningModule.on_save_checkpoint
instead.- Return type