Confusion Matrix¶
Module Interface¶
- class torchmetrics.ConfusionMatrix(**kwargs)[source]
Compute the confusion matrix.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
taskargument to either'binary','multiclass'ormultilabel. See the documentation ofBinaryConfusionMatrix,MulticlassConfusionMatrixandMultilabelConfusionMatrix()for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0, 1, 0, 0]) >>> confmat = ConfusionMatrix(task="binary", num_classes=2) >>> confmat(preds, target) tensor([[2, 0], [1, 1]])
>>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> confmat = ConfusionMatrix(task="multiclass", num_classes=3) >>> confmat(preds, target) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
>>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> confmat = ConfusionMatrix(task="multilabel", num_labels=3) >>> confmat(preds, target) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
BinaryConfusionMatrix¶
- class torchmetrics.classification.BinaryConfusionMatrix(threshold=0.5, ignore_index=None, normalize=None, validate_args=True, **kwargs)[source]
Compute the confusion matrix for binary tasks.
As input to
forwardandupdatethe metric accepts the following input:preds(Tensor): An int or float tensor of shape(N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold.target(Tensor): An int tensor of shape(N, ...).
As output to
forwardandcomputethe metric returns the following output:confusion_matrix(Tensor): A tensor containing a(2, 2)matrix
Additional dimension
...will be flattened into the batch dimension.- Parameters:
threshold¶ (
float) – Threshold for transforming probability to binary (0,1) predictionsignore_index¶ (
Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculationnormalize¶ (
Optional[Literal['true','pred','all','none']]) –Normalization mode for confusion matrix. Choose from:
Noneor'none': no normalization (default)'true': normalization over the targets (most commonly used)'pred': normalization over the predictions'all': normalization over the whole matrix
validate_args¶ (
bool) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalsefor faster computations.kwargs¶ (
Any) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (preds is int tensor):
>>> from torchmetrics.classification import BinaryConfusionMatrix >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) >>> bcm = BinaryConfusionMatrix() >>> bcm(preds, target) tensor([[2, 0], [1, 1]])
- Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryConfusionMatrix >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) >>> bcm = BinaryConfusionMatrix() >>> bcm(preds, target) tensor([[2, 0], [1, 1]])
- plot(val=None, ax=None, add_text=True, labels=None)[source]
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Optional[Tensor]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axisadd_text¶ (
bool) – if the value of each cell should be added to the plotlabels¶ (
Optional[List[str]]) – a list of strings, if provided will be added to the plot to indicate the different classes
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randint >>> from torchmetrics.classification import MulticlassConfusionMatrix >>> metric = MulticlassConfusionMatrix(num_classes=5) >>> metric.update(randint(5, (20,)), randint(5, (20,))) >>> fig_, ax_ = metric.plot()
MulticlassConfusionMatrix¶
- class torchmetrics.classification.MulticlassConfusionMatrix(num_classes, ignore_index=None, normalize=None, validate_args=True, **kwargs)[source]
Compute the confusion matrix for multiclass tasks.
As input to
forwardandupdatethe metric accepts the following input:preds(Tensor): An int or float tensor of shape(N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold.target(Tensor): An int tensor of shape(N, ...).
As output to
forwardandcomputethe metric returns the following output:confusion_matrix: [num_classes, num_classes] matrix
- Parameters:
num_classes¶ (
int) – Integer specifing the number of classesignore_index¶ (
Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculationnormalize¶ (
Optional[Literal['true','pred','all','none']]) –Normalization mode for confusion matrix. Choose from:
Noneor'none': no normalization (default)'true': normalization over the targets (most commonly used)'pred': normalization over the predictions'all': normalization over the whole matrix
validate_args¶ (
bool) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalsefor faster computations.kwargs¶ (
Any) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (pred is integer tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassConfusionMatrix >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> metric = MulticlassConfusionMatrix(num_classes=3) >>> metric(preds, target) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
- Example (pred is float tensor):
>>> from torchmetrics.classification import MulticlassConfusionMatrix >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> metric = MulticlassConfusionMatrix(num_classes=3) >>> metric(preds, target) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
- plot(val=None, ax=None, add_text=True, labels=None)[source]
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Optional[Tensor]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axisadd_text¶ (
bool) – if the value of each cell should be added to the plotlabels¶ (
Optional[List[str]]) – a list of strings, if provided will be added to the plot to indicate the different classes
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randint >>> from torchmetrics.classification import MulticlassConfusionMatrix >>> metric = MulticlassConfusionMatrix(num_classes=5) >>> metric.update(randint(5, (20,)), randint(5, (20,))) >>> fig_, ax_ = metric.plot()
MultilabelConfusionMatrix¶
- class torchmetrics.classification.MultilabelConfusionMatrix(num_labels, threshold=0.5, ignore_index=None, normalize=None, validate_args=True, **kwargs)[source]
Compute the confusion matrix for multilabel tasks.
As input to ‘update’ the metric accepts the following input:
preds(int or float tensor):(N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold.target(int tensor):(N, C, ...)
As output of ‘compute’ the metric returns the following output:
confusion matrix: [num_labels,2,2] matrix
- Parameters:
num_classes¶ – Integer specifing the number of labels
threshold¶ (
float) – Threshold for transforming probability to binary (0,1) predictionsignore_index¶ (
Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculationnormalize¶ (
Optional[Literal['true','pred','all','none']]) –Normalization mode for confusion matrix. Choose from:
Noneor'none': no normalization (default)'true': normalization over the targets (most commonly used)'pred': normalization over the predictions'all': normalization over the whole matrix
validate_args¶ (
bool) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalsefor faster computations.kwargs¶ (
Any) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MultilabelConfusionMatrix >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelConfusionMatrix(num_labels=3) >>> metric(preds, target) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
- Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelConfusionMatrix >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelConfusionMatrix(num_labels=3) >>> metric(preds, target) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
- plot(val=None, ax=None, add_text=True, labels=None)[source]
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Optional[Tensor]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axisadd_text¶ (
bool) – if the value of each cell should be added to the plotlabels¶ (
Optional[List[str]]) – a list of strings, if provided will be added to the plot to indicate the different classes
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randint >>> from torchmetrics.classification import MulticlassConfusionMatrix >>> metric = MulticlassConfusionMatrix(num_classes=5) >>> metric.update(randint(5, (20,)), randint(5, (20,))) >>> fig_, ax_ = metric.plot()
Functional Interface¶
confusion_matrix¶
- torchmetrics.functional.confusion_matrix(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, normalize=None, ignore_index=None, validate_args=True)[source]
Compute the confusion matrix.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
taskargument to either'binary','multiclass'ormultilabel. See the documentation ofbinary_confusion_matrix(),multiclass_confusion_matrix()andmultilabel_confusion_matrix()for the specific details of each argument influence and examples.- Return type:
- Legacy Example:
>>> from torch import tensor >>> from torchmetrics.classification import ConfusionMatrix >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0, 1, 0, 0]) >>> confmat = ConfusionMatrix(task="binary") >>> confmat(preds, target) tensor([[2, 0], [1, 1]])
>>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> confmat = ConfusionMatrix(task="multiclass", num_classes=3) >>> confmat(preds, target) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
>>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> confmat = ConfusionMatrix(task="multilabel", num_labels=3) >>> confmat(preds, target) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
binary_confusion_matrix¶
- torchmetrics.functional.classification.binary_confusion_matrix(preds, target, threshold=0.5, normalize=None, ignore_index=None, validate_args=True)[source]
Compute the confusion matrix for binary tasks.
Accepts the following input tensors:
preds(int or float tensor):(N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold.target(int tensor):(N, ...)
Additional dimension
...will be flattened into the batch dimension.- Parameters:
threshold¶ (
float) – Threshold for transforming probability to binary (0,1) predictionsnormalize¶ (
Optional[Literal['true','pred','all','none']]) –Normalization mode for confusion matrix. Choose from:
Noneor'none': no normalization (default)'true': normalization over the targets (most commonly used)'pred': normalization over the predictions'all': normalization over the whole matrix
ignore_index¶ (
Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args¶ (
bool) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalsefor faster computations.
- Return type:
- Returns:
A
[2, 2]tensor
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import binary_confusion_matrix >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0, 1, 0, 0]) >>> binary_confusion_matrix(preds, target) tensor([[2, 0], [1, 1]])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_confusion_matrix >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0.35, 0.85, 0.48, 0.01]) >>> binary_confusion_matrix(preds, target) tensor([[2, 0], [1, 1]])
multiclass_confusion_matrix¶
- torchmetrics.functional.classification.multiclass_confusion_matrix(preds, target, num_classes, normalize=None, ignore_index=None, validate_args=True)[source]
Compute the confusion matrix for multiclass tasks.
Accepts the following input tensors:
preds:(N, ...)(int tensor) or(N, C, ..)(float tensor). If preds is a floating point we applytorch.argmaxalong theCdimension to automatically convert probabilities/logits into an int tensor.target(int tensor):(N, ...)
Additional dimension
...will be flattened into the batch dimension.- Parameters:
num_classes¶ (
int) – Integer specifing the number of classesnormalize¶ (
Optional[Literal['true','pred','all','none']]) –Normalization mode for confusion matrix. Choose from:
Noneor'none': no normalization (default)'true': normalization over the targets (most commonly used)'pred': normalization over the predictions'all': normalization over the whole matrix
ignore_index¶ (
Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args¶ (
bool) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalsefor faster computations.
- Return type:
- Returns:
A
[num_classes, num_classes]tensor
- Example (pred is integer tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multiclass_confusion_matrix >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> multiclass_confusion_matrix(preds, target, num_classes=3) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
- Example (pred is float tensor):
>>> from torchmetrics.functional.classification import multiclass_confusion_matrix >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> multiclass_confusion_matrix(preds, target, num_classes=3) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
multilabel_confusion_matrix¶
- torchmetrics.functional.classification.multilabel_confusion_matrix(preds, target, num_labels, threshold=0.5, normalize=None, ignore_index=None, validate_args=True)[source]
Compute the confusion matrix for multilabel tasks.
Accepts the following input tensors:
preds(int or float tensor):(N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold.target(int tensor):(N, C, ...)
Additional dimension
...will be flattened into the batch dimension.- Parameters:
threshold¶ (
float) – Threshold for transforming probability to binary (0,1) predictionsnormalize¶ (
Optional[Literal['true','pred','all','none']]) –Normalization mode for confusion matrix. Choose from:
Noneor'none': no normalization (default)'true': normalization over the targets (most commonly used)'pred': normalization over the predictions'all': normalization over the whole matrix
ignore_index¶ (
Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args¶ (
bool) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalsefor faster computations.
- Return type:
- Returns:
A
[num_labels, 2, 2]tensor
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multilabel_confusion_matrix >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_confusion_matrix(preds, target, num_labels=3) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_confusion_matrix >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_confusion_matrix(preds, target, num_labels=3) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])