[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / models / detection / roi_heads.py
1 import torch
3 import torch.nn.functional as F
4 from torch import nn
6 from ...ops import boxes as box_ops
7 from ...ops import misc as misc_nn_ops
8 from ...ops import roi_align
10 from . import _utils as det_utils
13 def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
14 """
15 Computes the loss for Faster R-CNN.
17 Arguments:
18 class_logits (Tensor)
19 box_regression (Tensor)
20 labels (list[BoxList])
21 regression_targets (Tensor)
23 Returns:
24 classification_loss (Tensor)
25 box_loss (Tensor)
26 """
28 labels = torch.cat(labels, dim=0)
29 regression_targets = torch.cat(regression_targets, dim=0)
31 classification_loss = F.cross_entropy(class_logits, labels)
33 # get indices that correspond to the regression targets for
34 # the corresponding ground truth labels, to be used with
35 # advanced indexing
36 sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1)
37 labels_pos = labels[sampled_pos_inds_subset]
38 N, num_classes = class_logits.shape
39 box_regression = box_regression.reshape(N, -1, 4)
41 box_loss = F.smooth_l1_loss(
42 box_regression[sampled_pos_inds_subset, labels_pos],
43 regression_targets[sampled_pos_inds_subset],
44 reduction="sum",
45 )
46 box_loss = box_loss / labels.numel()
48 return classification_loss, box_loss
51 def maskrcnn_inference(x, labels):
52 """
53 From the results of the CNN, post process the masks
54 by taking the mask corresponding to the class with max
55 probability (which are of fixed size and directly output
56 by the CNN) and return the masks in the mask field of the BoxList.
58 Arguments:
59 x (Tensor): the mask logits
60 boxes (list[BoxList]): bounding boxes that are used as
61 reference, one for ech image
63 Returns:
64 results (list[BoxList]): one BoxList for each image, containing
65 the extra field mask
66 """
67 mask_prob = x.sigmoid()
69 # select masks coresponding to the predicted classes
70 num_masks = x.shape[0]
71 boxes_per_image = [len(l) for l in labels]
72 labels = torch.cat(labels)
73 index = torch.arange(num_masks, device=labels.device)
74 mask_prob = mask_prob[index, labels][:, None]
76 mask_prob = mask_prob.split(boxes_per_image, dim=0)
78 return mask_prob
81 def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
82 """
83 Given segmentation masks and the bounding boxes corresponding
84 to the location of the masks in the image, this function
85 crops and resizes the masks in the position defined by the
86 boxes. This prepares the masks for them to be fed to the
87 loss computation as the targets.
88 """
89 matched_idxs = matched_idxs.to(boxes)
90 rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
91 gt_masks = gt_masks[:, None].to(rois)
92 return roi_align(gt_masks, rois, (M, M), 1)[:, 0]
95 def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
96 """
97 Arguments:
98 proposals (list[BoxList])
99 mask_logits (Tensor)
100 targets (list[BoxList])
102 Return:
103 mask_loss (Tensor): scalar tensor containing the loss
104 """
106 discretization_size = mask_logits.shape[-1]
107 labels = [l[idxs] for l, idxs in zip(gt_labels, mask_matched_idxs)]
108 mask_targets = [
109 project_masks_on_boxes(m, p, i, discretization_size)
110 for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
111 ]
113 labels = torch.cat(labels, dim=0)
114 mask_targets = torch.cat(mask_targets, dim=0)
116 # torch.mean (in binary_cross_entropy_with_logits) doesn't
117 # accept empty tensors, so handle it separately
118 if mask_targets.numel() == 0:
119 return mask_logits.sum() * 0
121 mask_loss = F.binary_cross_entropy_with_logits(
122 mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
123 )
124 return mask_loss
127 def keypoints_to_heatmap(keypoints, rois, heatmap_size):
128 offset_x = rois[:, 0]
129 offset_y = rois[:, 1]
130 scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
131 scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
133 offset_x = offset_x[:, None]
134 offset_y = offset_y[:, None]
135 scale_x = scale_x[:, None]
136 scale_y = scale_y[:, None]
138 x = keypoints[..., 0]
139 y = keypoints[..., 1]
141 x_boundary_inds = x == rois[:, 2][:, None]
142 y_boundary_inds = y == rois[:, 3][:, None]
144 x = (x - offset_x) * scale_x
145 x = x.floor().long()
146 y = (y - offset_y) * scale_y
147 y = y.floor().long()
149 x[x_boundary_inds] = heatmap_size - 1
150 y[y_boundary_inds] = heatmap_size - 1
152 valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
153 vis = keypoints[..., 2] > 0
154 valid = (valid_loc & vis).long()
156 lin_ind = y * heatmap_size + x
157 heatmaps = lin_ind * valid
159 return heatmaps, valid
162 def heatmaps_to_keypoints(maps, rois):
163 """Extract predicted keypoint locations from heatmaps. Output has shape
164 (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
165 for each keypoint.
166 """
167 # This function converts a discrete image coordinate in a HEATMAP_SIZE x
168 # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
169 # consistency with keypoints_to_heatmap_labels by using the conversion from
170 # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
171 # continuous coordinate.
172 offset_x = rois[:, 0]
173 offset_y = rois[:, 1]
175 widths = rois[:, 2] - rois[:, 0]
176 heights = rois[:, 3] - rois[:, 1]
177 widths = widths.clamp(min=1)
178 heights = heights.clamp(min=1)
179 widths_ceil = widths.ceil()
180 heights_ceil = heights.ceil()
182 num_keypoints = maps.shape[1]
183 xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
184 end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
185 for i in range(len(rois)):
186 roi_map_width = int(widths_ceil[i].item())
187 roi_map_height = int(heights_ceil[i].item())
188 width_correction = widths[i] / roi_map_width
189 height_correction = heights[i] / roi_map_height
190 roi_map = torch.nn.functional.interpolate(
191 maps[i][None], size=(roi_map_height, roi_map_width), mode='bicubic', align_corners=False)[0]
192 # roi_map_probs = scores_to_probs(roi_map.copy())
193 w = roi_map.shape[2]
194 pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
195 x_int = pos % w
196 y_int = (pos - x_int) // w
197 # assert (roi_map_probs[k, y_int, x_int] ==
198 # roi_map_probs[k, :, :].max())
199 x = (x_int.float() + 0.5) * width_correction
200 y = (y_int.float() + 0.5) * height_correction
201 xy_preds[i, 0, :] = x + offset_x[i]
202 xy_preds[i, 1, :] = y + offset_y[i]
203 xy_preds[i, 2, :] = 1
204 end_scores[i, :] = roi_map[torch.arange(num_keypoints), y_int, x_int]
206 return xy_preds.permute(0, 2, 1), end_scores
209 def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
210 N, K, H, W = keypoint_logits.shape
211 assert H == W
212 discretization_size = H
213 heatmaps = []
214 valid = []
215 for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
216 kp = gt_kp_in_image[midx]
217 heatmaps_per_image, valid_per_image = keypoints_to_heatmap(
218 kp, proposals_per_image, discretization_size
219 )
220 heatmaps.append(heatmaps_per_image.view(-1))
221 valid.append(valid_per_image.view(-1))
223 keypoint_targets = torch.cat(heatmaps, dim=0)
224 valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
225 valid = torch.nonzero(valid).squeeze(1)
227 # torch.mean (in binary_cross_entropy_with_logits) does'nt
228 # accept empty tensors, so handle it sepaartely
229 if keypoint_targets.numel() == 0 or len(valid) == 0:
230 return keypoint_logits.sum() * 0
232 keypoint_logits = keypoint_logits.view(N * K, H * W)
234 keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
235 return keypoint_loss
238 def keypointrcnn_inference(x, boxes):
239 kp_probs = []
240 kp_scores = []
242 boxes_per_image = [len(box) for box in boxes]
243 x2 = x.split(boxes_per_image, dim=0)
245 for xx, bb in zip(x2, boxes):
246 kp_prob, scores = heatmaps_to_keypoints(xx, bb)
247 kp_probs.append(kp_prob)
248 kp_scores.append(scores)
250 return kp_probs, kp_scores
253 # the next two functions should be merged inside Masker
254 # but are kept here for the moment while we need them
255 # temporarily gor paste_mask_in_image
256 def expand_boxes(boxes, scale):
257 w_half = (boxes[:, 2] - boxes[:, 0]) * .5
258 h_half = (boxes[:, 3] - boxes[:, 1]) * .5
259 x_c = (boxes[:, 2] + boxes[:, 0]) * .5
260 y_c = (boxes[:, 3] + boxes[:, 1]) * .5
262 w_half *= scale
263 h_half *= scale
265 boxes_exp = torch.zeros_like(boxes)
266 boxes_exp[:, 0] = x_c - w_half
267 boxes_exp[:, 2] = x_c + w_half
268 boxes_exp[:, 1] = y_c - h_half
269 boxes_exp[:, 3] = y_c + h_half
270 return boxes_exp
273 def expand_masks(mask, padding):
274 M = mask.shape[-1]
275 scale = float(M + 2 * padding) / M
276 padded_mask = torch.nn.functional.pad(mask, (padding,) * 4)
277 return padded_mask, scale
280 def paste_mask_in_image(mask, box, im_h, im_w):
281 TO_REMOVE = 1
282 w = int(box[2] - box[0] + TO_REMOVE)
283 h = int(box[3] - box[1] + TO_REMOVE)
284 w = max(w, 1)
285 h = max(h, 1)
287 # Set shape to [batchxCxHxW]
288 mask = mask.expand((1, 1, -1, -1))
290 # Resize mask
291 mask = misc_nn_ops.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
292 mask = mask[0][0]
294 im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
295 x_0 = max(box[0], 0)
296 x_1 = min(box[2] + 1, im_w)
297 y_0 = max(box[1], 0)
298 y_1 = min(box[3] + 1, im_h)
300 im_mask[y_0:y_1, x_0:x_1] = mask[
301 (y_0 - box[1]):(y_1 - box[1]), (x_0 - box[0]):(x_1 - box[0])
302 ]
303 return im_mask
306 def paste_masks_in_image(masks, boxes, img_shape, padding=1):
307 masks, scale = expand_masks(masks, padding=padding)
308 boxes = expand_boxes(boxes, scale).to(dtype=torch.int64).tolist()
309 # im_h, im_w = img_shape.tolist()
310 im_h, im_w = img_shape
311 res = [
312 paste_mask_in_image(m[0], b, im_h, im_w)
313 for m, b in zip(masks, boxes)
314 ]
315 if len(res) > 0:
316 res = torch.stack(res, dim=0)[:, None]
317 else:
318 res = masks.new_empty((0, 1, im_h, im_w))
319 return res
322 class RoIHeads(torch.nn.Module):
323 def __init__(self,
324 box_roi_pool,
325 box_head,
326 box_predictor,
327 # Faster R-CNN training
328 fg_iou_thresh, bg_iou_thresh,
329 batch_size_per_image, positive_fraction,
330 bbox_reg_weights,
331 # Faster R-CNN inference
332 score_thresh,
333 nms_thresh,
334 detections_per_img,
335 # Mask
336 mask_roi_pool=None,
337 mask_head=None,
338 mask_predictor=None,
339 keypoint_roi_pool=None,
340 keypoint_head=None,
341 keypoint_predictor=None,
342 ):
343 super(RoIHeads, self).__init__()
345 self.box_similarity = box_ops.box_iou
346 # assign ground-truth boxes for each proposal
347 self.proposal_matcher = det_utils.Matcher(
348 fg_iou_thresh,
349 bg_iou_thresh,
350 allow_low_quality_matches=False)
352 self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(
353 batch_size_per_image,
354 positive_fraction)
356 if bbox_reg_weights is None:
357 bbox_reg_weights = (10., 10., 5., 5.)
358 self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
360 self.box_roi_pool = box_roi_pool
361 self.box_head = box_head
362 self.box_predictor = box_predictor
364 self.score_thresh = score_thresh
365 self.nms_thresh = nms_thresh
366 self.detections_per_img = detections_per_img
368 self.mask_roi_pool = mask_roi_pool
369 self.mask_head = mask_head
370 self.mask_predictor = mask_predictor
372 self.keypoint_roi_pool = keypoint_roi_pool
373 self.keypoint_head = keypoint_head
374 self.keypoint_predictor = keypoint_predictor
376 @property
377 def has_mask(self):
378 if self.mask_roi_pool is None:
379 return False
380 if self.mask_head is None:
381 return False
382 if self.mask_predictor is None:
383 return False
384 return True
386 @property
387 def has_keypoint(self):
388 if self.keypoint_roi_pool is None:
389 return False
390 if self.keypoint_head is None:
391 return False
392 if self.keypoint_predictor is None:
393 return False
394 return True
396 def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
397 matched_idxs = []
398 labels = []
399 for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
400 match_quality_matrix = self.box_similarity(gt_boxes_in_image, proposals_in_image)
401 matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
403 clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
405 labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
406 labels_in_image = labels_in_image.to(dtype=torch.int64)
408 # Label background (below the low threshold)
409 bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
410 labels_in_image[bg_inds] = 0
412 # Label ignore proposals (between low and high thresholds)
413 ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
414 labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler
416 matched_idxs.append(clamped_matched_idxs_in_image)
417 labels.append(labels_in_image)
418 return matched_idxs, labels
420 def subsample(self, labels):
421 sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
422 sampled_inds = []
423 for img_idx, (pos_inds_img, neg_inds_img) in enumerate(
424 zip(sampled_pos_inds, sampled_neg_inds)
425 ):
426 img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1)
427 sampled_inds.append(img_sampled_inds)
428 return sampled_inds
430 def add_gt_proposals(self, proposals, gt_boxes):
431 proposals = [
432 torch.cat((proposal, gt_box))
433 for proposal, gt_box in zip(proposals, gt_boxes)
434 ]
436 return proposals
438 def check_targets(self, targets):
439 assert targets is not None
440 assert all("boxes" in t for t in targets)
441 assert all("labels" in t for t in targets)
442 if self.has_mask:
443 assert all("masks" in t for t in targets)
445 def select_training_samples(self, proposals, targets):
446 self.check_targets(targets)
447 gt_boxes = [t["boxes"] for t in targets]
448 gt_labels = [t["labels"] for t in targets]
450 # append ground-truth bboxes to propos
451 proposals = self.add_gt_proposals(proposals, gt_boxes)
453 # get matching gt indices for each proposal
454 matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
455 # sample a fixed proportion of positive-negative proposals
456 sampled_inds = self.subsample(labels)
457 matched_gt_boxes = []
458 num_images = len(proposals)
459 for img_id in range(num_images):
460 img_sampled_inds = sampled_inds[img_id]
461 proposals[img_id] = proposals[img_id][img_sampled_inds]
462 labels[img_id] = labels[img_id][img_sampled_inds]
463 matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
464 matched_gt_boxes.append(gt_boxes[img_id][matched_idxs[img_id]])
466 regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
467 return proposals, matched_idxs, labels, regression_targets
469 def postprocess_detections(self, class_logits, box_regression, proposals, image_shapes):
470 device = class_logits.device
471 num_classes = class_logits.shape[-1]
473 boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals]
474 pred_boxes = self.box_coder.decode(box_regression, proposals)
476 pred_scores = F.softmax(class_logits, -1)
478 # split boxes and scores per image
479 pred_boxes = pred_boxes.split(boxes_per_image, 0)
480 pred_scores = pred_scores.split(boxes_per_image, 0)
482 all_boxes = []
483 all_scores = []
484 all_labels = []
485 for boxes, scores, image_shape in zip(pred_boxes, pred_scores, image_shapes):
486 boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
488 # create labels for each prediction
489 labels = torch.arange(num_classes, device=device)
490 labels = labels.view(1, -1).expand_as(scores)
492 # remove predictions with the background label
493 boxes = boxes[:, 1:]
494 scores = scores[:, 1:]
495 labels = labels[:, 1:]
497 # batch everything, by making every class prediction be a separate instance
498 boxes = boxes.reshape(-1, 4)
499 scores = scores.flatten()
500 labels = labels.flatten()
502 # remove low scoring boxes
503 inds = torch.nonzero(scores > self.score_thresh).squeeze(1)
504 boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
506 # non-maximum suppression, independently done per class
507 keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
508 # keep only topk scoring predictions
509 keep = keep[:self.detections_per_img]
510 boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
512 all_boxes.append(boxes)
513 all_scores.append(scores)
514 all_labels.append(labels)
516 return all_boxes, all_scores, all_labels
518 def forward(self, features, proposals, image_shapes, targets=None):
519 """
520 Arguments:
521 features (List[Tensor])
522 proposals (List[Tensor[N, 4]])
523 image_shapes (List[Tuple[H, W]])
524 targets (List[Dict])
525 """
526 if self.training:
527 proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
529 box_features = self.box_roi_pool(features, proposals, image_shapes)
530 box_features = self.box_head(box_features)
531 class_logits, box_regression = self.box_predictor(box_features)
533 result, losses = [], {}
534 if self.training:
535 loss_classifier, loss_box_reg = fastrcnn_loss(
536 class_logits, box_regression, labels, regression_targets)
537 losses = dict(loss_classifier=loss_classifier, loss_box_reg=loss_box_reg)
538 else:
539 boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
540 num_images = len(boxes)
541 for i in range(num_images):
542 result.append(
543 dict(
544 boxes=boxes[i],
545 labels=labels[i],
546 scores=scores[i],
547 )
548 )
550 if self.has_mask:
551 mask_proposals = [p["boxes"] for p in result]
552 if self.training:
553 # during training, only focus on positive boxes
554 num_images = len(proposals)
555 mask_proposals = []
556 pos_matched_idxs = []
557 for img_id in range(num_images):
558 pos = torch.nonzero(labels[img_id] > 0).squeeze(1)
559 mask_proposals.append(proposals[img_id][pos])
560 pos_matched_idxs.append(matched_idxs[img_id][pos])
562 mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
563 mask_features = self.mask_head(mask_features)
564 mask_logits = self.mask_predictor(mask_features)
566 loss_mask = {}
567 if self.training:
568 gt_masks = [t["masks"] for t in targets]
569 gt_labels = [t["labels"] for t in targets]
570 loss_mask = maskrcnn_loss(
571 mask_logits, mask_proposals,
572 gt_masks, gt_labels, pos_matched_idxs)
573 loss_mask = dict(loss_mask=loss_mask)
574 else:
575 labels = [r["labels"] for r in result]
576 masks_probs = maskrcnn_inference(mask_logits, labels)
577 for mask_prob, r in zip(masks_probs, result):
578 r["masks"] = mask_prob
580 losses.update(loss_mask)
582 if self.has_keypoint:
583 keypoint_proposals = [p["boxes"] for p in result]
584 if self.training:
585 # during training, only focus on positive boxes
586 num_images = len(proposals)
587 keypoint_proposals = []
588 pos_matched_idxs = []
589 for img_id in range(num_images):
590 pos = torch.nonzero(labels[img_id] > 0).squeeze(1)
591 keypoint_proposals.append(proposals[img_id][pos])
592 pos_matched_idxs.append(matched_idxs[img_id][pos])
594 keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
595 keypoint_features = self.keypoint_head(keypoint_features)
596 keypoint_logits = self.keypoint_predictor(keypoint_features)
598 loss_keypoint = {}
599 if self.training:
600 gt_keypoints = [t["keypoints"] for t in targets]
601 loss_keypoint = keypointrcnn_loss(
602 keypoint_logits, keypoint_proposals,
603 gt_keypoints, pos_matched_idxs)
604 loss_keypoint = dict(loss_keypoint=loss_keypoint)
605 else:
606 keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
607 for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
608 r["keypoints"] = keypoint_prob
609 r["keypoints_scores"] = kps
611 losses.update(loss_keypoint)
613 return result, losses