[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / ops / boxes.py
1 import torch
2 from torch.jit.annotations import Tuple
3 from torch import Tensor
4 import torchvision
7 def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
8 """
9 Performs non-maximum suppression (NMS) on the boxes according
10 to their intersection-over-union (IoU).
12 NMS iteratively removes lower scoring boxes which have an
13 IoU greater than iou_threshold with another (higher scoring)
14 box.
16 If multiple boxes have the exact same score and satisfy the IoU
17 criterion with respect to a reference box, the selected box is
18 not guaranteed to be the same between CPU and GPU. This is similar
19 to the behavior of argsort in PyTorch when repeated values are present.
21 Parameters
22 ----------
23 boxes : Tensor[N, 4])
24 boxes to perform NMS on. They
25 are expected to be in (x1, y1, x2, y2) format
26 scores : Tensor[N]
27 scores for each one of the boxes
28 iou_threshold : float
29 discards all overlapping
30 boxes with IoU > iou_threshold
32 Returns
33 -------
34 keep : Tensor
35 int64 tensor with the indices
36 of the elements that have been kept
37 by NMS, sorted in decreasing order of scores
38 """
39 return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
42 @torch.jit._script_if_tracing
43 def batched_nms(
44 boxes: Tensor,
45 scores: Tensor,
46 idxs: Tensor,
47 iou_threshold: float,
48 ) -> Tensor:
49 """
50 Performs non-maximum suppression in a batched fashion.
52 Each index value correspond to a category, and NMS
53 will not be applied between elements of different categories.
55 Parameters
56 ----------
57 boxes : Tensor[N, 4]
58 boxes where NMS will be performed. They
59 are expected to be in (x1, y1, x2, y2) format
60 scores : Tensor[N]
61 scores for each one of the boxes
62 idxs : Tensor[N]
63 indices of the categories for each one of the boxes.
64 iou_threshold : float
65 discards all overlapping boxes
66 with IoU > iou_threshold
68 Returns
69 -------
70 keep : Tensor
71 int64 tensor with the indices of
72 the elements that have been kept by NMS, sorted
73 in decreasing order of scores
74 """
75 if boxes.numel() == 0:
76 return torch.empty((0,), dtype=torch.int64, device=boxes.device)
77 # strategy: in order to perform NMS independently per class.
78 # we add an offset to all the boxes. The offset is dependent
79 # only on the class idx, and is large enough so that boxes
80 # from different classes do not overlap
81 else:
82 max_coordinate = boxes.max()
83 offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
84 boxes_for_nms = boxes + offsets[:, None]
85 keep = nms(boxes_for_nms, scores, iou_threshold)
86 return keep
89 def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
90 """
91 Remove boxes which contains at least one side smaller than min_size.
93 Arguments:
94 boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
95 min_size (float): minimum size
97 Returns:
98 keep (Tensor[K]): indices of the boxes that have both sides
99 larger than min_size
100 """
101 ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
102 keep = (ws >= min_size) & (hs >= min_size)
103 keep = keep.nonzero().squeeze(1)
104 return keep
107 def clip_boxes_to_image(boxes: Tensor, size: Tuple[int, int]) -> Tensor:
108 """
109 Clip boxes so that they lie inside an image of size `size`.
111 Arguments:
112 boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
113 size (Tuple[height, width]): size of the image
115 Returns:
116 clipped_boxes (Tensor[N, 4])
117 """
118 dim = boxes.dim()
119 boxes_x = boxes[..., 0::2]
120 boxes_y = boxes[..., 1::2]
121 height, width = size
123 if torchvision._is_tracing():
124 boxes_x = torch.max(boxes_x, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
125 boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
126 boxes_y = torch.max(boxes_y, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
127 boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
128 else:
129 boxes_x = boxes_x.clamp(min=0, max=width)
130 boxes_y = boxes_y.clamp(min=0, max=height)
132 clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
133 return clipped_boxes.reshape(boxes.shape)
136 def box_area(boxes: Tensor) -> Tensor:
137 """
138 Computes the area of a set of bounding boxes, which are specified by its
139 (x1, y1, x2, y2) coordinates.
141 Arguments:
142 boxes (Tensor[N, 4]): boxes for which the area will be computed. They
143 are expected to be in (x1, y1, x2, y2) format
145 Returns:
146 area (Tensor[N]): area for each box
147 """
148 return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
151 # implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
152 # with slight modifications
153 def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
154 """
155 Return intersection-over-union (Jaccard index) of boxes.
157 Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
159 Arguments:
160 boxes1 (Tensor[N, 4])
161 boxes2 (Tensor[M, 4])
163 Returns:
164 iou (Tensor[N, M]): the NxM matrix containing the pairwise
165 IoU values for every element in boxes1 and boxes2
166 """
167 area1 = box_area(boxes1)
168 area2 = box_area(boxes2)
170 lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
171 rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
173 wh = (rb - lt).clamp(min=0) # [N,M,2]
174 inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
176 iou = inter / (area1[:, None] + area2 - inter)
177 return iou