I have the case, where I want to use Lightning-CLI to instantiate an object inside my LightningModule but cannot define all arguments for that object in the config file. With hydra this could be done via partial initialization, but I am not sure how to achieve the same with lightning-cli. I have looked at this class type defaults section in the docs but it doesn’t exactly cover the case, since I don’t know the arguments to that object beforehand. I have tried to create an example to illustrate what I am looking to do:
from lightning import LightningModule, LightningDataModule
from lightning.pytorch.cli import LightningCLI, ArgsType
from lightning.pytorch.cli import OptimizerCallable
import os
import torch
from typing import Optional
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from lightning.pytorch.cli import OptimizerCallable, LightningCLI
class MyExtraObject(nn.Module):
def __init__(self, arg1: torch.Tensor, arg2: torch.Tensor) -> None:
super().__init__()
self.arg1 = arg1
self.arg2 = arg2
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.arg1 * x + self.arg2
class MyNetwork(nn.Module):
def __init__(self, num_inputs: int, num_outputs: int):
super().__init__()
self.network = nn.Sequential(
nn.Linear(num_inputs, 32),
nn.Linear(32, num_outputs)
)
def forward(self, x):
return self.network(x)
class RandomDataset(Dataset):
def __init__(self, size=5, num_samples=10):
self.len = num_samples
self.data = torch.randn(num_samples, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class RandomDataModule(LightningDataModule):
def __init__(self, size=5, num_samples=10, batch_size=2):
super().__init__()
self.size = size
self.num_samples = num_samples
self.batch_size = batch_size
def setup(self, stage=None):
self.train_dataset = RandomDataset(self.size, self.num_samples)
self.val_dataset = RandomDataset(self.size, self.num_samples)
self.test_dataset = RandomDataset(self.size, self.num_samples)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size)
class BoringModel(LightningModule):
def __init__(self, model: nn.Module, extra_object: MyExtraObject, loss_fn: nn.Module = nn.CrossEntropyLoss(), optimizer: OptimizerCallable = torch.optim.Adam):
super().__init__()
self.model = model
self.optimizer = optimizer
self.loss_fn = loss_fn
# some other process or function calls defining the missing argument to the extra object
# this is a toy example trying to illustrate the point
# then initialze the extra object
arg1, arg2 = do_the_computation()
self.extra_object = extra_object(arg1, arg2)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def exclude_some_params(self):
return self.model.network[0].parameters()
def configure_optimizers(self):
return self.optimizer(self.exclude_some_params())
def cli_main(args: ArgsType = None):
LightningCLI(args=args, auto_configure_optimizers=False, parser_kwargs={"error_handler": None})
if __name__ == "__main__":
cli_main(["fit", "--config", "example_yaml.yaml"])
and the corresponding yaml file:
trainer: # pytorch lightning trainer arguments
max_epochs: 2
data: # datamodule arguments
class_path: example.RandomDataModule
init_args:
size: 2
model:
class_path: example.BoringModel
init_args:
model:
class_path: example.MyNetwork
init_args:
num_inputs: 2
num_outputs: 2
extra_object:
class_path: example.MyExtraObject
optimizer:
class_path: torch.optim.SGD
init_args:
lr: 0.003