Bootstrapper¶
Module Interface¶
- class torchmetrics.wrappers.BootStrapper(base_metric, num_bootstraps=10, mean=True, std=True, quantile=None, raw=False, sampling_strategy='poisson', **kwargs)[source]¶
Using Turn a Metric into a Bootstrapped.
That can automate the process of getting confidence intervals for metric values. This wrapper class basically keeps multiple copies of the same base metric in memory and whenever
update
orforward
is called, all input tensors are resampled (with replacement) along the first dimension.- Parameters:
num_bootstraps¶ (
int
) – number of copies to make of the base metric for bootstrappingstd¶ (
bool
) – ifTrue
return the standard deviation of the bootstrapsquantile¶ (
Union
[float
,Tensor
,None
]) – if given, returns the quantile of the bootstraps. Can only be used with pytorch version 1.6 or highersampling_strategy¶ (
str
) – Determines how to produce bootstrapped samplings. Either'poisson'
ormultinomial
. If'possion'
is chosen, the number of times each sample will be included in the bootstrap will be given by \(n\sim Poisson(\lambda=1)\), which approximates the true bootstrap distribution when the number of samples is large. If'multinomial'
is chosen, we will apply true bootstrapping at the batch level to approximate bootstrapping over the hole dataset.kwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example::
>>> from pprint import pprint >>> from torch import randint >>> from torchmetrics.wrappers import BootStrapper >>> from torchmetrics.classification import MulticlassAccuracy >>> base_metric = MulticlassAccuracy(num_classes=5, average='micro') >>> bootstrap = BootStrapper(base_metric, num_bootstraps=20) >>> bootstrap.update(randint(5, (20,)), randint(5, (20,))) >>> output = bootstrap.compute() >>> pprint(output) {'mean': tensor(0.2089), 'std': tensor(0.0772)}
- compute()[source]¶
Compute the bootstrapped metric values.
Always returns a dict of tensors, which can contain the following keys:
mean
,std
,quantile
andraw
depending on how the class was initialized.
- forward(*args, **kwargs)[source]¶
Use the original forward method of the base metric class.
- Return type:
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.wrappers import BootStrapper >>> from torchmetrics.regression import MeanSquaredError >>> metric = BootStrapper(MeanSquaredError(), num_bootstraps=20) >>> metric.update(torch.randn(100,), torch.randn(100,)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.wrappers import BootStrapper >>> from torchmetrics.regression import MeanSquaredError >>> metric = BootStrapper(MeanSquaredError(), num_bootstraps=20) >>> values = [ ] >>> for _ in range(3): ... values.append(metric(torch.randn(100,), torch.randn(100,))) >>> fig_, ax_ = metric.plot(values)