diff options
-rw-r--r-- | README.md | 4 | ||||
-rw-r--r-- | tools/test.py | 10 | ||||
-rw-r--r-- | xmmdet/__init__.py | 6 | ||||
-rw-r--r-- | xmmdet/models/__init__.py | 3 | ||||
-rw-r--r-- | xmmdet/models/backbones/resnet.py | 4 | ||||
-rw-r--r-- | xmmdet/models/dense_heads/fcos_head.py | 2 | ||||
-rw-r--r-- | xmmdet/models/dense_heads/retina_head.py | 2 | ||||
-rw-r--r-- | xmmdet/models/dense_heads/ssd_head.py | 2 | ||||
-rw-r--r-- | xmmdet/models/necks/fpn.py | 2 | ||||
-rw-r--r-- | xmmdet/utils/__init__.py | 4 | ||||
-rw-r--r-- | xmmdet/utils/flops_counter.py | 444 |
11 files changed, 470 insertions, 13 deletions
@@ -15,7 +15,9 @@ This repository is released under the following [LICENSE](./LICENSE). | |||
15 | 15 | ||
16 | ## Installation | 16 | ## Installation |
17 | 17 | ||
18 | Please refer to [mmdetection install.md](https://github.com/open-mmlab/mmdetection/docs/install.md) for installation and dataset preparation. | 18 | Please refer to [mmdetection install.md](https://github.com/open-mmlab/mmdetection/docs/install.md) for installation and dataset preparation. |
19 | |||
20 | We used the the version **v2.1.0** of mmdetection to test our changes. If you get any issues with the master branch of mmdetection, try checking out that tag. | ||
19 | 21 | ||
20 | After installing mmdetection, please install [PyTorch-Jacinto-AI-DevKit](https://bitbucket.itg.ti.com/projects/JACINTO-AI/repos/pytorch-jacinto-ai-devkit/browse/) as our repository uses several components from there - especially to define low complexity models and to Quantization Aware Training (QAT). | 22 | After installing mmdetection, please install [PyTorch-Jacinto-AI-DevKit](https://bitbucket.itg.ti.com/projects/JACINTO-AI/repos/pytorch-jacinto-ai-devkit/browse/) as our repository uses several components from there - especially to define low complexity models and to Quantization Aware Training (QAT). |
21 | 23 | ||
diff --git a/tools/test.py b/tools/test.py index e5c36446..803e3860 100644 --- a/tools/test.py +++ b/tools/test.py | |||
@@ -8,11 +8,11 @@ from mmcv.parallel import MMDataParallel, MMDistributedDataParallel | |||
8 | from mmcv.runner import get_dist_info, init_dist, load_checkpoint | 8 | from mmcv.runner import get_dist_info, init_dist, load_checkpoint |
9 | from tools.fuse_conv_bn import fuse_module | 9 | from tools.fuse_conv_bn import fuse_module |
10 | 10 | ||
11 | from mmdet.apis import multi_gpu_test, single_gpu_test | 11 | from xmmdet.apis import multi_gpu_test, single_gpu_test |
12 | from mmdet.core import wrap_fp16_model | 12 | from xmmdet.core import wrap_fp16_model |
13 | from mmdet.datasets import build_dataloader, build_dataset | 13 | from xmmdet.datasets import build_dataloader, build_dataset |
14 | from mmdet.models import build_detector | 14 | from xmmdet.models import build_detector |
15 | from mmdet.utils import MMDetQuantTestModule, save_model_proto, mmdet_load_checkpoint | 15 | from xmmdet.utils import MMDetQuantTestModule, save_model_proto, mmdet_load_checkpoint |
16 | 16 | ||
17 | from pytorch_jacinto_ai import xnn | 17 | from pytorch_jacinto_ai import xnn |
18 | 18 | ||
diff --git a/xmmdet/__init__.py b/xmmdet/__init__.py index f9d66cc7..325fb75f 100644 --- a/xmmdet/__init__.py +++ b/xmmdet/__init__.py | |||
@@ -1,2 +1,8 @@ | |||
1 | from mmdet import * | 1 | from mmdet import * |
2 | from .ops import * | ||
3 | from .core import * | ||
4 | from .datasets import * | ||
5 | from .models import * | ||
6 | from .utils import * | ||
7 | from .apis import * | ||
2 | 8 | ||
diff --git a/xmmdet/models/__init__.py b/xmmdet/models/__init__.py index 608be121..718b897e 100644 --- a/xmmdet/models/__init__.py +++ b/xmmdet/models/__init__.py | |||
@@ -1 +1,4 @@ | |||
1 | from mmdet.models import * | 1 | from mmdet.models import * |
2 | from .backbones import * | ||
3 | from .dense_heads import * | ||
4 | from .necks import * | ||
diff --git a/xmmdet/models/backbones/resnet.py b/xmmdet/models/backbones/resnet.py index 08a26836..5048a4e9 100644 --- a/xmmdet/models/backbones/resnet.py +++ b/xmmdet/models/backbones/resnet.py | |||
@@ -290,7 +290,7 @@ class Bottleneck(nn.Module): | |||
290 | return out | 290 | return out |
291 | 291 | ||
292 | 292 | ||
293 | @BACKBONES.register_module() | 293 | @BACKBONES.register_module(force=True) |
294 | class ResNet(nn.Module): | 294 | class ResNet(nn.Module): |
295 | """ResNet backbone. | 295 | """ResNet backbone. |
296 | 296 | ||
@@ -619,7 +619,7 @@ class ResNet(nn.Module): | |||
619 | m.eval() | 619 | m.eval() |
620 | 620 | ||
621 | 621 | ||
622 | @BACKBONES.register_module() | 622 | @BACKBONES.register_module(force=True) |
623 | class ResNetV1d(ResNet): | 623 | class ResNetV1d(ResNet): |
624 | """ResNetV1d variant described in | 624 | """ResNetV1d variant described in |
625 | `Bag of Tricks <https://arxiv.org/pdf/1812.01187.pdf>`_. | 625 | `Bag of Tricks <https://arxiv.org/pdf/1812.01187.pdf>`_. |
diff --git a/xmmdet/models/dense_heads/fcos_head.py b/xmmdet/models/dense_heads/fcos_head.py index 1961245c..27e15c3d 100644 --- a/xmmdet/models/dense_heads/fcos_head.py +++ b/xmmdet/models/dense_heads/fcos_head.py | |||
@@ -8,7 +8,7 @@ from mmdet.models.builder import HEADS, build_loss | |||
8 | INF = 1e8 | 8 | INF = 1e8 |
9 | 9 | ||
10 | 10 | ||
11 | @HEADS.register_module() | 11 | @HEADS.register_module(force=True) |
12 | class FCOSHead(nn.Module): | 12 | class FCOSHead(nn.Module): |
13 | """Anchor-free head used in `FCOS <https://arxiv.org/abs/1904.01355>`_. | 13 | """Anchor-free head used in `FCOS <https://arxiv.org/abs/1904.01355>`_. |
14 | 14 | ||
diff --git a/xmmdet/models/dense_heads/retina_head.py b/xmmdet/models/dense_heads/retina_head.py index 68232c45..27031acd 100644 --- a/xmmdet/models/dense_heads/retina_head.py +++ b/xmmdet/models/dense_heads/retina_head.py | |||
@@ -5,7 +5,7 @@ from mmdet.models.builder import HEADS | |||
5 | from mmdet.models.dense_heads.anchor_head import AnchorHead | 5 | from mmdet.models.dense_heads.anchor_head import AnchorHead |
6 | 6 | ||
7 | 7 | ||
8 | @HEADS.register_module() | 8 | @HEADS.register_module(force=True) |
9 | class RetinaHead(AnchorHead): | 9 | class RetinaHead(AnchorHead): |
10 | """An anchor-based head used in | 10 | """An anchor-based head used in |
11 | `RetinaNet <https://arxiv.org/pdf/1708.02002.pdf>`_. | 11 | `RetinaNet <https://arxiv.org/pdf/1708.02002.pdf>`_. |
diff --git a/xmmdet/models/dense_heads/ssd_head.py b/xmmdet/models/dense_heads/ssd_head.py index d5f4658e..486343f1 100644 --- a/xmmdet/models/dense_heads/ssd_head.py +++ b/xmmdet/models/dense_heads/ssd_head.py | |||
@@ -11,7 +11,7 @@ from mmdet.models.dense_heads.anchor_head import AnchorHead | |||
11 | 11 | ||
12 | 12 | ||
13 | # TODO: add loss evaluator for SSD | 13 | # TODO: add loss evaluator for SSD |
14 | @HEADS.register_module() | 14 | @HEADS.register_module(force=True) |
15 | class SSDHead(AnchorHead): | 15 | class SSDHead(AnchorHead): |
16 | 16 | ||
17 | def __init__(self, | 17 | def __init__(self, |
diff --git a/xmmdet/models/necks/fpn.py b/xmmdet/models/necks/fpn.py index c55ef05e..f773a01c 100644 --- a/xmmdet/models/necks/fpn.py +++ b/xmmdet/models/necks/fpn.py | |||
@@ -8,7 +8,7 @@ from mmdet.models.builder import NECKS | |||
8 | from pytorch_jacinto_ai import xnn | 8 | from pytorch_jacinto_ai import xnn |
9 | 9 | ||
10 | 10 | ||
11 | @NECKS.register_module() | 11 | @NECKS.register_module(force=True) |
12 | class FPN(nn.Module): | 12 | class FPN(nn.Module): |
13 | """ | 13 | """ |
14 | Feature Pyramid Network. | 14 | Feature Pyramid Network. |
diff --git a/xmmdet/utils/__init__.py b/xmmdet/utils/__init__.py index 062baecd..2ce3a842 100644 --- a/xmmdet/utils/__init__.py +++ b/xmmdet/utils/__init__.py | |||
@@ -1,6 +1,8 @@ | |||
1 | from mmdet.utils import * | 1 | from mmdet.utils import * |
2 | from .flops_counter import get_model_complexity_info | ||
2 | from .logger import LoggerStream, get_root_logger | 3 | from .logger import LoggerStream, get_root_logger |
3 | from .runner import MMDetRunner, MMDetNoOptimizerHook | 4 | from .runner import MMDetRunner, MMDetNoOptimizerHook, \ |
5 | mmdet_load_checkpoint, mmdet_save_checkpoint | ||
4 | from .save_model import save_model_proto | 6 | from .save_model import save_model_proto |
5 | from .quantize import MMDetQuantTrainModule, MMDetQuantCalibrateModule, \ | 7 | from .quantize import MMDetQuantTrainModule, MMDetQuantCalibrateModule, \ |
6 | MMDetQuantTestModule, is_mmdet_quant_module | 8 | MMDetQuantTestModule, is_mmdet_quant_module |
diff --git a/xmmdet/utils/flops_counter.py b/xmmdet/utils/flops_counter.py new file mode 100644 index 00000000..04f27f8b --- /dev/null +++ b/xmmdet/utils/flops_counter.py | |||
@@ -0,0 +1,444 @@ | |||
1 | # Modified from flops-counter.pytorch by Vladislav Sovrasov | ||
2 | # original repo: https://github.com/sovrasov/flops-counter.pytorch | ||
3 | |||
4 | # MIT License | ||
5 | |||
6 | # Copyright (c) 2018 Vladislav Sovrasov | ||
7 | |||
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy | ||
9 | # of this software and associated documentation files (the "Software"), to deal | ||
10 | # in the Software without restriction, including without limitation the rights | ||
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
12 | # copies of the Software, and to permit persons to whom the Software is | ||
13 | # furnished to do so, subject to the following conditions: | ||
14 | |||
15 | # The above copyright notice and this permission notice shall be included in | ||
16 | # all copies or substantial portions of the Software. | ||
17 | |||
18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
24 | # SOFTWARE. | ||
25 | |||
26 | import sys | ||
27 | |||
28 | import numpy as np | ||
29 | import torch | ||
30 | import torch.nn as nn | ||
31 | from torch.nn.modules.batchnorm import _BatchNorm | ||
32 | from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin | ||
33 | from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, | ||
34 | _AvgPoolNd, _MaxPoolNd) | ||
35 | |||
36 | |||
37 | def get_model_complexity_info(model, | ||
38 | input_res, | ||
39 | print_per_layer_stat=True, | ||
40 | as_strings=True, | ||
41 | input_constructor=None, | ||
42 | ost=sys.stdout): | ||
43 | assert type(input_res) is tuple | ||
44 | assert len(input_res) >= 2 | ||
45 | flops_model = add_flops_counting_methods(model) | ||
46 | flops_model.eval().start_flops_count() | ||
47 | if input_constructor: | ||
48 | input = input_constructor(input_res) | ||
49 | _ = flops_model(**input) | ||
50 | else: | ||
51 | batch = torch.ones(()).new_empty( | ||
52 | (1, *input_res), | ||
53 | dtype=next(flops_model.parameters()).dtype, | ||
54 | device=next(flops_model.parameters()).device) | ||
55 | flops_model.forward_dummy(batch) | ||
56 | |||
57 | if print_per_layer_stat: | ||
58 | print_model_with_flops(flops_model, ost=ost) | ||
59 | flops_count = flops_model.compute_average_flops_cost() | ||
60 | params_count = get_model_parameters_number(flops_model) | ||
61 | flops_model.stop_flops_count() | ||
62 | |||
63 | if as_strings: | ||
64 | return flops_to_string(flops_count), params_to_string(params_count) | ||
65 | |||
66 | return flops_count, params_count | ||
67 | |||
68 | |||
69 | def flops_to_string(flops, units='GMac', precision=2): | ||
70 | if units is None: | ||
71 | if flops // 10**9 > 0: | ||
72 | return str(round(flops / 10.**9, precision)) + ' GMac' | ||
73 | elif flops // 10**6 > 0: | ||
74 | return str(round(flops / 10.**6, precision)) + ' MMac' | ||
75 | elif flops // 10**3 > 0: | ||
76 | return str(round(flops / 10.**3, precision)) + ' KMac' | ||
77 | else: | ||
78 | return str(flops) + ' Mac' | ||
79 | else: | ||
80 | if units == 'GMac': | ||
81 | return str(round(flops / 10.**9, precision)) + ' ' + units | ||
82 | elif units == 'MMac': | ||
83 | return str(round(flops / 10.**6, precision)) + ' ' + units | ||
84 | elif units == 'KMac': | ||
85 | return str(round(flops / 10.**3, precision)) + ' ' + units | ||
86 | else: | ||
87 | return str(flops) + ' Mac' | ||
88 | |||
89 | |||
90 | def params_to_string(params_num): | ||
91 | """converting number to string | ||
92 | |||
93 | :param float params_num: number | ||
94 | :returns str: number | ||
95 | |||
96 | >>> params_to_string(1e9) | ||
97 | '1000.0 M' | ||
98 | >>> params_to_string(2e5) | ||
99 | '200.0 k' | ||
100 | >>> params_to_string(3e-9) | ||
101 | '3e-09' | ||
102 | """ | ||
103 | if params_num // 10**6 > 0: | ||
104 | return str(round(params_num / 10**6, 2)) + ' M' | ||
105 | elif params_num // 10**3: | ||
106 | return str(round(params_num / 10**3, 2)) + ' k' | ||
107 | else: | ||
108 | return str(params_num) | ||
109 | |||
110 | |||
111 | def print_model_with_flops(model, units='GMac', precision=3, ost=sys.stdout): | ||
112 | total_flops = model.compute_average_flops_cost() | ||
113 | |||
114 | def accumulate_flops(self): | ||
115 | if is_supported_instance(self): | ||
116 | return self.__flops__ / (model.__batch_counter__ or 1) | ||
117 | else: | ||
118 | sum = 0 | ||
119 | for m in self.children(): | ||
120 | sum += m.accumulate_flops() | ||
121 | return sum | ||
122 | |||
123 | def flops_repr(self): | ||
124 | accumulated_flops_cost = self.accumulate_flops() | ||
125 | return ', '.join([ | ||
126 | flops_to_string( | ||
127 | accumulated_flops_cost, units=units, precision=precision), | ||
128 | f'{accumulated_flops_cost / total_flops:.3%} MACs', | ||
129 | self.original_extra_repr() | ||
130 | ]) | ||
131 | |||
132 | def add_extra_repr(m): | ||
133 | m.accumulate_flops = accumulate_flops.__get__(m) | ||
134 | flops_extra_repr = flops_repr.__get__(m) | ||
135 | if m.extra_repr != flops_extra_repr: | ||
136 | m.original_extra_repr = m.extra_repr | ||
137 | m.extra_repr = flops_extra_repr | ||
138 | assert m.extra_repr != m.original_extra_repr | ||
139 | |||
140 | def del_extra_repr(m): | ||
141 | if hasattr(m, 'original_extra_repr'): | ||
142 | m.extra_repr = m.original_extra_repr | ||
143 | del m.original_extra_repr | ||
144 | if hasattr(m, 'accumulate_flops'): | ||
145 | del m.accumulate_flops | ||
146 | |||
147 | model.apply(add_extra_repr) | ||
148 | print(model, file=ost) | ||
149 | model.apply(del_extra_repr) | ||
150 | |||
151 | |||
152 | def get_model_parameters_number(model): | ||
153 | params_num = sum(p.numel() for p in model.parameters() if p.requires_grad) | ||
154 | return params_num | ||
155 | |||
156 | |||
157 | def add_flops_counting_methods(net_main_module): | ||
158 | # adding additional methods to the existing module object, | ||
159 | # this is done this way so that each function has access to self object | ||
160 | net_main_module.start_flops_count = start_flops_count.__get__( | ||
161 | net_main_module) | ||
162 | net_main_module.stop_flops_count = stop_flops_count.__get__( | ||
163 | net_main_module) | ||
164 | net_main_module.reset_flops_count = reset_flops_count.__get__( | ||
165 | net_main_module) | ||
166 | net_main_module.compute_average_flops_cost = \ | ||
167 | compute_average_flops_cost.__get__(net_main_module) | ||
168 | |||
169 | net_main_module.reset_flops_count() | ||
170 | |||
171 | # Adding variables necessary for masked flops computation | ||
172 | net_main_module.apply(add_flops_mask_variable_or_reset) | ||
173 | |||
174 | return net_main_module | ||
175 | |||
176 | |||
177 | def compute_average_flops_cost(self): | ||
178 | """ | ||
179 | A method that will be available after add_flops_counting_methods() is | ||
180 | called on a desired net object. | ||
181 | Returns current mean flops consumption per image. | ||
182 | """ | ||
183 | |||
184 | batches_count = (self.__batch_counter__ or 1) | ||
185 | flops_sum = 0 | ||
186 | for module in self.modules(): | ||
187 | if is_supported_instance(module): | ||
188 | flops_sum += module.__flops__ | ||
189 | |||
190 | return flops_sum / batches_count | ||
191 | |||
192 | |||
193 | def start_flops_count(self): | ||
194 | """ | ||
195 | A method that will be available after add_flops_counting_methods() is | ||
196 | called on a desired net object. | ||
197 | Activates the computation of mean flops consumption per image. | ||
198 | Call it before you run the network. | ||
199 | """ | ||
200 | add_batch_counter_hook_function(self) | ||
201 | self.apply(add_flops_counter_hook_function) | ||
202 | |||
203 | |||
204 | def stop_flops_count(self): | ||
205 | """ | ||
206 | A method that will be available after add_flops_counting_methods() is | ||
207 | called on a desired net object. | ||
208 | Stops computing the mean flops consumption per image. | ||
209 | Call whenever you want to pause the computation. | ||
210 | """ | ||
211 | remove_batch_counter_hook_function(self) | ||
212 | self.apply(remove_flops_counter_hook_function) | ||
213 | |||
214 | |||
215 | def reset_flops_count(self): | ||
216 | """ | ||
217 | A method that will be available after add_flops_counting_methods() is | ||
218 | called on a desired net object. | ||
219 | Resets statistics computed so far. | ||
220 | """ | ||
221 | add_batch_counter_variables_or_reset(self) | ||
222 | self.apply(add_flops_counter_variable_or_reset) | ||
223 | |||
224 | |||
225 | def add_flops_mask(module, mask): | ||
226 | |||
227 | def add_flops_mask_func(module): | ||
228 | if isinstance(module, torch.nn.Conv2d): | ||
229 | module.__mask__ = mask | ||
230 | |||
231 | module.apply(add_flops_mask_func) | ||
232 | |||
233 | |||
234 | def remove_flops_mask(module): | ||
235 | module.apply(add_flops_mask_variable_or_reset) | ||
236 | |||
237 | |||
238 | def is_supported_instance(module): | ||
239 | for mod in hook_mapping: | ||
240 | if issubclass(type(module), mod): | ||
241 | return True | ||
242 | return False | ||
243 | |||
244 | |||
245 | def empty_flops_counter_hook(module, input, output): | ||
246 | module.__flops__ += 0 | ||
247 | |||
248 | |||
249 | def upsample_flops_counter_hook(module, input, output): | ||
250 | output_size = output[0] | ||
251 | batch_size = output_size.shape[0] | ||
252 | output_elements_count = batch_size | ||
253 | for val in output_size.shape[1:]: | ||
254 | output_elements_count *= val | ||
255 | module.__flops__ += int(output_elements_count) | ||
256 | |||
257 | |||
258 | def relu_flops_counter_hook(module, input, output): | ||
259 | active_elements_count = output.numel() | ||
260 | module.__flops__ += int(active_elements_count) | ||
261 | |||
262 | |||
263 | def linear_flops_counter_hook(module, input, output): | ||
264 | input = input[0] | ||
265 | batch_size = input.shape[0] | ||
266 | module.__flops__ += int(batch_size * input.shape[1] * output.shape[1]) | ||
267 | |||
268 | |||
269 | def pool_flops_counter_hook(module, input, output): | ||
270 | input = input[0] | ||
271 | module.__flops__ += int(np.prod(input.shape)) | ||
272 | |||
273 | |||
274 | def bn_flops_counter_hook(module, input, output): | ||
275 | input = input[0] | ||
276 | |||
277 | batch_flops = np.prod(input.shape) | ||
278 | if module.affine: | ||
279 | batch_flops *= 2 | ||
280 | module.__flops__ += int(batch_flops) | ||
281 | |||
282 | |||
283 | def gn_flops_counter_hook(module, input, output): | ||
284 | elems = np.prod(input[0].shape) | ||
285 | # there is no precise FLOPs estimation of computing mean and variance, | ||
286 | # and we just set it 2 * elems: half muladds for computing | ||
287 | # means and half for computing vars | ||
288 | batch_flops = 3 * elems | ||
289 | if module.affine: | ||
290 | batch_flops += elems | ||
291 | module.__flops__ += int(batch_flops) | ||
292 | |||
293 | |||
294 | def deconv_flops_counter_hook(conv_module, input, output): | ||
295 | # Can have multiple inputs, getting the first one | ||
296 | input = input[0] | ||
297 | |||
298 | batch_size = input.shape[0] | ||
299 | input_height, input_width = input.shape[2:] | ||
300 | |||
301 | kernel_height, kernel_width = conv_module.kernel_size | ||
302 | in_channels = conv_module.in_channels | ||
303 | out_channels = conv_module.out_channels | ||
304 | groups = conv_module.groups | ||
305 | |||
306 | filters_per_channel = out_channels // groups | ||
307 | conv_per_position_flops = ( | ||
308 | kernel_height * kernel_width * in_channels * filters_per_channel) | ||
309 | |||
310 | active_elements_count = batch_size * input_height * input_width | ||
311 | overall_conv_flops = conv_per_position_flops * active_elements_count | ||
312 | bias_flops = 0 | ||
313 | if conv_module.bias is not None: | ||
314 | output_height, output_width = output.shape[2:] | ||
315 | bias_flops = out_channels * batch_size * output_height * output_height | ||
316 | overall_flops = overall_conv_flops + bias_flops | ||
317 | |||
318 | conv_module.__flops__ += int(overall_flops) | ||
319 | |||
320 | |||
321 | def conv_flops_counter_hook(conv_module, input, output): | ||
322 | # Can have multiple inputs, getting the first one | ||
323 | input = input[0] | ||
324 | |||
325 | batch_size = input.shape[0] | ||
326 | output_dims = list(output.shape[2:]) | ||
327 | |||
328 | kernel_dims = list(conv_module.kernel_size) | ||
329 | in_channels = conv_module.in_channels | ||
330 | out_channels = conv_module.out_channels | ||
331 | groups = conv_module.groups | ||
332 | |||
333 | filters_per_channel = out_channels // groups | ||
334 | conv_per_position_flops = np.prod( | ||
335 | kernel_dims) * in_channels * filters_per_channel | ||
336 | |||
337 | active_elements_count = batch_size * np.prod(output_dims) | ||
338 | |||
339 | if conv_module.__mask__ is not None: | ||
340 | # (b, 1, h, w) | ||
341 | output_height, output_width = output.shape[2:] | ||
342 | flops_mask = conv_module.__mask__.expand(batch_size, 1, output_height, | ||
343 | output_width) | ||
344 | active_elements_count = flops_mask.sum() | ||
345 | |||
346 | overall_conv_flops = conv_per_position_flops * active_elements_count | ||
347 | |||
348 | bias_flops = 0 | ||
349 | |||
350 | if conv_module.bias is not None: | ||
351 | |||
352 | bias_flops = out_channels * active_elements_count | ||
353 | |||
354 | overall_flops = overall_conv_flops + bias_flops | ||
355 | |||
356 | conv_module.__flops__ += int(overall_flops) | ||
357 | |||
358 | |||
359 | hook_mapping = { | ||
360 | # conv | ||
361 | _ConvNd: conv_flops_counter_hook, | ||
362 | # deconv | ||
363 | _ConvTransposeMixin: deconv_flops_counter_hook, | ||
364 | # fc | ||
365 | nn.Linear: linear_flops_counter_hook, | ||
366 | # pooling | ||
367 | _AvgPoolNd: pool_flops_counter_hook, | ||
368 | _MaxPoolNd: pool_flops_counter_hook, | ||
369 | _AdaptiveAvgPoolNd: pool_flops_counter_hook, | ||
370 | _AdaptiveMaxPoolNd: pool_flops_counter_hook, | ||
371 | # activation | ||
372 | nn.ReLU: relu_flops_counter_hook, | ||
373 | nn.PReLU: relu_flops_counter_hook, | ||
374 | nn.ELU: relu_flops_counter_hook, | ||
375 | nn.LeakyReLU: relu_flops_counter_hook, | ||
376 | nn.ReLU6: relu_flops_counter_hook, | ||
377 | # normalization | ||
378 | _BatchNorm: bn_flops_counter_hook, | ||
379 | nn.GroupNorm: gn_flops_counter_hook, | ||
380 | # upsample | ||
381 | nn.Upsample: upsample_flops_counter_hook, | ||
382 | } | ||
383 | |||
384 | |||
385 | def batch_counter_hook(module, input, output): | ||
386 | batch_size = 1 | ||
387 | if len(input) > 0: | ||
388 | # Can have multiple inputs, getting the first one | ||
389 | input = input[0] | ||
390 | batch_size = len(input) | ||
391 | else: | ||
392 | print('Warning! No positional inputs found for a module, ' | ||
393 | 'assuming batch size is 1.') | ||
394 | module.__batch_counter__ += batch_size | ||
395 | |||
396 | |||
397 | def add_batch_counter_variables_or_reset(module): | ||
398 | module.__batch_counter__ = 0 | ||
399 | |||
400 | |||
401 | def add_batch_counter_hook_function(module): | ||
402 | if hasattr(module, '__batch_counter_handle__'): | ||
403 | return | ||
404 | |||
405 | handle = module.register_forward_hook(batch_counter_hook) | ||
406 | module.__batch_counter_handle__ = handle | ||
407 | |||
408 | |||
409 | def remove_batch_counter_hook_function(module): | ||
410 | if hasattr(module, '__batch_counter_handle__'): | ||
411 | module.__batch_counter_handle__.remove() | ||
412 | del module.__batch_counter_handle__ | ||
413 | |||
414 | |||
415 | def add_flops_counter_variable_or_reset(module): | ||
416 | if is_supported_instance(module): | ||
417 | module.__flops__ = 0 | ||
418 | |||
419 | |||
420 | def add_flops_counter_hook_function(module): | ||
421 | if is_supported_instance(module): | ||
422 | if hasattr(module, '__flops_handle__'): | ||
423 | return | ||
424 | |||
425 | for mod_type, counter_hook in hook_mapping.items(): | ||
426 | if issubclass(type(module), mod_type): | ||
427 | handle = module.register_forward_hook(counter_hook) | ||
428 | break | ||
429 | |||
430 | module.__flops_handle__ = handle | ||
431 | |||
432 | |||
433 | def remove_flops_counter_hook_function(module): | ||
434 | if is_supported_instance(module): | ||
435 | if hasattr(module, '__flops_handle__'): | ||
436 | module.__flops_handle__.remove() | ||
437 | del module.__flops_handle__ | ||
438 | |||
439 | |||
440 | # --- Masked flops counting | ||
441 | # Also being run in the initialization | ||
442 | def add_flops_mask_variable_or_reset(module): | ||
443 | if is_supported_instance(module): | ||
444 | module.__mask__ = None | ||