Hello!
I’m working on a LightningModule (A) containing LightningModules (A.1, A.2, A.3, etc). When I initialize the module A, some contained LightningModules are created from scratch, but others are already pretrained so they are loaded from their own ckpts.
My question is: which is the best way to save the module A to be used later? Right now to load the model I have to also put the pretrained submodules ckpts in the same place they where when the model was trained since the init() will look for them. I feel it could be better organized in order to only need a unique ckpt to use module A but I’m not shure about the best way to do it.
Thanks!
Hey, so basically you have two options: having the model instantiation inside or outside the LM.
Option A
class OptionA(L.LightningModule):
def __init__(self, **kwargs):
super().__init__()
self.save_hyperparameters()
self.model_a = MyModelA(**kwargs)
...
OptionA.load_from_checkpoint('path/to/checkpoint')
In this case, the submodel (model_a
) will always be part of the checkpoint of the LightningModule
and always be instantiated when a checkpoint is loaded (automatically)!
Option B
class OptionB(L.LightningModule):
def __init__(self, model_b: torch.nn.Module):
super().__init__()
self.save_hyperparameters()
self.model_b = model_b
...
OptionB.load_from_checkpoint('path/to/checkpoint', model_b=MyModelB())
In this case model_b
will still be part of the checkpoint, but not saved in the hyperparameters as it is not easily serializable. So OptionB.load_from_checkpoint
requires the keyword argument model_b=MyModelB()
in order to be able to instantiate the class. The checkpoint will then be loaded as usual since it still contains model_b
1 Like