I have a Nested model
class MovieScoreTask(pl.LightningModule):
def __init__(self, base_model:nn.Module, learning_rate:float):
super().__init__()
self.save_hyperparameters()
# self.example_input_array = torch.randint(0,100,(10,2))
self.base_model= base_model
self.lr = learning_rate
When I run it, the Lightning Tool show some thing like
Attribute 'base_model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['base_model'])`.
if I follow the instruction, when I want to recover from checkpoint file
ckpt= torch.load(ckpt_path)
base= BaseModel(**config)
base.load_state_dict(ckpt['state_dict']) # This will raise bug beacuse state_dict key is now "base_model.layer.weight" not "layer.weight" that cannot not be loaded
So is the Best practice is still use self.save_hyperparameters() without ignore parameter set ?
If it is, the Tutorial here seems falut.
it should be
checkpoint = torch.load(CKPT_PATH)
encoder_weights = checkpoint["hyper_parameters"]["encoder"]
decoder_weights = checkpoint["hyper_parameters"]["decoder"]