Yeah, you’re right. By using simply DistributedSampler
and dist.all_gather_object
of pytorch (not pytorch-lightning), I could do it!
Although my code still does not care about the last batch, I would like to share my code for someone who wants to see it.
import itertools
import torch.distributed as dist
from torch.utils.data import DistributedSampler
def on_train_start(self):
dataset = MyDataset(...)
sampler = DistributedSampler(dataset, shuffle=False)
dataloader = Dataloader(dataset=dataset, sampler=sampler, ...)
all_vecs = []
for batch in tqdm(dataloader , disable=self.global_rank == 0):
with torch.no_grad():
all_vecs.append(self.my_model(batch['xxx']))
all_rank_all_vecs = [None for _ in range(dist.get_world_size())]
dist.all_gather_object(all_rank_cand_vecs, all_vecs)
# flatten list of list
all_rank_cand_vecs = torch.stack(flatten(all_rank_cand_vecs), dim=0)