fa6b75396e10e1ac25de878c5c6a2b02aae217fe
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_base_module.py
1 import copy
2 from .quant_graph_module import *
4 ###########################################################
5 class QuantEstimationType:
6 QUANTIZED_THROUGH_ESTIMATION = 0
7 STRAIGHT_THROUGH_ESTIMATION = 1
8 ALPHA_BLENDING_ESTIMATION = 2
11 # base module to be use for all quantization modules
12 class QuantBaseModule(QuantGraphModule):
13 def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
14 histogram_range=True, bias_calibration=False, constrain_weights=False, constrain_bias=None,
15 model_surgery_quantize=True, power2_weight_range=None, power2_activation_range=None):
16 super().__init__(module)
17 self.bitwidth_weights = bitwidth_weights
18 self.bitwidth_activations = bitwidth_activations
19 self.per_channel_q = per_channel_q
20 self.histogram_range = histogram_range
21 self.constrain_weights = constrain_weights
22 self.constrain_bias = True if (constrain_bias is None) else constrain_bias
23 self.bias_calibration = bias_calibration
24 self.power2_weight_range = True if (power2_weight_range is None) else power2_weight_range
25 self.power2_activation_range = True if (power2_activation_range is None) else power2_activation_range
26 # range shrink - 0.0 indicates no shrink
27 self.percentile_range_shrink = (layers.PAct2.PACT2_RANGE_SHRINK if histogram_range else 0.0)
28 # for help in debug/print
29 utils.add_module_names(self)
30 # put in eval mode before analyze
31 self.eval()
32 # model surgery for quantization
33 if model_surgery_quantize:
34 with torch.no_grad():
35 utils.print_yellow("=> model surgery by '{}'".format(self.model_surgery_quantize.__name__))
36 assert dummy_input is not None, 'dummy input is needed by quantized models to analyze graph'
37 self.model_surgery_quantize(dummy_input)
38 #
39 # add hooks to execute the pact modules
40 self.add_activation_hooks()
41 #
42 # for help in debug/print
43 utils.add_module_names(self)
45 # set attributes to all modules - can control the behaviour from here
46 utils.apply_setattr(self, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
47 histogram_range=histogram_range, bias_calibration=self.bias_calibration, per_channel_q=self.per_channel_q,
48 percentile_range_shrink=self.percentile_range_shrink, constrain_weights=self.constrain_weights,
49 power2_weight_range=self.power2_weight_range, power2_activation_range=self.power2_activation_range,
50 constrain_bias=self.constrain_bias)
52 def add_activation_hooks(self):
53 # add a forward hook to call the extra activation that we added
54 def _forward_input_activation(op, inputs):
55 if hasattr(op, 'activation_in'):
56 # hook passes the input as tuple - expand it
57 to_squeeze = isinstance(inputs, tuple) and len(inputs) == 1
58 inputs = inputs[0] if to_squeeze else inputs
59 inputs = op.activation_in(inputs)
60 inputs = (inputs,) if to_squeeze else inputs
61 #
62 return inputs
63 #
64 def _forward_output_activation(op, inputs, outputs):
65 if hasattr(op, 'activation_q'):
66 # hook passes the input as tuple - expand it
67 to_squeeze = isinstance(outputs, tuple) and len(outputs) == 1
68 outputs = outputs[0] if to_squeeze else outputs
69 outputs = op.activation_q(outputs)
70 outputs = (outputs,) if to_squeeze else outputs
71 #
72 return outputs
73 #
74 for m in self.modules():
75 m.register_forward_pre_hook(_forward_input_activation)
76 m.register_forward_hook(_forward_output_activation)
77 #
80 def train(self, mode=True):
81 self.iter_in_epoch = -1
82 super().train(mode)
85 def _backup_weights_orig(self):
86 self.__params_orig__ = {}
87 for n,p in self.named_parameters():
88 self.__params_orig__[n] = copy.deepcopy(p.data)
89 #
90 self.__buffers_orig__ = {}
91 for n,p in self.named_buffers():
92 self.__buffers_orig__[n] = copy.deepcopy(p.data)
93 #
95 def _restore_weights_orig(self):
96 for n,p in self.named_parameters():
97 p.data.copy_(self.__params_orig__[n].data)
98 #
99 for n,p in self.named_buffers():
100 p.data.copy_(self.__buffers_orig__[n].data)
101 #
103 def _backup_weights_quant(self):
104 self.__params_quant__ = {}
105 for n,p in self.named_parameters():
106 self.__params_quant__[n] = copy.deepcopy(p.data)
107 #
108 self.__buffers_quant__ = {}
109 for n,p in self.named_buffers():
110 self.__buffers_quant__[n] = copy.deepcopy(p.data)
111 #
113 def _restore_weights_quant(self):
114 for n,p in self.named_parameters():
115 p.data.copy_(self.__params_quant__[n].data)
116 #
117 for n,p in self.named_buffers():
118 p.data.copy_(self.__buffers_quant__[n].data)
119 #
121 def _remove_backups(self):
122 if hasattr(self, '__params_orig__'):
123 del self.__params_orig__
124 if hasattr(self, '__params_quant__'):
125 del self.__params_quant__
126 if hasattr(self, '__buffers_orig__'):
127 del self.__params_orig__
128 if hasattr(self, '__buffers_quant__'):
129 del self.__params_quant__
130 #
131 # output means are some temp buffers used for calibration
132 def _remove_output_means_op(self, op):
133 if hasattr(op, '__output_mean_orig__'):
134 del op.__output_mean_orig__
135 if hasattr(op, '__output_std_orig__'):
136 del op.__output_std_orig__
137 #
138 #
139 self.apply(_remove_output_means_op)