I’m using the following setup:
class BaseNetwork(pl.LightningModule):
# network code
class ModelTrainer(pl.LightningModule):
# model training code
class MyModel(ModelTrainer, BaseNetwork):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# new code
# anything that needs to be overwritten
Here is a gist demonstrating the setup I made for an issue:
Which I made for this issue: Hparams not restored when using load_from_checkpoint (default argument values are the problem?)
Make sure you apply the self.save_hyperparameters()
fix.