[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_test_module.py
1 import torch
2 import math
3 import copy
4 import warnings
5 import numpy as np
8 ########################################################################
9 from .quant_train_module import *
11 class QuantTestModule(QuantTrainModule):
12 def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
13 histogram_range=True, bias_calibration=False, constrain_weights=None, model_surgery_quantize=True,
14 power2_weight_range=None, power2_activation_range=None, constrain_bias=None):
15 constrain_weights = (not per_channel_q) if constrain_weights is None else constrain_weights
16 super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
17 per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
18 constrain_weights=constrain_weights,
19 power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range, constrain_bias=constrain_bias)
20 assert model_surgery_quantize == True, f'{self.__class__.__name__} does not support model_surgery_quantize=False. please use a qat or calibrated module.'
21 self.eval()
24 def train(self, mode=True):
25 assert mode == False, 'QuantTestModule cannot be used in train mode'
26 super().train(mode)
29 ########################################################################
30 from .quant_base_module import *
31 from .quant_utils import *
34 class QuantEstimateModule(QuantBaseModule):
35 '''
36 QuantEstimateModule can be used to estimate the quantization accuracy of a float model
37 that has not gone through QAT or Calibration. However, this is an approximate method.
38 '''
39 def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
40 histogram_range=True, range_calibration_online=False, model_surgery_quantize=True,
41 power2_weight_range=None, power2_activation_range=None, constrain_bias=None):
42 super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
43 per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=False,
44 constrain_weights=False, model_surgery_quantize=model_surgery_quantize,
45 power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range, constrain_bias=constrain_bias)
46 assert False, 'recommend to use QuantTestModule instead'
47 # whether to do online adjustment of calibration using previous frame range
48 self.range_calibration_online = range_calibration_online
49 # number of offline calibration iters. during offline calibration, current frame range is used
50 self.range_calibration_offline_iters = 25 #10
52 # minimum speed for range update
53 self.range_update_factor_min = 0.001 #0.1
54 # range expansion is not needed now as the ranges are not computed based on the actual floating point values.
55 # earlier it was based on quantized values - that's when the expansion was needed.
56 self.range_expansion_factor = 1.0
58 # set these to 0 to use faster min/max based range computation (lower accuracy) instead of histogram based range.
59 # shrink range: 0.01 means 0.01 percentile_range_shrink, not 1 percentile_range_shrink
60 self.percentile_range_shrink_activations = (layers.PAct2.PACT2_RANGE_SHRINK if histogram_range else 0)
61 # range shrinking of weight is hurting in some models
62 self.percentile_range_shrink_weights = 0 #(0.01 if histogram_range else 0)
64 self.idx_large_mse_for_act = 0
67 def model_surgery_quantize(self, dummy_input):
68 super().model_surgery_quantize(dummy_input)
70 def replace_func(op):
71 for name, m in op._modules.items():
72 if isinstance(m, layers.QAct):
73 new_m = layers.PAct2(signed=None)
74 else:
75 new_m = None
76 #
77 if new_m is not None:
78 for attr in dir(m):
79 value = getattr(m,attr)
80 if isinstance(value,torch.Tensor) and value is not None:
81 getattr(new_m,attr).data.copy_(value.data)
82 elif isinstance(value,torch.nn.Module) and value is not None:
83 setattr(new_m, attr, getattr(m,attr))
84 elif attr in ('weight','bias','eps','clips_act','clips_w'): # add more attribute name here
85 setattr(new_m, attr, getattr(m, attr))
86 #
87 new_m.train(m.training)
88 setattr(op, name, new_m)
89 #
90 #
91 #
92 #
93 # apply recursively
94 self.apply(replace_func)
96 # clear
97 self.clear_qstate()
98 #
101 def forward(self, inputs):
102 # analyze - need to merge_weights - so call analyze_graph() instead of just update_counters()
103 self.analyze_graph(inputs=inputs, force_update=True, merge_weights=True)
105 # batch_size = inputs[0].size(0) if utils.is_list(inputs) else inputs.size(0)
106 # if batch_size != 1:
107 # warnings.warn('suggest (not mandatory) to set batchsize to 1 for quantized inference to simulate a realistic scenario')
108 # #
110 # calibration does not need gradients
111 with torch.no_grad():
112 # quantize
113 outputs = self.forward_quantize(inputs)
114 # start and new frame, copy the qparams for previous frame of inference
115 self.get_qstate().qparams_prev = self.copy_qparams(self.get_qstate().qparams, inputs)
116 # return
117 return outputs
118 #
121 def _forward_quantize_hook(self, op, *inputs_orig):
122 inputs = utils.squeeze_list2(inputs_orig)
123 self.start_node(op)
124 self.start_quantize(op)
126 if (self.iter_in_epoch == 0):
127 self.process_weights(op, inputs)
128 #
129 self.process_inputs(op, inputs, None)
131 outputs = op.__forward_orig__(*inputs_orig)
133 self.process_outputs(op, inputs, outputs)
134 self.finish_node(op, inputs, outputs)
135 return outputs
136 #
138 def forward_quantize(self, inputs):
139 self.start_call()
140 self.add_call_hook(self.module, self._forward_quantize_hook)
141 self.current_scale = 1.0
142 outputs = self.module(inputs)
143 self.remove_call_hook(self.module)
144 self.finish_call()
145 return outputs
146 #
149 # implement this in a derived class to clamp weights
150 def apply_constrain_weights(self, module):
151 pass
154 # implement this in a derived class to do bias calibration
155 def calibrate_bias(self, inputs):
156 pass
159 def start_quantize(self, op):
160 qparams = self.get_qparams(op)
161 qparams.qrange_in = []
162 qparams.qrange_out = []
165 def process_weights(self, module, inputs, outputs=None):
166 weight = module.weight if hasattr(module, 'weight') else None
167 bias = module.bias if hasattr(module, 'bias') else None
168 qparams = self.get_qparams(module)
169 if (self.bitwidth_weights is None) or (not qparams.quantize_w):
170 return
172 if qparams.quantize_w and weight is not None:
173 qparams.qrange_w = Dict()
174 self.quantize_weights_tensor(module, weight, qparams.qrange_w)
175 else:
176 qparams.qrange_w = None
178 if qparams.quantize_b and bias is not None:
179 qparams.qparams_b = Dict()
180 self.quantize_bias_tensor(module, bias, qparams.qparams_b)
181 else:
182 qparams.qparams_b = None
185 def process_inputs(self, module, inputs, outputs=None):
186 if self.bitwidth_activations is None:
187 return
189 inputs = self.format_tensors(inputs)
190 outputs = self.format_tensors(outputs)
191 qparams = self.get_qparams(module)
192 qparams_prev = self.get_qparams_prev(module)
194 # track the scale across non-modules (eg. functionals) via current_scale
195 for inp in inputs:
196 inp.scale = inp.scale if hasattr(inp,'scale') else self.current_scale
198 qrange_cur = self.quantize_input_tensors(module, inputs, outputs, qparams_prev, qparams)
200 # create the current scale in proccess_inputs instead of process_outputs.
201 # otherwise exit condition for aggregate modules (eg. torch.nn.Sequential, Bottleneck in ResNet) will cause trouble.
202 # all the inputs scales are assumed to be aligned at this point (see align_input_tensors)
203 # any module that needs special handling needs to be considered in quantize_input_tensors / align_input_tensors.
204 has_weight_scale = (hasattr(module,'weight') and (module.weight is not None) and hasattr(module.weight,'scale'))
205 if has_weight_scale:
206 is_dw = utils.is_dwconv(module)
207 use_per_channel_q = (is_dw and self.per_channel_q is True) or (self.per_channel_q == 'all')
208 if use_per_channel_q:
209 #different scale for different channels
210 self.current_scale = [inputs[0].scale * module.weight.scale[chan] for chan in range(module.weight.shape[0])]
211 else:
212 self.current_scale = (inputs[0].scale * module.weight.scale)
213 #
214 else:
215 self.current_scale = inputs[0].scale
216 #
218 # update range
219 if qparams.quantize_in:
220 # in the first frame we cannot do running update. after that we can do that.
221 running_update = (qparams_prev is not None) and len(qparams_prev.qrange_in)>0
222 for idx, inp in enumerate(inputs):
223 qrange_prev = qparams_prev.qrange_in[idx] if running_update else (0,0)
224 qrange_running = self._update_activation_ranges(module, inp, running_update, qrange_cur[idx], qrange_prev)
225 qparams.qrange_in.append(qrange_running)
228 def process_outputs(self, module, inputs, outputs):
229 if self.bitwidth_activations is None:
230 return
232 inputs = self.format_tensors(inputs)
233 output = self.format_tensors(outputs)
234 qparams = self.get_qparams(module)
235 qparams_prev = self.get_qparams_prev(module)
237 # already adjusted the scale due to weights, in process_inputs
238 for idx, opt in enumerate(output):
239 opt.scale = self.current_scale
241 qrange_cur = self.quantize_output_tensors(module, inputs, output, qparams_prev, qparams)
242 self.viz_act(en=False, opt_q=output[0], opt_fl=inputs[0])
243 self.current_scale = output[0].scale
245 # update range
246 if qparams.quantize_out or qparams.unquantize_out:
247 # in the first frame we cannot do running update. after that we can do that.
248 running_update = (qparams_prev is not None) and len(qparams_prev.qrange_out)>0
249 for idx, opt in enumerate(output):
250 if isinstance(opt, (torch.LongTensor, torch.cuda.LongTensor)):
251 continue
252 #
253 qrange_prev = qparams_prev.qrange_out[idx] if running_update else None
254 qrange_running = self._update_activation_ranges(module, opt, running_update, qrange_cur[idx], qrange_prev)
255 qparams.qrange_out.append(qrange_running)
257 self.unquantize_outputs(module, inputs, output, qparams_prev, qparams)
258 self.current_scale = output[0].scale
261 def compute_tensor_range(self, module, tensor_in, percentile_range_shrink):
262 if hasattr(tensor_in, 'scale') and utils.is_list(tensor_in.scale):
263 scale_inv = [(1/s) for s in tensor_in.scale]
264 tensor_scale_inv = torch.tensor(scale_inv).view(1,-1,1,1).to(tensor_in.device)
265 tensor_scaled = tensor_in * tensor_scale_inv
266 (mn, mx) = self._compute_tensor_range_noscale(module, tensor_scaled, percentile_range_shrink)
267 else:
268 scale = tensor_in.scale if hasattr(tensor_in, 'scale') else 1.0
269 (mn, mx) = self._compute_tensor_range_noscale(module, tensor_in, percentile_range_shrink)
270 (mn, mx) = (mn / scale, mx / scale)
271 #
272 return mn, mx
275 def _compute_tensor_range_noscale(self, module, tensor, percentile_range_shrink):
276 mn, mx = utils.extrema_fast(tensor.data, percentile_range_shrink)
277 return mn, mx
280 def _update_activation_ranges(self, module, tensor_in, running_update, qrange_cur, qrange_prev):
281 is_calibration = (self.iter_in_epoch < self.range_calibration_offline_iters)
282 update_activation_range = (is_calibration or self.range_calibration_online)
283 if update_activation_range:
284 # in the case of fixed range module, we do not expand the ranges
285 fixed_range_module = utils.is_fixed_range(module)
286 if fixed_range_module:
287 qrange_running = qrange_cur
288 else:
289 (mn, mx) = (float(qrange_cur.min)*self.range_expansion_factor, float(qrange_cur.max)*self.range_expansion_factor)
290 # in the first frame we cannot do running update. after that we can do that.
291 if running_update:
292 update_factor = (1.0 / (self.iter_in_epoch + 1))
293 update_factor = max(update_factor, self.range_update_factor_min) if self.range_update_factor_min else update_factor
294 mn = update_factor * mn + (1 - update_factor) * qrange_prev.min
295 mx = update_factor * mx + (1 - update_factor) * qrange_prev.max
296 #
297 qrange_running = Dict()
298 qrange_running.min = mn; qrange_running.max = mx
299 #
300 else:
301 qrange_running = qrange_prev
303 return qrange_running
306 def get_bitwidth_weights(self, module):
307 bitwidth_weights_last = (self.bitwidth_weights[2] if utils.is_list(self.bitwidth_weights) else self.bitwidth_weights)
308 bitwidth_weights_dw = (self.bitwidth_weights[1] if utils.is_list(self.bitwidth_weights) else self.bitwidth_weights)
309 bitwidth_weights_nodw = (self.bitwidth_weights[0] if utils.is_list(self.bitwidth_weights) else self.bitwidth_weights)
310 bitwidth_weights = bitwidth_weights_last if self.is_last_conv(module) else \
311 (bitwidth_weights_dw if utils.is_dwconv(module) else bitwidth_weights_nodw)
312 return bitwidth_weights
315 def get_bitwidth_activations(self, module):
316 bitwidth_activations_last = (self.bitwidth_activations[2] if utils.is_list(self.bitwidth_activations) else self.bitwidth_activations)
317 bitwidth_activations_dw = (self.bitwidth_activations[1] if utils.is_list(self.bitwidth_activations) else self.bitwidth_activations)
318 bitwidth_activations_nodw = (self.bitwidth_activations[0] if utils.is_list(self.bitwidth_activations) else self.bitwidth_activations)
319 bitwidth_activations = bitwidth_activations_last if self.is_last_conv(module) else \
320 (bitwidth_activations_dw if utils.is_dwconv(module) else bitwidth_activations_nodw)
321 return bitwidth_activations
324 def quantize_weights_tensor(self, module, tensor_in, qrange):
325 self.apply_constrain_weights(module)
327 bitwidth_weights = self.get_bitwidth_weights(module)
328 with torch.no_grad():
329 is_dw = utils.is_dwconv(module)
330 use_per_channel_q = (self.per_channel_q == 'all' or (bool(self.per_channel_q) == True and is_dw))
331 if use_per_channel_q:
332 qrange.min = []
333 qrange.max = []
334 tensor_in.scale = []
335 for chan in range(tensor_in.shape[0]):
336 # Range
337 mn, mx = self.compute_tensor_range(module, tensor_in[chan], percentile_range_shrink=self.percentile_range_shrink_weights)
338 tensor_scale, clamp_limits = compute_tensor_scale(tensor_in[chan], mn, mx, bitwidth_weights, self.power2_weight_range)
339 qrange.min.append(mn)
340 qrange.max.append(mx)
341 # Quantize
342 tensor = symmetric_round_tensor(tensor_in[chan] * tensor_scale)
343 tensor = tensor.clamp(clamp_limits[0], clamp_limits[1]).float()
344 # Convert back to float - since this module does only simulation
345 tensor_in[chan].data[...] = (tensor.data / tensor_scale)
346 tensor_in.scale.append(1.0)
347 #
348 else:
349 # Range
350 mn, mx = self.compute_tensor_range(module, tensor_in, percentile_range_shrink=self.percentile_range_shrink_weights)
351 tensor_scale, clamp_limits = compute_tensor_scale(tensor_in, mn, mx, bitwidth_weights, self.power2_weight_range)
352 qrange.min = mn
353 qrange.max = mx
354 # Quantize
355 tensor = symmetric_round_tensor(tensor_in * tensor_scale)
356 tensor = tensor.clamp(clamp_limits[0], clamp_limits[1])
357 #Convert back to float - since this module does only simulation
358 tensor_in.data = (tensor.data / tensor_scale)
359 tensor_in.scale = 1.0
362 def quantize_bias_tensor(self, module, tensor_in, qparams):
363 quant_for_bias = True
364 if quant_for_bias:
365 bitwidth_weights = self.get_bitwidth_weights(module)
367 #use same bitwidth as weight
368 bitwidth_bias = bitwidth_weights
370 mn, mx = self.compute_tensor_range(module, tensor_in, percentile_range_shrink=0.0)
371 tensor_scale, clamp_limits = compute_tensor_scale(tensor_in, mn, mx, bitwidth_bias, self.power2_weight_range)
373 # --
374 tensor = symmetric_round_tensor(tensor_in * tensor_scale)
375 tensor = tensor.clamp(clamp_limits[0], clamp_limits[1])
377 # Convert back to float - since this module does only simulation
378 tensor_in.data = (tensor.data / tensor_scale)
379 tensor_in.scale = 1.0
380 else:
381 tensor_in.scale = 1.0
384 def quantize_input_tensors(self, module, input, output, qparams_prev, qparams):
385 qrange_cur = []
386 use_current_range = (not self.range_calibration_offline_iters) or (self.iter_in_epoch < self.range_calibration_offline_iters)
387 for idx, inp in enumerate(input):
388 if qparams.quantize_in:
389 qrange_tensor_approx = None if use_current_range else qparams_prev.qrange_in[idx]
390 qrange_tensor = self._quantize_activation(module, inp, qrange_tensor_approx)
391 qrange_cur.append(qrange_tensor)
393 return qrange_cur
396 def quantize_output_tensors(self, module, input, output, qparams_prev, qparams):
397 qrange_cur = []
398 use_current_range = (not self.range_calibration_offline_iters) or (self.iter_in_epoch < self.range_calibration_offline_iters)
399 for idx, opt in enumerate(output):
400 if qparams.quantize_out:
401 qrange_tensor_approx = None if use_current_range else qparams_prev.qrange_out[idx]
402 qrange_tensor = self._quantize_activation(module, opt, qrange_tensor_approx)
403 qrange_cur.append(qrange_tensor)
405 return qrange_cur
408 def unquantize_outputs(self, module, input, output, qparams_prev, qparams):
409 pass
412 def _quantize_activation(self, module, tensor_in, qrange):
413 bitwidth_activations = self.get_bitwidth_activations(module)
414 with torch.no_grad():
415 if qrange:
416 # after calibration, we use the range obtained from previous frame directly
417 mn = qrange.min
418 mx = qrange.max
419 else:
420 # range expansion is not required when quantizing using the current frame range (calibration)
421 # for fixed range modules, we use that range directly.
422 fixed_range_module = utils.is_fixed_range(module)
423 if fixed_range_module:
424 op_range = utils.get_range(module)
425 mn = op_range[0]
426 mx = op_range[1]
427 else:
428 mn, mx = self.compute_tensor_range(module, tensor_in, percentile_range_shrink=self.percentile_range_shrink_activations)
430 tensor_scale, clamp_limits = compute_tensor_scale(None, mn, mx, bitwidth_activations, True)
431 tensor = upward_round_tensor(tensor_in*tensor_scale)
432 tensor = tensor.clamp(clamp_limits[0], clamp_limits[1])
434 #Convert back to float - since this module does only simulation
435 tensor_in.data = tensor.data/tensor_scale
436 tensor_in.scale = 1.0
437 qrange_tensor = Dict(); qrange_tensor.min = mn; qrange_tensor.max = mx
438 return qrange_tensor
441 def wt_mse_based_clip(self, tensor_wt_fl, bitwidth_weights=8):
442 mn, mx = utils.extrema_fast(tensor_wt_fl)
443 mn = mn.cpu().numpy()
444 mx = mx.cpu().numpy()
445 mx_abs = max(abs(mn), abs(mx))
446 # print("******** New Wt Tensor Starts *******")
447 # print("mn,mx: ", mn,mx, end = ' ')
448 if mx_abs == 0:
449 best_clip_value = 0.0
450 # print("All weights 0. Dead channel !!! Check")
451 else:
452 # How many clip value needs to be searched for?
453 serach_num_clips = 100
454 step_size = mx_abs / serach_num_clips
455 # print("step_size: ", step_size)
456 min_mse = float("inf")
457 bset_clip_value = 0
458 for clip_value in np.arange(step_size, mx_abs + step_size, step_size):
459 # print("Target clip_value: {:05f}".format(clip_value), end=' ')
460 [mse, actual_clip] = self.compute_mse_for_quant(tensor_wt_fl=tensor_wt_fl, mn=-clip_value,
461 mx=clip_value, bitwidth_weights=bitwidth_weights)
462 if mse < min_mse:
463 min_mse = mse
464 best_clip_value = actual_clip
465 # print("clip_value: mse : {:05f} : {:8.5f}".format(actual_clip, mse))
466 # print("******** Clip value Ends *******")
467 #if min_mse < mse:
468 # print("best_clip_value: {:8.5f} min_mse : {:0.7f} clip_value: {:8.5f} mse : {:0.7f} ".format(
469 # best_clip_value, min_mse, actual_clip, mse))
470 # [mn,mx]
471 return [-best_clip_value, best_clip_value]
474 def compute_mse_for_quant(self, tensor_wt_fl, mn=0, mx=0, bitwidth_weights=8):
475 tensor_scale, clamp_limits = compute_tensor_scale(tensor_wt_fl, mn, mx, bitwidth_weights, self.power2_weight_range)
477 # print("mn : mx {} {}".format(mn, mx))
479 # print("tensor_wt_fl: " , tensor_wt_fl.cpu().numpy().flatten()[0:20])
480 tensor_wt_q = symmetric_round_tensor(tensor_wt_fl * tensor_scale)
481 # print("tensor_wt_q:fl*scale " , tensor_wt_q.cpu().numpy().flatten()[0:20])
482 tensor_wt_q = tensor_wt_q.clamp(clamp_limits[0], clamp_limits[1])
483 # print("tensor_wt_q:clamp(flt*scale) " , tensor_wt_q.cpu().numpy().flatten()[0:20])
485 # Convert back to float - since this module does only simulation
486 tensor_wt_q.data = (tensor_wt_q.data / tensor_scale)
487 tensor_wt_q.scale = 1.0
488 # print("tensor_wt_q:Final " , tensor_wt_q.cpu().numpy().flatten()[0:20])
490 mse = ((tensor_wt_fl.cpu().numpy() - tensor_wt_q.cpu().numpy()) ** 2).mean(axis=None)
491 actual_clip = clamp_limits[1] / tensor_scale
492 return [mse, actual_clip]
495 def viz_act(self, en=False, opt_q=[], opt_fl=[]):
496 if not en:
497 return
498 opt_q = opt_q.cpu().numpy().flatten()
499 opt_fl = opt_fl.cpu().numpy().flatten()
500 # act_mse = ((opt_q - opt_fl)**2).mean(axis=None)
501 if True: # (act_mse > 1E-4):
502 # print("act_mse: {:.6f}".format( act_mse))
503 if (self.idx_large_mse_for_act >= 0):
504 mn = opt_fl.min()
505 mx = opt_fl.max()
506 hist_fl = utils.hist_weight_tensor2D(x_ch=opt_fl, log=True, dir='act_study_fl',
507 name='act_{:03d}_fl_{:.3f}_{:.3f}'.format(
508 self.idx_large_mse_for_act, mn, mx), ch=0, en=True)
510 mn = opt_q.min()
511 mx = opt_q.max()
512 hist_q = utils.hist_weight_tensor2D(x_ch=opt_q, log=True, dir='act_study_q',
513 name='act_{:03d}_q_{:.3f}_{:.3f}'.format(self.idx_large_mse_for_act,
514 mn, mx), ch=0, en=True)
515 # print('hist_fl: ', hist_fl)#
516 # print('hist_q: ', hist_q)
517 self.idx_large_mse_for_act += 1