# Critical Success Index (CSI)¶

## Module Interface¶

class torchmetrics.regression.CriticalSuccessIndex(threshold, keep_sequence_dim=None, **kwargs)[source]

Calculate critical success index (CSI).

Critical success index (also known as the threat score) is a statistic used weather forecasting that measures forecast performance over inputs binarized at a specified threshold. It is defined as:

$\text{CSI} = \frac{\text{TP}}{\text{TP}+\text{FN}+\text{FP}}$

Where $$\text{TP}$$, $$\text{FN}$$ and $$\text{FP}$$ represent the number of true positives, false negatives and false positives respectively after binarizing the input tensors.

Parameters:
• threshold (float) – Values above or equal to threshold are replaced with 1, below by 0

• keep_sequence_dim (Optional[int]) – Index of the sequence dimension if the inputs are sequences of images. If specified, the score will be calculated separately for each image in the sequence. If None, the score will be calculated across all dimensions.

Example

>>> import torch
>>> from torchmetrics.regression import CriticalSuccessIndex
>>> x = torch.Tensor([[0.2, 0.7], [0.9, 0.3]])
>>> y = torch.Tensor([[0.4, 0.2], [0.8, 0.6]])
>>> csi = CriticalSuccessIndex(0.5)
>>> csi(x, y)
tensor(0.3333)


Example

>>> import torch
>>> from torchmetrics.regression import CriticalSuccessIndex
>>> x = torch.Tensor([[[0.2, 0.7], [0.9, 0.3]], [[0.2, 0.7], [0.9, 0.3]]])
>>> y = torch.Tensor([[[0.4, 0.2], [0.8, 0.6]], [[0.4, 0.2], [0.8, 0.6]]])
>>> csi = CriticalSuccessIndex(0.5, keep_sequence_dim=0)
>>> csi(x, y)
tensor([0.3333, 0.3333])


## Functional Interface¶

torchmetrics.functional.regression.critical_success_index(preds, target, threshold, keep_sequence_dim=None)[source]

Compute critical success index.

Parameters:
Return type:

Tensor

Returns:

If keep_sequence_dim is specified, the metric returns a vector of with CSI scores for each image in the sequence. Otherwise, it returns a scalar tensor with the CSI score.

Example

>>> import torch
>>> from torchmetrics.functional.regression import critical_success_index
>>> x = torch.Tensor([[0.2, 0.7], [0.9, 0.3]])
>>> y = torch.Tensor([[0.4, 0.2], [0.8, 0.6]])
>>> critical_success_index(x, y, 0.5)
tensor(0.3333)


Example

>>> import torch
>>> from torchmetrics.functional.regression import critical_success_index
>>> x = torch.Tensor([[[0.2, 0.7], [0.9, 0.3]], [[0.2, 0.7], [0.9, 0.3]]])
>>> y = torch.Tensor([[[0.4, 0.2], [0.8, 0.6]], [[0.4, 0.2], [0.8, 0.6]]])
>>> critical_success_index(x, y, 0.5, keep_sequence_dim=0)
tensor([0.3333, 0.3333])