]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xvision/models/detection/roi_heads.py
updated python package requirements (don't need tensorflow for tensorboard). not...
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / models / detection / roi_heads.py
1 import torch
3 import torch.nn.functional as F
4 from torch import nn
6 from ...ops import boxes as box_ops
7 from ...ops import misc as misc_nn_ops
8 from ...ops import roi_align
10 from . import _utils as det_utils
13 def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
14     """
15     Computes the loss for Faster R-CNN.
17     Arguments:
18         class_logits (Tensor)
19         box_regression (Tensor)
20         labels (list[BoxList])
21         regression_targets (Tensor)
23     Returns:
24         classification_loss (Tensor)
25         box_loss (Tensor)
26     """
28     labels = torch.cat(labels, dim=0)
29     regression_targets = torch.cat(regression_targets, dim=0)
31     classification_loss = F.cross_entropy(class_logits, labels)
33     # get indices that correspond to the regression targets for
34     # the corresponding ground truth labels, to be used with
35     # advanced indexing
36     sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1)
37     labels_pos = labels[sampled_pos_inds_subset]
38     N, num_classes = class_logits.shape
39     box_regression = box_regression.reshape(N, -1, 4)
41     box_loss = F.smooth_l1_loss(
42         box_regression[sampled_pos_inds_subset, labels_pos],
43         regression_targets[sampled_pos_inds_subset],
44         reduction="sum",
45     )
46     box_loss = box_loss / labels.numel()
48     return classification_loss, box_loss
51 def maskrcnn_inference(x, labels):
52     """
53     From the results of the CNN, post process the masks
54     by taking the mask corresponding to the class with max
55     probability (which are of fixed size and directly output
56     by the CNN) and return the masks in the mask field of the BoxList.
58     Arguments:
59         x (Tensor): the mask logits
60         boxes (list[BoxList]): bounding boxes that are used as
61             reference, one for ech image
63     Returns:
64         results (list[BoxList]): one BoxList for each image, containing
65             the extra field mask
66     """
67     mask_prob = x.sigmoid()
69     # select masks coresponding to the predicted classes
70     num_masks = x.shape[0]
71     boxes_per_image = [len(l) for l in labels]
72     labels = torch.cat(labels)
73     index = torch.arange(num_masks, device=labels.device)
74     mask_prob = mask_prob[index, labels][:, None]
76     mask_prob = mask_prob.split(boxes_per_image, dim=0)
78     return mask_prob
81 def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
82     """
83     Given segmentation masks and the bounding boxes corresponding
84     to the location of the masks in the image, this function
85     crops and resizes the masks in the position defined by the
86     boxes. This prepares the masks for them to be fed to the
87     loss computation as the targets.
88     """
89     matched_idxs = matched_idxs.to(boxes)
90     rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
91     gt_masks = gt_masks[:, None].to(rois)
92     return roi_align(gt_masks, rois, (M, M), 1)[:, 0]
95 def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
96     """
97     Arguments:
98         proposals (list[BoxList])
99         mask_logits (Tensor)
100         targets (list[BoxList])
102     Return:
103         mask_loss (Tensor): scalar tensor containing the loss
104     """
106     discretization_size = mask_logits.shape[-1]
107     labels = [l[idxs] for l, idxs in zip(gt_labels, mask_matched_idxs)]
108     mask_targets = [
109         project_masks_on_boxes(m, p, i, discretization_size)
110         for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
111     ]
113     labels = torch.cat(labels, dim=0)
114     mask_targets = torch.cat(mask_targets, dim=0)
116     # torch.mean (in binary_cross_entropy_with_logits) doesn't
117     # accept empty tensors, so handle it separately
118     if mask_targets.numel() == 0:
119         return mask_logits.sum() * 0
121     mask_loss = F.binary_cross_entropy_with_logits(
122         mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
123     )
124     return mask_loss
127 def keypoints_to_heatmap(keypoints, rois, heatmap_size):
128     offset_x = rois[:, 0]
129     offset_y = rois[:, 1]
130     scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
131     scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
133     offset_x = offset_x[:, None]
134     offset_y = offset_y[:, None]
135     scale_x = scale_x[:, None]
136     scale_y = scale_y[:, None]
138     x = keypoints[..., 0]
139     y = keypoints[..., 1]
141     x_boundary_inds = x == rois[:, 2][:, None]
142     y_boundary_inds = y == rois[:, 3][:, None]
144     x = (x - offset_x) * scale_x
145     x = x.floor().long()
146     y = (y - offset_y) * scale_y
147     y = y.floor().long()
149     x[x_boundary_inds] = heatmap_size - 1
150     y[y_boundary_inds] = heatmap_size - 1
152     valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
153     vis = keypoints[..., 2] > 0
154     valid = (valid_loc & vis).long()
156     lin_ind = y * heatmap_size + x
157     heatmaps = lin_ind * valid
159     return heatmaps, valid
162 def heatmaps_to_keypoints(maps, rois):
163     """Extract predicted keypoint locations from heatmaps. Output has shape
164     (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
165     for each keypoint.
166     """
167     # This function converts a discrete image coordinate in a HEATMAP_SIZE x
168     # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
169     # consistency with keypoints_to_heatmap_labels by using the conversion from
170     # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
171     # continuous coordinate.
172     offset_x = rois[:, 0]
173     offset_y = rois[:, 1]
175     widths = rois[:, 2] - rois[:, 0]
176     heights = rois[:, 3] - rois[:, 1]
177     widths = widths.clamp(min=1)
178     heights = heights.clamp(min=1)
179     widths_ceil = widths.ceil()
180     heights_ceil = heights.ceil()
182     num_keypoints = maps.shape[1]
183     xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
184     end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
185     for i in range(len(rois)):
186         roi_map_width = int(widths_ceil[i].item())
187         roi_map_height = int(heights_ceil[i].item())
188         width_correction = widths[i] / roi_map_width
189         height_correction = heights[i] / roi_map_height
190         roi_map = torch.nn.functional.interpolate(
191             maps[i][None], size=(roi_map_height, roi_map_width), mode='bicubic', align_corners=False)[0]
192         # roi_map_probs = scores_to_probs(roi_map.copy())
193         w = roi_map.shape[2]
194         pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
195         x_int = pos % w
196         y_int = (pos - x_int) // w
197         # assert (roi_map_probs[k, y_int, x_int] ==
198         #         roi_map_probs[k, :, :].max())
199         x = (x_int.float() + 0.5) * width_correction
200         y = (y_int.float() + 0.5) * height_correction
201         xy_preds[i, 0, :] = x + offset_x[i]
202         xy_preds[i, 1, :] = y + offset_y[i]
203         xy_preds[i, 2, :] = 1
204         end_scores[i, :] = roi_map[torch.arange(num_keypoints), y_int, x_int]
206     return xy_preds.permute(0, 2, 1), end_scores
209 def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
210     N, K, H, W = keypoint_logits.shape
211     assert H == W
212     discretization_size = H
213     heatmaps = []
214     valid = []
215     for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
216         kp = gt_kp_in_image[midx]
217         heatmaps_per_image, valid_per_image = keypoints_to_heatmap(
218             kp, proposals_per_image, discretization_size
219         )
220         heatmaps.append(heatmaps_per_image.view(-1))
221         valid.append(valid_per_image.view(-1))
223     keypoint_targets = torch.cat(heatmaps, dim=0)
224     valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
225     valid = torch.nonzero(valid).squeeze(1)
227     # torch.mean (in binary_cross_entropy_with_logits) does'nt
228     # accept empty tensors, so handle it sepaartely
229     if keypoint_targets.numel() == 0 or len(valid) == 0:
230         return keypoint_logits.sum() * 0
232     keypoint_logits = keypoint_logits.view(N * K, H * W)
234     keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
235     return keypoint_loss
238 def keypointrcnn_inference(x, boxes):
239     kp_probs = []
240     kp_scores = []
242     boxes_per_image = [len(box) for box in boxes]
243     x2 = x.split(boxes_per_image, dim=0)
245     for xx, bb in zip(x2, boxes):
246         kp_prob, scores = heatmaps_to_keypoints(xx, bb)
247         kp_probs.append(kp_prob)
248         kp_scores.append(scores)
250     return kp_probs, kp_scores
253 # the next two functions should be merged inside Masker
254 # but are kept here for the moment while we need them
255 # temporarily gor paste_mask_in_image
256 def expand_boxes(boxes, scale):
257     w_half = (boxes[:, 2] - boxes[:, 0]) * .5
258     h_half = (boxes[:, 3] - boxes[:, 1]) * .5
259     x_c = (boxes[:, 2] + boxes[:, 0]) * .5
260     y_c = (boxes[:, 3] + boxes[:, 1]) * .5
262     w_half *= scale
263     h_half *= scale
265     boxes_exp = torch.zeros_like(boxes)
266     boxes_exp[:, 0] = x_c - w_half
267     boxes_exp[:, 2] = x_c + w_half
268     boxes_exp[:, 1] = y_c - h_half
269     boxes_exp[:, 3] = y_c + h_half
270     return boxes_exp
273 def expand_masks(mask, padding):
274     M = mask.shape[-1]
275     scale = float(M + 2 * padding) / M
276     padded_mask = torch.nn.functional.pad(mask, (padding,) * 4)
277     return padded_mask, scale
280 def paste_mask_in_image(mask, box, im_h, im_w):
281     TO_REMOVE = 1
282     w = int(box[2] - box[0] + TO_REMOVE)
283     h = int(box[3] - box[1] + TO_REMOVE)
284     w = max(w, 1)
285     h = max(h, 1)
287     # Set shape to [batchxCxHxW]
288     mask = mask.expand((1, 1, -1, -1))
290     # Resize mask
291     mask = misc_nn_ops.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
292     mask = mask[0][0]
294     im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
295     x_0 = max(box[0], 0)
296     x_1 = min(box[2] + 1, im_w)
297     y_0 = max(box[1], 0)
298     y_1 = min(box[3] + 1, im_h)
300     im_mask[y_0:y_1, x_0:x_1] = mask[
301         (y_0 - box[1]):(y_1 - box[1]), (x_0 - box[0]):(x_1 - box[0])
302     ]
303     return im_mask
306 def paste_masks_in_image(masks, boxes, img_shape, padding=1):
307     masks, scale = expand_masks(masks, padding=padding)
308     boxes = expand_boxes(boxes, scale).to(dtype=torch.int64).tolist()
309     # im_h, im_w = img_shape.tolist()
310     im_h, im_w = img_shape
311     res = [
312         paste_mask_in_image(m[0], b, im_h, im_w)
313         for m, b in zip(masks, boxes)
314     ]
315     if len(res) > 0:
316         res = torch.stack(res, dim=0)[:, None]
317     else:
318         res = masks.new_empty((0, 1, im_h, im_w))
319     return res
322 class RoIHeads(torch.nn.Module):
323     def __init__(self,
324                  box_roi_pool,
325                  box_head,
326                  box_predictor,
327                  # Faster R-CNN training
328                  fg_iou_thresh, bg_iou_thresh,
329                  batch_size_per_image, positive_fraction,
330                  bbox_reg_weights,
331                  # Faster R-CNN inference
332                  score_thresh,
333                  nms_thresh,
334                  detections_per_img,
335                  # Mask
336                  mask_roi_pool=None,
337                  mask_head=None,
338                  mask_predictor=None,
339                  keypoint_roi_pool=None,
340                  keypoint_head=None,
341                  keypoint_predictor=None,
342                  ):
343         super(RoIHeads, self).__init__()
345         self.box_similarity = box_ops.box_iou
346         # assign ground-truth boxes for each proposal
347         self.proposal_matcher = det_utils.Matcher(
348             fg_iou_thresh,
349             bg_iou_thresh,
350             allow_low_quality_matches=False)
352         self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(
353             batch_size_per_image,
354             positive_fraction)
356         if bbox_reg_weights is None:
357             bbox_reg_weights = (10., 10., 5., 5.)
358         self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
360         self.box_roi_pool = box_roi_pool
361         self.box_head = box_head
362         self.box_predictor = box_predictor
364         self.score_thresh = score_thresh
365         self.nms_thresh = nms_thresh
366         self.detections_per_img = detections_per_img
368         self.mask_roi_pool = mask_roi_pool
369         self.mask_head = mask_head
370         self.mask_predictor = mask_predictor
372         self.keypoint_roi_pool = keypoint_roi_pool
373         self.keypoint_head = keypoint_head
374         self.keypoint_predictor = keypoint_predictor
376     @property
377     def has_mask(self):
378         if self.mask_roi_pool is None:
379             return False
380         if self.mask_head is None:
381             return False
382         if self.mask_predictor is None:
383             return False
384         return True
386     @property
387     def has_keypoint(self):
388         if self.keypoint_roi_pool is None:
389             return False
390         if self.keypoint_head is None:
391             return False
392         if self.keypoint_predictor is None:
393             return False
394         return True
396     def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
397         matched_idxs = []
398         labels = []
399         for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
400             match_quality_matrix = self.box_similarity(gt_boxes_in_image, proposals_in_image)
401             matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
403             clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
405             labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
406             labels_in_image = labels_in_image.to(dtype=torch.int64)
408             # Label background (below the low threshold)
409             bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
410             labels_in_image[bg_inds] = 0
412             # Label ignore proposals (between low and high thresholds)
413             ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
414             labels_in_image[ignore_inds] = -1  # -1 is ignored by sampler
416             matched_idxs.append(clamped_matched_idxs_in_image)
417             labels.append(labels_in_image)
418         return matched_idxs, labels
420     def subsample(self, labels):
421         sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
422         sampled_inds = []
423         for img_idx, (pos_inds_img, neg_inds_img) in enumerate(
424             zip(sampled_pos_inds, sampled_neg_inds)
425         ):
426             img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1)
427             sampled_inds.append(img_sampled_inds)
428         return sampled_inds
430     def add_gt_proposals(self, proposals, gt_boxes):
431         proposals = [
432             torch.cat((proposal, gt_box))
433             for proposal, gt_box in zip(proposals, gt_boxes)
434         ]
436         return proposals
438     def check_targets(self, targets):
439         assert targets is not None
440         assert all("boxes" in t for t in targets)
441         assert all("labels" in t for t in targets)
442         if self.has_mask:
443             assert all("masks" in t for t in targets)
445     def select_training_samples(self, proposals, targets):
446         self.check_targets(targets)
447         gt_boxes = [t["boxes"] for t in targets]
448         gt_labels = [t["labels"] for t in targets]
450         # append ground-truth bboxes to propos
451         proposals = self.add_gt_proposals(proposals, gt_boxes)
453         # get matching gt indices for each proposal
454         matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
455         # sample a fixed proportion of positive-negative proposals
456         sampled_inds = self.subsample(labels)
457         matched_gt_boxes = []
458         num_images = len(proposals)
459         for img_id in range(num_images):
460             img_sampled_inds = sampled_inds[img_id]
461             proposals[img_id] = proposals[img_id][img_sampled_inds]
462             labels[img_id] = labels[img_id][img_sampled_inds]
463             matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
464             matched_gt_boxes.append(gt_boxes[img_id][matched_idxs[img_id]])
466         regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
467         return proposals, matched_idxs, labels, regression_targets
469     def postprocess_detections(self, class_logits, box_regression, proposals, image_shapes):
470         device = class_logits.device
471         num_classes = class_logits.shape[-1]
473         boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals]
474         pred_boxes = self.box_coder.decode(box_regression, proposals)
476         pred_scores = F.softmax(class_logits, -1)
478         # split boxes and scores per image
479         pred_boxes = pred_boxes.split(boxes_per_image, 0)
480         pred_scores = pred_scores.split(boxes_per_image, 0)
482         all_boxes = []
483         all_scores = []
484         all_labels = []
485         for boxes, scores, image_shape in zip(pred_boxes, pred_scores, image_shapes):
486             boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
488             # create labels for each prediction
489             labels = torch.arange(num_classes, device=device)
490             labels = labels.view(1, -1).expand_as(scores)
492             # remove predictions with the background label
493             boxes = boxes[:, 1:]
494             scores = scores[:, 1:]
495             labels = labels[:, 1:]
497             # batch everything, by making every class prediction be a separate instance
498             boxes = boxes.reshape(-1, 4)
499             scores = scores.flatten()
500             labels = labels.flatten()
502             # remove low scoring boxes
503             inds = torch.nonzero(scores > self.score_thresh).squeeze(1)
504             boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
506             # non-maximum suppression, independently done per class
507             keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
508             # keep only topk scoring predictions
509             keep = keep[:self.detections_per_img]
510             boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
512             all_boxes.append(boxes)
513             all_scores.append(scores)
514             all_labels.append(labels)
516         return all_boxes, all_scores, all_labels
518     def forward(self, features, proposals, image_shapes, targets=None):
519         """
520         Arguments:
521             features (List[Tensor])
522             proposals (List[Tensor[N, 4]])
523             image_shapes (List[Tuple[H, W]])
524             targets (List[Dict])
525         """
526         if self.training:
527             proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
529         box_features = self.box_roi_pool(features, proposals, image_shapes)
530         box_features = self.box_head(box_features)
531         class_logits, box_regression = self.box_predictor(box_features)
533         result, losses = [], {}
534         if self.training:
535             loss_classifier, loss_box_reg = fastrcnn_loss(
536                 class_logits, box_regression, labels, regression_targets)
537             losses = dict(loss_classifier=loss_classifier, loss_box_reg=loss_box_reg)
538         else:
539             boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
540             num_images = len(boxes)
541             for i in range(num_images):
542                 result.append(
543                     dict(
544                         boxes=boxes[i],
545                         labels=labels[i],
546                         scores=scores[i],
547                     )
548                 )
550         if self.has_mask:
551             mask_proposals = [p["boxes"] for p in result]
552             if self.training:
553                 # during training, only focus on positive boxes
554                 num_images = len(proposals)
555                 mask_proposals = []
556                 pos_matched_idxs = []
557                 for img_id in range(num_images):
558                     pos = torch.nonzero(labels[img_id] > 0).squeeze(1)
559                     mask_proposals.append(proposals[img_id][pos])
560                     pos_matched_idxs.append(matched_idxs[img_id][pos])
562             mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
563             mask_features = self.mask_head(mask_features)
564             mask_logits = self.mask_predictor(mask_features)
566             loss_mask = {}
567             if self.training:
568                 gt_masks = [t["masks"] for t in targets]
569                 gt_labels = [t["labels"] for t in targets]
570                 loss_mask = maskrcnn_loss(
571                     mask_logits, mask_proposals,
572                     gt_masks, gt_labels, pos_matched_idxs)
573                 loss_mask = dict(loss_mask=loss_mask)
574             else:
575                 labels = [r["labels"] for r in result]
576                 masks_probs = maskrcnn_inference(mask_logits, labels)
577                 for mask_prob, r in zip(masks_probs, result):
578                     r["masks"] = mask_prob
580             losses.update(loss_mask)
582         if self.has_keypoint:
583             keypoint_proposals = [p["boxes"] for p in result]
584             if self.training:
585                 # during training, only focus on positive boxes
586                 num_images = len(proposals)
587                 keypoint_proposals = []
588                 pos_matched_idxs = []
589                 for img_id in range(num_images):
590                     pos = torch.nonzero(labels[img_id] > 0).squeeze(1)
591                     keypoint_proposals.append(proposals[img_id][pos])
592                     pos_matched_idxs.append(matched_idxs[img_id][pos])
594             keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
595             keypoint_features = self.keypoint_head(keypoint_features)
596             keypoint_logits = self.keypoint_predictor(keypoint_features)
598             loss_keypoint = {}
599             if self.training:
600                 gt_keypoints = [t["keypoints"] for t in targets]
601                 loss_keypoint = keypointrcnn_loss(
602                     keypoint_logits, keypoint_proposals,
603                     gt_keypoints, pos_matched_idxs)
604                 loss_keypoint = dict(loss_keypoint=loss_keypoint)
605             else:
606                 keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
607                 for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
608                     r["keypoints"] = keypoint_prob
609                     r["keypoints_scores"] = kps
611             losses.update(loss_keypoint)
613         return result, losses