[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_train_module.py
diff --git a/modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py b/modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py
index b1a882d07534ff19fcbc606038651f956a97bcbd..7fe340e705ebf742f4a1ed8f9abed618f7c89c18 100644 (file)
from .. import layers
from .. import utils
-from .quant_train_utils import *
+from . import quant_utils
+from .quant_base_module import *
warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
###########################################################
class QuantTrainModule(QuantBaseModule):
- def __init__(self, module, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False, histogram_range=True,
- constrain_weights=True, bias_calibration=False, dummy_input=None):
- super().__init__(module)
- assert dummy_input is not None, 'dummy input is needed by quantized models to analyze graph'
- self.bitwidth_weights = bitwidth_weights
- self.bitwidth_activations = bitwidth_activations
- self.per_channel_q = per_channel_q
- self.constrain_weights = constrain_weights #and (not bool(self.per_channel_q))
- self.bias_calibration = bias_calibration
-
- # for help in debug/print
- utils.add_module_names(self)
-
- # put in eval mode before analyze
- self.eval()
-
- with torch.no_grad():
- # model surgery for quantization
- utils.print_yellow("=> model surgery by '{}'".format(self.model_surgery_quantize.__name__))
- self.model_surgery_quantize(dummy_input)
- #
-
+ def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
+ histogram_range=True, bias_calibration=False, constrain_weights=None):
+ constrain_weights = (not per_channel_q) if constrain_weights is None else constrain_weights
+ super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
+ per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
+ constrain_weights=constrain_weights, model_surgery_quantize=True)
# range shrink - 0.0 indicates no shrink
percentile_range_shrink = (layers.PAct2.PACT2_RANGE_SHRINK if histogram_range else 0.0)
# set attributes to all modules - can control the behaviour from here
- utils.apply_setattr(self, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations, per_channel_q=per_channel_q, bias_calibration=bias_calibration,
- quantize_enable=True, quantize_weights=True, quantize_bias=True, quantize_activations=True,
- percentile_range_shrink=percentile_range_shrink, constrain_weights=self.constrain_weights)
-
- # for help in debug/print
- utils.add_module_names(self)
-
-
- def train(self, mode=True):
- self.iter_in_epoch.fill_(-1.0)
- super().train(mode)
-
+ utils.apply_setattr(self, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
+ per_channel_q=self.per_channel_q, bias_calibration=self.bias_calibration,
+ percentile_range_shrink=percentile_range_shrink, constrain_weights=self.constrain_weights,
+ update_range=True, quantize_enable=True, quantize_weights=True, quantize_bias=True, quantize_activations=True)
def forward(self, inputs):
- # analyze
- self.analyze_graph(inputs=inputs, cleanup_states=True)
-
- # actual forward call
- if self.training and self.bias_calibration:
- # bias calibration
- outputs = self.forward_calibrate_bias(inputs)
- else:
- outputs = self.module(inputs)
- #
+ # counters such as num_batches_tracked are used. update them.
+ self.update_counters()
+ # outputs
+ outputs = self.module(inputs)
return outputs
- def forward_calibrate_bias(self, inputs):
- assert False, 'forward_calibrate_bias is not implemented'
-
-
def model_surgery_quantize(self, dummy_input):
super().model_surgery_quantize(dummy_input)
elif isinstance(m, layers.NoAct):
new_m = QuantTrainPAct2(signed=None, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
per_channel_q=self.per_channel_q)
- elif isinstance(m, (torch.nn.ReLU, torch.nn.ReLU6, layers.ReLUN)):
+ elif isinstance(m, (torch.nn.ReLU, torch.nn.ReLU6)):
new_m = QuantTrainPAct2(signed=False, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
per_channel_q=self.per_channel_q)
else:
# apply recursively
self.apply(replace_func)
- # add a forward hook to call the extra activation that we added
- def _forward_activation(op, inputs, outputs):
- if hasattr(op, 'activation_q'):
- outputs = op.activation_q(outputs)
- #
- return outputs
+ # clear
+ self.clear_states()
+ #
+
+
+
+
+###########################################################
+class QuantTrainParams:
+ pass
+
+
+def get_qparams():
+ qparams = QuantTrainParams()
+ qparams.inputs = []
+ qparams.modules = []
+ return qparams
+
+
+def is_merged_layer(x):
+ is_merged = (hasattr(x, 'qparams') and isinstance(x.qparams, QuantTrainParams) and len(x.qparams.modules)>0)
+ return is_merged
+
+
+###########################################################
+class QuantTrainConv2d(torch.nn.Conv2d):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.quantize_enable = True
+ self.bitwidth_weights = None
+ self.bitwidth_activations = None
+ self.per_channel_q = False
+
+ def forward(self, x):
+ is_merged = is_merged_layer(x)
+ if is_merged:
+ warnings.warn('please see if a PAct can be inserted before this module to collect ranges')
#
- for m in self.modules():
- m.register_forward_hook(_forward_activation)
+
+ y = super().forward(x)
+
+ if not self.quantize_enable:
+ # if quantization is disabled - return
+ return y
#
- # clear
- self.clear_states()
+ qparams = get_qparams()
+ qparams.inputs.append(x)
+ qparams.modules.append(self)
+ y.qparams = qparams
+ #
+ return y
#
###########################################################
-class QuantCalibrateModule(QuantTrainModule):
- def __init__(self, module, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False, bias_calibration=True,
- histogram_range=True, constrain_weights=True, lr_calib=0.1, dummy_input=None):
- self.bias_calibration = bias_calibration
- self.lr_calib = lr_calib
- self.bias_calibration_factor = lr_calib
- self.bias_calibration_gamma = 0.5
- self.calibrate_weights = False
- self.calibrate_repeats = 1
+class QuantTrainConvTranspose2d(torch.nn.ConvTranspose2d):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
self.quantize_enable = True
- # BNs can be adjusted based on the input provided - however this is not really required
- self.calibrate_bn = False
- super().__init__(module, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations, per_channel_q=per_channel_q,
- histogram_range=histogram_range, constrain_weights=constrain_weights, bias_calibration=bias_calibration, dummy_input=dummy_input)
+ self.bitwidth_weights = None
+ self.bitwidth_activations = None
+ self.per_channel_q = False
+
+ def forward(self, x):
+ is_merged = is_merged_layer(x)
+ if is_merged:
+ warnings.warn('please see if a PAct can be inserted before this module to collect ranges')
+ #
+
+ y = super().forward(x)
+
+ if not self.quantize_enable:
+ # if quantization is disabled - return
+ return y
+ #
+
+ qparams = get_qparams()
+ qparams.inputs.append(x)
+ qparams.modules.append(self)
+ y.qparams = qparams
+ #
+ return y
#
- def forward_calibrate_bias(self, inputs):
- # we don't need gradients for calibration
- with torch.no_grad():
- # prepare/backup weights
- if self.num_batches_tracked == 0:
- # lr_calib
- self.bias_calibration_factor = self.lr_calib * np.power(self.bias_calibration_gamma, float(self.epoch))
- # backup original weights
- self._backup_weights_orig()
- # backup quantized weights
- self._backup_weights_quant()
- #
+###########################################################
+class QuantTrainBatchNorm2d(torch.nn.BatchNorm2d):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.quantize_enable = True
- # backup the current state
- training = self.training
- # compute the mean output in float
- # also, set all bns to eval. we can't set the whole model to eval because
- # we need the pact to learn the ranges - which will happen only in training mode.
- # also the model output itself may be different in eval mode.
- if self.calibrate_bn:
- outputs = self.forward_compute_oputput_stats(inputs)
- utils.freeze_bn(self)
- else:
- utils.freeze_bn(self)
- outputs = self.forward_compute_oputput_stats(inputs)
- #
+ def forward(self, x):
+ y = super().forward(x)
+
+ if not self.quantize_enable:
+ # if quantization is disabled - return
+ return y
+ #
- # adjust the quantized output to match the mean
- outputs = self.forward_adjust_bias(inputs)
+ if is_merged_layer(x) and utils.is_conv_deconv(x.qparams.modules[-1]):
+ qparams = get_qparams()
+ qparams.inputs = [x.qparams.inputs[0], x]
+ qparams.modules = [x.qparams.modules[0], self]
+ y.qparams = qparams
+ #
+
+ return y
+ #
- self.train(training)
- return outputs
+###########################################################
+# fake quantized PAct2 for training
+class QuantTrainPAct2(layers.PAct2):
+ def __init__(self, inplace=False, signed=False, clip_range=None, bitwidth_weights=None, bitwidth_activations=None, per_channel_q=False):
+ super().__init__(inplace=inplace, signed=signed, clip_range=clip_range)
+
+ self.bitwidth_weights = bitwidth_weights
+ self.bitwidth_activations = bitwidth_activations
+ self.per_channel_q = per_channel_q
+ # weight shrinking is done by clamp weights - set this factor to zero.
+ # this must me zero - as in pact we do not have the actual weight param, but just a temporary tensor
+ # so any clipping we do here is not stored int he weight params
+ self.range_shrink_weights = 0.0
+ self.round_dither = 0.0
+ self.update_range = True
+ self.quantize_enable = True
+ self.quantize_weights = True
+ self.quantize_bias = True
+ self.quantize_activations = True
+ self.constrain_weights = True
+ self.bias_calibration = False
+ # save quantized weight/bias once in a while into the params - not needed
+ self.params_save_frequency = None #(10 if self.bias_calibration else None)
+
+ # set to STRAIGHT_THROUGH_ESTIMATION or ALPHA_BLENDING_ESTIMATION
+ # For a comparison of STE and ABE, read:
+ # Learning low-precision neural networks without Straight-Through Estimator (STE):
+ # https://arxiv.org/pdf/1903.01061.pdf
+ self.quantized_estimation_type = QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION
+ self.alpha_blending_estimation_factor = 0.5
+
+ if (layers.PAct2.PACT2_RANGE_LEARN):
+ assert self.quantized_estimation_type != QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION, \
+ 'straight through estimation should not used when PACT clip values are being learned as it doesnt backpropagate gradients though quantization'
#
- def forward_compute_oputput_stats(self, inputs):
- self._restore_weights_orig()
- # disable quantization for a moment
- quantize_enable_backup_value = self.quantize_enable
- utils.apply_setattr(self, quantize_enable=False)
+ def forward(self, x):
+ assert (self.bitwidth_weights is not None) and (self.bitwidth_activations is not None), \
+ 'bitwidth_weights and bitwidth_activations must not be None'
- self.add_call_hook(self.module, self._forward_compute_oputput_stats_hook)
- outputs = self.module(inputs)
- self.remove_call_hook(self.module)
+ # the pact range update happens here - but range clipping depends on quantize_enable
+ y = super().forward(x, update_range=self.update_range, enable=self.quantize_enable)
- # turn quantization back on - not a clean method
- utils.apply_setattr(self, quantize_enable=quantize_enable_backup_value)
- self._backup_weights_orig()
- return outputs
- #
- def _forward_compute_oputput_stats_hook(self, op, *inputs_orig):
- outputs = op.__forward_orig__(*inputs_orig)
- # calibration at specific layers
- bias = op.bias if (hasattr(op, 'bias') and (op.bias is not None)) else None
- weight = op.weight if (hasattr(op, 'weight') and (op.weight is not None)) else None
- if (bias is not None) or (self.calibrate_weights and weight is not None):
- output = outputs[0] if isinstance(outputs, (list,tuple)) else outputs
- reduce_dims = [0,2,3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
- op.__output_mean_orig__ = torch.mean(output, dim=reduce_dims)
- op.__output_std_orig__ = torch.std(output, dim=reduce_dims)
+ if not self.quantize_enable:
+ return y
+ #
+
+ # previous intermediate outputs and other infoirmation are avaliable
+ # for example - conv-bn-relu may need to be merged together.
+ is_merged = is_merged_layer(x)
+ if is_merged:
+ qparams = x.qparams
+ xorg = qparams.inputs[0]
+
+ conv, bn = None, None
+ # merge weight and bias (if possible) across layers
+ if len(qparams.modules) == 2 and utils.is_conv_deconv(qparams.modules[-2]) and isinstance(
+ qparams.modules[-1], torch.nn.BatchNorm2d):
+ conv = qparams.modules[-2]
+ bn = qparams.modules[-1]
+ elif len(qparams.modules) == 1 and utils.is_conv_deconv(qparams.modules[-1]):
+ conv = qparams.modules[-1]
+ elif len(qparams.modules) == 1 and isinstance(qparams.modules[-1], torch.nn.BatchNorm2d):
+ assert False, f'quantization: previous layer is a BN without Conv {qparams.modules} - prease inspect the model carefully'
+ bn = qparams.modules[-1]
+ #
+ else:
+ assert False, f'QuantTrainPAct2: both conv & bn layes cannot be None in a merged scenario - prease inspect the model carefully'
+ #
+ conv, weight, bias = self.merge_quantize_weights(conv, bn)
+ else:
+ conv, weight, bias = None, None, None
#
- return outputs
- #
+ if is_merged and utils.is_conv(conv):
+ xq = torch.nn.functional.conv2d(xorg, weight, bias, stride=conv.stride, padding=conv.padding, dilation=conv.dilation, groups=conv.groups)
+ elif is_merged and utils.is_deconv(conv):
+ xq = torch.nn.functional.conv_transpose2d(xorg, weight, bias, stride=conv.stride, padding=conv.padding, output_padding=conv.output_padding,
+ dilation=conv.dilation, groups=conv.groups)
+ else:
+ xq = x
+ #
+
+ if (self.quantize_enable and self.quantize_activations):
+ clip_min, clip_max, scale, scale_inv = self.get_clips_scale_act()
+ width_min, width_max = self.get_widths_act()
+ # no need to call super().forward here as clipping with width_min/windth_max-1 after scaling has the same effect.
+ # currently the gradient is set to 1 in round_up_g. shouldn't the gradient be (y-x)?
+ # see eqn 6 in https://arxiv.org/pdf/1903.08066v2.pdf
+ # yq = layers.clamp_g(layers.round_up_g(xq * scale), width_min, width_max-1, self.training) * scale_inv
+ yq = layers.quantize_dequantize_g(xq, scale, width_min, width_max-1, self.power2, 'round_up')
+ else:
+ yq = super().forward(xq, update_range=False, enable=True)
+ #
- def forward_adjust_bias(self, input):
- self._restore_weights_quant()
- self.add_call_hook(self.module, self._forward_adjust_bias_hook)
- for _ in range(self.calibrate_repeats):
- output = self.module(input)
+ if (self.quantized_estimation_type == QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION):
+ # replace the float output with quantized version
+ # the entire weight merging and quantization process is bypassed in the forward pass
+ # however, the backward gradients flow through only the float path - this is called straight through estimation (STE)
+ with torch.no_grad():
+ y.data.copy_(yq.data)
+ #
+ elif self.training and (self.quantized_estimation_type == QuantEstimationType.ALPHA_BLENDING_ESTIMATION):
+ # TODO: vary the alpha blending factor over the epochs
+ y = y * (1.0-self.alpha_blending_estimation_factor) + yq * self.alpha_blending_estimation_factor
+ elif (self.quantized_estimation_type == QuantEstimationType.QUANTIZED_THROUGH_ESTIMATION):
+ # pass on the quantized output - the backward gradients also flow through quantization.
+ # however, note the gradients of round and ceil operators are forced to be unity (1.0).
+ y = yq
+ else:
+ assert False, f'unsupported value for quantized_estimation_type: {self.quantized_estimation_type}'
#
- self.remove_call_hook(self.module)
- self._backup_weights_quant()
- return output
+
+ return y
#
- def _forward_adjust_bias_hook(self, op, *inputs_orig):
- outputs = op.__forward_orig__(*inputs_orig)
- # calibration at specific layers
- bias = op.bias if (hasattr(op, 'bias') and (op.bias is not None)) else None
- if bias is not None:
- output = outputs[0] if isinstance(outputs, (list,tuple)) else outputs
- reduce_dims = [0,2,3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
- output_mean = torch.mean(output, dim=reduce_dims)
- output_delta = op.__output_mean_orig__ - output_mean
- output_delta = output_delta * self.bias_calibration_factor
- bias.data += (output_delta)
- # # TODO: is this required?
- # if len(output.size()) == 4:
- # output.data += output_delta.data.view(1,-1,1,1)
- # elif len(output.size()) == 2:
- # output.data += output_delta.data.view(1,-1)
- # else:
- # assert False, 'unknown dimensions'
- # #
+
+
+ def apply_constrain_weights(self, merged_weight):
+ return quant_utils.constrain_weight(merged_weight)
+
+
+ def merge_quantize_weights(self, conv, bn):
+ # store the quantized weights and biases in-frequently - otherwise learning will be poor
+ # since this may not be done at the of the epoch, there can be a slight mismatch in validation accuracy
+ first_training_iter = self.training and (self.num_batches_tracked == 0)
+ is_store_weight_bias_iter = (self.params_save_frequency is not None) and (torch.remainder(self.num_batches_tracked, self.params_save_frequency) == 0)
+
+ # merge weight and bias (if possible) across layers
+ if conv is not None and bn is not None:
+ conv_bias = conv.bias if (conv.bias is not None) else torch.tensor(0.0).to(conv.weight.device)
+ #
+ bn_weight = bn.weight if (bn.weight is not None) else torch.tensor(0.0).to(bn.running_mean.device)
+ bn_bias = bn.bias if (bn.bias is not None) else torch.tensor(0.0).to(bn.running_mean.device)
+ #
+ merged_scale = bn_weight / torch.sqrt(bn.running_var + bn.eps)
+ if utils.is_conv(conv):
+ merged_scale = merged_scale.view(-1, 1, 1, 1)
+ elif utils.is_deconv(conv):
+ merged_scale = merged_scale.view(1, -1, 1, 1)
+ else:
+ assert False, 'unable to merge convolution and BN'
+ #
+ merged_bias = (conv_bias - bn.running_mean) * merged_scale.view(-1) + bn_bias
+ merged_weight = conv.weight * merged_scale
+ #
+ merged_scale_sign = merged_scale.sign()
+ merged_scale_sign = merged_scale_sign + (merged_scale_sign == 0) # make the 0s in sign to 1
+ merged_scale_eps = merged_scale.abs().clamp(min=bn.eps) * merged_scale_sign
+ merged_scale_inv = 1.0 / merged_scale_eps
+ #
+ elif conv is not None:
+ merged_weight = conv.weight
+ merged_bias = conv.bias if (conv.bias is not None) else torch.zeros(conv.out_channels).to(conv.weight.device)
+ merged_scale = 1.0
+ merged_scale_inv = 1.0
+ elif bn is not None:
+ merged_weight = bn.weight if (bn.weight is not None) else torch.zeros(bn.num_features).to(bn.running_mean.device)
+ merged_bias = bn.bias if (bn.bias is not None) else torch.zeros(bn.num_features).to(bn.running_mean.device)
+ else:
+ assert False, f'merge_quantize_weights(): both conv & bn layes cannot be None in a merged scenario - prease inspect the model carefully'
+ merged_weight = 0.0
+ merged_bias = 0.0
#
- # weight = op.weight if (hasattr(op, 'weight') and (op.weight is not None)) else None
- # iter_threshold = (1.0/self.bias_calibration_factor)
- # if self.calibrate_weights and (weight is not None) and (self.num_batches_tracked > iter_threshold):
- # eps = 1e-3
- # output = outputs[0] if isinstance(outputs, (list,tuple)) else outputs
- # reduce_dims = [0, 2, 3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
- # output_std = torch.std(output, dim=reduce_dims)
- # output_ratio = (op.__output_std_orig__ + eps) / (output_std + eps)
- # channels = output.size(1)
- # output_ratio = output_ratio.view(channels, 1, 1, 1) if len(weight.data.size()) > 1 else output_ratio
- # output_ratio = torch.pow(output_ratio, self.bias_calibration_factor)
- # output_ratio = torch.clamp(output_ratio, 1.0-self.bias_calibration_factor, 1.0+self.bias_calibration_factor)
- # weight.data *= output_ratio
- # #
- # #
+ # quantize weight and bias
+ if (conv is not None):
+ if (self.quantize_enable and self.quantize_weights):
+ if self.constrain_weights and first_training_iter:
+ with torch.no_grad():
+ # clamp merged weights, invert the bn and copy to conv weight
+ constrained_weight = self.apply_constrain_weights(merged_weight.data)
+ merged_weight.data.copy_(constrained_weight.data)
+ # store clipped weight after inverting bn - not really needed as there is a saving below as well
+ # conv.weight.data.copy_(merged_weight.data * merged_scale_inv)
+ #
+ #
- return outputs
+ is_dw = utils.is_dwconv(conv)
+ use_per_channel_q = (is_dw and self.per_channel_q is True) or (self.per_channel_q == 'all')
+ if use_per_channel_q:
+ channels = int(merged_weight.size(0))
+ scale2 = torch.zeros(channels,1,1,1).to(merged_weight.device)
+ scale_inv2 = torch.zeros(channels,1,1,1).to(merged_weight.device)
+ for chan_id in range(channels):
+ clip_min, clip_max, scale2_value, scale_inv2_value = self.get_clips_scale_w(merged_weight[chan_id])
+ scale2[chan_id,0,0,0] = scale2_value
+ scale_inv2[chan_id,0,0,0] = scale_inv2_value
+ #
+ else:
+ clip_min, clip_max, scale2, scale_inv2 = self.get_clips_scale_w(merged_weight)
+ #
+ width_min, width_max = self.get_widths_w()
+ # merged_weight = layers.clamp_g(layers.round_sym_g(merged_weight * scale2), width_min, width_max-1, self.training) * scale_inv2
+ merged_weight = layers.quantize_dequantize_g(merged_weight, scale2, width_min, width_max-1, self.power2, 'round_sym')
+ #
+
+ if (self.quantize_enable and self.quantize_bias):
+ bias_width_min, bias_width_max = self.get_widths_bias()
+ bias_clip_min, bias_clip_max, bias_scale2, bias_scale_inv2 = self.get_clips_scale_bias(merged_bias)
+ # merged_bias = layers.clamp_g(layers.round_sym_g(merged_bias * bias_scale2), bias_width_min, bias_width_max-1, self.training) * bias_scale_inv2
+ merged_bias = layers.quantize_dequantize_g(merged_bias, bias_scale2, bias_width_min, bias_width_max - 1, self.power2, 'round_sym')
+ #
+
+ # invert the bn operation and store weights/bias
+ if first_training_iter or (self.training and is_store_weight_bias_iter):
+ with torch.no_grad():
+ if self.quantize_enable and self.quantize_weights:
+ conv.weight.data.copy_(merged_weight.data * merged_scale_inv)
+ #
+ if self.quantize_enable and self.quantize_bias:
+ if conv.bias is not None:
+ if bn is not None:
+ conv_bias = (merged_bias - bn_bias) * merged_scale_inv.view(-1) + bn.running_mean
+ conv.bias.data.copy_(conv_bias.data)
+ else:
+ conv.bias.data.copy_(merged_bias.data)
+ #
+ elif bn is not None and bn.bias is not None:
+ bn_bias = merged_bias + bn.running_mean * merged_scale.view(-1)
+ bn.bias.data.copy_(bn_bias.data)
+ #
+ #
+ #
+ #
+ #
+ return conv, merged_weight, merged_bias
+
+
+ def get_clips_w(self, tensor):
+ # find the clip values
+ w_min, w_max = utils.extrema_fast(tensor.data, percentile_range_shrink=self.range_shrink_weights)
+ clip_max = torch.max(torch.abs(w_min), torch.abs(w_max))
+ clip_max = torch.clamp(clip_max, min=self.eps)
+ # in range learning mode + training - this power2 is taken care in the quantize function
+ use_power2 = (self.power2 and (not (self.PACT2_RANGE_LEARN and self.training)))
+ clip_max2 = layers.functional.ceil2_g(clip_max) if use_power2 else clip_max
+ clip_min2 = -clip_max2
+ return (clip_min2, clip_max2)
+
+ # bias uses the same kind of clips
+ get_clips_bias = get_clips_w
+
+
+ def get_clips_scale_w(self, weight):
+ # convert to scale
+ clip_min, clip_max = self.get_clips_w(weight)
+ width_min, width_max = self.get_widths_w()
+ scale2 = (width_max / clip_max)
+ scale2 = torch.clamp(scale2, min=self.eps)
+ scale_inv2 = scale2.pow(-1.0)
+ return (clip_min, clip_max, scale2, scale_inv2)
+
+
+ # in reality, bias quantization will also depend on the activation scale
+ # this is not perfect - just a quick and dirty quantization for bias
+ def get_clips_scale_bias(self, bias):
+ # convert to scale
+ clip_min, clip_max = self.get_clips_bias(bias)
+ width_min, width_max = self.get_widths_bias()
+ scale2 = (width_max / clip_max)
+ scale2 = torch.clamp(scale2, min=self.eps)
+ scale_inv2 = scale2.pow(-1.0)
+ return (clip_min, clip_max, scale2, scale_inv2)
+
+
+ def get_widths_w(self):
+ # weights
+ bw = (self.bitwidth_weights - 1)
+ width_max = np.power(2.0, bw)
+ width_min = -width_max
+ # return
+ return (width_min, width_max)
+
+
+ def get_widths_bias(self):
+ # bias
+ bitwidth_bias = (2*self.bitwidth_activations)
+ bias_width_max = np.power(2.0, bitwidth_bias-1)
+ bias_width_min = -bias_width_max
+ # return
+ return (bias_width_min, bias_width_max)
+
+
+ # activation utility functions
+ def get_clips_scale_act(self):
+ # convert to scale
+ clip_min, clip_max = self.get_clips_act()
+ width_min, width_max = self.get_widths_act()
+ scale2 = width_max / clip_max
+ scale2 = torch.clamp(scale2, min=self.eps)
+ scale_inv2 = scale2.pow(-1.0)
+ return (clip_min, clip_max, scale2, scale_inv2)
+
+
+ def get_widths_act(self):
+ if self.signed is None:
+ clip_min, clip_max = self.get_clips_act()
+ signed = (clip_min < 0.0)
+ else:
+ signed = self.signed
+ #
+ bw = (self.bitwidth_activations - 1) if signed else self.bitwidth_activations
+ width_max = np.power(2.0, bw)
+ width_min = -width_max if signed else 0.0
+ return width_min, width_max