Hey, so basically you have two options: having the model instantiation inside or outside the LM.
Option A
class OptionA(L.LightningModule):
def __init__(self, **kwargs):
super().__init__()
self.save_hyperparameters()
self.model_a = MyModelA(**kwargs)
...
OptionA.load_from_checkpoint('path/to/checkpoint')
In this case, the submodel (model_a
) will always be part of the checkpoint of the LightningModule
and always be instantiated when a checkpoint is loaded (automatically)!
Option B
class OptionB(L.LightningModule):
def __init__(self, model_b: torch.nn.Module):
super().__init__()
self.save_hyperparameters()
self.model_b = model_b
...
OptionB.load_from_checkpoint('path/to/checkpoint', model_b=MyModelB())
In this case model_b
will still be part of the checkpoint, but not saved in the hyperparameters as it is not easily serializable. So OptionB.load_from_checkpoint
requires the keyword argument model_b=MyModelB()
in order to be able to instantiate the class. The checkpoint will then be loaded as usual since it still contains model_b