[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_test_module.py
diff --git a/modules/pytorch_jacinto_ai/xnn/quantize/quant_test_module.py b/modules/pytorch_jacinto_ai/xnn/quantize/quant_test_module.py
index 5814f025cb1e7836b0930616b86a707d224e2dfe..7166caade59c444511edbe57ec147cb33eba2fab 100644 (file)
class QuantTestModule(QuantBaseModule):
- def __init__(self, module, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False, power2_weights=True, histogram_range=True,
- range_calibration_online=False, bias_calibration=False, model_surgery_quantize=False, dummy_input=None):
- super().__init__(module)
- assert dummy_input is not None, 'dummy input is needed by quantized models to analyze graph'
- self.bitwidth_weights = bitwidth_weights
- self.bitwidth_activations = bitwidth_activations
- self.per_channel_q = per_channel_q
- self.power2_weights = power2_weights
- # this is actually to indicate the bias calibration - indicates this was called from the derived class
- self.bias_calibration = bias_calibration
-
+ def __init__(self, module, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False, histogram_range=True,
+ range_calibration_online=False, bias_calibration=False, dummy_input=None, model_surgery_quantize=False):
+ super().__init__(module, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
+ per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
+ constrain_weights=False, dummy_input=dummy_input, model_surgery_quantize=model_surgery_quantize)
+ # use power2_weights for now
+ self.power2_weights = True
# whether to do online adjustment of calibration using previous frame range
self.range_calibration_online = range_calibration_online
# number of offline calibration iters. during offline calibration, current frame range is used
self.idx_large_mse_for_act = 0
- # put in eval mode before analyze
- self.eval()
-
- with torch.no_grad():
- # model surgery for quantization
- if model_surgery_quantize:
- utils.print_yellow("=> model surgery by '{}'".format(self.model_surgery_quantize.__name__))
- self.model_surgery_quantize(dummy_input)
- #
- #
-
- # for help in debug/print
- for n, m in self.named_modules():
- m.name = n
- #
-
def model_surgery_quantize(self, dummy_input):
super().model_surgery_quantize(dummy_input)
# apply recursively
self.apply(replace_func)
- # add a forward hook to call the extra activation that we added
- def _forward_activation(op, inputs, outputs):
- if hasattr(op, 'activation_q'):
- outputs = op.activation_q(outputs)
- #
- return outputs
- #
- for m in self.modules():
- m.register_forward_hook(_forward_activation)
- #
-
# clear
self.clear_states()
#
# implement this in a derived class to clamp weights
- def constrain_weights(self, module):
+ def apply_constrain_weights(self, module):
pass
def quantize_weights(self, module, tensor_in, qrange):
- self.constrain_weights(module)
+ self.apply_constrain_weights(module)
bitwidth_weights = self.get_bitwidth_weights(module)
with torch.no_grad():
# #
#
#
-# def constrain_weights(self, module):
+# def apply_constrain_weights(self, module):
# if not self.constrain_weights_enabled:
# return
# #
-# constrained_weight = constrain_weight(module.weight.data)
+# constrained_weight = quant_utils.constrain_weight(module.weight.data)
# module.weight.data.copy_(constrained_weight.data)
# #
#