]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blobdiff - modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py
default of model_surgery_quantize is now True for QuantTestModule.
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_graph_module.py
index 3f40d38e49dfe90ebafe1df7bfb0cb7cf9425ad7..1086cb559fa50063a7d95bd99e3adff9a9c95d4c 100644 (file)
@@ -11,9 +11,9 @@ class QuantGraphModule(HookedModule):
         super().__init__()
         self.module = module
         self.init_qstate()
-        self.register_buffer('num_batches_tracked', torch.tensor(-1.0))
-        self.register_buffer('iter_in_epoch', torch.tensor(-1.0))
-        self.register_buffer('epoch', torch.tensor(-1.0))
+        self.num_batches_tracked = -1
+        self.iter_in_epoch = -1
+        self.epoch = -1
 
         # TBD: is this required
         # # if the original module has load_weights, add it to the quant module also
@@ -54,13 +54,13 @@ class QuantGraphModule(HookedModule):
 
 
     def update_counters(self, force_update=False):
+        self.iter_in_epoch += 1
         if self.training or force_update:
             self.num_batches_tracked += 1
-            if self.num_batches_tracked == 0:
+            if self.iter_in_epoch == 0:
                 self.epoch += 1.0
             #
         #
-        self.iter_in_epoch += 1
     #
 
     # force_update is used to increment inte counters even in non training
@@ -134,7 +134,7 @@ class QuantGraphModule(HookedModule):
 
 
     def train(self, mode=True):
-        self.iter_in_epoch.fill_(-1.0)
+        self.iter_in_epoch = -1
         super().train(mode)