support Hardtanh activation function also in quantization aware training
authorManu Mathew <a0393608@ti.com>
Thu, 12 Mar 2020 08:29:16 +0000 (13:59 +0530)
committerManu Mathew <a0393608@ti.com>
Thu, 12 Mar 2020 08:30:50 +0000 (14:00 +0530)
modules/pytorch_jacinto_ai/engine/infer_pixel2pixel.py
modules/pytorch_jacinto_ai/engine/infer_pixel2pixel_onnx_rt.py
modules/pytorch_jacinto_ai/engine/test_pixel2pixel_onnx.py
modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py
modules/pytorch_jacinto_ai/vision/models/__init__.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py
modules/pytorch_jacinto_ai/xnn/utils/module_utils.py

index b7e45d91b8c9a0f65917c33d1edfd6e30de5daf9..94c01b8f476260fc7ce0357cb37222ba2405eb98 100644 (file)
@@ -78,7 +78,6 @@ def get_config():
 
     args.multistep_gamma = 0.5                  # steps for step scheduler
     args.polystep_power = 1.0                   # power for polynomial scheduler
-    args.train_fwbw = False                     # do forward backward step while training
 
     args.rand_seed = 1                          # random seed
     args.img_border_crop = None                 # image border crop rectangle. can be relative or absolute
index e243d9ddd4be2721038612f8d40602d129ca48a3..5823c53deb5790c730d1132a3734b2a71451bb13 100644 (file)
@@ -83,7 +83,6 @@ def get_config():
 
     args.multistep_gamma = 0.5                  # steps for step scheduler
     args.polystep_power = 1.0                   # power for polynomial scheduler
-    args.train_fwbw = False                     # do forward backward step while training
 
     args.rand_seed = 1                          # random seed
     args.img_border_crop = None                 # image border crop rectangle. can be relative or absolute
index 2b8ccfba5c7572880e8e1edc52653460d6275df5..84978fefaafb5975f49b98e7f7ea5655e68f22d9 100644 (file)
@@ -70,7 +70,6 @@ def get_config():
 
     args.multistep_gamma = 0.5                  # steps for step scheduler
     args.polystep_power = 1.0                   # power for polynomial scheduler
-    args.train_fwbw = False                     # do forward backward step while training
 
     args.rand_seed = 1                          # random seed
     args.img_border_crop = None                 # image border crop rectangle. can be relative or absolute
index fb5837c9fd8387b2d60d49f292896f3f87747efa..605cef2d44ae3464aa7b8e98f549fcdb59211501 100644 (file)
@@ -112,7 +112,6 @@ def get_config():
 
     args.multistep_gamma = 0.5                          # steps for step scheduler
     args.polystep_power = 1.0                           # power for polynomial scheduler
-    args.train_fwbw = False                             # do forward backward step while training
 
     args.rand_seed = 1                                  # random seed
     args.img_border_crop = None                         # image border crop rectangle. can be relative or absolute
index e7329cfcb65cd77f91099379af9b8774c97fff72..cab21e92b395d4c6c25c28656d72ae30f532d0be 100644 (file)
@@ -19,9 +19,6 @@ from . import classification
 # derived / multi purpose / multi input models
 from .multi_input_net import *
 
-try: from .fwbwnet_internal import *
-except: pass
-
 try: from .mobilenetv2_ericsun_internal import *
 except: pass
 
@@ -31,10 +28,6 @@ except: pass
 try: from .mobilenetv2_shicai_internal import *
 except: pass
 
-try: from .flownetbase_internal import *
-except: pass
-
-
 @property
 def name():
     return 'pytorch_jacinto_ai.vision.models'
index 9a0161d11c3f2e5f30bb03a1a7e92f74c32196de..31cbbc35452193f1b9551d051eacb1335779bd02 100644 (file)
@@ -141,6 +141,8 @@ class QuantGraphModule(HookedModule):
                 if utils.is_activation(module):
                     if isinstance(module, (torch.nn.ReLU, torch.nn.ReLU6, layers.ReLUN)):
                         activation_q = layers.PAct2(signed=False)
+                    elif isinstance(module, torch.nn.Hardtanh):
+                        activation_q = layers.PAct2(clip_range=(module.min_val, module.max_val))
                     else:
                         activation_q = layers.PAct2(signed=None)
                     #
index 2e048aad99d578e36237f2de8493463c3a453036..20b03577bde09060c3966a57538aac05c2ad12cb 100644 (file)
@@ -9,8 +9,8 @@ def is_normalization(module):
 
 
 def is_activation(module):
-    is_act = isinstance(module, (torch.nn.ReLU, torch.nn.ReLU6, layers.NoAct,
-                                 layers.PAct2, layers.ReLUN))
+    is_act = isinstance(module, (torch.nn.ReLU, torch.nn.ReLU6, torch.nn.Hardtanh,
+                                 layers.NoAct, layers.PAct2, layers.ReLUN))
     return is_act
 
 def is_pact2(module):
@@ -92,8 +92,10 @@ def get_range(op):
         return 0.0, 6.0
     elif isinstance(op, torch.nn.Sigmoid):
         return 0.0, 1.0
-    elif isinstance(op, (torch.nn.Tanh,torch.nn.Hardtanh)):
+    elif isinstance(op, torch.nn.Tanh):
         return -1.0, 1.0
+    elif isinstance(op, torch.nn.Hardtanh):
+        return op.min_val, op.max_val
     else:
         assert False, 'dont know the range of the module'