I am trying to move all validation outputs to one process to calculate my metric. My code looks something like the follows
At each validation step
def validation_step(self, batch, batch_idx):
# forward
outputs = self.forward(batch)
y_pred = outputs.logits
y_true = batch["rbd_labels"]
# loss
loss = F.cross_entropy(y_pred, y_true)
self.log("valid_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
instances = batch["instances"]
labels = batch["rbd_labels"]
scores = y_pred[:, 1]
return (instances, labels, scores)
Then I call self.all_gather
at the output
def validation_epoch_end(self, validation_step_outputs):
# Step 1: collect instances and labels
instance_label_score_map = defaultdict(list)
instances = []
labels = []
scores = []
for i, l, s in validation_step_outputs:
instances += i.tolist()
labels += l.tolist()
scores += s.tolist()
out_instances = self.all_gather(instances)
out_labels = self.all_gather(labels)
out_scores = self.all_gather(scores)
if dist.get_rank() == 0:
print("dist rank: 0")
# Note: they stack in a weird way so we need to convert it back
out_instances = torch.stack(out_instances).cpu().tolist()
out_labels = torch.stack(out_labels).cpu().tolist()
out_scores = torch.stack(out_scores).cpu().tolist()
score = compute_score()
self.log("valid_score", score)
Then my programs hangs right after the sanity check after the validation_epoch_end
.
Appreciate any help!