Get batch’s datapoints across all GPUs

hey @fermoren

if you need just the predictions, you can use self.all_gather within your LightningModule.

def LitModel(LightningModule):
    def some_hook(...):
        preds = ...
        if self.trainer.is_global_zero:
            preds = self.all_gather(preds)
            # do whatever

Also, we have moved the discussions to GitHub Discussions. You might want to check that out instead to get a quick response. The forums will be marked read-only soon.

Thank you