]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py
updated python package requirements (don't need tensorflow for tensorboard). not...
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_graph_module.py
1 import warnings
2 import torch
3 import copy
4 from .. import utils
5 from .. import layers
6 from ..utils import AttrDict as Dict
7 from .hooked_module import *
9 class QuantGraphModule(HookedModule):
10     def __init__(self, module):
11         super().__init__()
12         self.module = module
13         self.init_qstate()
14         self.num_batches_tracked = -1
15         self.iter_in_epoch = -1
16         self.epoch = -1
17         # these are the blocks whose output we quantize for sure.
18         # outputs of other clocks such as Conv2d, ConvTranspose2d, BatchNorm2d, Lindear are quantized conditionally
19         self.quantize_out_blocks = (torch.nn.ReLU, torch.nn.ReLU6, torch.nn.Hardtanh, layers.QAct, layers.PAct2,
20                                     layers.AddBlock, layers.CatBlock, layers.MultBlock, torch.nn.MaxPool2d, torch.nn.AvgPool2d)
22         # this block is not quantized. Also if the next block is this, current block is not quantized
23         self.ignore_out_blocks = (layers.NoQAct,torch.nn.Dropout2d)
25         # TBD: is this required
26         # # if the original module has load_weights, add it to the quant module also
27         # if hasattr(module, 'load_weights'):
28         #     def load_weights(m, state_dict, change_names_dict=None):
29         #         utils.load_weights(m.module, state_dict, change_names_dict=change_names_dict)
30         #     #
31         #     self.load_weights = types.MethodType(load_weights, self)
32         # #
35     def init_qstate(self):
36         if not hasattr(self, '__qstate__'):
37             self.__qstate__ = Dict()
38         #
39         if 'qparams' not in self.get_qstate():
40             self.get_qstate().qparams = Dict()
41         #
42         if 'qparams_prev' not in self.get_qstate():
43             self.get_qstate().qparams_prev = Dict()
44         #
45         if 'analyzed_graph' not in self.get_qstate():
46             self.get_qstate().analyzed_graph = False
47         #
50     def clear_qstate(self):
51         self.__qstate__ = Dict()
52         self.init_qstate()
55     def get_qstate(self):
56         return self.__qstate__
59     def forward(self, inputs, *args, **kwargs):
60         assert False, 'forward is not defined'
63     def update_counters(self, force_update=False):
64         self.iter_in_epoch += 1
65         if self.training or force_update:
66             self.num_batches_tracked += 1
67             if self.iter_in_epoch == 0:
68                 self.epoch += 1.0
69             #
70         #
71     #
73     # force_update is used to increment inte counters even in non training
74     # used for validation in QuantTestModule
75     def analyze_graph(self, inputs, *args, force_update=False, merge_weights=False, clear_qstate=False, **kwargs):
76         with torch.no_grad():
77             self.init_qstate()
78             self.update_counters(force_update=force_update)
79             if (self.get_qstate().analyzed_graph == False):
80                 # forward and analyze
81                 self.forward_analyze_modules(inputs, *args, **kwargs)
82                 # analyze the connections
83                 self.analyze_connections()
84                 self.get_qstate().analyzed_graph = True
86                 # merge weights so that weight quantization can be done
87                 if merge_weights:
88                     self.merge_weights()
89                 #
91                 if clear_qstate:
92                     self.clear_qstate()
93                 #
94             #
95         #
98     def model_surgery_quantize(self, dummy_input, *args, **kwargs):
99         # lear the sates - just to be sure
100         self.clear_qstate()
101         # analyze
102         self.analyze_graph(dummy_input, *args, **kwargs)
103         # insert QAct wherever range clipping needs to be done
104         self.model_surgery_activations()
105         # since we might have added new activations, clear the sates as they may not be valid
106         self.clear_qstate()
107         # need to call analyze_graph in the derived class
108     #
110     def model_surgery_activations(self):
111         for module_hash, qparams in self.get_qstate().qparams.items():
112             module = self.get_module(module_hash)
113             if isinstance(module, layers.PAct2):
114                 pass
115             elif qparams.quantize_out:
116                 if utils.is_activation(module):
117                     if isinstance(module, (torch.nn.ReLU, torch.nn.ReLU6)):
118                         activation_q = layers.PAct2(signed=False)
119                     elif isinstance(module, torch.nn.Hardtanh):
120                         activation_q = layers.PAct2(clip_range=(module.min_val, module.max_val))
121                     elif isinstance(module, layers.QAct):
122                         activation_q = layers.PAct2(signed=None)
123                     else:
124                         activation_q = layers.PAct2(signed=None)
125                     #
126                     # replace the existing activation by PAct2
127                     parent = utils.get_parent_module(self, module)
128                     name = utils.get_module_name(parent, module)
129                     activation_q.train(self.training)
130                     setattr(parent, name, activation_q)
131                 elif not hasattr(module, 'activation_q'):
132                     activation_q = layers.PAct2(signed=None)
133                     activation_q.train(self.training)
134                     module.activation_q = activation_q
135                 #
136             elif qparams.quantize_in:
137                 if not hasattr(module, 'activation_in'):
138                     # TODO: set percentile_range_shrink=0.0 to avoid shrinking of input range, if needed.
139                     activation_in = layers.PAct2(signed=None)
140                     activation_in.train(self.training)
141                     module.activation_in = activation_in
142                 #
143             else:
144                 pass
145             #
146         #
147     #
150     def train(self, mode=True):
151         self.iter_in_epoch = -1
152         super().train(mode)
155     ################################################################
156     def forward_analyze_modules(self, inputs, *args, **kwargs):
157         '''
158         analyze modules needs a call hook - the call hook does not work with DataParallel.
159         So, do the analysis on a copy.
160         '''
161         self_copy = copy.deepcopy(self)
162         self_copy._forward_analyze_modules_impl(inputs, *args, **kwargs)
163         self.get_qstate().qparams = self_copy.get_qstate().qparams
165     def _forward_analyze_modules_impl(self, inputs, *args, **kwargs):
166         self.start_call()
167         self.add_call_hook(self, self._analyze_modules_op)
168         forward_analyze_method_name = kwargs.pop('forward_analyze_method', None)
169         if forward_analyze_method_name is not None and hasattr(self.module, forward_analyze_method_name):
170             # get the bound method to be used as forward
171             forward_analyze_method = getattr(self.module, forward_analyze_method_name)
172             output = forward_analyze_method(inputs, *args, **kwargs)
173         else:
174             output = self.module(inputs, *args, **kwargs)
175         #
176         self.remove_call_hook(self.module)
177         self.finish_call()
178         return output
180     def _analyze_modules_op(self, op, inputs, *args, **kwargs):
181         inputs = utils.squeeze_list2(inputs)
182         self.start_node(op)
183         self.add_node(op, inputs)
184         outputs = op.__forward_orig__(inputs, *args, **kwargs)
185         self.add_node(op, inputs, outputs)
186         self.finish_node(op, inputs, outputs)
187         return outputs
189     def add_node(self, module, inputs, outputs=None):
190         inputs = self.format_tensors(inputs)
191         module_hash = self.module_hash(module)
193         if module_hash not in list(self.get_qstate().qparams.keys()):
194             self.get_qstate().qparams[module_hash] = Dict()
195             self.get_qstate().qparams[module_hash].qrange_w = None
196             self.get_qstate().qparams[module_hash].qrange_b = None
197             self.get_qstate().qparams[module_hash].qrange_in = []
198             self.get_qstate().qparams[module_hash].qrange_out = []
199             self.get_qstate().qparams[module_hash].is_input = (self.module is module)
200             self.get_qstate().qparams[module_hash].previous_node = []
201             self.get_qstate().qparams[module_hash].next_node = []
202             self.get_qstate().qparams[module_hash].current_node = module_hash
204         current_node = self.get_qstate().qparams[module_hash].current_node
205         for inp in inputs:
206             if hasattr(inp, 'qparams') and hasattr(inp.qparams, 'last_node'):
207                 prev_module_hash = inp.qparams.last_node
208                 prev_module = self.get_module(prev_module_hash)
209                 previous_node = self.get_qstate().qparams[module_hash].previous_node
210                 next_node = self.get_qstate().qparams[prev_module_hash].next_node
212                 if str(inp.qparams.last_node) not in [str(p) for p in previous_node]:
213                     self.get_qstate().qparams[module_hash].previous_node += [inp.qparams.last_node]
214                 if str(current_node) not in [str(n) for n in next_node]:
215                     self.get_qstate().qparams[prev_module_hash].next_node += [current_node]
217         if outputs is not None:
218             outputs = self.format_tensors(outputs)
219             for opt in outputs:
220                 if not hasattr(opt, 'qparams'):
221                     opt.qparams = Dict()
222                 #
223                 # update last_node if this is not a container module
224                 # if this is a container module, the last_node would have been already filled in in the last leaf module
225                 if len(module._modules) == 0:
226                     opt.qparams.last_node = current_node
227                 #
230     ################################################################
231     def analyze_connections(self):
232         first_module = None
233         prediction_module = None
234         for module_hash, qparams in self.get_qstate().qparams.items():
235             module = self.get_module(module_hash)
236             if utils.is_conv_deconv_linear(module) or utils.is_normalization(module) or utils.is_activation(module):
237                 first_module = module if first_module is None else first_module
238                 prediction_module = module
239             #
240         #
241         for module_hash, qparams in self.get_qstate().qparams.items():
242             module = self.get_module(module_hash)
243             is_first_module = (first_module is module)
244             is_prediction_module = (prediction_module is module)
245             self._analyse_connections_op(module_hash, module, qparams, is_first_module, is_prediction_module)
246         #
248     def _analyse_connections_op(self, module_hash, module, qparams, is_first_module, is_prediction_module):
249         previous_modules = [self.get_module(p) for p in qparams.previous_node] if len(qparams.previous_node)>0 else []
250         next_modules = [self.get_module(n) for n in qparams.next_node] if len(qparams.next_node)>0 else []
252         quantize_out = False
253         if isinstance(module, self.ignore_out_blocks):
254             quantize_out = False
255         elif utils.is_activation(module):
256             if len(next_modules)==1 and utils.is_activation(next_modules[0]):
257                 quantize_out = False
258             else:
259                 quantize_out = True
260             #
261         elif isinstance(module, self.quantize_out_blocks):
262             if len(next_modules)==1 and utils.is_activation(next_modules[0]):
263                 quantize_out = False
264             else:
265                 quantize_out = True
266             #
267         elif utils.is_normalization(module):
268             if len(next_modules)==1 and utils.is_activation(next_modules[0]):
269                 quantize_out = False
270             else:
271                 quantize_out = True
272             #
273         elif utils.is_conv(module) or utils.is_deconv(module):
274             if len(next_modules)==1 and (utils.is_normalization(next_modules[0]) or utils.is_activation(next_modules[0])):
275                 quantize_out = False
276             else:
277                 quantize_out = True
278             #
279         elif utils.is_linear(module):
280             if len(next_modules)==1 and (utils.is_normalization(next_modules[0]) or utils.is_activation(next_modules[0])):
281                 quantize_out = False
282             else:
283                 quantize_out = True
284             #
285         # elif isinstance(module, (torch.nn.AdaptiveAvgPool2d, torch.nn.Upsample, layers.ResizeTo, torch.nn.Flatten)):
286         #     quantize_out = True
287         # #
289         if len(qparams.previous_node) > 0:
290             previous_module_hash = qparams.previous_node[-1]
291             previous_module = self.get_module(previous_module_hash)
292             previous_module_qparams = self.get_qstate().qparams[previous_module_hash]
293             is_input_ignored = isinstance(previous_module, self.ignore_out_blocks)
294             is_input_quantized = previous_module_qparams.quantize_out if \
295                 hasattr(previous_module_qparams, 'quantize_out') else False
296         else:
297             is_input_ignored = False
298             is_input_quantized = False
299         #
301         quantize_in = utils.is_conv_deconv_linear(module) and not is_input_quantized and \
302                       not is_input_ignored and is_first_module
303         qparams.quantize_w = utils.is_conv_deconv_linear(module)                                # all conv/deconv layers will be quantized
304         qparams.quantize_b = utils.is_conv_deconv_linear(module)                                # all conv/deconv layers will be quantized
305         qparams.quantize_out = quantize_out                                                     # selectively quantize output
306         qparams.quantize_in = quantize_in                                                       # only top modules's input need to be quantized
307         multi_input_blocks = (layers.AddBlock, layers.CatBlock, torch.nn.AdaptiveAvgPool2d)
308         qparams.align_in = isinstance(module, multi_input_blocks)                               # all tensors to be made same q at the input
309         qparams.scale_in = 64.0 if isinstance(module, torch.nn.AdaptiveAvgPool2d) else 1.0      # additional scaleup to simulate fixed point
310         qparams.unquantize_out = qparams.is_input                                               # only top modules's output need to be unquantized
311         qparams.is_dwconv = utils.is_dwconv(module)
312         qparams.next_modules = next_modules
313         qparams.is_first_module = is_first_module
314         qparams.is_prediction_module = is_prediction_module
317     ################################################################
318     def merge_weights(self, make_backup=False):
319         assert self.get_qstate().analyzed_graph == True, 'graph must be analyzed before merge_weights()'
320         with torch.no_grad():
321             for module_hash, qparams in self.get_qstate().qparams.items():
322                 module = self.get_module(module_hash)
323                 self._merge_weight_op(module_hash, module, qparams, make_backup)
324             #
325         #
326     #
327     def _merge_weight_op(self, module_hash, module, qparams, make_backup):
328         is_conv = utils.is_conv_deconv(module)
330         # note: we consider merging only if there is a single next node
331         next_module = qparams.next_modules[0] if len(qparams.next_modules) == 1 else None
332         next_bn = isinstance(next_module, torch.nn.BatchNorm2d) if (next_module is not None) else None
334         # if the next module is a bn, appy bn merging step
335         if is_conv and next_bn:
336             conv = module
337             bn = next_module
339             # weight/bias
340             conv_bias = conv.bias.data if (conv.bias is not None) else 0.0
341             bn_weight = bn.weight.data if bn.affine else 1.0
342             bn_bias = bn.bias.data if bn.affine else 0.0
344             # merged weight and offset
345             merged_scale = torch.rsqrt(bn.running_var.data + bn.eps) * bn_weight
346             if utils.is_conv(conv):
347                 merged_scale = merged_scale.view(-1, 1, 1, 1)
348             elif utils.is_deconv(conv):
349                 merged_scale = merged_scale.view(1, -1, 1, 1)
350             else:
351                 assert False, 'unable to merge convolution and BN'
352             #
353             merged_weight = conv.weight.data * merged_scale
354             merged_bias = (conv_bias - bn.running_mean.data) * merged_scale.view(-1) + bn_bias
356             # bn is set to unity
357             bn.running_mean.data.fill_(0.0)
358             bn.running_var.data.fill_(1.0 - bn.eps)
359             if bn.affine:
360                 bn.weight.data.fill_(1.0)
361                 bn.bias.data.fill_(0.0)
362             #
364             # copy merged weights to conv
365             conv.weight.data.copy_(merged_weight)
367             # copy merge bias
368             if conv.bias is not None:
369                 conv.bias.data.copy_(merged_bias)
370             elif bn.affine:
371                 bn.bias.data.copy_(merged_bias.data)
372             else:
373                 warnings.warn('problem detected in conv+bn layer pair: either one of conv.bias or bn.affine is required for successfull calibration - preferably bn.affine')
374             #
375         #
376         return
379     ################################################################
380     def get_qparams(self, module):
381         module_hash = self.module_hash(module)
382         return self.get_qstate().qparams[module_hash]
385     def get_qparams_prev(self, module):
386         module_hash = self.module_hash(module)
387         return self.get_qstate().qparams_prev[module_hash] if self.get_qstate().qparams_prev else None
390     def start_call(self):
391         self.call_count = Dict()
394     def finish_call(self):
395         self.call_count = None
398     def start_node(self, module):
399         module_name = self.module_name(module)
400         if module_name not in list(self.call_count.keys()):
401             self.call_count[module_name] = 0
402         #
403         return
406     def finish_node(self, module, inputs, outputs):
407         module_name = self.module_name(module)
408         self.call_count[module_name] = self.call_count[module_name] + 1
409         return
412     def module_hash(self, module):
413         '''
414         A module may be called multiple times in a model. This module has creates a unique name/hash for each call
415         using teh call_count. call_count needs tobe correct for this to work as expected.
416         call_count is kep up to date by using start_node() / finish_node() calls.
417         '''
418         module_name = self.module_name(module)
419         module_hash = module_name + '-call:{}'.format(self.call_count[module_name])
420         return module_hash
423     def module_name(self, module):
424         name = None
425         for n, m in self.named_modules():
426             if m is module:
427                 name = n
428             #
429         #
430         return name
433     def get_module(self, module_hash):
434         module_name = module_hash.split('-call:')[0]
435         for mname, mod in self.named_modules():
436             if module_name == mname:
437                 return mod
438         #
439         return None
442     def is_last_conv(self, module):
443         # implementation is not correct. disable it for the time being
444         return False #(module is self.last_conv_linear_module)
447     def format_tensors(self, inputs):
448         # make a list/tuple if inputs is not. if it is a double list, remove the extra one
449         inputs = utils.squeeze_list2(utils.make_list(inputs))
450         # remove lists/tuple
451         inputs = [ipt for ipt in inputs if utils.is_tensor(ipt)]
452         return inputs
455     def copy_qparams(self, qparams, inputs):
456         qparams_copy = Dict()
457         for module_hash, qparam_entry in qparams.items():
458             qparams_copy[module_hash] = Dict()
459             for key, value in qparam_entry.items():
460                 # deep copy may not work in some cases, so do it conditionally
461                 try:
462                     qparams_copy[module_hash][key] = copy.deepcopy(value)
463                 except Exception:
464                     qparams_copy[module_hash][key] = value
465                 #
466             #
467         #
468         return qparams_copy