Lightning CLI "partial" instances of lightning module arguments where arguments to that object cannot be defined in config

I have the case, where I want to use Lightning-CLI to instantiate an object inside my LightningModule but cannot define all arguments for that object in the config file. With hydra this could be done via partial initialization, but I am not sure how to achieve the same with lightning-cli. I have looked at this class type defaults section in the docs but it doesn’t exactly cover the case, since I don’t know the arguments to that object beforehand. I have tried to create an example to illustrate what I am looking to do:

from lightning import LightningModule, LightningDataModule
from lightning.pytorch.cli import LightningCLI, ArgsType
from lightning.pytorch.cli import OptimizerCallable

import os
import torch
from typing import Optional
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


from lightning.pytorch.cli import OptimizerCallable, LightningCLI

class MyExtraObject(nn.Module):
    def __init__(self, arg1: torch.Tensor, arg2: torch.Tensor) -> None:
        super().__init__()

        self.arg1 = arg1
        self.arg2 = arg2

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.arg1 * x + self.arg2
    
class MyNetwork(nn.Module):
    def __init__(self, num_inputs: int, num_outputs: int):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(num_inputs, 32),
            nn.Linear(32, num_outputs)
        )

    def forward(self, x):
        return self.network(x)

class RandomDataset(Dataset):
    def __init__(self, size=5, num_samples=10):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len
    
class RandomDataModule(LightningDataModule):
    def __init__(self, size=5, num_samples=10, batch_size=2):
        super().__init__()
        self.size = size
        self.num_samples = num_samples
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_dataset = RandomDataset(self.size, self.num_samples)
        self.val_dataset = RandomDataset(self.size, self.num_samples)
        self.test_dataset = RandomDataset(self.size, self.num_samples)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

class BoringModel(LightningModule):
    def __init__(self, model: nn.Module, extra_object: MyExtraObject, loss_fn: nn.Module = nn.CrossEntropyLoss(), optimizer: OptimizerCallable = torch.optim.Adam):
        super().__init__()

        self.model = model
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        
        # some other process or function calls defining the missing argument to the extra object
        # this is a toy example trying to illustrate the point
        # then initialze the extra object
        arg1, arg2 = do_the_computation()
        self.extra_object = extra_object(arg1, arg2)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def exclude_some_params(self):
        return self.model.network[0].parameters()

    def configure_optimizers(self):
        return self.optimizer(self.exclude_some_params())


def cli_main(args: ArgsType = None):
    LightningCLI(args=args, auto_configure_optimizers=False, parser_kwargs={"error_handler": None})

if __name__ == "__main__":
    cli_main(["fit", "--config", "example_yaml.yaml"])

and the corresponding yaml file:

trainer: # pytorch lightning trainer arguments
  max_epochs: 2

data: # datamodule arguments
  class_path: example.RandomDataModule
  init_args:
    size: 2

model:
  class_path: example.BoringModel
  init_args:
    model:
      class_path: example.MyNetwork
      init_args:
        num_inputs: 2
        num_outputs: 2
    extra_object:
      class_path: example.MyExtraObject
    optimizer:
      class_path: torch.optim.SGD
      init_args:
        lr: 0.003

@mauvilsa

jsonargparse uses the type hints of the parameters to decide what to do. Having extra_object: MyExtraObject means that an instance of MyExtraObject must be provided to extra_object. Giving something which is not an instance of MyExtraObject would be invalid.

Note that your case is the same as the optimizers. To instantiate an optimizer, the parameters of the model need to be given. The model can’t have as parameter an optimizer instance. Have a look at Configure hyperparameters from the CLI (Advanced) — PyTorch Lightning 2.1.2 documentation

Basically, the parameter should be defined like create_extra_object: Callable[[torch.Tensor, torch.Tensor], MyExtraObject]. That is, the module should receive a function that when called with two positional tensors, it returns an instance of MyExtraObject. Then in the body you would do self.extra_object = create_extra_object(arg1, arg2).

Thanks for your reply. I have a follow up question, namely, what happens if some of the arguments to MyExtraObject can be defined through the config i.e a string, and some need to be computed dynamically like before. I get the error:

argparse.ArgumentError: Parser key "model":
  Problem with given class_path 'example.BoringModel':
    Parser key "extra_object":
      Type typing.Callable[[str, torch.Tensor], example.MyExtraObject] expects a function or a callable class: Validation failed: No action for key "arg1" to check its value.. Got value: Namespace(class_path='example.MyExtraObject', init_args=Namespace(arg1='hello'))

But I am not sure what this error message means, even with JSONARGPARSE_DEBUG=true.

I have update the example:

from lightning import LightningModule, LightningDataModule
from lightning.pytorch.cli import LightningCLI, ArgsType
from lightning.pytorch.cli import OptimizerCallable

import os
import torch
from typing import Optional, Callable
import torch.nn as nn
from torch import Tensor
from torch.utils.data import Dataset, DataLoader


from lightning.pytorch.cli import OptimizerCallable, LightningCLI

class MyExtraObject(nn.Module):
    def __init__(self, arg1: str, arg2: torch.Tensor) -> None:
        super().__init__()

        self.arg1 = arg1
        self.arg2 = arg2

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.arg2
    
class MyNetwork(nn.Module):
    def __init__(self, num_inputs: int, num_outputs: int):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(num_inputs, 32),
            nn.Linear(32, num_outputs)
        )

    def forward(self, x):
        return self.network(x)

class RandomDataset(Dataset):
    def __init__(self, size=5, num_samples=10):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len
    
class RandomDataModule(LightningDataModule):
    def __init__(self, size=5, num_samples=10, batch_size=2):
        super().__init__()
        self.size = size
        self.num_samples = num_samples
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_dataset = RandomDataset(self.size, self.num_samples)
        self.val_dataset = RandomDataset(self.size, self.num_samples)
        self.test_dataset = RandomDataset(self.size, self.num_samples)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

class BoringModel(LightningModule):
    def __init__(self, model: nn.Module, extra_object: Callable[[str, Tensor], MyExtraObject], loss_fn: nn.Module = nn.CrossEntropyLoss(), optimizer: OptimizerCallable = torch.optim.Adam):
        super().__init__()

        self.model = model
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        
        # some other process or function calls defining the missing argument to the extra object
        # this is a toy example trying to illustrate the point
        # then initialze the extra object
        arg2 = torch.randn(1)
        self.extra_object = extra_object(arg2)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def exclude_some_params(self):
        return self.model.network[0].parameters()

    def configure_optimizers(self):
        return self.optimizer(self.exclude_some_params())


def cli_main(args: ArgsType = None):
    LightningCLI(args=args, auto_configure_optimizers=False, parser_kwargs={"error_handler": None})

if __name__ == "__main__":
    cli_main(["fit", "--config", "example_yaml.yaml"])

and

trainer: # pytorch lightning trainer arguments
  max_epochs: 2

data: # datamodule arguments
  class_path: example.RandomDataModule
  init_args:
    size: 2

model:
  class_path: example.BoringModel
  init_args:
    model:
      class_path: example.MyNetwork
      init_args:
        num_inputs: 2
        num_outputs: 2
    extra_object:
      class_path: example.MyExtraObject
      init_args:
        arg1: hello
    optimizer:
      class_path: torch.optim.SGD
      init_args:
        lr: 0.003

A type Callable[[str, Tensor], MyExtraObject] means that a function must be provided that receives two positional arguments of types str and Tensor and return an instance of MyExtraObject. In a config you could give the path to a function that implements such a signature, like extra_object: path.to.a.function.

Note that the class MyExtraObject is not a function. When in a config you give a class_path, then jsonargparse internally creates a function with such a signature. The two positional arguments of the callable type are associated to the two first arguments of MyExtraObject.__init__. So you can’t give these two arguments in the init_args, since they should be given as arguments to this function. It is like the optimizer. You can’t give params in the config, but lr you can.