How to apply multiple GPUs on not `training_step`?

I am implementing a dual-encoder to solve Entity Linking tasks. I could implement a dual-encoder with 8 GPUs (DDPStrategy).

For the next step, I tried to implement Hard-Negative mining during training. For every epoch, I need to encode all candidates (say all Wikipedia articles) and save them. The code is like this.

def on_train_epoch_start(self):
     all_vecs = []
     all_wikipedia_dataloader = create_dataloader(..)
     for batch in tqdm(all_wikipedia_dataloader):
           all_vecs.append(self.encoder(input)[0][:, 0, :]) # hidden vector of CLS token 

Although this code worked, I noticed that each GPU has to encode all candidates (NOT split the data!). This is time consuming because I can use 8 GPUs for this code. On training_step the code can actually apply DDP strategy.

So my question is how to apply multiple GPUs on not training_step?

My code is similar to this. But the code is separated from the training code

Hey, you would have to do the splitting yourself. In your case it’s probably quite easy to just use the DistributedSampler from pytorch in your dataloader and then call all_gather on these resulting all_vecs. Note however, that the distributed sampler does repeat samples for the last batch to ensure they are the same size.

1 Like

Hi @justusschock

Yeah, you’re right. By using simply DistributedSampler and dist.all_gather_object of pytorch (not pytorch-lightning), I could do it!

Although my code still does not care about the last batch, I would like to share my code for someone who wants to see it.

import itertools
import torch.distributed as dist
from import DistributedSampler

def on_train_start(self): 
    dataset = MyDataset(...)
    sampler = DistributedSampler(dataset, shuffle=False)
    dataloader = Dataloader(dataset=dataset, sampler=sampler, ...)
    all_vecs = []
    for batch in tqdm(dataloader , disable=self.global_rank == 0):
        with torch.no_grad():
    all_rank_all_vecs = [None for _ in range(dist.get_world_size())]
    dist.all_gather_object(all_rank_cand_vecs, all_vecs)
    # flatten list of list
    all_rank_cand_vecs = torch.stack(flatten(all_rank_cand_vecs), dim=0) 

That’s great new, thanks for sharing with everybody!