Complete Intersection Over Union (cIoU)
Module Interface
- class torchmetrics.detection.ciou.CompleteIntersectionOverUnion(box_format='xyxy', iou_threshold=None, class_metrics=False, **kwargs)[source]
Computes Complete Intersection Over Union (CIoU) <https://arxiv.org/abs/2005.03572>`_.
As input to
forward
andupdate
the metric accepts the following input:preds
(List
): A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dict:boxes: (
FloatTensor
) of shape(num_boxes, 4)
containingnum_boxes
detection boxes of the format specified in the constructor. By default, this method expects(xmin, ymin, xmax, ymax)
in absolute image coordinates.scores:
FloatTensor
of shape(num_boxes)
containing detection scores for the boxes.labels:
IntTensor
of shape(num_boxes)
containing 0-indexed detection classes for the boxes.
target
(List
) A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dict:boxes:
FloatTensor
of shape(num_boxes, 4)
containingnum_boxes
ground truth boxes of the format specified in the constructor. By default, this method expects(xmin, ymin, xmax, ymax)
in absolute image coordinates.labels:
IntTensor
of shape(num_boxes)
containing 0-indexed ground truth classes for the boxes.
As output of
forward
andcompute
the metric returns the following output:ciou_dict
: A dictionary containing the following key-values:
- Parameters:
box_format (
str
) – Input format of given boxes. Supported formats are[`xyxy`, `xywh`, `cxcywh`]
.iou_thresholds – Optional IoU thresholds for evaluation. If set to None the threshold is ignored.
class_metrics (
bool
) – Option to enable per-class metrics for IoU. Has a performance impact.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> import torch >>> from torchmetrics.detection import CompleteIntersectionOverUnion >>> preds = [ ... { ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), ... "scores": torch.tensor([0.236, 0.56]), ... "labels": torch.tensor([4, 5]), ... } ... ] >>> target = [ ... { ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]), ... "labels": torch.tensor([5]), ... } ... ] >>> metric = CompleteIntersectionOverUnion() >>> metric(preds, target) {'ciou': tensor(-0.5694)}
- Raises:
ModuleNotFoundError – If torchvision is not installed with version 0.13.0 or newer.
- plot(val=None, ax=None)[source]
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting single value >>> import torch >>> from torchmetrics.detection import CompleteIntersectionOverUnion >>> preds = [ ... { ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), ... "scores": torch.tensor([0.236, 0.56]), ... "labels": torch.tensor([4, 5]), ... } ... ] >>> target = [ ... { ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]), ... "labels": torch.tensor([5]), ... } ... ] >>> metric = CompleteIntersectionOverUnion() >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.detection import CompleteIntersectionOverUnion >>> preds = [ ... { ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), ... "scores": torch.tensor([0.236, 0.56]), ... "labels": torch.tensor([4, 5]), ... } ... ] >>> target = lambda : [ ... { ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]) + torch.randint(-10, 10, (1, 4)), ... "labels": torch.tensor([5]), ... } ... ] >>> metric = CompleteIntersectionOverUnion() >>> vals = [] >>> for _ in range(20): ... vals.append(metric(preds, target())) >>> fig_, ax_ = metric.plot(vals)
Functional Interface
- torchmetrics.functional.detection.ciou.complete_intersection_over_union(preds, target, iou_threshold=None, replacement_val=0, aggregate=True)[source]
Compute Complete Intersection over Union between two sets of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format with 0 <= x1 < x2 and 0 <= y1 < y2.
- Parameters:
preds (
Tensor
) – The input tensor containing the predicted bounding boxes.target (
Tensor
) – The tensor containing the ground truth.iou_threshold (
Optional
[float
]) – Optional IoU thresholds for evaluation. If set to None the threshold is ignored.replacement_val (
float
) – Value to replace values under the threshold with.aggregate (
bool
) – Return the average value instead of the complete IoU matrix.
- Return type:
Example
>>> import torch >>> from torchmetrics.functional.detection import complete_intersection_over_union >>> preds = torch.Tensor([[100, 100, 200, 200]]) >>> target = torch.Tensor([[110, 110, 210, 210]]) >>> complete_intersection_over_union(preds, target) tensor(0.6724)