Log AUC¶
Module Interface¶
- class torchmetrics.LogAUC(**kwargs)[source]¶
Compute the Log AUC score for multiclass classification tasks.
The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate is of high importance.
This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofBinaryLogAUC
,MulticlassLogAUC
andMultilabelLogAUC
for the specific details of each argument influence and examples.
BinaryLogAUC¶
- class torchmetrics.classification.BinaryLogAUC(fpr_range=(0.001, 0.1), thresholds=None, ignore_index=None, validate_args=False, **kwargs)[source]¶
Compute the Log AUC score for binary classification tasks.
The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate is of high importance.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, ...)
containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(Tensor
): An int tensor of shape(N, ...)
containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
As output to
forward
andcompute
the metric returns the following output:logauc
(Tensor
): A single scalar with the logauc score.
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).
- Parameters:
fpr_range¶ (
Tuple
[float
,float
]) – 2-element tuple with the lower and upper bound of the false positive rate range to compute the log AUC score.thresholds¶ (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index¶ (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args¶ (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.classification import BinaryLogAUC >>> preds = tensor([0.75, 0.05, 0.05, 0.05, 0.05]) >>> target = tensor([1, 0, 0, 0, 0]) >>> metric = BinaryLogAUC() >>> metric(preds, target) tensor(1.)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single >>> import torch >>> from torchmetrics.classification import BinaryLogAUC >>> metric = BinaryLogAUC() >>> metric.update(torch.rand(20,), torch.randint(2, (20,))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.classification import BinaryLogAUC >>> metric = BinaryLogAUC() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(20,), torch.randint(2, (20,)))) >>> fig_, ax_ = metric.plot(values)
MulticlassLogAUC¶
- class torchmetrics.classification.MulticlassLogAUC(num_classes, fpr_range=(0.001, 0.1), average=None, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the Log AUC score for multiclass classification tasks.
The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate is of high importance.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, C, ...)
containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.target
(Tensor
): An int tensor of shape(N, ...)
containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).
As output to
forward
andcompute
the metric returns the following output:logauc
(Tensor
): If average=None|”none” then a 1d tensor of shape (n_classes, ) will be returned with logauc score per class. If average=”macro” then a single scalar is returned.
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).
- Parameters:
num_classes¶ (
int
) – Integer specifying the number of classesfpr_range¶ (
Tuple
[float
,float
]) – 2-element tuple with the lower and upper bound of the false positive rate range to compute the log AUC score.average¶ (
Optional
[Literal
['macro'
,'none'
]]) –Defines the reduction that is applied over classes. Should be one of the following:
"macro"
: Calculate score for each class and average them"weighted"
: calculates score for each class and computes weighted average using their support"none"
orNone
: calculates score for each class and applies no reduction
thresholds¶ (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args¶ (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassLogAUC >>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = tensor([0, 1, 3, 2]) >>> metric = MulticlassLogAUC(num_classes=5, average="macro", thresholds=None) >>> metric(preds, target) tensor(0.4000) >>> metric = MulticlassLogAUC(num_classes=5, average=None, thresholds=None) >>> metric(preds, target) tensor([1., 1., 0., 0., 0.])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single >>> import torch >>> from torchmetrics.classification import MulticlassLogAUC >>> metric = MulticlassLogAUC(num_classes=3) >>> metric.update(torch.randn(20, 3), torch.randint(3,(20,))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.classification import MulticlassLogAUC >>> metric = MulticlassLogAUC(num_classes=3) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.randn(20, 3), torch.randint(3, (20,)))) >>> fig_, ax_ = metric.plot(values)
MultilabelLogAUC¶
- class torchmetrics.classification.MultilabelLogAUC(num_labels, fpr_range=(0.001, 0.1), average=None, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the Log AUC score for multiclass classification tasks.
The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate is of high importance.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, C, ...)
containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(Tensor
): An int tensor of shape(N, C, ...)
containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).
As output to
forward
andcompute
the metric returns the following output:logauc
(Tensor
): If average=None|”none” then a 1d tensor of shape (num_labels, ) will be returned with logauc score per class. If average=”macro” then a single scalar is returned.
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).
- Parameters:
fpr_range¶ (
Tuple
[float
,float
]) – 2-element tuple with the lower and upper bound of the false positive rate range to compute the log AUC score.average¶ (
Optional
[Literal
['macro'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
"macro"
: Calculate the score for each label and average them"none"
orNone
: calculates score for each label and applies no reduction
thresholds¶ (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args¶ (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.classification import MultilabelLogAUC >>> preds = tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], ... [0.05, 0.65, 0.05]]) >>> target = tensor([[1, 0, 1], ... [0, 0, 0], ... [0, 1, 1], ... [1, 1, 1]]) >>> metric = MultilabelLogAUC(num_labels=3, average="macro", thresholds=None) >>> metric(preds, target) tensor(0.3945) >>> metric = MultilabelLogAUC(num_labels=3, average=None, thresholds=None) >>> metric(preds, target) tensor([0.5000, 0.0000, 0.6835])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single >>> import torch >>> from torchmetrics.classification import MultilabelLogAUC >>> metric = MultilabelLogAUC(num_labels=3) >>> metric.update(torch.rand(20,3), torch.randint(2, (20,3))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.classification import MultilabelLogAUC >>> metric = MultilabelLogAUC(num_labels=3) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(20,3), torch.randint(2, (20,3)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.logauc(preds, target, task, thresholds=None, num_classes=None, num_labels=None, fpr_range=(0.001, 0.1), average=None, ignore_index=None, validate_args=True)[source]¶
Compute the Log AUC score for classification tasks.
The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate is of high importance.
binary_logauc¶
- torchmetrics.functional.classification.binary_logauc(preds, target, fpr_range=(0.001, 0.1), thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute the Log AUC score for binary classification tasks.
The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate is of high importance.
Accepts the following input tensors:
preds
(float tensor):(N, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).
- Parameters:
fpr_range¶ (
Tuple
[float
,float
]) – 2-element tuple with the lower and upper bound of the false positive rate range to compute the log AUC score.thresholds¶ (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index¶ (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args¶ (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
A single scalar with the log auc score
Example
>>> from torchmetrics.functional.classification import binary_logauc >>> from torch import tensor >>> preds = tensor([0.75, 0.05, 0.05, 0.05, 0.05]) >>> target = tensor([1, 0, 0, 0, 0]) >>> binary_logauc(preds, target) tensor(1.)
multiclass_logauc¶
- torchmetrics.functional.classification.multiclass_logauc(preds, target, num_classes, fpr_range=(0.001, 0.1), average='macro', thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute the Log AUC score for multiclass classification tasks.
The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate is of high importance.
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).
- Parameters:
num_classes¶ (
int
) – Integer specifying the number of classesfpr_range¶ (
Tuple
[float
,float
]) – 2-element tuple with the lower and upper bound of the false positive rate range to compute the log AUC score.thresholds¶ (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
average¶ (
Optional
[Literal
['macro'
,'none'
]]) –Defines the reduction that is applied over classes. Should be one of the following:
macro
: Calculate score for each class and average them"none"
orNone
: calculates score for each class and applies no reduction
ignore_index¶ (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args¶ (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
Example
>>> from torchmetrics.functional.classification import multiclass_logauc >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> multiclass_logauc(preds, target, num_classes=5, average="macro", thresholds=None) tensor(0.4000) >>> multiclass_logauc(preds, target, num_classes=5, average=None, thresholds=None) tensor([1., 1., 0., 0., 0.])
multilabel_logauc¶
- torchmetrics.functional.classification.multilabel_logauc(preds, target, num_labels, fpr_range=(0.001, 0.1), average='macro', thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute the Log AUC score for multilabel classification tasks.
The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate is of high importance.
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, C, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).
- Parameters:
fpr_range¶ (
Tuple
[float
,float
]) – 2-element tuple with the lower and upper bound of the false positive rate range to compute the log AUC score.average¶ (
Optional
[Literal
['macro'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
macro
: Calculate score for each label and average them"none"
orNone
: calculates score for each label and applies no reduction
thresholds¶ (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index¶ (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args¶ (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
Example
>>> from torchmetrics.functional.classification import multilabel_logauc >>> preds = torch.tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], ... [0.05, 0.65, 0.05]]) >>> target = torch.tensor([[1, 0, 1], ... [0, 0, 0], ... [0, 1, 1], ... [1, 1, 1]]) >>> multilabel_logauc(preds, target, num_labels=3, average="macro", thresholds=None) tensor(0.3945) >>> multilabel_logauc(preds, target, num_labels=3, average=None, thresholds=None) tensor([0.5000, 0.0000, 0.6835])