aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorManu Mathew2020-12-08 04:57:31 -0600
committerManu Mathew2020-12-08 04:57:31 -0600
commit1740669e0a76adf02d0a6e3dbb90314d60b7618c (patch)
treefd31c492b4d8666d90401201d4e832e492cd6fe3
parent51d70cc298f3b61928921a2684c05fc537945efb (diff)
downloadpytorch-mmdetection-1740669e0a76adf02d0a6e3dbb90314d60b7618c.tar.gz
pytorch-mmdetection-1740669e0a76adf02d0a6e3dbb90314d60b7618c.tar.xz
pytorch-mmdetection-1740669e0a76adf02d0a6e3dbb90314d60b7618c.zip
re-implementation of Freeze BN and Rang for quantize
-rw-r--r--xmmdet/apis/train.py12
-rw-r--r--xmmdet/utils/__init__.py2
-rw-r--r--xmmdet/utils/runner.py37
3 files changed, 24 insertions, 27 deletions
diff --git a/xmmdet/apis/train.py b/xmmdet/apis/train.py
index 785664db..a47cd082 100644
--- a/xmmdet/apis/train.py
+++ b/xmmdet/apis/train.py
@@ -10,7 +10,7 @@ from mmcv.utils import build_from_cfg
10from mmdet.core import DistEvalHook, EvalHook 10from mmdet.core import DistEvalHook, EvalHook
11from mmdet.datasets import build_dataloader, build_dataset 11from mmdet.datasets import build_dataloader, build_dataset
12from ..utils import (get_root_logger, XMMDetEpochBasedRunner, XMMDetNoOptimizerHook, \ 12from ..utils import (get_root_logger, XMMDetEpochBasedRunner, XMMDetNoOptimizerHook, \
13 XMMDetDataParallel) 13 XMMDetDataParallel, FreezeRangeHook)
14 14
15 15
16def set_random_seed(seed, deterministic=False): 16def set_random_seed(seed, deterministic=False):
@@ -84,15 +84,13 @@ def train_detector(model,
84 84
85 # build runner 85 # build runner
86 quantize = cfg.get('quantize', False) 86 quantize = cfg.get('quantize', False)
87 freeze_range = bool(quantize)
88 optimizer = build_optimizer(model, cfg.optimizer) 87 optimizer = build_optimizer(model, cfg.optimizer)
89 runner = XMMDetEpochBasedRunner( 88 runner = XMMDetEpochBasedRunner(
90 model, 89 model,
91 optimizer=optimizer, 90 optimizer=optimizer,
92 work_dir=cfg.work_dir, 91 work_dir=cfg.work_dir,
93 logger=logger, 92 logger=logger,
94 meta=meta, 93 meta=meta)
95 freeze_range=freeze_range)
96 # an ugly workaround to make .log and .log.json filenames the same 94 # an ugly workaround to make .log and .log.json filenames the same
97 runner.timestamp = timestamp 95 runner.timestamp = timestamp
98 96
@@ -115,6 +113,12 @@ def train_detector(model,
115 if distributed: 113 if distributed:
116 runner.register_hook(DistSamplerSeedHook()) 114 runner.register_hook(DistSamplerSeedHook())
117 115
116 # register train hooks
117 freeze_range = bool(quantize)
118 if freeze_range:
119 runner.register_hook(FreezeRangeHook())
120 #
121
118 # register eval hooks 122 # register eval hooks
119 if validate: 123 if validate:
120 # Support batch_size > 1 in validation 124 # Support batch_size > 1 in validation
diff --git a/xmmdet/utils/__init__.py b/xmmdet/utils/__init__.py
index fe27d9f5..8b6923cd 100644
--- a/xmmdet/utils/__init__.py
+++ b/xmmdet/utils/__init__.py
@@ -1,7 +1,7 @@
1from mmdet.utils import * 1from mmdet.utils import *
2from .flops_counter import get_model_complexity_info 2from .flops_counter import get_model_complexity_info
3from .logger import LoggerStream, get_root_logger 3from .logger import LoggerStream, get_root_logger
4from .runner import XMMDetEpochBasedRunner, XMMDetNoOptimizerHook, \ 4from .runner import XMMDetEpochBasedRunner, XMMDetNoOptimizerHook, FreezeRangeHook, \
5 mmdet_load_checkpoint, mmdet_save_checkpoint 5 mmdet_load_checkpoint, mmdet_save_checkpoint
6from .save_model import save_model_proto 6from .save_model import save_model_proto
7from .quantize import XMMDetQuantTrainModule, XMMDetQuantCalibrateModule, \ 7from .quantize import XMMDetQuantTrainModule, XMMDetQuantCalibrateModule, \
diff --git a/xmmdet/utils/runner.py b/xmmdet/utils/runner.py
index 52edbc56..07195703 100644
--- a/xmmdet/utils/runner.py
+++ b/xmmdet/utils/runner.py
@@ -1,7 +1,7 @@
1import torch 1import torch
2import mmcv 2import mmcv
3from mmcv.runner import EpochBasedRunner 3from mmcv.runner import EpochBasedRunner
4from mmcv.runner import OptimizerHook 4from mmcv.runner import OptimizerHook, HOOKS, Hook
5from pytorch_jacinto_ai import xnn 5from pytorch_jacinto_ai import xnn
6from .quantize import is_mmdet_quant_module 6from .quantize import is_mmdet_quant_module
7 7
@@ -21,27 +21,6 @@ def mmdet_save_checkpoint(model, *args, **kwargs):
21 21
22 22
23class XMMDetEpochBasedRunner(EpochBasedRunner): 23class XMMDetEpochBasedRunner(EpochBasedRunner):
24 def __init__(self, *args, **kwargs):
25 freeze_range = kwargs.pop('freeze_range', False)
26 super().__init__(*args, **kwargs)
27 self.freeze_range = freeze_range
28
29 def train(self, data_loader, **kwargs):
30 if self.freeze_range:
31 # currently we don't have a parameter that indicates whether we are doing QAT or not.
32 # Let us do it for all cases of training for the time being.
33 freeze_bn_epoch = (self.max_epochs//2)-1
34 freeze_range_epoch = (self.max_epochs//2)+1
35 if self.epoch > 0 and self.epoch >= freeze_bn_epoch:
36 xnn.utils.freeze_bn(self.model)
37 #
38 if self.epoch > 1 and self.epoch >= freeze_range_epoch:
39 xnn.layers.freeze_quant_range(self.model)
40 #
41 #
42 super().train(data_loader, **kwargs)
43
44
45 def _get_model_orig(self): 24 def _get_model_orig(self):
46 model_orig = self.model 25 model_orig = self.model
47 is_model_orig = True 26 is_model_orig = True
@@ -84,9 +63,23 @@ class XMMDetEpochBasedRunner(EpochBasedRunner):
84 # 63 #
85 64
86 65
66@HOOKS.register_module()
87class XMMDetNoOptimizerHook(OptimizerHook): 67class XMMDetNoOptimizerHook(OptimizerHook):
88 def after_train_iter(self, runner): 68 def after_train_iter(self, runner):
89 pass 69 pass
90 70
91 71
72@HOOKS.register_module()
73class FreezeRangeHook(Hook):
74 def before_train_epoch(self, runner):
75 freeze_bn_epoch = (runner.max_epochs // 2) - 1
76 freeze_range_epoch = (runner.max_epochs // 2) + 1
77 if runner.epoch >= 1 and runner.epoch >= freeze_bn_epoch:
78 xnn.utils.print_once('Freezing BN')
79 xnn.utils.freeze_bn(runner.model)
80 #
81 if runner.epoch >= 2 and runner.epoch >= freeze_range_epoch:
82 xnn.utils.print_once('Freezing Activation ranges')
83 xnn.layers.freeze_quant_range(runner.model)
84 #
92 85