[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_train_module.py
1 ###########################################################
2 # Approximate quantized floating point simulation with gradients.
3 # Can be used for quantized training of models.
4 ###########################################################
6 import torch
7 import numpy as np
8 import copy
9 import warnings
11 from .. import layers
12 from .. import utils
13 from .quant_train_utils import *
15 warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
18 ###########################################################
19 class QuantTrainModule(QuantBaseModule):
20 def __init__(self, module, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False, histogram_range=True,
21 constrain_weights=True, bias_calibration=False, dummy_input=None):
22 super().__init__(module)
23 assert dummy_input is not None, 'dummy input is needed by quantized models to analyze graph'
24 self.bitwidth_weights = bitwidth_weights
25 self.bitwidth_activations = bitwidth_activations
26 self.per_channel_q = per_channel_q
27 self.constrain_weights = constrain_weights #and (not bool(self.per_channel_q))
28 self.bias_calibration = bias_calibration
30 # for help in debug/print
31 utils.add_module_names(self)
33 # put in eval mode before analyze
34 self.eval()
36 with torch.no_grad():
37 # model surgery for quantization
38 utils.print_yellow("=> model surgery by '{}'".format(self.model_surgery_quantize.__name__))
39 self.model_surgery_quantize(dummy_input)
40 #
42 # range shrink - 0.0 indicates no shrink
43 percentile_range_shrink = (layers.PAct2.PACT2_RANGE_SHRINK if histogram_range else 0.0)
44 # set attributes to all modules - can control the behaviour from here
45 utils.apply_setattr(self, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations, per_channel_q=per_channel_q, bias_calibration=bias_calibration,
46 quantize_enable=True, quantize_weights=True, quantize_bias=True, quantize_activations=True,
47 percentile_range_shrink=percentile_range_shrink, constrain_weights=self.constrain_weights)
49 # for help in debug/print
50 utils.add_module_names(self)
53 def train(self, mode=True):
54 self.iter_in_epoch.fill_(-1.0)
55 super().train(mode)
58 def forward(self, inputs):
59 # analyze
60 self.analyze_graph(inputs=inputs, cleanup_states=True)
62 # actual forward call
63 if self.training and self.bias_calibration:
64 # bias calibration
65 outputs = self.forward_calibrate_bias(inputs)
66 else:
67 outputs = self.module(inputs)
68 #
69 return outputs
72 def forward_calibrate_bias(self, inputs):
73 assert False, 'forward_calibrate_bias is not implemented'
76 def model_surgery_quantize(self, dummy_input):
77 super().model_surgery_quantize(dummy_input)
79 def replace_func(op):
80 for name, m in op._modules.items():
81 if utils.is_conv(m):
82 bias = (m.bias is not None)
83 padding_mode = m.padding_mode
84 new_m = QuantTrainConv2d(in_channels=m.in_channels, out_channels=m.out_channels, kernel_size=m.kernel_size, stride=m.stride,
85 padding=m.padding, dilation=m.dilation, groups=m.groups, bias=bias, padding_mode=padding_mode)
86 elif utils.is_deconv(m):
87 bias = (m.bias is not None)
88 padding_mode = m.padding_mode
89 new_m = QuantTrainConvTranspose2d(in_channels=m.in_channels, out_channels=m.out_channels, kernel_size=m.kernel_size, stride=m.stride,
90 padding=m.padding, output_padding=m.output_padding, groups=m.groups, bias=bias, dilation=m.dilation, padding_mode=padding_mode)
91 elif utils.is_bn(m):
92 new_m = QuantTrainBatchNorm2d(num_features=m.num_features, eps=m.eps, momentum=m.momentum, affine=m.affine,
93 track_running_stats=m.track_running_stats)
94 elif isinstance(m, layers.PAct2):
95 new_m = QuantTrainPAct2(inplace=m.inplace, signed=m.signed, clip_range=m.clip_range,
96 bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
97 per_channel_q=self.per_channel_q)
98 elif isinstance(m, layers.NoAct):
99 new_m = QuantTrainPAct2(signed=None, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
100 per_channel_q=self.per_channel_q)
101 elif isinstance(m, (torch.nn.ReLU, torch.nn.ReLU6)):
102 new_m = QuantTrainPAct2(signed=False, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
103 per_channel_q=self.per_channel_q)
104 else:
105 new_m = None
106 #
107 if new_m is not None:
108 for attr in dir(m):
109 value = getattr(m,attr)
110 if isinstance(value,torch.Tensor) and value is not None:
111 getattr(new_m,attr).data.copy_(value.data)
112 elif isinstance(value,torch.nn.Module) and value is not None:
113 setattr(new_m, attr, getattr(m,attr))
114 elif attr in ('weight','bias','eps','clips_act','clips_w'): # add more attribute name here
115 setattr(new_m, attr, getattr(m, attr))
116 #
117 #
118 new_m.train(m.training)
119 setattr(op, name, new_m)
120 #
121 #
122 #
123 # apply recursively
124 self.apply(replace_func)
126 # add a forward hook to call the extra activation that we added
127 def _forward_activation(op, inputs, outputs):
128 if hasattr(op, 'activation_q'):
129 outputs = op.activation_q(outputs)
130 #
131 return outputs
132 #
133 for m in self.modules():
134 m.register_forward_hook(_forward_activation)
135 #
137 # clear
138 self.clear_states()
139 #
142 ###########################################################
143 class QuantCalibrateModule(QuantTrainModule):
144 def __init__(self, module, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False, bias_calibration=True,
145 histogram_range=True, constrain_weights=True, lr_calib=0.1, dummy_input=None):
146 self.bias_calibration = bias_calibration
147 self.lr_calib = lr_calib
148 self.bias_calibration_factor = lr_calib
149 self.bias_calibration_gamma = 0.5
150 self.calibrate_weights = False
151 self.calibrate_repeats = 1
152 self.quantize_enable = True
153 # BNs can be adjusted based on the input provided - however this is not really required
154 self.calibrate_bn = False
155 super().__init__(module, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations, per_channel_q=per_channel_q,
156 histogram_range=histogram_range, constrain_weights=constrain_weights, bias_calibration=bias_calibration, dummy_input=dummy_input)
157 #
160 def forward_calibrate_bias(self, inputs):
161 # we don't need gradients for calibration
162 with torch.no_grad():
163 # prepare/backup weights
164 if self.num_batches_tracked == 0:
165 # lr_calib
166 self.bias_calibration_factor = self.lr_calib * np.power(self.bias_calibration_gamma, float(self.epoch))
167 # backup original weights
168 self._backup_weights_orig()
169 # backup quantized weights
170 self._backup_weights_quant()
171 #
173 # backup the current state
174 training = self.training
176 # compute the mean output in float
177 # also, set all bns to eval. we can't set the whole model to eval because
178 # we need the pact to learn the ranges - which will happen only in training mode.
179 # also the model output itself may be different in eval mode.
180 if self.calibrate_bn:
181 outputs = self.forward_compute_oputput_stats(inputs)
182 utils.freeze_bn(self)
183 else:
184 utils.freeze_bn(self)
185 outputs = self.forward_compute_oputput_stats(inputs)
186 #
188 # adjust the quantized output to match the mean
189 outputs = self.forward_adjust_bias(inputs)
191 self.train(training)
193 return outputs
194 #
197 def forward_compute_oputput_stats(self, inputs):
198 self._restore_weights_orig()
199 # disable quantization for a moment
200 quantize_enable_backup_value = self.quantize_enable
201 utils.apply_setattr(self, quantize_enable=False)
203 self.add_call_hook(self.module, self._forward_compute_oputput_stats_hook)
204 outputs = self.module(inputs)
205 self.remove_call_hook(self.module)
207 # turn quantization back on - not a clean method
208 utils.apply_setattr(self, quantize_enable=quantize_enable_backup_value)
209 self._backup_weights_orig()
210 return outputs
211 #
212 def _forward_compute_oputput_stats_hook(self, op, *inputs_orig):
213 outputs = op.__forward_orig__(*inputs_orig)
214 # calibration at specific layers
215 bias = op.bias if (hasattr(op, 'bias') and (op.bias is not None)) else None
216 weight = op.weight if (hasattr(op, 'weight') and (op.weight is not None)) else None
217 if (bias is not None) or (self.calibrate_weights and weight is not None):
218 output = outputs[0] if isinstance(outputs, (list,tuple)) else outputs
219 reduce_dims = [0,2,3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
220 op.__output_mean_orig__ = torch.mean(output, dim=reduce_dims)
221 op.__output_std_orig__ = torch.std(output, dim=reduce_dims)
222 #
223 return outputs
224 #
227 def forward_adjust_bias(self, input):
228 self._restore_weights_quant()
229 self.add_call_hook(self.module, self._forward_adjust_bias_hook)
230 for _ in range(self.calibrate_repeats):
231 output = self.module(input)
232 #
233 self.remove_call_hook(self.module)
234 self._backup_weights_quant()
235 return output
236 #
237 def _forward_adjust_bias_hook(self, op, *inputs_orig):
238 outputs = op.__forward_orig__(*inputs_orig)
239 # calibration at specific layers
240 bias = op.bias if (hasattr(op, 'bias') and (op.bias is not None)) else None
241 if bias is not None:
242 output = outputs[0] if isinstance(outputs, (list,tuple)) else outputs
243 reduce_dims = [0,2,3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
244 output_mean = torch.mean(output, dim=reduce_dims)
245 output_delta = op.__output_mean_orig__ - output_mean
246 output_delta = output_delta * self.bias_calibration_factor
247 bias.data += (output_delta)
248 # # TODO: is this required?
249 # if len(output.size()) == 4:
250 # output.data += output_delta.data.view(1,-1,1,1)
251 # elif len(output.size()) == 2:
252 # output.data += output_delta.data.view(1,-1)
253 # else:
254 # assert False, 'unknown dimensions'
255 # #
256 #
258 # weight = op.weight if (hasattr(op, 'weight') and (op.weight is not None)) else None
259 # iter_threshold = (1.0/self.bias_calibration_factor)
260 # if self.calibrate_weights and (weight is not None) and (self.num_batches_tracked > iter_threshold):
261 # eps = 1e-3
262 # output = outputs[0] if isinstance(outputs, (list,tuple)) else outputs
263 # reduce_dims = [0, 2, 3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
264 # output_std = torch.std(output, dim=reduce_dims)
265 # output_ratio = (op.__output_std_orig__ + eps) / (output_std + eps)
266 # channels = output.size(1)
267 # output_ratio = output_ratio.view(channels, 1, 1, 1) if len(weight.data.size()) > 1 else output_ratio
268 # output_ratio = torch.pow(output_ratio, self.bias_calibration_factor)
269 # output_ratio = torch.clamp(output_ratio, 1.0-self.bias_calibration_factor, 1.0+self.bias_calibration_factor)
270 # weight.data *= output_ratio
271 # #
272 # #
274 return outputs