[CLI] How to Pass Arguments to Initialize an Object in L.LightningModule?

I want to use Lightning CLI to pass arguments to initialize a LightningModule and some objects inside (e.g., a nn.Module). Lightning CLI provides some helpful features that allow me to create and configure objects by just setting values in a config file and/or passing command line arguments.

Meanwhile, quite often when people create a model, instead of just passing parameters to a nn.Module class (e.g., nn.Linear(**kwargs)), it may take more steps before and after, and people often prepare a function to do this, like a create_model(...) function.

Sometimes I want to wrap others’s models in LightningModules to use Lighting for easier training, and sometimes I want to make as few modifications to their codes as I can.

In such situation, I would need to pass arguments to a function (create_model(...)) instead of some class (e.g., nn.Linear(**kwargs)), but I could not get it work. More specifically, I tried to include the parameters in a dict (e.g., model_kwargs={…}) and pass this dict to the create_model(...) function, but cause ambiguity about the data type of the values passed and caused some int getting parsed as str in my case (see (3) below).

To sum up, my question here is, is there a (better) way to create an object in LightningModule, which can be configured by (1) config file and (2) command line arguments?

Below, I will give some of my attempts to create a nn.Conv2d object using Lightning CLI, but none of these really solves my problem.

===== 1. The Preferred Way (Probably) to Create and Configure an Object in LightningModule

My LightningModule:

class BoringModel(L.LightningModule) :
     def __init__(
          self, 
          model: Callable = nn.Conv2d,
     ):
        super().__init__()
        self.save_hyperparameters()

        self.model = model
...

My config yaml:

model:
  class_path: debug.boring_model.BoringModel
  init_args:
    model:
      class_path: torch.nn.Conv2d
      init_args:
        in_channels: 1
        out_channels: 10
        kernel_size: 1
...

My command:

python3 main.py fit \
--config boring_model.yaml

This works without issue. Below, I will try to create the object using a function instead.

===== 2. Passing Arguments to a Function to Create an Object

As far as I know, it is not possible to make the LightningModule to take a “function-type” parameter instead of the Callable-type parameter as we have seen above.

Instead, I tried to make the LightningModule to take a dict parameter, which will be passed to the function create_model(...) like this:

My create_model(…) function:

def create_model(
    in_channels: int = 3,
    out_channels: int = 1000,
    kernel_size: int = 3,
):
    return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size)

My LightningModule:

class BoringModel(L.LightningModule) :
    def __init__(self, 
                 model_kwargs: dict = {},
                 **kwargs,
                 ):
        super().__init__()
        self.save_hyperparameters()
        print("in_channels: ", model_kwargs['in_channels'], type(model_kwargs['in_channels']))
        print("out_channels: ", model_kwargs['out_channels'], type(model_kwargs['out_channels']))

        self.model = create_model(**model_kwargs)
...

My config yaml:

model:
  class_path: boring_model.BoringModel
  init_args:
    model_kwargs:
      in_channels: 1
      out_channels: 10
      kernel_size: 1
    optimizer: 
...

My command:

python3 main.py fit \
--config boring_model.yaml

Outputs:

in_channels:  1 <class 'int'>
out_channels:  10 <class 'int'>

This works as expected. Below, I will configure the object using another way.

===== 3. Configure Object Created in a Function through Command Line Interface

If I pass the values as command line arguments, since the parameters as passed inside a dict, the parser cannot figure out the data types of those values and will parse them as str instead of int.

The LightnModule and yaml config file are the same as in (2).

My command:

python3 main.py fit \
--config boring_model.yaml \
--model.model_kwargs.in_channels 64

Outputs:

in_channels:  64 <class 'str'>
out_channels:  10 <class 'int'>

Now in_channels is parsed as a str by mistake, and this does not work with create_model(...).

What you can do is to use a type hint that represents the parameters of create_model. This could be done via a dataclass. Implement a function that creates such a dataclass like:

import inspect
from dataclasses import make_dataclass, field

def dataclass_from_signature(component):
    fields = []
    for param in inspect.signature(component).parameters.values():
        if param.default is inspect.Parameter.empty:
            fields.append((param.name, param.annotation))
        else:
            fields.append((param.name, param.annotation, field(default=param.default)))
    return make_dataclass(f'{component.__name__}_parameters', fields)

Then create the dataclass and use it as the type, i.e.:

model_kwargs_type = dataclass_from_signature(create_model)

class BoringModel(L.LightningModule) :
    def __init__(
        self, 
        model_kwargs: model_kwargs_type = model_kwargs_type(),
        **kwargs,
    ):
        ...

I have thought about adding a dataclass_from_signature function to jsonargparse. But also support resolving parameters from **kwargs.

1 Like

Would you mind giving more explanation on how to pass model_kwargs in this case? I tried these but these do not work.

model = create_model(model_kwargs)
# File "/root/anaconda3/envs/diff_dev/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 89, in __init__
#    if in_channels % groups != 0:
#       ~~~~~~~~~~~~^~~~~~~~
# TypeError: unsupported operand type(s) for %: 'create_model_parameters' and 'int'
# 
# (as I am creating a nn.Conv2d object in create_model(...)

model = create_model(*model_kwargs)
# create_model() argument after * must be an iterable, not create_model_parameters

model = create_model(**model_kwargs) 
# create_model() argument after ** must be a mapping, not create_model_parameters

I do not know if this is related, I am setting my config file in this way:

model:
  class_path: boring_model.BoringModel
  init_args:
    model_kwargs:
      in_channels: 1
      out_channels: 10
      kernel_size: 1

and I am starting a trainer using CLI with this command:

python3 main.py fit \
--config boring_model.yaml \
--model.model_kwargs.in_channels 64

from dataclasses import asdict

model = create_model(**asdict(model_kwargs))
1 Like

This works! Thank you very much for your help.

I am developing on this example and I am trying to make model_kwargs an optional parameter, so that if nothing about the model is specified, then I do not create the model. Sometimes, I find this helpful when I have multiple sub-modules in a model that can be switch on/off through CLI. However, I have not yet found a way to do this. This is my current attempt:

class BoringModel(L.LightningModule) :
    def __init__(
        self, 
        model_kwargs: model_kwargs_type = None,
        ...
    ):
        
        if model_kwargs is not None :
            self.model = create_model(**asdict(model_kwargs))
        else :
            self.model = None

In fact, this if-else condition works so that when model_kwargs is specified create_model is called, and if no model_kwargs is specified self.model is set to None.

However, I found out that if some kwargs of the model are set, the model is always created with its default parameters, no matter what values I set to model_kwargs. I guess this is because dataclass_from_signature(...) from the solution is never called to get the specified kwargs values?

I would like to ask if there is a way to extend the suggested solution to handle optional parameters.

I guess one can add an additional bool argument like use_model=True/False to determine whether to call create_model so that
one can still set model_kwargs: model_kwargs_type = model_kwargs_type(), but maybe this is not a very neat solution as this introduces more parameter(s).

If you want to make model_kwargs optional, then it should be defined as model_kwargs: Optional[model_kwargs_type] = None. I don’t really understand what say after about dataclass_from_signature and handling optional parameters.