torchmetrics.utilities
In the following is listed public utility functions that may be beneficial to use in your own code. These functions are
not part of the public API and may change at any time.
torchmetrics.utilities.data
The data utilities are used to help with data manipulation, such as converting labels in classification from one
format to another.
select_topk
-
torchmetrics.utilities.data.select_topk(prob_tensor, topk=1, dim=1)[source]
Convert a probability tensor to binary by selecting top-k the highest entries.
- Parameters:
prob_tensor (Tensor
) – dense tensor of shape [..., C, ...]
, where C
is in the
position defined by the dim
argument
topk (int
) – number of the highest entries to turn into 1s
dim (int
) – dimension on which to compare entries
- Return type:
Tensor
- Returns:
A binary tensor of the same shape as the input tensor of type torch.int32
Example
>>> x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]])
>>> select_topk(x, topk=2)
tensor([[0, 1, 1],
[1, 1, 0]], dtype=torch.int32)
to_categorical
-
torchmetrics.utilities.data.to_categorical(x, argmax_dim=1)[source]
Convert a tensor of probabilities to a dense label tensor.
- Parameters:
x (Tensor
) – probabilities to get the categorical label [N, d1, d2, …]
argmax_dim (int
) – dimension to apply
- Return type:
Tensor
- Returns:
A tensor with categorical labels [N, d2, …]
Example
>>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]])
>>> to_categorical(x)
tensor([1, 0])
to_onehot
-
torchmetrics.utilities.data.to_onehot(label_tensor, num_classes=None)[source]
Convert a dense label tensor to one-hot format.
- Parameters:
label_tensor (Tensor
) – dense label tensor, with shape [N, d1, d2, …]
num_classes (Optional
[int
]) – number of classes C
- Return type:
Tensor
- Returns:
A sparse label tensor with shape [N, C, d1, d2, …]
Example
>>> x = torch.tensor([1, 2, 3])
>>> to_onehot(x)
tensor([[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]])
dim_zero_cat
-
torchmetrics.utilities.data.dim_zero_cat(x)[source]
Concatenation along the zero dimension.
- Return type:
Tensor
dim_zero_max
-
torchmetrics.utilities.data.dim_zero_max(x)[source]
Max along the zero dimension.
- Return type:
Tensor
dim_zero_mean
-
torchmetrics.utilities.data.dim_zero_mean(x)[source]
Average along the zero dimension.
- Return type:
Tensor
dim_zero_min
-
torchmetrics.utilities.data.dim_zero_min(x)[source]
Min along the zero dimension.
- Return type:
Tensor
dim_zero_sum
-
torchmetrics.utilities.data.dim_zero_sum(x)[source]
Summation along the zero dimension.
- Return type:
Tensor
torchmetrics.utilities.distributed
The distributed utilities are used to help with synchronization of metrics across multiple processes.
gather_all_tensors
-
torchmetrics.utilities.distributed.gather_all_tensors(result, group=None)[source]
Gather all tensors from several ddp processes onto a list that is broadcasted to all processes.
Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case
tensors are padded, gathered and then trimmed to secure equal workload for all processes.
- Parameters:
-
- Return type:
List
[Tensor
]
- Returns:
list with size equal to the process group where element i corresponds to result tensor from process i
torchmetrics.utilities.exceptions
TorchMetricsUserError
-
class torchmetrics.utilities.exceptions.TorchMetricsUserError[source]
Error used to inform users of a wrong combination of Metric API calls.
TorchMetricsUserWarning
-
class torchmetrics.utilities.exceptions.TorchMetricsUserWarning[source]
Error used to inform users of specific warnings due to the torchmetrics API.