Compute Loss After Sharing Tensor Across GPUs

I’m currently attempting to make a Multi-GPU-supported CLIP training script, but am hitting a wall. I need to compute two matrices that are composed of whole batch statistics before I can compute loss. Namely, I need to compute the image and text embeddings of an entire batch. Only then can I compute the sub batch losses.

How can I first calculate and share the whole batch matrices across GPUs before computing losses?