My problem
I’d like to ask the community for recommendations on how to implement save/load operations for LightningModules that are organized in a particular pattern. The pattern I’m talking about comes from the docs’ recommendation on how to setup models for production:
class ClassificationTask(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
# The rest of the module defines `__step` computations and optimizers
Here, the LightningModule
describes computations rather than a network; the network itself is injected in the module. I really like this approach (that is quite modular and easily configurable IMO), but I’m a bit puzzled on how to make saving/loading modules work with it.
What I tried
I figure that I have to rebuild model
myself and inject it during the loading of ClassificationTask
, however I’m hitting a wall when trying to make it work in practice. What I’m doing is something like this:
class ClassificationTask(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
def load_from_checkpoint(...):
# Rebuild `model` from the configuration stored in the checkpoint
model = ...
# Here is the tricky part
super().load_from_checkpoint(..., model=model)
Issues with my solution so far
When calling super().load_from_checkpoint(...)
, I thought I could just inject model
there and be done with it; it would be forwarded to the class’ __init__
and all would be well. However, digging a little deeper in the code, I came across this bit in the load_from_checkpoint
base implementation in LightningModule
:
# for past checkpoint need to add the new key
if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint:
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {}
# override the hparams with values that were passed in
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
In my case, kwargs
would be {'model': model}
. The _load_model_state
call means it will indeed get forwarded to the class __init__
eventually, however it’s also getting added to checkpoint’s hyperparameters, which I would want to avoid polluting with the injected model.
Thus follows my question: would you have any recommendation on how to modify my setup so that I can inject the network inside the LightingModule
upon loading it, without having the network be added to the checkpoint’s hyperparams?
Additional context
You might be wondering why I’m so insistent about not having model
be added to the checkpoint’s hyperparams. Well, I’m trying to setup a project seed merging Lightning and Hydra. Thus, my hyperparams are a typed structured config that straight up refuses to receive a custom class. The error message I receive is something like:
omegaconf.errors.UnsupportedValueType: Value 'model.__class__' is not a supported primitive type
Furthermore, as a general principle, I think it’s best to save/load strict hyperparams, to ease reproducibility and portability. I’ve had previous bad experiences with internal states that varied a little depending on whether they where built from scratch or loaded from a checkpoint.
Thanks in advance for your help/recommendations!
EDIT: I’d have added more links to Hydra’s documentation to give better context, but as a new user of the forum I’m unfortunately limited to 2 links by post