Total Variation (TV)¶
Module Interface¶
- class torchmetrics.image.TotalVariation(reduction='sum', **kwargs)[source]¶
Compute Total Variation loss (TV).
As input to
forward
andupdate
the metric accepts the following inputimg
(Tensor
): A tensor of shape(N, C, H, W)
consisting of images
As output of forward and compute the metric returns the following output
sdi
(Tensor
): ifreduction!='none'
returns float scalar tensor with average TV value over sample else returns tensor of shape(N,)
with TV values per sample
- Parameters:
reduction¶ (
Optional
[Literal
['mean'
,'sum'
,'none'
]]) –a method to reduce metric score over samples
'mean'
: takes the mean over samples'sum'
: takes the sum over samplesNone
or'none'
: return the score per sample
kwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
reduction
is not one of'sum'
,'mean'
,'none'
orNone
Example
>>> from torch import rand >>> from torchmetrics.image import TotalVariation >>> tv = TotalVariation() >>> img = torch.rand(5, 3, 28, 28) >>> tv(img) tensor(7546.8018)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.image import TotalVariation >>> metric = TotalVariation() >>> metric.update(torch.rand(5, 3, 28, 28)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.image import TotalVariation >>> metric = TotalVariation() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(5, 3, 28, 28))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.image.total_variation(img, reduction='sum')[source]¶
Compute total variation loss.
- Parameters:
- Return type:
- Returns:
A loss scalar value containing the total variation
- Raises:
ValueError – If
reduction
is not one of'sum'
,'mean'
,'none'
orNone
RuntimeError – If
img
is not 4D tensor
Example
>>> from torch import rand >>> from torchmetrics.functional.image import total_variation >>> img = rand(5, 3, 28, 28) >>> total_variation(img) tensor(7546.8018)