Resuming training gives different model result / weights

  • PyTorch version: 2.0.0
  • PyTorch lightning: 2.0.2
  • PyTorch-Forecasting version: 1.0.0
  • Python version: 3.8.10
  • Operating System: Ubuntu 20.04

Expected behavior

Saving a checkpoint and resuming from that checkpoint in Pytorch Lightning (2.0.2) should give the same model in the end (weights, states, loss).

Actual behavior

Saving a checkpoint and resuming from that checkpoint in Pytorch Lightning (2.0.2) gives different training result in the end (checked via MD5, loss etc)

Code to reproduce the problem

In this code I create 3 data frames and train on each, one epoch at the time. This will render a md5 sum. Rerunning the code will render the same md5 sum.

The second time I run the exact same code, before the last iteration, I save a checkpoint to disk and then resume training from that checkpoint. This will render a different md5 sum. This should not be the case.

import os
import warnings

warnings.filterwarnings("ignore")  # avoid printing out absolute paths

import copy
from pathlib import Path
import warnings
from matplotlib import pyplot as plt
import tempfile
import json, hashlib

import numpy as np
import pandas as pd
import torch
import random
import lightning.pytorch  as pl

from pytorch_forecasting import TimeSeriesDataSet, NHiTS

import logging
#logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
#logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.WARNING)
logging.getLogger("lightning.pytorch.accelerators.cuda").setLevel(logging.WARNING)

##################### Eval

class NHitsRunner():

	def __init__(self):

		self.batch_size = 256 
		self.reset_seed()
		self.create_datasets()
		self.create_model()
		self.path = None

	def create_model(self):

		# Dataset

		self.nhits_model = NHiTS.from_dataset(
			self.test_datasets[0],
			learning_rate=0.1,
			hidden_size=16,
			optimizer='adam'
		)

		self.trainer = pl.Trainer(
			devices=1, accelerator="gpu",
			max_epochs=1,
			gradient_clip_val=0.1,
			enable_model_summary=False,
			enable_progress_bar=True,
			enable_checkpointing=False,
			logger=False,
		)

	def create_dataset(self, frame):

		self.max_prediction_length = 10
		self.max_encoder_length = 24
		training_cutoff = frame["time_idx"].max() - self.max_prediction_length

		return TimeSeriesDataSet(
			frame[lambda x: x.time_idx <= training_cutoff],
			time_idx="time_idx",
			target="price",
			group_ids=["group_id"],
			max_encoder_length=self.max_encoder_length,
			max_prediction_length=self.max_prediction_length,
			static_categoricals=[],
			time_varying_known_reals=[],
			time_varying_unknown_categoricals=[],
			time_varying_unknown_reals=[
				"price",
				"volume",
			],
			target_normalizer=None,
		)

	def train_epoch(self, set_id):

		train_dataloader = self.test_datasets[set_id].to_dataloader(train=True, batch_size=self.batch_size, num_workers=0)

		if self.path != None:

			print('Consumed trainer checkpoint')

			#self.nhits_model = NHiTS.load_from_checkpoint(self.path)

			self.trainer.fit(
				self.nhits_model,
				ckpt_path=self.path,
				train_dataloaders=train_dataloader,
			)

			os.remove(self.path)

			self.path = None

		else:
			
			self.trainer.fit(
				self.nhits_model,
				train_dataloaders=train_dataloader,
			)

		self.trainer.fit_loop.max_epochs += int(1)

	def reset_seed(self):

		pl.seed_everything(0)
		torch.set_float32_matmul_precision('medium')

	def create_frame(self):

		# Create sine wave
		dataframe_size = 128 * 64
		test_dataframe = pd.DataFrame({'time_idx': np.arange(0, dataframe_size, 1, dtype=int)})
		test_dataframe = test_dataframe.assign(price=lambda x: np.sin(x.time_idx/3)*100)
		test_dataframe = test_dataframe.assign(volume=lambda x: np.sin(x.time_idx/2)*200)
		test_dataframe['group_id'] = 'group_a'
		return test_dataframe

	def create_datasets(self):

		self.test_frames = []
		self.test_datasets = []

		for i in range(3):

			new_frame = self.create_frame()

			self.test_frames.append(new_frame)

			self.test_datasets.append(self.create_dataset(new_frame))

	def convert_dict(self, d, flat_dict):
		for k, v in d.items():
			if isinstance(v, dict):
				self.convert_dict(v, flat_dict)
			else:
				if isinstance(v, torch.Tensor):
					v = v.cpu().detach().numpy().tolist()
				flat_dict[k] = v

	def dict_md5(self):

		_dict = self.nhits_model.state_dict()
		
		flat_dict = {}
		self.convert_dict(_dict, flat_dict)

		data_md5 = hashlib.md5(json.dumps(flat_dict, sort_keys=True).encode('utf-8')).hexdigest()

		return data_md5

	def test_save_load_trainer_checkpoint(self):

		print("Saving trainer checkpoint")

		# Save trainer file

		model_filename = "test_checkpoint_forecast.ckpt"
		tempdir = tempfile.gettempdir()
		file_name = tempdir+"/"+model_filename
		self.trainer.save_checkpoint(file_name)

		self.path = file_name

#### Test running WITHOUT saving and loading 

runner_1 = NHitsRunner()

for i in range(3):

	runner_1.train_epoch(i)

run1_md5 = runner_1.dict_md5()

#### Test running WITHOUt saving and loading 

runner_2 = NHitsRunner()

for i in range(3):

	if i == 2:
		runner_2.test_save_load_trainer_checkpoint()

	runner_2.train_epoch(i)

run2_md5 = runner_2.dict_md5()

#### Compare

if run1_md5 != run2_md5:

	print('Models md5 differs {} != {}'.format(run1_md5, run2_md5))

else:

	print('Models md5 equals {} == {}'.format(run1_md5, run2_md5))