updated quantization modules to support mmdetection, using Hardtanh for fixed range...
authorManu Mathew <a0393608@ti.com>
Tue, 16 Jun 2020 08:02:49 +0000 (13:32 +0530)
committerManu Mathew <a0393608@ti.com>
Tue, 16 Jun 2020 08:06:14 +0000 (13:36 +0530)
20 files changed:
docs/Quantization.md
modules/pytorch_jacinto_ai/engine/train_classification.py
modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py
modules/pytorch_jacinto_ai/vision/models/pixel2pixel/pixel2pixelnet_utils.py
modules/pytorch_jacinto_ai/xnn/__init__.py
modules/pytorch_jacinto_ai/xnn/layers/activation.py
modules/pytorch_jacinto_ai/xnn/layers/common_blocks.py
modules/pytorch_jacinto_ai/xnn/layers/conv_blocks.py
modules/pytorch_jacinto_ai/xnn/onnx/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/optim/lr_scheduler.py
modules/pytorch_jacinto_ai/xnn/quantize/hooked_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_base_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_calib_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_test_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py
modules/pytorch_jacinto_ai/xnn/utils/load_weights.py
run_quantization.sh
scripts/train_classification_main.py
scripts/train_segmentation_main.py

index 30350506dabac9eb41f9a30dfe22ceb95fb7d415..83cbffe28faa3de87b15be66a252e82827ff9d83 100644 (file)
@@ -25,13 +25,14 @@ To get best accuracy at the quantization stage, it is important that the model i
 
 ## Implementation Notes, Limitations & Recommendations
 - **Please read carefully** - closely following these recommendations can save hours or days of debug related to quantization accuracy issues.
-- **Use Modules instead of functions** (by Module we mean classes derived from torch.nn.Module). We make use of Modules heavily in our quantization tools - in order to do range collection, in order to merge Convolution/BN/ReLU in order to decide whether to quantize a certain tensor and so on. For example use torch.nn.ReLU instead of torch.nn.functional.relu(), torch.nn.AdaptiveAvgPool2d() instead of torch.nn.functional.adaptive_avg_pool2d(), torch.nn.Flatten() instead of torch.nn.functional.flatten() etc.<br>
 - **The same module should not be re-used multiple times within the module** in order that the feature map range estimation is correct. Unfortunately, in the torchvision ResNet models, the ReLU module in the BasicBlock and BottleneckBlock are re-used multiple times. We have corrected this by defining separate ReLU modules. This change is minor and **does not** affect the loading of existing pretrained weights. See the [our modified ResNet model definition here](../modules/pytorch_jacinto_ai/vision/models/resnet.py).<br>
+- **Use Modules instead of functionals or tensor operations** (by Module we mean classes derived from torch.nn.Module). We make use of Modules heavily in our quantization tools - in order to do range collection, in order to merge Convolution/BatchNorm/ReLU in order to decide whether to quantize a certain tensor and so on. For example use torch.nn.ReLU instead of torch.nn.functional.relu(), torch.nn.AdaptiveAvgPool2d() instead of torch.nn.functional.adaptive_avg_pool2d(), torch.nn.Flatten() instead of torch.nn.functional.flatten() etc.<br>
+- Other notable modules provided are: [xnn.layers.AddBlock](../modules/pytorch_jacinto_ai/xnn/layers/common_blocks.py) to do elementwise addition and [xnn.layers.CatBlock](../modules/pytorch_jacinto_ai/xnn/layers/common_blocks.py) to do concatenation of tensors. Use these in the models instead of tensor operations. Note that if there are multiple element wise additions in a model, each of them should use a different instance of xnn.layers.AddBlock (since the same module should not be re-used multiple times - see above). The same restriction applies for xnn.layers.CatBlock or any other module as well.
+- **Interpolation/Upsample/Resize** has been tricky in PyTorch in the sense that the ONNX graph generated used to be unnecessarily complicated. Recent versions of PyTorch has fixed it - but the right options must be used to get the clean graph. We have provided a functional form as well as a module form of this operator with the capability to export a clean ONNX graph [xnn.layers.resize_with, xnn.layers.ResizeWith](../modules/pytorch_jacinto_ai/xnn/layers/resize_blocks.py)
 - If you have done QAT and is getting poor accuracy either in the Python code or during inference in the platform, please inspect your model carefully to see if the above recommendations have been followed - some of these can be easily missed by oversight - and can result in painful debugging that could have been avoided.<br>
 - However, if a function does not change the range of feature map, it is not critical to use it in Module form. An example of this is torch.nn.functional.interpolate<br>
-- **Multi-GPU training/calibration/validation with DataParallel is supported with our QAT module** QuantTrainModule. This takes care of a major concern that was earlier there in doing QAT with QuantTrainModule. (However it is not supported for QuantCalibrateModule/QuantTestModule - these calibration/test phases take much less time - so hopefully this is not a big issue. In our example training scripts train_classification.py and train_pixel2pixel.py in pytorch_jacinto_ai/engine, we do not wrap the model in DataParallel if the model is QuantCalibrateModule or QuantTestModule, but we do that for QuantTrainModule).<br>
+- **Multi-GPU training/validation with DataParallel** is supported with our QAT module QuantTrainModule and Test module QuantTestModule. This takes care of a major concern that was earlier there in doing QAT with QuantTrainModule. (However it is not supported for QuantCalibrateModule - calibration take much less time - so hopefully this is not a big issue. In our example training scripts train_classification.py and train_pixel2pixel.py in pytorch_jacinto_ai/engine, we do not wrap the model in DataParallel if the model is QuantCalibrateModule, but we do that for QuantTrainModule and QuantTestModule).<br>
 - If your training/calibration crashes because of insufficient GPU memory, reduce the batch size and try again.
-- This repository has several useful functions and Modules as part of the xnn python module. Most notable ones are: [xnn.layers.resize_with, xnn.layers.ResizeWith](../modules/pytorch_jacinto_ai/xnn/layers/resize_blocks.py) to export a clean resize/interpolate/upsamle graph, [xnn.layers.AddBlock, xnn.layers.CatBlock](../modules/pytorch_jacinto_ai/xnn/layers/common_blocks.py) to do elementwise addition & concatenation in a torch.nn.Module form.
 - If you are using TIDL to infer a model trained using QAT (or calibratied using PTQ) tools provided in this repository, please set the following in the import config file for best accuracy: **quantizationStyle = 3** to use power of 2 quantization. **foldPreBnConv2D = 0** to avoid a slight accuracy degradation due to incorrect folding of BatchNormalization that comes before Convolution (input mean/scale is implemented in TIDL as a PreBN - so this affects most networks).
 
 ## Post Training Calibration For Quantization (PTQ a.k.a. Calibration)
index 609b134ee255de0c3b277be17932f6890aa8e253..1747931b9abdab99496b1cb1587f957471569c6d 100644 (file)
@@ -223,8 +223,16 @@ def main(args):
     #################################################
     # create model
     print("=> creating model '{}'".format(args.model_name))
-    
-    model = vision.models.classification.__dict__[args.model_name](args.model_config) if args.model == None else args.model
+
+    is_onnx_model = False
+    if isinstance(args.model, torch.nn.Module):
+        model = args.model
+    elif isinstance(args.model, str) and args.model.endswith('.onnx'):
+        model = xnn.onnx.import_onnx(args.model)
+        is_onnx_model = True
+    else:
+        model = vision.models.classification.__dict__[args.model_name](args.model_config)
+    #
 
     # check if we got the model as well as parameters to change the names in pretrained
     model, change_names_dict = model if isinstance(model, (list,tuple)) else (model,None)
@@ -256,7 +264,7 @@ def main(args):
     #
 
     # load pretrained
-    if pretrained_data is not None:
+    if pretrained_data is not None and not is_onnx_model:
         xnn.utils.load_weights(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
     #
     
index 6e9e6235059715c890905a41f55a4a0a4b509cc6..6f6ffd92acc1e20f49401f507ae4ede3242561ea 100644 (file)
@@ -337,9 +337,13 @@ def main(args):
 
     #################################################
     # create model
-    if args.model is not None:
+    is_onnx_model = False
+    if isinstance(args.model, torch.nn.Module):
         model, change_names_dict = args.model if isinstance(args.model, (list, tuple)) else (args.model, None)
         assert isinstance(model, torch.nn.Module), 'args.model, if provided must be a valid torch.nn.Module'
+    elif isinstance(args.model, str) and args.model.endswith('.onnx'):
+        model = xnn.onnx.import_onnx(args.model)
+        is_onnx_model = True
     else:
         xnn.utils.print_yellow("=> creating model '{}'".format(args.model_name))
         model = vision.models.pixel2pixel.__dict__[args.model_name](args.model_config)
@@ -371,7 +375,7 @@ def main(args):
     #
 
     # load pretrained model
-    if pretrained_data is not None:
+    if pretrained_data is not None and not is_onnx_model:
         for (p_data,p_file) in zip(pretrained_data, pretrained_files):
             print("=> using pretrained weights from: {}".format(p_file))
             xnn.utils.load_weights(get_model_orig(model), pretrained=p_data, change_names_dict=change_names_dict)
index 3073552365211a63424e89cae2d81e28a8701f57..dce04718d023ccbe968ee2ea06f48292a4d1af29 100644 (file)
@@ -9,7 +9,9 @@ def add_lite_prediction_modules(self, model_config, current_channels, module_nam
         UpsampleClass = xnn.layers.UpsampleWith
 
         # can control the range of final output with output_range
-        final_activation = xnn.layers.get_fixed_pact2(output_range=model_config.output_range) if (model_config.output_range is not None) else False
+        output_range = model_config.output_range
+        final_activation = xnn.layers.get_fixed_hardtanh_type(output_range[0],output_range[1]) \
+            if (output_range is not None) else False
         upstride2 = model_config.shortcut_strides[0]
 
         if self.model_config.final_upsample and self.model_config.interpolation_type in ('deconv','upsample_conv','subpixel_conv'):
index 79a8ee713d5542d065794fead3cbd59d0cbd35b9..6cf402db5d0d6743c9882aba4a8f8a972b1a1fa8 100644 (file)
@@ -2,3 +2,6 @@ from . import layers
 from . import optim
 from . import utils
 from . import quantize
+from . import onnx
+
+
index a1e6ee68af0dab1a6891e821eea07959b237b5a8..a70b464297156e1b4610e5f8c497d627a1a26ac1 100644 (file)
@@ -15,7 +15,8 @@ class PAct2(torch.nn.Module):
     PACT2_RANGE_INIT = 8.0      # this is the starting range
     PACT2_RANGE_EXPANSION = 1.1 # expand the calculated range for margin
 
-    def __init__(self, inplace=False, signed=None, percentile_range_shrink=PACT2_RANGE_SHRINK, clip_range=None, power2_activation_range=True, **kwargs):
+    def __init__(self, inplace=False, signed=None, percentile_range_shrink=PACT2_RANGE_SHRINK, clip_range=None,
+                 power2_activation_range=True, **kwargs):
         super().__init__()
         if (clip_range is not None) and (signed is not None):
             assert signed == (clip_range[0]<0.0), 'the signed flag provided did not match the clip_range provided'
@@ -121,14 +122,26 @@ class PAct2(torch.nn.Module):
 
 ###############################################################
 # return a function that creates PAct2 with the given fixed range
-def get_fixed_pact2(inplace=False, signed=None, output_range=None):
-        def FixedPAct2Creator(inplace=inplace, signed=signed):
+# remember: this function returns a type and not an instance
+def get_fixed_pact2_type(inplace=False, signed=None, output_range=None):
+        def FixedPAct2Type(inplace=inplace, signed=signed):
             assert output_range is not None, 'output_range must be specified for FixedPact2'
             clip_range = output_range #max(abs(np.array(output_range)))
             signed = True if ((output_range[0] < 0.0) or (signed is True)) else signed
             return PAct2(inplace=inplace, signed=signed, clip_range=clip_range)
         #
-        return FixedPAct2Creator
+        return FixedPAct2Type
+
+
+###############################################################
+# return a derivative of Hardtanh with the given fixed range
+# remember: this function returns a type and not an instance
+def get_fixed_hardtanh_type(*args, **kwargs):
+        class FixedHardtanhType(torch.nn.Hardtanh):
+            def __init__(self, *args_, **kwargs_):
+                super().__init__(*args, **kwargs)
+        #
+        return FixedHardtanhType
 
 
 ###############################################################
index 0e9b4b40c5a6905c7d412cfaa90ed1ada1bcabab..f7e3b9e294583cd63a0b0e6b93ea80df2dd30d03 100644 (file)
@@ -194,3 +194,13 @@ class ShuffleBlock(torch.nn.Module):
          else:
              return x
 
+
+###############################################################
+class ArgMax(torch.nn.Module):
+    def __init__(self, dim=1, keepdim=True):
+        super().__init__()
+        self.dim = dim
+        self.keepdim = keepdim
+    def forward(self, x):
+        y = torch.argmax(x, dim=self.dim, keepdim=self.keepdim)
+        return y
\ No newline at end of file
index 1f9ce2728a595d5b680cfe1306067c6dbebda88e..b6fa6488b084e09ad0b168053a22473175a93f57 100644 (file)
@@ -2,24 +2,35 @@ import torch
 from .layer_config import *
 from . import functional
 
+def check_groups(in_planes, out_planes, groups, group_size):
+    assert groups is None or group_size is None, 'only one of groups or group_size must be specified'
+    assert groups is not None or group_size is not None, 'atleast one of groups or group_size must be specified'
+    groups = (in_planes//group_size) if groups is None else groups
+    group_size = (in_planes//groups) if group_size is None else group_size
+    assert in_planes%groups == 0, 'in_planes must be a multiple of groups'
+    assert group_size != 1 or in_planes == out_planes, 'in DW layer channels must not change'
+    return groups, group_size
 
 ############################################################### 
 def ConvLayer2d(in_planes, out_planes, kernel_size, stride=1, groups=1, dilation=1, padding=None, bias=False):
     """convolution with padding"""
     padding = padding if padding else ((kernel_size-1)//2)*dilation
+    groups, group_size = check_groups(in_planes, out_planes, groups=groups, group_size=None)
     return DefaultConv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, groups=groups)
 
 
-def ConvDWLayer2d(in_planes, out_planes, stride=1, dilation=1, kernel_size=None, bias=False):
+def ConvDWLayer2d(in_planes, out_planes, kernel_size=None, stride=1, dilation=1, groups_dw=None, group_size_dw=None, bias=False, padding=None):
     """convolution with padding"""
-    assert in_planes == out_planes, 'in DW layer channels must not change'
-    return ConvLayer2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, groups=in_planes, bias=bias)
+    groups_dw = in_planes if (groups_dw is None and group_size_dw is None) else groups_dw
+    groups_dw, group_size_dw = check_groups(in_planes, out_planes, groups=groups_dw, group_size=group_size_dw)
+    return ConvLayer2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, groups=groups_dw, bias=bias, padding=padding)
     
 
 ############################################################### 
 def ConvNormAct2d(in_planes, out_planes, kernel_size=None, stride=1, groups=1, dilation=1, padding=None, bias=False, \
               normalization=DefaultNorm2d, activation=DefaultAct2d):
     """convolution with padding, BN, ReLU"""
+    groups, group_size = check_groups(in_planes, out_planes, groups=groups, group_size=None)
     if type(kernel_size) in (list,tuple):
         padding = padding if padding else (((kernel_size[0]-1)//2)*dilation,((kernel_size[1]-1)//2)*dilation)
     else:
@@ -42,26 +53,31 @@ def ConvNormAct2d(in_planes, out_planes, kernel_size=None, stride=1, groups=1, d
     return layers
 
     
-def ConvDWNormAct2d(in_planes, out_planes, stride=1, kernel_size=None, dilation=1, bias=False, normalization=DefaultNorm2d, activation=DefaultAct2d):
+def ConvDWNormAct2d(in_planes, out_planes, kernel_size=None, stride=1, dilation=1, groups_dw=None, group_size_dw=None, bias=False, padding=None,
+                    normalization=DefaultNorm2d, activation=DefaultAct2d):
     """convolution with padding, BN, ReLU"""
-    assert in_planes == out_planes, 'in DW layer channels must not change'
-    return ConvNormAct2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, bias=bias, groups=in_planes, \
-                     normalization=normalization, activation=activation)
+    groups_dw = in_planes if (groups_dw is None and group_size_dw is None) else groups_dw
+    groups_dw, group_size_dw = check_groups(in_planes, out_planes, groups=groups_dw, group_size=group_size_dw)
+    return ConvNormAct2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, bias=bias, groups=groups_dw, \
+                     padding=padding, normalization=normalization, activation=activation)
 
 
 ###########################################################
-def ConvDWSepNormAct2d(in_planes, out_planes, stride=1, kernel_size=None, groups=1, dilation=1, bias=False, \
+def ConvDWSepNormAct2d(in_planes, out_planes, kernel_size=None, stride=1, groups=1, groups_dw=None, group_size_dw=None, dilation=1, bias=False, padding=None, \
                    first_1x1=False, normalization=(DefaultNorm2d,DefaultNorm2d), activation=(DefaultAct2d,DefaultAct2d)):
+    bias = bias if isinstance(bias, (list,tuple)) else (bias,bias)
     if first_1x1:
-        layers = [ConvNormAct2d(in_planes, out_planes, kernel_size=1, groups=groups, bias=bias,
+        layers = [ConvNormAct2d(in_planes, out_planes, kernel_size=1, bias=bias[0], groups=groups,
                       normalization=normalization[0], activation=activation[0]),
-                  ConvDWNormAct2d(out_planes, out_planes, stride=stride, kernel_size=kernel_size, dilation=dilation, bias=bias,
-                        normalization=normalization[1], activation=activation[1])]
+                  ConvDWNormAct2d(out_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, bias=bias[1],
+                      groups_dw=groups_dw, group_size_dw=group_size_dw,
+                      padding=padding, normalization=normalization[1], activation=activation[1])]
     else:
-        layers = [ConvDWNormAct2d(in_planes, in_planes, stride=stride, kernel_size=kernel_size, dilation=dilation, bias=bias,
-                              normalization=normalization[0], activation=activation[0]),
-                  ConvNormAct2d(in_planes, out_planes, groups=groups, kernel_size=1, bias=bias,
-                            normalization=normalization[1], activation=activation[1])]
+        layers = [ConvDWNormAct2d(in_planes, in_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, bias=bias[0],
+                      groups_dw=groups_dw, group_size_dw=group_size_dw,
+                      padding=padding, normalization=normalization[0], activation=activation[0]),
+                  ConvNormAct2d(in_planes, out_planes, kernel_size=1, bias=bias[1], groups=groups,
+                      normalization=normalization[1], activation=activation[1])]
 
     layers = torch.nn.Sequential(*layers)
     return layers
diff --git a/modules/pytorch_jacinto_ai/xnn/onnx/__init__.py b/modules/pytorch_jacinto_ai/xnn/onnx/__init__.py
new file mode 100644 (file)
index 0000000..ba04b03
--- /dev/null
@@ -0,0 +1,4 @@
+try:
+    from .onnx2pytorch_internal import import_onnx
+except:
+    pass
index a7b44609f1749894ab406d2a07ec6ff011ff44d4..2c9c945dccea515d0fcbe7c0b299a2f2aae04023 100644 (file)
@@ -11,7 +11,7 @@ class SchedulerWrapper(torch.optim.lr_scheduler._LRScheduler):
         if scheduler_type == 'step':
             lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=multistep_gamma, last_epoch=start_epoch-1)
         elif scheduler_type == 'poly':
-            lambda_scheduler = lambda iter: ((1.0-iter/max_iter)**polystep_power)
+            lambda_scheduler = lambda last_epoch: ((1.0-last_epoch/epochs)**polystep_power)
             lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_scheduler, last_epoch=start_epoch-1)
         elif scheduler_type == 'cosine':
             lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0, last_epoch=start_epoch-1)
index 34d49ed0c3766538b3f3458cbf4a623c3e067d86..ecaffe0c47f5648969f9261132736a457f99bac4 100644 (file)
@@ -13,7 +13,7 @@ class HookedModule(torch.nn.Module):
         def _call_hook_enable(op):
             # do not patch the top level modules. makes it easy to invoke by self.module(x)
             if op is not module:
-                assert not hasattr(op, backup_name), f'detected an existing function {backup_name} : please double check'
+                assert not hasattr(op, backup_name), f'in {op.__class__.__name__} detected an existing function {backup_name} : please double check'
                 # backup the original forward of op into backup_name
                 method_orig = getattr(op, method_name)
                 setattr(op, backup_name, method_orig)
index fa6b75396e10e1ac25de878c5c6a2b02aae217fe..11467d1f2495b0193ef235f47a9fe6607f668ccb 100644 (file)
@@ -10,9 +10,9 @@ class QuantEstimationType:
 
 # base module to be use for all quantization modules
 class QuantBaseModule(QuantGraphModule):
-    def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
+    def __init__(self, module, dummy_input, *args, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
                  histogram_range=True, bias_calibration=False, constrain_weights=False, constrain_bias=None,
-                 model_surgery_quantize=True, power2_weight_range=None, power2_activation_range=None):
+                 model_surgery_quantize=True, power2_weight_range=None, power2_activation_range=None, **kwargs):
         super().__init__(module)
         self.bitwidth_weights = bitwidth_weights
         self.bitwidth_activations = bitwidth_activations
@@ -34,7 +34,7 @@ class QuantBaseModule(QuantGraphModule):
             with torch.no_grad():
                 utils.print_yellow("=> model surgery by '{}'".format(self.model_surgery_quantize.__name__))
                 assert dummy_input is not None, 'dummy input is needed by quantized models to analyze graph'
-                self.model_surgery_quantize(dummy_input)
+                self.model_surgery_quantize(dummy_input, *args, **kwargs)
             #
             # add hooks to execute the pact modules
             self.add_activation_hooks()
@@ -52,31 +52,35 @@ class QuantBaseModule(QuantGraphModule):
     def add_activation_hooks(self):
         # add a forward hook to call the extra activation that we added
         def _forward_input_activation(op, inputs):
-            if hasattr(op, 'activation_in'):
-                # hook passes the input as tuple - expand it
-                to_squeeze = isinstance(inputs, tuple) and len(inputs) == 1
-                inputs = inputs[0] if to_squeeze else inputs
-                inputs = op.activation_in(inputs)
-                inputs = (inputs,) if to_squeeze else inputs
-            #
+            # hook passes the input as tuple - expand it
+            to_squeeze = isinstance(inputs, tuple) and len(inputs) == 1
+            inputs = inputs[0] if to_squeeze else inputs
+            inputs = op.activation_in(inputs)
+            inputs = (inputs,) if to_squeeze else inputs
             return inputs
         #
         def _forward_output_activation(op, inputs, outputs):
-            if hasattr(op, 'activation_q'):
-                # hook passes the input as tuple - expand it
-                to_squeeze = isinstance(outputs, tuple) and len(outputs) == 1
-                outputs = outputs[0] if to_squeeze else outputs
-                outputs = op.activation_q(outputs)
-                outputs = (outputs,) if to_squeeze else outputs
-            #
+            # hook passes the input as tuple - expand it
+            to_squeeze = isinstance(outputs, tuple) and len(outputs) == 1
+            outputs = outputs[0] if to_squeeze else outputs
+            outputs = op.activation_q(outputs)
+            outputs = (outputs,) if to_squeeze else outputs
             return outputs
         #
         for m in self.modules():
-            m.register_forward_pre_hook(_forward_input_activation)
-            m.register_forward_hook(_forward_output_activation)
+            if hasattr(m, 'activation_in'):
+                m.register_forward_pre_hook(_forward_input_activation)
+            #
+            if hasattr(m, 'activation_q'):
+                m.register_forward_hook(_forward_output_activation)
+            #
         #
 
 
+    def apply_setattr(self, **kwargs):
+        utils.apply_setattr(self, **kwargs)
+
+
     def train(self, mode=True):
         self.iter_in_epoch = -1
         super().train(mode)
index 954ea45368fa25a72eeb16f59e28c63f545e1d58..a36adccaaaa98fc0b3672409623a56a8ca9932ef 100644 (file)
@@ -17,9 +17,9 @@ warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
 
 ###########################################################
 class QuantCalibrateModule(QuantTrainModule):
-    def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
+    def __init__(self, module, dummy_input, *args, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
                  histogram_range=True, bias_calibration=True, constrain_weights=None,
-                 power2_weight_range=None, power2_activation_range=None, constrain_bias=None, lr_calib=0.05):
+                 power2_weight_range=None, power2_activation_range=None, constrain_bias=None, lr_calib=0.05, **kwargs):
         self.weights_calibration = False
         self.lr_calib = lr_calib
         self.calibration_factor = lr_calib
@@ -28,12 +28,12 @@ class QuantCalibrateModule(QuantTrainModule):
         self.quantize_enable = True
         self.update_activation_range = True
         constrain_weights = (bias_calibration and (not per_channel_q)) if constrain_weights is None else constrain_weights
-        super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
+        super().__init__(module, dummy_input, *args, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration, constrain_weights=constrain_weights,
-                         power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range, constrain_bias=constrain_bias)
+                         power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range, constrain_bias=constrain_bias, **kwargs)
 
 
-    def forward(self, inputs):
+    def forward(self, inputs, *args, **kwargs):
         # calibration doesn't need gradients
         with torch.no_grad():
             # counters such as num_batches_tracked are used. update them.
@@ -51,9 +51,9 @@ class QuantCalibrateModule(QuantTrainModule):
             # actual forward call
             if self.training and (self.bias_calibration or self.weights_calibration):
                 # calibration
-                outputs = self.forward_calibrate(inputs)
+                outputs = self.forward_calibrate(inputs, *args, **kwargs)
             else:
-                outputs = self.module(inputs)
+                outputs = self.module(inputs, *args, **kwargs)
             #
 
             self.train(training)
@@ -61,7 +61,7 @@ class QuantCalibrateModule(QuantTrainModule):
         return outputs
 
 
-    def forward_calibrate(self, inputs):
+    def forward_calibrate(self, inputs, *args, **kwargs):
         # we don't need gradients for calibration
         # prepare/backup weights
         if self.num_batches_tracked == 0:
@@ -74,21 +74,21 @@ class QuantCalibrateModule(QuantTrainModule):
         #
 
         # Compute the mean output in float first.
-        outputs = self.forward_float(inputs)
+        outputs = self.forward_float(inputs, *args, **kwargs)
         # Then adjust weights/bias so that the quantized output matches float output
-        outputs = self.forward_quantized(inputs)
+        outputs = self.forward_quantized(inputs, *args, **kwargs)
 
         return outputs
 
 
-    def forward_float(self, inputs):
+    def forward_float(self, inputs, *args, **kwargs):
         self._restore_weights_orig()
         # disable quantization for a moment
         quantize_enable_backup_value, update_activation_range_backup_value = self.quantize_enable, self.update_activation_range
         utils.apply_setattr(self, quantize_enable=False, update_activation_range=False)
 
         self.add_call_hook(self.module, self.forward_float_hook)
-        outputs = self.module(inputs)
+        outputs = self.module(inputs, *args, **kwargs)
         self.remove_call_hook(self.module)
 
         # turn quantization back on - not a clean method
@@ -101,6 +101,9 @@ class QuantCalibrateModule(QuantTrainModule):
 
         # calibration at specific layers
         output = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
+        while isinstance(output, (list, tuple)):
+            output = output[0]
+
         reduce_dims = [0, 2, 3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
 
         bias = op.bias if hasattr(op, 'bias') else None
@@ -115,21 +118,23 @@ class QuantCalibrateModule(QuantTrainModule):
     #
 
 
-    def forward_quantized(self, input):
+    def forward_quantized(self, input, *args, **kwargs):
         self._restore_weights_quant()
         self.add_call_hook(self.module, self.forward_quantized_hook)
         for _ in range(self.calibrate_repeats):
-            output = self.module(input)
+            output = self.module(input, *args, **kwargs)
         #
         self.remove_call_hook(self.module)
         self._backup_weights_quant()
         return output
     #
-    def forward_quantized_hook(self, op, *inputs_orig):
-        outputs = op.__forward_orig__(*inputs_orig)
+    def forward_quantized_hook(self, op, input, *args, **kwargs):
+        outputs = op.__forward_orig__(input, *args, **kwargs)
 
         # calibration at specific layers
         output = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
+        while isinstance(output, (list, tuple)):
+            output = output[0]
 
         bias = op.bias if hasattr(op, 'bias') else None
         if self.bias_calibration and bias is not None:
index cbcb3c1ebb6fd8c4a465ab5d1eb80d0b87b53cec..b9dc4c1311bb2f90480a789efd8947ee01f9b727 100644 (file)
@@ -56,7 +56,7 @@ class QuantGraphModule(HookedModule):
         return self.__qstate__
 
 
-    def forward(self, inputs):
+    def forward(self, inputs, *args, **kwargs):
         assert False, 'forward is not defined'
 
 
@@ -72,13 +72,13 @@ class QuantGraphModule(HookedModule):
 
     # force_update is used to increment inte counters even in non training
     # used for validation in QuantTestModule
-    def analyze_graph(self, inputs, force_update=False, merge_weights=False, clear_qstate=False):
+    def analyze_graph(self, inputs, *args, force_update=False, merge_weights=False, clear_qstate=False, **kwargs):
         with torch.no_grad():
             self.init_qstate()
             self.update_counters(force_update=force_update)
             if (self.get_qstate().analyzed_graph == False):
                 # forward and analyze
-                self.forward_analyze_modules(inputs)
+                self.forward_analyze_modules(inputs, *args, **kwargs)
                 # analyze the connections
                 self.analyze_connections()
                 self.get_qstate().analyzed_graph = True
@@ -95,11 +95,11 @@ class QuantGraphModule(HookedModule):
         #
 
 
-    def model_surgery_quantize(self, dummy_input):
+    def model_surgery_quantize(self, dummy_input, *args, **kwargs):
         # lear the sates - just to be sure
         self.clear_qstate()
         # analyze
-        self.analyze_graph(dummy_input)
+        self.analyze_graph(dummy_input, *args, **kwargs)
         # insert QAct wherever range clipping needs to be done
         self.model_surgery_activations()
         # since we might have added new activations, clear the sates as they may not be valid
@@ -135,7 +135,8 @@ class QuantGraphModule(HookedModule):
                 #
             elif qparams.quantize_in:
                 if not hasattr(module, 'activation_in'):
-                    activation_in = layers.PAct2(signed=None)
+                    # do not want to clip input, so set percentile_range_shrink=0.0
+                    activation_in = layers.PAct2(signed=None, percentile_range_shrink=0.0)
                     activation_in.train(self.training)
                     module.activation_in = activation_in
                 #
@@ -152,28 +153,35 @@ class QuantGraphModule(HookedModule):
 
 
     ################################################################
-    def forward_analyze_modules(self, inputs):
+    def forward_analyze_modules(self, inputs, *args, **kwargs):
         '''
         analyze modules needs a call hook - the call hook does not work with DataParallel.
         So, do the analysis on a copy.
         '''
         self_copy = copy.deepcopy(self)
-        self_copy._forward_analyze_modules_impl(inputs)
+        self_copy._forward_analyze_modules_impl(inputs, *args, **kwargs)
         self.get_qstate().qparams = self_copy.get_qstate().qparams
 
-    def _forward_analyze_modules_impl(self, inputs):
+    def _forward_analyze_modules_impl(self, inputs, *args, **kwargs):
         self.start_call()
         self.add_call_hook(self, self._analyze_modules_op)
-        output = self.module(inputs)
+        forward_analyze_method_name = kwargs.pop('forward_analyze_method', None)
+        if forward_analyze_method_name is not None and hasattr(self.module, forward_analyze_method_name):
+            # get the bound method to be used as forward
+            forward_analyze_method = getattr(self.module, forward_analyze_method_name)
+            output = forward_analyze_method(inputs, *args, **kwargs)
+        else:
+            output = self.module(inputs, *args, **kwargs)
+        #
         self.remove_call_hook(self.module)
         self.finish_call()
         return output
 
-    def _analyze_modules_op(self, op, *inputs_orig):
-        inputs = utils.squeeze_list2(inputs_orig)
+    def _analyze_modules_op(self, op, inputs, *args, **kwargs):
+        inputs = utils.squeeze_list2(inputs)
         self.start_node(op)
         self.add_node(op, inputs)
-        outputs = op.__forward_orig__(*inputs_orig)
+        outputs = op.__forward_orig__(inputs, *args, **kwargs)
         self.add_node(op, inputs, outputs)
         self.finish_node(op, inputs, outputs)
         return outputs
index 76a139160718e44658dfca591a108eb12e5b4c6f..bc5e3d5a04af8e12ab1654d77c3ab974fb2758f7 100644 (file)
@@ -9,14 +9,15 @@ import numpy as np
 from .quant_train_module import *
 
 class QuantTestModule(QuantTrainModule):
-    def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
+    def __init__(self, module, dummy_input, *args, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
                  histogram_range=True, bias_calibration=False, constrain_weights=None, model_surgery_quantize=True,
-                 power2_weight_range=None, power2_activation_range=None, constrain_bias=None):
+                 power2_weight_range=None, power2_activation_range=None, constrain_bias=None, **kwargs):
         constrain_weights = (not per_channel_q) if constrain_weights is None else constrain_weights
-        super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
+        super().__init__(module, dummy_input, *args, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
                          constrain_weights=constrain_weights,
-                         power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range, constrain_bias=constrain_bias)
+                         power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range,
+                         constrain_bias=constrain_bias, **kwargs)
         assert model_surgery_quantize == True, f'{self.__class__.__name__} does not support model_surgery_quantize=False. please use a qat or calibrated module.'
         self.eval()
 
index f0b4529f1235068fad816207fa914b1a89195288..679261e842ab2da9428a8a3ddabb6ffc24d6b5d8 100644 (file)
@@ -18,25 +18,26 @@ warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
 
 ###########################################################
 class QuantTrainModule(QuantBaseModule):
-    def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
+    def __init__(self, module, dummy_input, *args, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
                  histogram_range=True, bias_calibration=False, constrain_weights=None,
-                 power2_weight_range=None, power2_activation_range=None, constrain_bias=None):
+                 power2_weight_range=None, power2_activation_range=None, constrain_bias=None, **kwargs):
         constrain_weights = (not per_channel_q) if constrain_weights is None else constrain_weights
-        super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
+        super().__init__(module, dummy_input, *args, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
                          constrain_weights=constrain_weights, model_surgery_quantize=True,
-                         power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range, constrain_bias=constrain_bias)
+                         power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range,
+                         constrain_bias=constrain_bias, **kwargs)
 
-    def forward(self, inputs):
+    def forward(self, inputs, *args, **kwargs):
         # counters such as num_batches_tracked are used. update them.
         self.update_counters()
         # outputs
-        outputs = self.module(inputs)
+        outputs = self.module(inputs, *args, **kwargs)
         return outputs
 
 
-    def model_surgery_quantize(self, dummy_input):
-        super().model_surgery_quantize(dummy_input)
+    def model_surgery_quantize(self, dummy_input, *args, **kwargs):
+        super().model_surgery_quantize(dummy_input, *args, **kwargs)
 
         def replace_func(op):
             for name, m in op._modules.items():
@@ -73,7 +74,7 @@ class QuantTrainModule(QuantBaseModule):
                     new_m = None
                 #
                 if new_m is not None:
-                    copy_attr_list = ('weight', 'bias', 'eps', 'clips_act', 'clips_w')
+                    copy_attr_list = ('weight', 'bias', 'eps', 'clips_act', 'clips_w', 'percentile_range_shrink')
                     for attr in dir(m):
                         value = getattr(m,attr)
                         if isinstance(value,torch.Tensor) and value is not None:
index 3bc9743abb5204804cf0078bc3977b1fbfc52904..5a933682a6f3b17edb60b7a5c9dc2b8a7895a115 100644 (file)
@@ -10,6 +10,15 @@ from . import print_utils
 from . import utils_data
 
 ######################################################
+# the method used in vision/models
+try:
+    from torch.hub import load_state_dict_from_url
+except ImportError:
+    from torch.utils.model_zoo import load_url as load_state_dict_from_url
+
+
+######################################################
+# our custom load function with more features
 def load_weights(model, pretrained, change_names_dict=None, keep_original_names=False, width_mult=1.0,
                        ignore_size=True, verbose=False, num_batches_tracked = None, download_root=None, **kwargs):
     download_root = './' if (download_root is None) else download_root
index 461dbc4bb0147bc9c5b506770bb411b351b9ee32..3a288e7b59c7d1103712192843a7b103e7911f7b 100755 (executable)
@@ -5,41 +5,53 @@
 ## =====================================================================================
 #
 #### Image Classification - Post Training Calibration & Quantization - ResNet50
-#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name resnet50_x1 --data_path ./data/datasets/image_folder_classification \
+#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name resnet50_x1 --data_path ./data/datasets/image_folder_classification --gpus 0 \
 #--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth \
 #--batch_size 64 --quantize True --epochs 1 --epoch_size 0.1 --evaluate_start False
 #
 #
 #### Image Classification - Post Training Calibration & Quantization - ResNet18
-#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name resnet18_x1 --data_path ./data/datasets/image_folder_classification \
+#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name resnet18_x1 --data_path ./data/datasets/image_folder_classification --gpus 0 \
 #--pretrained https://download.pytorch.org/models/resnet18-5c106cde.pth \
 #--batch_size 64 --quantize True --epochs 1 --epoch_size 0.1 --evaluate_start False
 #
 #
 #### Image Classification - Post Training Calibration & Quantization - MobileNetV2
-#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification \
+#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification --gpus 0 \
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth \
 #--batch_size 64 --quantize True --epochs 1 --epoch_size 0.1 --evaluate_start False
 #
 #
 #### Image Classification - Post Training Calibration & Quantization for a TOUGH MobileNetV2 pretrained model
-#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 --data_path ./data/datasets/image_folder_classification \
+#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 --data_path ./data/datasets/image_folder_classification --gpus 0 \
 #--pretrained ./data/modelzoo/pytorch/image_classification/imagenet1k/shicai/mobilenetv2_shicai_rgb.pth \
 #--batch_size 64 --quantize True --epochs 1 --epoch_size 0.1 --evaluate_start False
 #
 #
+#### Image Classification - Post Training Calibration & Quantization - ONNX Model Import
+#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --gpus 0 \
+#--model_name resnet18-v1-7 --model /data/tensorlabdata1/modelzoo/pytorch/image_classification/imagenet1k/onnx-model-zoo/resnet18-v1-7.onnx \
+#--data_path ./data/datasets/image_folder_classification --batch_size 64 --quantize True --epochs 1 --epoch_size 0.1 --evaluate_start False
+#
+#
 ### Semantic Segmentation - Post Training Calibration &  Quantization for MobileNetV2+DeeplabV3Lite
-#python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 \
+#python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 \
 #--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth \
 #--batch_size 6 --quantize True --epochs 1 --evaluate_start False
 #
 #
 ### Semantic Segmentation - Post Training Calibration &  Quantization for MobileNetV2+UNetLite
-#python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 \
+#python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 \
 #--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth \
 #--batch_size 6 --quantize True --epochs 1 --evaluate_start False
 #
 #
+### Depth Estimation - Post Training Calibration &  Quantization for MobileNetV2+DeeplabV3Lite
+#python ./scripts/train_depth_main.py --phase calibration --dataset_name kitti_depth --model_name deeplabv3lite_mobilenetv2_tv  --gpus 0 \
+#--pretrained ./data/modelzoo/pytorch/monocular_depth/kitti_depth/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth \
+#--batch_size 32 --quantize True --epochs 1 --evaluate_start False
+#
+#
 ## =====================================================================================
 ## Quantization Aware Training
 ## =====================================================================================
 #### Image Classification - Quantization Aware Training - MobileNetV2
 #python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification \
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth \
-#--batch_size 64 --quantize True --epochs 25 --epoch_size 0.1 --lr 1e-5 --evaluate_start False
+#--batch_size 64 --quantize True --epochs 50 --epoch_size 0.1 --lr 1e-5 --evaluate_start False
 #
 #
 #### Image Classification - Quantization Aware Training - MobileNetV2(Shicai) - a TOUGH MobileNetV2 pretrained model
 #python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 --data_path ./data/datasets/image_folder_classification \
 #--pretrained ./data/modelzoo/pytorch/image_classification/imagenet1k/shicai/mobilenetv2_shicai_rgb.pth \
-#--batch_size 64 --quantize True --epochs 25 --epoch_size 0.1 --lr 1e-5 --evaluate_start False
+#--batch_size 64 --quantize True --epochs 50 --epoch_size 0.1 --lr 1e-5 --evaluate_start False
 #
 #
 #### Semantic Segmentation - Quantization Aware Training for MobileNetV2+DeeplabV3Lite
 #python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth \
-#--batch_size 6 --quantize True --epochs 150 --lr 1e-5 --evaluate_start False
+#--batch_size 12 --quantize True --epochs 50 --lr 1e-5 --evaluate_start False
 #
 #
 #### Semantic Segmentation - Quantization Aware Training for MobileNetV2+UNetLite
 #python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth \
-#--batch_size 6 --quantize True --epochs 150 --lr 1e-5 --evaluate_start False
+#--batch_size 12 --quantize True --epochs 50 --lr 1e-5 --evaluate_start False
+#
 #
+### Depth Estimation - Post Training Calibration &  Quantization for MobileNetV2+DeeplabV3Lite
+#python ./scripts/train_depth_main.py --dataset_name kitti_depth --model_name deeplabv3lite_mobilenetv2_tv --img_resize 384 768 --output_size 1024 2048 \
+#--pretrained ./data/modelzoo/pytorch/monocular_depth/kitti_depth/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth \
+#--batch_size 32 --quantize True --epochs 50 --lr 1e-5 --evaluate_start False
 #
 #
 ## =====================================================================================
index 19b364c2fead1736bba1607b7ec8599d5da0fa74..ebd9b02f1089011e7dfcfa7b3644a853692d09b3 100755 (executable)
@@ -18,6 +18,7 @@ parser.add_argument('--lr', type=float, default=None, help='Base learning rate')
 parser.add_argument('--lr_clips', type=float, default=None, help='Learning rate for clips in PAct2')
 parser.add_argument('--lr_calib', type=float, default=None, help='Learning rate for calibration')
 parser.add_argument('--model_name', type=str, default=None, help='model name')
+parser.add_argument('--model', default=None, help='model')
 parser.add_argument('--dataset_name', type=str, default=None, help='dataset name')
 parser.add_argument('--data_path', type=str, default=None, help='data path')
 parser.add_argument('--epochs', type=int, default=None, help='number of epochs')
@@ -52,6 +53,7 @@ parser.add_argument('--epoch_size_val', type=float, default=None, help='epoch si
                                                                    '0 will use the full epoch. '
                                                                    'using a number will cause the epoch to have that many images. '
                                                                    'using a fraction will reduce the number of images used for one epoch. ')
+parser.add_argument('--parallel_model', type=str2bool, default=None, help='whether to use DataParallel for models')
 #
 cmds = parser.parse_args()
 
@@ -154,19 +156,22 @@ for key in vars(cmds):
 
 ################################
 # these dependent on the dataset chosen
-args.model_config.num_classes = (100 if 'cifar100' in args.dataset_name else (10  if 'cifar10' in args.dataset_name else 1000))
-
-
+args.model_config.num_classes = (100 if 'cifar100' in args.dataset_name else (10 if 'cifar10' in args.dataset_name else 1000))
 
 ################################
 # Run the training
 train_classification.main(args)
 
 ################################
-# In addition run a quantized calibration, starting from the trained model
+# if the previous phase was training, run a quantization aware training, starting from the trained model
 if 'training' in args.phase and (not args.quantize):
-    save_path = train_classification.get_save_path(args)
-    args.pretrained = os.path.join(save_path, 'model_best.pth')
+    if args.epochs > 0:
+        save_path = train_classification.get_save_path(args)
+        if isinstance(args.model, str) and args.model.endswith('.onnx'):
+            args.model = os.path.join(save_path, 'model_best.onnx')
+        #
+        args.pretrained = os.path.join(save_path, 'model_best.pth')
+    #
     args.phase = 'training_quantize'
     args.quantize = True
     args.lr = 1e-5
@@ -178,6 +183,9 @@ if 'training' in args.phase and (not args.quantize):
 # In addition run a separate validation, starting from the calibrated model - to estimate the quantized accuracy accurately
 if 'training' in args.phase or 'calibration' in args.phase:
     save_path = train_classification.get_save_path(args)
+    if isinstance(args.model, str) and args.model.endswith('.onnx'):
+        args.model = os.path.join(save_path, 'model_best.onnx')
+    #
     args.pretrained = os.path.join(save_path, 'model_best.pth')
     if 'training' in args.phase:
         # DataParallel isn't enabled for QuantCalibrateModule and QuantTestModule.
@@ -188,4 +196,6 @@ if 'training' in args.phase or 'calibration' in args.phase:
     args.phase = 'validation'
     args.quantize = True
     train_classification.main(args)
-#
\ No newline at end of file
+#
+
+
index f85b5f0edc4cf715c157c8219a9d7efd6f3ad3a6..2fb707818086bd2dd310832d63585f68296de585 100755 (executable)
@@ -17,6 +17,7 @@ parser.add_argument('--lr', type=float, default=None, help='Base learning rate')
 parser.add_argument('--lr_clips', type=float, default=None, help='Learning rate for clips in PAct2')
 parser.add_argument('--lr_calib', type=float, default=None, help='Learning rate for calibration')
 parser.add_argument('--model_name', type=str, default=None, help='model name')
+parser.add_argument('--model', default=None, help='model')
 parser.add_argument('--dataset_name', type=str, default=None, help='dataset name')
 parser.add_argument('--data_path', type=str, default=None, help='data path')
 parser.add_argument('--epochs', type=int, default=None, help='number of epochs')
@@ -52,6 +53,7 @@ parser.add_argument('--epoch_size_val', type=float, default=None, help='epoch si
                                                                    '0 will use the full epoch. '
                                                                    'using a number will cause the epoch to have that many images. '
                                                                    'using a fraction will reduce the number of images used for one epoch. ')
+parser.add_argument('--parallel_model', type=str2bool, default=None, help='whether to use DataParallel for models')
 #
 cmds = parser.parse_args()
 
@@ -158,10 +160,15 @@ for key in vars(cmds):
 train_pixel2pixel.main(args)
 
 ################################
-# In addition run a quantization aware training, starting from the trained model
+# if the previous phase was training, run a quantization aware training, starting from the trained model
 if 'training' in args.phase and (not args.quantize):
-    save_path = train_pixel2pixel.get_save_path(args)
-    args.pretrained = os.path.join(save_path, 'model_best.pth') if (args.epochs>0) else args.pretrained
+    if args.epochs > 0:
+        save_path = train_pixel2pixel.get_save_path(args)
+        if isinstance(args.model, str) and args.model.endswith('.onnx'):
+            args.model = os.path.join(save_path, 'model_best.onnx')
+        #
+        args.pretrained = os.path.join(save_path, 'model_best.pth')
+    #
     args.phase = 'training_quantize'
     args.quantize = True
     args.lr = 1e-5
@@ -173,6 +180,9 @@ if 'training' in args.phase and (not args.quantize):
 # In addition run a separate validation
 if 'training' in args.phase or 'calibration' in args.phase:
     save_path = train_pixel2pixel.get_save_path(args)
+    if isinstance(args.model, str) and args.model.endswith('.onnx'):
+        args.model = os.path.join(save_path, 'model_best.onnx')
+    #
     args.pretrained = os.path.join(save_path, 'model_best.pth')
     if 'training' in args.phase:
         # DataParallel isn't enabled for QuantCalibrateModule and QuantTestModule.