import torch from torchmetrics.wrappers import MultitaskWrapper from torchmetrics.regression import MeanSquaredError from torchmetrics.classification import BinaryAccuracy classification_target = torch.tensor([0, 1, 0]) regression_target = torch.tensor([2.5, 5.0, 4.0]) targets = {"Classification": classification_target, "Regression": regression_target} classification_preds = torch.tensor([0, 0, 1]) regression_preds = torch.tensor([3.0, 5.0, 2.5]) preds = {"Classification": classification_preds, "Regression": regression_preds} metrics = MultitaskWrapper({ "Classification": BinaryAccuracy(), "Regression": MeanSquaredError() }) metrics.update(preds, targets) value = metrics.compute() fig_, ax_ = metrics.plot(value)