Multi-task Wrapper¶
Module Interface¶
- class torchmetrics.wrappers.MultitaskWrapper(task_metrics, prefix=None, postfix=None)[source]¶
Wrapper class for computing different metrics on different tasks in the context of multitask learning.
In multitask learning the different tasks requires different metrics to be evaluated. This wrapper allows for easy evaluation in such cases by supporting multiple predictions and targets through a dictionary. Note that only metrics where the signature of update follows the standard preds, target is supported.
- Parameters:
task_metrics¶ (
Dict
[str
,Union
[Metric
,MetricCollection
]]) – Dictionary associating each task to a Metric or a MetricCollection. The keys of the dictionary represent the names of the tasks, and the values represent the metrics to use for each task.prefix¶ (
Optional
[str
]) – A string to append in front of the metric keys. If not provided, will default to an empty string.postfix¶ (
Optional
[str
]) – A string to append after the keys of the output dict. If not provided, will default to an empty string.
Tip
The use prefix and postfix allows for easily creating task wrappers for training, validation and test. The arguments are only changing the output keys of the computed metrics and not the input keys. This means that a
MultitaskWrapper
initialized asMultitaskWrapper({"task": Metric()}, prefix="train_")
will still expect the input to be a dictionary with the key “task”, but the output will be a dictionary with the key “train_task”.- Raises:
TypeError – If argument task_metrics is not an dictionary
TypeError – If not all values in the task_metrics dictionary is instances of Metric or MetricCollection
ValueError – If prefix is not a string
ValueError – If postfix is not a string
- Example (with a single metric per class):
>>> import torch >>> from torchmetrics.wrappers import MultitaskWrapper >>> from torchmetrics.regression import MeanSquaredError >>> from torchmetrics.classification import BinaryAccuracy >>> >>> classification_target = torch.tensor([0, 1, 0]) >>> regression_target = torch.tensor([2.5, 5.0, 4.0]) >>> targets = {"Classification": classification_target, "Regression": regression_target} >>> >>> classification_preds = torch.tensor([0, 0, 1]) >>> regression_preds = torch.tensor([3.0, 5.0, 2.5]) >>> preds = {"Classification": classification_preds, "Regression": regression_preds} >>> >>> metrics = MultitaskWrapper({ ... "Classification": BinaryAccuracy(), ... "Regression": MeanSquaredError() ... }) >>> metrics.update(preds, targets) >>> metrics.compute() {'Classification': tensor(0.3333), 'Regression': tensor(0.8333)}
- Example (with several metrics per task):
>>> import torch >>> from torchmetrics import MetricCollection >>> from torchmetrics.wrappers import MultitaskWrapper >>> from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError >>> from torchmetrics.classification import BinaryAccuracy, BinaryF1Score >>> >>> classification_target = torch.tensor([0, 1, 0]) >>> regression_target = torch.tensor([2.5, 5.0, 4.0]) >>> targets = {"Classification": classification_target, "Regression": regression_target} >>> >>> classification_preds = torch.tensor([0, 0, 1]) >>> regression_preds = torch.tensor([3.0, 5.0, 2.5]) >>> preds = {"Classification": classification_preds, "Regression": regression_preds} >>> >>> metrics = MultitaskWrapper({ ... "Classification": MetricCollection(BinaryAccuracy(), BinaryF1Score()), ... "Regression": MetricCollection(MeanSquaredError(), MeanAbsoluteError()) ... }) >>> metrics.update(preds, targets) >>> metrics.compute() {'Classification': {'BinaryAccuracy': tensor(0.3333), 'BinaryF1Score': tensor(0.)}, 'Regression': {'MeanSquaredError': tensor(0.8333), 'MeanAbsoluteError': tensor(0.6667)}}
- Example (with a prefix and postfix):
>>> import torch >>> from torchmetrics.wrappers import MultitaskWrapper >>> from torchmetrics.regression import MeanSquaredError >>> from torchmetrics.classification import BinaryAccuracy >>> >>> classification_target = torch.tensor([0, 1, 0]) >>> regression_target = torch.tensor([2.5, 5.0, 4.0]) >>> targets = {"Classification": classification_target, "Regression": regression_target} >>> classification_preds = torch.tensor([0, 0, 1]) >>> regression_preds = torch.tensor([3.0, 5.0, 2.5]) >>> preds = {"Classification": classification_preds, "Regression": regression_preds} >>> >>> metrics = MultitaskWrapper({ ... "Classification": BinaryAccuracy(), ... "Regression": MeanSquaredError() ... }, prefix="train_") >>> metrics.update(preds, targets) >>> metrics.compute() {'train_Classification': tensor(0.3333), 'train_Regression': tensor(0.8333)}
- forward(task_preds, task_targets)[source]¶
Call underlying forward methods for all tasks and return the result as a dictionary.
- plot(val=None, axes=None)[source]¶
Plot a single or multiple values from the metric.
All tasks’ results are plotted on individual axes.
- Parameters:
val¶ (
Union
[Dict
,Sequence
[Dict
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.axes¶ (
Optional
[Sequence
[Axes
]]) – Sequence of matplotlib axis objects. If provided, will add the plots to the provided axis objects. If not provided, will create them.
- Return type:
- Returns:
Sequence of tuples with Figure and Axes object for each task.
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.wrappers import MultitaskWrapper >>> from torchmetrics.regression import MeanSquaredError >>> from torchmetrics.classification import BinaryAccuracy >>> >>> classification_target = torch.tensor([0, 1, 0]) >>> regression_target = torch.tensor([2.5, 5.0, 4.0]) >>> targets = {"Classification": classification_target, "Regression": regression_target} >>> >>> classification_preds = torch.tensor([0, 0, 1]) >>> regression_preds = torch.tensor([3.0, 5.0, 2.5]) >>> preds = {"Classification": classification_preds, "Regression": regression_preds} >>> >>> metrics = MultitaskWrapper({ ... "Classification": BinaryAccuracy(), ... "Regression": MeanSquaredError() ... }) >>> metrics.update(preds, targets) >>> value = metrics.compute() >>> fig_, ax_ = metrics.plot(value)
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.wrappers import MultitaskWrapper >>> from torchmetrics.regression import MeanSquaredError >>> from torchmetrics.classification import BinaryAccuracy >>> >>> classification_target = torch.tensor([0, 1, 0]) >>> regression_target = torch.tensor([2.5, 5.0, 4.0]) >>> targets = {"Classification": classification_target, "Regression": regression_target} >>> >>> classification_preds = torch.tensor([0, 0, 1]) >>> regression_preds = torch.tensor([3.0, 5.0, 2.5]) >>> preds = {"Classification": classification_preds, "Regression": regression_preds} >>> >>> metrics = MultitaskWrapper({ ... "Classification": BinaryAccuracy(), ... "Regression": MeanSquaredError() ... }) >>> values = [] >>> for _ in range(10): ... values.append(metrics(preds, targets)) >>> fig_, ax_ = metrics.plot(values)