]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xvision/models/detection/keypoint_rcnn.py
docs - added deprecation notice
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / models / detection / keypoint_rcnn.py
1 import torch
2 from torch import nn
4 from torchvision.ops import MultiScaleRoIAlign
6 from ..utils import load_state_dict_from_url
8 from .faster_rcnn import FasterRCNN
9 from .backbone_utils import resnet_fpn_backbone
12 __all__ = [
13     "KeypointRCNN", "keypointrcnn_resnet50_fpn"
14 ]
17 class KeypointRCNN(FasterRCNN):
18     """
19     Implements Keypoint R-CNN.
21     The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
22     image, and should be in 0-1 range. Different images can have different sizes.
24     The behavior of the model changes depending if it is in training or evaluation mode.
26     During training, the model expects both the input tensors, as well as a targets (list of dictionary),
27     containing:
28         - boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with values of x
29           between 0 and W and values of y between 0 and H
30         - labels (Int64Tensor[N]): the class label for each ground-truth box
31         - keypoints (FloatTensor[N, K, 3]): the K keypoints location for each of the N instances, in the
32           format [x, y, visibility], where visibility=0 means that the keypoint is not visible.
34     The model returns a Dict[Tensor] during training, containing the classification and regression
35     losses for both the RPN and the R-CNN, and the keypoint loss.
37     During inference, the model requires only the input tensors, and returns the post-processed
38     predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
39     follows:
40         - boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values of x
41           between 0 and W and values of y between 0 and H
42         - labels (Int64Tensor[N]): the predicted labels for each image
43         - scores (Tensor[N]): the scores or each prediction
44         - keypoints (FloatTensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format.
46     Arguments:
47         backbone (nn.Module): the network used to compute the features for the model.
48             It should contain a out_channels attribute, which indicates the number of output
49             channels that each feature map has (and it should be the same for all feature maps).
50             The backbone should return a single Tensor or and OrderedDict[Tensor].
51         num_classes (int): number of output classes of the model (including the background).
52             If box_predictor is specified, num_classes should be None.
53         min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
54         max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
55         image_mean (Tuple[float, float, float]): mean values used for input normalization.
56             They are generally the mean values of the dataset on which the backbone has been trained
57             on
58         image_std (Tuple[float, float, float]): std values used for input normalization.
59             They are generally the std values of the dataset on which the backbone has been trained on
60         rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
61             maps.
62         rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
63         rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
64         rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
65         rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
66         rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
67         rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
68         rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
69             considered as positive during training of the RPN.
70         rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
71             considered as negative during training of the RPN.
72         rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
73             for computing the loss
74         rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
75             of the RPN
76         box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
77             the locations indicated by the bounding boxes
78         box_head (nn.Module): module that takes the cropped feature maps as input
79         box_predictor (nn.Module): module that takes the output of box_head and returns the
80             classification logits and box regression deltas.
81         box_score_thresh (float): during inference, only return proposals with a classification score
82             greater than box_score_thresh
83         box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
84         box_detections_per_img (int): maximum number of detections per image, for all classes.
85         box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
86             considered as positive during training of the classification head
87         box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
88             considered as negative during training of the classification head
89         box_batch_size_per_image (int): number of proposals that are sampled during training of the
90             classification head
91         box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
92             of the classification head
93         bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
94             bounding boxes
95         keypoint_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
96              the locations indicated by the bounding boxes, which will be used for the keypoint head.
97         keypoint_head (nn.Module): module that takes the cropped feature maps as input
98         keypoint_predictor (nn.Module): module that takes the output of the keypoint_head and returns the
99             heatmap logits
101     Example::
103         >>> import torch
104         >>> import torchvision
105         >>> from torchvision.models.detection import KeypointRCNN
106         >>> from torchvision.models.detection.rpn import AnchorGenerator
107         >>>
108         >>> # load a pre-trained model for classification and return
109         >>> # only the features
110         >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features
111         >>> # KeypointRCNN needs to know the number of
112         >>> # output channels in a backbone. For mobilenet_v2, it's 1280
113         >>> # so we need to add it here
114         >>> backbone.out_channels = 1280
115         >>>
116         >>> # let's make the RPN generate 5 x 3 anchors per spatial
117         >>> # location, with 5 different sizes and 3 different aspect
118         >>> # ratios. We have a Tuple[Tuple[int]] because each feature
119         >>> # map could potentially have different sizes and
120         >>> # aspect ratios
121         >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
122         >>>                                    aspect_ratios=((0.5, 1.0, 2.0),))
123         >>>
124         >>> # let's define what are the feature maps that we will
125         >>> # use to perform the region of interest cropping, as well as
126         >>> # the size of the crop after rescaling.
127         >>> # if your backbone returns a Tensor, featmap_names is expected to
128         >>> # be ['0']. More generally, the backbone should return an
129         >>> # OrderedDict[Tensor], and in featmap_names you can choose which
130         >>> # feature maps to use.
131         >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
132         >>>                                                 output_size=7,
133         >>>                                                 sampling_ratio=2)
134         >>>
135         >>> keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
136         >>>                                                          output_size=14,
137         >>>                                                          sampling_ratio=2)
138         >>> # put the pieces together inside a KeypointRCNN model
139         >>> model = KeypointRCNN(backbone,
140         >>>                      num_classes=2,
141         >>>                      rpn_anchor_generator=anchor_generator,
142         >>>                      box_roi_pool=roi_pooler,
143         >>>                      keypoint_roi_pool=keypoint_roi_pooler)
144         >>> model.eval()
145         >>> model.eval()
146         >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
147         >>> predictions = model(x)
148     """
149     def __init__(self, backbone, num_classes=None,
150                  # transform parameters
151                  min_size=None, max_size=1333,
152                  image_mean=None, image_std=None,
153                  # RPN parameters
154                  rpn_anchor_generator=None, rpn_head=None,
155                  rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000,
156                  rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000,
157                  rpn_nms_thresh=0.7,
158                  rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
159                  rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
160                  # Box parameters
161                  box_roi_pool=None, box_head=None, box_predictor=None,
162                  box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
163                  box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5,
164                  box_batch_size_per_image=512, box_positive_fraction=0.25,
165                  bbox_reg_weights=None,
166                  # keypoint parameters
167                  keypoint_roi_pool=None, keypoint_head=None, keypoint_predictor=None,
168                  num_keypoints=17):
170         assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None)))
171         if min_size is None:
172             min_size = (640, 672, 704, 736, 768, 800)
174         if num_classes is not None:
175             if keypoint_predictor is not None:
176                 raise ValueError("num_classes should be None when keypoint_predictor is specified")
178         out_channels = backbone.out_channels
180         if keypoint_roi_pool is None:
181             keypoint_roi_pool = MultiScaleRoIAlign(
182                 featmap_names=['0', '1', '2', '3'],
183                 output_size=14,
184                 sampling_ratio=2)
186         if keypoint_head is None:
187             keypoint_layers = tuple(512 for _ in range(8))
188             keypoint_head = KeypointRCNNHeads(out_channels, keypoint_layers)
190         if keypoint_predictor is None:
191             keypoint_dim_reduced = 512  # == keypoint_layers[-1]
192             keypoint_predictor = KeypointRCNNPredictor(keypoint_dim_reduced, num_keypoints)
194         super(KeypointRCNN, self).__init__(
195             backbone, num_classes,
196             # transform parameters
197             min_size, max_size,
198             image_mean, image_std,
199             # RPN-specific parameters
200             rpn_anchor_generator, rpn_head,
201             rpn_pre_nms_top_n_train, rpn_pre_nms_top_n_test,
202             rpn_post_nms_top_n_train, rpn_post_nms_top_n_test,
203             rpn_nms_thresh,
204             rpn_fg_iou_thresh, rpn_bg_iou_thresh,
205             rpn_batch_size_per_image, rpn_positive_fraction,
206             # Box parameters
207             box_roi_pool, box_head, box_predictor,
208             box_score_thresh, box_nms_thresh, box_detections_per_img,
209             box_fg_iou_thresh, box_bg_iou_thresh,
210             box_batch_size_per_image, box_positive_fraction,
211             bbox_reg_weights)
213         self.roi_heads.keypoint_roi_pool = keypoint_roi_pool
214         self.roi_heads.keypoint_head = keypoint_head
215         self.roi_heads.keypoint_predictor = keypoint_predictor
218 class KeypointRCNNHeads(nn.Sequential):
219     def __init__(self, in_channels, layers):
220         d = []
221         next_feature = in_channels
222         for out_channels in layers:
223             d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
224             d.append(nn.ReLU(inplace=True))
225             next_feature = out_channels
226         super(KeypointRCNNHeads, self).__init__(*d)
227         for m in self.children():
228             if isinstance(m, nn.Conv2d):
229                 nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
230                 nn.init.constant_(m.bias, 0)
233 class KeypointRCNNPredictor(nn.Module):
234     def __init__(self, in_channels, num_keypoints):
235         super(KeypointRCNNPredictor, self).__init__()
236         input_features = in_channels
237         deconv_kernel = 4
238         self.kps_score_lowres = nn.ConvTranspose2d(
239             input_features,
240             num_keypoints,
241             deconv_kernel,
242             stride=2,
243             padding=deconv_kernel // 2 - 1,
244         )
245         nn.init.kaiming_normal_(
246             self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu"
247         )
248         nn.init.constant_(self.kps_score_lowres.bias, 0)
249         self.up_scale = 2
250         self.out_channels = num_keypoints
252     def forward(self, x):
253         x = self.kps_score_lowres(x)
254         return torch.nn.functional.interpolate(
255             x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
256         )
259 model_urls = {
260     # legacy model for BC reasons, see https://github.com/pytorch/vision/issues/1606
261     'keypointrcnn_resnet50_fpn_coco_legacy':
262         'https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth',
263     'keypointrcnn_resnet50_fpn_coco':
264         'https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth',
268 def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
269                               num_classes=2, num_keypoints=17,
270                               pretrained_backbone=True, trainable_backbone_layers=3, **kwargs):
271     """
272     Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
274     The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
275     image, and should be in ``0-1`` range. Different images can have different sizes.
277     The behavior of the model changes depending if it is in training or evaluation mode.
279     During training, the model expects both the input tensors, as well as a targets (list of dictionary),
280     containing:
281         - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with values of ``x``
282           between ``0`` and ``W`` and values of ``y`` between ``0`` and ``H``
283         - labels (``Int64Tensor[N]``): the class label for each ground-truth box
284         - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
285           format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.
287     The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
288     losses for both the RPN and the R-CNN, and the keypoint loss.
290     During inference, the model requires only the input tensors, and returns the post-processed
291     predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
292     follows:
293         - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format,  with values of ``x``
294           between ``0`` and ``W`` and values of ``y`` between ``0`` and ``H``
295         - labels (``Int64Tensor[N]``): the predicted labels for each image
296         - scores (``Tensor[N]``): the scores or each prediction
297         - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
299     Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
301     Example::
303         >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=True)
304         >>> model.eval()
305         >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
306         >>> predictions = model(x)
307         >>>
308         >>> # optionally, if you want to export the model to ONNX:
309         >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
311     Arguments:
312         pretrained (bool): If True, returns a model pre-trained on COCO train2017
313         progress (bool): If True, displays a progress bar of the download to stderr
314         pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
315         num_classes (int): number of output classes of the model (including the background)
316         trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
317             Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
318     """
319     assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0
320     # dont freeze any layers if pretrained model or backbone is not used
321     if not (pretrained or pretrained_backbone):
322         trainable_backbone_layers = 5
323     if pretrained:
324         # no need to download the backbone if pretrained is set
325         pretrained_backbone = False
326     backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers)
327     model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
328     if pretrained:
329         key = 'keypointrcnn_resnet50_fpn_coco'
330         if pretrained == 'legacy':
331             key += '_legacy'
332         state_dict = load_state_dict_from_url(model_urls[key],
333                                               progress=progress)
334         model.load_state_dict(state_dict)
335     return model