]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xnn/quantize/quant_train_utils.py
23965eb49caf3385b2a6a2d09cfab2e3d824a146
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_train_utils.py
1 import warnings
2 import numpy as np
3 import torch
4 from .. import utils
5 from .. import layers
6 from .quant_base_module import *
7 from .quant_utils import *
9 ###########################################################
10 class QuantTrainParams:
11     pass
14 def get_qparams():
15     qparams = QuantTrainParams()
16     qparams.inputs = []
17     qparams.modules = []
18     return qparams
21 def is_merged_layer(x):
22     is_merged = (hasattr(x, 'qparams') and isinstance(x.qparams, QuantTrainParams) and len(x.qparams.modules)>0)
23     return is_merged
26 ###########################################################
27 class QuantTrainConv2d(torch.nn.Conv2d):
28     def __init__(self, *args, **kwargs):
29         super().__init__(*args, **kwargs)
30         self.quantize_enable = True
31         self.bitwidth_weights = None
32         self.bitwidth_activations = None
33         self.per_channel_q = False
35     def forward(self, x):
36         is_merged = is_merged_layer(x)
37         if is_merged:
38            warnings.warn('please see if a PAct can be inserted before this module to collect ranges')
39         #
41         y = super().forward(x)
43         if not self.quantize_enable:
44             # if quantization is disabled - return
45             return y
46         #
48         qparams = get_qparams()
49         qparams.inputs.append(x)
50         qparams.modules.append(self)
51         y.qparams = qparams
52         #
53         return y
54     #
57 ###########################################################
58 class QuantTrainConvTranspose2d(torch.nn.ConvTranspose2d):
59     def __init__(self, *args, **kwargs):
60         super().__init__(*args, **kwargs)
61         self.quantize_enable = True
62         self.bitwidth_weights = None
63         self.bitwidth_activations = None
64         self.per_channel_q = False
66     def forward(self, x):
67         is_merged = is_merged_layer(x)
68         if is_merged:
69            warnings.warn('please see if a PAct can be inserted before this module to collect ranges')
70         #
72         y = super().forward(x)
74         if not self.quantize_enable:
75             # if quantization is disabled - return
76             return y
77         #
79         qparams = get_qparams()
80         qparams.inputs.append(x)
81         qparams.modules.append(self)
82         y.qparams = qparams
83         #
84         return y
85     #
88 ###########################################################
89 class QuantTrainBatchNorm2d(torch.nn.BatchNorm2d):
90     def __init__(self, *args, **kwargs):
91         super().__init__(*args, **kwargs)
92         self.quantize_enable = True
95     def forward(self, x):
96         y = super().forward(x)
98         if not self.quantize_enable:
99             # if quantization is disabled - return
100             return y
101         #
103         if is_merged_layer(x) and utils.is_conv_deconv(x.qparams.modules[-1]):
104             qparams = get_qparams()
105             qparams.inputs = [x.qparams.inputs[0], x]
106             qparams.modules = [x.qparams.modules[0], self]
107             y.qparams = qparams
108         #
110         return y
111     #
114 ###########################################################
115 # fake quantized PAct2 for training
116 class QuantTrainPAct2(layers.PAct2):
117     def __init__(self, inplace=False, signed=False, clip_range=None, bitwidth_weights=None, bitwidth_activations=None, per_channel_q=False):
118         super().__init__(inplace=inplace, signed=signed, clip_range=clip_range)
120         self.bitwidth_weights = bitwidth_weights
121         self.bitwidth_activations = bitwidth_activations
122         self.per_channel_q = per_channel_q
123         # weight shrinking is done by clamp weights - set this factor to zero.
124         # this must me zero - as in pact we do not have the actual weight param, but just a temporary tensor
125         # so any clipping we do here is not stored int he weight params
126         self.range_shrink_weights = 0.0
127         self.round_dither = 0.0
128         self.update_range = True
129         self.quantize_enable = True
130         self.quantize_weights = True
131         self.quantize_bias = True
132         self.quantize_activations = True
133         self.constrain_weights = True
134         self.bias_calibration = False
135         # save quantized weight/bias once in a while into the params - not needed
136         self.params_save_frequency = None #(10 if self.bias_calibration else None)
138         # set to STRAIGHT_THROUGH_ESTIMATION or ALPHA_BLENDING_ESTIMATION
139         # For a comparison of STE and ABE, read:
140         # Learning low-precision neural networks without Straight-Through Estimator (STE):
141         # https://arxiv.org/pdf/1903.01061.pdf
142         self.quantized_estimation_type = QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION
143         self.alpha_blending_estimation_factor = 0.5
145         if (layers.PAct2.PACT2_RANGE_LEARN):
146             assert self.quantized_estimation_type != QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION, \
147                 'straight through estimation should not used when PACT clip values are being learned as it doesnt backpropagate gradients though quantization'
148         #
151     def forward(self, x):
152         assert (self.bitwidth_weights is not None) and (self.bitwidth_activations is not None), \
153                         'bitwidth_weights and bitwidth_activations must not be None'
155         # the pact range update happens here - but range clipping depends on quantize_enable
156         y = super().forward(x, update_range=self.update_range, enable=self.quantize_enable)
158         if not self.quantize_enable:
159             return y
160         #
162         # previous intermediate outputs and other infoirmation are avaliable
163         # for example - conv-bn-relu may need to be merged together.
164         is_merged = is_merged_layer(x)
165         if is_merged:
166             qparams = x.qparams
167             xorg = qparams.inputs[0]
168             conv, weight, bias = self.merge_quantize_weights(qparams, is_merged)
169         else:
170             conv, weight, bias = None, None, None
171         #
173         if is_merged and utils.is_conv(conv):
174             xq = torch.nn.functional.conv2d(xorg, weight, bias, stride=conv.stride, padding=conv.padding, dilation=conv.dilation, groups=conv.groups)
175         elif is_merged and utils.is_deconv(conv):
176             xq = torch.nn.functional.conv_transpose2d(xorg, weight, bias, stride=conv.stride, padding=conv.padding, output_padding=conv.output_padding,
177                                                       dilation=conv.dilation, groups=conv.groups)
178         else:
179             xq = x
180         #
182         if (self.quantize_enable and self.quantize_activations):
183             clip_min, clip_max, scale, scale_inv = self.get_clips_scale_act()
184             width_min, width_max = self.get_widths_act()
185             # no need to call super().forward here as clipping with width_min/windth_max-1 after scaling has the same effect.
186             # currently the gradient is set to 1 in round_up_g. shouldn't the gradient be (y-x)?
187             # see eqn 6 in https://arxiv.org/pdf/1903.08066v2.pdf
188             # yq = layers.clamp_g(layers.round_up_g(xq * scale), width_min, width_max-1, self.training) * scale_inv
189             yq = layers.quantize_dequantize_g(xq, scale, width_min, width_max-1, self.power2, 'round_up')
190         else:
191             yq = super().forward(xq, update_range=False, enable=True)
192         #
194         if (self.quantized_estimation_type == QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION):
195             # replace the float output with quantized version
196             # the entire weight merging and quantization process is bypassed in the forward pass
197             # however, the backward gradients flow through only the float path - this is called straight through estimation (STE)
198             with torch.no_grad():
199                 y.data.copy_(yq.data)
200             #
201         elif self.training and (self.quantized_estimation_type == QuantEstimationType.ALPHA_BLENDING_ESTIMATION):
202             # TODO: vary the alpha blending factor over the epochs
203             y = y * (1.0-self.alpha_blending_estimation_factor) + yq * self.alpha_blending_estimation_factor
204         elif (self.quantized_estimation_type == QuantEstimationType.QUANTIZED_THROUGH_ESTIMATION):
205             # pass on the quantized output - the backward gradients also flow through quantization.
206             # however, note the gradients of round and ceil operators are forced to be unity (1.0).
207             y = yq
208         else:
209             assert False, f'unsupported value for quantized_estimation_type: {self.quantized_estimation_type}'
210         #
212         return y
213     #
215     def apply_constrain_weights(self, merged_weight):
216         return constrain_weight(merged_weight)
219     def merge_quantize_weights(self, qparams, is_merged):
220         # store the quantized weights and biases in-frequently - otherwise learning will be poor
221         # since this may not be done at the of the epoch, there can be a slight mismatch in validation accuracy
222         first_training_iter = self.training and (self.num_batches_tracked == 0)
223         is_store_weight_bias_iter = (self.params_save_frequency is not None) and (torch.remainder(self.num_batches_tracked, self.params_save_frequency) == 0)
225         conv, bn = None, None
226         # merge weight and bias (if possible) across layers
227         if len(qparams.modules) == 2 and utils.is_conv_deconv(qparams.modules[-2]) and isinstance(qparams.modules[-1], torch.nn.BatchNorm2d):
228             conv = qparams.modules[-2]
229             conv_bias = conv.bias if (conv.bias is not None) else torch.tensor(0.0).to(conv.weight.device)
230             #
231             bn = qparams.modules[-1]
232             bn_weight = bn.weight if (bn.weight is not None) else torch.tensor(0.0).to(bn.running_mean.device)
233             bn_bias = bn.bias if (bn.bias is not None) else torch.tensor(0.0).to(bn.running_mean.device)
234             #
235             merged_scale = bn_weight / torch.sqrt(bn.running_var + bn.eps)
236             merged_bias = (conv_bias - bn.running_mean) * merged_scale + bn_bias
237             merged_weight = conv.weight * merged_scale.view(-1, 1, 1, 1)
238             #
239             merged_scale_sign = merged_scale.sign()
240             merged_scale_sign = merged_scale_sign + (merged_scale_sign == 0) # make the 0s in sign to 1
241             merged_scale_eps = merged_scale.abs().clamp(min=bn.eps) * merged_scale_sign
242             merged_scale_inv = 1.0 / merged_scale_eps
243             #
244         elif len(qparams.modules) == 1 and utils.is_conv_deconv(qparams.modules[-1]):
245             conv = qparams.modules[-1]
246             merged_weight = conv.weight
247             merged_bias = conv.bias if (conv.bias is not None) else torch.zeros(conv.out_channels).to(conv.weight.device)
248             merged_scale = torch.ones(conv.out_channels).to(conv.weight.device)
249             merged_scale_inv = torch.ones(conv.out_channels).to(conv.weight.device)
250         elif len(qparams.modules) == 1 and isinstance(qparams.modules[-1], torch.nn.BatchNorm2d):
251             assert False, f'quantization: previous layer is a BN without Conv {qparams.modules} - prease inspect the model carefully'
252             bn = qparams.modules[-1]
253             merged_weight = bn.weight if (bn.weight is not None) else torch.zeros(bn.num_features).to(bn.running_mean.device)
254             merged_bias = bn.bias if (bn.bias is not None) else torch.zeros(bn.num_features).to(bn.running_mean.device)
255         else:
256             assert False, f'quantization: previous layer is unrecognized {qparams.modules} - prease inspect the model carefully'
257             merged_weight = 0.0
258             merged_bias = 0.0
259         #
261         # quantize weight and bias
262         if (conv is not None):
263             if (self.quantize_enable and self.quantize_weights):
264                 if self.constrain_weights and first_training_iter:
265                     with torch.no_grad():
266                         # clamp merged weights, invert the bn and copy to conv weight
267                         constrained_weight = self.apply_constrain_weights(merged_weight.data)
268                         merged_weight.data.copy_(constrained_weight.data)
269                         # store clipped weight after inverting bn
270                         conv.weight.data.copy_(merged_weight.data * merged_scale_inv.view(-1, 1, 1, 1))
271                     #
272                 #
274                 is_dw = utils.is_dwconv(conv)
275                 use_per_channel_q = (is_dw and self.per_channel_q is True) or (self.per_channel_q == 'all')
276                 if use_per_channel_q:
277                     channels = int(merged_weight.size(0))
278                     scale2 = torch.zeros(channels,1,1,1).to(merged_weight.device)
279                     scale_inv2 = torch.zeros(channels,1,1,1).to(merged_weight.device)
280                     for chan_id in range(channels):
281                         clip_min, clip_max, scale2_value, scale_inv2_value = self.get_clips_scale_w(merged_weight[chan_id])
282                         scale2[chan_id,0,0,0] = scale2_value
283                         scale_inv2[chan_id,0,0,0] = scale_inv2_value
284                     #
285                 else:
286                     clip_min, clip_max, scale2, scale_inv2 = self.get_clips_scale_w(merged_weight)
287                 #
288                 width_min, width_max = self.get_widths_w()
289                 # merged_weight = layers.clamp_g(layers.round_sym_g(merged_weight * scale2), width_min, width_max-1, self.training) * scale_inv2
290                 merged_weight = layers.quantize_dequantize_g(merged_weight, scale2, width_min, width_max-1, self.power2, 'round_sym')
291             #
293             if (self.quantize_enable and self.quantize_bias):
294                 bias_width_min, bias_width_max = self.get_widths_bias()
295                 bias_clip_min, bias_clip_max, bias_scale2, bias_scale_inv2 = self.get_clips_scale_bias(merged_bias)
296                 # merged_bias = layers.clamp_g(layers.round_sym_g(merged_bias * bias_scale2), bias_width_min, bias_width_max-1, self.training) * bias_scale_inv2
297                 merged_bias = layers.quantize_dequantize_g(merged_bias, bias_scale2, bias_width_min, bias_width_max - 1, self.power2, 'round_sym')
298             #
300             # invert the bn operation and store weights/bias
301             if self.training and is_store_weight_bias_iter:
302                 with torch.no_grad():
303                     if self.quantize_enable and self.quantize_weights:
304                         conv.weight.data.copy_(merged_weight.data * merged_scale_inv.view(-1, 1, 1, 1))
305                     #
306                     if self.quantize_enable and self.quantize_bias:
307                         if conv.bias is not None:
308                             if bn is not None:
309                                 conv_bias = (merged_bias - bn_bias) * merged_scale_inv.view(-1) + bn.running_mean
310                                 conv.bias.data.copy_(conv_bias.data)
311                             else:
312                                 conv.bias.data.copy_(merged_bias.data)
313                             #
314                         elif bn is not None and bn.bias is not None:
315                             bn_bias = merged_bias + bn.running_mean * merged_scale.view(-1)
316                             bn.bias.data.copy_(bn_bias.data)
317                         #
318                     #
319                 #
320             #
321         #
322         return conv, merged_weight, merged_bias
325     def get_clips_w(self, tensor):
326         # find the clip values
327         w_min, w_max = utils.extrema_fast(tensor.data, percentile_range_shrink=self.range_shrink_weights)
328         clip_max = torch.max(torch.abs(w_min), torch.abs(w_max))
329         clip_max = torch.clamp(clip_max, min=self.eps)
330         # in range learning mode + training - this power2 is taken care in the quantize function
331         use_power2 = (self.power2 and (not (self.PACT2_RANGE_LEARN and self.training)))
332         clip_max2 = layers.functional.ceil2_g(clip_max) if use_power2 else clip_max
333         clip_min2 = -clip_max2
334         return (clip_min2, clip_max2)
336     # bias uses the same kind of clips
337     get_clips_bias = get_clips_w
340     def get_clips_scale_w(self, weight):
341         # convert to scale
342         clip_min, clip_max = self.get_clips_w(weight)
343         width_min, width_max = self.get_widths_w()
344         scale2 = (width_max / clip_max)
345         scale2 = torch.clamp(scale2, min=self.eps)
346         scale_inv2 = scale2.pow(-1.0)
347         return (clip_min, clip_max, scale2, scale_inv2)
350     # in reality, bias quantization will also depend on the activation scale
351     # this is not perfect - just a quick and dirty quantization for bias
352     def get_clips_scale_bias(self, bias):
353         # convert to scale
354         clip_min, clip_max = self.get_clips_bias(bias)
355         width_min, width_max = self.get_widths_bias()
356         scale2 = (width_max / clip_max)
357         scale2 = torch.clamp(scale2, min=self.eps)
358         scale_inv2 = scale2.pow(-1.0)
359         return (clip_min, clip_max, scale2, scale_inv2)
362     def get_widths_w(self):
363         # weights
364         bw = (self.bitwidth_activations - 1)
365         width_max = np.power(2.0, bw)
366         width_min = -width_max
367         # return
368         return (width_min, width_max)
371     def get_widths_bias(self):
372         # bias
373         bitwidth_bias = (2*self.bitwidth_activations)
374         bias_width_max = np.power(2.0, bitwidth_bias-1)
375         bias_width_min = -bias_width_max
376         # return
377         return (bias_width_min, bias_width_max)
380     # activation utility functions
381     def get_clips_scale_act(self):
382         # convert to scale
383         clip_min, clip_max = self.get_clips_act()
384         width_min, width_max = self.get_widths_act()
385         scale2 = width_max / clip_max
386         scale2 = torch.clamp(scale2, min=self.eps)
387         scale_inv2 = scale2.pow(-1.0)
388         return (clip_min, clip_max, scale2, scale_inv2)
391     def get_widths_act(self):
392         if self.signed is None:
393             clip_min, clip_max = self.get_clips_act()
394             signed = (clip_min < 0.0)
395         else:
396             signed = self.signed
397         #
398         bw = (self.bitwidth_activations - 1) if signed else self.bitwidth_activations
399         width_max = np.power(2.0, bw)
400         width_min = -width_max if signed else 0.0
401         return width_min, width_max