Hi, I was wondering what is the proper way of logging metrics when using DDP. I noticed that if I want to print something inside validation_epoch_end
it will be printed twice when using 2 GPUs. I was expecting validation_epoch_end
to be called only on rank 0 and to receive the outputs from all GPUs, but I am not sure this is correct anymore. Therefore I have several questions:
-
validation_epoch_end(self, outputs)
- When using DDP does every subprocess receive the data processed from the current GPU or data processed from all GPUs, i.e. does the input parameteroutputs
contains the outputs of the entire validation set, from all GPUs? - If
outputs
is GPU/process specific what is the proper way to calculate any metric on the entire validation set invalidation_epoch_end
when using DDP?
I understand that I can solve the printing by checking self.global_rank == 0
and printing/logging only in that case, however I am trying to get a deeper understanding of what I am printing/logging in this case.
Here is a code snippet from my use case. I would like to be able to report f1, precision and recall on the entire validation dataset and I am wondering what is the correct way of doing it when using DDP.
def _process_epoch_outputs(self,
outputs: List[Dict[str, Any]]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Creates and returns tensors containing all labels and predictions
Goes over the outputs accumulated from every batch, detaches the
necessary tensors and stacks them together.
Args:
outputs (List[Dict])
"""
all_labels = []
all_predictions = []
for output in outputs:
for labels in output['labels'].detach():
all_labels.append(labels)
for predictions in output['predictions'].detach():
all_predictions.append(predictions)
all_labels = torch.stack(all_labels).long().cpu()
all_predictions = torch.stack(all_predictions).cpu()
return all_predictions, all_labels
def validation_epoch_end(self, outputs: List[Dict[str, Any]]) -> None:
"""Logs f1, precision and recall on the validation set."""
if self.global_rank == 0:
print(f'Validation Epoch: {self.current_epoch}')
predictions, labels = self._process_epoch_outputs(outputs)
for i, name in enumerate(self.label_columns):
f1, prec, recall, t = metrics.get_f1_prec_recall(predictions[:, i],
labels[:, i],
threshold=None)
self.logger.experiment.add_scalar(f'{name}_f1/Val',
f1,
self.current_epoch)
self.logger.experiment.add_scalar(f'{name}_Precision/Val',
prec,
self.current_epoch)
self.logger.experiment.add_scalar(f'{name}_Recall/Val',
recall,
self.current_epoch)
if self.global_rank == 0:
print((f'F1: {f1}, Precision: {prec}, '
f'Recall: {recall}, Threshold {t}'))