Saving extra memory consumption because of CUDA Memory issue after a few epochs

Hello there,

I am currently encountering an issue with my PyTorch-Lightning pipeline related to memory consumption. During the 3th epoch, I encounter a CUDA OutOfMemory error: torch. cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 5.00 GiB (GPU 0; 79.15 GiB total capacity; 66.54 GiB already allocated; 3.43 GiB free; 74.69 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

The issue being, my batch size is already equal to 1 (I work with large objects) and I am checking that every sample can fit the model by doing a dry run on the biggest sample from the dataset prior to the run (with precheck_memory()), which is completed correctly. I looked at the GPU Memory Allocated panels (on Weight&Biases - see below), which indicates that my batches correctly fit the memory (peak around t = 10min), but then there is a slow but constant increase of the memory over the training which ultimately leads to a memory overflow at the end.

I assume this issue is coming from some results or metrics of the training that are kept on the GPU memory and, as those data are growing along the training, lead to this memory overflow.

Here is the model I am using (I shortened the code to make the reading easy):

class Model(ptl.LightningModule):
    def __init__(self, config, pos_weight_factor=0.9, global_step=0, criterion=pt.nn.BCEWithLogitsLoss(reduction="none")):
        super(Model, self).__init__()
        self.em = pt.nn.Sequential(
            pt.nn.Linear(N0, N1),
            pt.nn.ELU(),
            ......
        )
        self.sum = pt.nn.Sequential(....)
        self.spl = StatePoolLayer(.....)
        self.dm = pt.nn.Sequential(
            pt.nn.Linear(2*N0, N1),
            pt.nn.ELU(),
            ...
        )

        self.criterion = criterion

        self.pos_ratios = 0.05*pt.ones(N2, dtype=pt.float)
        self.pos_weight_factor = pos_weight_factor

        self.metrics = {'auc': tm.AUROC(task='multiclass', num_classes=20),
                        'f1': tm.F1Score(task='multiclass', num_classes=20),
                        'mcc': tm.MatthewsCorrCoef(task='multiclass', num_classes=20),
                        'mlrap': tm.classification.MultilabelRankingAveragePrecision(num_labels=20),
                        'mlrl': tm.classification.MultilabelRankingLoss(num_labels=20)}

    def setup(self, stage: str) -> None:
        # Change device here - setup is called after the model is moved to the device
        self.criterion = self.criterion.to(self.device)
        self.pos_ratios = self.pos_ratios.to(self.device)
        self.metrics = {k: v.to(self.device) for k, v in self.metrics.items()}
        return super().setup(stage)
    
    def precheck_memory(self, dataset):
        # quick training step on largest data: memory check and pre-allocation
        try:
            with self.eval():
                data = dataset.get_largest()
                p = self.predict_step(data)
        except:
            raise MemoryError(f"Not enough memory to pre-allocate a model containing {data[0].shape[0]} residues (pdb of {np.max(dataset.sizes[dataset.m,0])}o)")

    def forward(self, X, ids_topk, q0, M):
        ........
        return z
    
    def apply_loss(self, gt, pred, is_training=False):
        # compute weighted loss
        if is_training:
            self.pos_ratios += ((pt.mean(gt,dim=0) - self.pos_ratios) / (1.0 + np.sqrt(self.global_step))).to(self.device)
            self.criterion.pos_weight = self.pos_weight_factor * (1.0 - self.pos_ratios) / (self.pos_ratios + 1e-6)
            self.log_dict({'pos_ratios_mean': pt.mean(self.pos_ratios), 'pos_weight_mean': pt.mean(self.criterion.pos_weight)}, sync_dist=True)
        dloss = self.criterion(pred, gt).to(self.device)

        # re-weighted losses
        loss_factors = (self.pos_ratios / pt.sum(self.pos_ratios)).reshape(1,-1).to(self.device)
        losses = (loss_factors * dloss) / dloss.shape[0]

        # backward propagation
        loss = pt.sum(losses)

        return loss

    def apply_recovery(self, gt, pred):
        aa_mask = pt.any(gt > 0.5, axis=1)

        if pt.any(aa_mask):
            gt_seq = pt.argmax(gt[aa_mask], axis=1)

            # Compute ability to recover sequence from a certain confidence
            acc_any = pt.sum(pt.gt(pred[aa_mask, gt_seq], self.pred_threshold)) / gt_seq.shape[0]
            # Compute direct sequence recovery
            acc = pt.sum(pt.argmax(pred[aa_mask], axis=1) == gt_seq) / gt_seq.shape[0]
        else:
            raise NotImplementedError("....")
        return acc_any.detach().cpu().item(), acc.detach().cpu().item()
    
    def apply_metrics(self, gt, pred):              
        return {
            'auc': pt.mean(self.metrics['auc'](pred, pt.argmax(gt, axis=1))).detach().cpu().item(),
            'f1': pt.mean(self.metrics['f1'](pred, gt)).detach().cpu().item(),
            'mcc': pt.mean(self.metrics['mcc'](pred, pt.argmax(gt, axis=1)).float()).detach().cpu().item(),
            'mlrap': pt.mean(self.metrics['mlrap'](pred, gt)).detach().cpu().item(),
            'mlrl': pt.mean(self.metrics['mlrl'](pred, gt)).detach().cpu().item()
        }
    
    def training_step(self, batch):
        X, q, M, y, rids_sel = batch

        ids_topk, _, _, _, _ = extract_topology(X, 64)
        p = self.forward(X, ids_topk, q, M)

        loss = self.apply_loss(y, p, is_training=True)
        acc_any, acc = self.apply_recovery(y, p)
        scores = self.apply_metrics(y, p)

        self.log_dict({'train_loss': loss, 'train_accuracy_any': acc_any, 'train_recovery': acc, **{'train_'+k:v for k,v in scores.items()}}, sync_dist=True)

        return loss
    
    def validation_step(self, batch):
        X, q, M, y, rids_sel = batch

        ids_topk, _, _, _, _ = extract_topology(X, 64)
        p = self.forward(X, ids_topk, q, M)

        loss = self.apply_loss(y, p, is_training=False)
        acc_any, acc = self.apply_recovery(y, p)
        scores = self.apply_metrics(y, p)

        self.log_dict({'val_loss': loss, 'val_accuracy_any': acc_any, 'val_recovery': acc, **{'val_'+k:v for k,v in scores.items()}}, sync_dist=True)

        return loss
    
    def test_step(self, batch):
        X, q, M, y, rids_sel = batch

        ids_topk, _, _, _, _ = extract_topology(X, 64)
        p = self.forward(X, ids_topk, q, M)

        loss = self.apply_loss(y, p, is_training=False)
        acc_any, acc = self.apply_recovery(y, p)
        scores = self.apply_metrics(y, p)

        self.log_dict({'test_loss': loss, 'test_accuracy_any': acc_any, 'test_recovery': acc, **{'test_'+k:v for k,v in scores.items()}}, sync_dist=True)

        return loss
    
    def predict_step(self, batch):
        X, q, M, y, rids_sel = batch
        ids_topk, _, _, _, _ = extract_topology(X, 64)
        p = self.forward(X, ids_topk, q, M)

        return p
    
    
    def configure_optimizers(self, lr: float = 1e-3) -> pt.optim.Optimizer:
        return pt.optim.Adam(self.parameters(), lr=lr)

Would you have any idea to solve this issue, or save some memory by placing monitoring tensors on the CPU?