GPU memory surge after training epochs causing CUDA memory error

Post on github discussion with notebook log infomation:
github discussion

I use pytorch lightning to train a model but it always strangely fail at end: After validations completed, the trainer will start an epoch that bigger that max_epoch and causing GPU memory allocation failure (CUDA out of memory) right after this epoch (which should not run) started. For my example, I set max_epoch=5 so there should only be epoch 0-4. But there will always be an additional epoch-5 after 5 validations are done and a few seconds later the CUDA memory error will occur.

Wandb system info:

My dataset should be fine as CUDA memory and system memory are stable during the training period, except the GPU memory surge at the very end. And here are my code for lightning module and training loop which I think may cause this trouble:

class BaseModel(pl.LightningModule):
    def __init__(self, model_name=params['model'], out_features=params['out_features'], inp_channels=params['inp_channels'], pretrained=True):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, in_chans=inp_channels)
        
        # Change output features. Input features keep the same.
        if model_name == 'resnet18d':
            n_features = self.model.fc.in_features
            self.model.fc = nn.Linear(n_features, out_features, bias=True)
            
        if model_name == 'nfnet_f1':
            n_features = self.model.head.fc.in_features
            self.model.head.fc = nn.Linear(n_features, out_features, bias=True)
            
        elif model_name == 'efficientnet_b1':
            n_features = self.model.classifier.in_features
            self.model.classifier = nn.Linear(n_features, out_features, bias=True)
            
        self.criterion = nn.BCEWithLogitsLoss()
        
    def forward(self, x):
        output = self.model(x)
        return output

    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        labels = y.unsqueeze(1)
        loss = self.criterion(output, labels)

        try:
            auc = roc_auc_score(labels.detach().cpu(), output.sigmoid().detach().cpu())
            self.log('auc', auc, on_step=True, prog_bar=True, logger=True)
            self.log('Train Loss', loss, on_step=True, prog_bar=True, logger=True)
        except:
            pass

        return {'loss': loss, 'predictions': output, "labels": labels}

    def training_epoch_end(self, outputs):
        preds = []
        labels = []

        for output in outputs:
            preds += output['predictions'].detach()
            labels += output['labels'].detach()

        preds = torch.stack(preds)
        labels = torch.stack(labels)

        train_auc = roc_auc_score(labels.detach().cpu(), preds.sigmoid().detach().cpu())
        self.log('mean_train_auc', train_auc, prog_bar=True, logger=True)

    def validation_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        labels = y.unsqueeze(1)
        loss = self.criterion(output, labels)

        self.log('val_loss', loss, on_step=True, prog_bar=True, logger=True)
        return {'predictions': output, 'labels': labels}

    def validation_epoch_end(self, outputs):
        preds = []
        labels = []

        for output in outputs:
            preds += output['predictions'].detach()
            labels += output['labels'].detach()

        preds = torch.stack(preds)
        labels = torch.stack(labels)

        val_auc = roc_auc_score(labels.detach().cpu(), preds.sigmoid().detach().cpu())
        self.log('val_auc', val_auc, prog_bar=True, logger=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        output = self(x).sigmoid()
        return output

    def configure_optimizers(self):
        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] # no decay
        optimizer_parameters = [
            {
                'params': [
                    p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
                ], # Do not optimize no decay parameters.
                'weight_decay': params['weight_decay'],
            },
            {
                'params': [
                    p for n, p in param_optimizer if any(nd in n for nd in no_decay)
                ],
                'weight_decay': 0.0,
            }
        ]

        optimizer = FusedAdam(optimizer_parameters, lr=params['lr'])

        scheduler = CosineAnnealingLR(optimizer,
                                      T_max=params['T_max'],
                                      eta_min=params['min_lr'],
                                      last_epoch=-1)

        # Give out optimizer & scheduler for pytorch lightning in python dict format.
        return dict(optimizer=optimizer,
                    lr_scheduler=scheduler) # lr_scheduler for scheduler.  
kfolds = StratifiedKFold(n_splits=params['nfolds'], shuffle=True, random_state=params['seed'])

model = BaseModel()

for fold, (trn_idx, val_idx) in enumerate(kfolds.split(train_df["id"], train_df['target'])):
    # Run first round.
    if fold != 0:
        continue
    
    # PL + wandb
    wandb_logger = WandbLogger(project='G2Net-steady-exp',
                               config=params,
                               group='Effnet-CQT',
                               job_type='train',
                               name=f'Fold{fold}')
    print(f"{'='*20} Fold: {fold} {'='*20}")
    
    # Set up data module.
    train_data = train_df.loc[trn_idx]
    train_sample_data = data_sample(train_data)
    valid_data = train_df.loc[val_idx] # About 65k samples.
    data_module = DataModule(train_sample_data,
                             valid_data,
                             valid_data) # No test data yet.
    data_module.setup()
    
    # Add callbacks.
    early_stopping_callback = EarlyStopping(monitor='val_auc',
                                            mode='max',
                                            patience=5)
    checkpoint_callback = ModelCheckpoint(dirpath='./checkpoints/',
                                          filename= f'fold-{fold}-best' + '-val_auc{val_auc:.3f}',
                                          save_top_k=2,
                                          verbose=True,
                                          monitor='val_auc',
                                          mode='max')
    
    trainer = pl.Trainer(gpus=1,
                         callbacks=[early_stopping_callback,
                                    checkpoint_callback],
                         max_epochs=params['epochs'],
                         precision=params['precision'],
                         progress_bar_refresh_rate=1,
                         stochastic_weight_avg=True,
                         logger=wandb_logger)
    
    trainer.fit(model, data_module)

Can I get any clue about why this would happen ? I’m new to pytorch lightning so there might be problems I’m not aware of. Thanks a lot!