torchmetrics.Metric¶
The base Metric
class is an abstract base class that are used as the building block for all other Module
metrics.
- class torchmetrics.Metric(**kwargs)[source]¶
Base class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality: 1. Handles the transfer of metric states to correct device 2. Handles the synchronization of metric states across processes
The three core methods of the base class are *
add_state()
*forward()
*reset()
which should almost never be overwritten by child classes. Instead, the following methods should be overwritten *
update()
*compute()
- Parameters:
additional keyword arguments, see Advanced metric settings for more info.
compute_on_cpu: If metric state should be stored on CPU during computations. Only works for list states.
dist_sync_on_step: If metric state should synchronize on
forward()
. Default isFalse
process_group: The process group on which the synchronization is called. Default is the world.
dist_sync_fn: Function that performs the allgather option on the metric state. Default is an custom implementation that calls
torch.distributed.all_gather
internally.distributed_available_fn: Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()
andtorch.distributed.is_initialized()
.sync_on_compute: If metric state should synchronize when
compute
is called. Default isTrue
compute_with_cache: If results from
compute
should be cached. Default isTrue
- add_state(name, default, dist_reduce_fx=None, persistent=False)[source]¶
Add metric state variable. Only used by subclasses.
Metric state variables are either :class:`~torch.Tensor or an empty list, which can be appended to by the metric. Each state variable must have a unique name associated with it. State variables are accessible as attributes of the metric i.e, if
name
is"my_state"
then its value can be accessed from an instancemetric
asmetric.my_state
. Metric states behave like buffers and parameters ofModule
as they are also updated when.to()
is called. Unlike parameters and buffers, metric states are not by default saved in the modulesstate_dict
.- Parameters:
name¶ (
str
) – The name of the state variable. The variable will then be accessible atself.name
.default¶ (
Union
[list
,Tensor
]) – Default value of the state; can either be aTensor
or an empty list. The state will be reset to this value whenself.reset()
is called.dist_reduce_fx¶ (Optional) – Function to reduce state across multiple processes in distributed mode. If value is
"sum"
,"mean"
,"cat"
,"min"
or"max"
we will usetorch.sum
,torch.mean
,torch.cat
,torch.min
andtorch.max`
respectively, each with argumentdim=0
. Note that the"cat"
reduction only makes sense if the state is a list, and not a tensor. The user can also pass a custom function in this parameter.persistent¶ (Optional) – whether the state will be saved as part of the modules
state_dict
. Default isFalse
.
- Return type:
Note
Setting
dist_reduce_fx
to None will return the metric state synchronized across different processes. However, there won’t be any reduction function applied to the synchronized metric state.The metric states would be synced as follows
If the metric state is
Tensor
, the synced value will be a stackedTensor
across the process dimension if the metric state was aTensor
. The originalTensor
metric state retains dimension and hence the synchronized output will be of shape(num_process, ...)
.If the metric state is a
list
, the synced value will be alist
containing the combined elements from all processes.
Note
When passing a custom function to
dist_reduce_fx
, expect the synchronized metric state to follow the format discussed in the above note.Note
The values inserted into a list state are deleted whenever
reset()
is called. This allows device memory to be automatically reallocated, but may produce unexpected effects when referencing list states. To retain such values afterreset()
is called, you must first copy them to another object.- Raises:
ValueError – If
default
is not atensor
or anempty list
.ValueError – If
dist_reduce_fx
is not callable or one of"mean"
,"sum"
,"cat"
,"min"
,"max"
orNone
.
- abstract compute()[source]¶
Override this method to compute the final metric value.
This method will automatically synchronize state variables when running in distributed backend.
- Return type:
- double()[source]¶
Override default and prevent dtype casting.
Please use
Metric.set_dtype()
instead.- Return type:
- float()[source]¶
Override default and prevent dtype casting.
Please use
Metric.set_dtype()
instead.- Return type:
- forward(*args, **kwargs)[source]¶
Aggregate and evaluate batch input directly.
Serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch statistics to the overall accumulating metric state. Input arguments are the exact same as corresponding
update
method. The returned output is the exact same as the output ofcompute
.- Parameters:
- Return type:
- Returns:
The output of the
compute
method evaluated on the current batch.- Raises:
TorchMetricsUserError – If the metric is already synced and
forward
is called again.
- half()[source]¶
Override default and prevent dtype casting.
Please use
Metric.set_dtype()
instead.- Return type:
- persistent(mode=False)[source]¶
Change post-init if metric states should be saved to its state_dict.
- Return type:
- set_dtype(dst_type)[source]¶
Transfer all metric state to specific dtype. Special version of standard type method.
- state_dict(destination=None, prefix='', keep_vars=False)[source]¶
Get the current state of metric as an dictionary.
- Parameters:
destination¶ (
Optional
[Dict
[str
,Any
]]) – Optional dictionary, that if provided, the state of module will be updated into the dict and the same object is returned. Otherwise, anOrderedDict
will be created and returned.prefix¶ (
str
) – optional string, a prefix added to parameter and buffer names to compose the keys in state_dict.keep_vars¶ (
bool
) – by default theTensor
returned in the state dict are detached from autograd. If set toTrue
, detaching will not be performed.
- Return type:
- sync(dist_sync_fn=None, process_group=None, should_sync=True, distributed_available=None)[source]¶
Sync function for manually controlling when metrics states should be synced across processes.
- Parameters:
dist_sync_fn¶ (
Optional
[Callable
]) – Function to be used to perform states synchronizationprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)should_sync¶ (
bool
) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.distributed_available¶ (
Optional
[Callable
]) – Function to determine if we are running inside a distributed setting
- Raises:
TorchMetricsUserError – If the metric is already synced and
sync
is called again.- Return type:
- sync_context(dist_sync_fn=None, process_group=None, should_sync=True, should_unsync=True, distributed_available=None)[source]¶
Context manager to synchronize states.
This context manager is used in distributed setting and makes sure that the local cache states are restored after yielding the synchronized state.
- Parameters:
dist_sync_fn¶ (
Optional
[Callable
]) – Function to be used to perform states synchronizationprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)should_sync¶ (
bool
) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.should_unsync¶ (
bool
) – Whether to restore the cache state so that the metrics can continue to be accumulated.distributed_available¶ (
Optional
[Callable
]) – Function to determine if we are running inside a distributed setting
- Return type:
- type(dst_type)[source]¶
Override default and prevent dtype casting.
Please use
Metric.set_dtype()
instead.- Return type:
- unsync(should_unsync=True)[source]¶
Unsync function for manually controlling when metrics states should be reverted back to their local states.
- abstract update(*_, **__)[source]¶
Override this method to update the state variables of your metric class.
- Return type:
- property metric_state: Dict[str, Union[List[Tensor], Tensor]][source]¶
Get the current state of the metric.