Hey, you would have to do the splitting yourself. In your case it’s probably quite easy to just use the DistributedSampler
from pytorch in your dataloader and then call all_gather
on these resulting all_vecs
. Note however, that the distributed sampler does repeat samples for the last batch to ensure they are the same size.
1 Like