Get batch’s datapoints across all GPUs


I´m running my model in a cluster with multiples GPUs (2). My problem is that I would like to access all the datapoints in the batch. Because I´m using more than 2 GPUs, my batch in divided between those two devices for parallelisation purposes, which means than when I access the data in the batch in eval/training, I´m getting just half the batch.

How could I obtain the complete batch and the predictions of the model that are divided among different devices/GPUs? I tried to set the flag accelerator=“ddp” but the problem persists.


hey @fermoren

if you need just the predictions, you can use self.all_gather within your LightningModule.

def LitModel(LightningModule):
    def some_hook(...):
        preds = ...
        if self.trainer.is_global_zero:
            preds = self.all_gather(preds)
            # do whatever

Also, we have moved the discussions to GitHub Discussions. You might want to check that out instead to get a quick response. The forums will be marked read-only soon.

Thank you

Hey, thanks for your answer!

I´m trying your suggested solution into the LightningModule’s forward and I´m afraid it is just returning half the data, that is, just the data in one of the two gpus I´m using, any idea of what’s going on and how can I solve it?


PD: I will move the question to github discussions.