:orphan: TorchMetrics ============ `TorchMetrics `_ is a collection of machine learning metrics for distributed, scalable PyTorch models and an easy-to-use API to create custom metrics. It has a collection of 60+ PyTorch metrics implementations and is rigorously tested for all edge cases. .. code-block:: bash pip install torchmetrics In TorchMetrics, we offer the following benefits: - A standardized interface to increase reproducibility - Reduced Boilerplate - Distributed-training compatible - Rigorously tested - Automatic accumulation over batches - Automatic synchronization across multiple devices ----------------- Example 1: Functional Metrics ----------------------------- Below is a simple example for calculating the accuracy using the functional interface: .. code-block:: python import torch import torchmetrics # simulate a classification problem preds = torch.randn(10, 5).softmax(dim=-1) target = torch.randint(5, (10,)) acc = torchmetrics.functional.accuracy(preds, target) ------------ Example 2: Module Metrics ------------------------- The example below shows how to use the class-based interface: .. code-block:: python import torch import torchmetrics # initialize metric metric = torchmetrics.Accuracy() n_batches = 10 for i in range(n_batches): # simulate a classification problem preds = torch.randn(10, 5).softmax(dim=-1) target = torch.randint(5, (10,)) # metric on current batch acc = metric(preds, target) print(f"Accuracy on batch {i}: {acc}") # metric on all batches using custom accumulation acc = metric.compute() print(f"Accuracy on all data: {acc}") # Resetting internal state such that metric ready for new data metric.reset() ------------ Example 3: TorchMetrics with Lightning -------------------------------------- The example below shows how to use a metric in your :doc:`LightningModule <../common/lightning_module>`: .. code-block:: python class MyModel(LightningModule): def __init__(self): ... self.accuracy = torchmetrics.Accuracy() def training_step(self, batch, batch_idx): x, y = batch preds = self(x) ... # log step metric self.accuracy(preds, y) self.log("train_acc_step", self.accuracy, on_epoch=True) ...