Skip instances during training

Hi, I am using the LightningModule to train a neural network across many instances/GPUs, however the data is imbalanced ( I cannot change this ), so I want to skip over some instances during training to balance it.

Here is the logic of my code inside inside the training_step()

# call this from training_step(4 batches...) elsewhere...

    def forward(... ):
        batch_size = len(inputs). # passing 4 batches at a time...
        total_loss = 0.

        for batch_idx in range(batch_size):

            train_model = True
            filtered_tags = tag[batch_idx][self.most_frequent_feature_indexes]
            if filtered_tags.sum() == 0:
                # Skip 3 times more data points where all scenario tags are 0
                if self.count_zero % 4 != 0:
                    self.count_zero += 1
                    train_model = False
                else:
                    self.count_zero += 1
                    self.total_zero_data += 1
            if (filtered_tags.sum() == 1) and (torch.argmax(filtered_tags) == 5):
                # Skip 9 out of 10 data points where there is only one tag and it's 5
                if self.count_stationary % 10 < 9:
                    self.count_stationary += 1
                    train_model = False
                else:
                    self.count_stationary += 1

            if train_model:
                logits = self.model(...)
            else:
                with torch.no_grad():
                    logits = self.model(...)

But I get this error

Latest log:
[E ProcessGroupNCCL.cpp:414] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:737] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=27, OpType=BROADCAST, Timeout(ms)=7200000) ran for 7208733 milliseconds before timing out.

There must be a way to do this right? Thank you!

Hey

Due to the nature of distributed training, you can’t “skip” a training step conditionally unless the decision is the same across all processes. Or in the case above, if one process takes the if train_model branch, then all processes need to do that in that iteration. It’s important that they don’t get out of sync, but it looks like they do based on your error.

You need to find a way to make this decision consistent across all processes.

Thanks a lot for the quick reply.

Can I just somehow force the LightningModule to not do distributed training? That would perhaps fix this right? Is there a way to do that easily?

I had tried to keep the “flow” the same across all instances be doing forward passes etc. in all of them for all data, but conditionally doing with torch.no_grad() on some. I can’t think of any other way to accomplish this.

So e.g. four different batches are sent to four different instances, and because some run with torch.no_grad() and some don’t, it messes it up?

Thanks again for any advice.