f940fdf785575338f60015aaeba953d6300bb608
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_train_module.py
1 ###########################################################
2 # Approximate quantized floating point simulation with gradients.
3 # Can be used for quantized training of models.
4 ###########################################################
6 import torch
7 import numpy as np
8 import copy
9 import warnings
11 from .. import layers
12 from .. import utils
13 from . import quant_utils
14 from .quant_base_module import *
16 warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
19 ###########################################################
20 class QuantTrainModule(QuantBaseModule):
21     def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
22                  histogram_range=True, bias_calibration=False, constrain_weights=True):
23         super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
24                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
25                          constrain_weights=constrain_weights, model_surgery_quantize=True)
26         # range shrink - 0.0 indicates no shrink
27         percentile_range_shrink = (layers.PAct2.PACT2_RANGE_SHRINK if histogram_range else 0.0)
28         # set attributes to all modules - can control the behaviour from here
29         utils.apply_setattr(self, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
30                             per_channel_q=self.per_channel_q, bias_calibration=self.bias_calibration,
31                             percentile_range_shrink=percentile_range_shrink, constrain_weights=self.constrain_weights,
32                             update_range=True, quantize_enable=True, quantize_weights=True, quantize_bias=True, quantize_activations=True)
34     def forward(self, inputs):
35         # counters such as num_batches_tracked are used. update them.
36         self.update_counters()
37         # outputs
38         outputs = self.module(inputs)
39         return outputs
42     def model_surgery_quantize(self, dummy_input):
43         super().model_surgery_quantize(dummy_input)
45         def replace_func(op):
46             for name, m in op._modules.items():
47                 if utils.is_conv(m):
48                     bias = (m.bias is not None)
49                     padding_mode = m.padding_mode
50                     new_m = QuantTrainConv2d(in_channels=m.in_channels, out_channels=m.out_channels, kernel_size=m.kernel_size, stride=m.stride,
51                                             padding=m.padding, dilation=m.dilation, groups=m.groups, bias=bias, padding_mode=padding_mode)
52                 elif utils.is_deconv(m):
53                     bias = (m.bias is not None)
54                     padding_mode = m.padding_mode
55                     new_m = QuantTrainConvTranspose2d(in_channels=m.in_channels, out_channels=m.out_channels, kernel_size=m.kernel_size, stride=m.stride,
56                                             padding=m.padding, output_padding=m.output_padding, groups=m.groups, bias=bias, dilation=m.dilation, padding_mode=padding_mode)
57                 elif utils.is_bn(m):
58                     new_m = QuantTrainBatchNorm2d(num_features=m.num_features, eps=m.eps, momentum=m.momentum, affine=m.affine,
59                                             track_running_stats=m.track_running_stats)
60                 elif isinstance(m, layers.PAct2):
61                     new_m = QuantTrainPAct2(inplace=m.inplace, signed=m.signed, clip_range=m.clip_range,
62                                              bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
63                                              per_channel_q=self.per_channel_q)
64                 elif isinstance(m, layers.NoAct):
65                     new_m = QuantTrainPAct2(signed=None, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
66                                              per_channel_q=self.per_channel_q)
67                 elif isinstance(m, (torch.nn.ReLU, torch.nn.ReLU6)):
68                     new_m = QuantTrainPAct2(signed=False, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
69                                              per_channel_q=self.per_channel_q)
70                 else:
71                     new_m = None
72                 #
73                 if new_m is not None:
74                     for attr in dir(m):
75                         value = getattr(m,attr)
76                         if isinstance(value,torch.Tensor) and value is not None:
77                             getattr(new_m,attr).data.copy_(value.data)
78                         elif isinstance(value,torch.nn.Module) and value is not None:
79                             setattr(new_m, attr, getattr(m,attr))
80                         elif attr in ('weight','bias','eps','clips_act','clips_w'): # add more attribute name here
81                             setattr(new_m, attr, getattr(m, attr))
82                         #
83                     #
84                     new_m.train(m.training)
85                     setattr(op, name, new_m)
86                 #
87             #
88         #
89         # apply recursively
90         self.apply(replace_func)
92         # clear
93         self.clear_states()
94     #
99 ###########################################################
100 class QuantTrainParams:
101     pass
104 def get_qparams():
105     qparams = QuantTrainParams()
106     qparams.inputs = []
107     qparams.modules = []
108     return qparams
111 def is_merged_layer(x):
112     is_merged = (hasattr(x, 'qparams') and isinstance(x.qparams, QuantTrainParams) and len(x.qparams.modules)>0)
113     return is_merged
116 ###########################################################
117 class QuantTrainConv2d(torch.nn.Conv2d):
118     def __init__(self, *args, **kwargs):
119         super().__init__(*args, **kwargs)
120         self.quantize_enable = True
121         self.bitwidth_weights = None
122         self.bitwidth_activations = None
123         self.per_channel_q = False
125     def forward(self, x):
126         is_merged = is_merged_layer(x)
127         if is_merged:
128            warnings.warn('please see if a PAct can be inserted before this module to collect ranges')
129         #
131         y = super().forward(x)
133         if not self.quantize_enable:
134             # if quantization is disabled - return
135             return y
136         #
138         qparams = get_qparams()
139         qparams.inputs.append(x)
140         qparams.modules.append(self)
141         y.qparams = qparams
142         #
143         return y
144     #
147 ###########################################################
148 class QuantTrainConvTranspose2d(torch.nn.ConvTranspose2d):
149     def __init__(self, *args, **kwargs):
150         super().__init__(*args, **kwargs)
151         self.quantize_enable = True
152         self.bitwidth_weights = None
153         self.bitwidth_activations = None
154         self.per_channel_q = False
156     def forward(self, x):
157         is_merged = is_merged_layer(x)
158         if is_merged:
159            warnings.warn('please see if a PAct can be inserted before this module to collect ranges')
160         #
162         y = super().forward(x)
164         if not self.quantize_enable:
165             # if quantization is disabled - return
166             return y
167         #
169         qparams = get_qparams()
170         qparams.inputs.append(x)
171         qparams.modules.append(self)
172         y.qparams = qparams
173         #
174         return y
175     #
178 ###########################################################
179 class QuantTrainBatchNorm2d(torch.nn.BatchNorm2d):
180     def __init__(self, *args, **kwargs):
181         super().__init__(*args, **kwargs)
182         self.quantize_enable = True
185     def forward(self, x):
186         y = super().forward(x)
188         if not self.quantize_enable:
189             # if quantization is disabled - return
190             return y
191         #
193         if is_merged_layer(x) and utils.is_conv_deconv(x.qparams.modules[-1]):
194             qparams = get_qparams()
195             qparams.inputs = [x.qparams.inputs[0], x]
196             qparams.modules = [x.qparams.modules[0], self]
197             y.qparams = qparams
198         #
200         return y
201     #
204 ###########################################################
205 # fake quantized PAct2 for training
206 class QuantTrainPAct2(layers.PAct2):
207     def __init__(self, inplace=False, signed=False, clip_range=None, bitwidth_weights=None, bitwidth_activations=None, per_channel_q=False):
208         super().__init__(inplace=inplace, signed=signed, clip_range=clip_range)
210         self.bitwidth_weights = bitwidth_weights
211         self.bitwidth_activations = bitwidth_activations
212         self.per_channel_q = per_channel_q
213         # weight shrinking is done by clamp weights - set this factor to zero.
214         # this must me zero - as in pact we do not have the actual weight param, but just a temporary tensor
215         # so any clipping we do here is not stored int he weight params
216         self.range_shrink_weights = 0.0
217         self.round_dither = 0.0
218         self.update_range = True
219         self.quantize_enable = True
220         self.quantize_weights = True
221         self.quantize_bias = True
222         self.quantize_activations = True
223         self.constrain_weights = True
224         self.bias_calibration = False
225         # save quantized weight/bias once in a while into the params - not needed
226         self.params_save_frequency = None #(10 if self.bias_calibration else None)
228         # set to STRAIGHT_THROUGH_ESTIMATION or ALPHA_BLENDING_ESTIMATION
229         # For a comparison of STE and ABE, read:
230         # Learning low-precision neural networks without Straight-Through Estimator (STE):
231         # https://arxiv.org/pdf/1903.01061.pdf
232         self.quantized_estimation_type = QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION
233         self.alpha_blending_estimation_factor = 0.5
235         if (layers.PAct2.PACT2_RANGE_LEARN):
236             assert self.quantized_estimation_type != QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION, \
237                 'straight through estimation should not used when PACT clip values are being learned as it doesnt backpropagate gradients though quantization'
238         #
241     def forward(self, x):
242         assert (self.bitwidth_weights is not None) and (self.bitwidth_activations is not None), \
243                         'bitwidth_weights and bitwidth_activations must not be None'
245         # the pact range update happens here - but range clipping depends on quantize_enable
246         y = super().forward(x, update_range=self.update_range, enable=self.quantize_enable)
248         if not self.quantize_enable:
249             return y
250         #
252         # previous intermediate outputs and other infoirmation are avaliable
253         # for example - conv-bn-relu may need to be merged together.
254         is_merged = is_merged_layer(x)
255         if is_merged:
256             qparams = x.qparams
257             xorg = qparams.inputs[0]
259             conv, bn = None, None
260             # merge weight and bias (if possible) across layers
261             if len(qparams.modules) == 2 and utils.is_conv_deconv(qparams.modules[-2]) and isinstance(
262                     qparams.modules[-1], torch.nn.BatchNorm2d):
263                 conv = qparams.modules[-2]
264                 bn = qparams.modules[-1]
265             elif len(qparams.modules) == 1 and utils.is_conv_deconv(qparams.modules[-1]):
266                 conv = qparams.modules[-1]
267             elif len(qparams.modules) == 1 and isinstance(qparams.modules[-1], torch.nn.BatchNorm2d):
268                 assert False, f'quantization: previous layer is a BN without Conv {qparams.modules} - prease inspect the model carefully'
269                 bn = qparams.modules[-1]
270             #
271             else:
272                 assert False, f'QuantTrainPAct2: both conv & bn layes cannot be None in a merged scenario - prease inspect the model carefully'
273             #
274             conv, weight, bias = self.merge_quantize_weights(conv, bn)
275         else:
276             conv, weight, bias = None, None, None
277         #
279         if is_merged and utils.is_conv(conv):
280             xq = torch.nn.functional.conv2d(xorg, weight, bias, stride=conv.stride, padding=conv.padding, dilation=conv.dilation, groups=conv.groups)
281         elif is_merged and utils.is_deconv(conv):
282             xq = torch.nn.functional.conv_transpose2d(xorg, weight, bias, stride=conv.stride, padding=conv.padding, output_padding=conv.output_padding,
283                                                       dilation=conv.dilation, groups=conv.groups)
284         else:
285             xq = x
286         #
288         if (self.quantize_enable and self.quantize_activations):
289             clip_min, clip_max, scale, scale_inv = self.get_clips_scale_act()
290             width_min, width_max = self.get_widths_act()
291             # no need to call super().forward here as clipping with width_min/windth_max-1 after scaling has the same effect.
292             # currently the gradient is set to 1 in round_up_g. shouldn't the gradient be (y-x)?
293             # see eqn 6 in https://arxiv.org/pdf/1903.08066v2.pdf
294             # yq = layers.clamp_g(layers.round_up_g(xq * scale), width_min, width_max-1, self.training) * scale_inv
295             yq = layers.quantize_dequantize_g(xq, scale, width_min, width_max-1, self.power2, 'round_up')
296         else:
297             yq = super().forward(xq, update_range=False, enable=True)
298         #
300         if (self.quantized_estimation_type == QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION):
301             # replace the float output with quantized version
302             # the entire weight merging and quantization process is bypassed in the forward pass
303             # however, the backward gradients flow through only the float path - this is called straight through estimation (STE)
304             with torch.no_grad():
305                 y.data.copy_(yq.data)
306             #
307         elif self.training and (self.quantized_estimation_type == QuantEstimationType.ALPHA_BLENDING_ESTIMATION):
308             # TODO: vary the alpha blending factor over the epochs
309             y = y * (1.0-self.alpha_blending_estimation_factor) + yq * self.alpha_blending_estimation_factor
310         elif (self.quantized_estimation_type == QuantEstimationType.QUANTIZED_THROUGH_ESTIMATION):
311             # pass on the quantized output - the backward gradients also flow through quantization.
312             # however, note the gradients of round and ceil operators are forced to be unity (1.0).
313             y = yq
314         else:
315             assert False, f'unsupported value for quantized_estimation_type: {self.quantized_estimation_type}'
316         #
318         return y
319     #
322     def apply_constrain_weights(self, merged_weight):
323         return quant_utils.constrain_weight(merged_weight)
326     def merge_quantize_weights(self, conv, bn):
327         # store the quantized weights and biases in-frequently - otherwise learning will be poor
328         # since this may not be done at the of the epoch, there can be a slight mismatch in validation accuracy
329         first_training_iter = self.training and (self.num_batches_tracked == 0)
330         is_store_weight_bias_iter = (self.params_save_frequency is not None) and (torch.remainder(self.num_batches_tracked, self.params_save_frequency) == 0)
332         # merge weight and bias (if possible) across layers
333         if conv is not None and bn is not None:
334             conv_bias = conv.bias if (conv.bias is not None) else torch.tensor(0.0).to(conv.weight.device)
335             #
336             bn_weight = bn.weight if (bn.weight is not None) else torch.tensor(0.0).to(bn.running_mean.device)
337             bn_bias = bn.bias if (bn.bias is not None) else torch.tensor(0.0).to(bn.running_mean.device)
338             #
339             merged_scale = bn_weight / torch.sqrt(bn.running_var + bn.eps)
340             if utils.is_conv(conv):
341                 merged_scale = merged_scale.view(-1, 1, 1, 1)
342             elif utils.is_deconv(conv):
343                 merged_scale = merged_scale.view(1, -1, 1, 1)
344             else:
345                 assert False, 'unable to merge convolution and BN'
346             #
347             merged_bias = (conv_bias - bn.running_mean) * merged_scale.view(-1) + bn_bias
348             merged_weight = conv.weight * merged_scale
349             #
350             merged_scale_sign = merged_scale.sign()
351             merged_scale_sign = merged_scale_sign + (merged_scale_sign == 0) # make the 0s in sign to 1
352             merged_scale_eps = merged_scale.abs().clamp(min=bn.eps) * merged_scale_sign
353             merged_scale_inv = 1.0 / merged_scale_eps
354             #
355         elif conv is not None:
356             merged_weight = conv.weight
357             merged_bias = conv.bias if (conv.bias is not None) else torch.zeros(conv.out_channels).to(conv.weight.device)
358             merged_scale = 1.0
359             merged_scale_inv = 1.0
360         elif bn is not None:
361             merged_weight = bn.weight if (bn.weight is not None) else torch.zeros(bn.num_features).to(bn.running_mean.device)
362             merged_bias = bn.bias if (bn.bias is not None) else torch.zeros(bn.num_features).to(bn.running_mean.device)
363         else:
364             assert False, f'merge_quantize_weights(): both conv & bn layes cannot be None in a merged scenario - prease inspect the model carefully'
365             merged_weight = 0.0
366             merged_bias = 0.0
367         #
369         # quantize weight and bias
370         if (conv is not None):
371             if (self.quantize_enable and self.quantize_weights):
372                 if self.constrain_weights and first_training_iter:
373                     with torch.no_grad():
374                         # clamp merged weights, invert the bn and copy to conv weight
375                         constrained_weight = self.apply_constrain_weights(merged_weight.data)
376                         merged_weight.data.copy_(constrained_weight.data)
377                         # store clipped weight after inverting bn - not really needed as there is a saving below as well
378                         # conv.weight.data.copy_(merged_weight.data * merged_scale_inv)
379                     #
380                 #
382                 is_dw = utils.is_dwconv(conv)
383                 use_per_channel_q = (is_dw and self.per_channel_q is True) or (self.per_channel_q == 'all')
384                 if use_per_channel_q:
385                     channels = int(merged_weight.size(0))
386                     scale2 = torch.zeros(channels,1,1,1).to(merged_weight.device)
387                     scale_inv2 = torch.zeros(channels,1,1,1).to(merged_weight.device)
388                     for chan_id in range(channels):
389                         clip_min, clip_max, scale2_value, scale_inv2_value = self.get_clips_scale_w(merged_weight[chan_id])
390                         scale2[chan_id,0,0,0] = scale2_value
391                         scale_inv2[chan_id,0,0,0] = scale_inv2_value
392                     #
393                 else:
394                     clip_min, clip_max, scale2, scale_inv2 = self.get_clips_scale_w(merged_weight)
395                 #
396                 width_min, width_max = self.get_widths_w()
397                 # merged_weight = layers.clamp_g(layers.round_sym_g(merged_weight * scale2), width_min, width_max-1, self.training) * scale_inv2
398                 merged_weight = layers.quantize_dequantize_g(merged_weight, scale2, width_min, width_max-1, self.power2, 'round_sym')
399             #
401             if (self.quantize_enable and self.quantize_bias):
402                 bias_width_min, bias_width_max = self.get_widths_bias()
403                 bias_clip_min, bias_clip_max, bias_scale2, bias_scale_inv2 = self.get_clips_scale_bias(merged_bias)
404                 # merged_bias = layers.clamp_g(layers.round_sym_g(merged_bias * bias_scale2), bias_width_min, bias_width_max-1, self.training) * bias_scale_inv2
405                 merged_bias = layers.quantize_dequantize_g(merged_bias, bias_scale2, bias_width_min, bias_width_max - 1, self.power2, 'round_sym')
406             #
408             # invert the bn operation and store weights/bias
409             if first_training_iter or (self.training and is_store_weight_bias_iter):
410                 with torch.no_grad():
411                     if self.quantize_enable and self.quantize_weights:
412                         conv.weight.data.copy_(merged_weight.data * merged_scale_inv)
413                     #
414                     if self.quantize_enable and self.quantize_bias:
415                         if conv.bias is not None:
416                             if bn is not None:
417                                 conv_bias = (merged_bias - bn_bias) * merged_scale_inv.view(-1) + bn.running_mean
418                                 conv.bias.data.copy_(conv_bias.data)
419                             else:
420                                 conv.bias.data.copy_(merged_bias.data)
421                             #
422                         elif bn is not None and bn.bias is not None:
423                             bn_bias = merged_bias + bn.running_mean * merged_scale.view(-1)
424                             bn.bias.data.copy_(bn_bias.data)
425                         #
426                     #
427                 #
428             #
429         #
430         return conv, merged_weight, merged_bias
433     def get_clips_w(self, tensor):
434         # find the clip values
435         w_min, w_max = utils.extrema_fast(tensor.data, percentile_range_shrink=self.range_shrink_weights)
436         clip_max = torch.max(torch.abs(w_min), torch.abs(w_max))
437         clip_max = torch.clamp(clip_max, min=self.eps)
438         # in range learning mode + training - this power2 is taken care in the quantize function
439         use_power2 = (self.power2 and (not (self.PACT2_RANGE_LEARN and self.training)))
440         clip_max2 = layers.functional.ceil2_g(clip_max) if use_power2 else clip_max
441         clip_min2 = -clip_max2
442         return (clip_min2, clip_max2)
444     # bias uses the same kind of clips
445     get_clips_bias = get_clips_w
448     def get_clips_scale_w(self, weight):
449         # convert to scale
450         clip_min, clip_max = self.get_clips_w(weight)
451         width_min, width_max = self.get_widths_w()
452         scale2 = (width_max / clip_max)
453         scale2 = torch.clamp(scale2, min=self.eps)
454         scale_inv2 = scale2.pow(-1.0)
455         return (clip_min, clip_max, scale2, scale_inv2)
458     # in reality, bias quantization will also depend on the activation scale
459     # this is not perfect - just a quick and dirty quantization for bias
460     def get_clips_scale_bias(self, bias):
461         # convert to scale
462         clip_min, clip_max = self.get_clips_bias(bias)
463         width_min, width_max = self.get_widths_bias()
464         scale2 = (width_max / clip_max)
465         scale2 = torch.clamp(scale2, min=self.eps)
466         scale_inv2 = scale2.pow(-1.0)
467         return (clip_min, clip_max, scale2, scale_inv2)
470     def get_widths_w(self):
471         # weights
472         bw = (self.bitwidth_weights - 1)
473         width_max = np.power(2.0, bw)
474         width_min = -width_max
475         # return
476         return (width_min, width_max)
479     def get_widths_bias(self):
480         # bias
481         bitwidth_bias = (2*self.bitwidth_activations)
482         bias_width_max = np.power(2.0, bitwidth_bias-1)
483         bias_width_min = -bias_width_max
484         # return
485         return (bias_width_min, bias_width_max)
488     # activation utility functions
489     def get_clips_scale_act(self):
490         # convert to scale
491         clip_min, clip_max = self.get_clips_act()
492         width_min, width_max = self.get_widths_act()
493         scale2 = width_max / clip_max
494         scale2 = torch.clamp(scale2, min=self.eps)
495         scale_inv2 = scale2.pow(-1.0)
496         return (clip_min, clip_max, scale2, scale_inv2)
499     def get_widths_act(self):
500         if self.signed is None:
501             clip_min, clip_max = self.get_clips_act()
502             signed = (clip_min < 0.0)
503         else:
504             signed = self.signed
505         #
506         bw = (self.bitwidth_activations - 1) if signed else self.bitwidth_activations
507         width_max = np.power(2.0, bw)
508         width_min = -width_max if signed else 0.0
509         return width_min, width_max