Best way to use load_from_checkpoint when model contains other models

Hey, so basically you have two options: having the model instantiation inside or outside the LM.

Option A

class OptionA(L.LightningModule):
    def __init__(self, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.model_a = MyModelA(**kwargs)

    ...

OptionA.load_from_checkpoint('path/to/checkpoint')

In this case, the submodel (model_a) will always be part of the checkpoint of the LightningModule and always be instantiated when a checkpoint is loaded (automatically)!

Option B

class OptionB(L.LightningModule):
    def __init__(self, model_b: torch.nn.Module):
        super().__init__()
        self.save_hyperparameters()
        self.model_b = model_b
    ...

OptionB.load_from_checkpoint('path/to/checkpoint', model_b=MyModelB())

In this case model_b will still be part of the checkpoint, but not saved in the hyperparameters as it is not easily serializable. So OptionB.load_from_checkpoint requires the keyword argument model_b=MyModelB() in order to be able to instantiate the class. The checkpoint will then be loaded as usual since it still contains model_b

1 Like