aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--README.md4
-rw-r--r--tools/test.py10
-rw-r--r--xmmdet/__init__.py6
-rw-r--r--xmmdet/models/__init__.py3
-rw-r--r--xmmdet/models/backbones/resnet.py4
-rw-r--r--xmmdet/models/dense_heads/fcos_head.py2
-rw-r--r--xmmdet/models/dense_heads/retina_head.py2
-rw-r--r--xmmdet/models/dense_heads/ssd_head.py2
-rw-r--r--xmmdet/models/necks/fpn.py2
-rw-r--r--xmmdet/utils/__init__.py4
-rw-r--r--xmmdet/utils/flops_counter.py444
11 files changed, 470 insertions, 13 deletions
diff --git a/README.md b/README.md
index 5066e034..61b2817f 100644
--- a/README.md
+++ b/README.md
@@ -15,7 +15,9 @@ This repository is released under the following [LICENSE](./LICENSE).
15 15
16## Installation 16## Installation
17 17
18Please refer to [mmdetection install.md](https://github.com/open-mmlab/mmdetection/docs/install.md) for installation and dataset preparation. 18Please refer to [mmdetection install.md](https://github.com/open-mmlab/mmdetection/docs/install.md) for installation and dataset preparation.
19
20We used the the version **v2.1.0** of mmdetection to test our changes. If you get any issues with the master branch of mmdetection, try checking out that tag.
19 21
20After installing mmdetection, please install [PyTorch-Jacinto-AI-DevKit](https://bitbucket.itg.ti.com/projects/JACINTO-AI/repos/pytorch-jacinto-ai-devkit/browse/) as our repository uses several components from there - especially to define low complexity models and to Quantization Aware Training (QAT). 22After installing mmdetection, please install [PyTorch-Jacinto-AI-DevKit](https://bitbucket.itg.ti.com/projects/JACINTO-AI/repos/pytorch-jacinto-ai-devkit/browse/) as our repository uses several components from there - especially to define low complexity models and to Quantization Aware Training (QAT).
21 23
diff --git a/tools/test.py b/tools/test.py
index e5c36446..803e3860 100644
--- a/tools/test.py
+++ b/tools/test.py
@@ -8,11 +8,11 @@ from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
8from mmcv.runner import get_dist_info, init_dist, load_checkpoint 8from mmcv.runner import get_dist_info, init_dist, load_checkpoint
9from tools.fuse_conv_bn import fuse_module 9from tools.fuse_conv_bn import fuse_module
10 10
11from mmdet.apis import multi_gpu_test, single_gpu_test 11from xmmdet.apis import multi_gpu_test, single_gpu_test
12from mmdet.core import wrap_fp16_model 12from xmmdet.core import wrap_fp16_model
13from mmdet.datasets import build_dataloader, build_dataset 13from xmmdet.datasets import build_dataloader, build_dataset
14from mmdet.models import build_detector 14from xmmdet.models import build_detector
15from mmdet.utils import MMDetQuantTestModule, save_model_proto, mmdet_load_checkpoint 15from xmmdet.utils import MMDetQuantTestModule, save_model_proto, mmdet_load_checkpoint
16 16
17from pytorch_jacinto_ai import xnn 17from pytorch_jacinto_ai import xnn
18 18
diff --git a/xmmdet/__init__.py b/xmmdet/__init__.py
index f9d66cc7..325fb75f 100644
--- a/xmmdet/__init__.py
+++ b/xmmdet/__init__.py
@@ -1,2 +1,8 @@
1from mmdet import * 1from mmdet import *
2from .ops import *
3from .core import *
4from .datasets import *
5from .models import *
6from .utils import *
7from .apis import *
2 8
diff --git a/xmmdet/models/__init__.py b/xmmdet/models/__init__.py
index 608be121..718b897e 100644
--- a/xmmdet/models/__init__.py
+++ b/xmmdet/models/__init__.py
@@ -1 +1,4 @@
1from mmdet.models import * 1from mmdet.models import *
2from .backbones import *
3from .dense_heads import *
4from .necks import *
diff --git a/xmmdet/models/backbones/resnet.py b/xmmdet/models/backbones/resnet.py
index 08a26836..5048a4e9 100644
--- a/xmmdet/models/backbones/resnet.py
+++ b/xmmdet/models/backbones/resnet.py
@@ -290,7 +290,7 @@ class Bottleneck(nn.Module):
290 return out 290 return out
291 291
292 292
293@BACKBONES.register_module() 293@BACKBONES.register_module(force=True)
294class ResNet(nn.Module): 294class ResNet(nn.Module):
295 """ResNet backbone. 295 """ResNet backbone.
296 296
@@ -619,7 +619,7 @@ class ResNet(nn.Module):
619 m.eval() 619 m.eval()
620 620
621 621
622@BACKBONES.register_module() 622@BACKBONES.register_module(force=True)
623class ResNetV1d(ResNet): 623class ResNetV1d(ResNet):
624 """ResNetV1d variant described in 624 """ResNetV1d variant described in
625 `Bag of Tricks <https://arxiv.org/pdf/1812.01187.pdf>`_. 625 `Bag of Tricks <https://arxiv.org/pdf/1812.01187.pdf>`_.
diff --git a/xmmdet/models/dense_heads/fcos_head.py b/xmmdet/models/dense_heads/fcos_head.py
index 1961245c..27e15c3d 100644
--- a/xmmdet/models/dense_heads/fcos_head.py
+++ b/xmmdet/models/dense_heads/fcos_head.py
@@ -8,7 +8,7 @@ from mmdet.models.builder import HEADS, build_loss
8INF = 1e8 8INF = 1e8
9 9
10 10
11@HEADS.register_module() 11@HEADS.register_module(force=True)
12class FCOSHead(nn.Module): 12class FCOSHead(nn.Module):
13 """Anchor-free head used in `FCOS <https://arxiv.org/abs/1904.01355>`_. 13 """Anchor-free head used in `FCOS <https://arxiv.org/abs/1904.01355>`_.
14 14
diff --git a/xmmdet/models/dense_heads/retina_head.py b/xmmdet/models/dense_heads/retina_head.py
index 68232c45..27031acd 100644
--- a/xmmdet/models/dense_heads/retina_head.py
+++ b/xmmdet/models/dense_heads/retina_head.py
@@ -5,7 +5,7 @@ from mmdet.models.builder import HEADS
5from mmdet.models.dense_heads.anchor_head import AnchorHead 5from mmdet.models.dense_heads.anchor_head import AnchorHead
6 6
7 7
8@HEADS.register_module() 8@HEADS.register_module(force=True)
9class RetinaHead(AnchorHead): 9class RetinaHead(AnchorHead):
10 """An anchor-based head used in 10 """An anchor-based head used in
11 `RetinaNet <https://arxiv.org/pdf/1708.02002.pdf>`_. 11 `RetinaNet <https://arxiv.org/pdf/1708.02002.pdf>`_.
diff --git a/xmmdet/models/dense_heads/ssd_head.py b/xmmdet/models/dense_heads/ssd_head.py
index d5f4658e..486343f1 100644
--- a/xmmdet/models/dense_heads/ssd_head.py
+++ b/xmmdet/models/dense_heads/ssd_head.py
@@ -11,7 +11,7 @@ from mmdet.models.dense_heads.anchor_head import AnchorHead
11 11
12 12
13# TODO: add loss evaluator for SSD 13# TODO: add loss evaluator for SSD
14@HEADS.register_module() 14@HEADS.register_module(force=True)
15class SSDHead(AnchorHead): 15class SSDHead(AnchorHead):
16 16
17 def __init__(self, 17 def __init__(self,
diff --git a/xmmdet/models/necks/fpn.py b/xmmdet/models/necks/fpn.py
index c55ef05e..f773a01c 100644
--- a/xmmdet/models/necks/fpn.py
+++ b/xmmdet/models/necks/fpn.py
@@ -8,7 +8,7 @@ from mmdet.models.builder import NECKS
8from pytorch_jacinto_ai import xnn 8from pytorch_jacinto_ai import xnn
9 9
10 10
11@NECKS.register_module() 11@NECKS.register_module(force=True)
12class FPN(nn.Module): 12class FPN(nn.Module):
13 """ 13 """
14 Feature Pyramid Network. 14 Feature Pyramid Network.
diff --git a/xmmdet/utils/__init__.py b/xmmdet/utils/__init__.py
index 062baecd..2ce3a842 100644
--- a/xmmdet/utils/__init__.py
+++ b/xmmdet/utils/__init__.py
@@ -1,6 +1,8 @@
1from mmdet.utils import * 1from mmdet.utils import *
2from .flops_counter import get_model_complexity_info
2from .logger import LoggerStream, get_root_logger 3from .logger import LoggerStream, get_root_logger
3from .runner import MMDetRunner, MMDetNoOptimizerHook 4from .runner import MMDetRunner, MMDetNoOptimizerHook, \
5 mmdet_load_checkpoint, mmdet_save_checkpoint
4from .save_model import save_model_proto 6from .save_model import save_model_proto
5from .quantize import MMDetQuantTrainModule, MMDetQuantCalibrateModule, \ 7from .quantize import MMDetQuantTrainModule, MMDetQuantCalibrateModule, \
6 MMDetQuantTestModule, is_mmdet_quant_module 8 MMDetQuantTestModule, is_mmdet_quant_module
diff --git a/xmmdet/utils/flops_counter.py b/xmmdet/utils/flops_counter.py
new file mode 100644
index 00000000..04f27f8b
--- /dev/null
+++ b/xmmdet/utils/flops_counter.py
@@ -0,0 +1,444 @@
1# Modified from flops-counter.pytorch by Vladislav Sovrasov
2# original repo: https://github.com/sovrasov/flops-counter.pytorch
3
4# MIT License
5
6# Copyright (c) 2018 Vladislav Sovrasov
7
8# Permission is hereby granted, free of charge, to any person obtaining a copy
9# of this software and associated documentation files (the "Software"), to deal
10# in the Software without restriction, including without limitation the rights
11# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12# copies of the Software, and to permit persons to whom the Software is
13# furnished to do so, subject to the following conditions:
14
15# The above copyright notice and this permission notice shall be included in
16# all copies or substantial portions of the Software.
17
18# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24# SOFTWARE.
25
26import sys
27
28import numpy as np
29import torch
30import torch.nn as nn
31from torch.nn.modules.batchnorm import _BatchNorm
32from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin
33from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd,
34 _AvgPoolNd, _MaxPoolNd)
35
36
37def get_model_complexity_info(model,
38 input_res,
39 print_per_layer_stat=True,
40 as_strings=True,
41 input_constructor=None,
42 ost=sys.stdout):
43 assert type(input_res) is tuple
44 assert len(input_res) >= 2
45 flops_model = add_flops_counting_methods(model)
46 flops_model.eval().start_flops_count()
47 if input_constructor:
48 input = input_constructor(input_res)
49 _ = flops_model(**input)
50 else:
51 batch = torch.ones(()).new_empty(
52 (1, *input_res),
53 dtype=next(flops_model.parameters()).dtype,
54 device=next(flops_model.parameters()).device)
55 flops_model.forward_dummy(batch)
56
57 if print_per_layer_stat:
58 print_model_with_flops(flops_model, ost=ost)
59 flops_count = flops_model.compute_average_flops_cost()
60 params_count = get_model_parameters_number(flops_model)
61 flops_model.stop_flops_count()
62
63 if as_strings:
64 return flops_to_string(flops_count), params_to_string(params_count)
65
66 return flops_count, params_count
67
68
69def flops_to_string(flops, units='GMac', precision=2):
70 if units is None:
71 if flops // 10**9 > 0:
72 return str(round(flops / 10.**9, precision)) + ' GMac'
73 elif flops // 10**6 > 0:
74 return str(round(flops / 10.**6, precision)) + ' MMac'
75 elif flops // 10**3 > 0:
76 return str(round(flops / 10.**3, precision)) + ' KMac'
77 else:
78 return str(flops) + ' Mac'
79 else:
80 if units == 'GMac':
81 return str(round(flops / 10.**9, precision)) + ' ' + units
82 elif units == 'MMac':
83 return str(round(flops / 10.**6, precision)) + ' ' + units
84 elif units == 'KMac':
85 return str(round(flops / 10.**3, precision)) + ' ' + units
86 else:
87 return str(flops) + ' Mac'
88
89
90def params_to_string(params_num):
91 """converting number to string
92
93 :param float params_num: number
94 :returns str: number
95
96 >>> params_to_string(1e9)
97 '1000.0 M'
98 >>> params_to_string(2e5)
99 '200.0 k'
100 >>> params_to_string(3e-9)
101 '3e-09'
102 """
103 if params_num // 10**6 > 0:
104 return str(round(params_num / 10**6, 2)) + ' M'
105 elif params_num // 10**3:
106 return str(round(params_num / 10**3, 2)) + ' k'
107 else:
108 return str(params_num)
109
110
111def print_model_with_flops(model, units='GMac', precision=3, ost=sys.stdout):
112 total_flops = model.compute_average_flops_cost()
113
114 def accumulate_flops(self):
115 if is_supported_instance(self):
116 return self.__flops__ / (model.__batch_counter__ or 1)
117 else:
118 sum = 0
119 for m in self.children():
120 sum += m.accumulate_flops()
121 return sum
122
123 def flops_repr(self):
124 accumulated_flops_cost = self.accumulate_flops()
125 return ', '.join([
126 flops_to_string(
127 accumulated_flops_cost, units=units, precision=precision),
128 f'{accumulated_flops_cost / total_flops:.3%} MACs',
129 self.original_extra_repr()
130 ])
131
132 def add_extra_repr(m):
133 m.accumulate_flops = accumulate_flops.__get__(m)
134 flops_extra_repr = flops_repr.__get__(m)
135 if m.extra_repr != flops_extra_repr:
136 m.original_extra_repr = m.extra_repr
137 m.extra_repr = flops_extra_repr
138 assert m.extra_repr != m.original_extra_repr
139
140 def del_extra_repr(m):
141 if hasattr(m, 'original_extra_repr'):
142 m.extra_repr = m.original_extra_repr
143 del m.original_extra_repr
144 if hasattr(m, 'accumulate_flops'):
145 del m.accumulate_flops
146
147 model.apply(add_extra_repr)
148 print(model, file=ost)
149 model.apply(del_extra_repr)
150
151
152def get_model_parameters_number(model):
153 params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
154 return params_num
155
156
157def add_flops_counting_methods(net_main_module):
158 # adding additional methods to the existing module object,
159 # this is done this way so that each function has access to self object
160 net_main_module.start_flops_count = start_flops_count.__get__(
161 net_main_module)
162 net_main_module.stop_flops_count = stop_flops_count.__get__(
163 net_main_module)
164 net_main_module.reset_flops_count = reset_flops_count.__get__(
165 net_main_module)
166 net_main_module.compute_average_flops_cost = \
167 compute_average_flops_cost.__get__(net_main_module)
168
169 net_main_module.reset_flops_count()
170
171 # Adding variables necessary for masked flops computation
172 net_main_module.apply(add_flops_mask_variable_or_reset)
173
174 return net_main_module
175
176
177def compute_average_flops_cost(self):
178 """
179 A method that will be available after add_flops_counting_methods() is
180 called on a desired net object.
181 Returns current mean flops consumption per image.
182 """
183
184 batches_count = (self.__batch_counter__ or 1)
185 flops_sum = 0
186 for module in self.modules():
187 if is_supported_instance(module):
188 flops_sum += module.__flops__
189
190 return flops_sum / batches_count
191
192
193def start_flops_count(self):
194 """
195 A method that will be available after add_flops_counting_methods() is
196 called on a desired net object.
197 Activates the computation of mean flops consumption per image.
198 Call it before you run the network.
199 """
200 add_batch_counter_hook_function(self)
201 self.apply(add_flops_counter_hook_function)
202
203
204def stop_flops_count(self):
205 """
206 A method that will be available after add_flops_counting_methods() is
207 called on a desired net object.
208 Stops computing the mean flops consumption per image.
209 Call whenever you want to pause the computation.
210 """
211 remove_batch_counter_hook_function(self)
212 self.apply(remove_flops_counter_hook_function)
213
214
215def reset_flops_count(self):
216 """
217 A method that will be available after add_flops_counting_methods() is
218 called on a desired net object.
219 Resets statistics computed so far.
220 """
221 add_batch_counter_variables_or_reset(self)
222 self.apply(add_flops_counter_variable_or_reset)
223
224
225def add_flops_mask(module, mask):
226
227 def add_flops_mask_func(module):
228 if isinstance(module, torch.nn.Conv2d):
229 module.__mask__ = mask
230
231 module.apply(add_flops_mask_func)
232
233
234def remove_flops_mask(module):
235 module.apply(add_flops_mask_variable_or_reset)
236
237
238def is_supported_instance(module):
239 for mod in hook_mapping:
240 if issubclass(type(module), mod):
241 return True
242 return False
243
244
245def empty_flops_counter_hook(module, input, output):
246 module.__flops__ += 0
247
248
249def upsample_flops_counter_hook(module, input, output):
250 output_size = output[0]
251 batch_size = output_size.shape[0]
252 output_elements_count = batch_size
253 for val in output_size.shape[1:]:
254 output_elements_count *= val
255 module.__flops__ += int(output_elements_count)
256
257
258def relu_flops_counter_hook(module, input, output):
259 active_elements_count = output.numel()
260 module.__flops__ += int(active_elements_count)
261
262
263def linear_flops_counter_hook(module, input, output):
264 input = input[0]
265 batch_size = input.shape[0]
266 module.__flops__ += int(batch_size * input.shape[1] * output.shape[1])
267
268
269def pool_flops_counter_hook(module, input, output):
270 input = input[0]
271 module.__flops__ += int(np.prod(input.shape))
272
273
274def bn_flops_counter_hook(module, input, output):
275 input = input[0]
276
277 batch_flops = np.prod(input.shape)
278 if module.affine:
279 batch_flops *= 2
280 module.__flops__ += int(batch_flops)
281
282
283def gn_flops_counter_hook(module, input, output):
284 elems = np.prod(input[0].shape)
285 # there is no precise FLOPs estimation of computing mean and variance,
286 # and we just set it 2 * elems: half muladds for computing
287 # means and half for computing vars
288 batch_flops = 3 * elems
289 if module.affine:
290 batch_flops += elems
291 module.__flops__ += int(batch_flops)
292
293
294def deconv_flops_counter_hook(conv_module, input, output):
295 # Can have multiple inputs, getting the first one
296 input = input[0]
297
298 batch_size = input.shape[0]
299 input_height, input_width = input.shape[2:]
300
301 kernel_height, kernel_width = conv_module.kernel_size
302 in_channels = conv_module.in_channels
303 out_channels = conv_module.out_channels
304 groups = conv_module.groups
305
306 filters_per_channel = out_channels // groups
307 conv_per_position_flops = (
308 kernel_height * kernel_width * in_channels * filters_per_channel)
309
310 active_elements_count = batch_size * input_height * input_width
311 overall_conv_flops = conv_per_position_flops * active_elements_count
312 bias_flops = 0
313 if conv_module.bias is not None:
314 output_height, output_width = output.shape[2:]
315 bias_flops = out_channels * batch_size * output_height * output_height
316 overall_flops = overall_conv_flops + bias_flops
317
318 conv_module.__flops__ += int(overall_flops)
319
320
321def conv_flops_counter_hook(conv_module, input, output):
322 # Can have multiple inputs, getting the first one
323 input = input[0]
324
325 batch_size = input.shape[0]
326 output_dims = list(output.shape[2:])
327
328 kernel_dims = list(conv_module.kernel_size)
329 in_channels = conv_module.in_channels
330 out_channels = conv_module.out_channels
331 groups = conv_module.groups
332
333 filters_per_channel = out_channels // groups
334 conv_per_position_flops = np.prod(
335 kernel_dims) * in_channels * filters_per_channel
336
337 active_elements_count = batch_size * np.prod(output_dims)
338
339 if conv_module.__mask__ is not None:
340 # (b, 1, h, w)
341 output_height, output_width = output.shape[2:]
342 flops_mask = conv_module.__mask__.expand(batch_size, 1, output_height,
343 output_width)
344 active_elements_count = flops_mask.sum()
345
346 overall_conv_flops = conv_per_position_flops * active_elements_count
347
348 bias_flops = 0
349
350 if conv_module.bias is not None:
351
352 bias_flops = out_channels * active_elements_count
353
354 overall_flops = overall_conv_flops + bias_flops
355
356 conv_module.__flops__ += int(overall_flops)
357
358
359hook_mapping = {
360 # conv
361 _ConvNd: conv_flops_counter_hook,
362 # deconv
363 _ConvTransposeMixin: deconv_flops_counter_hook,
364 # fc
365 nn.Linear: linear_flops_counter_hook,
366 # pooling
367 _AvgPoolNd: pool_flops_counter_hook,
368 _MaxPoolNd: pool_flops_counter_hook,
369 _AdaptiveAvgPoolNd: pool_flops_counter_hook,
370 _AdaptiveMaxPoolNd: pool_flops_counter_hook,
371 # activation
372 nn.ReLU: relu_flops_counter_hook,
373 nn.PReLU: relu_flops_counter_hook,
374 nn.ELU: relu_flops_counter_hook,
375 nn.LeakyReLU: relu_flops_counter_hook,
376 nn.ReLU6: relu_flops_counter_hook,
377 # normalization
378 _BatchNorm: bn_flops_counter_hook,
379 nn.GroupNorm: gn_flops_counter_hook,
380 # upsample
381 nn.Upsample: upsample_flops_counter_hook,
382}
383
384
385def batch_counter_hook(module, input, output):
386 batch_size = 1
387 if len(input) > 0:
388 # Can have multiple inputs, getting the first one
389 input = input[0]
390 batch_size = len(input)
391 else:
392 print('Warning! No positional inputs found for a module, '
393 'assuming batch size is 1.')
394 module.__batch_counter__ += batch_size
395
396
397def add_batch_counter_variables_or_reset(module):
398 module.__batch_counter__ = 0
399
400
401def add_batch_counter_hook_function(module):
402 if hasattr(module, '__batch_counter_handle__'):
403 return
404
405 handle = module.register_forward_hook(batch_counter_hook)
406 module.__batch_counter_handle__ = handle
407
408
409def remove_batch_counter_hook_function(module):
410 if hasattr(module, '__batch_counter_handle__'):
411 module.__batch_counter_handle__.remove()
412 del module.__batch_counter_handle__
413
414
415def add_flops_counter_variable_or_reset(module):
416 if is_supported_instance(module):
417 module.__flops__ = 0
418
419
420def add_flops_counter_hook_function(module):
421 if is_supported_instance(module):
422 if hasattr(module, '__flops_handle__'):
423 return
424
425 for mod_type, counter_hook in hook_mapping.items():
426 if issubclass(type(module), mod_type):
427 handle = module.register_forward_hook(counter_hook)
428 break
429
430 module.__flops_handle__ = handle
431
432
433def remove_flops_counter_hook_function(module):
434 if is_supported_instance(module):
435 if hasattr(module, '__flops_handle__'):
436 module.__flops_handle__.remove()
437 del module.__flops_handle__
438
439
440# --- Masked flops counting
441# Also being run in the initialization
442def add_flops_mask_variable_or_reset(module):
443 if is_supported_instance(module):
444 module.__mask__ = None