Hi so I made the following code for training a model for regression:
from torch.utils.data import Dataset
import torch
from PIL import Image
import torchvision.transforms as transforms
import pandas as pd
from torchvision.transforms.functional import to_grayscale
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torchmetrics
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
import numpy as np
from pytorch_lightning.loggers import WandbLogger
import os
from PIL import Image, ImageEnhance, ImageFilter
import torchvision.transforms as T
import timm
#os.environ['WANDB_API_KEY'] = '8ea73fd2d6ba34d4ce64b744d1889658b07f6d14'#Summer's wandb key. don't share!
from lion_pytorch import Lion
#from FasterVIT import faster_vit_0_224
from torchvision.transforms import TrivialAugmentWide
#wandb_logger = WandbLogger(project = "ProjectBabble",log_model = True)
class BabbleDataset(Dataset):
def __init__(self, dataframe, transform=None):
self.dataframe = dataframe
self.labels = self.dataframe[["cheekPuffLeft", "cheekPuffRight", "cheekSuckLeft", "cheekSuckRight", "jawOpen", "jawForward", "jawLeft", "jawRight", "noseSneerLeft", "noseSneerRight", "mouthFunnel", "mouthPucker", "mouthLeft", "mouthRight",
"mouthRollUpper", "mouthRollLower", "mouthShrugUpper", "mouthShrugLower", "mouthClose", "mouthSmileLeft",
"mouthSmileRight", "mouthFrownLeft", "mouthFrownRight", "mouthDimpleLeft", "mouthDimpleRight", "mouthUpperUpLeft",
"mouthUpperUpRight", "mouthLowerDownLeft", "mouthLowerDownRight", "mouthPressLeft", "mouthPressRight", "mouthStretchLeft",
"mouthStretchRight", "tongueOut", "tongueUp", "tongueDown", "tongueLeft", "tongueRight", "tongueRoll", "tongueBendDown", "tongueCurlUp", "tongueSquish", "tongueFlat", "tongueTwistLeft", "tongueTwistRight"]]
self.images = self.dataframe[["filename"]]
#self.images["filename"] = self.images["filename"]
self.transform = TrivialAugmentWide()
self.box = (0, 0, 256, 255)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
filename = self.images["filename"][idx]
image = Image.open(filename).crop(self.box).resize((256, 256))
image = self.transform(image)
image = to_grayscale(image)
image = transforms.ToTensor()(image)
label = self.labels.loc[idx].to_list()
label = torch.tensor(label).float() #.round(decimals = 4)
return image, label
class BabbleDataModule(LightningDataModule):
def __init__(self, data_file, batch_size=64, num_workers=0): #control the batch size from here
super().__init__()
self.data_file = data_file
self.batch_size = batch_size
self.num_workers = num_workers
def prepare_data(self):
# This method is used for downloading and preprocessing data (if any)
pass
def setup(self, stage=None):
df = pd.read_csv(self.data_file)
train, test = train_test_split(df, test_size=0.2)
train = train.reset_index(drop=True)
test = test.reset_index(drop=True)
self.train_dataset = BabbleDataset(train)
self.test_dataset = BabbleDataset(test)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers = 8,
pin_memory= True,
)
def val_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers = 8,
pin_memory= True,
)
class BabbleModel(LightningModule):
def __init__(self):
super().__init__()
self.model = timm.create_model('tf_efficientnetv2_b0.in1k',num_classes = 45, pretrained = True, in_chans = 1)
self.criterion = nn.HuberLoss()
self.metric = torchmetrics.MeanMetric()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
inputs, labels = batch
outputs = self(inputs)
loss = self.criterion(outputs, labels)
self.log('train_loss', loss, prog_bar=True, on_step = True, sync_dist=True)
return loss
def validation_step(self, batch, batch_idx):
inputs, labels = batch
outputs = self(inputs)
loss = self.criterion(outputs, labels)
self.metric(loss)
self.log('val_loss', loss, prog_bar=True, on_step = True, sync_dist=True)
def configure_optimizers(self):
optimizer = Lion(self.model.parameters(), lr=0.000069) #learning rate 000069
return optimizer
if __name__ == '__main__':
data_module = BabbleDataModule(data_file='/home/babble/Documents/ProjectBabble/Babble_data.csv') #csv goes here
model = BabbleModel()
checkpoint_callback = ModelCheckpoint(monitor='val_loss', save_on_train_epoch_end = True, save_last= True)
trainer = Trainer(
max_epochs=100, #50
callbacks=[checkpoint_callback],
#logger=wandb_logger,
)
trainer.fit(model, data_module)
For some reason my training time decreases at first, then slowly increases! I don’t really know why heh!