summary | shortlog | log | commit | commitdiff | tree
raw | patch | inline | side by side (parent: b14f424)
raw | patch | inline | side by side (parent: b14f424)
author | Manu Mathew <a0393608@ti.com> | |
Sat, 20 Jun 2020 11:05:28 +0000 (16:35 +0530) | ||
committer | Manu Mathew <a0393608@ti.com> | |
Sat, 20 Jun 2020 11:46:43 +0000 (17:16 +0530) |
12 files changed:
index 9b8388d4a9c576881adec12d47b83617e4eb0fdc..c72df19a910e4eedb371c383459dfedb94177c26 100644 (file)
fpn_start_level = 1
fpn_num_outs = 5
-#retinanet_base_stride = (8 if fpn_start_level==1 else (4 if fpn_start_level==0 else None))
+fcos_num_levels = 5
+fcos_base_stride = (8 if fpn_start_level==1 else (4 if fpn_start_level==0 else None))
# for multi-scale training
input_size_ms = [(input_size[0], (input_size[1]*8)//10),
norm_eval=False,
style='pytorch'),
neck=dict(
- type='FPN',
+ type='JaiFPN',
in_channels=fpn_in_channels,
out_channels=fpn_out_channels,
start_level=fpn_start_level,
diff --git a/configs/retinanet/retinanet_regnet_fpn_bgr.py b/configs/retinanet/retinanet_regnet_fpn_bgr.py
index 74d81cce9c91abb0da65086c6e205d29e177ac39..730ce641f987aa012ca95503c855e104818d15f1 100644 (file)
'../_jacinto_ai_base_/hyper_params/ssd_config.py',
]
-dataset_type = 'VOCDataset'
+dataset_type = 'CocoDataset'
if dataset_type == 'VOCDataset':
_base_ += ['../_jacinto_ai_base_/datasets/voc0712_det.py']
norm_eval=False,
style='pytorch'),
neck=dict(
- type='FPN',
+ type='JaiFPN',
in_channels=fpn_in_channels,
out_channels=fpn_out_channels,
start_level=fpn_start_level,
]
data = dict(
- samples_per_gpu=1, #16,
+ samples_per_gpu=16,
workers_per_gpu=3,
train=dict(dataset=dict(pipeline=train_pipeline)),
val=dict(pipeline=test_pipeline),
index c71c66ba85c0adbdaa19265deb7a5eef827d56d6..069c8c1310551ff417be8eb0b8cd09c1997e323c 100644 (file)
'../_jacinto_ai_base_/hyper_params/ssd_config.py',
]
-dataset_type = 'VOCDataset'
+dataset_type = 'CocoDataset'
if dataset_type == 'VOCDataset':
_base_ += ['../_jacinto_ai_base_/datasets/voc0712_det.py']
index a7605d349c9ef8587b300f3701df433a6ea38fe5..0de3d48680dc1211ae548d8c74d2c18c0726ff7d 100644 (file)
out_feature_indices=None,
l2_norm_scale=None),
neck=dict(
- type='FPN',
+ type='JaiFPN',
in_channels=fpn_in_channels,
out_channels=fpn_out_channels,
start_level=fpn_start_level,
index 0a2fa5ff97b4bd7842bdd9c38821a4f70d203c06..d02f0799c45788fa6b6bd68dfdf3c5e7f3c04630 100644 (file)
norm_eval=False,
style='pytorch'),
neck=dict(
- type='FPN',
+ type='JaiFPN',
in_channels=fpn_in_channels,
out_channels=fpn_out_channels,
start_level=fpn_start_level,
index c24b1a02990be9c7cc80ce5e059ba3ecf67639d7..f9680715dba57299e1e7f92487e8261047e33ac8 100644 (file)
type='SSDHead',
in_channels=[fpn_out_channels for _ in range(6)],
num_classes=num_classes,
- conv_cfg=conv_cfg,
anchor_generator=dict(
type='SSDAnchorGenerator',
scale_major=False,
index 27e15c3d17fd654dab4da5b9bb2c7308ffdb1050..4c66b65141b2a68b61e59535a82396014c7539be 100644 (file)
from mmdet.core import distance2bbox, force_fp32, multi_apply, multiclass_nms
from mmdet.models.builder import HEADS, build_loss
+from mmdet.models.dense_heads.fcos_head import FCOSHead
+from ...ops import ConvModuleWrapper
INF = 1e8
-@HEADS.register_module(force=True)
-class FCOSHead(nn.Module):
+@HEADS.register_module()
+class JaiFCOSHead(FCOSHead):
"""Anchor-free head used in `FCOS <https://arxiv.org/abs/1904.01355>`_.
The FCOS head does not use anchor boxes. Instead bounding boxes are
low-quality predictions.
Example:
- >>> self = FCOSHead(11, 7)
+ >>> self = JaiFCOSHead(11, 7)
>>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]]
>>> cls_score, bbox_pred, centerness = self.forward(feats)
>>> assert len(cls_score) == len(self.scales)
"""
- def __init__(self,
- num_classes,
- in_channels,
- feat_channels=256,
- stacked_convs=4,
- strides=(4, 8, 16, 32, 64),
- regress_ranges=((-1, 64), (64, 128), (128, 256), (256, 512),
- (512, INF)),
- center_sampling=False,
- center_sample_radius=1.5,
- background_label=None,
- loss_cls=dict(
- type='FocalLoss',
- use_sigmoid=True,
- gamma=2.0,
- alpha=0.25,
- loss_weight=1.0),
- loss_bbox=dict(type='IoULoss', loss_weight=1.0),
- loss_centerness=dict(
- type='CrossEntropyLoss',
- use_sigmoid=True,
- loss_weight=1.0),
- conv_cfg=None,
- norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
- train_cfg=None,
- test_cfg=None):
- super(FCOSHead, self).__init__()
- self.num_classes = num_classes
- self.cls_out_channels = num_classes
- self.in_channels = in_channels
- self.feat_channels = feat_channels
- self.stacked_convs = stacked_convs
- self.strides = strides
- self.regress_ranges = regress_ranges
- self.loss_cls = build_loss(loss_cls)
- self.loss_bbox = build_loss(loss_bbox)
- self.loss_centerness = build_loss(loss_centerness)
- self.train_cfg = train_cfg
- self.test_cfg = test_cfg
- self.conv_cfg = conv_cfg
- self.norm_cfg = norm_cfg
- self.fp16_enabled = False
- self.center_sampling = center_sampling
- self.center_sample_radius = center_sample_radius
- self.background_label = (
- num_classes if background_label is None else background_label)
- # background_label should be either 0 or num_classes
- assert (self.background_label == 0
- or self.background_label == num_classes)
-
- self._init_layers()
-
def _init_layers(self):
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
- ConvModule(
+ ConvModuleWrapper(
chn,
self.feat_channels,
3,
norm_cfg=self.norm_cfg,
bias=self.norm_cfg is None))
self.reg_convs.append(
- ConvModule(
+ ConvModuleWrapper(
chn,
self.feat_channels,
3,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
bias=self.norm_cfg is None))
- self.fcos_cls = nn.Conv2d(
- self.feat_channels, self.cls_out_channels, 3, padding=1)
- self.fcos_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
- self.fcos_centerness = nn.Conv2d(self.feat_channels, 1, 3, padding=1)
+ self.fcos_cls = ConvModuleWrapper(self.feat_channels, self.cls_out_channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=None, act_cfg=None)
+ self.fcos_reg = ConvModuleWrapper(self.feat_channels, 4, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=None, act_cfg=None)
+ self.fcos_centerness = ConvModuleWrapper(self.feat_channels, 1, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=None, act_cfg=None)
self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])
normal_init(self.fcos_reg, std=0.01)
if hasattr(self.fcos_centerness, 'weight'):
normal_init(self.fcos_centerness, std=0.01)
-
- def forward(self, feats):
- # if multiple features are provided, take the first (highest resolution) one
- num_strides = len(self.strides)
- feats = feats[:num_strides] if isinstance(feats, (list,tuple)) else feats
- return multi_apply(self.forward_single, feats, self.scales)
-
- def forward_single(self, x, scale):
- cls_feat = x
- reg_feat = x
-
- for cls_layer in self.cls_convs:
- cls_feat = cls_layer(cls_feat)
- cls_score = self.fcos_cls(cls_feat)
- centerness = self.fcos_centerness(cls_feat)
-
- for reg_layer in self.reg_convs:
- reg_feat = reg_layer(reg_feat)
- # scale the bbox_pred of different level
- # float to avoid overflow when enabling FP16
- bbox_pred = scale(self.fcos_reg(reg_feat)).float().exp()
- return cls_score, bbox_pred, centerness
-
- @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
- def loss(self,
- cls_scores,
- bbox_preds,
- centernesses,
- gt_bboxes,
- gt_labels,
- img_metas,
- gt_bboxes_ignore=None):
- assert len(cls_scores) == len(bbox_preds) == len(centernesses)
- featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
- all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
- bbox_preds[0].device)
- labels, bbox_targets = self.get_targets(all_level_points, gt_bboxes,
- gt_labels)
-
- num_imgs = cls_scores[0].size(0)
- # flatten cls_scores, bbox_preds and centerness
- flatten_cls_scores = [
- cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
- for cls_score in cls_scores
- ]
- flatten_bbox_preds = [
- bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
- for bbox_pred in bbox_preds
- ]
- flatten_centerness = [
- centerness.permute(0, 2, 3, 1).reshape(-1)
- for centerness in centernesses
- ]
- flatten_cls_scores = torch.cat(flatten_cls_scores)
- flatten_bbox_preds = torch.cat(flatten_bbox_preds)
- flatten_centerness = torch.cat(flatten_centerness)
- flatten_labels = torch.cat(labels)
- flatten_bbox_targets = torch.cat(bbox_targets)
- # repeat points to align with bbox_preds
- flatten_points = torch.cat(
- [points.repeat(num_imgs, 1) for points in all_level_points])
-
- # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
- bg_class_ind = self.num_classes
- pos_inds = ((flatten_labels >= 0)
- & (flatten_labels < bg_class_ind)).nonzero().reshape(-1)
- num_pos = len(pos_inds)
- loss_cls = self.loss_cls(
- flatten_cls_scores, flatten_labels,
- avg_factor=num_pos + num_imgs) # avoid num_pos is 0
-
- pos_bbox_preds = flatten_bbox_preds[pos_inds]
- pos_centerness = flatten_centerness[pos_inds]
-
- if num_pos > 0:
- pos_bbox_targets = flatten_bbox_targets[pos_inds]
- pos_centerness_targets = self.centerness_target(pos_bbox_targets)
- pos_points = flatten_points[pos_inds]
- pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds)
- pos_decoded_target_preds = distance2bbox(pos_points,
- pos_bbox_targets)
- # centerness weighted iou loss
- loss_bbox = self.loss_bbox(
- pos_decoded_bbox_preds,
- pos_decoded_target_preds,
- weight=pos_centerness_targets,
- avg_factor=pos_centerness_targets.sum())
- loss_centerness = self.loss_centerness(pos_centerness,
- pos_centerness_targets)
- else:
- loss_bbox = pos_bbox_preds.sum()
- loss_centerness = pos_centerness.sum()
-
- return dict(
- loss_cls=loss_cls,
- loss_bbox=loss_bbox,
- loss_centerness=loss_centerness)
-
- @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
- def get_bboxes(self,
- cls_scores,
- bbox_preds,
- centernesses,
- img_metas,
- cfg=None,
- rescale=None):
- assert len(cls_scores) == len(bbox_preds)
- num_levels = len(cls_scores)
-
- featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
- mlvl_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
- bbox_preds[0].device)
- result_list = []
- for img_id in range(len(img_metas)):
- cls_score_list = [
- cls_scores[i][img_id].detach() for i in range(num_levels)
- ]
- bbox_pred_list = [
- bbox_preds[i][img_id].detach() for i in range(num_levels)
- ]
- centerness_pred_list = [
- centernesses[i][img_id].detach() for i in range(num_levels)
- ]
- img_shape = img_metas[img_id]['img_shape']
- scale_factor = img_metas[img_id]['scale_factor']
- det_bboxes = self._get_bboxes_single(cls_score_list,
- bbox_pred_list,
- centerness_pred_list,
- mlvl_points, img_shape,
- scale_factor, cfg, rescale)
- result_list.append(det_bboxes)
- return result_list
-
- def _get_bboxes_single(self,
- cls_scores,
- bbox_preds,
- centernesses,
- mlvl_points,
- img_shape,
- scale_factor,
- cfg,
- rescale=False):
- cfg = self.test_cfg if cfg is None else cfg
- assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
- mlvl_bboxes = []
- mlvl_scores = []
- mlvl_centerness = []
- for cls_score, bbox_pred, centerness, points in zip(
- cls_scores, bbox_preds, centernesses, mlvl_points):
- assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
- scores = cls_score.permute(1, 2, 0).reshape(
- -1, self.cls_out_channels).sigmoid()
- centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid()
-
- bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
- nms_pre = cfg.get('nms_pre', -1)
- if nms_pre > 0 and scores.shape[0] > nms_pre:
- max_scores, _ = (scores * centerness[:, None]).max(dim=1)
- _, topk_inds = max_scores.topk(nms_pre)
- points = points[topk_inds, :]
- bbox_pred = bbox_pred[topk_inds, :]
- scores = scores[topk_inds, :]
- centerness = centerness[topk_inds]
- bboxes = distance2bbox(points, bbox_pred, max_shape=img_shape)
- mlvl_bboxes.append(bboxes)
- mlvl_scores.append(scores)
- mlvl_centerness.append(centerness)
- mlvl_bboxes = torch.cat(mlvl_bboxes)
- if rescale:
- mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
- mlvl_scores = torch.cat(mlvl_scores)
- padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
- # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
- # BG cat_id: num_class
- mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
- mlvl_centerness = torch.cat(mlvl_centerness)
- det_bboxes, det_labels = multiclass_nms(
- mlvl_bboxes,
- mlvl_scores,
- cfg.score_thr,
- cfg.nms,
- cfg.max_per_img,
- score_factors=mlvl_centerness)
- return det_bboxes, det_labels
-
- def get_points(self, featmap_sizes, dtype, device):
- """Get points according to feature map sizes.
-
- Args:
- featmap_sizes (list[tuple]): Multi-level feature map sizes.
- dtype (torch.dtype): Type of points.
- device (torch.device): Device of points.
-
- Returns:
- tuple: points of each image.
- """
- mlvl_points = []
- for i in range(len(featmap_sizes)):
- mlvl_points.append(
- self._get_points_single(featmap_sizes[i], self.strides[i],
- dtype, device))
- return mlvl_points
-
- def _get_points_single(self, featmap_size, stride, dtype, device):
- h, w = featmap_size
- x_range = torch.arange(
- 0, w * stride, stride, dtype=dtype, device=device)
- y_range = torch.arange(
- 0, h * stride, stride, dtype=dtype, device=device)
- y, x = torch.meshgrid(y_range, x_range)
- points = torch.stack(
- (x.reshape(-1), y.reshape(-1)), dim=-1) + stride // 2
- return points
-
- def get_targets(self, points, gt_bboxes_list, gt_labels_list):
- assert len(points) == len(self.regress_ranges)
- num_levels = len(points)
- # expand regress ranges to align with points
- expanded_regress_ranges = [
- points[i].new_tensor(self.regress_ranges[i])[None].expand_as(
- points[i]) for i in range(num_levels)
- ]
- # concat all levels points and regress ranges
- concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0)
- concat_points = torch.cat(points, dim=0)
-
- # the number of points per img, per lvl
- num_points = [center.size(0) for center in points]
-
- # get labels and bbox_targets of each image
- labels_list, bbox_targets_list = multi_apply(
- self._get_target_single,
- gt_bboxes_list,
- gt_labels_list,
- points=concat_points,
- regress_ranges=concat_regress_ranges,
- num_points_per_lvl=num_points)
-
- # split to per img, per level
- labels_list = [labels.split(num_points, 0) for labels in labels_list]
- bbox_targets_list = [
- bbox_targets.split(num_points, 0)
- for bbox_targets in bbox_targets_list
- ]
-
- # concat per level image
- concat_lvl_labels = []
- concat_lvl_bbox_targets = []
- for i in range(num_levels):
- concat_lvl_labels.append(
- torch.cat([labels[i] for labels in labels_list]))
- concat_lvl_bbox_targets.append(
- torch.cat(
- [bbox_targets[i] for bbox_targets in bbox_targets_list]))
- return concat_lvl_labels, concat_lvl_bbox_targets
-
- def _get_target_single(self, gt_bboxes, gt_labels, points, regress_ranges,
- num_points_per_lvl):
- num_points = points.size(0)
- num_gts = gt_labels.size(0)
- if num_gts == 0:
- return gt_labels.new_full((num_points, ), self.background_label), \
- gt_bboxes.new_zeros((num_points, 4))
-
- areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (
- gt_bboxes[:, 3] - gt_bboxes[:, 1])
- # TODO: figure out why these two are different
- # areas = areas[None].expand(num_points, num_gts)
- areas = areas[None].repeat(num_points, 1)
- regress_ranges = regress_ranges[:, None, :].expand(
- num_points, num_gts, 2)
- gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4)
- xs, ys = points[:, 0], points[:, 1]
- xs = xs[:, None].expand(num_points, num_gts)
- ys = ys[:, None].expand(num_points, num_gts)
-
- left = xs - gt_bboxes[..., 0]
- right = gt_bboxes[..., 2] - xs
- top = ys - gt_bboxes[..., 1]
- bottom = gt_bboxes[..., 3] - ys
- bbox_targets = torch.stack((left, top, right, bottom), -1)
-
- if self.center_sampling:
- # condition1: inside a `center bbox`
- radius = self.center_sample_radius
- center_xs = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) / 2
- center_ys = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) / 2
- center_gts = torch.zeros_like(gt_bboxes)
- stride = center_xs.new_zeros(center_xs.shape)
-
- # project the points on current lvl back to the `original` sizes
- lvl_begin = 0
- for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl):
- lvl_end = lvl_begin + num_points_lvl
- stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius
- lvl_begin = lvl_end
-
- x_mins = center_xs - stride
- y_mins = center_ys - stride
- x_maxs = center_xs + stride
- y_maxs = center_ys + stride
- center_gts[..., 0] = torch.where(x_mins > gt_bboxes[..., 0],
- x_mins, gt_bboxes[..., 0])
- center_gts[..., 1] = torch.where(y_mins > gt_bboxes[..., 1],
- y_mins, gt_bboxes[..., 1])
- center_gts[..., 2] = torch.where(x_maxs > gt_bboxes[..., 2],
- gt_bboxes[..., 2], x_maxs)
- center_gts[..., 3] = torch.where(y_maxs > gt_bboxes[..., 3],
- gt_bboxes[..., 3], y_maxs)
-
- cb_dist_left = xs - center_gts[..., 0]
- cb_dist_right = center_gts[..., 2] - xs
- cb_dist_top = ys - center_gts[..., 1]
- cb_dist_bottom = center_gts[..., 3] - ys
- center_bbox = torch.stack(
- (cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1)
- inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0
- else:
- # condition1: inside a gt bbox
- inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0
-
- # condition2: limit the regression range for each location
- max_regress_distance = bbox_targets.max(-1)[0]
- inside_regress_range = (
- max_regress_distance >= regress_ranges[..., 0]) & (
- max_regress_distance <= regress_ranges[..., 1])
-
- # if there are still more than one objects for a location,
- # we choose the one with minimal area
- areas[inside_gt_bbox_mask == 0] = INF
- areas[inside_regress_range == 0] = INF
- min_area, min_area_inds = areas.min(dim=1)
-
- labels = gt_labels[min_area_inds]
- labels[min_area == INF] = self.background_label # set as BG
- bbox_targets = bbox_targets[range(num_points), min_area_inds]
-
- return labels, bbox_targets
-
- def centerness_target(self, pos_bbox_targets):
- # only calculate pos centerness targets, otherwise there may be nan
- left_right = pos_bbox_targets[:, [0, 2]]
- top_bottom = pos_bbox_targets[:, [1, 3]]
- centerness_targets = (
- left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (
- top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
- return torch.sqrt(centerness_targets)
-
-
-
-@HEADS.register_module()
-class JaiFCOSHead(FCOSHead):
- def _init_layers(self):
- self.cls_convs = nn.ModuleList()
- self.reg_convs = nn.ModuleList()
- for i in range(self.stacked_convs):
- chn = self.in_channels if i == 0 else self.feat_channels
- self.cls_convs.append(
- ConvModule(
- chn,
- self.feat_channels,
- 3,
- stride=1,
- padding=1,
- conv_cfg=self.conv_cfg,
- norm_cfg=self.norm_cfg))
- self.reg_convs.append(
- ConvModule(
- chn,
- self.feat_channels,
- 3,
- stride=1,
- padding=1,
- conv_cfg=self.conv_cfg,
- norm_cfg=self.norm_cfg))
- self.fcos_cls = ConvModule(self.feat_channels, self.cls_out_channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=None, act_cfg=None)
- self.fcos_reg = ConvModule(self.feat_channels, 4, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=None, act_cfg=None)
- self.fcos_centerness = ConvModule(self.feat_channels, 1, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=None, act_cfg=None)
-
- self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])
\ No newline at end of file
index 27031acdaa7321f8a0b2a3dff1b035601cd8ac00..b76b6cc880d3d065bddafca0d1f73c83b0e47a13 100644 (file)
from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init
from mmdet.models.builder import HEADS
-from mmdet.models.dense_heads.anchor_head import AnchorHead
+from mmdet.models.dense_heads.retina_head import RetinaHead
+from ...ops import ConvModuleWrapper
-@HEADS.register_module(force=True)
-class RetinaHead(AnchorHead):
+@HEADS.register_module()
+class JaiRetinaHead(RetinaHead):
"""An anchor-based head used in
`RetinaNet <https://arxiv.org/pdf/1708.02002.pdf>`_.
Example:
>>> import torch
- >>> self = RetinaHead(11, 7)
+ >>> self = JaiRetinaHead(11, 7)
>>> x = torch.rand(1, 7, 32, 32)
>>> cls_score, bbox_pred = self.forward_single(x)
>>> # Each anchor predicts a score for each class except background
>>> assert box_per_anchor == 4
"""
- def __init__(self,
- num_classes,
- in_channels,
- stacked_convs=4,
- conv_cfg=None,
- norm_cfg=None,
- anchor_generator=dict(
- type='AnchorGenerator',
- octave_base_scale=4,
- scales_per_octave=3,
- ratios=[0.5, 1.0, 2.0],
- strides=[8, 16, 32, 64, 128]),
- **kwargs):
- self.stacked_convs = stacked_convs
- self.conv_cfg = conv_cfg
- self.norm_cfg = norm_cfg
- super(RetinaHead, self).__init__(
- num_classes,
- in_channels,
- anchor_generator=anchor_generator,
- **kwargs)
-
def _init_layers(self):
- self.relu = nn.ReLU(inplace=True)
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
- ConvModule(
+ ConvModuleWrapper(
chn,
self.feat_channels,
3,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.reg_convs.append(
- ConvModule(
+ ConvModuleWrapper(
chn,
self.feat_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
- self.retina_cls = nn.Conv2d(
+ self.retina_cls = ConvModuleWrapper(
self.feat_channels,
self.num_anchors * self.cls_out_channels,
3,
- padding=1)
- self.retina_reg = nn.Conv2d(
- self.feat_channels, self.num_anchors * 4, 3, padding=1)
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=None,
+ act_cfg=None)
+ self.retina_reg = ConvModuleWrapper(
+ self.feat_channels, self.num_anchors * 4, 3, padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=None,
+ act_cfg=None)
def init_weights(self):
for m in self.cls_convs:
if hasattr(self.retina_reg, 'weight'):
normal_init(self.retina_reg, std=0.01)
- def forward_single(self, x):
- cls_feat = x
- reg_feat = x
- for cls_conv in self.cls_convs:
- cls_feat = cls_conv(cls_feat)
- for reg_conv in self.reg_convs:
- reg_feat = reg_conv(reg_feat)
- cls_score = self.retina_cls(cls_feat)
- bbox_pred = self.retina_reg(reg_feat)
- return cls_score, bbox_pred
-@HEADS.register_module()
-class JaiRetinaHead(RetinaHead):
- def _init_layers(self):
- self.cls_convs = nn.ModuleList()
- self.reg_convs = nn.ModuleList()
- for i in range(self.stacked_convs):
- chn = self.in_channels if i == 0 else self.feat_channels
- self.cls_convs.append(
- ConvModule(
- chn,
- self.feat_channels,
- 3,
- stride=1,
- padding=1,
- conv_cfg=self.conv_cfg,
- norm_cfg=self.norm_cfg))
- self.reg_convs.append(
- ConvModule(
- chn,
- self.feat_channels,
- 3,
- stride=1,
- padding=1,
- conv_cfg=self.conv_cfg,
- norm_cfg=self.norm_cfg))
- self.retina_cls = ConvModule(
- self.feat_channels,
- self.num_anchors * self.cls_out_channels,
- 3,
- padding=1,
- conv_cfg=self.conv_cfg,
- norm_cfg=None,
- act_cfg=None)
- self.retina_reg = ConvModule(
- self.feat_channels, self.num_anchors * 4, 3, padding=1,
- conv_cfg=self.conv_cfg,
- norm_cfg=None,
- act_cfg=None)
\ No newline at end of file
index 486343f16f0b3c4bd5c8e68e64cef72ea23110c2..a449373a883f5593b15633b8609580487378f149 100644 (file)
build_bbox_coder, build_sampler, multi_apply)
from mmdet.models.builder import HEADS
from mmdet.models.losses import smooth_l1_loss
-from mmdet.models.dense_heads.anchor_head import AnchorHead
+from mmdet.models.dense_heads.ssd_head import SSDHead
+from ...ops import ConvModuleWrapper
# TODO: add loss evaluator for SSD
-@HEADS.register_module(force=True)
-class SSDHead(AnchorHead):
-
- def __init__(self,
- num_classes=80,
- in_channels=(512, 1024, 512, 256, 256, 256),
- anchor_generator=dict(
- type='SSDAnchorGenerator',
- scale_major=False,
- input_size=300,
- strides=[8, 16, 32, 64, 100, 300],
- ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]),
- basesize_ratio_range=(0.1, 0.9)),
- background_label=None,
- bbox_coder=dict(
- type='DeltaXYWHBBoxCoder',
- target_means=[.0, .0, .0, .0],
- target_stds=[1.0, 1.0, 1.0, 1.0],
- ),
- reg_decoded_bbox=False,
- train_cfg=None,
- test_cfg=None,
- conv_cfg=None):
- super(AnchorHead, self).__init__()
- self.num_classes = num_classes
- self.in_channels = in_channels
- self.cls_out_channels = num_classes + 1 # add background class
- self.anchor_generator = build_anchor_generator(anchor_generator)
- self.num_anchors = self.anchor_generator.num_base_anchors
- self.conv_cfg = conv_cfg
- self._init_layers()
- self.background_label = (
- num_classes if background_label is None else background_label)
- # background_label should be either 0 or num_classes
- assert (self.background_label == 0
- or self.background_label == num_classes)
-
- self.bbox_coder = build_bbox_coder(bbox_coder)
- self.reg_decoded_bbox = reg_decoded_bbox
- self.use_sigmoid_cls = False
- self.cls_focal_loss = False
- self.train_cfg = train_cfg
- self.test_cfg = test_cfg
- # set sampling=False for archor_target
- self.sampling = False
- if self.train_cfg:
- self.assigner = build_assigner(self.train_cfg.assigner)
- # SSD sampling=False so use PseudoSampler
- sampler_cfg = dict(type='PseudoSampler')
- self.sampler = build_sampler(sampler_cfg, context=self)
- self.fp16_enabled = False
-
- def _init_layers(self):
- reg_convs = []
- cls_convs = []
- for i in range(len(self.in_channels)):
- reg_convs.append(
- nn.Conv2d(
- self.in_channels[i],
- self.num_anchors[i] * 4,
- kernel_size=3,
- padding=1))
- cls_convs.append(
- nn.Conv2d(
- self.in_channels[i],
- self.num_anchors[i] * (self.num_classes + 1),
- kernel_size=3,
- padding=1))
- self.reg_convs = nn.ModuleList(reg_convs)
- self.cls_convs = nn.ModuleList(cls_convs)
-
- def init_weights(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- xavier_init(m, distribution='uniform', bias=0)
-
- def forward(self, feats):
- cls_scores = []
- bbox_preds = []
- for feat, reg_conv, cls_conv in zip(feats, self.reg_convs,
- self.cls_convs):
- cls_scores.append(cls_conv(feat))
- bbox_preds.append(reg_conv(feat))
- return cls_scores, bbox_preds
-
- def loss_single(self, cls_score, bbox_pred, anchor, labels, label_weights,
- bbox_targets, bbox_weights, num_total_samples):
- loss_cls_all = F.cross_entropy(
- cls_score, labels, reduction='none') * label_weights
- # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
- pos_inds = ((labels >= 0) &
- (labels < self.background_label)).nonzero().reshape(-1)
- neg_inds = (labels == self.background_label).nonzero().view(-1)
-
- num_pos_samples = pos_inds.size(0)
- num_neg_samples = self.train_cfg.neg_pos_ratio * num_pos_samples
- if num_neg_samples > neg_inds.size(0):
- num_neg_samples = neg_inds.size(0)
- topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples)
- loss_cls_pos = loss_cls_all[pos_inds].sum()
- loss_cls_neg = topk_loss_cls_neg.sum()
- loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples
-
- if self.reg_decoded_bbox:
- bbox_pred = self.bbox_coder.decode(anchor, bbox_pred)
-
- loss_bbox = smooth_l1_loss(
- bbox_pred,
- bbox_targets,
- bbox_weights,
- beta=self.train_cfg.smoothl1_beta,
- avg_factor=num_total_samples)
- return loss_cls[None], loss_bbox
-
- def loss(self,
- cls_scores,
- bbox_preds,
- gt_bboxes,
- gt_labels,
- img_metas,
- gt_bboxes_ignore=None):
- featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
- assert len(featmap_sizes) == self.anchor_generator.num_levels
-
- device = cls_scores[0].device
-
- anchor_list, valid_flag_list = self.get_anchors(
- featmap_sizes, img_metas, device=device)
- cls_reg_targets = self.get_targets(
- anchor_list,
- valid_flag_list,
- gt_bboxes,
- img_metas,
- gt_bboxes_ignore_list=gt_bboxes_ignore,
- gt_labels_list=gt_labels,
- label_channels=1,
- unmap_outputs=False)
- if cls_reg_targets is None:
- return None
- (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
- num_total_pos, num_total_neg) = cls_reg_targets
-
- num_images = len(img_metas)
- all_cls_scores = torch.cat([
- s.permute(0, 2, 3, 1).reshape(
- num_images, -1, self.cls_out_channels) for s in cls_scores
- ], 1)
- all_labels = torch.cat(labels_list, -1).view(num_images, -1)
- all_label_weights = torch.cat(label_weights_list,
- -1).view(num_images, -1)
- all_bbox_preds = torch.cat([
- b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
- for b in bbox_preds
- ], -2)
- all_bbox_targets = torch.cat(bbox_targets_list,
- -2).view(num_images, -1, 4)
- all_bbox_weights = torch.cat(bbox_weights_list,
- -2).view(num_images, -1, 4)
-
- # concat all level anchors to a single tensor
- all_anchors = []
- for i in range(num_images):
- all_anchors.append(torch.cat(anchor_list[i]))
-
- # check NaN and Inf
- assert torch.isfinite(all_cls_scores).all().item(), \
- 'classification scores become infinite or NaN!'
- assert torch.isfinite(all_bbox_preds).all().item(), \
- 'bbox predications become infinite or NaN!'
-
- losses_cls, losses_bbox = multi_apply(
- self.loss_single,
- all_cls_scores,
- all_bbox_preds,
- all_anchors,
- all_labels,
- all_label_weights,
- all_bbox_targets,
- all_bbox_weights,
- num_total_samples=num_total_pos)
- return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
-
-
-
@HEADS.register_module()
class JaiSSDHead(SSDHead):
- def _init_layers(self):
+
+ def __init__(self, *args, **kwargs):
+ conv_cfg = kwargs.pop('conv_cfg', None) # not supported in base class - so pop it
+ super(JaiSSDHead, self).__init__(*args, **kwargs)
+ num_anchors = self.anchor_generator.num_base_anchors
+
reg_convs = []
cls_convs = []
for i in range(len(self.in_channels)):
reg_convs.append(
- ConvModule(
+ ConvModuleWrapper(
self.in_channels[i],
- self.num_anchors[i] * 4,
+ num_anchors[i] * 4,
kernel_size=3,
padding=1,
- conv_cfg=self.conv_cfg,
+ conv_cfg=conv_cfg,
norm_cfg=None,
act_cfg=None))
cls_convs.append(
- ConvModule(
+ ConvModuleWrapper(
self.in_channels[i],
- self.num_anchors[i] * (self.num_classes + 1),
+ num_anchors[i] * (self.num_classes + 1),
kernel_size=3,
padding=1,
- conv_cfg=self.conv_cfg,
+ conv_cfg=conv_cfg,
norm_cfg=None,
act_cfg=None))
self.reg_convs = nn.ModuleList(reg_convs)
- self.cls_convs = nn.ModuleList(cls_convs)
\ No newline at end of file
+ self.cls_convs = nn.ModuleList(cls_convs)
index 36402c815998dd28356dfea829a09a9d0049e1cf..134b35bd7ee49095361e0449a54a97bd0bb27d8f 100644 (file)
from mmdet.models.necks import *
-from .fpn import InLoopFPN
+from .fpn import JaiFPN, JaiInLoopFPN
index f773a01c569135e98d1a784078d13317f683ef16..8931f0b02b96b5de5829f81850b34e6fe7463ccf 100644 (file)
from mmdet.core import auto_fp16
from mmdet.models.builder import NECKS
+from ...ops import ConvModuleWrapper
from pytorch_jacinto_ai import xnn
-@NECKS.register_module(force=True)
-class FPN(nn.Module):
+@NECKS.register_module()
+class JaiFPN(nn.Module):
"""
Feature Pyramid Network.
norm_cfg=None,
act_cfg=None,
upsample_cfg=dict(mode='nearest')):
- super(FPN, self).__init__()
+ super(JaiFPN, self).__init__()
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels
self.fpn_convs = nn.ModuleList()
for i in range(self.start_level, self.backbone_end_level):
- l_conv = ConvModule(
+ l_conv = ConvModuleWrapper(
in_channels[i],
out_channels,
1,
norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
act_cfg=act_cfg,
inplace=False)
- fpn_conv = ConvModule(
+ fpn_conv = ConvModuleWrapper(
out_channels,
out_channels,
3,
in_channels = self.in_channels[self.backbone_end_level - 1]
else:
in_channels = out_channels
- extra_fpn_conv = ConvModule(
+ extra_fpn_conv = ConvModuleWrapper(
in_channels,
out_channels,
3,
@NECKS.register_module()
-class InLoopFPN(FPN):
+class JaiInLoopFPN(JaiFPN):
@auto_fp16()
def forward(self, inputs):
assert len(inputs) == len(self.in_channels)
index 0a13ee2dfaec8888d05533a78356bc378efa04fe..8a215f2046be9236cb20145afca55463f1f6d50a 100644 (file)
return y
+def ConvModuleWrapper(*args, **kwargs):
+ conv_cfg = kwargs.get('conv_cfg', dict(type=None))
+ has_type = conv_cfg and ('type' in conv_cfg)
+
+ kernel_size = kwargs.get('kernel_size', None)
+ kernel_size = kernel_size or args[2]
+ assert kernel_size is not None, 'kernel_size must be specified'
+
+ if not has_type:
+ return mmcv.cnn.ConvModule(*args, **kwargs)
+ elif conv_cfg.type == 'ConvNormAct' or (conv_cfg.type == 'ConvDWSep' and kernel_size == 1):
+ return ConvNormAct2d(*args, **kwargs)
+ elif conv_cfg.type == 'ConvDWSep':
+ return ConvDWSep2d(*args, **kwargs)
+ elif conv_cfg.type == 'ConvDWTriplet':
+ return ConvDWTriplet2d(*args, **kwargs)
+ else:
+ return mmcv.cnn.ConvModule(*args, **kwargs)
-################################################################################
-if not hasattr(mmcv.cnn, '_ConvModule'):
- # first, backup the original ConvModule. this will be called inside the wrapper
- mmcv.cnn._ConvModule = mmcv.cnn.ConvModule
-
-
- def ConvModuleWrapper(*args, **kwargs):
- conv_cfg = kwargs.get('conv_cfg', dict(type=None))
- has_type = conv_cfg and ('type' in conv_cfg)
-
- kernel_size = kwargs.get('kernel_size', None)
- kernel_size = kernel_size or args[2]
- assert kernel_size is not None, 'kernel_size must be specified'
-
- if not has_type:
- return mmcv.cnn._ConvModule(*args, **kwargs)
- elif conv_cfg.type == 'ConvNormAct' or (conv_cfg.type == 'ConvDWSep' and kernel_size == 1):
- return ConvNormAct2d(*args, **kwargs)
- elif conv_cfg.type == 'ConvDWSep':
- return ConvDWSep2d(*args, **kwargs)
- elif conv_cfg.type == 'ConvDWTriplet':
- return ConvDWTriplet2d(*args, **kwargs)
- else:
- return mmcv.cnn._ConvModule(*args, **kwargs)
-
- # finally, replace the original ConvModule with ConvModuleWrapper
- mmcv.cnn.ConvModule = ConvModuleWrapper
-#
\ No newline at end of file