]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py
torch.nn.ReLU is the recommended activation module. removed the custom defined module...
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_train_module.py
1 ###########################################################
2 # Approximate quantized floating point simulation with gradients.
3 # Can be used for quantized training of models.
4 ###########################################################
6 import torch
7 import numpy as np
8 import copy
9 import warnings
11 from .. import layers
12 from .. import utils
13 from .quant_train_utils import *
15 warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
18 ###########################################################
19 class QuantTrainModule(QuantBaseModule):
20     def __init__(self, module, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False, histogram_range=True,
21                  constrain_weights=True, bias_calibration=False, dummy_input=None):
22         super().__init__(module)
23         assert dummy_input is not None, 'dummy input is needed by quantized models to analyze graph'
24         self.bitwidth_weights = bitwidth_weights
25         self.bitwidth_activations = bitwidth_activations
26         self.per_channel_q = per_channel_q
27         self.constrain_weights = constrain_weights #and (not bool(self.per_channel_q))
28         self.bias_calibration = bias_calibration
30         # for help in debug/print
31         utils.add_module_names(self)
33         # put in eval mode before analyze
34         self.eval()
36         with torch.no_grad():
37             # model surgery for quantization
38             utils.print_yellow("=> model surgery by '{}'".format(self.model_surgery_quantize.__name__))
39             self.model_surgery_quantize(dummy_input)
40         #
42         # range shrink - 0.0 indicates no shrink
43         percentile_range_shrink = (layers.PAct2.PACT2_RANGE_SHRINK if histogram_range else 0.0)
44         # set attributes to all modules - can control the behaviour from here
45         utils.apply_setattr(self, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations, per_channel_q=per_channel_q, bias_calibration=bias_calibration,
46                            quantize_enable=True, quantize_weights=True, quantize_bias=True, quantize_activations=True,
47                            percentile_range_shrink=percentile_range_shrink, constrain_weights=self.constrain_weights)
49         # for help in debug/print
50         utils.add_module_names(self)
53     def train(self, mode=True):
54         self.iter_in_epoch.fill_(-1.0)
55         super().train(mode)
58     def forward(self, inputs):
59         # analyze
60         self.analyze_graph(inputs=inputs, cleanup_states=True)
62         # actual forward call
63         if self.training and self.bias_calibration:
64             # bias calibration
65             outputs = self.forward_calibrate_bias(inputs)
66         else:
67             outputs = self.module(inputs)
68         #
69         return outputs
72     def forward_calibrate_bias(self, inputs):
73         assert False, 'forward_calibrate_bias is not implemented'
76     def model_surgery_quantize(self, dummy_input):
77         super().model_surgery_quantize(dummy_input)
79         def replace_func(op):
80             for name, m in op._modules.items():
81                 if utils.is_conv(m):
82                     bias = (m.bias is not None)
83                     padding_mode = m.padding_mode
84                     new_m = QuantTrainConv2d(in_channels=m.in_channels, out_channels=m.out_channels, kernel_size=m.kernel_size, stride=m.stride,
85                                             padding=m.padding, dilation=m.dilation, groups=m.groups, bias=bias, padding_mode=padding_mode)
86                 elif utils.is_deconv(m):
87                     bias = (m.bias is not None)
88                     padding_mode = m.padding_mode
89                     new_m = QuantTrainConvTranspose2d(in_channels=m.in_channels, out_channels=m.out_channels, kernel_size=m.kernel_size, stride=m.stride,
90                                             padding=m.padding, output_padding=m.output_padding, groups=m.groups, bias=bias, dilation=m.dilation, padding_mode=padding_mode)
91                 elif utils.is_bn(m):
92                     new_m = QuantTrainBatchNorm2d(num_features=m.num_features, eps=m.eps, momentum=m.momentum, affine=m.affine,
93                                             track_running_stats=m.track_running_stats)
94                 elif isinstance(m, layers.PAct2):
95                     new_m = QuantTrainPAct2(inplace=m.inplace, signed=m.signed, clip_range=m.clip_range,
96                                              bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
97                                              per_channel_q=self.per_channel_q)
98                 elif isinstance(m, layers.NoAct):
99                     new_m = QuantTrainPAct2(signed=None, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
100                                              per_channel_q=self.per_channel_q)
101                 elif isinstance(m, (torch.nn.ReLU, torch.nn.ReLU6)):
102                     new_m = QuantTrainPAct2(signed=False, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
103                                              per_channel_q=self.per_channel_q)
104                 else:
105                     new_m = None
106                 #
107                 if new_m is not None:
108                     for attr in dir(m):
109                         value = getattr(m,attr)
110                         if isinstance(value,torch.Tensor) and value is not None:
111                             getattr(new_m,attr).data.copy_(value.data)
112                         elif isinstance(value,torch.nn.Module) and value is not None:
113                             setattr(new_m, attr, getattr(m,attr))
114                         elif attr in ('weight','bias','eps','clips_act','clips_w'): # add more attribute name here
115                             setattr(new_m, attr, getattr(m, attr))
116                         #
117                     #
118                     new_m.train(m.training)
119                     setattr(op, name, new_m)
120                 #
121             #
122         #
123         # apply recursively
124         self.apply(replace_func)
126         # add a forward hook to call the extra activation that we added
127         def _forward_activation(op, inputs, outputs):
128             if hasattr(op, 'activation_q'):
129                 outputs = op.activation_q(outputs)
130             #
131             return outputs
132         #
133         for m in self.modules():
134             m.register_forward_hook(_forward_activation)
135         #
137         # clear
138         self.clear_states()
139     #
142 ###########################################################
143 class QuantCalibrateModule(QuantTrainModule):
144     def __init__(self, module, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False, bias_calibration=True,
145                  histogram_range=True, constrain_weights=True, lr_calib=0.1, dummy_input=None):
146         self.bias_calibration = bias_calibration
147         self.lr_calib = lr_calib
148         self.bias_calibration_factor = lr_calib
149         self.bias_calibration_gamma = 0.5
150         self.calibrate_weights = False
151         self.calibrate_repeats = 1
152         self.quantize_enable = True
153         # BNs can be adjusted based on the input provided - however this is not really required
154         self.calibrate_bn = False
155         super().__init__(module, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations, per_channel_q=per_channel_q,
156                          histogram_range=histogram_range, constrain_weights=constrain_weights, bias_calibration=bias_calibration, dummy_input=dummy_input)
157     #
160     def forward_calibrate_bias(self, inputs):
161         # we don't need gradients for calibration
162         with torch.no_grad():
163             # prepare/backup weights
164             if self.num_batches_tracked == 0:
165                 # lr_calib
166                 self.bias_calibration_factor = self.lr_calib * np.power(self.bias_calibration_gamma, float(self.epoch))
167                 # backup original weights
168                 self._backup_weights_orig()
169                 # backup quantized weights
170                 self._backup_weights_quant()
171             #
173             # backup the current state
174             training = self.training
176             # compute the mean output in float
177             # also, set all bns to eval. we can't set the whole model to eval because
178             # we need the pact to learn the ranges - which will happen only in training mode.
179             # also the model output itself may be different in eval mode.
180             if self.calibrate_bn:
181                 outputs = self.forward_compute_oputput_stats(inputs)
182                 utils.freeze_bn(self)
183             else:
184                 utils.freeze_bn(self)
185                 outputs = self.forward_compute_oputput_stats(inputs)
186             #
188             # adjust the quantized output to match the mean
189             outputs = self.forward_adjust_bias(inputs)
191             self.train(training)
193             return outputs
194         #
197     def forward_compute_oputput_stats(self, inputs):
198         self._restore_weights_orig()
199         # disable quantization for a moment
200         quantize_enable_backup_value = self.quantize_enable
201         utils.apply_setattr(self, quantize_enable=False)
203         self.add_call_hook(self.module, self._forward_compute_oputput_stats_hook)
204         outputs = self.module(inputs)
205         self.remove_call_hook(self.module)
207         # turn quantization back on - not a clean method
208         utils.apply_setattr(self, quantize_enable=quantize_enable_backup_value)
209         self._backup_weights_orig()
210         return outputs
211     #
212     def _forward_compute_oputput_stats_hook(self, op, *inputs_orig):
213         outputs = op.__forward_orig__(*inputs_orig)
214         # calibration at specific layers
215         bias = op.bias if (hasattr(op, 'bias') and (op.bias is not None)) else None
216         weight = op.weight if (hasattr(op, 'weight') and (op.weight is not None)) else None
217         if (bias is not None) or (self.calibrate_weights and weight is not None):
218             output = outputs[0] if isinstance(outputs, (list,tuple)) else outputs
219             reduce_dims = [0,2,3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
220             op.__output_mean_orig__ = torch.mean(output, dim=reduce_dims)
221             op.__output_std_orig__ = torch.std(output, dim=reduce_dims)
222         #
223         return outputs
224     #
227     def forward_adjust_bias(self, input):
228         self._restore_weights_quant()
229         self.add_call_hook(self.module, self._forward_adjust_bias_hook)
230         for _ in range(self.calibrate_repeats):
231             output = self.module(input)
232         #
233         self.remove_call_hook(self.module)
234         self._backup_weights_quant()
235         return output
236     #
237     def _forward_adjust_bias_hook(self, op, *inputs_orig):
238         outputs = op.__forward_orig__(*inputs_orig)
239         # calibration at specific layers
240         bias = op.bias if (hasattr(op, 'bias') and (op.bias is not None)) else None
241         if bias is not None:
242             output = outputs[0] if isinstance(outputs, (list,tuple)) else outputs
243             reduce_dims = [0,2,3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
244             output_mean = torch.mean(output, dim=reduce_dims)
245             output_delta = op.__output_mean_orig__ - output_mean
246             output_delta = output_delta * self.bias_calibration_factor
247             bias.data += (output_delta)
248             # # TODO: is this required?
249             # if len(output.size()) == 4:
250             #     output.data += output_delta.data.view(1,-1,1,1)
251             # elif len(output.size()) == 2:
252             #     output.data += output_delta.data.view(1,-1)
253             # else:
254             #     assert False, 'unknown dimensions'
255             # #
256         #
258         # weight = op.weight if (hasattr(op, 'weight') and (op.weight is not None)) else None
259         # iter_threshold = (1.0/self.bias_calibration_factor)
260         # if self.calibrate_weights and (weight is not None) and (self.num_batches_tracked > iter_threshold):
261         #         eps = 1e-3
262         #         output = outputs[0] if isinstance(outputs, (list,tuple)) else outputs
263         #         reduce_dims = [0, 2, 3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
264         #         output_std = torch.std(output, dim=reduce_dims)
265         #         output_ratio = (op.__output_std_orig__ + eps) / (output_std + eps)
266         #         channels = output.size(1)
267         #         output_ratio = output_ratio.view(channels, 1, 1, 1) if len(weight.data.size()) > 1 else output_ratio
268         #         output_ratio = torch.pow(output_ratio, self.bias_calibration_factor)
269         #         output_ratio = torch.clamp(output_ratio, 1.0-self.bias_calibration_factor, 1.0+self.bias_calibration_factor)
270         #         weight.data *= output_ratio
271         #     #
272         # #
274         return outputs