Compute Precision Recall Curve without OOM

Hello.
I’m interested in training in GPU, computing metrics on the validation loop.
I need to use heavyweight metrics such as PrecisionRecallCurve and ROC over 512x512 images in a multilabel segmentation problem, the problem is that I easily get OOM.

I have computed the metrics correctly deactivating the GPU avaliable, the problem is that 80% of the time it is doing the forward() function of the model.

So nice, metrics are working perfectly with CPU, the .update() and .compute() steps are working properly.

The problem arrives when I activate the GPU, and initialize the metric for CPU, infer on GPU and calculate metric on CPU.

I did initialize the metric like this:

PrecisionRecallCurve(pos_label=1, num_classes=1).to('cpu')

and in step_end:

precision_recall_curve.update(image.to('cpu'),
                              mask.to('cpu')

So far so good, it seems that the metric is updating correctly.

The problem is:

precision, recall, threshold = precision_recall_curve.compute()

raises me the following exception:

work = _default_pg.allgather([tensor_list], [tensor])
RuntimeError: Tensors must be CUDA and dense

I guess that this kind of hybrid functionality it’s not supported.

I guess the only solution I have, is to compute the metrics per sample/batch individually and then perform some kind of aggregation since the whole validation set (which is not too big) does not fit in vram.

I would like to know if anyone has faced a similar problem and managed to solve it.

Thanks

Hello,
Have you found a solution to your problem, I am getting the exact same problem. Some metrics are too heavy to be computed to the gpu and when training in a multi gpu setup metrics are expected to be on CUDA when gathered together.

Thanks !

Hey,

we know that this is a problem and this is why we introduced the binned version of multiple metrics (like the PR-Curve) a while ago, which computes it in constant memory.
The only downside is, that it is not exactly as accurate as the memory-hungry one depending on the thresholds you choose.

Cheers,
Justus