diff options
author | Manu Mathew | 2020-12-08 04:57:31 -0600 |
---|---|---|
committer | Manu Mathew | 2020-12-08 04:57:31 -0600 |
commit | 1740669e0a76adf02d0a6e3dbb90314d60b7618c (patch) | |
tree | fd31c492b4d8666d90401201d4e832e492cd6fe3 | |
parent | 51d70cc298f3b61928921a2684c05fc537945efb (diff) | |
download | pytorch-mmdetection-1740669e0a76adf02d0a6e3dbb90314d60b7618c.tar.gz pytorch-mmdetection-1740669e0a76adf02d0a6e3dbb90314d60b7618c.tar.xz pytorch-mmdetection-1740669e0a76adf02d0a6e3dbb90314d60b7618c.zip |
re-implementation of Freeze BN and Rang for quantize
-rw-r--r-- | xmmdet/apis/train.py | 12 | ||||
-rw-r--r-- | xmmdet/utils/__init__.py | 2 | ||||
-rw-r--r-- | xmmdet/utils/runner.py | 37 |
3 files changed, 24 insertions, 27 deletions
diff --git a/xmmdet/apis/train.py b/xmmdet/apis/train.py index 785664db..a47cd082 100644 --- a/xmmdet/apis/train.py +++ b/xmmdet/apis/train.py | |||
@@ -10,7 +10,7 @@ from mmcv.utils import build_from_cfg | |||
10 | from mmdet.core import DistEvalHook, EvalHook | 10 | from mmdet.core import DistEvalHook, EvalHook |
11 | from mmdet.datasets import build_dataloader, build_dataset | 11 | from mmdet.datasets import build_dataloader, build_dataset |
12 | from ..utils import (get_root_logger, XMMDetEpochBasedRunner, XMMDetNoOptimizerHook, \ | 12 | from ..utils import (get_root_logger, XMMDetEpochBasedRunner, XMMDetNoOptimizerHook, \ |
13 | XMMDetDataParallel) | 13 | XMMDetDataParallel, FreezeRangeHook) |
14 | 14 | ||
15 | 15 | ||
16 | def set_random_seed(seed, deterministic=False): | 16 | def set_random_seed(seed, deterministic=False): |
@@ -84,15 +84,13 @@ def train_detector(model, | |||
84 | 84 | ||
85 | # build runner | 85 | # build runner |
86 | quantize = cfg.get('quantize', False) | 86 | quantize = cfg.get('quantize', False) |
87 | freeze_range = bool(quantize) | ||
88 | optimizer = build_optimizer(model, cfg.optimizer) | 87 | optimizer = build_optimizer(model, cfg.optimizer) |
89 | runner = XMMDetEpochBasedRunner( | 88 | runner = XMMDetEpochBasedRunner( |
90 | model, | 89 | model, |
91 | optimizer=optimizer, | 90 | optimizer=optimizer, |
92 | work_dir=cfg.work_dir, | 91 | work_dir=cfg.work_dir, |
93 | logger=logger, | 92 | logger=logger, |
94 | meta=meta, | 93 | meta=meta) |
95 | freeze_range=freeze_range) | ||
96 | # an ugly workaround to make .log and .log.json filenames the same | 94 | # an ugly workaround to make .log and .log.json filenames the same |
97 | runner.timestamp = timestamp | 95 | runner.timestamp = timestamp |
98 | 96 | ||
@@ -115,6 +113,12 @@ def train_detector(model, | |||
115 | if distributed: | 113 | if distributed: |
116 | runner.register_hook(DistSamplerSeedHook()) | 114 | runner.register_hook(DistSamplerSeedHook()) |
117 | 115 | ||
116 | # register train hooks | ||
117 | freeze_range = bool(quantize) | ||
118 | if freeze_range: | ||
119 | runner.register_hook(FreezeRangeHook()) | ||
120 | # | ||
121 | |||
118 | # register eval hooks | 122 | # register eval hooks |
119 | if validate: | 123 | if validate: |
120 | # Support batch_size > 1 in validation | 124 | # Support batch_size > 1 in validation |
diff --git a/xmmdet/utils/__init__.py b/xmmdet/utils/__init__.py index fe27d9f5..8b6923cd 100644 --- a/xmmdet/utils/__init__.py +++ b/xmmdet/utils/__init__.py | |||
@@ -1,7 +1,7 @@ | |||
1 | from mmdet.utils import * | 1 | from mmdet.utils import * |
2 | from .flops_counter import get_model_complexity_info | 2 | from .flops_counter import get_model_complexity_info |
3 | from .logger import LoggerStream, get_root_logger | 3 | from .logger import LoggerStream, get_root_logger |
4 | from .runner import XMMDetEpochBasedRunner, XMMDetNoOptimizerHook, \ | 4 | from .runner import XMMDetEpochBasedRunner, XMMDetNoOptimizerHook, FreezeRangeHook, \ |
5 | mmdet_load_checkpoint, mmdet_save_checkpoint | 5 | mmdet_load_checkpoint, mmdet_save_checkpoint |
6 | from .save_model import save_model_proto | 6 | from .save_model import save_model_proto |
7 | from .quantize import XMMDetQuantTrainModule, XMMDetQuantCalibrateModule, \ | 7 | from .quantize import XMMDetQuantTrainModule, XMMDetQuantCalibrateModule, \ |
diff --git a/xmmdet/utils/runner.py b/xmmdet/utils/runner.py index 52edbc56..07195703 100644 --- a/xmmdet/utils/runner.py +++ b/xmmdet/utils/runner.py | |||
@@ -1,7 +1,7 @@ | |||
1 | import torch | 1 | import torch |
2 | import mmcv | 2 | import mmcv |
3 | from mmcv.runner import EpochBasedRunner | 3 | from mmcv.runner import EpochBasedRunner |
4 | from mmcv.runner import OptimizerHook | 4 | from mmcv.runner import OptimizerHook, HOOKS, Hook |
5 | from pytorch_jacinto_ai import xnn | 5 | from pytorch_jacinto_ai import xnn |
6 | from .quantize import is_mmdet_quant_module | 6 | from .quantize import is_mmdet_quant_module |
7 | 7 | ||
@@ -21,27 +21,6 @@ def mmdet_save_checkpoint(model, *args, **kwargs): | |||
21 | 21 | ||
22 | 22 | ||
23 | class XMMDetEpochBasedRunner(EpochBasedRunner): | 23 | class XMMDetEpochBasedRunner(EpochBasedRunner): |
24 | def __init__(self, *args, **kwargs): | ||
25 | freeze_range = kwargs.pop('freeze_range', False) | ||
26 | super().__init__(*args, **kwargs) | ||
27 | self.freeze_range = freeze_range | ||
28 | |||
29 | def train(self, data_loader, **kwargs): | ||
30 | if self.freeze_range: | ||
31 | # currently we don't have a parameter that indicates whether we are doing QAT or not. | ||
32 | # Let us do it for all cases of training for the time being. | ||
33 | freeze_bn_epoch = (self.max_epochs//2)-1 | ||
34 | freeze_range_epoch = (self.max_epochs//2)+1 | ||
35 | if self.epoch > 0 and self.epoch >= freeze_bn_epoch: | ||
36 | xnn.utils.freeze_bn(self.model) | ||
37 | # | ||
38 | if self.epoch > 1 and self.epoch >= freeze_range_epoch: | ||
39 | xnn.layers.freeze_quant_range(self.model) | ||
40 | # | ||
41 | # | ||
42 | super().train(data_loader, **kwargs) | ||
43 | |||
44 | |||
45 | def _get_model_orig(self): | 24 | def _get_model_orig(self): |
46 | model_orig = self.model | 25 | model_orig = self.model |
47 | is_model_orig = True | 26 | is_model_orig = True |
@@ -84,9 +63,23 @@ class XMMDetEpochBasedRunner(EpochBasedRunner): | |||
84 | # | 63 | # |
85 | 64 | ||
86 | 65 | ||
66 | @HOOKS.register_module() | ||
87 | class XMMDetNoOptimizerHook(OptimizerHook): | 67 | class XMMDetNoOptimizerHook(OptimizerHook): |
88 | def after_train_iter(self, runner): | 68 | def after_train_iter(self, runner): |
89 | pass | 69 | pass |
90 | 70 | ||
91 | 71 | ||
72 | @HOOKS.register_module() | ||
73 | class FreezeRangeHook(Hook): | ||
74 | def before_train_epoch(self, runner): | ||
75 | freeze_bn_epoch = (runner.max_epochs // 2) - 1 | ||
76 | freeze_range_epoch = (runner.max_epochs // 2) + 1 | ||
77 | if runner.epoch >= 1 and runner.epoch >= freeze_bn_epoch: | ||
78 | xnn.utils.print_once('Freezing BN') | ||
79 | xnn.utils.freeze_bn(runner.model) | ||
80 | # | ||
81 | if runner.epoch >= 2 and runner.epoch >= freeze_range_epoch: | ||
82 | xnn.utils.print_once('Freezing Activation ranges') | ||
83 | xnn.layers.freeze_quant_range(runner.model) | ||
84 | # | ||
92 | 85 | ||