Hi, i have the following dummy LightningModule
class MyLightningModule(LightningModule):
def __init__(
self,
param_1: torch.nn.Module = torch.nn.Conv2d(1,1,1)
param_2: torch.nn.Module = MyCustomModule(...)
):
super().__init__()
self.save_hyperparameters()
print(self.hparams.param_1) # prints out correctly
print(self.hparams.param_2) # prints out correctly
When I tried to load a checkpoint via MyLigntningModule.load_from_checkpoint(ckpt_path)
, I noticed that checkpoint[“hyper_parameters”] does NOT contain a key for param_2 while it DOES contain a key for param_1. I DO see the hparams.param_2 in my logger correctly printed, which i really weird.
For the param_2 is used a network from the escnn libarary which is derived from torch.nn.Module. I traced the problem back to using a any layer from that library:
import escnn.nn as enn
import escnn
param_1 = enn.R2Conv(enn.FieldType(escnn.gspaces.rot2dOnR2(8), [escnn.gspaces.rot2dOnR2(8).regular_repr]), enn.FieldType(escnn.gspaces.rot2dOnR2(8), [escnn.gspaces.rot2dOnR2(8).regular_repr]), 7),
What could be the reason that the custom model definition is not part of the checkpoint? Thanks in advance!