I am implementing Temporal Fusion Transformer (TFT) using pytorch (2.0.1+cpu), pytorch_forcasting (1.0.0), and pytorch_lightning (2.0.8). When I run the trainer.fit, I get an error “Checkpoint ‘c:/…/checkpoints/best_model.ckpt’ not found. Please check the file path.” Please see the code below
import pandas as pd
import pytorch_lightning as pl
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.metrics import SMAPE
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.data.encoders import NaNLabelEncoder
from pytorch_forecasting.data.examples import generate_ar_data
from pytorch_lightning.callbacks import ModelCheckpoint
import os
import torch
from pytorch_lightning.loggers import TensorBoardLogger
class TFTLightning(pl.LightningModule):
def __init__(self, model):
super(TFTLightning, self).__init__()
self.model = model
self.save_hyperparameters(ignore=["model"])
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = SMAPE()(y_hat, y)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = SMAPE()(y_hat, y)
self.log("val_loss", loss) # Log the validation loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.03)
return optimizer
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = SMAPE()(y_hat, y)
self.log("test_loss", loss) # Log the test loss
return loss
# Setting up the logger
logger = TensorBoardLogger(save_dir='logs/', name='my_experiment')
# Generate synthetic time series data
data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100, seed=42)
# Convert the data to a DataFrame
df = pd.DataFrame(data)
# Define the time series dataset
max_encoder_length = 60
max_prediction_length = 20
training_cutoff = df["time_idx"].max() - max_prediction_length
context_length = max_encoder_length
prediction_length = max_prediction_length
# creating TimeSeriesDataSet
training = TimeSeriesDataSet(
df[df["time_idx"] <= training_cutoff],
time_idx="time_idx",
target="value",
categorical_encoders={"series": NaNLabelEncoder(add_nan=True)},
group_ids=["series"],
time_varying_unknown_reals=["value"],
max_encoder_length=context_length,
max_prediction_length=prediction_length,
)
validation = TimeSeriesDataSet.from_dataset(training, df, min_prediction_idx=training_cutoff + 1)
batch_size = 64
# Define data loaders
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)
# Define the Temporal Fusion Transformer (TFT) model
tft = TemporalFusionTransformer.from_dataset(
training,
learning_rate=0.03,
hidden_size=32,
attention_head_size=1,
dropout=0.1,
hidden_continuous_size=8,
output_size=7,
loss=SMAPE(),
log_interval=10,
reduce_on_plateau_patience=4,
)
# Configure ModelCheckpoint to save the best model based on validation loss
checkpoint_callback = ModelCheckpoint(
monitor="val_loss", # Specify the metric to monitor (e.g., validation loss)
mode="min", # "min" for metrics where lower values are better, "max" for higher values
save_top_k=1, # Save the top 1 best models
# dirpath=checkpoint_root_dir_path, # Directory where checkpoints will be saved
filename="best_model", # File name for the best model checkpoint
)
# Wrap the TFT model in a LightningModule
model = TFTLightning(tft)
# Define a PyTorch Lightning trainer with the ModelCheckpoint callback
trainer = pl.Trainer(
max_epochs=10,
gradient_clip_val=0.1,
callbacks=[checkpoint_callback], # Add the ModelCheckpoint callback
default_root_dir=checkpoint_root_dir_path
)
# Train the model
trainer.fit(
model=model,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader
)
# After training, you can test the best model by specifying the checkpoint path
best_model_checkpoint = os.path.join(trainer.logger.log_dir, "best_model.ckpt")
if os.path.exists(best_model_checkpoint):
# Test the best model using the ModelCheckpoint
results = trainer.test(ckpt_path=best_model_checkpoint)
else:
print(f"Checkpoint '{best_model_checkpoint}' not found. Please check the file path.")
While training the model, versioned directories (version_0, version_1, etc) are created in lightning_logs directory; the contents are hparams.yaml files with no content. best_model.ckpt is not found in these versioned directories.
Environment
OS: Windows 11
IDE: VS Code
Python: 3.10.2