Transformations¶
Transformations allow for modifications to the input a metric receives by wrapping its pred and target arguments.
Transformations can be implemented by either subclassing the MetricInputTransformer
base class and overriding the .transform_pred()
and/or transform_target()
functions, or by supplying a lambda function via the LambdaInputTransformer
.
A BinaryTargetTransformer
which casts target labels to 0/1 given a threshold is provided for convenience.
Module Interface¶
- class torchmetrics.wrappers.MetricInputTransformer(wrapped_metric, **kwargs)[source]¶
Abstract base class for metric input transformations.
Input transformations are characterized by them applying a transformation to the input data of a metric, and then forwarding all calls to the wrapped metric with modifications applied.
- transform_pred(pred)[source]¶
Define transform operations on the prediction data.
Overridden by subclasses. Identity by default.
- Return type:
- class torchmetrics.wrappers.LambdaInputTransformer(wrapped_metric, transform_pred=None, transform_target=None, **kwargs)[source]¶
Wrapper class for transforming a metrics’ inputs given a user-defined lambda function.
- Parameters:
wrapped_metric¶ (
Metric
) – The underlying Metric or MetricCollection.transform_pred¶ (
Optional
[Callable
[[Tensor
],Tensor
]]) – The function to apply to the predictions before computing the metric.transform_target¶ (
Optional
[Callable
[[Tensor
],Tensor
]]) – The function to apply to the target before computing the metric.
- Raises:
Example
>>> import torch >>> from torchmetrics.classification import BinaryAccuracy >>> from torchmetrics.wrappers import LambdaInputTransformer >>> >>> preds = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5, 0.6, 0.7, 0.8, 0.5, 0.4]) >>> targets = torch.tensor([1,0,0,0,0,1,1,0,0,0]) >>> >>> metric = LambdaInputTransformer(BinaryAccuracy(), lambda preds: 1 - preds) >>> metric.update(preds, targets) >>> metric.compute() tensor(0.6000)
- class torchmetrics.wrappers.BinaryTargetTransformer(wrapped_metric, threshold=0, **kwargs)[source]¶
Wrapper class for computing a metric on binarized targets.
Useful when the given ground-truth targets are continuous, but the metric requires binary targets.
- Parameters:
- Raises:
TypeError – If threshold is not an int or float.
Example
>>> import torch >>> from torchmetrics.retrieval import RetrievalMRR >>> from torchmetrics.wrappers import BinaryTargetTransformer >>> >>> preds = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5, 0.6, 0.7, 0.8, 0.5, 0.4]) >>> targets = torch.tensor([1,0,0,0,0,2,1,0,0,0]) >>> topics = torch.tensor([0,0,0,0,0,1,1,1,1,1]) >>> >>> metric = BinaryTargetTransformer(RetrievalMRR()) >>> metric.update(preds, targets, indexes=topics) >>> metric.compute() tensor(0.7500)