TorchMetrics¶
TorchMetrics is a collection of machine learning metrics for distributed, scalable PyTorch models and an easy-to-use API to create custom metrics. It has a collection of 60+ PyTorch metrics implementations and is rigorously tested for all edge cases.
pip install torchmetrics
In TorchMetrics, we offer the following benefits:
A standardized interface to increase reproducibility
Reduced Boilerplate
Distributed-training compatible
Rigorously tested
Automatic accumulation over batches
Automatic synchronization across multiple devices
Example 1: Functional Metrics¶
Below is a simple example for calculating the accuracy using the functional interface:
import torch
import torchmetrics
# simulate a classification problem
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))
acc = torchmetrics.functional.accuracy(preds, target)
Example 2: Module Metrics¶
The example below shows how to use the class-based interface:
import torch
import torchmetrics
# initialize metric
metric = torchmetrics.Accuracy()
n_batches = 10
for i in range(n_batches):
# simulate a classification problem
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))
# metric on current batch
acc = metric(preds, target)
print(f"Accuracy on batch {i}: {acc}")
# metric on all batches using custom accumulation
acc = metric.compute()
print(f"Accuracy on all data: {acc}")
# Reseting internal state such that metric ready for new data
metric.reset()
Example 3: TorchMetrics with Lightning¶
The example below shows how to use a metric in your LightningModule:
class MyModel(LightningModule):
def __init__(self):
...
self.accuracy = torchmetrics.Accuracy()
def training_step(self, batch, batch_idx):
x, y = batch
preds = self(x)
...
# log step metric
self.accuracy(preds, y)
self.log("train_acc_step", self.accuracy, on_epoch=True)
...