[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_train_module.py
1 ###########################################################
2 # Approximate quantized floating point simulation with gradients.
3 # Can be used for quantized training of models.
4 ###########################################################
6 import torch
7 import numpy as np
8 import copy
9 import warnings
11 from .. import layers
12 from .. import utils
13 from .quant_train_utils import *
15 warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
18 ###########################################################
19 class QuantTrainModule(QuantBaseModule):
20 def __init__(self, module, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False, histogram_range=True,
21 constrain_weights=True, bias_calibration=False, dummy_input=None):
22 super().__init__(module)
23 assert dummy_input is not None, 'dummy input is needed by quantized models to analyze graph'
24 self.bitwidth_weights = bitwidth_weights
25 self.bitwidth_activations = bitwidth_activations
26 self.per_channel_q = per_channel_q
27 self.constrain_weights = constrain_weights #and (not bool(self.per_channel_q))
28 self.bias_calibration = bias_calibration
30 # for help in debug/print
31 utils.add_module_names(self)
33 # put in eval mode before analyze
34 self.eval()
36 with torch.no_grad():
37 # model surgery for quantization
38 utils.print_yellow("=> model surgery by '{}'".format(self.model_surgery_quantize.__name__))
39 self.model_surgery_quantize(dummy_input)
40 #
42 # range shrink - 0.0 indicates no shrink
43 percentile_range_shrink = (layers.PAct2.PACT2_RANGE_SHRINK if histogram_range else 0.0)
44 # set attributes to all modules - can control the behaviour from here
45 utils.apply_setattr(self, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
46 per_channel_q=per_channel_q, bias_calibration=bias_calibration,
47 percentile_range_shrink=percentile_range_shrink, constrain_weights=self.constrain_weights,
48 update_range=True, quantize_enable=True, quantize_weights=True, quantize_bias=True, quantize_activations=True)
50 # for help in debug/print
51 utils.add_module_names(self)
54 def train(self, mode=True):
55 self.iter_in_epoch.fill_(-1.0)
56 super().train(mode)
59 def forward(self, inputs):
60 # analyze
61 self.analyze_graph(inputs=inputs, cleanup_states=True)
63 # actual forward call
64 if self.training and self.bias_calibration:
65 # bias calibration
66 outputs = self.forward_calibrate_bias(inputs)
67 else:
68 outputs = self.module(inputs)
69 #
70 return outputs
73 def forward_calibrate_bias(self, inputs):
74 assert False, 'forward_calibrate_bias is not implemented'
77 def model_surgery_quantize(self, dummy_input):
78 super().model_surgery_quantize(dummy_input)
80 def replace_func(op):
81 for name, m in op._modules.items():
82 if utils.is_conv(m):
83 bias = (m.bias is not None)
84 padding_mode = m.padding_mode
85 new_m = QuantTrainConv2d(in_channels=m.in_channels, out_channels=m.out_channels, kernel_size=m.kernel_size, stride=m.stride,
86 padding=m.padding, dilation=m.dilation, groups=m.groups, bias=bias, padding_mode=padding_mode)
87 elif utils.is_deconv(m):
88 bias = (m.bias is not None)
89 padding_mode = m.padding_mode
90 new_m = QuantTrainConvTranspose2d(in_channels=m.in_channels, out_channels=m.out_channels, kernel_size=m.kernel_size, stride=m.stride,
91 padding=m.padding, output_padding=m.output_padding, groups=m.groups, bias=bias, dilation=m.dilation, padding_mode=padding_mode)
92 elif utils.is_bn(m):
93 new_m = QuantTrainBatchNorm2d(num_features=m.num_features, eps=m.eps, momentum=m.momentum, affine=m.affine,
94 track_running_stats=m.track_running_stats)
95 elif isinstance(m, layers.PAct2):
96 new_m = QuantTrainPAct2(inplace=m.inplace, signed=m.signed, clip_range=m.clip_range,
97 bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
98 per_channel_q=self.per_channel_q)
99 elif isinstance(m, layers.NoAct):
100 new_m = QuantTrainPAct2(signed=None, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
101 per_channel_q=self.per_channel_q)
102 elif isinstance(m, (torch.nn.ReLU, torch.nn.ReLU6)):
103 new_m = QuantTrainPAct2(signed=False, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
104 per_channel_q=self.per_channel_q)
105 else:
106 new_m = None
107 #
108 if new_m is not None:
109 for attr in dir(m):
110 value = getattr(m,attr)
111 if isinstance(value,torch.Tensor) and value is not None:
112 getattr(new_m,attr).data.copy_(value.data)
113 elif isinstance(value,torch.nn.Module) and value is not None:
114 setattr(new_m, attr, getattr(m,attr))
115 elif attr in ('weight','bias','eps','clips_act','clips_w'): # add more attribute name here
116 setattr(new_m, attr, getattr(m, attr))
117 #
118 #
119 new_m.train(m.training)
120 setattr(op, name, new_m)
121 #
122 #
123 #
124 # apply recursively
125 self.apply(replace_func)
127 # add a forward hook to call the extra activation that we added
128 def _forward_activation(op, inputs, outputs):
129 if hasattr(op, 'activation_q'):
130 outputs = op.activation_q(outputs)
131 #
132 return outputs
133 #
134 for m in self.modules():
135 m.register_forward_hook(_forward_activation)
136 #
138 # clear
139 self.clear_states()
140 #
143 ###########################################################
144 class QuantCalibrateModule(QuantTrainModule):
145 def __init__(self, module, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False, bias_calibration=True,
146 histogram_range=True, constrain_weights=True, lr_calib=0.1, dummy_input=None):
147 self.bias_calibration = bias_calibration
148 self.lr_calib = lr_calib
149 self.bias_calibration_factor = lr_calib
150 self.bias_calibration_gamma = 0.5
151 self.calibrate_weights = False
152 self.calibrate_repeats = 1
153 self.quantize_enable = True
154 self.update_range = True
155 # BNs can be adjusted based on the input provided - however this is not really required
156 self.calibrate_bn = False
157 super().__init__(module, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations, per_channel_q=per_channel_q,
158 histogram_range=histogram_range, constrain_weights=constrain_weights, bias_calibration=bias_calibration, dummy_input=dummy_input)
159 #
162 def forward_calibrate_bias(self, inputs):
163 # we don't need gradients for calibration
164 with torch.no_grad():
165 # prepare/backup weights
166 if self.num_batches_tracked == 0:
167 # lr_calib
168 self.bias_calibration_factor = self.lr_calib * np.power(self.bias_calibration_gamma, float(self.epoch))
169 # backup original weights
170 self._backup_weights_orig()
171 # backup quantized weights
172 self._backup_weights_quant()
173 #
175 # backup the current state
176 training = self.training
178 # compute the mean output in float
179 # also, set all bns to eval. we can't set the whole model to eval because
180 # we need the pact to learn the ranges - which will happen only in training mode.
181 # also the model output itself may be different in eval mode.
182 if self.calibrate_bn:
183 outputs = self.forward_compute_oputput_stats(inputs)
184 utils.freeze_bn(self)
185 else:
186 utils.freeze_bn(self)
187 outputs = self.forward_compute_oputput_stats(inputs)
188 #
190 # adjust the quantized output to match the mean
191 outputs = self.forward_adjust_bias(inputs)
193 self.train(training)
195 return outputs
196 #
199 def forward_compute_oputput_stats(self, inputs):
200 self._restore_weights_orig()
201 # disable quantization for a moment
202 quantize_enable_backup_value, update_range_backup_value = self.quantize_enable, self.update_range
203 utils.apply_setattr(self, quantize_enable=False, update_range=False)
205 self.add_call_hook(self.module, self._forward_compute_oputput_stats_hook)
206 outputs = self.module(inputs)
207 self.remove_call_hook(self.module)
209 # turn quantization back on - not a clean method
210 utils.apply_setattr(self, quantize_enable=quantize_enable_backup_value, update_range=update_range_backup_value)
211 self._backup_weights_orig()
212 return outputs
213 #
214 def _forward_compute_oputput_stats_hook(self, op, *inputs_orig):
215 outputs = op.__forward_orig__(*inputs_orig)
216 # calibration at specific layers
217 bias = op.bias if (hasattr(op, 'bias') and (op.bias is not None)) else None
218 weight = op.weight if (hasattr(op, 'weight') and (op.weight is not None)) else None
219 if (bias is not None) or (self.calibrate_weights and weight is not None):
220 output = outputs[0] if isinstance(outputs, (list,tuple)) else outputs
221 reduce_dims = [0,2,3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
222 op.__output_mean_orig__ = torch.mean(output, dim=reduce_dims)
223 op.__output_std_orig__ = torch.std(output, dim=reduce_dims)
224 #
225 return outputs
226 #
229 def forward_adjust_bias(self, input):
230 self._restore_weights_quant()
231 self.add_call_hook(self.module, self._forward_adjust_bias_hook)
232 for _ in range(self.calibrate_repeats):
233 output = self.module(input)
234 #
235 self.remove_call_hook(self.module)
236 self._backup_weights_quant()
237 return output
238 #
239 def _forward_adjust_bias_hook(self, op, *inputs_orig):
240 outputs = op.__forward_orig__(*inputs_orig)
241 # calibration at specific layers
242 bias = op.bias if (hasattr(op, 'bias') and (op.bias is not None)) else None
243 if bias is not None:
244 output = outputs[0] if isinstance(outputs, (list,tuple)) else outputs
245 reduce_dims = [0,2,3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
246 output_mean = torch.mean(output, dim=reduce_dims)
247 output_delta = op.__output_mean_orig__ - output_mean
248 output_delta = output_delta * self.bias_calibration_factor
249 bias.data += (output_delta)
250 # # TODO: is this required?
251 # if len(output.size()) == 4:
252 # output.data += output_delta.data.view(1,-1,1,1)
253 # elif len(output.size()) == 2:
254 # output.data += output_delta.data.view(1,-1)
255 # else:
256 # assert False, 'unknown dimensions'
257 # #
258 #
260 # weight = op.weight if (hasattr(op, 'weight') and (op.weight is not None)) else None
261 # iter_threshold = (1.0/self.bias_calibration_factor)
262 # if self.calibrate_weights and (weight is not None) and (self.num_batches_tracked > iter_threshold):
263 # eps = 1e-3
264 # output = outputs[0] if isinstance(outputs, (list,tuple)) else outputs
265 # reduce_dims = [0, 2, 3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
266 # output_std = torch.std(output, dim=reduce_dims)
267 # output_ratio = (op.__output_std_orig__ + eps) / (output_std + eps)
268 # channels = output.size(1)
269 # output_ratio = output_ratio.view(channels, 1, 1, 1) if len(weight.data.size()) > 1 else output_ratio
270 # output_ratio = torch.pow(output_ratio, self.bias_calibration_factor)
271 # output_ratio = torch.clamp(output_ratio, 1.0-self.bias_calibration_factor, 1.0+self.bias_calibration_factor)
272 # weight.data *= output_ratio
273 # #
274 # #
276 return outputs