Load checkpoint when passing a model

I want to create a “Base” LightningModule that can accomodate various model architectures and loss functions as well as a config dictonary for various utilities via arguments because the other steps remain the same across configurations. I made a small example to demonstrate what I am trying to do.

from pytorch_lightning import LightningModule, Trainer
from typing import Dict, Any
import torch.nn as nn
import tempfile
import torch
import os


class MLP(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.model = nn.Sequential(nn.Linear(1, 10), nn.ReLU(), nn.Linear(10, 1))

    def forward(self, X: torch.Tensor):
        return self.model(X)


class TestModule(LightningModule):
    def __init__(
        self, config: Dict[str, Any], model: nn.Module, criterion: nn.Module
    ) -> None:
        super().__init__()

        self.model = model
        self.config = config

        self.criterion = criterion
        self.save_hyperparameters()

    def training_step(self, *args: Any, **kwargs: Any):
        out = self.model(args[0][0])
        loss = self.criterion(out, args[0][1])
        return loss

    def configure_optimizers(self) -> Any:
        optimizer = torch.optim.AdamW(
            self.model.parameters(), lr=self.config["optimizer"]["lr"]
        )
        return {"optimizer": optimizer}


datamodule = MyDataModule()

mlp = MLP()
save_path = os.path.join(tempfile.gettempdir(), "my_test")
config_dict = {
    "model": {"model_name": "mlp"},
    "optimizer": {"lr": 1e-3},
    "pl": {"max_epochs": 1, "default_root_dir": save_path},
}
model = TestModule(config_dict, mlp, nn.MSELoss())

trainer = Trainer(**config_dict["pl"])

trainer.fit(model, datamodule)

ckpt_path = os.path.join(
    save_path,
    "lightning_logs",
    "version_0",
    "checkpoints",
    "epoch=0-step=1.ckpt",
)

loaded_model = TestModule.load_from_checkpoint(ckpt_path)

I was hoping I could still conveniently load the checkpoint but I am getting TypeError: __init__() missing 3 required positional arguments: 'config', 'model', and 'criterion'. My question therefore is if there is a way to make it work like this or do I need to structure my setup differently?

You can pass in the required arguments like so:

TestModule.load_from_checkpoint(ckpt_path, config=config_dict, model=mlp, criterion=nn.MSELoss())

Thank you. Is this a recommended approach for this setting or is there a more “correct” way to do it in how checkpointing with lightning was designed?

For the purpose of instantiating the model and loading the weights in one line of code, yes I think that’s what I would recommend.

Given your description:

can accomodate various model architectures and loss functions as well as a config dictonary for various utilities

You might want to integrate something like a configuration system. Hydra is a popular library to do that. It might be worth using something like that if you need to support many different types of models, optimizers etc.

But if you want to keep things simple, then instantiating your objects yourself and using the load_from_checkpoint is probably fine :slight_smile:

Thank you for your recommendation about Hydra, that looks interesting and potentially like a great help. I also came across https://lightning.ai/forums/t/best-way-to-use-load-from-checkpoint-when-model-contains-other-models/2094/2 so going with option A would eliminate the need to pass additional arguments during losing the checkpoint and one could pass the model class name as a variable as well to keep the flexibility from using different backbone models.

1 Like