Implementing a Metric¶
To implement your own custom metric, subclass the base Metric class and implement the following
methods:
__init__(): Each state variable should be called usingself.add_state(...).update(): Any code needed to update the state given any inputs to the metric.compute(): Computes a final value from the state of the metric.
We provide the remaining interface, such as reset() that will make sure to correctly reset all metric
states that have been added using add_state. You should therefore not implement reset() yourself.
Additionally, adding metric states with add_state will make sure that states are correctly synchronized
in distributed settings (DDP). To see how metric states are synchronized across distributed processes,
refer to add_state() docs from the base Metric class.
Example implementation:
from torchmetrics import Metric
class MyAccuracy(Metric):
def __init__(self):
super().__init__()
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: Tensor, target: Tensor):
preds, target = self._input_format(preds, target)
assert preds.shape == target.shape
self.correct += torch.sum(preds == target)
self.total += target.numel()
def compute(self):
return self.correct.float() / self.total
Additionally you may want to set the class properties: is_differentiable, higher_is_better and full_state_update. Note that none of them are strictly required for the metric to work.
from torchmetrics import Metric
class MyMetric(Metric):
# Set to True if the metric is differentiable else set to False
is_differentiable: Optional[bool] = None
# Set to True if the metric reaches it optimal value when the metric is maximized.
# Set to False if it when the metric is minimized.
higher_is_better: Optional[bool] = True
# Set to True if the metric during 'update' requires access to the global metric
# state for its calculations. If not, setting this to False indicates that all
# batch states are independent and we will optimize the runtime of 'forward'
full_state_update: bool = True
Finally, from torchmetrics v1.0.0 onwards, we also support plotting of metrics through the .plot method. By default this method will raise NotImplementedError but can be implemented by the user to provide a custom plot for the metric. For any metrics that returns a simple scalar tensor, or a dict of scalar tensors the internal ._plot method can be used, that provides the common plotting functionality for most metrics in torchmetrics.
from torchmetrics import Metric
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
class MyMetric(Metric):
...
def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
return self._plot(val, ax)
If the metric returns a more complex output, a custom implementation of the plot method is required. For more details on the plotting API, see the this page .
Internal implementation details¶
This section briefly describes how metrics work internally. We encourage looking at the source code for more info.
Internally, TorchMetrics wraps the user defined update() and compute() method. We do this to automatically
synchronize and reduce metric states across multiple devices. More precisely, calling update() does the
following internally:
Clears computed cache.
Calls user-defined
update().
Similarly, calling compute() does the following internally:
Syncs metric states between processes.
Reduce gathered metric states.
Calls the user defined
compute()method on the gathered metric states.Cache computed result.
From a user’s standpoint this has one important side-effect: computed results are cached. This means that no
matter how many times compute is called after one and another, it will continue to return the same result.
The cache is first emptied on the next call to update.
forward serves the dual purpose of both returning the metric on the current data and updating the internal
metric state for accumulating over multiple batches. The forward() method achieves this by combining calls
to update, compute and reset. Depending on the class property full_state_update, forward
can behave in two ways:
If
full_state_updateisTrueit indicates that the metric duringupdaterequires access to the full metric state and we therefore need to do two calls toupdateto secure that the metric is calculated correctlyCalls
update()to update the global metric state (for accumulation over multiple batches)Caches the global state.
Calls
reset()to clear global metric state.Calls
update()to update local metric state.Calls
compute()to calculate metric for current batch.Restores the global state.
If
full_state_updateisFalse(default) the metric state of one batch is completly independent of the state of other batches, which means that we only need to callupdateonce.Caches the global state.
Calls
resetthe metric to its default stateCalls
updateto update the state with local batch statisticsCalls
computeto calculate the metric for the current batchReduce the global state and batch state into a single state that becomes the new global state
If implementing your own metric, we recommend trying out the metric with full_state_update class property set to
both True and False. If the results are equal, then setting it to False will usually give the best performance.
- class torchmetrics.Metric(**kwargs)[source]¶
Base class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality: 1. Handles the transfer of metric states to correct device 2. Handles the synchronization of metric states across processes
The three core methods of the base class are *
add_state()*forward()*reset()which should almost never be overwritten by child classes. Instead, the following methods should be overwritten *
update()*compute()- Parameters:
additional keyword arguments, see Advanced metric settings for more info.
compute_on_cpu: If metric state should be stored on CPU during computations. Only works for list states.
dist_sync_on_step: If metric state should synchronize on
forward(). Default isFalseprocess_group: The process group on which the synchronization is called. Default is the world.
dist_sync_fn: Function that performs the allgather option on the metric state. Default is an custom implementation that calls
torch.distributed.all_gatherinternally.distributed_available_fn: Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()andtorch.distributed.is_initialized().sync_on_compute: If metric state should synchronize when
computeis called. Default isTruecompute_with_cache: If results from
computeshould be cached. Default isFalse
- add_state(name, default, dist_reduce_fx=None, persistent=False)[source]¶
Add metric state variable. Only used by subclasses.
Metric state variables are either :class:`~torch.Tensor or an empty list, which can be appended to by the metric. Each state variable must have a unique name associated with it. State variables are accessible as attributes of the metric i.e, if
nameis"my_state"then its value can be accessed from an instancemetricasmetric.my_state. Metric states behave like buffers and parameters ofModuleas they are also updated when.to()is called. Unlike parameters and buffers, metric states are not by default saved in the modulesstate_dict.- Parameters:
name¶ (
str) – The name of the state variable. The variable will then be accessible atself.name.default¶ (
Union[list,Tensor]) – Default value of the state; can either be aTensoror an empty list. The state will be reset to this value whenself.reset()is called.dist_reduce_fx¶ (Optional) – Function to reduce state across multiple processes in distributed mode. If value is
"sum","mean","cat","min"or"max"we will usetorch.sum,torch.mean,torch.cat,torch.minandtorch.max`respectively, each with argumentdim=0. Note that the"cat"reduction only makes sense if the state is a list, and not a tensor. The user can also pass a custom function in this parameter.persistent¶ (Optional) – whether the state will be saved as part of the modules
state_dict. Default isFalse.
- Return type:
Note
Setting
dist_reduce_fxto None will return the metric state synchronized across different processes. However, there won’t be any reduction function applied to the synchronized metric state.The metric states would be synced as follows
If the metric state is
Tensor, the synced value will be a stackedTensoracross the process dimension if the metric state was aTensor. The originalTensormetric state retains dimension and hence the synchronized output will be of shape(num_process, ...).If the metric state is a
list, the synced value will be alistcontaining the combined elements from all processes.
Note
When passing a custom function to
dist_reduce_fx, expect the synchronized metric state to follow the format discussed in the above note.- Raises:
ValueError – If
defaultis not atensoror anempty list.ValueError – If
dist_reduce_fxis not callable or one of"mean","sum","cat","min","max"orNone.
- abstract compute()[source]¶
Override this method to compute the final metric value.
This method will automatically synchronize state variables when running in distributed backend.
- Return type:
- double()[source]¶
Override default and prevent dtype casting.
Please use
Metric.set_dtype()instead.- Return type:
- float()[source]¶
Override default and prevent dtype casting.
Please use
Metric.set_dtype()instead.- Return type:
- forward(*args, **kwargs)[source]¶
Aggregate and evaluate batch input directly.
Serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch statistics to the overall accumululating metric state. Input arguments are the exact same as corresponding
updatemethod. The returned output is the exact same as the output ofcompute.- Parameters:
- Return type:
- Returns:
The output of the
computemethod evaluated on the current batch.- Raises:
TorchMetricsUserError – If the metric is already synced and
forwardis called again.
- half()[source]¶
Override default and prevent dtype casting.
Please use
Metric.set_dtype()instead.- Return type:
- persistent(mode=False)[source]¶
Change post-init if metric states should be saved to its state_dict.
- Return type:
- set_dtype(dst_type)[source]¶
Transfer all metric state to specific dtype. Special version of standard type method.
- state_dict(destination=None, prefix='', keep_vars=False)[source]¶
Get the current state of metric as an dictionary.
- Parameters:
destination¶ (
Optional[Dict[str,Any]]) – Optional dictionary, that if provided, the state of module will be updated into the dict and the same object is returned. Otherwise, anOrderedDictwill be created and returned.prefix¶ (
str) – optional string, a prefix added to parameter and buffer names to compose the keys in state_dict.keep_vars¶ (
bool) – by default theIf set to ``True`, detaching will not be performed.
- Return type:
- sync(dist_sync_fn=None, process_group=None, should_sync=True, distributed_available=None)[source]¶
Sync function for manually controlling when metrics states should be synced across processes.
- Parameters:
dist_sync_fn¶ (
Optional[Callable]) – Function to be used to perform states synchronizationprocess_group¶ (
Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)should_sync¶ (
bool) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.distributed_available¶ (
Optional[Callable]) – Function to determine if we are running inside a distributed setting
- Raises:
TorchMetricsUserError – If the metric is already synced and
syncis called again.- Return type:
- sync_context(dist_sync_fn=None, process_group=None, should_sync=True, should_unsync=True, distributed_available=None)[source]¶
Context manager to synchronize states.
This context manager is used in distributed setting and makes sure that the local cache states are restored after yielding the syncronized state.
- Parameters:
dist_sync_fn¶ (
Optional[Callable]) – Function to be used to perform states synchronizationprocess_group¶ (
Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)should_sync¶ (
bool) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.should_unsync¶ (
bool) – Whether to restore the cache state so that the metrics can continue to be accumulated.distributed_available¶ (
Optional[Callable]) – Function to determine if we are running inside a distributed setting
- Return type:
- type(dst_type)[source]¶
Override default and prevent dtype casting.
Please use
Metric.set_dtype()instead.- Return type:
- unsync(should_unsync=True)[source]¶
Unsync function for manually controlling when metrics states should be reverted back to their local states.
- abstract update(*_, **__)[source]¶
Override this method to update the state variables of your metric class.
- Return type:
Contributing your metric to TorchMetrics¶
Wanting to contribute the metric you have implemented? Great, we are always open to adding more metrics to torchmetrics
as long as they serve a general purpose. However, to keep all our metrics consistent we request that the implementation
and tests gets formatted in the following way:
Start by reading our contribution guidelines.
First implement the functional backend. This takes cares of all the logic that goes into the metric. The code should be put into a single file placed under
torchmetrics/functional/"domain"/"new_metric".pywheredomainis the type of metric (classification, regression, nlp etc) andnew_metricis the name of the metric. In this file, there should be the following three functions:
_new_metric_update(...): everything that has to do with type/shape checking and all logic required before distributed syncing need to go here.
_new_metric_compute(...): all remaining logic.
new_metric(...): essentially wraps the_updateand_computeprivate functions into one public function that makes up the functional interface for the metric.Note
The functional accuracy metric is a great example of this division of logic.
In a corresponding file placed in
torchmetrics/"domain"/"new_metric".pycreate the module interface:
Create a new module metric by subclassing
torchmetrics.Metric.In the
__init__of the module callself.add_statefor as many metric states are needed for the metric to proper accumulate metric statistics.The module interface should essentially call the private
_new_metric_update(...)in its update method and similarly the_new_metric_compute(...)function in itscompute. No logic should really be implemented in the module interface. We do this to not have duplicate code to maintain.Note
The module Accuracy metric that corresponds to the above functional example showcases these steps.
Remember to add binding to the different relevant
__init__files.Testing is key to keeping
torchmetricstrustworthy. This is why we have a very rigid testing protocol. This means that we in most cases require the metric to be tested against some other common framework (sklearn,scipyetc).
Create a testing file in
unittests/"domain"/test_"new_metric".py. Only one file is needed as it is intended to test both the functional and module interface.In that file, start by defining a number of test inputs that your metric should be evaluated on.
Create a testclass
class NewMetric(MetricTester)that inherits fromtests.helpers.testers.MetricTester. This testclass should essentially implement thetest_"new_metric"_classandtest_"new_metric"_fnmethods that respectively tests the module interface and the functional interface.The testclass should be parameterized (using
@pytest.mark.parametrize) by the different test inputs defined initially. Additionally, thetest_"new_metric"_classmethod should also be parameterized with anddpparameter such that it gets tested in a distributed setting. If your metric has additional parameters, then make sure to also parameterize these such that different combinations of inputs and parameters gets tested.(optional) If your metric raises any exception, please add tests that showcase this.
Note
The test file for accuracy metric shows how to implement such tests.
If you only can figure out part of the steps, do not fear to send a PR. We will much rather receive working metrics that are not formatted exactly like our codebase, than not receiving any. Formatting can always be applied. We will gladly guide and/or help implement the remaining :]