I’m new in Pytorch Lightning. I made a very simple model using PL. I checked the weights of the model before and after training but They are exactly the same knowing that the loss decrease during training.
def main(args, df_train, df_dev, df_test) :
""" main function"""
# Wandb connect
wandb_connect()
wandb_logger = WandbLogger(project="project name", name="Run name")
# Tokenization
[df_train, df_dev, df_test], params, tokenizer_qid, tokenizer_uid, tokenizer_qu_id, tokenizer_rank = apply_tokenization([df_train, df_dev, df_test])
# Dataloadeers
[train_loader, dev_loader, test_loader] = list(map(lambda x : Dataset_SM(x).get_dataloader(args.batch_size), [df_train, df_dev, df_test]))
# Model definition
model = NCM(**params).to(device)
# Weight before training
WW = model.emb_qid.weight
print(torch.mean(model.emb_qid.weight))
# Train & Eval
es = EarlyStopping(monitor='dev_loss', patience=4)
checkpoint_callback = ModelCheckpoint(dirpath=args.result_path)
trainer = pl.Trainer(max_epochs=args.n_epochs, callbacks=[es, checkpoint_callback], val_check_interval=args.val_check_interval,
logger=wandb_logger, gpus=1)
trainer.fit(model, train_loader, dev_loader)
trainer.save_checkpoint(args.result_path + "example.ckpt")
loaded_model = NCM.load_from_checkpoint(checkpoint_path=args.result_path + "example.ckpt", **params)
print(loaded_model.emb_qid.weight == WW)
Can someone tell me if I miss something ?