Confusion Matrix
Module Interface
- class torchmetrics.ConfusionMatrix(**kwargs)[source]
Compute the confusion matrix.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofBinaryConfusionMatrix
,MulticlassConfusionMatrix
andMultilabelConfusionMatrix()
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0, 1, 0, 0]) >>> confmat = ConfusionMatrix(task="binary", num_classes=2) >>> confmat(preds, target) tensor([[2, 0], [1, 1]])
>>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> confmat = ConfusionMatrix(task="multiclass", num_classes=3) >>> confmat(preds, target) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
>>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> confmat = ConfusionMatrix(task="multilabel", num_labels=3) >>> confmat(preds, target) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
BinaryConfusionMatrix
- class torchmetrics.classification.BinaryConfusionMatrix(threshold=0.5, ignore_index=None, normalize=None, validate_args=True, **kwargs)[source]
Compute the confusion matrix for binary tasks.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int or float tensor of shape(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, ...)
.
As output to
forward
andcompute
the metric returns the following output:confusion_matrix
(Tensor
): A tensor containing a(2, 2)
matrix
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
threshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationnormalize (
Optional
[Literal
['true'
,'pred'
,'all'
,'none'
]]) –Normalization mode for confusion matrix. Choose from:
None
or'none'
: no normalization (default)'true'
: normalization over the targets (most commonly used)'pred'
: normalization over the predictions'all'
: normalization over the whole matrix
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (preds is int tensor):
>>> from torchmetrics.classification import BinaryConfusionMatrix >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) >>> bcm = BinaryConfusionMatrix() >>> bcm(preds, target) tensor([[2, 0], [1, 1]])
- Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryConfusionMatrix >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) >>> bcm = BinaryConfusionMatrix() >>> bcm(preds, target) tensor([[2, 0], [1, 1]])
- plot(val=None, ax=None, add_text=True, labels=None)[source]
Plot a single or multiple values from the metric.
- Parameters:
val (
Optional
[Tensor
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axisadd_text (
bool
) – if the value of each cell should be added to the plotlabels (
Optional
[List
[str
]]) – a list of strings, if provided will be added to the plot to indicate the different classes
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randint >>> from torchmetrics.classification import MulticlassConfusionMatrix >>> metric = MulticlassConfusionMatrix(num_classes=5) >>> metric.update(randint(5, (20,)), randint(5, (20,))) >>> fig_, ax_ = metric.plot()
MulticlassConfusionMatrix
- class torchmetrics.classification.MulticlassConfusionMatrix(num_classes, ignore_index=None, normalize=None, validate_args=True, **kwargs)[source]
Compute the confusion matrix for multiclass tasks.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int or float tensor of shape(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, ...)
.
As output to
forward
andcompute
the metric returns the following output:confusion_matrix
: [num_classes, num_classes] matrix
- Parameters:
num_classes (
int
) – Integer specifing the number of classesignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationnormalize (
Optional
[Literal
['true'
,'pred'
,'all'
,'none'
]]) –Normalization mode for confusion matrix. Choose from:
None
or'none'
: no normalization (default)'true'
: normalization over the targets (most commonly used)'pred'
: normalization over the predictions'all'
: normalization over the whole matrix
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (pred is integer tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassConfusionMatrix >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> metric = MulticlassConfusionMatrix(num_classes=3) >>> metric(preds, target) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
- Example (pred is float tensor):
>>> from torchmetrics.classification import MulticlassConfusionMatrix >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> metric = MulticlassConfusionMatrix(num_classes=3) >>> metric(preds, target) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
- plot(val=None, ax=None, add_text=True, labels=None)[source]
Plot a single or multiple values from the metric.
- Parameters:
val (
Optional
[Tensor
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axisadd_text (
bool
) – if the value of each cell should be added to the plotlabels (
Optional
[List
[str
]]) – a list of strings, if provided will be added to the plot to indicate the different classes
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randint >>> from torchmetrics.classification import MulticlassConfusionMatrix >>> metric = MulticlassConfusionMatrix(num_classes=5) >>> metric.update(randint(5, (20,)), randint(5, (20,))) >>> fig_, ax_ = metric.plot()
MultilabelConfusionMatrix
- class torchmetrics.classification.MultilabelConfusionMatrix(num_labels, threshold=0.5, ignore_index=None, normalize=None, validate_args=True, **kwargs)[source]
Compute the confusion matrix for multilabel tasks.
As input to ‘update’ the metric accepts the following input:
preds
(int or float tensor):(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, C, ...)
As output of ‘compute’ the metric returns the following output:
confusion matrix
: [num_labels,2,2] matrix
- Parameters:
threshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationnormalize (
Optional
[Literal
['true'
,'pred'
,'all'
,'none'
]]) –Normalization mode for confusion matrix. Choose from:
None
or'none'
: no normalization (default)'true'
: normalization over the targets (most commonly used)'pred'
: normalization over the predictions'all'
: normalization over the whole matrix
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MultilabelConfusionMatrix >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelConfusionMatrix(num_labels=3) >>> metric(preds, target) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
- Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelConfusionMatrix >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelConfusionMatrix(num_labels=3) >>> metric(preds, target) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
- plot(val=None, ax=None, add_text=True, labels=None)[source]
Plot a single or multiple values from the metric.
- Parameters:
val (
Optional
[Tensor
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axisadd_text (
bool
) – if the value of each cell should be added to the plotlabels (
Optional
[List
[str
]]) – a list of strings, if provided will be added to the plot to indicate the different classes
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randint >>> from torchmetrics.classification import MulticlassConfusionMatrix >>> metric = MulticlassConfusionMatrix(num_classes=5) >>> metric.update(randint(5, (20,)), randint(5, (20,))) >>> fig_, ax_ = metric.plot()
Functional Interface
confusion_matrix
- torchmetrics.functional.confusion_matrix(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, normalize=None, ignore_index=None, validate_args=True)[source]
Compute the confusion matrix.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofbinary_confusion_matrix()
,multiclass_confusion_matrix()
andmultilabel_confusion_matrix()
for the specific details of each argument influence and examples.- Return type:
- Legacy Example:
>>> from torch import tensor >>> from torchmetrics.classification import ConfusionMatrix >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0, 1, 0, 0]) >>> confmat = ConfusionMatrix(task="binary") >>> confmat(preds, target) tensor([[2, 0], [1, 1]])
>>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> confmat = ConfusionMatrix(task="multiclass", num_classes=3) >>> confmat(preds, target) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
>>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> confmat = ConfusionMatrix(task="multilabel", num_labels=3) >>> confmat(preds, target) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
binary_confusion_matrix
- torchmetrics.functional.classification.binary_confusion_matrix(preds, target, threshold=0.5, normalize=None, ignore_index=None, validate_args=True)[source]
Compute the confusion matrix for binary tasks.
Accepts the following input tensors:
preds
(int or float tensor):(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, ...)
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsnormalize (
Optional
[Literal
['true'
,'pred'
,'all'
,'none'
]]) –Normalization mode for confusion matrix. Choose from:
None
or'none'
: no normalization (default)'true'
: normalization over the targets (most commonly used)'pred'
: normalization over the predictions'all'
: normalization over the whole matrix
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
A
[2, 2]
tensor
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import binary_confusion_matrix >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0, 1, 0, 0]) >>> binary_confusion_matrix(preds, target) tensor([[2, 0], [1, 1]])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_confusion_matrix >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0.35, 0.85, 0.48, 0.01]) >>> binary_confusion_matrix(preds, target) tensor([[2, 0], [1, 1]])
multiclass_confusion_matrix
- torchmetrics.functional.classification.multiclass_confusion_matrix(preds, target, num_classes, normalize=None, ignore_index=None, validate_args=True)[source]
Compute the confusion matrix for multiclass tasks.
Accepts the following input tensors:
preds
:(N, ...)
(int tensor) or(N, C, ..)
(float tensor). If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(int tensor):(N, ...)
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_classes (
int
) – Integer specifing the number of classesnormalize (
Optional
[Literal
['true'
,'pred'
,'all'
,'none'
]]) –Normalization mode for confusion matrix. Choose from:
None
or'none'
: no normalization (default)'true'
: normalization over the targets (most commonly used)'pred'
: normalization over the predictions'all'
: normalization over the whole matrix
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
A
[num_classes, num_classes]
tensor
- Example (pred is integer tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multiclass_confusion_matrix >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> multiclass_confusion_matrix(preds, target, num_classes=3) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
- Example (pred is float tensor):
>>> from torchmetrics.functional.classification import multiclass_confusion_matrix >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> multiclass_confusion_matrix(preds, target, num_classes=3) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
multilabel_confusion_matrix
- torchmetrics.functional.classification.multilabel_confusion_matrix(preds, target, num_labels, threshold=0.5, normalize=None, ignore_index=None, validate_args=True)[source]
Compute the confusion matrix for multilabel tasks.
Accepts the following input tensors:
preds
(int or float tensor):(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, C, ...)
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_labels (
int
) – Integer specifing the number of labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsnormalize (
Optional
[Literal
['true'
,'pred'
,'all'
,'none'
]]) –Normalization mode for confusion matrix. Choose from:
None
or'none'
: no normalization (default)'true'
: normalization over the targets (most commonly used)'pred'
: normalization over the predictions'all'
: normalization over the whole matrix
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
A
[num_labels, 2, 2]
tensor
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multilabel_confusion_matrix >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_confusion_matrix(preds, target, num_labels=3) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_confusion_matrix >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_confusion_matrix(preds, target, num_labels=3) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])