[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_graph_module.py
diff --git a/modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py b/modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py
index cbcb3c1ebb6fd8c4a465ab5d1eb80d0b87b53cec..b9dc4c1311bb2f90480a789efd8947ee01f9b727 100644 (file)
return self.__qstate__
- def forward(self, inputs):
+ def forward(self, inputs, *args, **kwargs):
assert False, 'forward is not defined'
# force_update is used to increment inte counters even in non training
# used for validation in QuantTestModule
- def analyze_graph(self, inputs, force_update=False, merge_weights=False, clear_qstate=False):
+ def analyze_graph(self, inputs, *args, force_update=False, merge_weights=False, clear_qstate=False, **kwargs):
with torch.no_grad():
self.init_qstate()
self.update_counters(force_update=force_update)
if (self.get_qstate().analyzed_graph == False):
# forward and analyze
- self.forward_analyze_modules(inputs)
+ self.forward_analyze_modules(inputs, *args, **kwargs)
# analyze the connections
self.analyze_connections()
self.get_qstate().analyzed_graph = True
#
- def model_surgery_quantize(self, dummy_input):
+ def model_surgery_quantize(self, dummy_input, *args, **kwargs):
# lear the sates - just to be sure
self.clear_qstate()
# analyze
- self.analyze_graph(dummy_input)
+ self.analyze_graph(dummy_input, *args, **kwargs)
# insert QAct wherever range clipping needs to be done
self.model_surgery_activations()
# since we might have added new activations, clear the sates as they may not be valid
#
elif qparams.quantize_in:
if not hasattr(module, 'activation_in'):
- activation_in = layers.PAct2(signed=None)
+ # do not want to clip input, so set percentile_range_shrink=0.0
+ activation_in = layers.PAct2(signed=None, percentile_range_shrink=0.0)
activation_in.train(self.training)
module.activation_in = activation_in
#
################################################################
- def forward_analyze_modules(self, inputs):
+ def forward_analyze_modules(self, inputs, *args, **kwargs):
'''
analyze modules needs a call hook - the call hook does not work with DataParallel.
So, do the analysis on a copy.
'''
self_copy = copy.deepcopy(self)
- self_copy._forward_analyze_modules_impl(inputs)
+ self_copy._forward_analyze_modules_impl(inputs, *args, **kwargs)
self.get_qstate().qparams = self_copy.get_qstate().qparams
- def _forward_analyze_modules_impl(self, inputs):
+ def _forward_analyze_modules_impl(self, inputs, *args, **kwargs):
self.start_call()
self.add_call_hook(self, self._analyze_modules_op)
- output = self.module(inputs)
+ forward_analyze_method_name = kwargs.pop('forward_analyze_method', None)
+ if forward_analyze_method_name is not None and hasattr(self.module, forward_analyze_method_name):
+ # get the bound method to be used as forward
+ forward_analyze_method = getattr(self.module, forward_analyze_method_name)
+ output = forward_analyze_method(inputs, *args, **kwargs)
+ else:
+ output = self.module(inputs, *args, **kwargs)
+ #
self.remove_call_hook(self.module)
self.finish_call()
return output
- def _analyze_modules_op(self, op, *inputs_orig):
- inputs = utils.squeeze_list2(inputs_orig)
+ def _analyze_modules_op(self, op, inputs, *args, **kwargs):
+ inputs = utils.squeeze_list2(inputs)
self.start_node(op)
self.add_node(op, inputs)
- outputs = op.__forward_orig__(*inputs_orig)
+ outputs = op.__forward_orig__(inputs, *args, **kwargs)
self.add_node(op, inputs, outputs)
self.finish_node(op, inputs, outputs)
return outputs