39c6482841f761a9263693bc6d6746ae590c183a
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_graph_module.py
1 import warnings
2 import torch
3 import copy
4 from .. import utils
5 from .. import layers
6 from ..utils import AttrDict as Dict
7 from .hooked_module import *
9 class QuantGraphModule(HookedModule):
10 # instance member states are not retained across forward calls when using DataParallel
11 # as a workaround, use class member variables instead, so that these can be retained
12 states = Dict()
13 def __init__(self, module):
14 super().__init__()
15 self.module = module
16 self.init_states()
17 self.register_buffer('num_batches_tracked', torch.tensor(-1.0))
18 self.register_buffer('iter_in_epoch', torch.tensor(-1.0))
19 self.register_buffer('epoch', torch.tensor(-1.0))
21 # TBD: is this required
22 # # if the original module has load_weights, add it to the quant module also
23 # if hasattr(module, 'load_weights'):
24 # def load_weights(m, state_dict, change_names_dict=None):
25 # utils.load_weights(m.module, state_dict, change_names_dict=change_names_dict)
26 # #
27 # self.load_weights = types.MethodType(load_weights, self)
28 # #
31 # create the state object required to keep some quantization parameters that need to be preserved
32 # a cuda() has been called on the module - copy the states from that was created for cpu
33 def get_state(self):
34 states = self.get_states()
35 module_device = self.module_device(self.module)
36 if module_device not in states:
37 module_device_src = None
38 for key, value in states.items():
39 if key.type == 'cpu':
40 module_device_src = key
41 #
42 if module_device_src is not None and module_device_src in states:
43 states[module_device] = copy.deepcopy(states[module_device_src])
44 else:
45 states[module_device] = Dict()
46 #
47 return states[module_device]
50 def get_states(self):
51 return __class__.states
54 def clear_states(self):
55 __class__.states = Dict()
57 # these entries will prevent this modules from being used with DataParallel - cleanup
58 def cleanup_states(self):
59 assert self.get_state().analyzed_graph == True, 'graph must be analyzed before cleanup_states()'
60 with torch.no_grad():
61 for module_hash, qparams in self.get_state().qparams.items():
62 if hasattr(qparams, 'previous_node'):
63 del qparams.previous_node
64 #
65 if hasattr(qparams, 'previous_module'):
66 del qparams.previous_module
67 #
68 if hasattr(qparams, 'next_node'):
69 del qparams.next_node
70 #
71 if hasattr(qparams, 'next_module'):
72 del qparams.next_module
73 #
74 #
75 #
78 # data parallel does not initialize the replicas correctly, explicitly initialize them.
79 # there is no use in doing this in __init__. it has to be done in forward even if it is called in __init__.
80 def init_states(self):
81 if 'qparams' not in self.get_state().keys():
82 self.get_state().qparams = Dict()
83 if 'qparams_prev' not in self.get_state().keys():
84 self.get_state().qparams_prev = Dict()
85 if 'analyzed_graph' not in self.get_state().keys():
86 self.get_state().analyzed_graph = False
89 def forward(self, inputs):
90 assert False, 'forward is not defined'
93 def update_counters(self, force_update=False):
94 if self.training or force_update:
95 self.num_batches_tracked += 1
96 if self.num_batches_tracked == 0:
97 self.epoch += 1.0
98 #
99 #
100 self.iter_in_epoch += 1
101 #
103 # force_update is used to increment inte counters even in non training
104 # used for validation in QuantTestModule
105 def analyze_graph(self, inputs, force_update=False, merge_weights=False, cleanup_states=False):
106 with torch.no_grad():
107 self.init_states()
108 self.update_counters(force_update=force_update)
109 if (self.get_state().analyzed_graph == False):
110 # forward and analyze
111 self.forward_analyze_modules(inputs)
112 # analyze the connections
113 self.analyze_connections()
114 self.get_state().analyzed_graph = True
116 # merge weights so that weight quantization can be done
117 if merge_weights:
118 self.merge_weights()
119 #
121 if cleanup_states:
122 self.cleanup_states()
123 #
124 #
125 #
128 def model_surgery_quantize(self, dummy_input):
129 # lear the sates - just to be sure
130 self.clear_states()
131 # analyze
132 self.analyze_graph(dummy_input)
133 # insert NoAct wherever range clipping needs to be done
134 self.model_surgery_activations()
135 # since we might have added new activations, clear the sates as they may not be valid
136 self.clear_states()
137 # need to call analyze_graph in the derived class
138 #
140 def model_surgery_activations(self):
141 for module_hash, qparams in self.get_state().qparams.items():
142 module = self.get_module(module_hash)
143 if isinstance(module, layers.PAct2):
144 pass
145 elif qparams.quantize_out:
146 if utils.is_activation(module):
147 if isinstance(module, (torch.nn.ReLU, torch.nn.ReLU6)):
148 activation_q = layers.PAct2(signed=False)
149 elif isinstance(module, torch.nn.Hardtanh):
150 activation_q = layers.PAct2(clip_range=(module.min_val, module.max_val))
151 elif isinstance(module, layers.NoAct):
152 activation_q = layers.PAct2(signed=None)
153 else:
154 activation_q = layers.PAct2(signed=None)
155 #
156 # replace the existing activation by PAct2
157 parent = utils.get_parent_module(self, module)
158 name = utils.get_module_name(parent, module)
159 activation_q.train(self.training)
160 setattr(parent, name, activation_q)
161 elif not hasattr(module, 'activation_q'):
162 activation_q = layers.PAct2(signed=None)
163 activation_q.train(self.training)
164 module.activation_q = activation_q
165 #
166 else:
167 pass
168 #
169 #
170 #
173 def start_call(self):
174 self.call_count = Dict()
177 def finish_call(self):
178 self.call_count = None
181 def train(self, mode=True):
182 self.iter_in_epoch.fill_(-1.0)
183 super().train(mode)
186 ################################################################
187 def forward_analyze_modules(self, inputs):
188 self.start_call()
189 self.add_call_hook(self.module, self._analyze_modules_op)
190 output = self.module(inputs)
191 self.remove_call_hook(self.module)
192 self.finish_call()
193 return output
195 def _analyze_modules_op(self, op, *inputs_orig):
196 inputs = utils.squeeze_list(inputs_orig)
197 self.start_node(op)
198 self.add_node(op, inputs)
199 outputs = op.__forward_orig__(*inputs_orig)
200 self.add_node(op, inputs, outputs)
201 self.finish_node(op, inputs, outputs)
202 return outputs
204 def add_node(self, module, inputs, outputs=None):
205 inputs = self.format_tensors(inputs)
206 module_hash = self.module_hash(module)
208 if module_hash not in list(self.get_state().qparams.keys()):
209 self.get_state().qparams[module_hash] = Dict()
210 self.get_state().qparams[module_hash].qrange_w = None
211 self.get_state().qparams[module_hash].qrange_b = None
212 self.get_state().qparams[module_hash].qrange_in = []
213 self.get_state().qparams[module_hash].qrange_out = []
214 self.get_state().qparams[module_hash].is_input = (self.module is module)
215 self.get_state().qparams[module_hash].previous_node = []
216 self.get_state().qparams[module_hash].next_node = []
217 self.get_state().qparams[module_hash].current_node = module_hash
219 current_node = self.get_state().qparams[module_hash].current_node
220 for inp in inputs:
221 if hasattr(inp, 'qparams') and hasattr(inp.qparams, 'last_node'):
222 prev_module_hash = inp.qparams.last_node
223 prev_module = self.get_module(prev_module_hash)
224 previous_node = self.get_state().qparams[module_hash].previous_node
225 next_node = self.get_state().qparams[prev_module_hash].next_node
227 if str(inp.qparams.last_node) not in [str(p) for p in previous_node]:
228 self.get_state().qparams[module_hash].previous_node += [inp.qparams.last_node]
229 if str(current_node) not in [str(n) for n in next_node]:
230 self.get_state().qparams[prev_module_hash].next_node += [current_node]
232 if outputs is not None:
233 outputs = self.format_tensors(outputs)
234 for opt in outputs:
235 if not hasattr(opt, 'qparams'):
236 opt.qparams = Dict()
237 #
238 # update last_node if this is not a container module
239 # if this is a container module, the last_node would have been already filled in in the last leaf module
240 if len(module._modules) == 0:
241 opt.qparams.last_node = current_node
242 #
245 ################################################################
246 def analyze_connections(self):
247 prediction_module = None
248 for module_hash, qparams in self.get_state().qparams.items():
249 module = self.get_module(module_hash)
250 if utils.is_conv(module) or utils.is_linear(module) or utils.is_normalization(module) or utils.is_activation(module):
251 prediction_module = module
252 #
253 #
254 for module_hash, qparams in self.get_state().qparams.items():
255 module = self.get_module(module_hash)
256 is_prediction = (prediction_module is module)
257 self._analyse_connections_op(module_hash, module, qparams, is_prediction)
258 #
260 def _analyse_connections_op(self, module_hash, module, qparams, is_prediction):
261 previous_module = [self.get_module(p) for p in qparams.previous_node] if len(qparams.previous_node)>0 else []
262 next_module = [self.get_module(n) for n in qparams.next_node] if len(qparams.next_node)>0 else []
264 quantize_out = False
265 if utils.is_activation(module):
266 if len(next_module)==1 and utils.is_activation(next_module[0]):
267 quantize_out = False
268 else:
269 quantize_out = True
270 #
271 elif isinstance(module, (layers.AddBlock, layers.CatBlock, layers.MultBlock)):
272 if len(next_module)==1 and utils.is_activation(next_module[0]):
273 quantize_out = False
274 else:
275 quantize_out = True
276 #
277 elif utils.is_normalization(module):
278 if len(next_module)==1 and utils.is_activation(next_module[0]):
279 quantize_out = False
280 else:
281 quantize_out = True
282 #
283 elif utils.is_conv(module) or utils.is_deconv(module):
284 if len(next_module)==1 and (utils.is_normalization(next_module[0]) or utils.is_activation(next_module[0])):
285 quantize_out = False
286 else:
287 quantize_out = True
288 #
289 elif utils.is_linear(module):
290 if len(next_module)==1 and (utils.is_normalization(next_module[0]) or utils.is_activation(next_module[0])):
291 quantize_out = False
292 else:
293 quantize_out = True
294 #
295 # elif isinstance(module, (torch.nn.AdaptiveAvgPool2d, torch.nn.Upsample, layers.ResizeTo, torch.nn.Flatten)):
296 # quantize_out = True
297 # #
299 qparams.quantize_w = isinstance(module, (torch.nn.Conv2d,torch.nn.ConvTranspose2d,torch.nn.Linear)) # all conv/deconv layers will be quantized
300 qparams.quantize_b = isinstance(module, (torch.nn.Conv2d,torch.nn.ConvTranspose2d,torch.nn.Linear)) # all conv/deconv layers will be quantized
301 qparams.quantize_out = quantize_out # selectively quantize output
302 qparams.quantize_in = qparams.is_input # only top modules's input need to be quantized
303 qparams.align_in = isinstance(module, (layers.AddBlock, layers.CatBlock,torch.nn.AdaptiveAvgPool2d))# all tensors to be made same q at the input
304 qparams.scale_in = 64.0 if isinstance(module, torch.nn.AdaptiveAvgPool2d) else 1.0 #additional scaleup to simulate fixed point
305 qparams.unquantize_out = qparams.is_input # only top modules's output need to be unquantized
306 qparams.is_dwconv = utils.is_dwconv(module)
307 qparams.next_module = next_module
308 qparams.is_prediction = is_prediction
311 ################################################################
312 def merge_weights(self, make_backup=False):
313 assert self.get_state().analyzed_graph == True, 'graph must be analyzed before merge_weights()'
314 with torch.no_grad():
315 for module_hash, qparams in self.get_state().qparams.items():
316 module = self.get_module(module_hash)
317 self._merge_weight_op(module_hash, module, qparams, make_backup)
318 #
319 #
320 #
321 def _merge_weight_op(self, module_hash, module, qparams, make_backup):
322 is_conv = utils.is_conv_deconv(module)
324 # note: we consider merging only if there is a single next node
325 next_module = qparams.next_module[0] if len(qparams.next_module) == 1 else None
326 next_bn = isinstance(next_module, torch.nn.BatchNorm2d) if (next_module is not None) else None
328 # if the next module is a bn, appy bn merging step
329 if is_conv and next_bn:
330 conv = module
331 bn = next_module
333 # weight/bias
334 conv_bias = conv.bias.data if (conv.bias is not None) else 0.0
335 bn_weight = bn.weight.data if bn.affine else 1.0
336 bn_bias = bn.bias.data if bn.affine else 0.0
338 # merged weight and offset
339 merged_scale = torch.rsqrt(bn.running_var.data + bn.eps) * bn_weight
340 if utils.is_conv(conv):
341 merged_scale = merged_scale.view(-1, 1, 1, 1)
342 elif utils.is_deconv(conv):
343 merged_scale = merged_scale.view(1, -1, 1, 1)
344 else:
345 assert False, 'unable to merge convolution and BN'
346 #
347 merged_weight = conv.weight.data * merged_scale
348 merged_bias = (conv_bias - bn.running_mean.data) * merged_scale.view(-1) + bn_bias
350 # bn is set to unity
351 bn.running_mean.data.fill_(0.0)
352 bn.running_var.data.fill_(1.0 - bn.eps)
353 if bn.affine:
354 bn.weight.data.fill_(1.0)
355 bn.bias.data.fill_(0.0)
356 #
358 # copy merged weights to conv
359 conv.weight.data.copy_(merged_weight)
361 # copy merge bias
362 if conv.bias is not None:
363 conv.bias.data.copy_(merged_bias)
364 elif bn.affine:
365 bn.bias.data.copy_(merged_bias.data)
366 else:
367 warnings.warn('problem detected in conv+bn layer pair: either one of conv.bias or bn.affine is required for successfull calibration - preferably bn.affine')
368 #
369 #
371 return
374 ################################################################
375 def get_qparams(self, module):
376 module_hash = self.module_hash(module)
377 return self.get_state().qparams[module_hash]
379 def get_qparams_prev(self, module):
380 module_hash = self.module_hash(module)
381 return self.get_state().qparams_prev[module_hash] if self.get_state().qparams_prev else None
383 def start_node(self, module):
384 module_name = self.module_name(module)
385 if module_name not in list(self.call_count.keys()):
386 self.call_count[module_name] = 0
388 return
390 def finish_node(self, module, inputs, outputs):
391 module_name = self.module_name(module)
392 self.call_count[module_name] = self.call_count[module_name] + 1
393 return
395 def module_hash(self, module):
396 module_name = self.module_name(module)
397 module_hash = module_name + '-call:{}'.format(self.call_count[module_name])
398 return module_hash
400 def module_name(self, module):
401 name = None
402 for n, m in self.named_modules():
403 if m is module:
404 name = n
405 #
406 return name
408 def module_device(self, module=None):
409 module = module if module is not None else self.module
410 try:
411 module_device = next(module.parameters()).device
412 except:
413 module_device = None
414 #
415 return module_device
417 def get_module(self, module_hash):
418 module_name = module_hash.split('-call:')[0]
419 for mname, mod in self.named_modules():
420 if module_name == mname:
421 return mod
422 #
423 return None
425 def is_last_conv(self, module):
426 # implementation is not correct. disable it for the time being
427 return False #(module is self.last_conv_linear_module)
429 def format_tensors(self, inputs):
430 # make a list/tuple if inputs is not. if it is a double list, remove the extra one
431 inputs = utils.squeeze_list(utils.make_list(inputs))
432 # remove lists/tuple
433 inputs = [ipt for ipt in inputs if utils.is_tensor(ipt)]
434 return inputs
437 def copy_qparams(self, qparams, inputs):
438 qparams_copy = Dict()
440 # deep copy may not work in some cases, so do it conditionally
441 for module_hash, qparam_entry in qparams.items():
442 qparams_copy[module_hash] = Dict()
443 for key, value in qparam_entry.items():
444 try:
445 qparams_copy[module_hash][key] = copy.deepcopy(value)
446 except Exception:
447 qparams_copy[module_hash][key] = value
449 return qparams_copy