OOM error due to tensor accumulation when trying to use functional metrics API

Hey everyone,

in my LightningModule I would like to compute some metrics during training using the functional metrics API. For that I just need the prediction of the network and the target (ground truth) during the training phase. In my training_step function, I perform a shared step that fetches a batch, computes the loss and performs softmax on my output logits. It outputs these tensors in a dict. However, if I access e.g. the prediction with shared_step[‘pred’] for the metrics computation, I’ll end up accumulating these output tensors and my GPU reaches its memory limit very quickly. The more data I have in my training dataset, the more likely it is to get a OOM error. What surprises me is that if I don’t access the content of shared_step, there is no accumulation…

def shared_step(self, batch):
    # Batch
    x, y = batch['x'], batch['y']

    # Prediction
    out = self.model(x)

    # Softmax
    out_soft = torch.nn.functional.softmax(out, dim=1)

    # Loss
    ce_loss = self.ce_loss(out, y)  # cross entropy loss (LogSoftmax + NLLLoss)

    return {**batch, 'pred': out_soft, 'loss': loss}

This is the training_step():

def training_step(self, batch, batch_idx):
    shared_step = self.shared_step(batch)
    acc_batch = pl.metrics.functional.accuracy(pred=shared_step['pred'],
                                               target=shared_step['y'],
                                               num_classes=3)


What’s the best practice in PL to access the model’s output to use in metrics computation (functional API)? Any help appreciated!

Hey Johannes!

The functional API for metrics is stateless, so there should be no accumulation of tensors.

What are you returning from training_step? If you are returning the batch and predictions from training_step (or validation_step) they will be accumulated to be passed to training_step_end and validation_step_end respectively, which could be causing the OOM errors.

Let me know if you are still encountering this issue, or if you could provide more code to reproduce! - Teddy

Hey teddy,

I was returning the predictions, the input, target and the loss (all packaged in a dict) and then accessed those in training_step_end and training_epoch_end. This probably caused to the OOM. Looking at the docs - the pseudocode - I realized that for the training/validation/test phase PL iterates over the batch and performs a training_step. Its output is stored in a list which can be accessed in ...epoch_end. If I’d return the predictions and then accessed them in ..epoch_end, this list of items could be pretty big, depending on the dataset and batch size. I think I managed to solve the problem by just returning the loss and storing the computed metrics in a nn.Module within the lightning module instead. I’ll share the code here:

The UNet model can be downloaded from my github gists here

# %% imports
import numpy as np
import torch
from torch.utils import data
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer


# %% custom metric class
class CustomMetric(torch.nn.Module):
    def __init__(self, metric, metric_name, **kwargs):
        super().__init__()
        self.metric = metric
        self.metric_name = metric_name
        self.kwargs = kwargs

        self.scores = []
        self.valid_classes = []
        self.valid_matrices = []

        self.score = None
        self.valid_class = None
        self.valid_matrix = None

        self.last_scores = None
        self.last_valid_classes = None
        self.last_valid_matrices = None

    def batch(self, prediction, target):
        # compute score for every batch
        self.score = self.metric(prediction, target, **self.kwargs)
        # compute valid classes for every batch
        self.valid_class = target.unique()
        # compute valid_matrix for every batch
        dummy = torch.zeros_like(self.score)
        dummy[self.valid_class] = 1
        self.valid_matrix = dummy.type(torch.bool)

        self.scores.append(self.score)
        self.valid_classes.append(self.valid_class)
        self.valid_matrices.append(self.valid_matrix)

    def get_metrics_batch(self, mean=True):
        # returns the class metrics of the batch for the classes that are present in the image
        if mean:
            return self.score[self.valid_class].mean()
        else:
            return self.score[self.valid_class]

    def get_metrics_epoch(self, mean=True, last=False):
        if last:
            scores = torch.stack(self.last_scores).T
            masks = torch.stack(self.last_valid_matrices).T
        else:
            scores = torch.stack(self.scores).T
            masks = torch.stack(self.valid_matrices).T

        # iterate over columns (classes) and only select the present classes
        filtered = [s[m] for s, m in zip(scores, masks)]

        # filtered = [scores[:, class_idx].masked_select(masks[:, class_idx]) for class_idx in range(scores.shape[1])]
        if mean:
            return torch.stack([c.mean() for c in filtered]).mean()
        else:
            return torch.stack([c.mean() for c in filtered])

    def epoch(self):
        # compute score for every epoch

        self.last_scores = self.scores
        self.last_valid_classes = self.valid_classes
        self.last_valid_matrices = self.valid_matrices

        result = self.get_metrics_epoch(mean=True)

        self.reset()
        return result

    def reset(self):
        self.scores = []
        self.valid_classes = []
        self.valid_matrices = []

    def __repr__(self):
        return self.metric_name


# %% lightningModule
class Segmentation_Lightning(pl.LightningModule):
    def __init__(self, model, lr, num_classes):
        super().__init__()
        # model
        self.model = model

        # learning rate
        self.lr = lr

        # number of classes
        self.num_classes = num_classes

        # loss
        self.ce_loss = torch.nn.CrossEntropyLoss()

        # metrics
        self.f1_train = CustomMetric(metric=pl.metrics.functional.f1,
                                     metric_name='F1_Train',
                                     num_classes=4,
                                     average='none')

        self.f1_valid = CustomMetric(metric=pl.metrics.functional.f1,
                                     metric_name='F1_Valid',
                                     num_classes=4,
                                     average='none')

        self.f1_test = CustomMetric(metric=pl.metrics.functional.f1,
                                    metric_name='F1_Valid',
                                    num_classes=4,
                                    average='none')

        # save hyperparameters
        self.save_hyperparameters()

    def shared_step(self, batch):
        # Batch
        x, y = batch['x'], batch['y']

        # Prediction
        out = self.model(x)

        # Softmax
        out_soft = torch.nn.functional.softmax(out, dim=1)

        # Loss
        loss = self.ce_loss(out, y)  # cross entropy loss (LogSoftmax + NLLLoss)

        return {**batch, 'pred': out_soft, 'loss': loss}

    def training_step(self, batch, batch_idx):
        # Loss
        shared_step = self.shared_step(batch)

        # Metrics
        self.f1_train.batch(shared_step['pred'], shared_step['y'])  # e.g. [0.2, 0.3, 0.25, 0.25]

        # Logging
        name = 'Train'
        self.logger.experiment.log_metric(f'{name}/F1/Batch', self.f1_train.get_metrics_batch(mean=True))  # Total F1
        for class_idx, metric in zip(self.f1_train.valid_class, self.f1_train.get_metrics_batch(mean=False)):
            self.logger.experiment.log_metric(f'{name}/F1/Batch/Class/{class_idx}', metric)

        return shared_step['loss']

    def training_epoch_end(self, outputs):
        # Logging
        name = 'Train'

        # Class
        for class_idx, value in enumerate(self.f1_train.get_metrics_epoch(mean=False)):
            self.logger.experiment.log_metric(f'{name}/F1/Epoch/Class/{class_idx}', value)

        # Total
        self.logger.experiment.log_metric(f'{name}/F1/Epoch', self.f1_train.epoch())  # Total F1

    def validation_step(self, batch, batch_idx):
        # Loss
        shared_step = self.shared_step(batch)

        # Metrics
        self.f1_valid.batch(shared_step['pred'], shared_step['y'])

        # Logging
        name = 'Valid'
        self.logger.experiment.log_metric(f'{name}/F1/Batch', self.f1_valid.get_metrics_batch(mean=True))  # Total F1
        for class_idx, metric in zip(self.f1_valid.valid_class, self.f1_valid.get_metrics_batch(mean=False)):
            self.logger.experiment.log_metric(f'{name}/F1/Batch/Class/{class_idx}', metric)

        # Logging for checkpoint
        self.log('checkpoint_valid_f1_epoch', self.f1_valid.get_metrics_batch(mean=True))  # per epoch automatically

        return shared_step['loss']

    def validation_epoch_end(self, outputs):
        # Logging
        name = 'Valid'

        # Class
        for class_idx, value in enumerate(self.f1_valid.get_metrics_epoch(mean=False)):
            self.logger.experiment.log_metric(f'{name}/F1/Epoch/Class/{class_idx}', value)

        # Total
        self.logger.experiment.log_metric(f'{name}/F1/Epoch', self.f1_valid.epoch())  # Total F1

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        return optimizer


# %% dataset class
class RandomDataSet(data.Dataset):
    def __init__(self,
                 num_samples,
                 size,
                 num_classes=4,
                 inputs_dtype=torch.float32,
                 targets_dtype=torch.long
                 ):
        self.num_samples = num_samples
        self.size = size
        self.num_classes = num_classes
        self.inputs_dtype = inputs_dtype
        self.targets_dtype = targets_dtype
        self.cached_data = []

        # Generate some random input target pairs
        for num in range(self.num_samples):
            inp = torch.from_numpy(np.random.uniform(low=0, high=1, size=(3,) + size))
            tar = torch.randint(low=0, high=num_classes, size=size)
            self.cached_data.append((inp, tar))

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index: int):
        x, y = self.cached_data[index]

        # Typecasting
        x, y = x.type(self.inputs_dtype), y.type(self.targets_dtype)

        return {'x': x, 'y': y}


# %% dataloader

size = (512, 512)
batch_size = 8
num_classes = 4

dataset_train = RandomDataSet(num_samples=40, size=size, num_classes=num_classes)
dataset_valid = RandomDataSet(num_samples=16, size=size, num_classes=num_classes)

dataloader_training = DataLoader(dataset=dataset_train,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 num_workers=0)

dataloader_valid = DataLoader(dataset=dataset_valid,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=0)

batch = next(iter(dataloader_training))
x, y = batch['x'], batch['y']

# %% model

from unet import UNet

model = UNet(in_channels=3,
             out_channels=4,
             n_blocks=4,
             start_filters=32,
             activation='relu',
             normalization='group8',
             conv_mode='same',
             dim=2,
             up_mode='transposed')

# %% task init
task = Segmentation_Lightning(model=model,
                              lr=0.001,
                              num_classes=num_classes)

# %% logger init
from pytorch_lightning.loggers.neptune import NeptuneLogger
from api_key_neptune import get_api_key  # I created a .py file from which I import the api key

api_key = get_api_key()

neptune_logger = NeptuneLogger(
    api_key=api_key,
    project_name='johschmidt42/Test',  # this has to be created beforehand otherwise an error is thrown
    experiment_name='testing',
)

# %% trainer init
trainer = Trainer(gpus=1,
                  precision=32,
                  benchmark=True,
                  checkpoint_callback=False,
                  logger=neptune_logger,
                  log_every_n_steps=1,
                  num_sanity_val_steps=0,
                  enable_pl_optimizer=False,
                  )

# %% start training
trainer.max_epochs = 10
trainer.fit(task,
            train_dataloader=dataloader_training,
            val_dataloaders=dataloader_valid)

If you notice anything unusual, a bad practice for PL etc., please let me know!

Some comments:

  • I prefer to use logger.experiment.log_metric() instead of log(), because on_step=True in validation_step() doesn’t seem to work/is intended not to work.
  • The CustomMetrics is basically a wrapper to perform some steps for metrics computation
  • The goal for me is to properly compute the metrics for a given batch that could be missing a or several classes, e.g. there is no instance/pixel for class 1 in the batch. Computing the unreduced metric f1 (f1 for every class) could result in sth like: [0.5, 0.0, 0.25, 0.25]. Taking the mean of that would return 0.25, which is wrong, because class 1 has to be ignored. The result shoud be 0.333. Unfortunately, there is no ignore_index flag like for IoU.
  • I could properly implement a metric (with update() and compute() functions) to make use of multi GPU cases, but I only use 1 GPU right now, so I think this should be fine for now.

Edit:
The stored metrics within the CustomMetric class should be sent to the CPU. Otherwise the created tensors might build up and cause an OOM error on the GPU. Furthermore, there’s probably no practical benefit to inherit from torch.nn.Module.
Here’s the change:

class CustomMetric:
    def __init__(self, metric, metric_name, **kwargs):
        ...

    def batch(self, prediction, target):
        # compute scores for every batch
        self.score = self.metric(prediction, target, **self.kwargs).to('cpu')
        # compute valid classes for every batch
        self.valid_class = target.unique().to('cpu')
        # compute valid_matrix for every batch
        dummy = torch.zeros_like(self.score).to('cpu')
        dummy[self.valid_class] = 1
        self.valid_matrix = dummy.type(torch.bool).to('cpu')

        self.scores.append(self.score)
        self.valid_classes.append(self.valid_class)
        self.valid_matrices.append(self.valid_matrix)