[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_graph_module.py
1 import warnings
2 import torch
3 import copy
4 from .. import utils
5 from .. import layers
6 from ..utils import AttrDict as Dict
7 from .hooked_module import *
9 class QuantGraphModule(HookedModule):
10 # instance member states are not retained across forward calls when using DataParallel
11 # as a workaround, use class member variables instead, so that these can be retained
12 states = Dict()
13 def __init__(self, module):
14 super().__init__()
15 self.module = module
16 self.init_states()
17 self.register_buffer('num_batches_tracked', torch.tensor(-1.0))
18 self.register_buffer('iter_in_epoch', torch.tensor(-1.0))
19 self.register_buffer('epoch', torch.tensor(-1.0))
21 # create the state object required to keep some quantization parameters that need to be preserved
22 # a cuda() has been called on the module - copy the states from that was created for cpu
23 def get_state(self):
24 states = self.get_states()
25 module_device = self.module_device(self.module)
26 if module_device not in states:
27 module_device_src = None
28 for key, value in states.items():
29 if key.type == 'cpu':
30 module_device_src = key
31 #
32 if module_device_src is not None and module_device_src in states:
33 states[module_device] = copy.deepcopy(states[module_device_src])
34 else:
35 states[module_device] = Dict()
36 #
37 return states[module_device]
40 def get_states(self):
41 return __class__.states
44 def clear_states(self):
45 __class__.states = Dict()
47 # these entries will prevent this modules from being used with DataParallel - cleanup
48 def cleanup_states(self):
49 assert self.get_state().analyzed_graph == True, 'graph must be analyzed before cleanup_states()'
50 with torch.no_grad():
51 for module_hash, qparams in self.get_state().qparams.items():
52 if hasattr(qparams, 'previous_node'):
53 del qparams.previous_node
54 #
55 if hasattr(qparams, 'previous_module'):
56 del qparams.previous_module
57 #
58 if hasattr(qparams, 'next_node'):
59 del qparams.next_node
60 #
61 if hasattr(qparams, 'next_module'):
62 del qparams.next_module
63 #
64 #
65 #
68 # data parallel does not initialize the replicas correctly, explicitly initialize them.
69 # there is no use in doing this in __init__. it has to be done in forward even if it is called in __init__.
70 def init_states(self):
71 if 'qparams' not in self.get_state().keys():
72 self.get_state().qparams = Dict()
73 if 'qparams_prev' not in self.get_state().keys():
74 self.get_state().qparams_prev = Dict()
75 if 'analyzed_graph' not in self.get_state().keys():
76 self.get_state().analyzed_graph = False
79 def forward(self, inputs):
80 assert False, 'forward is not defined'
82 # force_update is used to increment inte counters even in non training
83 # used for validation in QuantTestModule
84 def analyze_graph(self, inputs, force_update=False, merge_weights=False, cleanup_states=False):
85 with torch.no_grad():
86 self.init_states()
87 if self.training or force_update:
88 self.num_batches_tracked += 1
89 if self.num_batches_tracked == 0:
90 self.epoch += 1.0
91 #
92 #
93 self.iter_in_epoch += 1
94 if (self.get_state().analyzed_graph == False):
95 # forward and analyze
96 self.forward_analyze_modules(inputs)
97 # analyze the connections
98 self.analyze_connections()
99 self.get_state().analyzed_graph = True
101 # merge weights so that weight quantization can be done
102 if merge_weights:
103 self.merge_weights()
104 #
106 if cleanup_states:
107 self.cleanup_states()
108 #
109 #
110 #
113 def model_surgery_quantize(self, dummy_input):
114 # lear the sates - just to be sure
115 self.clear_states()
116 # analyze
117 self.analyze_graph(dummy_input)
118 # insert NoAct wherever range clipping needs to be done
119 self.model_surgery_activations()
120 # since we might have added new activations, clear the sates as they may not be valid
121 self.clear_states()
122 # need to call analyze_graph in the derived class
123 #
125 def model_surgery_activations(self):
126 for module_hash, qparams in self.get_state().qparams.items():
127 module = self.get_module(module_hash)
128 if isinstance(module, layers.PAct2):
129 pass
130 elif qparams.quantize_out:
131 if utils.is_activation(module):
132 if isinstance(module, (torch.nn.ReLU, torch.nn.ReLU6, layers.ReLUN)):
133 activation_q = layers.PAct2(signed=False)
134 else:
135 activation_q = layers.PAct2(signed=None)
136 #
137 # replace the existing activation by PAct2
138 parent = utils.get_parent_module(self, module)
139 name = utils.get_module_name(parent, module)
140 activation_q.train(self.training)
141 setattr(parent, name, activation_q)
142 elif not hasattr(module, 'activation_q'):
143 activation_q = layers.PAct2(signed=None)
144 activation_q.train(self.training)
145 module.activation_q = activation_q
146 #
147 else:
148 pass
149 #
150 #
151 #
154 def start_call(self):
155 self.call_count = Dict()
158 def finish_call(self):
159 self.call_count = None
162 def train(self, mode=True):
163 self.iter_in_epoch.fill_(-1.0)
164 super().train(mode)
167 ################################################################
168 def forward_analyze_modules(self, inputs):
169 self.start_call()
170 self.add_call_hook(self.module, self._analyze_modules_op)
171 output = self.module(inputs)
172 self.remove_call_hook(self.module)
173 self.finish_call()
174 return output
176 def _analyze_modules_op(self, op, *inputs_orig):
177 inputs = utils.squeeze_list(inputs_orig)
178 self.start_node(op)
179 self.add_node(op, inputs)
180 outputs = op.__forward_orig__(*inputs_orig)
181 self.add_node(op, inputs, outputs)
182 self.finish_node(op, inputs, outputs)
183 return outputs
185 def add_node(self, module, inputs, outputs=None):
186 inputs = self.format_tensors(inputs)
187 module_hash = self.module_hash(module)
189 if module_hash not in list(self.get_state().qparams.keys()):
190 self.get_state().qparams[module_hash] = Dict()
191 self.get_state().qparams[module_hash].qrange_w = None
192 self.get_state().qparams[module_hash].qrange_b = None
193 self.get_state().qparams[module_hash].qrange_in = []
194 self.get_state().qparams[module_hash].qrange_out = []
195 self.get_state().qparams[module_hash].is_input = (self.module is module)
196 self.get_state().qparams[module_hash].previous_node = []
197 self.get_state().qparams[module_hash].next_node = []
198 self.get_state().qparams[module_hash].current_node = module_hash
200 current_node = self.get_state().qparams[module_hash].current_node
201 for inp in inputs:
202 if hasattr(inp, 'qparams') and hasattr(inp.qparams, 'last_node'):
203 prev_module_hash = inp.qparams.last_node
204 prev_module = self.get_module(prev_module_hash)
205 previous_node = self.get_state().qparams[module_hash].previous_node
206 next_node = self.get_state().qparams[prev_module_hash].next_node
208 if str(inp.qparams.last_node) not in [str(p) for p in previous_node]:
209 self.get_state().qparams[module_hash].previous_node += [inp.qparams.last_node]
210 if str(current_node) not in [str(n) for n in next_node]:
211 self.get_state().qparams[prev_module_hash].next_node += [current_node]
213 if outputs is not None:
214 outputs = self.format_tensors(outputs)
215 for opt in outputs:
216 if not hasattr(opt, 'qparams'):
217 opt.qparams = Dict()
218 #
219 opt.qparams.last_node = current_node
222 ################################################################
223 def analyze_connections(self):
224 prediction_module = None
225 for module_hash, qparams in self.get_state().qparams.items():
226 module = self.get_module(module_hash)
227 if utils.is_conv(module) or utils.is_linear(module) or utils.is_normalization(module) or utils.is_activation(module):
228 prediction_module = module
229 #
230 #
231 for module_hash, qparams in self.get_state().qparams.items():
232 module = self.get_module(module_hash)
233 is_prediction = (prediction_module is module)
234 self._analyse_connections_op(module_hash, module, qparams, is_prediction)
235 #
237 def _analyse_connections_op(self, module_hash, module, qparams, is_prediction):
238 previous_module = [self.get_module(p) for p in qparams.previous_node] if len(qparams.previous_node)>0 else []
239 next_module = [self.get_module(n) for n in qparams.next_node] if len(qparams.next_node)>0 else []
241 quantize_out = False
242 if utils.is_activation(module):
243 quantize_out = True
244 elif isinstance(module, (layers.AddBlock, layers.CatBlock, layers.MultBlock)):
245 if len(next_module)==1 and utils.is_activation(next_module[0]):
246 quantize_out = False
247 else:
248 quantize_out = True
249 #
250 elif utils.is_normalization(module):
251 if len(next_module)==1 and utils.is_activation(next_module[0]):
252 quantize_out = False
253 else:
254 quantize_out = True
255 #
256 elif utils.is_conv(module):
257 if len(next_module)==1 and (utils.is_normalization(next_module[0]) or utils.is_activation(next_module[0])):
258 quantize_out = False
259 else:
260 quantize_out = True
261 #
262 elif utils.is_linear(module):
263 if len(next_module)==1 and (utils.is_normalization(next_module[0]) or utils.is_activation(next_module[0])):
264 quantize_out = False
265 else:
266 quantize_out = True
267 #
268 # elif isinstance(module, (torch.nn.AdaptiveAvgPool2d, torch.nn.Upsample, layers.ResizeTo, torch.nn.Flatten)):
269 # quantize_out = True
270 # #
272 qparams.quantize_w = isinstance(module, (torch.nn.Conv2d,torch.nn.Linear)) # all conv layers will be quantized
273 qparams.quantize_b = isinstance(module, (torch.nn.Conv2d,torch.nn.Linear)) # all conv layers will be quantized
274 qparams.quantize_out = quantize_out # selectively quantize output
275 qparams.quantize_in = qparams.is_input # only top modules's input need to be quantized
276 qparams.align_in = isinstance(module, (layers.AddBlock, layers.CatBlock,torch.nn.AdaptiveAvgPool2d))# all tensors to be made same q at the input
277 qparams.scale_in = 64.0 if isinstance(module, torch.nn.AdaptiveAvgPool2d) else 1.0 #additional scaleup to simulate fixed point
278 qparams.unquantize_out = qparams.is_input # only top modules's output need to be unquantized
279 qparams.is_dwconv = utils.is_dwconv(module)
280 qparams.next_module = next_module
281 qparams.is_prediction = is_prediction
284 ################################################################
285 def merge_weights(self, make_backup=False):
286 assert self.get_state().analyzed_graph == True, 'graph must be analyzed before merge_weights()'
287 with torch.no_grad():
288 for module_hash, qparams in self.get_state().qparams.items():
289 module = self.get_module(module_hash)
290 self._merge_weight_op(module_hash, module, qparams, make_backup)
291 #
292 #
293 #
294 def _merge_weight_op(self, module_hash, module, qparams, make_backup):
295 is_conv = isinstance(module,torch.nn.Conv2d)
297 # note: we consider merging only if there is a single next node
298 next_module = qparams.next_module[0] if len(qparams.next_module) == 1 else None
299 next_bn = isinstance(next_module, torch.nn.BatchNorm2d) if (next_module is not None) else None
301 # if the next module is a bn, appy bn merging step
302 if is_conv and next_bn:
303 conv = module
304 bn = next_module
306 # weight/bias
307 conv_bias = conv.bias.data if (conv.bias is not None) else 0.0
308 bn_weight = bn.weight.data if bn.affine else 1.0
309 bn_bias = bn.bias.data if bn.affine else 0.0
311 # merged weight and offset
312 merged_scale = torch.rsqrt(bn.running_var.data + bn.eps) * bn_weight
313 merged_weight = conv.weight.data * merged_scale.view(-1, 1, 1, 1)
314 merged_bias = (conv_bias - bn.running_mean.data) * merged_scale + bn_bias
316 # bn is set to unity
317 bn.running_mean.data.fill_(0.0)
318 bn.running_var.data.fill_(1.0 - bn.eps)
319 if bn.affine:
320 bn.weight.data.fill_(1.0)
321 bn.bias.data.fill_(0.0)
322 #
324 # copy merged weights to conv
325 conv.weight.data.copy_(merged_weight)
327 # copy merge bias
328 if conv.bias is not None:
329 conv.bias.data.copy_(merged_bias)
330 elif bn.affine:
331 bn.bias.data.copy_(merged_bias.data)
332 else:
333 warnings.warn('problem detected in conv+bn layer pair: either one of conv.bias or bn.affine is required for successfull calibration - preferably bn.affine')
334 #
335 #
337 return
340 ################################################################
341 def get_qparams(self, module):
342 module_hash = self.module_hash(module)
343 return self.get_state().qparams[module_hash]
345 def get_qparams_prev(self, module):
346 module_hash = self.module_hash(module)
347 return self.get_state().qparams_prev[module_hash] if self.get_state().qparams_prev else None
349 def start_node(self, module):
350 module_name = self.module_name(module)
351 if module_name not in list(self.call_count.keys()):
352 self.call_count[module_name] = 0
354 return
356 def finish_node(self, module, inputs, outputs):
357 module_name = self.module_name(module)
358 self.call_count[module_name] = self.call_count[module_name] + 1
359 return
361 def module_hash(self, module):
362 module_name = self.module_name(module)
363 module_hash = module_name + '-call:{}'.format(self.call_count[module_name])
364 return module_hash
366 def module_name(self, module):
367 name = None
368 for n, m in self.named_modules():
369 if m is module:
370 name = n
371 #
372 return name
374 def module_device(self, module=None):
375 module = module if module is not None else self.module
376 try:
377 module_device = next(module.parameters()).device
378 except:
379 module_device = None
380 #
381 return module_device
383 def get_module(self, module_hash):
384 module_name = module_hash.split('-call:')[0]
385 for mname, mod in self.named_modules():
386 if module_name == mname:
387 return mod
388 #
389 return None
391 def is_last_conv(self, module):
392 # implementation is not correct. disable it for the time being
393 return False #(module is self.last_conv_linear_module)
395 def format_tensors(self, inputs):
396 # make a list/tuple if inputs is not. if it is a double list, remove the extra one
397 inputs = utils.squeeze_list(utils.make_list(inputs))
398 # remove lists/tuple
399 inputs = [ipt for ipt in inputs if utils.is_tensor(ipt)]
400 return inputs
403 def copy_qparams(self, qparams, inputs):
404 qparams_copy = Dict()
406 # deep copy may not work in some cases, so do it conditionally
407 for module_hash, qparam_entry in qparams.items():
408 qparams_copy[module_hash] = Dict()
409 for key, value in qparam_entry.items():
410 try:
411 qparams_copy[module_hash][key] = copy.deepcopy(value)
412 except Exception:
413 qparams_copy[module_hash][key] = value
415 return qparams_copy