Best way to use load_from_checkpoint when model contains other models

Hello!

I’m working on a LightningModule (A) containing LightningModules (A.1, A.2, A.3, etc). When I initialize the module A, some contained LightningModules are created from scratch, but others are already pretrained so they are loaded from their own ckpts.

My question is: which is the best way to save the module A to be used later? Right now to load the model I have to also put the pretrained submodules ckpts in the same place they where when the model was trained since the init() will look for them. I feel it could be better organized in order to only need a unique ckpt to use module A but I’m not shure about the best way to do it.

Thanks!

Hey, so basically you have two options: having the model instantiation inside or outside the LM.

Option A

class OptionA(L.LightningModule):
    def __init__(self, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.model_a = MyModelA(**kwargs)

    ...

OptionA.load_from_checkpoint('path/to/checkpoint')

In this case, the submodel (model_a) will always be part of the checkpoint of the LightningModule and always be instantiated when a checkpoint is loaded (automatically)!

Option B

class OptionB(L.LightningModule):
    def __init__(self, model_b: torch.nn.Module):
        super().__init__()
        self.save_hyperparameters()
        self.model_b = model_b
    ...

OptionB.load_from_checkpoint('path/to/checkpoint', model_b=MyModelB())

In this case model_b will still be part of the checkpoint, but not saved in the hyperparameters as it is not easily serializable. So OptionB.load_from_checkpoint requires the keyword argument model_b=MyModelB() in order to be able to instantiate the class. The checkpoint will then be loaded as usual since it still contains model_b