Implementing a Metric¶
While we strive to include as many metrics as possible in torchmetrics
, we cannot include them all. We have made it
easy to implement your own metric, and you can contribute it to torchmetrics
if you wish. This page will guide
you through the process. If you afterwards are interested in contributing your metric to torchmetrics
, please
read the contribution guidelines and
see this section.
Base interface¶
To implement your own custom metric, subclass the base Metric
class and implement the following
methods:
__init__()
: Each state variable should be called usingself.add_state(...)
.update()
: Any code needed to update the state given any inputs to the metric.compute()
: Computes a final value from the state of the metric.
We provide the remaining interface, such as reset()
that will make sure to correctly reset all metric
states that have been added using add_state
. You should therefore not implement reset()
yourself, only in rare
cases where not all the state variables should be reset to their default value. Adding metric states with add_state
will make sure that states are correctly synchronized in distributed settings (DDP). To see how metric states are
synchronized across distributed processes, refer to add_state()
docs from the base
Metric
class.
Below is a basic implementation of a custom accuracy metric. In the __init__
method we add the metric states
correct
and total
, which will be used to accumulate the number of correct predictions and the total number
of predictions, respectively. In the update
method we update the metric states based on the inputs to the metric.
Finally, in the compute
method we compute the final metric value based on the metric states.
from torchmetrics import Metric
class MyAccuracy(Metric):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: Tensor, target: Tensor) -> None:
preds, target = self._input_format(preds, target)
if preds.shape != target.shape:
raise ValueError("preds and target must have the same shape")
self.correct += torch.sum(preds == target)
self.total += target.numel()
def compute(self) -> Tensor:
return self.correct.float() / self.total
A few important things to note:
The
dist_reduce_fx
argument toadd_state
is used to specify how the metric states should be reduced between batches in distributed settings. In this case we use"sum"
to sum the metric states across batches. A couple of built-in options are available:"sum"
,"mean"
,"cat"
,"min"
or"max"
, but a custom reduction is also supported.In
update
we do not return anything but instead update the metric states in-place.In
compute
when running in distributed mode, the states would have been synced before the compute method is called. Thusself.correct
andself.total
will contain the sum of the metric states across all processes.
Working with list states¶
When initializing metric states with add_state
, the default
argument can either be a single tensor (as in the
example above) or an empty list. Most metric will only require a single tensor to accumulate the metric states, but
for some metrics that need access to the individual batch states, it can be useful to use a list of tensors. In the
following example we show how to implement Spearman correlation, which requires access to the individual batch states
because we need to calculate the rank of the predictions and targets.
from torchmetrics import Metric
from torchmetrics.utilities import dim_zero_cat
class MySpearmanCorrCoef(Metric):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
def update(self, preds: Tensor, target: Tensor) -> None:
self.preds.append(preds)
self.target.append(target)
def compute(self):
# parse inputs
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
# some intermediate computation...
r_preds, r_target = _rank_data(preds), _rank_data(target)
preds_diff = r_preds - r_preds.mean(0)
target_diff = r_target - r_target.mean(0)
cov = (preds_diff * target_diff).mean(0)
preds_std = torch.sqrt((preds_diff * preds_diff).mean(0))
target_std = torch.sqrt((target_diff * target_diff).mean(0))
# finalize the computations
corrcoef = cov / (preds_std * target_std + eps)
return torch.clamp(corrcoef, -1.0, 1.0)
A few important things to note for this example:
When working with list states, the
dist_reduce_fx
argument toadd_state
should be set to"cat"
to concatenate the list of tensors across batches.When working with list states, The
update(...)
method should append the batch states to the list.In the the
compute
method the list states behave a bit differently dependeding on whether you are running in distributed mode or not. In non-distributed mode the list states will be a list of tensors, while in distributed mode the list have already been concatenated into a single tensor. For this reason, we recommend always using thedim_zero_cat
helper function which will standardize the list states to be a single concatenated tensor regardless of the mode.Calling the
reset
method will clear the list state, deleting any values inserted into it. For this reason, care must be taken when referencing list states. If you require the values after your metric is reset, you must first copy the attribute to another object (e.g. using deepcopy.copy).
Metric attributes¶
When done implementing your own metric, there are a few properties and attributes that you may want to set to add
additional functionality. The three attributes to consider are: is_differentiable
, higher_is_better
and
full_state_update
. Note that none of them are strictly required to be set for the metric to work.
from torchmetrics import Metric
class MyMetric(Metric):
# Set to True if the metric is differentiable else set to False
is_differentiable: Optional[bool] = None
# Set to True if the metric reaches it optimal value when the metric is maximized.
# Set to False if it when the metric is minimized.
higher_is_better: Optional[bool] = True
# Set to True if the metric during 'update' requires access to the global metric
# state for its calculations. If not, setting this to False indicates that all
# batch states are independent and we will optimize the runtime of 'forward'
full_state_update: bool = True
Plot interface¶
From torchmetrics v1.0.0 onwards, we also support plotting of metrics through the .plot()
method. By default this method
will raise NotImplementedError but can be implemented by the user to provide a custom plot for the metric.
For any metrics that returns a simple scalar tensor, or a dict of scalar tensors the internal ._plot method can be
used, that provides the common plotting functionality for most metrics in torchmetrics.
from torchmetrics import Metric
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
class MyMetric(Metric):
# set these attributes if you want to use the internal ._plot method
# bounds are automatically added to the generated plot
plot_lower_bound: Optional[float] = None
plot_upper_bound: Optional[float] = None
def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
return self._plot(val, ax)
If the metric returns a more complex output, a custom implementation of the plot method is required. For more details on the plotting API, see the this page .
Internal implementation details¶
This section briefly describes how metrics work internally. We encourage looking at the source code for more info.
Internally, TorchMetrics wraps the user defined update()
and compute()
method. We do this to automatically
synchronize and reduce metric states across multiple devices. More precisely, calling update()
does the
following internally:
Clears computed cache.
Calls user-defined
update()
.
Similarly, calling compute()
does the following internally:
Syncs metric states between processes.
Reduce gathered metric states.
Calls the user defined
compute()
method on the gathered metric states.Cache computed result.
From a user’s standpoint this has one important side-effect: computed results are cached. This means that no
matter how many times compute
is called after one and another, it will continue to return the same result.
The cache is first emptied on the next call to update
.
forward
serves the dual purpose of both returning the metric on the current data and updating the internal
metric state for accumulating over multiple batches. The forward()
method achieves this by combining calls
to update
, compute
and reset
. Depending on the class property full_state_update
, forward
can behave in two ways:
If
full_state_update
isTrue
it indicates that the metric duringupdate
requires access to the full metric state and we therefore need to do two calls toupdate
to secure that the metric is calculated correctlyCalls
update()
to update the global metric state (for accumulation over multiple batches)Caches the global state.
Calls
reset()
to clear global metric state.Calls
update()
to update local metric state.Calls
compute()
to calculate metric for current batch.Restores the global state.
If
full_state_update
isFalse
(default) the metric state of one batch is completely independent of the state of other batches, which means that we only need to callupdate
once.Caches the global state.
Calls
reset
the metric to its default stateCalls
update
to update the state with local batch statisticsCalls
compute
to calculate the metric for the current batchReduce the global state and batch state into a single state that becomes the new global state
If implementing your own metric, we recommend trying out the metric with full_state_update
class property set to
both True
and False
. If the results are equal, then setting it to False
will usually give the best
performance.
- class torchmetrics.Metric(**kwargs)[source]
Base class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality:
Handles the transfer of metric states to the correct device.
Handles the synchronization of metric states across processes.
Provides properties and methods to control the overall behavior of the metric and its states.
The three core methods of the base class are:
add_state()
,forward()
andreset()
which should almost never be overwritten by child classes. Instead, the following methods should be overwrittenupdate()
andcompute()
.- Parameters:
additional keyword arguments, see Advanced metric settings for more info.
- compute_on_cpu:
If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step:
If metric state should synchronize on
forward()
. Default isFalse
.
- process_group:
The process group on which the synchronization is called. Default is the world.
- dist_sync_fn:
Function that performs the allgather option on the metric state. Default is a custom implementation that calls
torch.distributed.all_gather
internally.
- distributed_available_fn:
Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()
andtorch.distributed.is_initialized()
.
- sync_on_compute:
If metric state should synchronize when
compute
is called. Default isTrue
.
- compute_with_cache:
If results from
compute
should be cached. Default isTrue
.
- add_state(name, default, dist_reduce_fx=None, persistent=False)[source]
Add metric state variable. Only used by subclasses.
Metric state variables are either :class:`~torch.Tensor or an empty list, which can be appended to by the metric. Each state variable must have a unique name associated with it. State variables are accessible as attributes of the metric i.e, if
name
is"my_state"
then its value can be accessed from an instancemetric
asmetric.my_state
. Metric states behave like buffers and parameters ofModule
as they are also updated when.to()
is called. Unlike parameters and buffers, metric states are not by default saved in the modulesstate_dict
.- Parameters:
name¶ (
str
) – The name of the state variable. The variable will then be accessible atself.name
.default¶ (
Union
[list
,Tensor
]) – Default value of the state; can either be aTensor
or an empty list. The state will be reset to this value whenself.reset()
is called.dist_reduce_fx¶ (Optional) – Function to reduce state across multiple processes in distributed mode. If value is
"sum"
,"mean"
,"cat"
,"min"
or"max"
we will usetorch.sum
,torch.mean
,torch.cat
,torch.min
andtorch.max`
respectively, each with argumentdim=0
. Note that the"cat"
reduction only makes sense if the state is a list, and not a tensor. The user can also pass a custom function in this parameter.persistent¶ (Optional) – whether the state will be saved as part of the modules
state_dict
. Default isFalse
.
- Return type:
Note
Setting
dist_reduce_fx
to None will return the metric state synchronized across different processes. However, there won’t be any reduction function applied to the synchronized metric state.The metric states would be synced as follows
If the metric state is
Tensor
, the synced value will be a stackedTensor
across the process dimension if the metric state was aTensor
. The originalTensor
metric state retains dimension and hence the synchronized output will be of shape(num_process, ...)
.If the metric state is a
list
, the synced value will be alist
containing the combined elements from all processes.
Important
When passing a custom function to
dist_reduce_fx
, expect the synchronized metric state to follow the format discussed in the above note.Caution
The values inserted into a list state are deleted whenever
reset()
is called. This allows device memory to be automatically reallocated, but may produce unexpected effects when referencing list states. To retain such values afterreset()
is called, you must first copy them to another object.- Raises:
ValueError – If
default
is not atensor
or anempty list
.ValueError – If
dist_reduce_fx
is not callable or one of"mean"
,"sum"
,"cat"
,"min"
,"max"
orNone
.
- abstract compute()[source]
Override this method to compute the final metric value.
This method will automatically synchronize state variables when running in distributed backend.
- Return type:
- double()[source]
Override default and prevent dtype casting.
Please use
Metric.set_dtype()
instead.- Return type:
- float()[source]
Override default and prevent dtype casting.
Please use
Metric.set_dtype()
instead.- Return type:
- forward(*args, **kwargs)[source]
Aggregate and evaluate batch input directly.
Serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch statistics to the overall accumulating metric state. Input arguments are the exact same as corresponding
update
method. The returned output is the exact same as the output ofcompute
.- Parameters:
- Return type:
- Returns:
The output of the
compute
method evaluated on the current batch.- Raises:
TorchMetricsUserError – If the metric is already synced and
forward
is called again.
- half()[source]
Override default and prevent dtype casting.
Please use
Metric.set_dtype()
instead.- Return type:
- merge_state(incoming_state)[source]
Merge incoming metric state to the current state of the metric.
- Parameters:
incoming_state¶ (
Union
[dict
[str
,Any
],Metric
]) – either a dict containing a metric state similar to the metric itself or an instance of the metric class.- Raises:
ValueError – If the incoming state is neither a dict nor an instance of the metric class.
RuntimeError – If the metric has
full_state_update=True
ordist_sync_on_step=True
. In these cases, the metric cannot be merged with another metric state in a simple way. The user should overwrite the method in the metric class to handle the merge operation.ValueError – If the incoming state is a metric instance but the class is different from the current metric class.
Example with a metric instance: :rtype:
None
>>> from torchmetrics.aggregation import SumMetric >>> metric1 = SumMetric() >>> metric2 = SumMetric() >>> metric1.update(1) >>> metric2.update(2) >>> metric1.merge_state(metric2) >>> metric1.compute() tensor(3.)
Example with a dict:
>>> from torchmetrics.aggregation import SumMetric >>> metric = SumMetric() >>> metric.update(1) >>> # SumMetric has one state variable called `sum_value` >>> metric.merge_state({"sum_value": torch.tensor(2)}) >>> metric.compute() tensor(3.)
- persistent(mode=False)[source]
Change post-init if metric states should be saved to its state_dict.
- Return type:
- set_dtype(dst_type)[source]
Transfer all metric state to specific dtype. Special version of standard type method.
- state_dict(destination=None, prefix='', keep_vars=False)[source]
Get the current state of metric as an dictionary.
- Parameters:
destination¶ (
Optional
[dict
[str
,Any
]]) – Optional dictionary, that if provided, the state of module will be updated into the dict and the same object is returned. Otherwise, anOrderedDict
will be created and returned.prefix¶ (
str
) – optional string, a prefix added to parameter and buffer names to compose the keys in state_dict.keep_vars¶ (
bool
) – by default theTensor
returned in the state dict are detached from autograd. If set toTrue
, detaching will not be performed.
- Return type:
- sync(dist_sync_fn=None, process_group=None, should_sync=True, distributed_available=None)[source]
Sync function for manually controlling when metrics states should be synced across processes.
- Parameters:
dist_sync_fn¶ (
Optional
[Callable
]) – Function to be used to perform states synchronizationprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)should_sync¶ (
bool
) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.distributed_available¶ (
Optional
[Callable
]) – Function to determine if we are running inside a distributed setting
- Raises:
TorchMetricsUserError – If the metric is already synced and
sync
is called again.- Return type:
- sync_context(dist_sync_fn=None, process_group=None, should_sync=True, should_unsync=True, distributed_available=None)[source]
Context manager to synchronize states.
This context manager is used in distributed setting and makes sure that the local cache states are restored after yielding the synchronized state.
- Parameters:
dist_sync_fn¶ (
Optional
[Callable
]) – Function to be used to perform states synchronizationprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)should_sync¶ (
bool
) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.should_unsync¶ (
bool
) – Whether to restore the cache state so that the metrics can continue to be accumulated.distributed_available¶ (
Optional
[Callable
]) – Function to determine if we are running inside a distributed setting
- Return type:
- type(dst_type)[source]
Override default and prevent dtype casting.
Please use
Metric.set_dtype()
instead.- Return type:
- unsync(should_unsync=True)[source]
Unsync function for manually controlling when metrics states should be reverted back to their local states.
- abstract update(*_, **__)[source]
Override this method to update the state variables of your metric class.
- Return type:
- property metric_state: dict[str, Union[List[torch.Tensor], torch.Tensor]][source]
Get the current state of the metric.
Contributing your metric to TorchMetrics¶
Wanting to contribute the metric you have implemented? Great, we are always open to adding more metrics to torchmetrics
as long as they serve a general purpose. However, to keep all our metrics consistent we request that the implementation
and tests gets formatted in the following way:
Start by reading our contribution guidelines.
First implement the functional backend. This takes care of all the logic that goes into the metric. The code should be put into a single file placed under
src/torchmetrics/functional/"domain"/"new_metric".py
wheredomain
is the type of metric (classification, regression, text etc.) andnew_metric
is the name of the metric. In this file, there should be the following three functions:
_new_metric_update(...)
: everything that has to do with type/shape checking and all logic required before distributed syncing need to go here.
_new_metric_compute(...)
: all remaining logic.
new_metric(...)
: essentially wraps the_update
and_compute
private functions into one public function that makes up the functional interface for the metric.Hint
The functional mean squared error metric is a is a great example of how to divide the logic.
In a corresponding file placed in
src/torchmetrics/"domain"/"new_metric".py
create the module interface:
Create a new module metric by subclassing
torchmetrics.Metric
.In the
__init__
of the module callself.add_state
for as many metric states are needed for the metric to proper accumulate metric statistics.The module interface should essentially call the private
_new_metric_update(...)
in its update method and similarly the_new_metric_compute(...)
function in itscompute
. No logic should really be implemented in the module interface. We do this to not have duplicate code to maintain.Note
The module MeanSquaredError metric that corresponds to the above functional example showcases these steps.
Remember to add binding to the different relevant
__init__
files.Testing is key to keeping
torchmetrics
trustworthy. This is why we have a very rigid testing protocol. This means that we in most cases require the metric to be tested against some other common framework (sklearn
,scipy
etc).
Create a testing file in
tests/unittests/"domain"/test_"new_metric".py
. Only one file is needed as it is intended to test both the functional and module interface.In that file, start by defining a number of test inputs that your metric should be evaluated on.
Create a testclass
class NewMetric(MetricTester)
that inherits fromtests.helpers.testers.MetricTester
. This test class should essentially implement thetest_"new_metric"_class
andtest_"new_metric"_fn
methods that respectively tests the module interface and the functional interface.The testclass should be parameterized (using
@pytest.mark.parametrize
) by the different test inputs defined initially. Additionally, thetest_"new_metric"_class
method should also be parameterized with anddp
parameter such that it gets tested in a distributed setting. If your metric has additional parameters, then make sure to also parameterize these so that different combinations of inputs and parameters get tested.(optional) If your metric raises any exception, please add tests that showcase this.
Hint
The test file for MSE metric shows how to implement such tests.
If you only can figure out part of the steps, do not fear to send a PR. We will much rather receive working metrics that are not formatted exactly like our codebase, than not receiving any. Formatting can always be applied. We will gladly guide and/or help implement the remaining :]