Hi, in my research code I have a huge dense classification layer with millions of classes. This does not fit well to the memory. What best practice suggests is to use model parallel dense layer, where a matrix multiplication is split across GPUs, that significantly impacts the performance since the huge dense layer becomes a bottleneck.
So far I have implemented a model parallel loss that distributes the computation across workers with nccl/gloo backends and is also able to propagate gradients correctly. In this layer I have different parameters on each GPU. Unfortunately, this does not work well with DDP distributed backend PyTorch plugin (https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/ddp_plugin.py) as it aggregates gradients for backward assuming all parameters should be the same.
What do you suggest to implement as a workaround for my problem or is there any?