Confusion Matrix

Module Interface

class torchmetrics.ConfusionMatrix(**kwargs)[source]

Compute the confusion matrix.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of BinaryConfusionMatrix, MulticlassConfusionMatrix and MultilabelConfusionMatrix for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0, 1, 0, 0])
>>> confmat = ConfusionMatrix(task="binary", num_classes=2)
>>> confmat(preds, target)
tensor([[2, 0],
        [1, 1]])
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> confmat = ConfusionMatrix(task="multiclass", num_classes=3)
>>> confmat(preds, target)
tensor([[1, 1, 0],
        [0, 1, 0],
        [0, 0, 1]])
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> confmat = ConfusionMatrix(task="multilabel", num_labels=3)
>>> confmat(preds, target)
tensor([[[1, 0], [0, 1]],
        [[1, 0], [1, 0]],
        [[0, 1], [0, 1]]])
static __new__(cls, task, threshold=0.5, num_classes=None, num_labels=None, normalize=None, ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinaryConfusionMatrix

class torchmetrics.classification.BinaryConfusionMatrix(threshold=0.5, ignore_index=None, normalize=None, validate_args=True, **kwargs)[source]

Compute the confusion matrix for binary tasks.

The confusion matrix \(C\) is constructed such that \(C_{i, j}\) is equal to the number of observations known to be in class \(i\) but predicted to be in class \(j\). Thus row indices of the confusion matrix correspond to the true class labels and column indices correspond to the predicted class labels.

For binary tasks, the confusion matrix is a 2x2 matrix with the following structure:

  • \(C_{0, 0}\): True negatives

  • \(C_{0, 1}\): False positives

  • \(C_{1, 0}\): False negatives

  • \(C_{1, 1}\): True positives

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int or float tensor of shape (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, ...).

As output to forward and compute the metric returns the following output:

  • confusion_matrix (Tensor): A tensor containing a (2, 2) matrix

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • normalize (Optional[Literal['true', 'pred', 'all', 'none']]) –

    Normalization mode for confusion matrix. Choose from:

    • None or 'none': no normalization (default)

    • 'true': normalization over the targets (most commonly used)

    • 'pred': normalization over the predictions

    • 'all': normalization over the whole matrix

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (preds is int tensor):
>>> from torchmetrics.classification import BinaryConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> bcm = BinaryConfusionMatrix()
>>> bcm(preds, target)
tensor([[2, 0],
        [1, 1]])
Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01])
>>> bcm = BinaryConfusionMatrix()
>>> bcm(preds, target)
tensor([[2, 0],
        [1, 1]])
plot(val=None, ax=None, add_text=True, labels=None, cmap=None)[source]

Plot a single or multiple values from the metric.

Parameters:
Return type:

tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randint
>>> from torchmetrics.classification import MulticlassConfusionMatrix
>>> metric = MulticlassConfusionMatrix(num_classes=5)
>>> metric.update(randint(5, (20,)), randint(5, (20,)))
>>> fig_, ax_ = metric.plot()
../_images/confusion_matrix-1.png

MulticlassConfusionMatrix

class torchmetrics.classification.MulticlassConfusionMatrix(num_classes, ignore_index=None, normalize=None, validate_args=True, **kwargs)[source]

Compute the confusion matrix for multiclass tasks.

The confusion matrix \(C\) is constructed such that \(C_{i, j}\) is equal to the number of observations known to be in class \(i\) but predicted to be in class \(j\). Thus row indices of the confusion matrix correspond to the true class labels and column indices correspond to the predicted class labels.

For multiclass tasks, the confusion matrix is a NxN matrix, where:

  • \(C_{i, i}\) represents the number of true positives for class \(i\)

  • \(\sum_{j=1, j\neq i}^N C_{i, j}\) represents the number of false negatives for class \(i\)

  • \(\sum_{j=1, j\neq i}^N C_{j, i}\) represents the number of false positives for class \(i\)

  • the sum of the remaining cells in the matrix represents the number of true negatives for class \(i\)

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int or float tensor of shape (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, ...).

As output to forward and compute the metric returns the following output:

  • confusion_matrix: [num_classes, num_classes] matrix

Parameters:
  • num_classes (int) – Integer specifying the number of classes

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • normalize (Optional[Literal['true', 'pred', 'all', 'none']]) –

    Normalization mode for confusion matrix. Choose from:

    • None or 'none': no normalization (default)

    • 'true': normalization over the targets (most commonly used)

    • 'pred': normalization over the predictions

    • 'all': normalization over the whole matrix

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (pred is integer tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassConfusionMatrix
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> metric = MulticlassConfusionMatrix(num_classes=3)
>>> metric(preds, target)
tensor([[1, 1, 0],
        [0, 1, 0],
        [0, 0, 1]])
Example (pred is float tensor):
>>> from torchmetrics.classification import MulticlassConfusionMatrix
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> metric = MulticlassConfusionMatrix(num_classes=3)
>>> metric(preds, target)
tensor([[1, 1, 0],
        [0, 1, 0],
        [0, 0, 1]])
plot(val=None, ax=None, add_text=True, labels=None, cmap=None)[source]

Plot a single or multiple values from the metric.

Parameters:
Return type:

tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randint
>>> from torchmetrics.classification import MulticlassConfusionMatrix
>>> metric = MulticlassConfusionMatrix(num_classes=5)
>>> metric.update(randint(5, (20,)), randint(5, (20,)))
>>> fig_, ax_ = metric.plot()
../_images/confusion_matrix-2.png

MultilabelConfusionMatrix

class torchmetrics.classification.MultilabelConfusionMatrix(num_labels, threshold=0.5, ignore_index=None, normalize=None, validate_args=True, **kwargs)[source]

Compute the confusion matrix for multilabel tasks.

The confusion matrix \(C\) is constructed such that \(C_{i, j}\) is equal to the number of observations known to be in class \(i\) but predicted to be in class \(j\). Thus row indices of the confusion matrix correspond to the true class labels and column indices correspond to the predicted class labels.

For multilabel tasks, the confusion matrix is a Nx2x2 tensor, where each 2x2 matrix corresponds to the confusion for that label. The structure of each 2x2 matrix is as follows:

  • \(C_{0, 0}\): True negatives

  • \(C_{0, 1}\): False positives

  • \(C_{1, 0}\): False negatives

  • \(C_{1, 1}\): True positives

As input to ‘update’ the metric accepts the following input:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

As output of ‘compute’ the metric returns the following output:

  • confusion matrix: [num_labels,2,2] matrix

Parameters:
  • num_classes – Integer specifying the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • normalize (Optional[Literal['true', 'pred', 'all', 'none']]) –

    Normalization mode for confusion matrix. Choose from:

    • None or 'none': no normalization (default)

    • 'true': normalization over the targets (most commonly used)

    • 'pred': normalization over the predictions

    • 'all': normalization over the whole matrix

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MultilabelConfusionMatrix
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> metric = MultilabelConfusionMatrix(num_labels=3)
>>> metric(preds, target)
tensor([[[1, 0], [0, 1]],
        [[1, 0], [1, 0]],
        [[0, 1], [0, 1]]])
Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelConfusionMatrix
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> metric = MultilabelConfusionMatrix(num_labels=3)
>>> metric(preds, target)
tensor([[[1, 0], [0, 1]],
        [[1, 0], [1, 0]],
        [[0, 1], [0, 1]]])
plot(val=None, ax=None, add_text=True, labels=None, cmap=None)[source]

Plot a single or multiple values from the metric.

Parameters:
Return type:

tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randint
>>> from torchmetrics.classification import MulticlassConfusionMatrix
>>> metric = MulticlassConfusionMatrix(num_classes=5)
>>> metric.update(randint(5, (20,)), randint(5, (20,)))
>>> fig_, ax_ = metric.plot()
../_images/confusion_matrix-3.png

Functional Interface

confusion_matrix

torchmetrics.functional.confusion_matrix(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, normalize=None, ignore_index=None, validate_args=True)[source]

Compute the confusion matrix.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of binary_confusion_matrix(), multiclass_confusion_matrix() and multilabel_confusion_matrix() for the specific details of each argument influence and examples.

Return type:

Tensor

Legacy Example:
>>> from torch import tensor
>>> from torchmetrics.classification import ConfusionMatrix
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0, 1, 0, 0])
>>> confmat = ConfusionMatrix(task="binary")
>>> confmat(preds, target)
tensor([[2, 0],
        [1, 1]])
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> confmat = ConfusionMatrix(task="multiclass", num_classes=3)
>>> confmat(preds, target)
tensor([[1, 1, 0],
        [0, 1, 0],
        [0, 0, 1]])
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> confmat = ConfusionMatrix(task="multilabel", num_labels=3)
>>> confmat(preds, target)
tensor([[[1, 0], [0, 1]],
        [[1, 0], [1, 0]],
        [[0, 1], [0, 1]]])

binary_confusion_matrix

torchmetrics.functional.classification.binary_confusion_matrix(preds, target, threshold=0.5, normalize=None, ignore_index=None, validate_args=True)[source]

Compute the confusion matrix for binary tasks.

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, ...)

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • normalize (Optional[Literal['true', 'pred', 'all', 'none']]) –

    Normalization mode for confusion matrix. Choose from:

    • None or 'none': no normalization (default)

    • 'true': normalization over the targets (most commonly used)

    • 'pred': normalization over the predictions

    • 'all': normalization over the whole matrix

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

A [2, 2] tensor

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import binary_confusion_matrix
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0, 1, 0, 0])
>>> binary_confusion_matrix(preds, target)
tensor([[2, 0],
        [1, 1]])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_confusion_matrix
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0.35, 0.85, 0.48, 0.01])
>>> binary_confusion_matrix(preds, target)
tensor([[2, 0],
        [1, 1]])

multiclass_confusion_matrix

torchmetrics.functional.classification.multiclass_confusion_matrix(preds, target, num_classes, normalize=None, ignore_index=None, validate_args=True)[source]

Compute the confusion matrix for multiclass tasks.

Accepts the following input tensors:

  • preds: (N, ...) (int tensor) or (N, C, ..) (float tensor). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (int tensor): (N, ...)

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifying the number of classes

  • normalize (Optional[Literal['true', 'pred', 'all', 'none']]) –

    Normalization mode for confusion matrix. Choose from:

    • None or 'none': no normalization (default)

    • 'true': normalization over the targets (most commonly used)

    • 'pred': normalization over the predictions

    • 'all': normalization over the whole matrix

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

A [num_classes, num_classes] tensor

Example (pred is integer tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multiclass_confusion_matrix
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> multiclass_confusion_matrix(preds, target, num_classes=3)
tensor([[1, 1, 0],
        [0, 1, 0],
        [0, 0, 1]])
Example (pred is float tensor):
>>> from torchmetrics.functional.classification import multiclass_confusion_matrix
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> multiclass_confusion_matrix(preds, target, num_classes=3)
tensor([[1, 1, 0],
        [0, 1, 0],
        [0, 0, 1]])

multilabel_confusion_matrix

torchmetrics.functional.classification.multilabel_confusion_matrix(preds, target, num_labels, threshold=0.5, normalize=None, ignore_index=None, validate_args=True)[source]

Compute the confusion matrix for multilabel tasks.

Accepts the following input tensors:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifying the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • normalize (Optional[Literal['true', 'pred', 'all', 'none']]) –

    Normalization mode for confusion matrix. Choose from:

    • None or 'none': no normalization (default)

    • 'true': normalization over the targets (most commonly used)

    • 'pred': normalization over the predictions

    • 'all': normalization over the whole matrix

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

A [num_labels, 2, 2] tensor

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multilabel_confusion_matrix
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> multilabel_confusion_matrix(preds, target, num_labels=3)
tensor([[[1, 0], [0, 1]],
        [[1, 0], [1, 0]],
        [[0, 1], [0, 1]]])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_confusion_matrix
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_confusion_matrix(preds, target, num_labels=3)
tensor([[[1, 0], [0, 1]],
        [[1, 0], [1, 0]],
        [[0, 1], [0, 1]]])