I want to use Lightning CLI to pass arguments to initialize a LightningModule and some objects inside (e.g., a nn.Module). Lightning CLI provides some helpful features that allow me to create and configure objects by just setting values in a config file and/or passing command line arguments.
Meanwhile, quite often when people create a model, instead of just passing parameters to a nn.Module class (e.g., nn.Linear(**kwargs)), it may take more steps before and after, and people often prepare a function to do this, like a create_model(...)
function.
Sometimes I want to wrap others’s models in LightningModules to use Lighting for easier training, and sometimes I want to make as few modifications to their codes as I can.
In such situation, I would need to pass arguments to a function (create_model(...)
) instead of some class (e.g., nn.Linear(**kwargs)), but I could not get it work. More specifically, I tried to include the parameters in a dict (e.g., model_kwargs={…}) and pass this dict to the create_model(...)
function, but cause ambiguity about the data type of the values passed and caused some int getting parsed as str in my case (see (3) below).
To sum up, my question here is, is there a (better) way to create an object in LightningModule, which can be configured by (1) config file and (2) command line arguments?
Below, I will give some of my attempts to create a nn.Conv2d object using Lightning CLI, but none of these really solves my problem.
===== 1. The Preferred Way (Probably) to Create and Configure an Object in LightningModule
My LightningModule:
class BoringModel(L.LightningModule) :
def __init__(
self,
model: Callable = nn.Conv2d,
):
super().__init__()
self.save_hyperparameters()
self.model = model
...
My config yaml:
model:
class_path: debug.boring_model.BoringModel
init_args:
model:
class_path: torch.nn.Conv2d
init_args:
in_channels: 1
out_channels: 10
kernel_size: 1
...
My command:
python3 main.py fit \
--config boring_model.yaml
This works without issue. Below, I will try to create the object using a function instead.
===== 2. Passing Arguments to a Function to Create an Object
As far as I know, it is not possible to make the LightningModule to take a “function-type” parameter instead of the Callable-type parameter as we have seen above.
Instead, I tried to make the LightningModule to take a dict parameter, which will be passed to the function create_model(...)
like this:
My create_model(…) function:
def create_model(
in_channels: int = 3,
out_channels: int = 1000,
kernel_size: int = 3,
):
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size)
My LightningModule:
class BoringModel(L.LightningModule) :
def __init__(self,
model_kwargs: dict = {},
**kwargs,
):
super().__init__()
self.save_hyperparameters()
print("in_channels: ", model_kwargs['in_channels'], type(model_kwargs['in_channels']))
print("out_channels: ", model_kwargs['out_channels'], type(model_kwargs['out_channels']))
self.model = create_model(**model_kwargs)
...
My config yaml:
model:
class_path: boring_model.BoringModel
init_args:
model_kwargs:
in_channels: 1
out_channels: 10
kernel_size: 1
optimizer:
...
My command:
python3 main.py fit \
--config boring_model.yaml
Outputs:
in_channels: 1 <class 'int'>
out_channels: 10 <class 'int'>
This works as expected. Below, I will configure the object using another way.
===== 3. Configure Object Created in a Function through Command Line Interface
If I pass the values as command line arguments, since the parameters as passed inside a dict, the parser cannot figure out the data types of those values and will parse them as str instead of int.
The LightnModule and yaml config file are the same as in (2).
My command:
python3 main.py fit \
--config boring_model.yaml \
--model.model_kwargs.in_channels 64
Outputs:
in_channels: 64 <class 'str'>
out_channels: 10 <class 'int'>
Now in_channels
is parsed as a str by mistake, and this does not work with create_model(...)
.