dba9fa91ea19a6d4bdc6d704b6f8c4a21ec257d2
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_calib_module.py
1 ###########################################################
2 # Approximate quantized floating point simulation with gradients.
3 # Can be used for quantized training of models.
4 ###########################################################
6 import torch
7 import numpy as np
8 import copy
9 import warnings
11 from .. import layers
12 from .. import utils
13 from .quant_train_module import *
15 warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
18 ###########################################################
19 class QuantCalibrateModule(QuantTrainModule):
20     def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
21                  histogram_range=True, bias_calibration=True, constrain_weights=True, lr_calib=0.05):
22         self.weights_calibration = False
23         self.lr_calib = lr_calib
24         self.calibration_factor = lr_calib
25         self.calibration_gamma = 0.5
26         self.calibrate_repeats = 1
27         self.quantize_enable = True
28         self.update_range = True
29         constrain_weights = (constrain_weights and bias_calibration)
30         super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
31                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
32                          constrain_weights=constrain_weights)
35     def forward(self, inputs):
36         # calibration doesn't need gradients
37         with torch.no_grad():
38             # backup the current state
39             training = self.training
41             # Set all bns to eval so that they do not change. We can't set the whole model to eval because,
42             # we need the pact to learn the ranges - which will happen only in training mode.
43             # Also the model output itself may be different in eval mode (in certain cases -
44             # for example if in a segmentation model argmax is done instead of softmax in eval mode).
45             utils.freeze_bn(self)
47             # counters such as num_batches_tracked are used. update them.
48             self.update_counters()
50             # actual forward call
51             if self.training and (self.bias_calibration or self.weights_calibration):
52                 # calibration
53                 outputs = self.forward_calibrate(inputs)
54             else:
55                 outputs = self.module(inputs)
56             #
58             self.train(training)
59         #
60         return outputs
63     def forward_calibrate(self, inputs):
64         # we don't need gradients for calibration
65         # prepare/backup weights
66         if self.num_batches_tracked == 0:
67             # lr_calib
68             self.calibration_factor = self.lr_calib * np.power(self.calibration_gamma, float(self.epoch))
69             # backup original weights
70             self._backup_weights_orig()
71             # backup quantized weights
72             self._backup_weights_quant()
73         #
75         # Compute the mean output in float first.
76         outputs = self.forward_float(inputs)
77         # Then adjust weights/bias so that the quantized output matches float output
78         outputs = self.forward_quantized(inputs)
80         return outputs
83     def forward_float(self, inputs):
84         self._restore_weights_orig()
85         # disable quantization for a moment
86         quantize_enable_backup_value, update_range_backup_value = self.quantize_enable, self.update_range
87         utils.apply_setattr(self, quantize_enable=False, update_range=False)
89         self.add_call_hook(self.module, self.forward_float_hook)
90         outputs = self.module(inputs)
91         self.remove_call_hook(self.module)
93         # turn quantization back on - not a clean method
94         utils.apply_setattr(self, quantize_enable=quantize_enable_backup_value, update_range=update_range_backup_value)
95         self._backup_weights_orig()
96         return outputs
97     #
98     def forward_float_hook(self, op, *inputs_orig):
99         outputs = op.__forward_orig__(*inputs_orig)
101         # calibration at specific layers
102         output = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
103         reduce_dims = [0, 2, 3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
105         bias = op.bias if hasattr(op, 'bias') else None
106         if (self.bias_calibration and bias is not None):
107             op.__output_mean_orig__ = torch.mean(output, dim=reduce_dims).data
108         #
110         if self.weights_calibration and utils.is_conv_deconv(op):
111             op.__output_std_orig__ = torch.std(output, dim=reduce_dims).data
112         #
113         return outputs
114     #
117     def forward_quantized(self, input):
118         self._restore_weights_quant()
119         self.add_call_hook(self.module, self.forward_quantized_hook)
120         for _ in range(self.calibrate_repeats):
121             output = self.module(input)
122         #
123         self.remove_call_hook(self.module)
124         self._backup_weights_quant()
125         return output
126     #
127     def forward_quantized_hook(self, op, *inputs_orig):
128         outputs = op.__forward_orig__(*inputs_orig)
130         # calibration at specific layers
131         output = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
133         bias = op.bias if hasattr(op, 'bias') else None
134         if self.bias_calibration and bias is not None:
135             reduce_dims = [0,2,3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
136             output_mean = torch.mean(output, dim=reduce_dims).data
137             output_delta = op.__output_mean_orig__ - output_mean
138             output_delta = output_delta * self.calibration_factor
139             bias.data += (output_delta)
140         #
142         if self.weights_calibration and utils.is_conv_deconv(op):
143             eps = 1e-6
144             weight = op.weight
145             reduce_dims = [0, 2, 3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
146             output_std = torch.std(output, dim=reduce_dims).data
147             output_ratio = (op.__output_std_orig__ + eps) / (output_std + eps)
148             channels = output.size(1)
149             output_ratio = output_ratio.view(channels, 1, 1, 1) if len(weight.data.size()) > 1 else output_ratio
150             output_ratio = torch.pow(output_ratio, self.calibration_factor)
151             weight.data *= output_ratio
152         #
153         return outputs