Jaccard Index

Module Interface

class torchmetrics.JaccardIndex(**kwargs)[source]

Calculate the Jaccard index for multilabel tasks.

The Jaccard index (also known as the intersection over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets:

\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of BinaryJaccardIndex, MulticlassJaccardIndex and MultilabelJaccardIndex for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import randint, tensor
>>> target = randint(0, 2, (10, 25, 25))
>>> pred = tensor(target)
>>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15]
>>> jaccard = JaccardIndex(task="multiclass", num_classes=2)
>>> jaccard(pred, target)
tensor(0.9660)
static __new__(cls, task, threshold=0.5, num_classes=None, num_labels=None, average='macro', ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinaryJaccardIndex

class torchmetrics.classification.BinaryJaccardIndex(threshold=0.5, ignore_index=None, validate_args=True, zero_division=0, **kwargs)[source]

Calculate the Jaccard index for binary tasks.

The Jaccard index (also known as the intersection over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets:

\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A int or float tensor of shape (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, ...).

Tip

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following output:

  • bji (Tensor): A tensor containing the Binary Jaccard Index.

Parameters:
  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • zero_division (float) – Value to replace when there is a division by zero. Should be 0 or 1.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import BinaryJaccardIndex
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0, 1, 0, 0])
>>> metric = BinaryJaccardIndex()
>>> metric(preds, target)
tensor(0.5000)
Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryJaccardIndex
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0.35, 0.85, 0.48, 0.01])
>>> metric = BinaryJaccardIndex()
>>> metric(preds, target)
tensor(0.5000)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torch import rand, randint
>>> from torchmetrics.classification import BinaryJaccardIndex
>>> metric = BinaryJaccardIndex()
>>> metric.update(rand(10), randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
../_images/jaccard_index-1.png
>>> # Example plotting multiple values
>>> from torch import rand, randint
>>> from torchmetrics.classification import BinaryJaccardIndex
>>> metric = BinaryJaccardIndex()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(rand(10), randint(2,(10,))))
>>> fig_, ax_ = metric.plot(values)
../_images/jaccard_index-2.png

MulticlassJaccardIndex

class torchmetrics.classification.MulticlassJaccardIndex(num_classes, average='macro', ignore_index=None, validate_args=True, zero_division=0, **kwargs)[source]

Calculate the Jaccard index for multiclass tasks.

The Jaccard index (also known as the intersection over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets:

\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A int tensor of shape (N, ...) or float tensor of shape (N, C, ..). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (Tensor): An int tensor of shape (N, ...).

Tip

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following output:

  • mcji (Tensor): A tensor containing the Multi-class Jaccard Index.

Parameters:
  • num_classes (int) – Integer specifying the number of classes

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • zero_division (float) – Value to replace when there is a division by zero. Should be 0 or 1.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (pred is integer tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassJaccardIndex
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> metric = MulticlassJaccardIndex(num_classes=3)
>>> metric(preds, target)
tensor(0.6667)
Example (pred is float tensor):
>>> from torchmetrics.classification import MulticlassJaccardIndex
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> metric = MulticlassJaccardIndex(num_classes=3)
>>> metric(preds, target)
tensor(0.6667)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value per class
>>> from torch import randint
>>> from torchmetrics.classification import MulticlassJaccardIndex
>>> metric = MulticlassJaccardIndex(num_classes=3, average=None)
>>> metric.update(randint(3, (20,)), randint(3, (20,)))
>>> fig_, ax_ = metric.plot()
../_images/jaccard_index-3.png
>>> # Example plotting a multiple values per class
>>> from torch import randint
>>> from torchmetrics.classification import MulticlassJaccardIndex
>>> metric = MulticlassJaccardIndex(num_classes=3, average=None)
>>> values = []
>>> for _ in range(20):
...     values.append(metric(randint(3, (20,)), randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
../_images/jaccard_index-4.png

MultilabelJaccardIndex

class torchmetrics.classification.MultilabelJaccardIndex(num_labels, threshold=0.5, average='macro', ignore_index=None, validate_args=True, zero_division=0, **kwargs)[source]

Calculate the Jaccard index for multilabel tasks.

The Jaccard index (also known as the intersection over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets:

\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A int tensor or float tensor of shape (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, C, ...)

Tip

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following output:

  • mlji (Tensor): A tensor containing the Multi-label Jaccard Index loss.

Parameters:
  • num_classes – Integer specifying the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • zero_division (float) – Value to replace when there is a division by zero. Should be 0 or 1.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MultilabelJaccardIndex
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> metric = MultilabelJaccardIndex(num_labels=3)
>>> metric(preds, target)
tensor(0.5000)
Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelJaccardIndex
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> metric = MultilabelJaccardIndex(num_labels=3)
>>> metric(preds, target)
tensor(0.5000)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torch import rand, randint
>>> from torchmetrics.classification import MultilabelJaccardIndex
>>> metric = MultilabelJaccardIndex(num_labels=3)
>>> metric.update(randint(2, (20, 3)), randint(2, (20, 3)))
>>> fig_, ax_ = metric.plot()
../_images/jaccard_index-5.png
>>> # Example plotting multiple values
>>> from torch import rand, randint
>>> from torchmetrics.classification import MultilabelJaccardIndex
>>> metric = MultilabelJaccardIndex(num_labels=3)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(randint(2, (20, 3)), randint(2, (20, 3))))
>>> fig_, ax_ = metric.plot(values)
../_images/jaccard_index-6.png

Functional Interface

jaccard_index

torchmetrics.functional.jaccard_index(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, average='macro', ignore_index=None, validate_args=True, zero_division=0.0)[source]

Calculate the Jaccard index.

The Jaccard index (also known as the intersection over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets: :rtype: Tensor

\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of binary_jaccard_index(), multiclass_jaccard_index() and multilabel_jaccard_index() for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import randint, tensor
>>> target = randint(0, 2, (10, 25, 25))
>>> pred = tensor(target)
>>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15]
>>> jaccard_index(pred, target, task="multiclass", num_classes=2)
tensor(0.9660)

binary_jaccard_index

torchmetrics.functional.classification.binary_jaccard_index(preds, target, threshold=0.5, ignore_index=None, validate_args=True, zero_division=0.0)[source]

Calculate the Jaccard index for binary tasks.

The Jaccard index (also known as the intersection over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets:

\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, ...)

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • zero_division (float) – Value to replace when there is a division by zero. Should be 0 or 1.

Return type:

Tensor

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import binary_jaccard_index
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0, 1, 0, 0])
>>> binary_jaccard_index(preds, target)
tensor(0.5000)
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_jaccard_index
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0.35, 0.85, 0.48, 0.01])
>>> binary_jaccard_index(preds, target)
tensor(0.5000)

multiclass_jaccard_index

torchmetrics.functional.classification.multiclass_jaccard_index(preds, target, num_classes, average='macro', ignore_index=None, validate_args=True, zero_division=0.0)[source]

Calculate the Jaccard index for multiclass tasks.

The Jaccard index (also known as the intersection over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets:

\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]

Accepts the following input tensors:

  • preds: (N, ...) (int tensor) or (N, C, ..) (float tensor). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (int tensor): (N, ...)

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifying the number of classes

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • zero_division (float) – Value to replace when there is a division by zero. Should be 0 or 1.

Return type:

Tensor

Example (pred is integer tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multiclass_jaccard_index
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> multiclass_jaccard_index(preds, target, num_classes=3)
tensor(0.6667)
Example (pred is float tensor):
>>> from torchmetrics.functional.classification import multiclass_jaccard_index
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> multiclass_jaccard_index(preds, target, num_classes=3)
tensor(0.6667)

multilabel_jaccard_index

torchmetrics.functional.classification.multilabel_jaccard_index(preds, target, num_labels, threshold=0.5, average='macro', ignore_index=None, validate_args=True, zero_division=0.0)[source]

Calculate the Jaccard index for multilabel tasks.

The Jaccard index (also known as the intersection over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets:

\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]

Accepts the following input tensors:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifying the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • zero_division (float) – Value to replace when there is a division by zero. Should be 0 or 1.

Return type:

Tensor

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multilabel_jaccard_index
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> multilabel_jaccard_index(preds, target, num_labels=3)
tensor(0.5000)
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_jaccard_index
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_jaccard_index(preds, target, num_labels=3)
tensor(0.5000)