[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_base_module.py
1 import copy
2 from .quant_graph_module import *
4 ###########################################################
5 class QuantEstimationType:
6 QUANTIZED_THROUGH_ESTIMATION = 0
7 STRAIGHT_THROUGH_ESTIMATION = 1
8 ALPHA_BLENDING_ESTIMATION = 2
11 # base module to be use for all quantization modules
12 class QuantBaseModule(QuantGraphModule):
13 def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
14 histogram_range=True, bias_calibration=False, constrain_weights=False,
15 model_surgery_quantize=False):
16 super().__init__(module)
17 self.bitwidth_weights = bitwidth_weights
18 self.bitwidth_activations = bitwidth_activations
19 self.per_channel_q = per_channel_q
20 self.histogram_range = histogram_range
21 self.constrain_weights = constrain_weights
22 self.bias_calibration = bias_calibration
23 # for help in debug/print
24 utils.add_module_names(self)
25 # put in eval mode before analyze
26 self.eval()
27 # model surgery for quantization
28 if model_surgery_quantize:
29 with torch.no_grad():
30 utils.print_yellow("=> model surgery by '{}'".format(self.model_surgery_quantize.__name__))
31 assert dummy_input is not None, 'dummy input is needed by quantized models to analyze graph'
32 self.model_surgery_quantize(dummy_input)
33 #
34 # add hooks to execute the pact modules
35 self.add_activation_hooks()
36 #
37 # for help in debug/print
38 utils.add_module_names(self)
41 def add_activation_hooks(self):
42 # add a forward hook to call the extra activation that we added
43 def _forward_activation(op, inputs, outputs):
44 if hasattr(op, 'activation_q'):
45 outputs = op.activation_q(outputs)
46 #
47 return outputs
48 #
49 for m in self.modules():
50 m.register_forward_hook(_forward_activation)
51 #
54 def train(self, mode=True):
55 self.iter_in_epoch.fill_(-1.0)
56 super().train(mode)
59 def _backup_weights_orig(self):
60 self.__params_orig__ = {}
61 for n,p in self.named_parameters():
62 self.__params_orig__[n] = copy.deepcopy(p.data)
63 #
64 self.__buffers_orig__ = {}
65 for n,p in self.named_buffers():
66 self.__buffers_orig__[n] = copy.deepcopy(p.data)
67 #
69 def _restore_weights_orig(self):
70 for n,p in self.named_parameters():
71 p.data.copy_(self.__params_orig__[n].data)
72 #
73 for n,p in self.named_buffers():
74 p.data.copy_(self.__buffers_orig__[n].data)
75 #
77 def _backup_weights_quant(self):
78 self.__params_quant__ = {}
79 for n,p in self.named_parameters():
80 self.__params_quant__[n] = copy.deepcopy(p.data)
81 #
82 self.__buffers_quant__ = {}
83 for n,p in self.named_buffers():
84 self.__buffers_quant__[n] = copy.deepcopy(p.data)
85 #
87 def _restore_weights_quant(self):
88 for n,p in self.named_parameters():
89 p.data.copy_(self.__params_quant__[n].data)
90 #
91 for n,p in self.named_buffers():
92 p.data.copy_(self.__buffers_quant__[n].data)
93 #
95 def _remove_backups(self):
96 if hasattr(self, '__params_orig__'):
97 del self.__params_orig__
98 if hasattr(self, '__params_quant__'):
99 del self.__params_quant__
100 if hasattr(self, '__buffers_orig__'):
101 del self.__params_orig__
102 if hasattr(self, '__buffers_quant__'):
103 del self.__params_quant__
104 #
105 # output means are some temp buffers used for calibration
106 def _remove_output_means_op(self, op):
107 if hasattr(op, '__output_mean_orig__'):
108 del op.__output_mean_orig__
109 if hasattr(op, '__output_std_orig__'):
110 del op.__output_std_orig__
111 #
112 #
113 self.apply(_remove_output_means_op)