[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / models / detection / keypoint_rcnn.py
1 import torch
2 from torch import nn
4 from torchvision.ops import MultiScaleRoIAlign
6 from ..utils import load_state_dict_from_url
8 from .faster_rcnn import FasterRCNN
9 from .backbone_utils import resnet_fpn_backbone
12 __all__ = [
13 "KeypointRCNN", "keypointrcnn_resnet50_fpn"
14 ]
17 class KeypointRCNN(FasterRCNN):
18 """
19 Implements Keypoint R-CNN.
21 The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
22 image, and should be in 0-1 range. Different images can have different sizes.
24 The behavior of the model changes depending if it is in training or evaluation mode.
26 During training, the model expects both the input tensors, as well as a targets (list of dictionary),
27 containing:
28 - boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with values of x
29 between 0 and W and values of y between 0 and H
30 - labels (Int64Tensor[N]): the class label for each ground-truth box
31 - keypoints (FloatTensor[N, K, 3]): the K keypoints location for each of the N instances, in the
32 format [x, y, visibility], where visibility=0 means that the keypoint is not visible.
34 The model returns a Dict[Tensor] during training, containing the classification and regression
35 losses for both the RPN and the R-CNN, and the keypoint loss.
37 During inference, the model requires only the input tensors, and returns the post-processed
38 predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
39 follows:
40 - boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values of x
41 between 0 and W and values of y between 0 and H
42 - labels (Int64Tensor[N]): the predicted labels for each image
43 - scores (Tensor[N]): the scores or each prediction
44 - keypoints (FloatTensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format.
46 Arguments:
47 backbone (nn.Module): the network used to compute the features for the model.
48 It should contain a out_channels attribute, which indicates the number of output
49 channels that each feature map has (and it should be the same for all feature maps).
50 The backbone should return a single Tensor or and OrderedDict[Tensor].
51 num_classes (int): number of output classes of the model (including the background).
52 If box_predictor is specified, num_classes should be None.
53 min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
54 max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
55 image_mean (Tuple[float, float, float]): mean values used for input normalization.
56 They are generally the mean values of the dataset on which the backbone has been trained
57 on
58 image_std (Tuple[float, float, float]): std values used for input normalization.
59 They are generally the std values of the dataset on which the backbone has been trained on
60 rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
61 maps.
62 rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
63 rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
64 rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
65 rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
66 rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
67 rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
68 rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
69 considered as positive during training of the RPN.
70 rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
71 considered as negative during training of the RPN.
72 rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
73 for computing the loss
74 rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
75 of the RPN
76 box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
77 the locations indicated by the bounding boxes
78 box_head (nn.Module): module that takes the cropped feature maps as input
79 box_predictor (nn.Module): module that takes the output of box_head and returns the
80 classification logits and box regression deltas.
81 box_score_thresh (float): during inference, only return proposals with a classification score
82 greater than box_score_thresh
83 box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
84 box_detections_per_img (int): maximum number of detections per image, for all classes.
85 box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
86 considered as positive during training of the classification head
87 box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
88 considered as negative during training of the classification head
89 box_batch_size_per_image (int): number of proposals that are sampled during training of the
90 classification head
91 box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
92 of the classification head
93 bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
94 bounding boxes
95 keypoint_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
96 the locations indicated by the bounding boxes, which will be used for the keypoint head.
97 keypoint_head (nn.Module): module that takes the cropped feature maps as input
98 keypoint_predictor (nn.Module): module that takes the output of the keypoint_head and returns the
99 heatmap logits
101 Example::
103 >>> import torch
104 >>> import torchvision
105 >>> from torchvision.models.detection import KeypointRCNN
106 >>> from torchvision.models.detection.rpn import AnchorGenerator
107 >>>
108 >>> # load a pre-trained model for classification and return
109 >>> # only the features
110 >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features
111 >>> # KeypointRCNN needs to know the number of
112 >>> # output channels in a backbone. For mobilenet_v2, it's 1280
113 >>> # so we need to add it here
114 >>> backbone.out_channels = 1280
115 >>>
116 >>> # let's make the RPN generate 5 x 3 anchors per spatial
117 >>> # location, with 5 different sizes and 3 different aspect
118 >>> # ratios. We have a Tuple[Tuple[int]] because each feature
119 >>> # map could potentially have different sizes and
120 >>> # aspect ratios
121 >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
122 >>> aspect_ratios=((0.5, 1.0, 2.0),))
123 >>>
124 >>> # let's define what are the feature maps that we will
125 >>> # use to perform the region of interest cropping, as well as
126 >>> # the size of the crop after rescaling.
127 >>> # if your backbone returns a Tensor, featmap_names is expected to
128 >>> # be ['0']. More generally, the backbone should return an
129 >>> # OrderedDict[Tensor], and in featmap_names you can choose which
130 >>> # feature maps to use.
131 >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
132 >>> output_size=7,
133 >>> sampling_ratio=2)
134 >>>
135 >>> keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
136 >>> output_size=14,
137 >>> sampling_ratio=2)
138 >>> # put the pieces together inside a KeypointRCNN model
139 >>> model = KeypointRCNN(backbone,
140 >>> num_classes=2,
141 >>> rpn_anchor_generator=anchor_generator,
142 >>> box_roi_pool=roi_pooler,
143 >>> keypoint_roi_pool=keypoint_roi_pooler)
144 >>> model.eval()
145 >>> model.eval()
146 >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
147 >>> predictions = model(x)
148 """
149 def __init__(self, backbone, num_classes=None,
150 # transform parameters
151 min_size=None, max_size=1333,
152 image_mean=None, image_std=None,
153 # RPN parameters
154 rpn_anchor_generator=None, rpn_head=None,
155 rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000,
156 rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000,
157 rpn_nms_thresh=0.7,
158 rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
159 rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
160 # Box parameters
161 box_roi_pool=None, box_head=None, box_predictor=None,
162 box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
163 box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5,
164 box_batch_size_per_image=512, box_positive_fraction=0.25,
165 bbox_reg_weights=None,
166 # keypoint parameters
167 keypoint_roi_pool=None, keypoint_head=None, keypoint_predictor=None,
168 num_keypoints=17):
170 assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None)))
171 if min_size is None:
172 min_size = (640, 672, 704, 736, 768, 800)
174 if num_classes is not None:
175 if keypoint_predictor is not None:
176 raise ValueError("num_classes should be None when keypoint_predictor is specified")
178 out_channels = backbone.out_channels
180 if keypoint_roi_pool is None:
181 keypoint_roi_pool = MultiScaleRoIAlign(
182 featmap_names=['0', '1', '2', '3'],
183 output_size=14,
184 sampling_ratio=2)
186 if keypoint_head is None:
187 keypoint_layers = tuple(512 for _ in range(8))
188 keypoint_head = KeypointRCNNHeads(out_channels, keypoint_layers)
190 if keypoint_predictor is None:
191 keypoint_dim_reduced = 512 # == keypoint_layers[-1]
192 keypoint_predictor = KeypointRCNNPredictor(keypoint_dim_reduced, num_keypoints)
194 super(KeypointRCNN, self).__init__(
195 backbone, num_classes,
196 # transform parameters
197 min_size, max_size,
198 image_mean, image_std,
199 # RPN-specific parameters
200 rpn_anchor_generator, rpn_head,
201 rpn_pre_nms_top_n_train, rpn_pre_nms_top_n_test,
202 rpn_post_nms_top_n_train, rpn_post_nms_top_n_test,
203 rpn_nms_thresh,
204 rpn_fg_iou_thresh, rpn_bg_iou_thresh,
205 rpn_batch_size_per_image, rpn_positive_fraction,
206 # Box parameters
207 box_roi_pool, box_head, box_predictor,
208 box_score_thresh, box_nms_thresh, box_detections_per_img,
209 box_fg_iou_thresh, box_bg_iou_thresh,
210 box_batch_size_per_image, box_positive_fraction,
211 bbox_reg_weights)
213 self.roi_heads.keypoint_roi_pool = keypoint_roi_pool
214 self.roi_heads.keypoint_head = keypoint_head
215 self.roi_heads.keypoint_predictor = keypoint_predictor
218 class KeypointRCNNHeads(nn.Sequential):
219 def __init__(self, in_channels, layers):
220 d = []
221 next_feature = in_channels
222 for out_channels in layers:
223 d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
224 d.append(nn.ReLU(inplace=True))
225 next_feature = out_channels
226 super(KeypointRCNNHeads, self).__init__(*d)
227 for m in self.children():
228 if isinstance(m, nn.Conv2d):
229 nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
230 nn.init.constant_(m.bias, 0)
233 class KeypointRCNNPredictor(nn.Module):
234 def __init__(self, in_channels, num_keypoints):
235 super(KeypointRCNNPredictor, self).__init__()
236 input_features = in_channels
237 deconv_kernel = 4
238 self.kps_score_lowres = nn.ConvTranspose2d(
239 input_features,
240 num_keypoints,
241 deconv_kernel,
242 stride=2,
243 padding=deconv_kernel // 2 - 1,
244 )
245 nn.init.kaiming_normal_(
246 self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu"
247 )
248 nn.init.constant_(self.kps_score_lowres.bias, 0)
249 self.up_scale = 2
250 self.out_channels = num_keypoints
252 def forward(self, x):
253 x = self.kps_score_lowres(x)
254 return torch.nn.functional.interpolate(
255 x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
256 )
259 model_urls = {
260 # legacy model for BC reasons, see https://github.com/pytorch/vision/issues/1606
261 'keypointrcnn_resnet50_fpn_coco_legacy':
262 'https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth',
263 'keypointrcnn_resnet50_fpn_coco':
264 'https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth',
265 }
268 def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
269 num_classes=2, num_keypoints=17,
270 pretrained_backbone=True, trainable_backbone_layers=3, **kwargs):
271 """
272 Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
274 The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
275 image, and should be in ``0-1`` range. Different images can have different sizes.
277 The behavior of the model changes depending if it is in training or evaluation mode.
279 During training, the model expects both the input tensors, as well as a targets (list of dictionary),
280 containing:
281 - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with values of ``x``
282 between ``0`` and ``W`` and values of ``y`` between ``0`` and ``H``
283 - labels (``Int64Tensor[N]``): the class label for each ground-truth box
284 - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
285 format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.
287 The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
288 losses for both the RPN and the R-CNN, and the keypoint loss.
290 During inference, the model requires only the input tensors, and returns the post-processed
291 predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
292 follows:
293 - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with values of ``x``
294 between ``0`` and ``W`` and values of ``y`` between ``0`` and ``H``
295 - labels (``Int64Tensor[N]``): the predicted labels for each image
296 - scores (``Tensor[N]``): the scores or each prediction
297 - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
299 Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
301 Example::
303 >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=True)
304 >>> model.eval()
305 >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
306 >>> predictions = model(x)
307 >>>
308 >>> # optionally, if you want to export the model to ONNX:
309 >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
311 Arguments:
312 pretrained (bool): If True, returns a model pre-trained on COCO train2017
313 progress (bool): If True, displays a progress bar of the download to stderr
314 pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
315 num_classes (int): number of output classes of the model (including the background)
316 trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
317 Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
318 """
319 assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0
320 # dont freeze any layers if pretrained model or backbone is not used
321 if not (pretrained or pretrained_backbone):
322 trainable_backbone_layers = 5
323 if pretrained:
324 # no need to download the backbone if pretrained is set
325 pretrained_backbone = False
326 backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers)
327 model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
328 if pretrained:
329 key = 'keypointrcnn_resnet50_fpn_coco'
330 if pretrained == 'legacy':
331 key += '_legacy'
332 state_dict = load_state_dict_from_url(model_urls[key],
333 progress=progress)
334 model.load_state_dict(state_dict)
335 return model