DDP: replacing torch dist. calls with PL directives for inter-node communication?

I’m working on a self-supervised learning problem where I’m trying to use Facebook’s VICReg. In their repo, they do torch.distributed.all_gather() and all_reduce() in order to compute variance and covariance across batches in a differentiable fashion, in the following way (link to their code):

class FullGatherLayer(torch.autograd.Function):
    """
    Gather tensors from all process and support backward propagation
    for the gradients across processes.
    """

    @staticmethod
    def forward(ctx, x):
        output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
        dist.all_gather(output, x)
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        all_gradients = torch.stack(grads)
        dist.all_reduce(all_gradients)
        return all_gradients[dist.get_rank()]

That routine is called via FullGatherLayer.apply(data). (The .backward part is never explicitly called from the users’ code, rather it is called via autograd methods)

The problem I’m having is, when running vi PL DDP SLURM & srun on multiple nodes of 8 GPUs each (with NCCL backend, and with data on the GPU device(s) ) , the dist.all_gather() operation only gathers across the 8 GPUs in each node, not the total effective batch size of 8*[number of nodes].

How do we modify this to use the appropriate PyTorch lightning directives instead?

Thanks,


PS -
I tried, but the @staticmethod aspect is throwing me off, and when I tried passing in my pl.LightningModule as module I got a Python error about too many levels of recursion.

class FullGatherLayer_pl(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, module):
        # module is supposed to be PyTorch Lightning module? 
        output = [torch.zeros_like(x) for _ in range(module.world_size())]
        module.all_gather(output, x)
        ctx.module = module # try overloading ctx to save for backward? 
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        all_gradients = torch.stack(grads)
        ctx.module.all_reduce(all_gradients)  # actually it's not called all_reduce in pl.Trainer; what is it? 
        return all_gradients[dist.get_rank()]

Hi @scott-9Uixp

This is very odd! Is the world size correctly set? Certainly should include results from all processes.

How do we modify this to use the appropriate PyTorch lightning directives instead?

For all_gather with support for backward/gradients, you can use

self.all_gather(..., sync_grads=True)

in the LightningModule directly.

This should be equivalent to the old (and today unneeded) FullGatherLayer. Back then, this wasn’t readily availabe in PyTorch, hence the special class. Could you give this a try?

PS: In 2022, I also converted vicreg to Lightning. If you want, I can invite you to the private repo if you share your GH username with me.

Cheers

Thanks for your reply!

Re. world size being set correctly: ? Well, all other aspects of training work ok. We (@harmonai_org) have been using the recommended Lightning SLURM setup since…last August?.. Just never had to manually do our own gather/reduce until now.

Re. the self.all_gather(..., sync_grads=True): it was the self part that was giving the Python recursion errors.

YES I would LOVE to see your VICReg implementation! Thanks. GitHub username is @drscotthawley

1 Like

Re. the self.all_gather(..., sync_grads=True) : it was the self part that was giving the Python recursion errors.

That’s interesting and weird. I’ve never seen this. Can I see how this can be reproduced? Do you have anything you could share to run to debug? And maybe the full error with stack trace would be useful too.


I invited you to my repo with the old code. I’m not sure how useful it is, it’s been a long time and maybe it’s too outdated (it also doesn’t use self.all_gather). I only produced it for a quick demo and it might be missing some parts.
Here is the file with the LightningModule implementation:
https://github.com/awaelchli/vicreg-lightning/blob/convert-to-lightning/main.py

1 Like

Thanks for sharing. Yea looks like your code uses the same FullGatherLayer as the original.

And you tried running this on multiple nodes, not just multiple GPUs?

I’m working on putting together a publicly-readable example; should be able to share tomorrow. Thanks for your interest in helping.

This may be due to to an SBATCH script error regarding the number of tasks per node. A little more work on my end may produce a solution. :crossed_fingers: Will update later.

Thanks again for your help Adrian. Turns out that my problem was all due to an improper SBATCH submission script.

Without any changes to the actual Python code, I’m now getting the expected all_gather behavior! :slight_smile: I’ll mark this thread as solved.

Very excited to be able to move on.

2 Likes

Awesome! I’m very glad :tada:

Hello @scott-9Uixp,

I am also working on self-supervised learning task that requires constructing a similarity matrix (with dimensions (batch_size, batch_size) before computing the loss value, but I want to concatenate the embeddings among the multiple GPUs before constructing the similarity matrix and then computing the loss value. This similarity matrix should have dimensions (num_gpus * batch_size, num_gpus * batch_size).

Since I am using pytorch lightning for data-distribution parallelism, how were you able to implement the FullyGatherLayer in the training_step hook? Is your publicly-readable example available, and if so, may you share the link with me?

Thanks!

Hello @awaelchli ,

Is possible for me to see the private repo implementation of VICReg? I am working on Self-supervised learning and want to put together a training procedure for multiple GPUs as well.

Thanks for your time!

Hi @PraljakReps
I can share you my old code yes. What’s your GitHub user name?

Thanks! Here is the GitHub username: PraljakReps

Cool. The implementation is on the convert-to-lightning branch.
https://github.com/awaelchli/vicreg-lightning/tree/convert-to-lightning
Again, it’s a very old version so I’m not sure how helpful it will be.

Hey, also here the “one file add-on” I made for VICReg that includes the relevant code:

1 Like