torchmetrics.utilities
In the following is listed public utility functions that may be beneficial to use in your own code. These functions are not part of the public API and may change at any time.
torchmetrics.utilities.data
The data utilities are used to help with data manipulation, such as converting labels in classification from one format to another.
select_topk
- torchmetrics.utilities.data.select_topk(prob_tensor, topk=1, dim=1)[source]
Convert a probability tensor to binary by selecting top-k the highest entries.
- Parameters:
- Return type:
- Returns:
A binary tensor of the same shape as the input tensor of type
torch.int32
Example
>>> x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]]) >>> select_topk(x, topk=2) tensor([[0, 1, 1], [1, 1, 0]], dtype=torch.int32)
to_categorical
- torchmetrics.utilities.data.to_categorical(x, argmax_dim=1)[source]
Convert a tensor of probabilities to a dense label tensor.
- Parameters:
- Return type:
- Returns:
A tensor with categorical labels [N, d2, …]
Example
>>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]]) >>> to_categorical(x) tensor([1, 0])
to_onehot
- torchmetrics.utilities.data.to_onehot(label_tensor, num_classes=None)[source]
Convert a dense label tensor to one-hot format.
- Parameters:
- Return type:
- Returns:
A sparse label tensor with shape [N, C, d1, d2, …]
Example
>>> x = torch.tensor([1, 2, 3]) >>> to_onehot(x) tensor([[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
dim_zero_cat
dim_zero_max
dim_zero_mean
dim_zero_min
dim_zero_sum
torchmetrics.utilities.distributed
The distributed utilities are used to help with synchronization of metrics across multiple processes.
gather_all_tensors
- torchmetrics.utilities.distributed.gather_all_tensors(result, group=None)[source]
Gather all tensors from several ddp processes onto a list that is broadcast to all processes.
Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case tensors are padded, gathered and then trimmed to secure equal workload for all processes.
torchmetrics.utilities.exceptions
TorchMetricsUserError
- class torchmetrics.utilities.exceptions.TorchMetricsUserError[source]
Error used to inform users of a wrong combination of Metric API calls.
TorchMetricsUserWarning
- class torchmetrics.utilities.exceptions.TorchMetricsUserWarning[source]
Error used to inform users of specific warnings due to the torchmetrics API.