]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xnn/quantize/quant_base_module.py
6e50118d1c3ea9e9bc8e48cace6fedfc2d25fddb
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_base_module.py
1 import copy
2 from .quant_graph_module import *
4 ###########################################################
5 class QuantEstimationType:
6     QUANTIZED_THROUGH_ESTIMATION = 0
7     STRAIGHT_THROUGH_ESTIMATION = 1
8     ALPHA_BLENDING_ESTIMATION = 2
11 # base module to be use for all quantization modules
12 class QuantBaseModule(QuantGraphModule):
13     def __init__(self, module, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
14                  histogram_range=True, bias_calibration=False, constrain_weights=False, dummy_input=None,
15                  model_surgery_quantize=False):
16         super().__init__(module)
17         self.bitwidth_weights = bitwidth_weights
18         self.bitwidth_activations = bitwidth_activations
19         self.per_channel_q = per_channel_q
20         self.histogram_range = histogram_range
21         self.constrain_weights = constrain_weights
22         self.bias_calibration = bias_calibration
23         # for help in debug/print
24         utils.add_module_names(self)
25         # put in eval mode before analyze
26         self.eval()
27         # model surgery for quantization
28         if model_surgery_quantize:
29             with torch.no_grad():
30                 utils.print_yellow("=> model surgery by '{}'".format(self.model_surgery_quantize.__name__))
31                 assert dummy_input is not None, 'dummy input is needed by quantized models to analyze graph'
32                 self.model_surgery_quantize(dummy_input)
33             #
34             # add hooks to execute the pact modules
35             self.add_activation_hooks()
36         #
37         # for help in debug/print
38         utils.add_module_names(self)
41     def add_activation_hooks(self):
42         # add a forward hook to call the extra activation that we added
43         def _forward_activation(op, inputs, outputs):
44             if hasattr(op, 'activation_q'):
45                 outputs = op.activation_q(outputs)
46             #
47             return outputs
48         #
49         for m in self.modules():
50             m.register_forward_hook(_forward_activation)
51         #
54     def train(self, mode=True):
55         self.iter_in_epoch.fill_(-1.0)
56         super().train(mode)
59     def _backup_weights_orig(self):
60         self.__params_orig__ = {}
61         for n,p in self.named_parameters():
62             self.__params_orig__[n] = copy.deepcopy(p.data)
63         #
64         self.__buffers_orig__ = {}
65         for n,p in self.named_buffers():
66             self.__buffers_orig__[n] = copy.deepcopy(p.data)
67         #
69     def _restore_weights_orig(self):
70         for n,p in self.named_parameters():
71             p.data.copy_(self.__params_orig__[n].data)
72         #
73         for n,p in self.named_buffers():
74             p.data.copy_(self.__buffers_orig__[n].data)
75         #
77     def _backup_weights_quant(self):
78         self.__params_quant__ = {}
79         for n,p in self.named_parameters():
80             self.__params_quant__[n] = copy.deepcopy(p.data)
81         #
82         self.__buffers_quant__ = {}
83         for n,p in self.named_buffers():
84             self.__buffers_quant__[n] = copy.deepcopy(p.data)
85         #
87     def _restore_weights_quant(self):
88         for n,p in self.named_parameters():
89             p.data.copy_(self.__params_quant__[n].data)
90         #
91         for n,p in self.named_buffers():
92             p.data.copy_(self.__buffers_quant__[n].data)
93         #
95     def _remove_backups(self):
96         if hasattr(self, '__params_orig__'):
97             del self.__params_orig__
98         if hasattr(self, '__params_quant__'):
99             del self.__params_quant__
100         if hasattr(self, '__buffers_orig__'):
101             del self.__params_orig__
102         if hasattr(self, '__buffers_quant__'):
103             del self.__params_quant__
104         #
105         # output means are some temp buffers used for calibration
106         def _remove_output_means_op(self, op):
107             if hasattr(op, '__output_mean_orig__'):
108                 del op.__output_mean_orig__
109             if hasattr(op, '__output_std_orig__'):
110                 del op.__output_std_orig__
111             #
112         #
113         self.apply(_remove_output_means_op)