Min / Max¶
Module Interface¶
- class torchmetrics.wrappers.MinMaxMetric(base_metric, **kwargs)[source]¶
Wrapper metric that tracks both the minimum and maximum of a scalar/tensor across an experiment.
The min/max value will be updated each time
.compute
is called.- Parameters:
base_metric¶ (
Metric
) – The metric of which you want to keep track of its maximum and minimum values.kwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
base_metric` argument is not a subclasses instance of ``torchmetrics.Metric
- Example::
>>> import torch >>> from torchmetrics.wrappers import MinMaxMetric >>> from torchmetrics.classification import BinaryAccuracy >>> from pprint import pprint >>> base_metric = BinaryAccuracy() >>> minmax_metric = MinMaxMetric(base_metric) >>> preds_1 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]]) >>> preds_2 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]]) >>> labels = torch.Tensor([[0, 1], [0, 1]]).long() >>> pprint(minmax_metric(preds_1, labels)) {'max': tensor(1.), 'min': tensor(1.), 'raw': tensor(1.)} >>> pprint(minmax_metric.compute()) {'max': tensor(1.), 'min': tensor(1.), 'raw': tensor(1.)} >>> minmax_metric.update(preds_2, labels) >>> pprint(minmax_metric.compute()) {'max': tensor(1.), 'min': tensor(0.7500), 'raw': tensor(0.7500)}
- compute()[source]¶
Compute the underlying metric as well as max and min values for this metric.
Returns a dictionary that consists of the computed value (
raw
), as well as the minimum (min
) and maximum (max
) values.
- forward(*args, **kwargs)[source]¶
Use the original forward method of the base metric class.
- Return type:
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.wrappers import MinMaxMetric >>> from torchmetrics.classification import BinaryAccuracy >>> metric = MinMaxMetric(BinaryAccuracy()) >>> metric.update(torch.randint(2, (20,)), torch.randint(2, (20,))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.wrappers import MinMaxMetric >>> from torchmetrics.classification import BinaryAccuracy >>> metric = MinMaxMetric(BinaryAccuracy()) >>> values = [ ] >>> for _ in range(3): ... values.append(metric(torch.randint(2, (20,)), torch.randint(2, (20,)))) >>> fig_, ax_ = metric.plot(values)