I want to create a “Base” LightningModule that can accomodate various model architectures and loss functions as well as a config dictonary for various utilities via arguments because the other steps remain the same across configurations. I made a small example to demonstrate what I am trying to do.
from pytorch_lightning import LightningModule, Trainer
from typing import Dict, Any
import torch.nn as nn
import tempfile
import torch
import os
class MLP(nn.Module):
def __init__(self) -> None:
super().__init__()
self.model = nn.Sequential(nn.Linear(1, 10), nn.ReLU(), nn.Linear(10, 1))
def forward(self, X: torch.Tensor):
return self.model(X)
class TestModule(LightningModule):
def __init__(
self, config: Dict[str, Any], model: nn.Module, criterion: nn.Module
) -> None:
super().__init__()
self.model = model
self.config = config
self.criterion = criterion
self.save_hyperparameters()
def training_step(self, *args: Any, **kwargs: Any):
out = self.model(args[0][0])
loss = self.criterion(out, args[0][1])
return loss
def configure_optimizers(self) -> Any:
optimizer = torch.optim.AdamW(
self.model.parameters(), lr=self.config["optimizer"]["lr"]
)
return {"optimizer": optimizer}
datamodule = MyDataModule()
mlp = MLP()
save_path = os.path.join(tempfile.gettempdir(), "my_test")
config_dict = {
"model": {"model_name": "mlp"},
"optimizer": {"lr": 1e-3},
"pl": {"max_epochs": 1, "default_root_dir": save_path},
}
model = TestModule(config_dict, mlp, nn.MSELoss())
trainer = Trainer(**config_dict["pl"])
trainer.fit(model, datamodule)
ckpt_path = os.path.join(
save_path,
"lightning_logs",
"version_0",
"checkpoints",
"epoch=0-step=1.ckpt",
)
loaded_model = TestModule.load_from_checkpoint(ckpt_path)
I was hoping I could still conveniently load the checkpoint but I am getting TypeError: __init__() missing 3 required positional arguments: 'config', 'model', and 'criterion'
. My question therefore is if there is a way to make it work like this or do I need to structure my setup differently?