[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_graph_module.py
1 import warnings
2 import torch
3 import copy
4 from .. import utils
5 from .. import layers
6 from ..utils import AttrDict as Dict
7 from .hooked_module import *
9 class QuantGraphModule(HookedModule):
10 def __init__(self, module):
11 super().__init__()
12 self.module = module
13 self.init_qstate()
14 self.num_batches_tracked = -1
15 self.iter_in_epoch = -1
16 self.epoch = -1
17 # these are the blocks whose output we quantize for sure.
18 # outputs of other clocks such as Conv2d, ConvTranspose2d, BatchNorm2d, Lindear are quantized conditionally
19 self.quantize_out_blocks = (torch.nn.ReLU, torch.nn.ReLU6, torch.nn.Hardtanh, layers.QAct, layers.PAct2,
20 layers.AddBlock, layers.CatBlock, layers.MultBlock, torch.nn.MaxPool2d, torch.nn.AvgPool2d)
22 # this block is not quantized. Also if the next block is this, current block is not quantized
23 self.ignore_out_blocks = (layers.NoQAct,torch.nn.Dropout2d)
25 # TBD: is this required
26 # # if the original module has load_weights, add it to the quant module also
27 # if hasattr(module, 'load_weights'):
28 # def load_weights(m, state_dict, change_names_dict=None):
29 # utils.load_weights(m.module, state_dict, change_names_dict=change_names_dict)
30 # #
31 # self.load_weights = types.MethodType(load_weights, self)
32 # #
35 def init_qstate(self):
36 if not hasattr(self, '__qstate__'):
37 self.__qstate__ = Dict()
38 #
39 if 'qparams' not in self.get_qstate():
40 self.get_qstate().qparams = Dict()
41 #
42 if 'qparams_prev' not in self.get_qstate():
43 self.get_qstate().qparams_prev = Dict()
44 #
45 if 'analyzed_graph' not in self.get_qstate():
46 self.get_qstate().analyzed_graph = False
47 #
50 def clear_qstate(self):
51 self.__qstate__ = Dict()
52 self.init_qstate()
55 def get_qstate(self):
56 return self.__qstate__
59 def forward(self, inputs, *args, **kwargs):
60 assert False, 'forward is not defined'
63 def update_counters(self, force_update=False):
64 self.iter_in_epoch += 1
65 if self.training or force_update:
66 self.num_batches_tracked += 1
67 if self.iter_in_epoch == 0:
68 self.epoch += 1.0
69 #
70 #
71 #
73 # force_update is used to increment inte counters even in non training
74 # used for validation in QuantTestModule
75 def analyze_graph(self, inputs, *args, force_update=False, merge_weights=False, clear_qstate=False, **kwargs):
76 with torch.no_grad():
77 self.init_qstate()
78 self.update_counters(force_update=force_update)
79 if (self.get_qstate().analyzed_graph == False):
80 # forward and analyze
81 self.forward_analyze_modules(inputs, *args, **kwargs)
82 # analyze the connections
83 self.analyze_connections()
84 self.get_qstate().analyzed_graph = True
86 # merge weights so that weight quantization can be done
87 if merge_weights:
88 self.merge_weights()
89 #
91 if clear_qstate:
92 self.clear_qstate()
93 #
94 #
95 #
98 def model_surgery_quantize(self, dummy_input, *args, **kwargs):
99 # lear the sates - just to be sure
100 self.clear_qstate()
101 # analyze
102 self.analyze_graph(dummy_input, *args, **kwargs)
103 # insert QAct wherever range clipping needs to be done
104 self.model_surgery_activations()
105 # since we might have added new activations, clear the sates as they may not be valid
106 self.clear_qstate()
107 # need to call analyze_graph in the derived class
108 #
110 def model_surgery_activations(self):
111 for module_hash, qparams in self.get_qstate().qparams.items():
112 module = self.get_module(module_hash)
113 if isinstance(module, layers.PAct2):
114 pass
115 elif qparams.quantize_out:
116 if utils.is_activation(module):
117 if isinstance(module, (torch.nn.ReLU, torch.nn.ReLU6)):
118 activation_q = layers.PAct2(signed=False)
119 elif isinstance(module, torch.nn.Hardtanh):
120 activation_q = layers.PAct2(clip_range=(module.min_val, module.max_val))
121 elif isinstance(module, layers.QAct):
122 activation_q = layers.PAct2(signed=None)
123 else:
124 activation_q = layers.PAct2(signed=None)
125 #
126 # replace the existing activation by PAct2
127 parent = utils.get_parent_module(self, module)
128 name = utils.get_module_name(parent, module)
129 activation_q.train(self.training)
130 setattr(parent, name, activation_q)
131 elif not hasattr(module, 'activation_q'):
132 activation_q = layers.PAct2(signed=None)
133 activation_q.train(self.training)
134 module.activation_q = activation_q
135 #
136 elif qparams.quantize_in:
137 if not hasattr(module, 'activation_in'):
138 # TODO: set percentile_range_shrink=0.0 to avoid shrinking of input range, if needed.
139 activation_in = layers.PAct2(signed=None)
140 activation_in.train(self.training)
141 module.activation_in = activation_in
142 #
143 else:
144 pass
145 #
146 #
147 #
150 def train(self, mode=True):
151 self.iter_in_epoch = -1
152 super().train(mode)
155 ################################################################
156 def forward_analyze_modules(self, inputs, *args, **kwargs):
157 '''
158 analyze modules needs a call hook - the call hook does not work with DataParallel.
159 So, do the analysis on a copy.
160 '''
161 self_copy = copy.deepcopy(self)
162 self_copy._forward_analyze_modules_impl(inputs, *args, **kwargs)
163 self.get_qstate().qparams = self_copy.get_qstate().qparams
165 def _forward_analyze_modules_impl(self, inputs, *args, **kwargs):
166 self.start_call()
167 self.add_call_hook(self, self._analyze_modules_op)
168 forward_analyze_method_name = kwargs.pop('forward_analyze_method', None)
169 if forward_analyze_method_name is not None and hasattr(self.module, forward_analyze_method_name):
170 # get the bound method to be used as forward
171 forward_analyze_method = getattr(self.module, forward_analyze_method_name)
172 output = forward_analyze_method(inputs, *args, **kwargs)
173 else:
174 output = self.module(inputs, *args, **kwargs)
175 #
176 self.remove_call_hook(self.module)
177 self.finish_call()
178 return output
180 def _analyze_modules_op(self, op, inputs, *args, **kwargs):
181 inputs = utils.squeeze_list2(inputs)
182 self.start_node(op)
183 self.add_node(op, inputs)
184 outputs = op.__forward_orig__(inputs, *args, **kwargs)
185 self.add_node(op, inputs, outputs)
186 self.finish_node(op, inputs, outputs)
187 return outputs
189 def add_node(self, module, inputs, outputs=None):
190 inputs = self.format_tensors(inputs)
191 module_hash = self.module_hash(module)
193 if module_hash not in list(self.get_qstate().qparams.keys()):
194 self.get_qstate().qparams[module_hash] = Dict()
195 self.get_qstate().qparams[module_hash].qrange_w = None
196 self.get_qstate().qparams[module_hash].qrange_b = None
197 self.get_qstate().qparams[module_hash].qrange_in = []
198 self.get_qstate().qparams[module_hash].qrange_out = []
199 self.get_qstate().qparams[module_hash].is_input = (self.module is module)
200 self.get_qstate().qparams[module_hash].previous_node = []
201 self.get_qstate().qparams[module_hash].next_node = []
202 self.get_qstate().qparams[module_hash].current_node = module_hash
204 current_node = self.get_qstate().qparams[module_hash].current_node
205 for inp in inputs:
206 if hasattr(inp, 'qparams') and hasattr(inp.qparams, 'last_node'):
207 prev_module_hash = inp.qparams.last_node
208 prev_module = self.get_module(prev_module_hash)
209 previous_node = self.get_qstate().qparams[module_hash].previous_node
210 next_node = self.get_qstate().qparams[prev_module_hash].next_node
212 if str(inp.qparams.last_node) not in [str(p) for p in previous_node]:
213 self.get_qstate().qparams[module_hash].previous_node += [inp.qparams.last_node]
214 if str(current_node) not in [str(n) for n in next_node]:
215 self.get_qstate().qparams[prev_module_hash].next_node += [current_node]
217 if outputs is not None:
218 outputs = self.format_tensors(outputs)
219 for opt in outputs:
220 if not hasattr(opt, 'qparams'):
221 opt.qparams = Dict()
222 #
223 # update last_node if this is not a container module
224 # if this is a container module, the last_node would have been already filled in in the last leaf module
225 if len(module._modules) == 0:
226 opt.qparams.last_node = current_node
227 #
230 ################################################################
231 def analyze_connections(self):
232 first_module = None
233 prediction_module = None
234 for module_hash, qparams in self.get_qstate().qparams.items():
235 module = self.get_module(module_hash)
236 if utils.is_conv_deconv_linear(module) or utils.is_normalization(module) or utils.is_activation(module):
237 first_module = module if first_module is None else first_module
238 prediction_module = module
239 #
240 #
241 for module_hash, qparams in self.get_qstate().qparams.items():
242 module = self.get_module(module_hash)
243 is_first_module = (first_module is module)
244 is_prediction_module = (prediction_module is module)
245 self._analyse_connections_op(module_hash, module, qparams, is_first_module, is_prediction_module)
246 #
248 def _analyse_connections_op(self, module_hash, module, qparams, is_first_module, is_prediction_module):
249 previous_modules = [self.get_module(p) for p in qparams.previous_node] if len(qparams.previous_node)>0 else []
250 next_modules = [self.get_module(n) for n in qparams.next_node] if len(qparams.next_node)>0 else []
252 quantize_out = False
253 if isinstance(module, self.ignore_out_blocks):
254 quantize_out = False
255 elif utils.is_activation(module):
256 if len(next_modules)==1 and utils.is_activation(next_modules[0]):
257 quantize_out = False
258 else:
259 quantize_out = True
260 #
261 elif isinstance(module, self.quantize_out_blocks):
262 if len(next_modules)==1 and utils.is_activation(next_modules[0]):
263 quantize_out = False
264 else:
265 quantize_out = True
266 #
267 elif utils.is_normalization(module):
268 if len(next_modules)==1 and utils.is_activation(next_modules[0]):
269 quantize_out = False
270 else:
271 quantize_out = True
272 #
273 elif utils.is_conv(module) or utils.is_deconv(module):
274 if len(next_modules)==1 and (utils.is_normalization(next_modules[0]) or utils.is_activation(next_modules[0])):
275 quantize_out = False
276 else:
277 quantize_out = True
278 #
279 elif utils.is_linear(module):
280 if len(next_modules)==1 and (utils.is_normalization(next_modules[0]) or utils.is_activation(next_modules[0])):
281 quantize_out = False
282 else:
283 quantize_out = True
284 #
285 # elif isinstance(module, (torch.nn.AdaptiveAvgPool2d, torch.nn.Upsample, layers.ResizeTo, torch.nn.Flatten)):
286 # quantize_out = True
287 # #
289 if len(qparams.previous_node) > 0:
290 previous_module_hash = qparams.previous_node[-1]
291 previous_module = self.get_module(previous_module_hash)
292 previous_module_qparams = self.get_qstate().qparams[previous_module_hash]
293 is_input_ignored = isinstance(previous_module, self.ignore_out_blocks)
294 is_input_quantized = previous_module_qparams.quantize_out if \
295 hasattr(previous_module_qparams, 'quantize_out') else False
296 else:
297 is_input_ignored = False
298 is_input_quantized = False
299 #
301 quantize_in = utils.is_conv_deconv_linear(module) and not is_input_quantized and \
302 not is_input_ignored and is_first_module
303 qparams.quantize_w = utils.is_conv_deconv_linear(module) # all conv/deconv layers will be quantized
304 qparams.quantize_b = utils.is_conv_deconv_linear(module) # all conv/deconv layers will be quantized
305 qparams.quantize_out = quantize_out # selectively quantize output
306 qparams.quantize_in = quantize_in # only top modules's input need to be quantized
307 multi_input_blocks = (layers.AddBlock, layers.CatBlock, torch.nn.AdaptiveAvgPool2d)
308 qparams.align_in = isinstance(module, multi_input_blocks) # all tensors to be made same q at the input
309 qparams.scale_in = 64.0 if isinstance(module, torch.nn.AdaptiveAvgPool2d) else 1.0 # additional scaleup to simulate fixed point
310 qparams.unquantize_out = qparams.is_input # only top modules's output need to be unquantized
311 qparams.is_dwconv = utils.is_dwconv(module)
312 qparams.next_modules = next_modules
313 qparams.is_first_module = is_first_module
314 qparams.is_prediction_module = is_prediction_module
317 ################################################################
318 def merge_weights(self, make_backup=False):
319 assert self.get_qstate().analyzed_graph == True, 'graph must be analyzed before merge_weights()'
320 with torch.no_grad():
321 for module_hash, qparams in self.get_qstate().qparams.items():
322 module = self.get_module(module_hash)
323 self._merge_weight_op(module_hash, module, qparams, make_backup)
324 #
325 #
326 #
327 def _merge_weight_op(self, module_hash, module, qparams, make_backup):
328 is_conv = utils.is_conv_deconv(module)
330 # note: we consider merging only if there is a single next node
331 next_module = qparams.next_modules[0] if len(qparams.next_modules) == 1 else None
332 next_bn = isinstance(next_module, torch.nn.BatchNorm2d) if (next_module is not None) else None
334 # if the next module is a bn, appy bn merging step
335 if is_conv and next_bn:
336 conv = module
337 bn = next_module
339 # weight/bias
340 conv_bias = conv.bias.data if (conv.bias is not None) else 0.0
341 bn_weight = bn.weight.data if bn.affine else 1.0
342 bn_bias = bn.bias.data if bn.affine else 0.0
344 # merged weight and offset
345 merged_scale = torch.rsqrt(bn.running_var.data + bn.eps) * bn_weight
346 if utils.is_conv(conv):
347 merged_scale = merged_scale.view(-1, 1, 1, 1)
348 elif utils.is_deconv(conv):
349 merged_scale = merged_scale.view(1, -1, 1, 1)
350 else:
351 assert False, 'unable to merge convolution and BN'
352 #
353 merged_weight = conv.weight.data * merged_scale
354 merged_bias = (conv_bias - bn.running_mean.data) * merged_scale.view(-1) + bn_bias
356 # bn is set to unity
357 bn.running_mean.data.fill_(0.0)
358 bn.running_var.data.fill_(1.0 - bn.eps)
359 if bn.affine:
360 bn.weight.data.fill_(1.0)
361 bn.bias.data.fill_(0.0)
362 #
364 # copy merged weights to conv
365 conv.weight.data.copy_(merged_weight)
367 # copy merge bias
368 if conv.bias is not None:
369 conv.bias.data.copy_(merged_bias)
370 elif bn.affine:
371 bn.bias.data.copy_(merged_bias.data)
372 else:
373 warnings.warn('problem detected in conv+bn layer pair: either one of conv.bias or bn.affine is required for successfull calibration - preferably bn.affine')
374 #
375 #
376 return
379 ################################################################
380 def get_qparams(self, module):
381 module_hash = self.module_hash(module)
382 return self.get_qstate().qparams[module_hash]
385 def get_qparams_prev(self, module):
386 module_hash = self.module_hash(module)
387 return self.get_qstate().qparams_prev[module_hash] if self.get_qstate().qparams_prev else None
390 def start_call(self):
391 self.call_count = Dict()
394 def finish_call(self):
395 self.call_count = None
398 def start_node(self, module):
399 module_name = self.module_name(module)
400 if module_name not in list(self.call_count.keys()):
401 self.call_count[module_name] = 0
402 #
403 return
406 def finish_node(self, module, inputs, outputs):
407 module_name = self.module_name(module)
408 self.call_count[module_name] = self.call_count[module_name] + 1
409 return
412 def module_hash(self, module):
413 '''
414 A module may be called multiple times in a model. This module has creates a unique name/hash for each call
415 using teh call_count. call_count needs tobe correct for this to work as expected.
416 call_count is kep up to date by using start_node() / finish_node() calls.
417 '''
418 module_name = self.module_name(module)
419 module_hash = module_name + '-call:{}'.format(self.call_count[module_name])
420 return module_hash
423 def module_name(self, module):
424 name = None
425 for n, m in self.named_modules():
426 if m is module:
427 name = n
428 #
429 #
430 return name
433 def get_module(self, module_hash):
434 module_name = module_hash.split('-call:')[0]
435 for mname, mod in self.named_modules():
436 if module_name == mname:
437 return mod
438 #
439 return None
442 def is_last_conv(self, module):
443 # implementation is not correct. disable it for the time being
444 return False #(module is self.last_conv_linear_module)
447 def format_tensors(self, inputs):
448 # make a list/tuple if inputs is not. if it is a double list, remove the extra one
449 inputs = utils.squeeze_list2(utils.make_list(inputs))
450 # remove lists/tuple
451 inputs = [ipt for ipt in inputs if utils.is_tensor(ipt)]
452 return inputs
455 def copy_qparams(self, qparams, inputs):
456 qparams_copy = Dict()
457 for module_hash, qparam_entry in qparams.items():
458 qparams_copy[module_hash] = Dict()
459 for key, value in qparam_entry.items():
460 # deep copy may not work in some cases, so do it conditionally
461 try:
462 qparams_copy[module_hash][key] = copy.deepcopy(value)
463 except Exception:
464 qparams_copy[module_hash][key] = value
465 #
466 #
467 #
468 return qparams_copy