torchmetrics.Metric

The base Metric class is an abstract base class that are used as the building block for all other Module metrics.

class torchmetrics.Metric(**kwargs)[source]

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality:

  1. Handles the transfer of metric states to the correct device.

  2. Handles the synchronization of metric states across processes.

  3. Provides properties and methods to control the overall behavior of the metric and its states.

The three core methods of the base class are: add_state(), forward() and reset() which should almost never be overwritten by child classes. Instead, the following methods should be overwritten update() and compute().

Parameters:

kwargs (Any) –

additional keyword arguments, see Advanced metric settings for more info.

  • compute_on_cpu:

    If metric state should be stored on CPU during computations. Only works for list states.

  • dist_sync_on_step:

    If metric state should synchronize on forward(). Default is False.

  • process_group:

    The process group on which the synchronization is called. Default is the world.

  • dist_sync_fn:

    Function that performs the allgather option on the metric state. Default is a custom implementation that calls torch.distributed.all_gather internally.

  • distributed_available_fn:

    Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

  • sync_on_compute:

    If metric state should synchronize when compute is called. Default is True.

  • compute_with_cache:

    If results from compute should be cached. Default is True.

add_state(name, default, dist_reduce_fx=None, persistent=False)[source]

Add metric state variable. Only used by subclasses.

Metric state variables are either :class:`~torch.Tensor or an empty list, which can be appended to by the metric. Each state variable must have a unique name associated with it. State variables are accessible as attributes of the metric i.e, if name is "my_state" then its value can be accessed from an instance metric as metric.my_state. Metric states behave like buffers and parameters of Module as they are also updated when .to() is called. Unlike parameters and buffers, metric states are not by default saved in the modules state_dict.

Parameters:
  • name (str) – The name of the state variable. The variable will then be accessible at self.name.

  • default (Union[list, Tensor]) – Default value of the state; can either be a Tensor or an empty list. The state will be reset to this value when self.reset() is called.

  • dist_reduce_fx (Optional) – Function to reduce state across multiple processes in distributed mode. If value is "sum", "mean", "cat", "min" or "max" we will use torch.sum, torch.mean, torch.cat, torch.min and torch.max` respectively, each with argument dim=0. Note that the "cat" reduction only makes sense if the state is a list, and not a tensor. The user can also pass a custom function in this parameter.

  • persistent (Optional) – whether the state will be saved as part of the modules state_dict. Default is False.

Return type:

None

Note

Setting dist_reduce_fx to None will return the metric state synchronized across different processes. However, there won’t be any reduction function applied to the synchronized metric state.

The metric states would be synced as follows

  • If the metric state is Tensor, the synced value will be a stacked Tensor across the process dimension if the metric state was a Tensor. The original Tensor metric state retains dimension and hence the synchronized output will be of shape (num_process, ...).

  • If the metric state is a list, the synced value will be a list containing the combined elements from all processes.

Important

When passing a custom function to dist_reduce_fx, expect the synchronized metric state to follow the format discussed in the above note.

Caution

The values inserted into a list state are deleted whenever reset() is called. This allows device memory to be automatically reallocated, but may produce unexpected effects when referencing list states. To retain such values after reset() is called, you must first copy them to another object.

Raises:
  • ValueError – If default is not a tensor or an empty list.

  • ValueError – If dist_reduce_fx is not callable or one of "mean", "sum", "cat", "min", "max" or None.

clone()[source]

Make a copy of the metric.

Return type:

Metric

abstract compute()[source]

Override this method to compute the final metric value.

This method will automatically synchronize state variables when running in distributed backend.

Return type:

Any

double()[source]

Override default and prevent dtype casting.

Please use Metric.set_dtype() instead.

Return type:

Metric

float()[source]

Override default and prevent dtype casting.

Please use Metric.set_dtype() instead.

Return type:

Metric

forward(*args, **kwargs)[source]

Aggregate and evaluate batch input directly.

Serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch statistics to the overall accumulating metric state. Input arguments are the exact same as corresponding update method. The returned output is the exact same as the output of compute.

Parameters:
  • args (Any) – Any arguments as required by the metric update method.

  • kwargs (Any) – Any keyword arguments as required by the metric update method.

Return type:

Any

Returns:

The output of the compute method evaluated on the current batch.

Raises:

TorchMetricsUserError – If the metric is already synced and forward is called again.

half()[source]

Override default and prevent dtype casting.

Please use Metric.set_dtype() instead.

Return type:

Metric

merge_state(incoming_state)[source]

Merge incoming metric state to the current state of the metric.

Parameters:

incoming_state (Union[dict[str, Any], Metric]) – either a dict containing a metric state similar to the metric itself or an instance of the metric class.

Raises:
  • ValueError – If the incoming state is neither a dict nor an instance of the metric class.

  • RuntimeError – If the metric has full_state_update=True or dist_sync_on_step=True. In these cases, the metric cannot be merged with another metric state in a simple way. The user should overwrite the method in the metric class to handle the merge operation.

  • ValueError – If the incoming state is a metric instance but the class is different from the current metric class.

Example with a metric instance: :rtype: None

>>> from torchmetrics.aggregation import SumMetric
>>> metric1 = SumMetric()
>>> metric2 = SumMetric()
>>> metric1.update(1)
>>> metric2.update(2)
>>> metric1.merge_state(metric2)
>>> metric1.compute()
tensor(3.)

Example with a dict:

>>> from torchmetrics.aggregation import SumMetric
>>> metric = SumMetric()
>>> metric.update(1)
>>> # SumMetric has one state variable called `sum_value`
>>> metric.merge_state({"sum_value": torch.tensor(2)})
>>> metric.compute()
tensor(3.)
persistent(mode=False)[source]

Change post-init if metric states should be saved to its state_dict.

Return type:

None

plot(*_, **__)[source]

Override this method plot the metric value.

Return type:

Any

reset()[source]

Reset metric state variables to their default value.

Return type:

None

set_dtype(dst_type)[source]

Transfer all metric state to specific dtype. Special version of standard type method.

Parameters:

dst_type (Union[str, dtype]) – the desired type as string or dtype object

Return type:

Metric

state_dict(destination=None, prefix='', keep_vars=False)[source]

Get the current state of metric as an dictionary.

Parameters:
  • destination (Optional[dict[str, Any]]) – Optional dictionary, that if provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned.

  • prefix (str) – optional string, a prefix added to parameter and buffer names to compose the keys in state_dict.

  • keep_vars (bool) – by default the Tensor returned in the state dict are detached from autograd. If set to True, detaching will not be performed.

Return type:

dict[str, Any]

sync(dist_sync_fn=None, process_group=None, should_sync=True, distributed_available=None)[source]

Sync function for manually controlling when metrics states should be synced across processes.

Parameters:
  • dist_sync_fn (Optional[Callable]) – Function to be used to perform states synchronization

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • should_sync (bool) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.

  • distributed_available (Optional[Callable]) – Function to determine if we are running inside a distributed setting

Raises:

TorchMetricsUserError – If the metric is already synced and sync is called again.

Return type:

None

sync_context(dist_sync_fn=None, process_group=None, should_sync=True, should_unsync=True, distributed_available=None)[source]

Context manager to synchronize states.

This context manager is used in distributed setting and makes sure that the local cache states are restored after yielding the synchronized state.

Parameters:
  • dist_sync_fn (Optional[Callable]) – Function to be used to perform states synchronization

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • should_sync (bool) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.

  • should_unsync (bool) – Whether to restore the cache state so that the metrics can continue to be accumulated.

  • distributed_available (Optional[Callable]) – Function to determine if we are running inside a distributed setting

Return type:

Generator

type(dst_type)[source]

Override default and prevent dtype casting.

Please use Metric.set_dtype() instead.

Return type:

Metric

unsync(should_unsync=True)[source]

Unsync function for manually controlling when metrics states should be reverted back to their local states.

Parameters:

should_unsync (bool) – Whether to perform unsync

Return type:

None

abstract update(*_, **__)[source]

Override this method to update the state variables of your metric class.

Return type:

None

property device: device[source]

Return the device of the metric.

property dtype: dtype[source]

Return the default dtype of the metric.

property metric_state: dict[str, Union[List[torch.Tensor], torch.Tensor]][source]

Get the current state of the metric.

property update_called: bool[source]

Returns True if update or forward has been called initialization or last reset.

property update_count: int[source]

Get the number of times update and/or forward has been called since initialization or last reset.