- PyTorch version: 2.0.0
- PyTorch lightning: 2.0.2
- PyTorch-Forecasting version: 1.0.0
- Python version: 3.8.10
- Operating System: Ubuntu 20.04
Expected behavior
Saving a checkpoint and resuming from that checkpoint in Pytorch Lightning (2.0.2) should give the same model in the end (weights, states, loss).
Actual behavior
Saving a checkpoint and resuming from that checkpoint in Pytorch Lightning (2.0.2) gives different training result in the end (checked via MD5, loss etc)
Code to reproduce the problem
In this code I create 3 data frames and train on each, one epoch at the time. This will render a md5 sum. Rerunning the code will render the same md5 sum.
The second time I run the exact same code, before the last iteration, I save a checkpoint to disk and then resume training from that checkpoint. This will render a different md5 sum. This should not be the case.
import os
import warnings
warnings.filterwarnings("ignore") # avoid printing out absolute paths
import copy
from pathlib import Path
import warnings
from matplotlib import pyplot as plt
import tempfile
import json, hashlib
import numpy as np
import pandas as pd
import torch
import random
import lightning.pytorch as pl
from pytorch_forecasting import TimeSeriesDataSet, NHiTS
import logging
#logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
#logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.WARNING)
logging.getLogger("lightning.pytorch.accelerators.cuda").setLevel(logging.WARNING)
##################### Eval
class NHitsRunner():
def __init__(self):
self.batch_size = 256
self.reset_seed()
self.create_datasets()
self.create_model()
self.path = None
def create_model(self):
# Dataset
self.nhits_model = NHiTS.from_dataset(
self.test_datasets[0],
learning_rate=0.1,
hidden_size=16,
optimizer='adam'
)
self.trainer = pl.Trainer(
devices=1, accelerator="gpu",
max_epochs=1,
gradient_clip_val=0.1,
enable_model_summary=False,
enable_progress_bar=True,
enable_checkpointing=False,
logger=False,
)
def create_dataset(self, frame):
self.max_prediction_length = 10
self.max_encoder_length = 24
training_cutoff = frame["time_idx"].max() - self.max_prediction_length
return TimeSeriesDataSet(
frame[lambda x: x.time_idx <= training_cutoff],
time_idx="time_idx",
target="price",
group_ids=["group_id"],
max_encoder_length=self.max_encoder_length,
max_prediction_length=self.max_prediction_length,
static_categoricals=[],
time_varying_known_reals=[],
time_varying_unknown_categoricals=[],
time_varying_unknown_reals=[
"price",
"volume",
],
target_normalizer=None,
)
def train_epoch(self, set_id):
train_dataloader = self.test_datasets[set_id].to_dataloader(train=True, batch_size=self.batch_size, num_workers=0)
if self.path != None:
print('Consumed trainer checkpoint')
#self.nhits_model = NHiTS.load_from_checkpoint(self.path)
self.trainer.fit(
self.nhits_model,
ckpt_path=self.path,
train_dataloaders=train_dataloader,
)
os.remove(self.path)
self.path = None
else:
self.trainer.fit(
self.nhits_model,
train_dataloaders=train_dataloader,
)
self.trainer.fit_loop.max_epochs += int(1)
def reset_seed(self):
pl.seed_everything(0)
torch.set_float32_matmul_precision('medium')
def create_frame(self):
# Create sine wave
dataframe_size = 128 * 64
test_dataframe = pd.DataFrame({'time_idx': np.arange(0, dataframe_size, 1, dtype=int)})
test_dataframe = test_dataframe.assign(price=lambda x: np.sin(x.time_idx/3)*100)
test_dataframe = test_dataframe.assign(volume=lambda x: np.sin(x.time_idx/2)*200)
test_dataframe['group_id'] = 'group_a'
return test_dataframe
def create_datasets(self):
self.test_frames = []
self.test_datasets = []
for i in range(3):
new_frame = self.create_frame()
self.test_frames.append(new_frame)
self.test_datasets.append(self.create_dataset(new_frame))
def convert_dict(self, d, flat_dict):
for k, v in d.items():
if isinstance(v, dict):
self.convert_dict(v, flat_dict)
else:
if isinstance(v, torch.Tensor):
v = v.cpu().detach().numpy().tolist()
flat_dict[k] = v
def dict_md5(self):
_dict = self.nhits_model.state_dict()
flat_dict = {}
self.convert_dict(_dict, flat_dict)
data_md5 = hashlib.md5(json.dumps(flat_dict, sort_keys=True).encode('utf-8')).hexdigest()
return data_md5
def test_save_load_trainer_checkpoint(self):
print("Saving trainer checkpoint")
# Save trainer file
model_filename = "test_checkpoint_forecast.ckpt"
tempdir = tempfile.gettempdir()
file_name = tempdir+"/"+model_filename
self.trainer.save_checkpoint(file_name)
self.path = file_name
#### Test running WITHOUT saving and loading
runner_1 = NHitsRunner()
for i in range(3):
runner_1.train_epoch(i)
run1_md5 = runner_1.dict_md5()
#### Test running WITHOUt saving and loading
runner_2 = NHitsRunner()
for i in range(3):
if i == 2:
runner_2.test_save_load_trainer_checkpoint()
runner_2.train_epoch(i)
run2_md5 = runner_2.dict_md5()
#### Compare
if run1_md5 != run2_md5:
print('Models md5 differs {} != {}'.format(run1_md5, run2_md5))
else:
print('Models md5 equals {} == {}'.format(run1_md5, run2_md5))