Problem
I’m having an issue where the model is training fine, and the saved checkpoint does indeed have the hparams used in training. When loading the model with MyModel.load_from_checkpoint()
, however, these hparams are not restored.
Code breakdown
Sorry the following code is the minimum working version I could make that can be executed and replicates my issue. I’m trying to support various non-Lightning pre-trained PyTorch weights and models (which I hope to make open source).
Simply said what each class does:
-
ModelBase
purpose: Load the model’s pre-trained weights based on givensample_rate
value. Also sets the other hparams to how the model was trained. -
Linear3
purpose: One of the many models that can be selected. Inherits fromModelBase
-
ModelTrainer
purpose: Do transfer learning. Can be swapped out so the user can choose whether to train for a multi-class or multi-label task. -
MyModel
purpose: Inherits from bothModelTrainer
andLinear3
to construct a “normal” PyTorch Lightning LightningModule. -
SimpleDataset
&SimpleDatamodule
purpose: Just to be able to use Trainer and it’ssave_checkpoint
function.
The following code can also be downloaded from: Demonstrates issue of self.hparams not being restored when loading from checkpoint. Details can be found here: https://lightning.ai/forums/t/hparams-not-restored-when-using-load-from-checkpoint-default-argument-values-are-the-problem/237 · GitHub
Script
CLICK ME to show code
from abc import abstractmethod
import torch
from torch import nn as nn
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy, to_categorical
from pytorch_lightning import Trainer
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
class ModelBase(pl.LightningModule):
def __init__(self, pretrained_hparams: bool, **kwargs): # **kwargs # sample_rate
print(f"Init ModelBase, hparams:\n{self.hparams}\n")
super().__init__()
print(f"Init ModelBase after, hparams:\n{self.hparams}\n")
# use PANNs.load_from_checkpoint when loading weights after transfer learning
if pretrained_hparams:
# save all arugments in self.hparams
self.save_hyperparameters()
print("Argument hparams: ", self.hparams)
# needed hparams for non-lightning pre-trained weights
self.set_pretrained_hparams()
# print("All hparams: ", self.hparams)
@abstractmethod
def forward(self, x):
pass
def set_pretrained_hparams(self):
if self.hparams["sample_rate"] == 8000:
self.hparams["hlayer1"] = 400
elif self.hparams["sample_rate"] == 16000:
self.hparams["hlayer1"] = 800
self.hparams["classes_num"] = 3
def load_non_lightning_weights(self, weights_path):
# checkpoint = torch.load(weights_path)
# self.load_state_dict(checkpoint['model'])
pass
# 1 variant
class Linear3(ModelBase):
def __init__(self, sample_rate, **kwargs):
print(f"Init Linear3, hparams:\n{self.hparams}\n")
super().__init__(sample_rate=sample_rate, **kwargs)
print(f"Init Linear3 after, hparams:\n{self.hparams}\n")
# 1 sec of audio
self.input_layer = nn.Linear(self.hparams["sample_rate"], self.hparams["hlayer1"], bias=True)
self.hidden_layer = nn.Linear(self.hparams["hlayer1"], 128, bias=True)
self.output_layer = nn.Linear(128, self.hparams["classes_num"], bias=True)
def forward(self, input):
x = F.relu_(self.input_layer(input))
x = F.relu_(self.hidden_layer(x))
output = self.output_layer(x) # torch.sigmoid()
return output
class ModelTrainer(pl.LightningModule):
# arguments should NOT be positional due to inherence; always have a default value
def __init__(self, learning_rate=1e-3, **kwargs): # **kwargs
print(f"Init ModelTrainer, hparams:\n{self.hparams}\n")
# everything included in init call will be included in self.hparams (here only kwargs is included);
# meaning only those will be saved in a .ckpt file
super().__init__(learning_rate=learning_rate, **kwargs) # **kwargs
print(f"Init ModelTrainer after, hparams:\n{self.hparams}\n")
self.criterion = nn.CrossEntropyLoss()
def calculate_loss(self, prediction, target):
"""Binary crossentropy loss"""
# loss = F.binary_cross_entropy_with_logits(prediction, target)
loss = self.criterion(prediction, target)
return loss
def training_step(self, batch, batch_idx):
input, target = batch
prediction = self(input)
loss = self.calculate_loss(prediction, target)
result = pl.TrainResult(minimize=loss)
result.log('train_loss', loss)
return result
def validation_step(self, batch, batch_idx):
input, target = batch
prediction = self(input)
loss = self.calculate_loss(prediction, target)
result = pl.EvalResult(checkpoint_on=loss)
result.log('val_loss', loss)
result.log('val_acc', accuracy(prediction, target))
return result
def test_step(self, batch, batch_idx):
input, target = batch
prediction = self(input)
loss = self.calculate_loss(prediction, target)
result = pl.EvalResult() # checkpoint_on=loss
result.log('test_loss', loss)
result.log('test_acc', accuracy(prediction, target)) # to_categorical()
return result
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams["learning_rate"])
class MyModel(ModelTrainer, Linear3):
def __init__(self, unfreeze_epoch=1, **kwargs):
# arguments passed here are stored in self.hparams
print(f"Init MyModel, hparams:\n{self.hparams}\n")
super().__init__(unfreeze_epoch=unfreeze_epoch, **kwargs) # unfreeze_epoch=unfreeze_epoch, **kwargs
print(f"Init MyModel after, hparams:\n{self.hparams}\n")
# print("hparams after init: ", self.hparams)
# self.unfreeze_epoch = unfreeze_epoch
# self.freeze()
def forward(self, input, mixup_lambda=None):
# unfreeze deep layers after unfreeze_epoch epochs
# if self.current_epoch == self.unfreeze_epoch:
# self.unfreeze()
x = F.relu_(self.input_layer(input))
x = F.relu_(self.hidden_layer(x))
output = self.output_layer(x) # torch.sigmoid()
return output
# DATA
class SimpleDataset(Dataset):
def __init__(self, sample_rate=8000):
self.sample_rate = sample_rate
def __len__(self):
return 16
def __getitem__(self, idx):
# 0, 1 or 2
target = torch.randint(0, 3, size=(1, )).squeeze()
# size 8000/16000 of 0.0, 0.5, or 1.0
input = torch.full((self.sample_rate,), (target.float()/2).item())
# torch.empty(self.sample_rate,).fill_(target.float()/2)
return input, target
class SimpleDatamodule(pl.LightningDataModule):
def setup(self, stage: str = None):
pass
def train_dataloader(self):
return DataLoader(SimpleDataset(), batch_size=4)
def val_dataloader(self):
return DataLoader(SimpleDataset(), batch_size=4)
# dataset = self._set_dataset_split("val")
# return DataLoader(dataset, batch_size=self.hparams["batch_size"],
# sampler=SubsetRandomSampler(dataset.indices), num_workers=4)
def test_dataloader(self):
return DataLoader(SimpleDataset(), batch_size=4)
if __name__ == '__main__':
sr = 8000
checkpoint_location = "example.ckpt"
# network
model = MyModel(sample_rate=8000, pretrained_hparams=True)
print("After all init, hparams:\n{self.hparams}\n")
# data
dm = SimpleDatamodule()
# train
trainer = Trainer(max_epochs=4, deterministic=True) # gpus=1,
trainer.fit(model, dm)
# save
trainer.save_checkpoint(checkpoint_location)
# check model contents
print(f"\n\nModel save completed. Checking contents saved model...")
checkpoint = torch.load(checkpoint_location)
print(f"Checkpoint hyper parameters: {checkpoint['hyper_parameters']}") # .keys() # ['state_dict']
# ERROR: load weights into new model
print("\nContents check completed. Trying to restore model with checkpoint...")
model2 = MyModel.load_from_checkpoint(checkpoint_location, pretrained_hparams=False)
# KeyError: 'sample_rate'
Script output
CLICK ME to show script output
$ python test_save_load.py
Init MyModel, hparams:
Init ModelTrainer, hparams:
Init Linear3, hparams:
Init ModelBase, hparams:
Init ModelBase after, hparams:
Argument hparams: "learning_rate": 0.001
"pretrained_hparams": True
"sample_rate": 8000
"unfreeze_epoch": 1
Init Linear3 after, hparams:
"classes_num": 3
"hlayer1": 400
"learning_rate": 0.001
"pretrained_hparams": True
"sample_rate": 8000
"unfreeze_epoch": 1
Init ModelTrainer after, hparams:
"classes_num": 3
"hlayer1": 400
"learning_rate": 0.001
"pretrained_hparams": True
"sample_rate": 8000
"unfreeze_epoch": 1
Init MyModel after, hparams:
"classes_num": 3
"hlayer1": 400
"learning_rate": 0.001
"pretrained_hparams": True
"sample_rate": 8000
"unfreeze_epoch": 1
After all init, hparams:
{self.hparams}
GPU available: True, used: False
TPU available: False, using: 0 TPU cores
/home/stefempath/anaconda3/envs/pytorchlit/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:37: UserWarning: GPU available but not used. Set the --gpus flag when calling the script.
warnings.warn(*args, **kwargs)
/home/stefempath/anaconda3/envs/pytorchlit/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:37: UserWarning: Could not log computational graph since the `model.example_input_array` attribute is not set or `input_array` was not given
warnings.warn(*args, **kwargs)
| Name | Type | Params
--------------------------------------------------
0 | input_layer | Linear | 3 M
1 | hidden_layer | Linear | 51 K
2 | output_layer | Linear | 387
3 | criterion | CrossEntropyLoss | 0
/home/stefempath/anaconda3/envs/pytorchlit/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:37: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 16 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
warnings.warn(*args, **kwargs)
/home/stefempath/anaconda3/envs/pytorchlit/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:37: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 16 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
warnings.warn(*args, **kwargs)
Epoch 3: 100%|████████████████Saving latest checkpoint..███████████████████████████| 8/8 [00:00<00:00, 103.70it/s, loss=1.651, v_num=28]
Epoch 3: 100%|█████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 102.18it/s, loss=1.651, v_num=28]
Model save completed. Checking contents saved model...
Checkpoint hyper parameters: "classes_num": 3
"hlayer1": 400
"learning_rate": 0.001
"pretrained_hparams": True
"sample_rate": 8000
"unfreeze_epoch": 1
Contents check completed. Trying to restore model with checkpoint...
Init MyModel, hparams:
Init ModelTrainer, hparams:
Init Linear3, hparams:
Init ModelBase, hparams:
Init ModelBase after, hparams:
Init Linear3 after, hparams:
Traceback (most recent call last):
File "test_save_load.py", line 183, in <module>
model2 = MyModel.load_from_checkpoint(checkpoint_location, pretrained_hparams=False)
File "/home/stefempath/anaconda3/envs/pytorchlit/lib/python3.7/site-packages/pytorch_lightning/core/saving.py", line 153, in load_from_checkpoint
model = cls._load_model_state(checkpoint, *args, strict=strict, **kwargs)
File "/home/stefempath/anaconda3/envs/pytorchlit/lib/python3.7/site-packages/pytorch_lightning/core/saving.py", line 190, in _load_model_state
model = cls(*cls_args, **cls_kwargs)
File "test_save_load.py", line 111, in __init__
super().__init__(unfreeze_epoch=unfreeze_epoch, **kwargs) # unfreeze_epoch=unfreeze_epoch, **kwargs
File "test_save_load.py", line 67, in __init__
super().__init__(learning_rate=learning_rate, **kwargs) # **kwargs
File "test_save_load.py", line 50, in __init__
self.input_layer = nn.Linear(self.hparams["sample_rate"], self.hparams["hlayer1"], bias=True)
KeyError: 'sample_rate'
Closer look at the issue
From the following code and output
checkpoint = torch.load(checkpoint_location)
print(f"Checkpoint hyper parameters: {checkpoint['hyper_parameters']}")
# Checkpoint hyper parameters:
# "classes_num": 3
# "hlayer1": 400
# "learning_rate": 0.001
# "pretrained_hparams": True
# "sample_rate": 8000
# "unfreeze_epoch": 1
we can see that nothing went wrong with the training and storing of the model, as we do have sample_rate
in there.
However the following code
model2 = MyModel.load_from_checkpoint(checkpoint_location, pretrained_hparams=False)
# self.input_layer = nn.Linear(self.hparams["sample_rate"], self.hparams["hlayer1"], bias=True)
# KeyError: 'sample_rate'
fails with a missing key. From the full script output we can also see that, before and after the call to super().__init__()
, there are NO values stored in self.hparams
.
From the documentation, if you DON’T want to use the values stored in the checkpoint, you would call self.save_hyperparameters()
. The loading code overwrites pretrained_hparams=False
, which means that this is NOT called:
# not called in loading, because pretrained_hparams is False
if pretrained_hparams:
# save all arugments in self.hparams
self.save_hyperparameters()
print("Argument hparams: ", self.hparams)
# needed hparams for non-lightning pre-trained weights
self.set_pretrained_hparams()
Therefore it shouldn’t be an issue of overriding self.hparams
.
Question
Why is self.hparams
not properly restored when I use .load_from_checkpoint()
, even though I avoid self.save_hyperparameters()
during training?
System
- Ubuntu 18.04
- PyTorch: 1.6.0
- PyTorch Lightning: 0.9.0 & 0.10.0 (tested both)
Reproduce conda environment
conda create --name pytorchlit10
conda activate pytorchlit10
conda install pytorch torchvision cudatoolkit=10.2 -c pytorch
pip install pytorch-lightning # 0.10.0