[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_test_module.py
1 import torch
2 import math
3 import copy
4 import warnings
5 import numpy as np
7 from .quant_base_module import *
8 from .quant_utils import *
11 class QuantTestModule(QuantBaseModule):
12 def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False, histogram_range=True,
13 range_calibration_online=False, model_surgery_quantize=False):
14 super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
15 per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=False,
16 constrain_weights=False, model_surgery_quantize=model_surgery_quantize)
17 # use power2_weights for now
18 self.power2_weights = True
19 # whether to do online adjustment of calibration using previous frame range
20 self.range_calibration_online = range_calibration_online
21 # number of offline calibration iters. during offline calibration, current frame range is used
22 self.range_calibration_offline_iters = 25 #10
24 # minimum speed for range update
25 self.range_update_factor_min = 0.001 #0.1
26 # range expansion is not needed now as the ranges are not computed based on the actual floating point values.
27 # earlier it was based on quantized values - that's when the expansion was needed.
28 self.range_expansion_factor = 1.0
30 # set these to 0 to use faster min/max based range computation (lower accuracy) instead of histogram based range.
31 # shrink range: 0.01 means 0.01 percentile_range_shrink, not 1 percentile_range_shrink
32 self.percentile_range_shrink_activations = (layers.PAct2.PACT2_RANGE_SHRINK if histogram_range else 0)
33 # range shrinking of weight is hurting in some models
34 self.percentile_range_shrink_weights = 0 #(0.01 if histogram_range else 0)
36 self.idx_large_mse_for_act = 0
39 def model_surgery_quantize(self, dummy_input):
40 super().model_surgery_quantize(dummy_input)
42 def replace_func(op):
43 for name, m in op._modules.items():
44 if isinstance(m, layers.NoAct):
45 new_m = layers.PAct2(signed=None)
46 else:
47 new_m = None
48 #
49 if new_m is not None:
50 for attr in dir(m):
51 value = getattr(m,attr)
52 if isinstance(value,torch.Tensor) and value is not None:
53 getattr(new_m,attr).data.copy_(value.data)
54 elif isinstance(value,torch.nn.Module) and value is not None:
55 setattr(new_m, attr, getattr(m,attr))
56 elif attr in ('weight','bias','eps','clips_act','clips_w'): # add more attribute name here
57 setattr(new_m, attr, getattr(m, attr))
58 #
59 new_m.train(m.training)
60 setattr(op, name, new_m)
61 #
62 #
63 #
64 #
65 # apply recursively
66 self.apply(replace_func)
68 # clear
69 self.clear_states()
70 #
73 def forward(self, inputs):
74 # analyze - need to merge_weights - so call analyze_graph() instead of just update_counters()
75 self.analyze_graph(inputs=inputs, force_update=True, merge_weights=True, cleanup_states=True)
77 # batch_size = inputs[0].size(0) if utils.is_list(inputs) else inputs.size(0)
78 # if batch_size != 1:
79 # warnings.warn('suggest (not mandatory) to set batchsize to 1 for quantized inference to simulate a realistic scenario')
80 # #
82 # calibration does not need gradients
83 with torch.no_grad():
84 # quantize
85 outputs = self.forward_quantize(inputs)
86 # start and new frame, copy the qparams for previous frame of inference
87 self.get_state().qparams_prev = self.copy_qparams(self.get_state().qparams, inputs)
88 # return
89 return outputs
90 #
93 def _forward_quantize_hook(self, op, *inputs_orig):
94 inputs = utils.squeeze_list(inputs_orig)
95 self.start_node(op)
96 self.start_quantize(op)
98 if (self.iter_in_epoch == 0):
99 self.process_weights(op, inputs)
100 #
101 self.process_inputs(op, inputs, None)
103 outputs = op.__forward_orig__(*inputs_orig)
105 self.process_outputs(op, inputs, outputs)
106 self.finish_node(op, inputs, outputs)
107 return outputs
108 #
110 def forward_quantize(self, inputs):
111 self.start_call()
112 self.add_call_hook(self.module, self._forward_quantize_hook)
113 self.current_scale = 1.0
114 outputs = self.module(inputs)
115 self.remove_call_hook(self.module)
116 self.finish_call()
117 return outputs
118 #
121 # implement this in a derived class to clamp weights
122 def apply_constrain_weights(self, module):
123 pass
126 # implement this in a derived class to do bias calibration
127 def calibrate_bias(self, inputs):
128 pass
131 def start_quantize(self, op):
132 qparams = self.get_qparams(op)
133 qparams.qrange_in = []
134 qparams.qrange_out = []
137 def process_weights(self, module, inputs, outputs=None):
138 weight = module.weight if hasattr(module, 'weight') else None
139 bias = module.bias if hasattr(module, 'bias') else None
140 qparams = self.get_qparams(module)
141 if (self.bitwidth_weights is None) or (not qparams.quantize_w):
142 return
144 if qparams.quantize_w and weight is not None:
145 qparams.qrange_w = Dict()
146 self.quantize_weights(module, weight, qparams.qrange_w)
147 else:
148 qparams.qrange_w = None
150 if qparams.quantize_b and bias is not None:
151 qparams.qparams_b = Dict()
152 self.quantize_bias(module, bias, qparams.qparams_b)
153 else:
154 qparams.qparams_b = None
157 def process_inputs(self, module, inputs, outputs=None):
158 if self.bitwidth_activations is None:
159 return
161 inputs = self.format_tensors(inputs)
162 outputs = self.format_tensors(outputs)
163 qparams = self.get_qparams(module)
164 qparams_prev = self.get_qparams_prev(module)
166 # track the scale across non-modules (eg. functionals) via current_scale
167 for inp in inputs:
168 inp.scale = inp.scale if hasattr(inp,'scale') else self.current_scale
170 qrange_cur = self.quantize_inputs(module, inputs, outputs, qparams_prev, qparams)
172 # create the current scale in proccess_inputs instead of process_outputs.
173 # otherwise exit condition for aggregate modules (eg. torch.nn.Sequential, Bottleneck in ResNet) will cause trouble.
174 # all the inputs scales are assumed to be aligned at this point (see align_inputs)
175 # any module that needs special handling needs to be considered in quantize_inputs / align_inputs.
176 has_weight_scale = (hasattr(module,'weight') and (module.weight is not None) and hasattr(module.weight,'scale'))
177 if has_weight_scale:
178 is_dw = utils.is_dwconv(module)
179 use_per_channel_q = (is_dw and self.per_channel_q is True) or (self.per_channel_q == 'all')
180 if use_per_channel_q:
181 #different scale for different channels
182 self.current_scale = [inputs[0].scale * module.weight.scale[chan] for chan in range(module.weight.shape[0])]
183 else:
184 self.current_scale = (inputs[0].scale * module.weight.scale)
185 #
186 else:
187 self.current_scale = inputs[0].scale
188 #
190 # update range
191 if qparams.quantize_in:
192 # in the first frame we cannot do running update. after that we can do that.
193 running_update = (qparams_prev is not None) and len(qparams_prev.qrange_in)>0
194 for idx, inp in enumerate(inputs):
195 qrange_prev = qparams_prev.qrange_in[idx] if running_update else (0,0)
196 qrange_running = self._update_activation_ranges(module, inp, running_update, qrange_cur[idx], qrange_prev)
197 qparams.qrange_in.append(qrange_running)
200 def process_outputs(self, module, inputs, outputs):
201 if self.bitwidth_activations is None:
202 return
204 inputs = self.format_tensors(inputs)
205 output = self.format_tensors(outputs)
206 qparams = self.get_qparams(module)
207 qparams_prev = self.get_qparams_prev(module)
209 # already adjusted the scale due to weights, in process_inputs
210 for idx, opt in enumerate(output):
211 opt.scale = self.current_scale
213 qrange_cur = self.quantize_outputs(module, inputs, output, qparams_prev, qparams)
214 self.viz_act(en=False, opt_q=output[0], opt_fl=inputs[0])
215 self.current_scale = output[0].scale
217 # update range
218 if qparams.quantize_out or qparams.unquantize_out:
219 # in the first frame we cannot do running update. after that we can do that.
220 running_update = (qparams_prev is not None) and len(qparams_prev.qrange_out)>0
221 for idx, opt in enumerate(output):
222 if isinstance(opt, (torch.LongTensor, torch.cuda.LongTensor)):
223 continue
224 #
225 qrange_prev = qparams_prev.qrange_out[idx] if running_update else None
226 qrange_running = self._update_activation_ranges(module, opt, running_update, qrange_cur[idx], qrange_prev)
227 qparams.qrange_out.append(qrange_running)
229 self.unquantize_outputs(module, inputs, output, qparams_prev, qparams)
230 self.current_scale = output[0].scale
233 def compute_tensor_range(self, module, tensor_in, percentile_range_shrink):
234 if hasattr(tensor_in, 'scale') and utils.is_list(tensor_in.scale):
235 scale_inv = [(1/s) for s in tensor_in.scale]
236 tensor_scale_inv = torch.tensor(scale_inv).view(1,-1,1,1).to(tensor_in.device)
237 tensor_scaled = tensor_in * tensor_scale_inv
238 (mn, mx) = self._compute_tensor_range_noscale(module, tensor_scaled, percentile_range_shrink)
239 else:
240 scale = tensor_in.scale if hasattr(tensor_in, 'scale') else 1.0
241 (mn, mx) = self._compute_tensor_range_noscale(module, tensor_in, percentile_range_shrink)
242 (mn, mx) = (mn / scale, mx / scale)
243 #
244 return mn, mx
247 def _compute_tensor_range_noscale(self, module, tensor, percentile_range_shrink):
248 mn, mx = utils.extrema_fast(tensor.data, percentile_range_shrink)
249 return mn, mx
252 def _update_activation_ranges(self, module, tensor_in, running_update, qrange_cur, qrange_prev):
253 is_calibration = (self.iter_in_epoch < self.range_calibration_offline_iters)
254 update_range = (is_calibration or self.range_calibration_online)
255 if update_range:
256 # in the case of fixed range module, we do not expand the ranges
257 fixed_range_module = utils.is_fixed_range(module)
258 if fixed_range_module:
259 qrange_running = qrange_cur
260 else:
261 (mn, mx) = (float(qrange_cur.min)*self.range_expansion_factor, float(qrange_cur.max)*self.range_expansion_factor)
262 # in the first frame we cannot do running update. after that we can do that.
263 if running_update:
264 update_factor = (1.0 / (self.iter_in_epoch + 1))
265 update_factor = max(update_factor, self.range_update_factor_min) if self.range_update_factor_min else update_factor
266 mn = update_factor * mn + (1 - update_factor) * qrange_prev.min
267 mx = update_factor * mx + (1 - update_factor) * qrange_prev.max
268 #
269 qrange_running = Dict()
270 qrange_running.min = mn; qrange_running.max = mx
271 #
272 else:
273 qrange_running = qrange_prev
275 return qrange_running
278 def get_bitwidth_weights(self, module):
279 bitwidth_weights_last = (self.bitwidth_weights[2] if utils.is_list(self.bitwidth_weights) else self.bitwidth_weights)
280 bitwidth_weights_dw = (self.bitwidth_weights[1] if utils.is_list(self.bitwidth_weights) else self.bitwidth_weights)
281 bitwidth_weights_nodw = (self.bitwidth_weights[0] if utils.is_list(self.bitwidth_weights) else self.bitwidth_weights)
282 bitwidth_weights = bitwidth_weights_last if self.is_last_conv(module) else \
283 (bitwidth_weights_dw if utils.is_dwconv(module) else bitwidth_weights_nodw)
284 return bitwidth_weights
287 def get_bitwidth_activations(self, module):
288 bitwidth_activations_last = (self.bitwidth_activations[2] if utils.is_list(self.bitwidth_activations) else self.bitwidth_activations)
289 bitwidth_activations_dw = (self.bitwidth_activations[1] if utils.is_list(self.bitwidth_activations) else self.bitwidth_activations)
290 bitwidth_activations_nodw = (self.bitwidth_activations[0] if utils.is_list(self.bitwidth_activations) else self.bitwidth_activations)
291 bitwidth_activations = bitwidth_activations_last if self.is_last_conv(module) else \
292 (bitwidth_activations_dw if utils.is_dwconv(module) else bitwidth_activations_nodw)
293 return bitwidth_activations
296 def quantize_weights(self, module, tensor_in, qrange):
297 self.apply_constrain_weights(module)
299 bitwidth_weights = self.get_bitwidth_weights(module)
300 with torch.no_grad():
301 is_dw = utils.is_dwconv(module)
302 use_per_channel_q = (self.per_channel_q == 'all' or (bool(self.per_channel_q) == True and is_dw))
303 if use_per_channel_q:
304 qrange.min = []
305 qrange.max = []
306 tensor_in.scale = []
307 for chan in range(tensor_in.shape[0]):
308 # Range
309 mn, mx = self.compute_tensor_range(module, tensor_in[chan], percentile_range_shrink=self.percentile_range_shrink_weights)
310 tensor_scale, clamp_limits = compute_tensor_scale(tensor_in[chan], mn, mx, bitwidth_weights, self.power2_weights)
311 qrange.min.append(mn)
312 qrange.max.append(mx)
313 # Quantize
314 tensor = symmetric_round_tensor(tensor_in[chan] * tensor_scale)
315 tensor = tensor.clamp(clamp_limits[0], clamp_limits[1]).float()
316 # Convert back to float - since this module does only simulation
317 tensor_in[chan].data[...] = (tensor.data / tensor_scale)
318 tensor_in.scale.append(1.0)
319 #
320 else:
321 # Range
322 mn, mx = self.compute_tensor_range(module, tensor_in, percentile_range_shrink=self.percentile_range_shrink_weights)
323 tensor_scale, clamp_limits = compute_tensor_scale(tensor_in, mn, mx, bitwidth_weights, self.power2_weights)
324 qrange.min = mn
325 qrange.max = mx
326 # Quantize
327 tensor = symmetric_round_tensor(tensor_in * tensor_scale)
328 tensor = tensor.clamp(clamp_limits[0], clamp_limits[1])
329 #Convert back to float - since this module does only simulation
330 tensor_in.data = (tensor.data / tensor_scale)
331 tensor_in.scale = 1.0
334 def quantize_bias(self, module, tensor_in, qparams):
335 quant_for_bias = True
336 if quant_for_bias:
337 bitwidth_weights = self.get_bitwidth_weights(module)
339 #use same bitwidth as weight
340 bitwidth_bias = bitwidth_weights
342 mn, mx = self.compute_tensor_range(module, tensor_in, percentile_range_shrink=0.0)
343 tensor_scale, clamp_limits = compute_tensor_scale(tensor_in, mn, mx, bitwidth_bias, self.power2_weights)
345 # --
346 tensor = symmetric_round_tensor(tensor_in * tensor_scale)
347 tensor = tensor.clamp(clamp_limits[0], clamp_limits[1])
349 # Convert back to float - since this module does only simulation
350 tensor_in.data = (tensor.data / tensor_scale)
351 tensor_in.scale = 1.0
352 else:
353 tensor_in.scale = 1.0
356 def quantize_inputs(self, module, input, output, qparams_prev, qparams):
357 qrange_cur = []
358 use_current_range = (not self.range_calibration_offline_iters) or (self.iter_in_epoch < self.range_calibration_offline_iters)
359 for idx, inp in enumerate(input):
360 if qparams.quantize_in:
361 qrange_tensor_approx = None if use_current_range else qparams_prev.qrange_in[idx]
362 qrange_tensor = self._quantize_activation(module, inp, qrange_tensor_approx)
363 qrange_cur.append(qrange_tensor)
365 return qrange_cur
368 def quantize_outputs(self, module, input, output, qparams_prev, qparams):
369 qrange_cur = []
370 use_current_range = (not self.range_calibration_offline_iters) or (self.iter_in_epoch < self.range_calibration_offline_iters)
371 for idx, opt in enumerate(output):
372 if qparams.quantize_out:
373 qrange_tensor_approx = None if use_current_range else qparams_prev.qrange_out[idx]
374 qrange_tensor = self._quantize_activation(module, opt, qrange_tensor_approx)
375 qrange_cur.append(qrange_tensor)
377 return qrange_cur
380 def unquantize_outputs(self, module, input, output, qparams_prev, qparams):
381 pass
384 def _quantize_activation(self, module, tensor_in, qrange):
385 bitwidth_activations = self.get_bitwidth_activations(module)
386 with torch.no_grad():
387 if qrange:
388 # after calibration, we use the range obtained from previous frame directly
389 mn = qrange.min
390 mx = qrange.max
391 else:
392 # range expansion is not required when quantizing using the current frame range (calibration)
393 # for fixed range modules, we use that range directly.
394 fixed_range_module = utils.is_fixed_range(module)
395 if fixed_range_module:
396 op_range = utils.get_range(module)
397 mn = op_range[0]
398 mx = op_range[1]
399 else:
400 mn, mx = self.compute_tensor_range(module, tensor_in, percentile_range_shrink=self.percentile_range_shrink_activations)
402 tensor_scale, clamp_limits = compute_tensor_scale(None, mn, mx, bitwidth_activations, True)
403 tensor = upward_round_tensor(tensor_in*tensor_scale)
404 tensor = tensor.clamp(clamp_limits[0], clamp_limits[1])
406 #Convert back to float - since this module does only simulation
407 tensor_in.data = tensor.data/tensor_scale
408 tensor_in.scale = 1.0
409 qrange_tensor = Dict(); qrange_tensor.min = mn; qrange_tensor.max = mx
410 return qrange_tensor
413 def wt_mse_based_clip(self, tensor_wt_fl, bitwidth_weights=8):
414 mn, mx = utils.extrema_fast(tensor_wt_fl)
415 mn = mn.cpu().numpy()
416 mx = mx.cpu().numpy()
417 mx_abs = max(abs(mn), abs(mx))
418 # print("******** New Wt Tensor Starts *******")
419 # print("mn,mx: ", mn,mx, end = ' ')
420 if mx_abs == 0:
421 best_clip_value = 0.0
422 # print("All weights 0. Dead channel !!! Check")
423 else:
424 # How many clip value needs to be searched for?
425 serach_num_clips = 100
426 step_size = mx_abs / serach_num_clips
427 # print("step_size: ", step_size)
428 min_mse = float("inf")
429 bset_clip_value = 0
430 for clip_value in np.arange(step_size, mx_abs + step_size, step_size):
431 # print("Target clip_value: {:05f}".format(clip_value), end=' ')
432 [mse, actual_clip] = self.compute_mse_for_quant(tensor_wt_fl=tensor_wt_fl, mn=-clip_value,
433 mx=clip_value, bitwidth_weights=bitwidth_weights)
434 if mse < min_mse:
435 min_mse = mse
436 best_clip_value = actual_clip
437 # print("clip_value: mse : {:05f} : {:8.5f}".format(actual_clip, mse))
438 # print("******** Clip value Ends *******")
439 #if min_mse < mse:
440 # print("best_clip_value: {:8.5f} min_mse : {:0.7f} clip_value: {:8.5f} mse : {:0.7f} ".format(
441 # best_clip_value, min_mse, actual_clip, mse))
442 # [mn,mx]
443 return [-best_clip_value, best_clip_value]
446 def compute_mse_for_quant(self, tensor_wt_fl, mn=0, mx=0, bitwidth_weights=8):
447 tensor_scale, clamp_limits = compute_tensor_scale(tensor_wt_fl, mn, mx, bitwidth_weights, self.power2_weights)
449 # print("mn : mx {} {}".format(mn, mx))
451 # print("tensor_wt_fl: " , tensor_wt_fl.cpu().numpy().flatten()[0:20])
452 tensor_wt_q = symmetric_round_tensor(tensor_wt_fl * tensor_scale)
453 # print("tensor_wt_q:fl*scale " , tensor_wt_q.cpu().numpy().flatten()[0:20])
454 tensor_wt_q = tensor_wt_q.clamp(clamp_limits[0], clamp_limits[1])
455 # print("tensor_wt_q:clamp(flt*scale) " , tensor_wt_q.cpu().numpy().flatten()[0:20])
457 # Convert back to float - since this module does only simulation
458 tensor_wt_q.data = (tensor_wt_q.data / tensor_scale)
459 tensor_wt_q.scale = 1.0
460 # print("tensor_wt_q:Final " , tensor_wt_q.cpu().numpy().flatten()[0:20])
462 mse = ((tensor_wt_fl.cpu().numpy() - tensor_wt_q.cpu().numpy()) ** 2).mean(axis=None)
463 actual_clip = clamp_limits[1] / tensor_scale
464 return [mse, actual_clip]
467 def viz_act(self, en=False, opt_q=[], opt_fl=[]):
468 if not en:
469 return
470 opt_q = opt_q.cpu().numpy().flatten()
471 opt_fl = opt_fl.cpu().numpy().flatten()
472 # act_mse = ((opt_q - opt_fl)**2).mean(axis=None)
473 if True: # (act_mse > 1E-4):
474 # print("act_mse: {:.6f}".format( act_mse))
475 if (self.idx_large_mse_for_act >= 0):
476 mn = opt_fl.min()
477 mx = opt_fl.max()
478 hist_fl = utils.hist_weight_tensor2D(x_ch=opt_fl, log=True, dir='act_study_fl',
479 name='act_{:03d}_fl_{:.3f}_{:.3f}'.format(
480 self.idx_large_mse_for_act, mn, mx), ch=0, en=True)
482 mn = opt_q.min()
483 mx = opt_q.max()
484 hist_q = utils.hist_weight_tensor2D(x_ch=opt_q, log=True, dir='act_study_q',
485 name='act_{:03d}_q_{:.3f}_{:.3f}'.format(self.idx_large_mse_for_act,
486 mn, mx), ch=0, en=True)
487 # print('hist_fl: ', hist_fl)#
488 # print('hist_q: ', hist_q)
489 self.idx_large_mse_for_act += 1