Transformations

Transformations allow for modifications to the input a metric receives by wrapping its pred and target arguments. Transformations can be implemented by either subclassing the MetricInputTransformer base class and overriding the .transform_pred() and/or transform_target() functions, or by supplying a lambda function via the LambdaInputTransformer. A BinaryTargetTransformer which casts target labels to 0/1 given a threshold is provided for convenience.

Module Interface

class torchmetrics.wrappers.MetricInputTransformer(wrapped_metric, **kwargs)[source]

Abstract base class for metric input transformations.

Input transformations are characterized by them applying a transformation to the input data of a metric, and then forwarding all calls to the wrapped metric with modifications applied.

compute()[source]

Wrap the compute call of the underlying metric.

Return type:

Any

forward(*args, **kwargs)[source]

Wrap the forward call of the underlying metric.

Return type:

Any

transform_pred(pred)[source]

Define transform operations on the prediction data.

Overridden by subclasses. Identity by default.

Return type:

Tensor

transform_target(target)[source]

Define transform operations on the target data.

Overridden by subclasses. Identity by default.

Return type:

Tensor

update(*args, **kwargs)[source]

Wrap the update call of the underlying metric.

Return type:

None

class torchmetrics.wrappers.LambdaInputTransformer(wrapped_metric, transform_pred=None, transform_target=None, **kwargs)[source]

Wrapper class for transforming a metrics’ inputs given a user-defined lambda function.

Parameters:
Raises:
  • TypeError – If transform_pred is not a Callable.

  • TypeError – If transform_target is not a Callable.

Example

>>> import torch
>>> from torchmetrics.classification import BinaryAccuracy
>>> from torchmetrics.wrappers import LambdaInputTransformer
>>>
>>> preds = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5, 0.6, 0.7, 0.8, 0.5, 0.4])
>>> targets = torch.tensor([1,0,0,0,0,1,1,0,0,0])
>>>
>>> metric = LambdaInputTransformer(BinaryAccuracy(), lambda preds: 1 - preds)
>>> metric.update(preds, targets)
>>> metric.compute()
tensor(0.6000)
class torchmetrics.wrappers.BinaryTargetTransformer(wrapped_metric, threshold=0, **kwargs)[source]

Wrapper class for computing a metric on binarized targets.

Useful when the given ground-truth targets are continuous, but the metric requires binary targets.

Parameters:
  • wrapped_metric (Union[Metric, MetricCollection]) – The underlying Metric or MetricCollection.

  • threshold (float) – The binarization threshold for the targets. Targets values t are cast to binary with t > threshold.

Raises:

TypeError – If threshold is not an int or float.

Example

>>> import torch
>>> from torchmetrics.retrieval import RetrievalMRR
>>> from torchmetrics.wrappers import BinaryTargetTransformer
>>>
>>> preds = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5, 0.6, 0.7, 0.8, 0.5, 0.4])
>>> targets = torch.tensor([1,0,0,0,0,2,1,0,0,0])
>>> topics = torch.tensor([0,0,0,0,0,1,1,1,1,1])
>>>
>>> metric = BinaryTargetTransformer(RetrievalMRR())
>>> metric.update(preds, targets, indexes=topics)
>>> metric.compute()
tensor(0.7500)
transform_target(target)[source]

Cast the target tensor to binary values according to the threshold.

Output assumes same type as input.

Return type:

Tensor