Custom model definition is not included in checkpoint hyper_parameters

Hi, i have the following dummy LightningModule

class MyLightningModule(LightningModule):
    def __init__(
        param_1: torch.nn.Module = torch.nn.Conv2d(1,1,1)
        param_2: torch.nn.Module = MyCustomModule(...)
    print(self.hparams.param_1) # prints out correctly
    print(self.hparams.param_2) # prints out correctly

When I tried to load a checkpoint via MyLigntningModule.load_from_checkpoint(ckpt_path), I noticed that checkpoint[“hyper_parameters”] does NOT contain a key for param_2 while it DOES contain a key for param_1. I DO see the hparams.param_2 in my logger correctly printed, which i really weird.

For the param_2 is used a network from the escnn libarary which is derived from torch.nn.Module. I traced the problem back to using a any layer from that library:

import escnn.nn as enn
import escnn

param_1 = enn.R2Conv(enn.FieldType(escnn.gspaces.rot2dOnR2(8), [escnn.gspaces.rot2dOnR2(8).regular_repr]), enn.FieldType(escnn.gspaces.rot2dOnR2(8), [escnn.gspaces.rot2dOnR2(8).regular_repr]), 7),

What could be the reason that the custom model definition is not part of the checkpoint? Thanks in advance!


Here is a demonstration of what you want based on your code example:

import torch
from import DataLoader, Dataset

from lightning.pytorch import LightningModule, Trainer

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length = torch.randn(length, size)

    def __getitem__(self, index):

    def __len__(self):
        return self.len

MyCustomModule = torch.nn.Linear

class MyLightningModule(LightningModule):
    def __init__(
        param_1=torch.nn.Conv2d(1, 1, 1),
        param_2=MyCustomModule(2, 2),
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.1)

def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    model = MyLightningModule()
    trainer = Trainer(max_steps=1), train_dataloaders=train_data)

    path = trainer.checkpoint_callback.best_model_path
    checkpoint = torch.load(path)
    print("hyper parameters:", list(checkpoint["hyper_parameters"].keys()))
    # hyper parameters: ['param_1', 'param_2']

    new_model = MyLightningModule.load_from_checkpoint(path)
    # works!

if __name__ == "__main__":

This works (I used Lightnign 2.0), feel free to run it and check the outputs.
While this can be done, I highly recommend not pickling entire module objects as “hyperparameters” into your checkpoint.

Best practice is to exclude modules past to the init, like so:

self.save_hyperparameters(ignore=["param1", "param2"])

When loading the checkpoint, you can pass the additional param1 and param2 as input:

model1 = ...
model2 = ...
new_model = MyLightningModule.load_from_checkpoint(path, param1=model1, param2=model2)

Helpful docs

Hope this helps.

Hi @awaelchli ,

thanks for your reply and your example. It runs smoothly. Debugging from there, i found out my issue was that my submodule is not pickable and is therefore removed from hparams.

Reagrding your suggestions: I have a LightningModule which defined a training procedure. I want to use different networks with its own parameters within this procedure. Most importantly, i want to recover everything solely from a checkpoint. Can i do something like this without saving the entire submodule to hparams? I believe in

model1 = ...

i would need to know beforehand which submodule have been used together with the checkpoint, right? Thanks in advance!