Training Steps Erroneous

I’m trying to train resnet50 model with FGVC8 dataset which has 18632 images totally. While performing training using this lightning module, The max steps for an epoch becomes 292 and the batch size is 64 which when multiplied gives 18688. I didn’t set any max_steps explicitly. I used StratifiedKFold to split the dataset into training and validation set. The number of splits is 5.

While splitting the dataset becomes 14905 for training and 3727 for validation. I’m passing this pytorch dataset into dataloader with batch size 64 and num_workers 4

After splitting, The steps for training must be 14905/64 which is around 232. But the network trains for 292 steps. Why is this happening. Is the network being trained on the whole dataset. I even didn’t perform any augmentations.

Link to my Source code: Resnet50 Training FGVC8

Dataset:

def train_augs():
    return Compose([
        Resize(256,256),
        Normalize(),
        ToTensorV2()
    ])
def val_augs():
    return Compose([
        Resize(256,256),
        Normalize(),
        ToTensorV2()
    ])



class FGVC(Dataset):
    def __init__(self,df,root_dir,transforms=None):
        self.root_dir = root_dir
        self.transforms = transforms
        self.dataf = df
    def __len__(self):
        return len(self.dataf)
    def __getitem__(self,idx):
        img_name = self.dataf.loc[idx,'image']
        img = cv2.imread(os.path.join(self.root_dir,img_name))
        #print(img_name)
        label = self.dataf.loc[idx,'lbc']
        if self.transforms is not None:
            img = self.transforms(image=img)['image']
        return img,torch.tensor(label)

Training Loop:

for fold,(train_ind,val_ind) in enumerate(StratifiedKFold(n_splits=5).split(train_data.image.values,train_data.labels.values)):
    
    print(f"***** Training Fold {fold}/5 *****")

    train_data1 = train_data.loc[train_ind].reset_index(drop=True)
    val_data1 = train_data.loc[val_ind].reset_index(drop=True)

    train_dataset = FGVC(df=train_data1,transforms=train_augs(),root_dir="train_images")
    val_dataset = FGVC(df=val_data1,transforms=val_augs(),root_dir="train_images")

    train_loader = DataLoader(train_dataset,batch_size=64,num_workers=4,shuffle=True)
    val_loader = DataLoader(val_dataset,batch_size=64,num_workers=4)
    
    print(len(train_loader.dataset),len(val_loader.dataset))

    trainer.fit(model,train_dataloader=train_loader,val_dataloaders=val_loader)

My Lightning Module:

class FGVCNet (pl.LightningModule):
    def __init__(self):
        super(FGVCNet,self).__init__()
        self.model = resnet50(pretrained=True)
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs,train_data.labels.nunique())
    

    def forward(self,x):
        return self.model(x)
    
    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(),lr=0.01)
        scheduler = {
            'scheduler':optim.lr_scheduler.ReduceLROnPlateau(
                optimizer=opt,
                mode='min',
                factor=0.5,
                patience=2,
                verbose=True
            ),
            'monitor':'val_loss',
            'interval':'epoch',
            'frequency':1,
            'strict':True
        }

        self.opt = opt
        self.scheduler = scheduler
        
        return [opt],[scheduler]

    def training_step(self,batch,batch_idx):
        X,y = batch
        y_hat = self.model(X)
        loss_tr = F.cross_entropy(y_hat,y,weight=weights)
        f1_tr = pl.metrics.functional.f1(y_hat,y,12) 
        self.log("TrainLoss",loss_tr,prog_bar=True,on_step=True,on_epoch=True)
        self.log("TrainF1",f1_tr,prog_bar=True,on_epoch=True,on_step=True)
        return loss_tr
    
    def validation_step(self,batch,batch_idx):
        X,y = batch
        y_hat = self.model(X)
        loss_val = F.cross_entropy(y_hat,y,weight=weights)
        f1_val = pl.metrics.functional.f1(y_hat,y,12)
        self.log("val_loss",loss_val,prog_bar=True,on_step=True)
        self.log("val_f1",f1_val,prog_bar=True,on_epoch=True)
        return loss_val