Lightning-CLI use timm.create_model to initialize model in config file

The Timm library offers a great variety of models, and with hydra/omegaconf one could write a config file like

model:
    _target_: timm.create_model
   model_name: resnet18

And be able to initiatlize a LightningModule, that takes model: nn.Module as an argument. I tried doing the same with LightningCLI but am not sure how to configure that correctly, because Import path timm.create_model does not correspond to a subclass of <class 'torch.nn.modules.module.Module'>. Here is the example:

from lightning import LightningModule, LightningDataModule
from lightning.pytorch.cli import LightningCLI, ArgsType
from lightning.pytorch.cli import OptimizerCallable

import os
import torch
from typing import Optional
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


from lightning.pytorch.cli import OptimizerCallable, LightningCLI


class RandomDataset(Dataset):
    def __init__(self, size=5, num_samples=10):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len
    
class RandomDataModule(LightningDataModule):
    def __init__(self, size=5, num_samples=10, batch_size=2):
        super().__init__()
        self.size = size
        self.num_samples = num_samples
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_dataset = RandomDataset(self.size, self.num_samples)
        self.val_dataset = RandomDataset(self.size, self.num_samples)
        self.test_dataset = RandomDataset(self.size, self.num_samples)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

class BoringModel(LightningModule):
    def __init__(self, model: nn.Module, loss_fn: nn.Module = nn.CrossEntropyLoss(), optimizer: OptimizerCallable = torch.optim.Adam):
        super().__init__()

        self.model = model
        self.optimizer = optimizer
        self.loss_fn = loss_fn
    
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def exclude_some_params(self):
        return self.model.network[0].parameters()

    def configure_optimizers(self):
        return self.optimizer(self.exclude_some_params())


def cli_main(args: ArgsType = None):
    LightningCLI(args=args, auto_configure_optimizers=False, parser_kwargs={"error_handler": None})

if __name__ == "__main__":
    cli_main(["fit", "--config", "example_yaml.yaml"])

And the config file:

trainer: # pytorch lightning trainer arguments
  max_epochs: 2

data: # datamodule arguments
  class_path: example.RandomDataModule
  init_args:
    size: 2

model:
  class_path: example.BoringModel
  init_args:
    model:
      class_path: timm.create_model
      init_args:
        model_name: "resnet18"
        num_classes: 2
        in_chans: 3
        pretrained: True
    optimizer:
      class_path: torch.optim.SGD
      init_args:
        lr: 0.003