Hi,
I created a model where I pass the argparse arguments(args) as well as a neural network and a loss function. After training, I try to load a checkpoint, so I thought I just need to pass the model checkpoint, neural network and the loss function as inputs to load_from_checkpoint(), since I already saved the hyperparameters(contained in the args). However, it throws an error unless I pass args.
args = parser.parse_args()
class SSLModel(pl.LightningModule):
def __init__(self, args, network, loss):
super().__init__()
"""
args: argument parser, required parameters for training the model
network: nn.Module, consists of an encoder and the projection head
loss: nn.Module, Loss used for training the model.
"""
self.save_hyperparameters(args)
self.network = network
self.loss = loss
Trained using the following commands:
learning_module = SSLModel(args= args, network=network, loss=loss)
trainer = Trainer.from_argparse_args(args, logger=logger, callbacks=[early_stopping, checkpoint_callback,lr_monitor, custom_callback])
trainer.fit(learning_module, train_dataloader, valid_dataloader)
After training, I try to load a checkpoint as given below, but it throws an error(‘args missing’) unless I also pass the args as input
model = SSLModel.load_from_checkpoint(ckpt_file, network=network, loss=l)
I assumed that I need to pass only network
and loss
since they are not saved hyperparameters.
Also, is there a way to escape even passing the network
and loss
and save them as hyperparams?