Multi-task Wrapper

Module Interface

class torchmetrics.wrappers.MultitaskWrapper(task_metrics, prefix=None, postfix=None)[source]

Wrapper class for computing different metrics on different tasks in the context of multitask learning.

In multitask learning the different tasks requires different metrics to be evaluated. This wrapper allows for easy evaluation in such cases by supporting multiple predictions and targets through a dictionary. Note that only metrics where the signature of update follows the standard preds, target is supported.

Parameters:
  • task_metrics (Dict[str, Union[Metric, MetricCollection]]) – Dictionary associating each task to a Metric or a MetricCollection. The keys of the dictionary represent the names of the tasks, and the values represent the metrics to use for each task.

  • prefix (Optional[str]) – A string to append in front of the metric keys. If not provided, will default to an empty string.

  • postfix (Optional[str]) – A string to append after the keys of the output dict. If not provided, will default to an empty string.

Note

The use pre prefix and postfix allows for easily creating task wrappers for training, validation and test. The arguments are only changing the output keys of the computed metrics and not the input keys. This means that a MultitaskWrapper initialized as MultitaskWrapper({"task": Metric()}, prefix="train_") will still expect the input to be a dictionary with the key “task”, but the output will be a dictionary with the key “train_task”.

Raises:
  • TypeError – If argument task_metrics is not an dictionary

  • TypeError – If not all values in the task_metrics dictionary is instances of Metric or MetricCollection

  • ValueError – If prefix is not a string

  • ValueError – If postfix is not a string

Example (with a single metric per class):
>>> import torch
>>> from torchmetrics.wrappers import MultitaskWrapper
>>> from torchmetrics.regression import MeanSquaredError
>>> from torchmetrics.classification import BinaryAccuracy
>>>
>>> classification_target = torch.tensor([0, 1, 0])
>>> regression_target = torch.tensor([2.5, 5.0, 4.0])
>>> targets = {"Classification": classification_target, "Regression": regression_target}
>>>
>>> classification_preds = torch.tensor([0, 0, 1])
>>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
>>> preds = {"Classification": classification_preds, "Regression": regression_preds}
>>>
>>> metrics = MultitaskWrapper({
...     "Classification": BinaryAccuracy(),
...     "Regression": MeanSquaredError()
... })
>>> metrics.update(preds, targets)
>>> metrics.compute()
{'Classification': tensor(0.3333), 'Regression': tensor(0.8333)}
Example (with several metrics per task):
>>> import torch
>>> from torchmetrics import MetricCollection
>>> from torchmetrics.wrappers import MultitaskWrapper
>>> from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError
>>> from torchmetrics.classification import BinaryAccuracy, BinaryF1Score
>>>
>>> classification_target = torch.tensor([0, 1, 0])
>>> regression_target = torch.tensor([2.5, 5.0, 4.0])
>>> targets = {"Classification": classification_target, "Regression": regression_target}
>>>
>>> classification_preds = torch.tensor([0, 0, 1])
>>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
>>> preds = {"Classification": classification_preds, "Regression": regression_preds}
>>>
>>> metrics = MultitaskWrapper({
...     "Classification": MetricCollection(BinaryAccuracy(), BinaryF1Score()),
...     "Regression": MetricCollection(MeanSquaredError(), MeanAbsoluteError())
... })
>>> metrics.update(preds, targets)
>>> metrics.compute()
{'Classification': {'BinaryAccuracy': tensor(0.3333), 'BinaryF1Score': tensor(0.)},
 'Regression': {'MeanSquaredError': tensor(0.8333), 'MeanAbsoluteError': tensor(0.6667)}}
Example (with a prefix and postfix):
>>> import torch
>>> from torchmetrics.wrappers import MultitaskWrapper
>>> from torchmetrics.regression import MeanSquaredError
>>> from torchmetrics.classification import BinaryAccuracy
>>>
>>> classification_target = torch.tensor([0, 1, 0])
>>> regression_target = torch.tensor([2.5, 5.0, 4.0])
>>> targets = {"Classification": classification_target, "Regression": regression_target}
>>> classification_preds = torch.tensor([0, 0, 1])
>>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
>>> preds = {"Classification": classification_preds, "Regression": regression_preds}
>>>
>>> metrics = MultitaskWrapper({
...     "Classification": BinaryAccuracy(),
...     "Regression": MeanSquaredError()
... }, prefix="train_")
>>> metrics.update(preds, targets)
>>> metrics.compute()
{'train_Classification': tensor(0.3333), 'train_Regression': tensor(0.8333)}
clone(prefix=None, postfix=None)[source]

Make a copy of the metric.

Parameters:
  • prefix (Optional[str]) – a string to append in front of the metric keys

  • postfix (Optional[str]) – a string to append after the keys of the output dict.

Return type:

MultitaskWrapper

compute()[source]

Compute metrics for all tasks.

Return type:

Dict[str, Any]

forward(task_preds, task_targets)[source]

Call underlying forward methods for all tasks and return the result as a dictionary.

Return type:

Dict[str, Any]

items(flatten=True)[source]

Iterate over task and task metrics.

Parameters:

flatten (bool) – If True, will iterate over all sub-metrics in the case of a MetricCollection. If False, will iterate over the task names and the corresponding metrics.

Return type:

Iterable[Tuple[str, Module]]

keys(flatten=True)[source]

Iterate over task names.

Parameters:

flatten (bool) – If True, will iterate over all sub-metrics in the case of a MetricCollection. If False, will iterate over the task names and the corresponding metrics.

Return type:

Iterable[str]

plot(val=None, axes=None)[source]

Plot a single or multiple values from the metric.

All tasks’ results are plotted on individual axes.

Parameters:
  • val (Union[Dict, Sequence[Dict], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • axes (Optional[Sequence[Axes]]) – Sequence of matplotlib axis objects. If provided, will add the plots to the provided axis objects. If not provided, will create them.

Return type:

Sequence[Tuple[Figure, Union[Axes, ndarray]]]

Returns:

Sequence of tuples with Figure and Axes object for each task.

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.wrappers import MultitaskWrapper
>>> from torchmetrics.regression import MeanSquaredError
>>> from torchmetrics.classification import BinaryAccuracy
>>>
>>> classification_target = torch.tensor([0, 1, 0])
>>> regression_target = torch.tensor([2.5, 5.0, 4.0])
>>> targets = {"Classification": classification_target, "Regression": regression_target}
>>>
>>> classification_preds = torch.tensor([0, 0, 1])
>>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
>>> preds = {"Classification": classification_preds, "Regression": regression_preds}
>>>
>>> metrics = MultitaskWrapper({
...     "Classification": BinaryAccuracy(),
...     "Regression": MeanSquaredError()
... })
>>> metrics.update(preds, targets)
>>> value = metrics.compute()
>>> fig_, ax_ = metrics.plot(value)
../_images/multi_task_wrapper-1_00.png
../_images/multi_task_wrapper-1_01.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.wrappers import MultitaskWrapper
>>> from torchmetrics.regression import MeanSquaredError
>>> from torchmetrics.classification import BinaryAccuracy
>>>
>>> classification_target = torch.tensor([0, 1, 0])
>>> regression_target = torch.tensor([2.5, 5.0, 4.0])
>>> targets = {"Classification": classification_target, "Regression": regression_target}
>>>
>>> classification_preds = torch.tensor([0, 0, 1])
>>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
>>> preds = {"Classification": classification_preds, "Regression": regression_preds}
>>>
>>> metrics = MultitaskWrapper({
...     "Classification": BinaryAccuracy(),
...     "Regression": MeanSquaredError()
... })
>>> values = []
>>> for _ in range(10):
...     values.append(metrics(preds, targets))
>>> fig_, ax_ = metrics.plot(values)
../_images/multi_task_wrapper-2_00.png
../_images/multi_task_wrapper-2_01.png
reset()[source]

Reset all underlying metrics.

Return type:

None

update(task_preds, task_targets)[source]

Update each task’s metric with its corresponding pred and target.

Parameters:
  • task_preds (Dict[str, Any]) – Dictionary associating each task to a Tensor of pred.

  • task_targets (Dict[str, Any]) – Dictionary associating each task to a Tensor of target.

Return type:

None

values(flatten=True)[source]

Iterate over task metrics.

Parameters:

flatten (bool) – If True, will iterate over all sub-metrics in the case of a MetricCollection. If False, will iterate over the task names and the corresponding metrics.

Return type:

Iterable[Module]