Hi, I am new to PyTorch Lightning, and now I am testing checkpointing because I cannot finish a training session before GPU resource timeout (12 hours).
I am logging accuracy, loss, and learning rate using TensorBoardLogger
.
From TensorBoard, I found that my source code does not properly load the model or trainer from the checkpoint. In the screenshots, version_0 (orange)
is the trace of the initial training, and version_1 (blue)
is the trace of the resumed training. As you can see, the resumed training starts from epoch 0 and the previous learning rate is not loaded and initialized as 0.01. If I am understanding correctly, all those states should be stored in the checkpoint and resumed as they were stored, but that’s not in my case.
Here is my source code with example save paths that I am using in it:
file_name = get_filename(hparams)
save_path = hparams.save_path or os.path.join(os.getcwd(), 'logs', hparams.dataset)
ckpt_path = None if hparams.v_num is None else os.path.join(save_path, file_name, 'version_' + str(hparams.v_num),'checkpoints')
ckpt_file = None if hparams.v_num is None else os.path.join(ckpt_path, os.listdir(ckpt_path)[0])
statedict_path = os.path.join(os.getcwd(), 'trained_models', hparams.dataset + '_' + hparams.arch + '.pt')
'''
### Path Examples:
file_name = MODEL_NAME
save_path = HOME_DIR/logs/cifar10
ckpt_path = HOME_DIR/logs/cifar10/MODEL_NAME/version_0/checkpoints
ckpt_file = HOME_DIR/logs/cifar10/MODEL_NAME/version_0/checkpoints/epoch=6-step=1243.ckpt
statedict_path = HOME_DIR/trained_models/MODEL_NAME.pt
'''
if hparams.v_num is None:
model = MODEL(hparams,
num_classes = dm.num_classes,
train_size = len(dm.train_dataloader().dataset))
else:
model = MODEL.load_from_checkpoint(ckpt_file,
num_classes = dm.num_classes,
train_size = len(dm.train_dataloader().dataset))
logger = TensorBoardLogger(save_path, name=file_name)
lr_monitor = LearningRateMonitor(logging_interval='step')
trainer = Trainer(default_root_dir=save_path,
gpus=hparams.gpus,
max_epochs=hparams.epochs,
resume_from_checkpoint = ckpt_file,
distributed_backend=hparams.distributed_backend,
num_nodes=hparams.num_nodes,
logger = logger,
callbacks=[lr_monitor,],
deterministic = deterministic,)
trainer.fit(model, dm)
# Save weights from checkpoint
torch.save(model.model.state_dict(), statedict_path)
FYI, hparams.v_num is specified as an int value only when resuming from checkpoint.
I am sure that I am not using the methods properly, and the save paths might also be incorrect. However, I could not find any good example to solve my issue with.
Please let me know if you can find any mistakes that I made from my source code…