Confusion Matrix¶
Module Interface¶
- class torchmetrics.ConfusionMatrix(**kwargs)[source]¶
Compute the confusion matrix.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofBinaryConfusionMatrix
,MulticlassConfusionMatrix
andMultilabelConfusionMatrix
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0, 1, 0, 0]) >>> confmat = ConfusionMatrix(task="binary", num_classes=2) >>> confmat(preds, target) tensor([[2, 0], [1, 1]])
>>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> confmat = ConfusionMatrix(task="multiclass", num_classes=3) >>> confmat(preds, target) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
>>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> confmat = ConfusionMatrix(task="multilabel", num_labels=3) >>> confmat(preds, target) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
BinaryConfusionMatrix¶
- class torchmetrics.classification.BinaryConfusionMatrix(threshold=0.5, ignore_index=None, normalize=None, validate_args=True, **kwargs)[source]¶
Compute the confusion matrix for binary tasks.
The confusion matrix \(C\) is constructed such that \(C_{i, j}\) is equal to the number of observations known to be in class \(i\) but predicted to be in class \(j\). Thus row indices of the confusion matrix correspond to the true class labels and column indices correspond to the predicted class labels.
For binary tasks, the confusion matrix is a 2x2 matrix with the following structure:
\(C_{0, 0}\): True negatives
\(C_{0, 1}\): False positives
\(C_{1, 0}\): False negatives
\(C_{1, 1}\): True positives
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int or float tensor of shape(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, ...)
.
As output to
forward
andcompute
the metric returns the following output:confusion_matrix
(Tensor
): A tensor containing a(2, 2)
matrix
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
threshold¶ (
float
) – Threshold for transforming probability to binary (0,1) predictionsignore_index¶ (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationnormalize¶ (
Optional
[Literal
['true'
,'pred'
,'all'
,'none'
]]) –Normalization mode for confusion matrix. Choose from:
None
or'none'
: no normalization (default)'true'
: normalization over the targets (most commonly used)'pred'
: normalization over the predictions'all'
: normalization over the whole matrix
validate_args¶ (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (preds is int tensor):
>>> from torchmetrics.classification import BinaryConfusionMatrix >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) >>> bcm = BinaryConfusionMatrix() >>> bcm(preds, target) tensor([[2, 0], [1, 1]])
- Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryConfusionMatrix >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) >>> bcm = BinaryConfusionMatrix() >>> bcm(preds, target) tensor([[2, 0], [1, 1]])
- plot(val=None, ax=None, add_text=True, labels=None, cmap=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Optional
[Tensor
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axisadd_text¶ (
bool
) – if the value of each cell should be added to the plotlabels¶ (
Optional
[list
[str
]]) – a list of strings, if provided will be added to the plot to indicate the different classescmap¶ (
Union
[Colormap
,str
,None
]) – matplotlib colormap to use for the confusion matrix https://matplotlib.org/stable/users/explain/colors/colormaps.html
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randint >>> from torchmetrics.classification import MulticlassConfusionMatrix >>> metric = MulticlassConfusionMatrix(num_classes=5) >>> metric.update(randint(5, (20,)), randint(5, (20,))) >>> fig_, ax_ = metric.plot()
MulticlassConfusionMatrix¶
- class torchmetrics.classification.MulticlassConfusionMatrix(num_classes, ignore_index=None, normalize=None, validate_args=True, **kwargs)[source]¶
Compute the confusion matrix for multiclass tasks.
The confusion matrix \(C\) is constructed such that \(C_{i, j}\) is equal to the number of observations known to be in class \(i\) but predicted to be in class \(j\). Thus row indices of the confusion matrix correspond to the true class labels and column indices correspond to the predicted class labels.
For multiclass tasks, the confusion matrix is a NxN matrix, where:
\(C_{i, i}\) represents the number of true positives for class \(i\)
\(\sum_{j=1, j\neq i}^N C_{i, j}\) represents the number of false negatives for class \(i\)
\(\sum_{j=1, j\neq i}^N C_{j, i}\) represents the number of false positives for class \(i\)
the sum of the remaining cells in the matrix represents the number of true negatives for class \(i\)
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int or float tensor of shape(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, ...)
.
As output to
forward
andcompute
the metric returns the following output:confusion_matrix
: [num_classes, num_classes] matrix
- Parameters:
num_classes¶ (
int
) – Integer specifying the number of classesignore_index¶ (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationnormalize¶ (
Optional
[Literal
['true'
,'pred'
,'all'
,'none'
]]) –Normalization mode for confusion matrix. Choose from:
None
or'none'
: no normalization (default)'true'
: normalization over the targets (most commonly used)'pred'
: normalization over the predictions'all'
: normalization over the whole matrix
validate_args¶ (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (pred is integer tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassConfusionMatrix >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> metric = MulticlassConfusionMatrix(num_classes=3) >>> metric(preds, target) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
- Example (pred is float tensor):
>>> from torchmetrics.classification import MulticlassConfusionMatrix >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> metric = MulticlassConfusionMatrix(num_classes=3) >>> metric(preds, target) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
- plot(val=None, ax=None, add_text=True, labels=None, cmap=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Optional
[Tensor
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axisadd_text¶ (
bool
) – if the value of each cell should be added to the plotlabels¶ (
Optional
[list
[str
]]) – a list of strings, if provided will be added to the plot to indicate the different classescmap¶ (
Union
[Colormap
,str
,None
]) – matplotlib colormap to use for the confusion matrix https://matplotlib.org/stable/users/explain/colors/colormaps.html
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randint >>> from torchmetrics.classification import MulticlassConfusionMatrix >>> metric = MulticlassConfusionMatrix(num_classes=5) >>> metric.update(randint(5, (20,)), randint(5, (20,))) >>> fig_, ax_ = metric.plot()
MultilabelConfusionMatrix¶
- class torchmetrics.classification.MultilabelConfusionMatrix(num_labels, threshold=0.5, ignore_index=None, normalize=None, validate_args=True, **kwargs)[source]¶
Compute the confusion matrix for multilabel tasks.
The confusion matrix \(C\) is constructed such that \(C_{i, j}\) is equal to the number of observations known to be in class \(i\) but predicted to be in class \(j\). Thus row indices of the confusion matrix correspond to the true class labels and column indices correspond to the predicted class labels.
For multilabel tasks, the confusion matrix is a Nx2x2 tensor, where each 2x2 matrix corresponds to the confusion for that label. The structure of each 2x2 matrix is as follows:
\(C_{0, 0}\): True negatives
\(C_{0, 1}\): False positives
\(C_{1, 0}\): False negatives
\(C_{1, 1}\): True positives
As input to ‘update’ the metric accepts the following input:
preds
(int or float tensor):(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, C, ...)
As output of ‘compute’ the metric returns the following output:
confusion matrix
: [num_labels,2,2] matrix
- Parameters:
num_classes¶ – Integer specifying the number of labels
threshold¶ (
float
) – Threshold for transforming probability to binary (0,1) predictionsignore_index¶ (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationnormalize¶ (
Optional
[Literal
['true'
,'pred'
,'all'
,'none'
]]) –Normalization mode for confusion matrix. Choose from:
None
or'none'
: no normalization (default)'true'
: normalization over the targets (most commonly used)'pred'
: normalization over the predictions'all'
: normalization over the whole matrix
validate_args¶ (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MultilabelConfusionMatrix >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelConfusionMatrix(num_labels=3) >>> metric(preds, target) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
- Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelConfusionMatrix >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelConfusionMatrix(num_labels=3) >>> metric(preds, target) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
- plot(val=None, ax=None, add_text=True, labels=None, cmap=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Optional
[Tensor
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axisadd_text¶ (
bool
) – if the value of each cell should be added to the plotlabels¶ (
Optional
[list
[str
]]) – a list of strings, if provided will be added to the plot to indicate the different classescmap¶ (
Union
[Colormap
,str
,None
]) – matplotlib colormap to use for the confusion matrix https://matplotlib.org/stable/users/explain/colors/colormaps.html
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randint >>> from torchmetrics.classification import MulticlassConfusionMatrix >>> metric = MulticlassConfusionMatrix(num_classes=5) >>> metric.update(randint(5, (20,)), randint(5, (20,))) >>> fig_, ax_ = metric.plot()
Functional Interface¶
confusion_matrix¶
- torchmetrics.functional.confusion_matrix(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, normalize=None, ignore_index=None, validate_args=True)[source]¶
Compute the confusion matrix.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofbinary_confusion_matrix()
,multiclass_confusion_matrix()
andmultilabel_confusion_matrix()
for the specific details of each argument influence and examples.- Return type:
- Legacy Example:
>>> from torch import tensor >>> from torchmetrics.classification import ConfusionMatrix >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0, 1, 0, 0]) >>> confmat = ConfusionMatrix(task="binary") >>> confmat(preds, target) tensor([[2, 0], [1, 1]])
>>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> confmat = ConfusionMatrix(task="multiclass", num_classes=3) >>> confmat(preds, target) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
>>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> confmat = ConfusionMatrix(task="multilabel", num_labels=3) >>> confmat(preds, target) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
binary_confusion_matrix¶
- torchmetrics.functional.classification.binary_confusion_matrix(preds, target, threshold=0.5, normalize=None, ignore_index=None, validate_args=True)[source]¶
Compute the confusion matrix for binary tasks.
Accepts the following input tensors:
preds
(int or float tensor):(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, ...)
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
threshold¶ (
float
) – Threshold for transforming probability to binary (0,1) predictionsnormalize¶ (
Optional
[Literal
['true'
,'pred'
,'all'
,'none'
]]) –Normalization mode for confusion matrix. Choose from:
None
or'none'
: no normalization (default)'true'
: normalization over the targets (most commonly used)'pred'
: normalization over the predictions'all'
: normalization over the whole matrix
ignore_index¶ (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args¶ (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
A
[2, 2]
tensor
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import binary_confusion_matrix >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0, 1, 0, 0]) >>> binary_confusion_matrix(preds, target) tensor([[2, 0], [1, 1]])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_confusion_matrix >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0.35, 0.85, 0.48, 0.01]) >>> binary_confusion_matrix(preds, target) tensor([[2, 0], [1, 1]])
multiclass_confusion_matrix¶
- torchmetrics.functional.classification.multiclass_confusion_matrix(preds, target, num_classes, normalize=None, ignore_index=None, validate_args=True)[source]¶
Compute the confusion matrix for multiclass tasks.
Accepts the following input tensors:
preds
:(N, ...)
(int tensor) or(N, C, ..)
(float tensor). If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(int tensor):(N, ...)
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
num_classes¶ (
int
) – Integer specifying the number of classesnormalize¶ (
Optional
[Literal
['true'
,'pred'
,'all'
,'none'
]]) –Normalization mode for confusion matrix. Choose from:
None
or'none'
: no normalization (default)'true'
: normalization over the targets (most commonly used)'pred'
: normalization over the predictions'all'
: normalization over the whole matrix
ignore_index¶ (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args¶ (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
A
[num_classes, num_classes]
tensor
- Example (pred is integer tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multiclass_confusion_matrix >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> multiclass_confusion_matrix(preds, target, num_classes=3) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
- Example (pred is float tensor):
>>> from torchmetrics.functional.classification import multiclass_confusion_matrix >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> multiclass_confusion_matrix(preds, target, num_classes=3) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
multilabel_confusion_matrix¶
- torchmetrics.functional.classification.multilabel_confusion_matrix(preds, target, num_labels, threshold=0.5, normalize=None, ignore_index=None, validate_args=True)[source]¶
Compute the confusion matrix for multilabel tasks.
Accepts the following input tensors:
preds
(int or float tensor):(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, C, ...)
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
threshold¶ (
float
) – Threshold for transforming probability to binary (0,1) predictionsnormalize¶ (
Optional
[Literal
['true'
,'pred'
,'all'
,'none'
]]) –Normalization mode for confusion matrix. Choose from:
None
or'none'
: no normalization (default)'true'
: normalization over the targets (most commonly used)'pred'
: normalization over the predictions'all'
: normalization over the whole matrix
ignore_index¶ (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args¶ (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
A
[num_labels, 2, 2]
tensor
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multilabel_confusion_matrix >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_confusion_matrix(preds, target, num_labels=3) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_confusion_matrix >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_confusion_matrix(preds, target, num_labels=3) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])