support DataParallel for QuantTrainModule
authorManu Mathew <a0393608@ti.com>
Tue, 19 May 2020 17:13:48 +0000 (22:43 +0530)
committerManu Mathew <a0393608@ti.com>
Tue, 19 May 2020 18:48:38 +0000 (00:18 +0530)
quantization docs update

release commit

docs/Calibration.md
docs/Quantization.md
modules/pytorch_jacinto_ai/engine/train_classification.py
modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_test_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py
scripts/train_classification_main.py
scripts/train_motion_segmentation_main.py
scripts/train_pixel2pixel_multitask_main.py
scripts/train_segmentation_main.py

index 385c0c7bde8515ef9d6ade7da081a5ac3c2009f1..0d1098cb1b488037db0b0e04538b749e9ca31b86 100644 (file)
@@ -90,3 +90,6 @@ python ./scripts/train_segmentation_main.py --phase calibration --dataset_name c
 --batch_size 12 --quantize True --epochs 1
 ```
 
+## Guidelines, Implementation Notes, Limitations & Recommendations
+- Please refer to the section on Quantization Aware Training, as the same guidelines, recomendations & limitations apply to QuantCalibrateModule.<br>
+- An additional limitation is that multi gpu processing with DataParallel / DistributedDataParallel is not supported for QuantCalibrateModule (also for QuantTestModule). In our example training scripts train_classification.py and train_pixel2pixel.py in pytorch_jacinto_ai/engine, we do not wrap the model in DataParallel if the model is QuantCalibrateModule or QuantTestModule. The original floating point training (without quantization) can use Multi-GPU as usual and we do not have any restrictions on that. (However multi gpu support with DataParallel works for QuantTrainModule - more details of this in the QAT section).<br>
index 803b2a6aba35467235fec4ca6412823d3f02ac32..5ae553150c33e181505131148c9f6c14e8a41be2 100644 (file)
@@ -29,9 +29,8 @@ To get best accuracy at the quantization stage, it is important that the model i
 - **The same module should not be re-used multiple times within the module** in order that the feature map range estimation is correct. Unfortunately, in the torchvision ResNet models, the ReLU module in the BasicBlock and BottleneckBlock are re-used multiple times. We have corrected this by defining separate ReLU modules. This change is minor and **does not** affect the loading of existing pretrained weights. See the [our modified ResNet model definition here](../modules/pytorch_jacinto_ai/vision/models/resnet.py).<br>
 - If you have done QAT and is getting poor accuracy either in the Python code or during inference in the platform, please inspect your model carefully to see if the above recommendations have been followed - some of these can be easily missed by oversight - and can result in painful debugging that could have been avoided.<br>
 - However, if a function does not change the range of feature map, it is not critical to use it in Module form. An example of this is torch.nn.functional.interpolate<br>
-- **Multi-GPU training/calibration/validation with DataParallel is not yet working with our quantization modules** QuantTrainModule/QuantCalibrateModule/QuantTestModule. We recommend not to wrap the modules in DataParallel if you are training/calibrating/testing with quantization - i.e. if your model is wrapped in QuantTrainModule/QuantCalibrateModule/QuantTestModule.<br>
-- If you get an error during training related to weights and input not being in the same GPU, please check and ensure that you are not using DataParallel with QuantTrainModule/QuantCalibrateModule/QuantTestModule. This may not be such a problem as calibration and quantization may not take as much time as the original floating point training. The original floating point training (without quantization) can use Multi-GPU as usual and we do not have any restrictions on that.<br>
-- If your calibration/training crashes with insufficient GPU memory, reduce the batch size and try again.
+- **Multi-GPU training/calibration/validation with DataParallel is supported with our QAT module** QuantTrainModule. This takes care of a major concern that was earlier there in doing QAT with QuantTrainModule. (However it is not supported for QuantCalibrateModule/QuantTestModule - these calibration/test phases take much less time - so hopefully this is not a big issue. In our example training scripts train_classification.py and train_pixel2pixel.py in pytorch_jacinto_ai/engine, we do not wrap the model in DataParallel if the model is QuantCalibrateModule or QuantTestModule, but we do that for QuantTrainModule).<br>
+- If your training/calibration crashes because of insufficient GPU memory, reduce the batch size and try again.
 - This repository has several useful functions and Modules as part of the xnn python module. Most notable ones are: [xnn.layers.resize_with, xnn.layers.ResizeWith](../modules/pytorch_jacinto_ai/xnn/resize_blocks.py) to export a clean resize/interpolate/upsamle graph, [xnn.layers.AddBlock, xnn.layers.CatBlock](../modules/pytorch_jacinto_ai/xnn/common_blocks.py) to do elementwise addition & concatenation in a torch.nn.Module form.
 - If you are using TIDL to infer a model trained using QAT (or calibratied using PTQ) tools provided in this repository, please set the following in the import config file for best accuracy: **quantizationStyle = 3** to use power of 2 quantization. **foldPreBnConv2D = 0** to avoid a slight accuracy degradation due to incorrect folding of BatchNormalization that comes before Convolution (input mean/scale is implemented in TIDL as a PreBN - so this affects most networks).
 
index 89f08b6cea894848a1551dc781e0da5b8aeec5c7..609b134ee255de0c3b277be17932f6890aa8e253 100644 (file)
@@ -282,12 +282,13 @@ def main(args):
         exit()
 
     #################################################
-    # multi gpu mode is not working for quantized model
-    if args.parallel_model and (not args.quantize):
+    # DataParallel does not work for QuantCalibrateModule or QuantTestModule
+    if args.parallel_model and (not isinstance(model, (xnn.quantize.QuantCalibrateModule, xnn.quantize.QuantTestModule))):
         if args.distributed:
             model = torch.nn.parallel.DistributedDataParallel(model)
         else:
             model = torch.nn.DataParallel(model)
+        #
     #
 
     #################################################
index d838078a1c038ffe4aec02e3d4ff7b0f3a5c2de3..6e9e6235059715c890905a41f55a4a0a4b509cc6 100644 (file)
@@ -401,9 +401,8 @@ def main(args):
         exit()
 
     #################################################
-    # multi gpu mode does not work for calibration/training for quantization
-    # so use it only when args.quantize is False
-    if args.parallel_model and ((not args.quantize)):
+    # DataParallel does not work for QuantCalibrateModule or QuantTestModule
+    if args.parallel_model and (not isinstance(model, (xnn.quantize.QuantCalibrateModule, xnn.quantize.QuantTestModule))):
         model = torch.nn.DataParallel(model)
 
     #################################################
index 39c6482841f761a9263693bc6d6746ae590c183a..3f40d38e49dfe90ebafe1df7bfb0cb7cf9425ad7 100644 (file)
@@ -7,13 +7,10 @@ from ..utils import AttrDict as Dict
 from .hooked_module import *
 
 class QuantGraphModule(HookedModule):
-    # instance member states are not retained across forward calls when using DataParallel
-    # as a workaround, use class member variables instead, so that these can be retained
-    states = Dict()
     def __init__(self, module):
         super().__init__()
         self.module = module
-        self.init_states()
+        self.init_qstate()
         self.register_buffer('num_batches_tracked', torch.tensor(-1.0))
         self.register_buffer('iter_in_epoch', torch.tensor(-1.0))
         self.register_buffer('epoch', torch.tensor(-1.0))
@@ -28,62 +25,28 @@ class QuantGraphModule(HookedModule):
         # #
 
 
-    # create the state object required to keep some quantization parameters that need to be preserved
-    # a cuda() has been called on the module - copy the states from that was created for cpu
-    def get_state(self):
-        states = self.get_states()
-        module_device = self.module_device(self.module)
-        if module_device not in states:
-            module_device_src = None
-            for key, value in states.items():
-                if key.type == 'cpu':
-                    module_device_src = key
-            #
-            if module_device_src is not None and module_device_src in states:
-                states[module_device] = copy.deepcopy(states[module_device_src])
-            else:
-                states[module_device] = Dict()
+    def init_qstate(self):
+        if not hasattr(self, '__qstate__'):
+            self.__qstate__ = Dict()
+        #
+        if 'qparams' not in self.get_qstate():
+            self.get_qstate().qparams = Dict()
+        #
+        if 'qparams_prev' not in self.get_qstate():
+            self.get_qstate().qparams_prev = Dict()
+        #
+        if 'analyzed_graph' not in self.get_qstate():
+            self.get_qstate().analyzed_graph = False
         #
-        return states[module_device]
-
-
-    def get_states(self):
-        return __class__.states
-
 
-    def clear_states(self):
-        __class__.states = Dict()
 
-    # these entries will prevent this modules from being used with DataParallel - cleanup
-    def cleanup_states(self):
-        assert self.get_state().analyzed_graph == True, 'graph must be analyzed before cleanup_states()'
-        with torch.no_grad():
-            for module_hash, qparams in self.get_state().qparams.items():
-                if hasattr(qparams, 'previous_node'):
-                    del qparams.previous_node
-                #
-                if hasattr(qparams, 'previous_module'):
-                    del qparams.previous_module
-                #
-                if hasattr(qparams, 'next_node'):
-                    del qparams.next_node
-                #
-                if hasattr(qparams, 'next_module'):
-                    del qparams.next_module
-                #
-            #
-        #
+    def clear_qstate(self):
+        self.__qstate__ = Dict()
+        self.init_qstate()
 
 
-    # data parallel does not initialize the replicas correctly, explicitly initialize them.
-    # there is no use in doing this in __init__. it has to be done in forward even if it is called in __init__.
-    def init_states(self):
-        if 'qparams' not in self.get_state().keys():
-            self.get_state().qparams = Dict()
-        if 'qparams_prev' not in self.get_state().keys():
-            self.get_state().qparams_prev = Dict()
-        if 'analyzed_graph' not in self.get_state().keys():
-            self.get_state().analyzed_graph = False
+    def get_qstate(self):
+        return self.__qstate__
 
 
     def forward(self, inputs):
@@ -102,24 +65,24 @@ class QuantGraphModule(HookedModule):
 
     # force_update is used to increment inte counters even in non training
     # used for validation in QuantTestModule
-    def analyze_graph(self, inputs, force_update=False, merge_weights=False, cleanup_states=False):
+    def analyze_graph(self, inputs, force_update=False, merge_weights=False, clear_qstate=False):
         with torch.no_grad():
-            self.init_states()
+            self.init_qstate()
             self.update_counters(force_update=force_update)
-            if (self.get_state().analyzed_graph == False):
+            if (self.get_qstate().analyzed_graph == False):
                 # forward and analyze
                 self.forward_analyze_modules(inputs)
                 # analyze the connections
                 self.analyze_connections()
-                self.get_state().analyzed_graph = True
+                self.get_qstate().analyzed_graph = True
 
                 # merge weights so that weight quantization can be done
                 if merge_weights:
                     self.merge_weights()
                 #
 
-                if cleanup_states:
-                    self.cleanup_states()
+                if clear_qstate:
+                    self.clear_qstate()
                 #
             #
         #
@@ -127,18 +90,18 @@ class QuantGraphModule(HookedModule):
 
     def model_surgery_quantize(self, dummy_input):
         # lear the sates - just to be sure
-        self.clear_states()
+        self.clear_qstate()
         # analyze
         self.analyze_graph(dummy_input)
         # insert NoAct wherever range clipping needs to be done
         self.model_surgery_activations()
         # since we might have added new activations, clear the sates as they may not be valid
-        self.clear_states()
+        self.clear_qstate()
         # need to call analyze_graph in the derived class
     #
 
     def model_surgery_activations(self):
-        for module_hash, qparams in self.get_state().qparams.items():
+        for module_hash, qparams in self.get_qstate().qparams.items():
             module = self.get_module(module_hash)
             if isinstance(module, layers.PAct2):
                 pass
@@ -170,14 +133,6 @@ class QuantGraphModule(HookedModule):
     #
 
 
-    def start_call(self):
-        self.call_count = Dict()
-
-
-    def finish_call(self):
-        self.call_count = None
-
-
     def train(self, mode=True):
         self.iter_in_epoch.fill_(-1.0)
         super().train(mode)
@@ -185,6 +140,15 @@ class QuantGraphModule(HookedModule):
 
     ################################################################
     def forward_analyze_modules(self, inputs):
+        '''
+        analyze modules needs a call hook - the call hook does not work with DataParallel.
+        So, do the analysis on a copy.
+        '''
+        self_copy = copy.deepcopy(self)
+        self_copy._forward_analyze_modules_impl(inputs)
+        self.get_qstate().qparams = self_copy.get_qstate().qparams
+
+    def _forward_analyze_modules_impl(self, inputs):
         self.start_call()
         self.add_call_hook(self.module, self._analyze_modules_op)
         output = self.module(inputs)
@@ -205,29 +169,29 @@ class QuantGraphModule(HookedModule):
         inputs = self.format_tensors(inputs)
         module_hash = self.module_hash(module)
 
-        if module_hash not in list(self.get_state().qparams.keys()):
-            self.get_state().qparams[module_hash] = Dict()
-            self.get_state().qparams[module_hash].qrange_w = None
-            self.get_state().qparams[module_hash].qrange_b = None
-            self.get_state().qparams[module_hash].qrange_in = []
-            self.get_state().qparams[module_hash].qrange_out = []
-            self.get_state().qparams[module_hash].is_input = (self.module is module)
-            self.get_state().qparams[module_hash].previous_node = []
-            self.get_state().qparams[module_hash].next_node = []
-            self.get_state().qparams[module_hash].current_node = module_hash
-
-        current_node = self.get_state().qparams[module_hash].current_node
+        if module_hash not in list(self.get_qstate().qparams.keys()):
+            self.get_qstate().qparams[module_hash] = Dict()
+            self.get_qstate().qparams[module_hash].qrange_w = None
+            self.get_qstate().qparams[module_hash].qrange_b = None
+            self.get_qstate().qparams[module_hash].qrange_in = []
+            self.get_qstate().qparams[module_hash].qrange_out = []
+            self.get_qstate().qparams[module_hash].is_input = (self.module is module)
+            self.get_qstate().qparams[module_hash].previous_node = []
+            self.get_qstate().qparams[module_hash].next_node = []
+            self.get_qstate().qparams[module_hash].current_node = module_hash
+
+        current_node = self.get_qstate().qparams[module_hash].current_node
         for inp in inputs:
             if hasattr(inp, 'qparams') and hasattr(inp.qparams, 'last_node'):
                 prev_module_hash = inp.qparams.last_node
                 prev_module = self.get_module(prev_module_hash)
-                previous_node = self.get_state().qparams[module_hash].previous_node
-                next_node = self.get_state().qparams[prev_module_hash].next_node
+                previous_node = self.get_qstate().qparams[module_hash].previous_node
+                next_node = self.get_qstate().qparams[prev_module_hash].next_node
 
                 if str(inp.qparams.last_node) not in [str(p) for p in previous_node]:
-                    self.get_state().qparams[module_hash].previous_node += [inp.qparams.last_node]
+                    self.get_qstate().qparams[module_hash].previous_node += [inp.qparams.last_node]
                 if str(current_node) not in [str(n) for n in next_node]:
-                    self.get_state().qparams[prev_module_hash].next_node += [current_node]
+                    self.get_qstate().qparams[prev_module_hash].next_node += [current_node]
 
         if outputs is not None:
             outputs = self.format_tensors(outputs)
@@ -245,13 +209,13 @@ class QuantGraphModule(HookedModule):
     ################################################################
     def analyze_connections(self):
         prediction_module = None
-        for module_hash, qparams in self.get_state().qparams.items():
+        for module_hash, qparams in self.get_qstate().qparams.items():
             module = self.get_module(module_hash)
             if utils.is_conv(module) or utils.is_linear(module) or utils.is_normalization(module) or utils.is_activation(module):
                 prediction_module = module
             #
         #
-        for module_hash, qparams in self.get_state().qparams.items():
+        for module_hash, qparams in self.get_qstate().qparams.items():
             module = self.get_module(module_hash)
             is_prediction = (prediction_module is module)
             self._analyse_connections_op(module_hash, module, qparams, is_prediction)
@@ -310,9 +274,9 @@ class QuantGraphModule(HookedModule):
 
     ################################################################
     def merge_weights(self, make_backup=False):
-        assert self.get_state().analyzed_graph == True, 'graph must be analyzed before merge_weights()'
+        assert self.get_qstate().analyzed_graph == True, 'graph must be analyzed before merge_weights()'
         with torch.no_grad():
-            for module_hash, qparams in self.get_state().qparams.items():
+            for module_hash, qparams in self.get_qstate().qparams.items():
                 module = self.get_module(module_hash)
                 self._merge_weight_op(module_hash, module, qparams, make_backup)
             #
@@ -367,52 +331,62 @@ class QuantGraphModule(HookedModule):
                 warnings.warn('problem detected in conv+bn layer pair: either one of conv.bias or bn.affine is required for successfull calibration - preferably bn.affine')
             #
         #
-
         return
 
 
     ################################################################
     def get_qparams(self, module):
         module_hash = self.module_hash(module)
-        return self.get_state().qparams[module_hash]
+        return self.get_qstate().qparams[module_hash]
+
 
     def get_qparams_prev(self, module):
         module_hash = self.module_hash(module)
-        return self.get_state().qparams_prev[module_hash] if self.get_state().qparams_prev else None
+        return self.get_qstate().qparams_prev[module_hash] if self.get_qstate().qparams_prev else None
+
+
+    def start_call(self):
+        self.call_count = Dict()
+
+
+    def finish_call(self):
+        self.call_count = None
+
 
     def start_node(self, module):
         module_name = self.module_name(module)
         if module_name not in list(self.call_count.keys()):
             self.call_count[module_name] = 0
-
+        #
         return
 
+
     def finish_node(self, module, inputs, outputs):
         module_name = self.module_name(module)
         self.call_count[module_name] = self.call_count[module_name] + 1
         return
 
+
     def module_hash(self, module):
+        '''
+        A module may be called multiple times in a model. This module has creates a unique name/hash for each call
+        using teh call_count. call_count needs tobe correct for this to work as expected.
+        call_count is kep up to date by using start_node() / finish_node() calls.
+        '''
         module_name = self.module_name(module)
         module_hash = module_name + '-call:{}'.format(self.call_count[module_name])
         return module_hash
 
+
     def module_name(self, module):
         name = None
         for n, m in self.named_modules():
             if m is module:
                 name = n
+            #
         #
         return name
 
-    def module_device(self, module=None):
-        module = module if module is not None else self.module
-        try:
-            module_device = next(module.parameters()).device
-        except:
-            module_device = None
-        #
-        return module_device
 
     def get_module(self, module_hash):
         module_name = module_hash.split('-call:')[0]
@@ -422,10 +396,12 @@ class QuantGraphModule(HookedModule):
         #
         return None
 
+
     def is_last_conv(self, module):
         # implementation is not correct. disable it for the time being
         return False #(module is self.last_conv_linear_module)
 
+
     def format_tensors(self, inputs):
         # make a list/tuple if inputs is not. if it is a double list, remove the extra one
         inputs = utils.squeeze_list(utils.make_list(inputs))
@@ -436,15 +412,16 @@ class QuantGraphModule(HookedModule):
 
     def copy_qparams(self, qparams, inputs):
         qparams_copy = Dict()
-
-        # deep copy may not work in some cases, so do it conditionally
         for module_hash, qparam_entry in qparams.items():
             qparams_copy[module_hash] = Dict()
             for key, value in qparam_entry.items():
+                # deep copy may not work in some cases, so do it conditionally
                 try:
                     qparams_copy[module_hash][key] = copy.deepcopy(value)
                 except Exception:
                     qparams_copy[module_hash][key] = value
-
+                #
+            #
+        #
         return qparams_copy
 
index 55f2301e9d720b41bc69ea8859442107d60e22a4..64bce323f6a8635059ee5aac1beb572a2ad5fb55 100644 (file)
@@ -66,13 +66,13 @@ class QuantTestModule(QuantBaseModule):
         self.apply(replace_func)
 
         # clear
-        self.clear_states()
+        self.clear_qstate()
     #
 
 
     def forward(self, inputs):
         # analyze - need to merge_weights - so call analyze_graph() instead of just update_counters()
-        self.analyze_graph(inputs=inputs, force_update=True, merge_weights=True, cleanup_states=True)
+        self.analyze_graph(inputs=inputs, force_update=True, merge_weights=True)
 
         # batch_size = inputs[0].size(0) if utils.is_list(inputs) else inputs.size(0)
         # if batch_size != 1:
@@ -84,7 +84,7 @@ class QuantTestModule(QuantBaseModule):
             # quantize
             outputs = self.forward_quantize(inputs)
             # start and new frame, copy the qparams for previous frame of inference
-            self.get_state().qparams_prev = self.copy_qparams(self.get_state().qparams, inputs)
+            self.get_qstate().qparams_prev = self.copy_qparams(self.get_qstate().qparams, inputs)
             # return
             return outputs
         #
index 7fe340e705ebf742f4a1ed8f9abed618f7c89c18..e7601cdc40a60c95981beb51079c38197a37124f 100644 (file)
@@ -91,7 +91,7 @@ class QuantTrainModule(QuantBaseModule):
         self.apply(replace_func)
 
         # clear
-        self.clear_states()
+        self.clear_qstate()
     #
 
 
index e7e01f1d864d4288cb27d87a875ef6c7e1aa3bde..19b364c2fead1736bba1607b7ec8599d5da0fa74 100755 (executable)
@@ -171,9 +171,6 @@ if 'training' in args.phase and (not args.quantize):
     args.quantize = True
     args.lr = 1e-5
     args.epochs = 50
-    # quantized training will use only one GPU in the engine - so reduce the batch_size
-    num_gpus = len(str(os.environ["CUDA_VISIBLE_DEVICES"]).split(','))
-    args.batch_size = args.batch_size//num_gpus
     train_classification.main(args)
 #
 
@@ -182,6 +179,12 @@ if 'training' in args.phase and (not args.quantize):
 if 'training' in args.phase or 'calibration' in args.phase:
     save_path = train_classification.get_save_path(args)
     args.pretrained = os.path.join(save_path, 'model_best.pth')
+    if 'training' in args.phase:
+        # DataParallel isn't enabled for QuantCalibrateModule and QuantTestModule.
+        # If the previous phase was training, then it is likely that the batch_size was high and won't fit in a single gpu - reduce it.
+        num_gpus = len(str(os.environ["CUDA_VISIBLE_DEVICES"]).split(',')) if ("CUDA_VISIBLE_DEVICES" in os.environ) else None
+        args.batch_size = max(args.batch_size//num_gpus, 1) if (num_gpus is not None) else args.batch_size
+    #
     args.phase = 'validation'
     args.quantize = True
     train_classification.main(args)
index 2f13e151d210bad3de75fd945804c55b39022677..1904225e1f780de7474b17011d6a1c3bcfbca12b 100755 (executable)
@@ -151,9 +151,6 @@ if 'training' in args.phase and (not args.quantize):
     args.quantize = True
     args.lr = 1e-5
     args.epochs = 50
-    # quantized training will use only one GPU in the engine - so reduce the batch_size
-    num_gpus = len(str(os.environ["CUDA_VISIBLE_DEVICES"]).split(',')) if ("CUDA_VISIBLE_DEVICES" in os.environ) else None
-    args.batch_size = (args.batch_size//num_gpus) if (num_gpus is not None) else args.batch_size
     train_pixel2pixel.main(args)
 #
 
@@ -162,6 +159,12 @@ if 'training' in args.phase and (not args.quantize):
 if 'training' in args.phase or 'calibration' in args.phase:
     save_path = train_pixel2pixel.get_save_path(args)
     args.pretrained = os.path.join(save_path, 'model_best.pth')
+    if 'training' in args.phase:
+        # DataParallel isn't enabled for QuantCalibrateModule and QuantTestModule.
+        # If the previous phase was training, then it is likely that the batch_size was high and won't fit in a single gpu - reduce it.
+        num_gpus = len(str(os.environ["CUDA_VISIBLE_DEVICES"]).split(',')) if ("CUDA_VISIBLE_DEVICES" in os.environ) else None
+        args.batch_size = max(args.batch_size//num_gpus, 1) if (num_gpus is not None) else args.batch_size
+    #
     args.phase = 'validation'
     args.quantize = True
     train_pixel2pixel.main(args)
index a50550025ba11f49f2b74d484a9e454d47a74e98..802b5ffc5e12136c256ac9c7bd8fe9aa951fca9c 100755 (executable)
@@ -163,9 +163,6 @@ if 'training' in args.phase and (not args.quantize):
     args.quantize = True
     args.lr = 1e-5
     args.epochs = 50
-    # quantized training will use only one GPU in the engine - so reduce the batch_size
-    num_gpus = len(str(os.environ["CUDA_VISIBLE_DEVICES"]).split(',')) if ("CUDA_VISIBLE_DEVICES" in os.environ) else None
-    args.batch_size = (args.batch_size//num_gpus) if (num_gpus is not None) else args.batch_size
     train_pixel2pixel.main(args)
 #
 
@@ -174,6 +171,12 @@ if 'training' in args.phase and (not args.quantize):
 if 'training' in args.phase or 'calibration' in args.phase:
     save_path = train_pixel2pixel.get_save_path(args)
     args.pretrained = os.path.join(save_path, 'model_best.pth.tar')
+    if 'training' in args.phase:
+        # DataParallel isn't enabled for QuantCalibrateModule and QuantTestModule.
+        # If the previous phase was training, then it is likely that the batch_size was high and won't fit in a single gpu - reduce it.
+        num_gpus = len(str(os.environ["CUDA_VISIBLE_DEVICES"]).split(',')) if ("CUDA_VISIBLE_DEVICES" in os.environ) else None
+        args.batch_size = max(args.batch_size//num_gpus, 1) if (num_gpus is not None) else args.batch_size
+    #
     args.phase = 'validation'
     args.quantize = True
     train_pixel2pixel.main(args)
index a87dd911c1c56e35cfbdf6abc2fb98bf424639d0..f85b5f0edc4cf715c157c8219a9d7efd6f3ad3a6 100755 (executable)
@@ -166,9 +166,6 @@ if 'training' in args.phase and (not args.quantize):
     args.quantize = True
     args.lr = 1e-5
     args.epochs = 50
-    # quantized training will use only one GPU in the engine - so reduce the batch_size
-    num_gpus = len(str(os.environ["CUDA_VISIBLE_DEVICES"]).split(',')) if ("CUDA_VISIBLE_DEVICES" in os.environ) else None
-    args.batch_size = (args.batch_size//num_gpus) if (num_gpus is not None) else args.batch_size
     train_pixel2pixel.main(args)
 #
 
@@ -177,6 +174,12 @@ if 'training' in args.phase and (not args.quantize):
 if 'training' in args.phase or 'calibration' in args.phase:
     save_path = train_pixel2pixel.get_save_path(args)
     args.pretrained = os.path.join(save_path, 'model_best.pth')
+    if 'training' in args.phase:
+        # DataParallel isn't enabled for QuantCalibrateModule and QuantTestModule.
+        # If the previous phase was training, then it is likely that the batch_size was high and won't fit in a single gpu - reduce it.
+        num_gpus = len(str(os.environ["CUDA_VISIBLE_DEVICES"]).split(',')) if ("CUDA_VISIBLE_DEVICES" in os.environ) else None
+        args.batch_size = max(args.batch_size//num_gpus, 1) if (num_gpus is not None) else args.batch_size
+    #
     args.phase = 'validation'
     args.quantize = True
     train_pixel2pixel.main(args)