I have a lightning module:
class MyClassifier(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.save_hyperparameters()
self.num_classes = config["arch"]["args"]["num_classes"]
self.model_args = config["arch"]["args"]
self.model, self.tokenizer = get_model_and_tokenizer(**self.model_args)
self.bias_loss = False
I train the model using this:
model = MyClassifier(config)
checkpoint_callback = ModelCheckpoint(
save_top_k=5,
verbose=True,
monitor="val_f1",
mode="max",
)
trainer = pl.Trainer(
accelerator='auto',
strategy="ddp",
max_epochs=args.n_epochs,
accumulate_grad_batches=config["accumulate_grad_batches"],
callbacks=[checkpoint_callback],
default_root_dir="saved/" + config["name"],
deterministic=True,
precision=16
)
trainer.fit(model, data_loader, valid_data_loader)
I try to load a checkpoint using this:
model = MyClassifier.load_from_checkpoint(PATH)
but I get this error:
RuntimeError: Error(s) in loading state_dict for MyClassifier:
Missing key(s) in state_dict: "model.roberta.embeddings.position_ids".
I have also tried this:
model = MyClassifier.load_from_checkpoint(PATH,config=config)
but I get the same error. I also tried loading state_dict using this:
ckpoint = torch.load(PATH)
model = MyClassifier.load_state_dict(ckpoint['state_dict'])
but I get this error:
TypeError: Module.load_state_dict() missing 1 required positional argument: 'state_dict'
My Pytorch Lightning version is 2.0.7. I run training on 4 GPUs.