aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorManu Mathew2021-01-23 03:09:38 -0600
committerManu Mathew2021-01-23 03:24:51 -0600
commit9588c89b29fd8fcd862c76b964833bdd553d56c6 (patch)
tree067409493e892a2cb45cd33dbd43787c5d897bbe
parent7012ecf747c74a4600329bfc263536bbcbcb2175 (diff)
downloadpytorch-mmdetection-9588c89b29fd8fcd862c76b964833bdd553d56c6.tar.gz
pytorch-mmdetection-9588c89b29fd8fcd862c76b964833bdd553d56c6.tar.xz
pytorch-mmdetection-9588c89b29fd8fcd862c76b964833bdd553d56c6.zip
yolov3 lite version has been added
-rw-r--r--scripts/detection_configs.py1
-rw-r--r--xmmdet/models/backbones/__init__.py2
-rw-r--r--xmmdet/models/backbones/darknet.py199
-rw-r--r--xmmdet/models/dense_heads/__init__.py1
-rw-r--r--xmmdet/models/dense_heads/yolo_head.py52
-rw-r--r--xmmdet/models/necks/__init__.py2
-rw-r--r--xmmdet/models/necks/yolo_neck.py142
-rw-r--r--xmmdet/ops/conv_wrapper.py2
8 files changed, 396 insertions, 5 deletions
diff --git a/scripts/detection_configs.py b/scripts/detection_configs.py
index 2a330685..0eeda542 100644
--- a/scripts/detection_configs.py
+++ b/scripts/detection_configs.py
@@ -16,6 +16,7 @@ config='./configs/retinanet/retinanet-lite_regnet_fpn_bgr.py'
16 16
17config='./configs/yolo/yolov3_d53.py' 17config='./configs/yolo/yolov3_d53.py'
18config='./configs/yolo/yolov3_d53_relu.py' 18config='./configs/yolo/yolov3_d53_relu.py'
19config='./configs/yolo/yolov3-lite_d53.py'
19''' 20'''
20 21
21config='./configs/ssd/ssd-lite_regnet_fpn_bgr.py' 22config='./configs/ssd/ssd-lite_regnet_fpn_bgr.py'
diff --git a/xmmdet/models/backbones/__init__.py b/xmmdet/models/backbones/__init__.py
index 9921d2f7..ccb94a12 100644
--- a/xmmdet/models/backbones/__init__.py
+++ b/xmmdet/models/backbones/__init__.py
@@ -4,3 +4,5 @@ from .mobilenetv2 import MobileNetV2
4from .resnet import ResNet 4from .resnet import ResNet
5from .regnet import RegNet 5from .regnet import RegNet
6from .resnext import ResNeXt 6from .resnext import ResNeXt
7from .darknet import DarknetLite
8
diff --git a/xmmdet/models/backbones/darknet.py b/xmmdet/models/backbones/darknet.py
new file mode 100644
index 00000000..a9a5667b
--- /dev/null
+++ b/xmmdet/models/backbones/darknet.py
@@ -0,0 +1,199 @@
1# Copyright (c) 2019 Western Digital Corporation or its affiliates.
2
3import logging
4
5from mmdet.models.backbones.darknet import *
6
7import torch.nn as nn
8from mmcv.cnn import ConvModule, constant_init, kaiming_init
9from mmcv.runner import load_checkpoint
10from torch.nn.modules.batchnorm import _BatchNorm
11
12from ...ops import ConvModuleWrapper
13
14class ResBlockLite(nn.Module):
15 """The basic residual block used in Darknet. Each ResBlock consists of two
16 ConvModules and the input is added to the final output. Each ConvModule is
17 composed of Conv, BN, and LeakyReLU. In YoloV3 paper, the first convLayer
18 has half of the number of the filters as much as the second convLayer. The
19 first convLayer has filter size of 1x1 and the second one has the filter
20 size of 3x3.
21
22 Args:
23 in_channels (int): The input channels. Must be even.
24 conv_cfg (dict): Config dict for convolution layer. Default: None.
25 norm_cfg (dict): Dictionary to construct and config norm layer.
26 Default: dict(type='BN', requires_grad=True)
27 act_cfg (dict): Config dict for activation layer.
28 Default: dict(type='LeakyReLU', negative_slope=0.1).
29 """
30
31 def __init__(self,
32 in_channels,
33 conv_cfg=None,
34 norm_cfg=dict(type='BN', requires_grad=True),
35 act_cfg=dict(type='ReLU')):
36 super(ResBlockLite, self).__init__()
37 assert in_channels % 2 == 0 # ensure the in_channels is even
38 half_in_channels = in_channels // 2
39
40 # shortcut
41 cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
42
43 self.conv1 = ConvModuleWrapper(in_channels, half_in_channels, 1, **cfg)
44 self.conv2 = ConvModuleWrapper(
45 half_in_channels, in_channels, 3, padding=1, **cfg)
46
47 def forward(self, x):
48 residual = x
49 out = self.conv1(x)
50 out = self.conv2(out)
51 out = out + residual
52
53 return out
54
55
56@BACKBONES.register_module()
57class DarknetLite(nn.Module):
58 """DarknetLite backbone.
59
60 Args:
61 depth (int): Depth of Darknet. Currently only support 53.
62 out_indices (Sequence[int]): Output from which stages.
63 frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
64 -1 means not freezing any parameters. Default: -1.
65 conv_cfg (dict): Config dict for convolution layer. Default: None.
66 norm_cfg (dict): Dictionary to construct and config norm layer.
67 Default: dict(type='BN', requires_grad=True)
68 act_cfg (dict): Config dict for activation layer.
69 Default: dict(type='LeakyReLU', negative_slope=0.1).
70 norm_eval (bool): Whether to set norm layers to eval mode, namely,
71 freeze running stats (mean and var). Note: Effect on Batch Norm
72 and its variants only.
73
74 Example:
75 >>> from mmdet.models import Darknet
76 >>> import torch
77 >>> self = Darknet(depth=53)
78 >>> self.eval()
79 >>> inputs = torch.rand(1, 3, 416, 416)
80 >>> level_outputs = self.forward(inputs)
81 >>> for level_out in level_outputs:
82 ... print(tuple(level_out.shape))
83 ...
84 (1, 256, 52, 52)
85 (1, 512, 26, 26)
86 (1, 1024, 13, 13)
87 """
88
89 # Dict(depth: (layers, channels))
90 arch_settings = {
91 53: ((1, 2, 8, 8, 4), ((32, 64), (64, 128), (128, 256), (256, 512),
92 (512, 1024)))
93 }
94
95 def __init__(self,
96 depth=53,
97 out_indices=(3, 4, 5),
98 frozen_stages=-1,
99 conv_cfg=None,
100 norm_cfg=dict(type='BN', requires_grad=True),
101 act_cfg=dict(type='ReLU'),
102 norm_eval=True):
103 super(DarknetLite, self).__init__()
104 if depth not in self.arch_settings:
105 raise KeyError(f'invalid depth {depth} for darknet')
106 self.depth = depth
107 self.out_indices = out_indices
108 self.frozen_stages = frozen_stages
109 self.layers, self.channels = self.arch_settings[depth]
110
111 cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
112 cfg1 = dict(conv_cfg=None, norm_cfg=norm_cfg, act_cfg=act_cfg)
113 self.conv1 = ConvModule(3, 32, 3, padding=1, **cfg1)
114
115 self.cr_blocks = ['conv1']
116 for i, n_layers in enumerate(self.layers):
117 layer_name = f'conv_res_block{i + 1}'
118 in_c, out_c = self.channels[i]
119 self.add_module(
120 layer_name,
121 self.make_conv_res_block(in_c, out_c, n_layers, **cfg))
122 self.cr_blocks.append(layer_name)
123
124 self.norm_eval = norm_eval
125
126 def forward(self, x):
127 outs = []
128 for i, layer_name in enumerate(self.cr_blocks):
129 cr_block = getattr(self, layer_name)
130 x = cr_block(x)
131 if i in self.out_indices:
132 outs.append(x)
133
134 return tuple(outs)
135
136 def init_weights(self, pretrained=None):
137 if isinstance(pretrained, str):
138 logger = logging.getLogger()
139 load_checkpoint(self, pretrained, strict=False, logger=logger)
140 elif pretrained is None:
141 for m in self.modules():
142 if isinstance(m, nn.Conv2d):
143 kaiming_init(m)
144 elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
145 constant_init(m, 1)
146
147 else:
148 raise TypeError('pretrained must be a str or None')
149
150 def _freeze_stages(self):
151 if self.frozen_stages >= 0:
152 for i in range(self.frozen_stages):
153 m = getattr(self, self.cr_blocks[i])
154 m.eval()
155 for param in m.parameters():
156 param.requires_grad = False
157
158 def train(self, mode=True):
159 super(DarknetLite, self).train(mode)
160 self._freeze_stages()
161 if mode and self.norm_eval:
162 for m in self.modules():
163 if isinstance(m, _BatchNorm):
164 m.eval()
165
166 @staticmethod
167 def make_conv_res_block(in_channels,
168 out_channels,
169 res_repeat,
170 conv_cfg=None,
171 norm_cfg=dict(type='BN', requires_grad=True),
172 act_cfg=dict(type='ReLU')):
173 """In Darknet backbone, ConvLayer is usually followed by ResBlock. This
174 function will make that. The Conv layers always have 3x3 filters with
175 stride=2. The number of the filters in Conv layer is the same as the
176 out channels of the ResBlock.
177
178 Args:
179 in_channels (int): The number of input channels.
180 out_channels (int): The number of output channels.
181 res_repeat (int): The number of ResBlocks.
182 conv_cfg (dict): Config dict for convolution layer. Default: None.
183 norm_cfg (dict): Dictionary to construct and config norm layer.
184 Default: dict(type='BN', requires_grad=True)
185 act_cfg (dict): Config dict for activation layer.
186 Default: dict(type='LeakyReLU', negative_slope=0.1).
187 """
188
189 cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
190
191 model = nn.Sequential()
192 model.add_module(
193 'conv',
194 ConvModuleWrapper(
195 in_channels, out_channels, 3, stride=2, padding=1, **cfg))
196 for idx in range(res_repeat):
197 model.add_module('res{}'.format(idx),
198 ResBlockLite(out_channels, **cfg))
199 return model
diff --git a/xmmdet/models/dense_heads/__init__.py b/xmmdet/models/dense_heads/__init__.py
index b9f02773..f1704a27 100644
--- a/xmmdet/models/dense_heads/__init__.py
+++ b/xmmdet/models/dense_heads/__init__.py
@@ -2,3 +2,4 @@ from mmdet.models.dense_heads import *
2from .ssd_head import SSDLiteHead 2from .ssd_head import SSDLiteHead
3from .retina_head import RetinaLiteHead 3from .retina_head import RetinaLiteHead
4from .fcos_head import FCOSLiteHead 4from .fcos_head import FCOSLiteHead
5from .yolo_head import YOLOV3LiteHead
diff --git a/xmmdet/models/dense_heads/yolo_head.py b/xmmdet/models/dense_heads/yolo_head.py
new file mode 100644
index 00000000..ffeecebf
--- /dev/null
+++ b/xmmdet/models/dense_heads/yolo_head.py
@@ -0,0 +1,52 @@
1# Copyright (c) 2019 Western Digital Corporation or its affiliates.
2
3import warnings
4
5import torch
6import torch.nn as nn
7import torch.nn.functional as F
8from mmcv.cnn import ConvModule, normal_init
9from mmcv.runner import force_fp32
10from mmcv.cnn import constant_init, kaiming_init
11
12from mmdet.core import (build_anchor_generator, build_assigner,
13 build_bbox_coder, build_sampler, images_to_levels,
14 multi_apply, multiclass_nms)
15from mmdet.models.builder import HEADS
16from mmdet.models.losses import smooth_l1_loss
17from mmdet.models.dense_heads.yolo_head import YOLOV3Head
18from ...ops import ConvModuleWrapper
19
20
21@HEADS.register_module()
22class YOLOV3LiteHead(YOLOV3Head):
23 """YOLOV3LiteHead Paper link: https://arxiv.org/abs/1804.02767.
24 """
25 def __init__(self, *args, **kwargs):
26 super(YOLOV3LiteHead, self).__init__(*args, **kwargs)
27
28 def _init_layers(self):
29 self.convs_bridge = nn.ModuleList()
30 self.convs_pred = nn.ModuleList()
31 for i in range(self.num_levels):
32 conv_bridge = ConvModuleWrapper(
33 self.in_channels[i],
34 self.out_channels[i],
35 3,
36 padding=1,
37 conv_cfg=self.conv_cfg,
38 norm_cfg=self.norm_cfg,
39 act_cfg=self.act_cfg)
40 conv_pred = nn.Conv2d(self.out_channels[i],
41 self.num_anchors * self.num_attrib, 1)
42
43 self.convs_bridge.append(conv_bridge)
44 self.convs_pred.append(conv_pred)
45
46 def init_weights(self):
47 """Initialize weights of the head."""
48 for m in self.modules():
49 if isinstance(m, nn.Conv2d):
50 kaiming_init(m)
51 elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
52 constant_init(m, 1)
diff --git a/xmmdet/models/necks/__init__.py b/xmmdet/models/necks/__init__.py
index b098fd74..0b25f7ab 100644
--- a/xmmdet/models/necks/__init__.py
+++ b/xmmdet/models/necks/__init__.py
@@ -1,4 +1,4 @@
1from mmdet.models.necks import * 1from mmdet.models.necks import *
2from .fpn import FPNLite, BiFPNLite 2from .fpn import FPNLite, BiFPNLite
3from .yolo_neck import YOLOV3Neck 3from .yolo_neck import YOLOV3Neck, YOLOV3LiteNeck
4 4
diff --git a/xmmdet/models/necks/yolo_neck.py b/xmmdet/models/necks/yolo_neck.py
index 890432d7..ea73f8d1 100644
--- a/xmmdet/models/necks/yolo_neck.py
+++ b/xmmdet/models/necks/yolo_neck.py
@@ -4,9 +4,11 @@ import torch
4import torch.nn as nn 4import torch.nn as nn
5import torch.nn.functional as F 5import torch.nn.functional as F
6from mmcv.cnn import ConvModule 6from mmcv.cnn import ConvModule
7from mmcv.cnn import constant_init, kaiming_init
7 8
8from mmdet.models.builder import NECKS 9from mmdet.models.builder import NECKS
9from mmdet.models.necks.yolo_neck import DetectionBlock 10from mmdet.models.necks.yolo_neck import DetectionBlock
11from ...ops import ConvModuleWrapper
10 12
11from pytorch_jacinto_ai import xnn 13from pytorch_jacinto_ai import xnn
12 14
@@ -84,6 +86,140 @@ class YOLOV3Neck(nn.Module):
84 return tuple(outs) 86 return tuple(outs)
85 87
86 def init_weights(self): 88 def init_weights(self):
87 """Initialize the weights of module.""" 89 """Initialize weights of the head."""
88 # init is done in ConvModule 90 for m in self.modules():
89 pass 91 if isinstance(m, nn.Conv2d):
92 kaiming_init(m)
93 elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
94 constant_init(m, 1)
95
96
97class DetectionLiteBlock(nn.Module):
98 """Detection block in YOLO neck.
99
100 Let out_channels = n, the DetectionBlock contains:
101 Six ConvLayers, 1 Conv2D Layer and 1 YoloLayer.
102 The first 6 ConvLayers are formed the following way:
103 1x1xn, 3x3x2n, 1x1xn, 3x3x2n, 1x1xn, 3x3x2n.
104 The Conv2D layer is 1x1x255.
105 Some block will have branch after the fifth ConvLayer.
106 The input channel is arbitrary (in_channels)
107
108 Args:
109 in_channels (int): The number of input channels.
110 out_channels (int): The number of output channels.
111 conv_cfg (dict): Config dict for convolution layer. Default: None.
112 norm_cfg (dict): Dictionary to construct and config norm layer.
113 Default: dict(type='BN', requires_grad=True)
114 act_cfg (dict): Config dict for activation layer.
115 Default: dict(type='LeakyReLU', negative_slope=0.1).
116 """
117
118 def __init__(self,
119 in_channels,
120 out_channels,
121 conv_cfg=None,
122 norm_cfg=dict(type='BN', requires_grad=True),
123 act_cfg=dict(type='ReLU')):
124 super(DetectionLiteBlock, self).__init__()
125 double_out_channels = out_channels * 2
126
127 # shortcut
128 cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
129 self.conv1 = ConvModuleWrapper(in_channels, out_channels, 1, **cfg)
130 self.conv2 = ConvModuleWrapper(
131 out_channels, double_out_channels, 3, padding=1, **cfg)
132 self.conv3 = ConvModuleWrapper(double_out_channels, out_channels, 1, **cfg)
133 self.conv4 = ConvModuleWrapper(
134 out_channels, double_out_channels, 3, padding=1, **cfg)
135 self.conv5 = ConvModuleWrapper(double_out_channels, out_channels, 1, **cfg)
136
137 def forward(self, x):
138 tmp = self.conv1(x)
139 tmp = self.conv2(tmp)
140 tmp = self.conv3(tmp)
141 tmp = self.conv4(tmp)
142 out = self.conv5(tmp)
143 return out
144
145
146@NECKS.register_module(force=True)
147class YOLOV3LiteNeck(nn.Module):
148 """The neck of YOLOV3.
149
150 It can be treated as a simplified version of FPN. It
151 will take the result from Darknet backbone and do some upsampling and
152 concatenation. It will finally output the detection result.
153
154 Note:
155 The input feats should be from top to bottom.
156 i.e., from high-lvl to low-lvl
157 But YOLOV3Neck will process them in reversed order.
158 i.e., from bottom (high-lvl) to top (low-lvl)
159
160 Args:
161 num_scales (int): The number of scales / stages.
162 in_channels (int): The number of input channels.
163 out_channels (int): The number of output channels.
164 conv_cfg (dict): Config dict for convolution layer. Default: None.
165 norm_cfg (dict): Dictionary to construct and config norm layer.
166 Default: dict(type='BN', requires_grad=True)
167 act_cfg (dict): Config dict for activation layer.
168 Default: dict(type='LeakyReLU', negative_slope=0.1).
169 """
170
171 def __init__(self,
172 num_scales,
173 in_channels,
174 out_channels,
175 conv_cfg=None,
176 norm_cfg=dict(type='BN', requires_grad=True),
177 act_cfg=dict(type='ReLU')):
178 super(YOLOV3LiteNeck, self).__init__()
179 assert (num_scales == len(in_channels) == len(out_channels))
180 self.num_scales = num_scales
181 self.in_channels = in_channels
182 self.out_channels = out_channels
183
184 # shortcut
185 cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
186
187 # To support arbitrary scales, the code looks awful, but it works.
188 # Better solution is welcomed.
189 self.detect1 = DetectionLiteBlock(in_channels[0], out_channels[0], **cfg)
190 for i in range(1, self.num_scales):
191 in_c, out_c = self.in_channels[i], self.out_channels[i]
192 self.add_module(f'conv{i}', ConvModuleWrapper(in_c, out_c, 1, **cfg))
193 # in_c + out_c : High-lvl feats will be cat with low-lvl feats
194 self.add_module(f'detect{i+1}',
195 DetectionLiteBlock(in_c + out_c, out_c, **cfg))
196
197 def forward(self, feats):
198 assert len(feats) == self.num_scales
199
200 # processed from bottom (high-lvl) to top (low-lvl)
201 outs = []
202 out = self.detect1(feats[-1])
203 outs.append(out)
204
205 for i, x in enumerate(reversed(feats[:-1])):
206 conv = getattr(self, f'conv{i + 1}')
207 tmp = conv(out)
208
209 # Cat with low-lvl feats
210 tmp = xnn.layers.resize_with(tmp, scale_factor=2)
211 tmp = torch.cat((tmp, x), 1)
212
213 detect = getattr(self, f'detect{i + 2}')
214 out = detect(tmp)
215 outs.append(out)
216
217 return tuple(outs)
218
219 def init_weights(self):
220 """Initialize weights of the head."""
221 for m in self.modules():
222 if isinstance(m, nn.Conv2d):
223 kaiming_init(m)
224 elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
225 constant_init(m, 1)
diff --git a/xmmdet/ops/conv_wrapper.py b/xmmdet/ops/conv_wrapper.py
index 14cb04f7..a08520a9 100644
--- a/xmmdet/ops/conv_wrapper.py
+++ b/xmmdet/ops/conv_wrapper.py
@@ -90,7 +90,7 @@ def ConvModuleWrapper(*args, **kwargs):
90 kernel_size = kernel_size or args[2] 90 kernel_size = kernel_size or args[2]
91 assert kernel_size is not None, 'kernel_size must be specified' 91 assert kernel_size is not None, 'kernel_size must be specified'
92 92
93 is_dw_conv = conv_cfg is not None and conv_cfg.type in ('ConvDWSep', 'ConvDWTripletRes') 93 is_dw_conv = conv_cfg is not None and conv_cfg.type in ('ConvDWSep', 'ConvDWTripletRes', 'ConvDWTripletAlwaysRes')
94 if not has_type: 94 if not has_type:
95 return mmcv.cnn.ConvModule(*args, **kwargs) 95 return mmcv.cnn.ConvModule(*args, **kwargs)
96 elif conv_cfg.type == 'ConvNormAct' or (is_dw_conv and kernel_size == 1): 96 elif conv_cfg.type == 'ConvNormAct' or (is_dw_conv and kernel_size == 1):