Training slowing down

Hi so I made the following code for training a model for regression:

from torch.utils.data import Dataset
import torch
from PIL import Image
import torchvision.transforms as transforms
import pandas as pd 
from torchvision.transforms.functional import to_grayscale
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torchmetrics
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
import numpy as np
from pytorch_lightning.loggers import WandbLogger
import os 
from PIL import Image, ImageEnhance, ImageFilter
import torchvision.transforms as T
import timm 
#os.environ['WANDB_API_KEY'] = '8ea73fd2d6ba34d4ce64b744d1889658b07f6d14'#Summer's wandb key. don't share!
from lion_pytorch import Lion
#from FasterVIT import faster_vit_0_224
from torchvision.transforms import TrivialAugmentWide
#wandb_logger = WandbLogger(project = "ProjectBabble",log_model = True) 
class BabbleDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.labels = self.dataframe[["cheekPuffLeft", "cheekPuffRight", "cheekSuckLeft", "cheekSuckRight", "jawOpen", "jawForward", "jawLeft", "jawRight", "noseSneerLeft", "noseSneerRight", "mouthFunnel", "mouthPucker", "mouthLeft", "mouthRight", 
    "mouthRollUpper", "mouthRollLower", "mouthShrugUpper", "mouthShrugLower", "mouthClose", "mouthSmileLeft", 
    "mouthSmileRight", "mouthFrownLeft", "mouthFrownRight", "mouthDimpleLeft", "mouthDimpleRight", "mouthUpperUpLeft", 
    "mouthUpperUpRight", "mouthLowerDownLeft", "mouthLowerDownRight", "mouthPressLeft", "mouthPressRight", "mouthStretchLeft", 
    "mouthStretchRight", "tongueOut", "tongueUp", "tongueDown", "tongueLeft", "tongueRight", "tongueRoll", "tongueBendDown", "tongueCurlUp", "tongueSquish", "tongueFlat", "tongueTwistLeft", "tongueTwistRight"]]
        self.images = self.dataframe[["filename"]]
        #self.images["filename"] = self.images["filename"] 
        self.transform = TrivialAugmentWide()
        self.box = (0, 0, 256, 255)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        filename = self.images["filename"][idx]
        image = Image.open(filename).crop(self.box).resize((256, 256))
        image = self.transform(image)
        image = to_grayscale(image)
        image = transforms.ToTensor()(image)
        label = self.labels.loc[idx].to_list()    
        label = torch.tensor(label).float() #.round(decimals = 4)
        return image, label


class BabbleDataModule(LightningDataModule):
    def __init__(self, data_file, batch_size=64, num_workers=0): #control the batch size from here
        super().__init__()
        self.data_file = data_file
        self.batch_size = batch_size
        self.num_workers = num_workers

    def prepare_data(self):
        # This method is used for downloading and preprocessing data (if any)
        pass

    def setup(self, stage=None):
        df = pd.read_csv(self.data_file)
        train, test = train_test_split(df, test_size=0.2)
        train = train.reset_index(drop=True)
        test = test.reset_index(drop=True)
        self.train_dataset = BabbleDataset(train)
        self.test_dataset = BabbleDataset(test)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers = 8, 
            pin_memory= True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers = 8, 
            pin_memory= True,
        )


class BabbleModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = timm.create_model('tf_efficientnetv2_b0.in1k',num_classes = 45, pretrained = True, in_chans = 1)
        self.criterion = nn.HuberLoss()
        self.metric = torchmetrics.MeanMetric()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)
        self.log('train_loss', loss, prog_bar=True, on_step = True, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)
        self.metric(loss)
        self.log('val_loss', loss, prog_bar=True, on_step = True, sync_dist=True)

    def configure_optimizers(self):
        optimizer = Lion(self.model.parameters(), lr=0.000069) #learning rate 000069
        return optimizer


if __name__ == '__main__':
 

    data_module = BabbleDataModule(data_file='/home/babble/Documents/ProjectBabble/Babble_data.csv') #csv goes here
    model = BabbleModel()

    checkpoint_callback = ModelCheckpoint(monitor='val_loss', save_on_train_epoch_end = True, save_last= True)
    trainer = Trainer(
        max_epochs=100, #50
        callbacks=[checkpoint_callback],
        #logger=wandb_logger,
    )

    trainer.fit(model, data_module)

For some reason my training time decreases at first, then slowly increases! I don’t really know why heh!

Try setting logger=False to disable the default csv logger, or update to the lightning version to 2.1.0+ where this inefficiency was fixed.