I have a DataModule
that looks like this:
from torch.vision.transforms.v2 import Resize
class BioDataModule(L.LightningDataModule):
def __init__(
self, path_to_X='biodata.npy',
path_to_Y='biolabels.csv',
split_ratio=[0.8, 0.1, 0.1],
batch_size=32,
transform = Resize(23), # The problem.
):
super().__init__()
self.path_to_X = path_to_X
self.path_to_Y = path_to_Y
self.batch_size = batch_size
self.split_ratio = split_ratio
self.transform = transform
When I run the training from the command line:
python run.py fit
I get the following error:
Traceback (most recent call last):
File "/home/asarikas/biochem/cli.py", line 16, in <module>
cli_main()
File "/home/asarikas/biochem/cli.py", line 13, in cli_main
cli = LightningCLI(BioNN, BioDataModule)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/lightning/pytorch/cli.py", line 388, in __init__
self._run_subcommand(self.subcommand)
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/lightning/pytorch/cli.py", line 679, in _run_subcommand
fn(**fn_kwargs)
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
call._call_and_handle_interrupt(
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 948, in _run
call._call_setup_hook(self) # allow user to set up LightningModule in accelerator environment
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 93, in _call_setup_hook
_call_callback_hooks(trainer, "setup", stage=fn)
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 208, in _call_callback_hooks
fn(trainer, trainer.lightning_module, *args, **kwargs)
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/lightning/pytorch/cli.py", line 269, in setup
self.parser.save(
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/jsonargparse/_deprecated.py", line 160, in patched_save
return self._unpatched_save(cfg, *args, multifile=multifile, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/jsonargparse/_core.py", line 832, in save
f.write(self.dump(cfg, **dump_kwargs)) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/jsonargparse/_deprecated.py", line 151, in patched_dump
return self._unpatched_dump(cfg, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/jsonargparse/_core.py", line 735, in dump
return dump_using_format(self, cfg_dict, "yaml_comments" if yaml_comments else format)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/jsonargparse/_loaders_dumpers.py", line 176, in dump_using_format
dump = dumpers[dump_format](*args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/jsonargparse/_loaders_dumpers.py", line 131, in yaml_dump
return yaml.safe_dump(data, **dump_yaml_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/yaml/__init__.py", line 269, in safe_dump
return dump_all([data], stream, Dumper=SafeDumper, **kwds)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/yaml/__init__.py", line 241, in dump_all
dumper.represent(data)
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/yaml/representer.py", line 27, in represent
node = self.represent_data(data)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/yaml/representer.py", line 48, in represent_data
node = self.yaml_representers[data_types[0]](self, data)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/yaml/representer.py", line 207, in represent_dict
return self.represent_mapping('tag:yaml.org,2002:map', data)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/yaml/representer.py", line 118, in represent_mapping
node_value = self.represent_data(item_value)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/yaml/representer.py", line 48, in represent_data
node = self.yaml_representers[data_types[0]](self, data)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/yaml/representer.py", line 207, in represent_dict
return self.represent_mapping('tag:yaml.org,2002:map', data)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/yaml/representer.py", line 118, in represent_mapping
node_value = self.represent_data(item_value)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/yaml/representer.py", line 58, in represent_data
node = self.yaml_representers[None](self, data)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/asarikas/venvir/testing/lib/python3.11/site-packages/yaml/representer.py", line 231, in represent_undefined
raise RepresenterError("cannot represent an object", data)
yaml.representer.RepresenterError: ('cannot represent an object', Resize(size=[23], interpolation=InterpolationMode.BILINEAR, antialias=warn))
I assume that this has to do with the fact that I am passing an instance in the __init__
method of the DataModule
(same happens also if I pass just the class Resize
). Can the LightningCLI
work with LightningDataModule
’s or LightningModule
’s that take such kind of init_args
?
Note
If I train the model with:
python run.py fit --config config.yaml
with config.yaml
:
data:
transform:
class_path: torchvision.transforms.v2.Resize
init_args: {size: 300}
everything works fine.