]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blobdiff - modules/pytorch_jacinto_ai/xnn/optim/lr_scheduler.py
updated quantization modules to support mmdetection, using Hardtanh for fixed range...
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / optim / lr_scheduler.py
index a7b44609f1749894ab406d2a07ec6ff011ff44d4..2c9c945dccea515d0fcbe7c0b299a2f2aae04023 100644 (file)
@@ -11,7 +11,7 @@ class SchedulerWrapper(torch.optim.lr_scheduler._LRScheduler):
         if scheduler_type == 'step':
             lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=multistep_gamma, last_epoch=start_epoch-1)
         elif scheduler_type == 'poly':
-            lambda_scheduler = lambda iter: ((1.0-iter/max_iter)**polystep_power)
+            lambda_scheduler = lambda last_epoch: ((1.0-last_epoch/epochs)**polystep_power)
             lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_scheduler, last_epoch=start_epoch-1)
         elif scheduler_type == 'cosine':
             lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0, last_epoch=start_epoch-1)