The Timm library offers a great variety of models, and with hydra/omegaconf one could write a config file like
model:
_target_: timm.create_model
model_name: resnet18
And be able to initiatlize a LightningModule, that takes model: nn.Module
as an argument. I tried doing the same with LightningCLI but am not sure how to configure that correctly, because Import path timm.create_model does not correspond to a subclass of <class 'torch.nn.modules.module.Module'>
. Here is the example:
from lightning import LightningModule, LightningDataModule
from lightning.pytorch.cli import LightningCLI, ArgsType
from lightning.pytorch.cli import OptimizerCallable
import os
import torch
from typing import Optional
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from lightning.pytorch.cli import OptimizerCallable, LightningCLI
class RandomDataset(Dataset):
def __init__(self, size=5, num_samples=10):
self.len = num_samples
self.data = torch.randn(num_samples, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class RandomDataModule(LightningDataModule):
def __init__(self, size=5, num_samples=10, batch_size=2):
super().__init__()
self.size = size
self.num_samples = num_samples
self.batch_size = batch_size
def setup(self, stage=None):
self.train_dataset = RandomDataset(self.size, self.num_samples)
self.val_dataset = RandomDataset(self.size, self.num_samples)
self.test_dataset = RandomDataset(self.size, self.num_samples)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size)
class BoringModel(LightningModule):
def __init__(self, model: nn.Module, loss_fn: nn.Module = nn.CrossEntropyLoss(), optimizer: OptimizerCallable = torch.optim.Adam):
super().__init__()
self.model = model
self.optimizer = optimizer
self.loss_fn = loss_fn
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def exclude_some_params(self):
return self.model.network[0].parameters()
def configure_optimizers(self):
return self.optimizer(self.exclude_some_params())
def cli_main(args: ArgsType = None):
LightningCLI(args=args, auto_configure_optimizers=False, parser_kwargs={"error_handler": None})
if __name__ == "__main__":
cli_main(["fit", "--config", "example_yaml.yaml"])
And the config file:
trainer: # pytorch lightning trainer arguments
max_epochs: 2
data: # datamodule arguments
class_path: example.RandomDataModule
init_args:
size: 2
model:
class_path: example.BoringModel
init_args:
model:
class_path: timm.create_model
init_args:
model_name: "resnet18"
num_classes: 2
in_chans: 3
pretrained: True
optimizer:
class_path: torch.optim.SGD
init_args:
lr: 0.003