I’m trying to understand how I should save and load my trained model for inference
Lightning allows me to save checkpoint files, but the problem is the files are quite large because they contain a lot of information that is not relevant to inference
Instead, I could do torch.save(model.state_dict(), "model.pt")
, which I believe only contains the trained weights, and then load the model using:
model = FullModel()
model.load_state_dict(torch.load("model.pt"))
model.eval()
My problem here is that my FullModel
class takes in a config dict, which was used to tune hyperparameters during training:
TypeError: __init__() missing 1 required positional argument: 'config'
Is the way around this to save config
to disk during training, and load that up with the model during inference? Or is there a more “correct” way of doing it?
I could simply save the entire model (and not just the state_dict), which really simplifies loading, but that file ends up almost as big as the checkpoint files