]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xvision/ops/boxes.py
docs - added deprecation notice
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / ops / boxes.py
1 import torch
2 from torch.jit.annotations import Tuple
3 from torch import Tensor
4 import torchvision
7 def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
8     """
9     Performs non-maximum suppression (NMS) on the boxes according
10     to their intersection-over-union (IoU).
12     NMS iteratively removes lower scoring boxes which have an
13     IoU greater than iou_threshold with another (higher scoring)
14     box.
16     If multiple boxes have the exact same score and satisfy the IoU
17     criterion with respect to a reference box, the selected box is
18     not guaranteed to be the same between CPU and GPU. This is similar
19     to the behavior of argsort in PyTorch when repeated values are present.
21     Parameters
22     ----------
23     boxes : Tensor[N, 4])
24         boxes to perform NMS on. They
25         are expected to be in (x1, y1, x2, y2) format
26     scores : Tensor[N]
27         scores for each one of the boxes
28     iou_threshold : float
29         discards all overlapping
30         boxes with IoU > iou_threshold
32     Returns
33     -------
34     keep : Tensor
35         int64 tensor with the indices
36         of the elements that have been kept
37         by NMS, sorted in decreasing order of scores
38     """
39     return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
42 @torch.jit._script_if_tracing
43 def batched_nms(
44     boxes: Tensor,
45     scores: Tensor,
46     idxs: Tensor,
47     iou_threshold: float,
48 ) -> Tensor:
49     """
50     Performs non-maximum suppression in a batched fashion.
52     Each index value correspond to a category, and NMS
53     will not be applied between elements of different categories.
55     Parameters
56     ----------
57     boxes : Tensor[N, 4]
58         boxes where NMS will be performed. They
59         are expected to be in (x1, y1, x2, y2) format
60     scores : Tensor[N]
61         scores for each one of the boxes
62     idxs : Tensor[N]
63         indices of the categories for each one of the boxes.
64     iou_threshold : float
65         discards all overlapping boxes
66         with IoU > iou_threshold
68     Returns
69     -------
70     keep : Tensor
71         int64 tensor with the indices of
72         the elements that have been kept by NMS, sorted
73         in decreasing order of scores
74     """
75     if boxes.numel() == 0:
76         return torch.empty((0,), dtype=torch.int64, device=boxes.device)
77     # strategy: in order to perform NMS independently per class.
78     # we add an offset to all the boxes. The offset is dependent
79     # only on the class idx, and is large enough so that boxes
80     # from different classes do not overlap
81     else:
82         max_coordinate = boxes.max()
83         offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
84         boxes_for_nms = boxes + offsets[:, None]
85         keep = nms(boxes_for_nms, scores, iou_threshold)
86         return keep
89 def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
90     """
91     Remove boxes which contains at least one side smaller than min_size.
93     Arguments:
94         boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
95         min_size (float): minimum size
97     Returns:
98         keep (Tensor[K]): indices of the boxes that have both sides
99             larger than min_size
100     """
101     ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
102     keep = (ws >= min_size) & (hs >= min_size)
103     keep = keep.nonzero().squeeze(1)
104     return keep
107 def clip_boxes_to_image(boxes: Tensor, size: Tuple[int, int]) -> Tensor:
108     """
109     Clip boxes so that they lie inside an image of size `size`.
111     Arguments:
112         boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
113         size (Tuple[height, width]): size of the image
115     Returns:
116         clipped_boxes (Tensor[N, 4])
117     """
118     dim = boxes.dim()
119     boxes_x = boxes[..., 0::2]
120     boxes_y = boxes[..., 1::2]
121     height, width = size
123     if torchvision._is_tracing():
124         boxes_x = torch.max(boxes_x, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
125         boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
126         boxes_y = torch.max(boxes_y, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
127         boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
128     else:
129         boxes_x = boxes_x.clamp(min=0, max=width)
130         boxes_y = boxes_y.clamp(min=0, max=height)
132     clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
133     return clipped_boxes.reshape(boxes.shape)
136 def box_area(boxes: Tensor) -> Tensor:
137     """
138     Computes the area of a set of bounding boxes, which are specified by its
139     (x1, y1, x2, y2) coordinates.
141     Arguments:
142         boxes (Tensor[N, 4]): boxes for which the area will be computed. They
143             are expected to be in (x1, y1, x2, y2) format
145     Returns:
146         area (Tensor[N]): area for each box
147     """
148     return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
151 # implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
152 # with slight modifications
153 def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
154     """
155     Return intersection-over-union (Jaccard index) of boxes.
157     Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
159     Arguments:
160         boxes1 (Tensor[N, 4])
161         boxes2 (Tensor[M, 4])
163     Returns:
164         iou (Tensor[N, M]): the NxM matrix containing the pairwise
165             IoU values for every element in boxes1 and boxes2
166     """
167     area1 = box_area(boxes1)
168     area2 = box_area(boxes2)
170     lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
171     rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]
173     wh = (rb - lt).clamp(min=0)  # [N,M,2]
174     inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]
176     iou = inter / (area1[:, None] + area2 - inter)
177     return iou