]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blobdiff - modules/pytorch_jacinto_ai/xnn/optim/lr_scheduler.py
docs - added deprecation notice
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / optim / lr_scheduler.py
index 2c9c945dccea515d0fcbe7c0b299a2f2aae04023..e67a5320b54b5644e3901464598c3fb4c09d1108 100644 (file)
@@ -1,12 +1,79 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 import numpy as np
 import torch
 
+
+class MultiStepLRWarmup(torch.optim.lr_scheduler.MultiStepLR):
+    def __init__(self, *args, warmup_epochs=5, warmup_ratio=1e-2, **kwargs):
+        self.warmup_epochs = warmup_epochs
+        self.warmup_ratio = warmup_ratio
+        super().__init__(*args, **kwargs)
+
+    def get_lr(self):
+        if self.last_epoch == 0:
+            return [lr * self.warmup_ratio for lr in self.base_lrs]
+        elif self.last_epoch < self.warmup_epochs:
+            return [lr * self.last_epoch / self.warmup_epochs for lr in self.base_lrs]
+        elif self.last_epoch == self.warmup_epochs:
+            return self._get_closed_form_lr()
+        else:
+            return super().get_lr()
+
+
+class CosineAnnealingLRWarmup(torch.optim.lr_scheduler.CosineAnnealingLR):
+    def __init__(self, *args, warmup_epochs=5, warmup_ratio=1e-2, **kwargs):
+        self.warmup_epochs = warmup_epochs
+        self.warmup_ratio = warmup_ratio
+        super().__init__(*args, **kwargs)
+
+    def get_lr(self):
+        if self.last_epoch == 0:
+            return [lr * self.warmup_ratio for lr in self.base_lrs]
+        elif self.last_epoch < self.warmup_epochs:
+            return [lr * self.last_epoch / self.warmup_epochs for lr in self.base_lrs]
+        elif self.last_epoch == self.warmup_epochs and hasattr(self, "_get_closed_form_lr"):
+            return self._get_closed_form_lr()
+        else:
+            return super().get_lr()
+
+
 class SchedulerWrapper(torch.optim.lr_scheduler._LRScheduler):
-    def __init__(self, scheduler_type, optimizer, epochs, start_epoch=0, warmup_epochs=5, max_iter=None, \
-                 polystep_power=1.0, milestones=None, multistep_gamma=0.5):
+    def __init__(self, scheduler_type, optimizer, epochs, start_epoch=0, warmup_epochs=5, warmup_factor=None,
+                 max_iter=None, polystep_power=1.0, milestones=None, multistep_gamma=0.5):
 
         self.scheduler_type = scheduler_type
         self.warmup_epochs = warmup_epochs
+        self.warmup_factor = warmup_factor
 
         if scheduler_type == 'step':
             lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=multistep_gamma, last_epoch=start_epoch-1)
@@ -19,9 +86,12 @@ class SchedulerWrapper(torch.optim.lr_scheduler._LRScheduler):
             raise ValueError('Unknown scheduler {}'.format(scheduler_type))
         #
         self.lr_scheduler = lr_scheduler
+        self.lr_backup = [param_group['lr'] for param_group in self.lr_scheduler.optimizer.param_groups]
         if start_epoch > 0:
+            # adjust the leraning rate to that of the start_epoch
             for step in range(start_epoch):
                 self.step()
+            #
         else:
             # to take care of first iteration and set warmup lr in param_group
             self.get_lr()
@@ -29,9 +99,13 @@ class SchedulerWrapper(torch.optim.lr_scheduler._LRScheduler):
 
 
     def get_lr(self):
-        epoch = self.lr_scheduler.last_epoch + 1
+        epoch = self.lr_scheduler.last_epoch
         if self.warmup_epochs and epoch <= self.warmup_epochs:
             lr = [(epoch * l_rate / self.warmup_epochs) for l_rate in self.lr_scheduler.base_lrs]
+            if epoch == 0 and self.warmup_factor is not None:
+                warmup_lr = [w_rate*self.warmup_factor for w_rate in self.lr_scheduler.base_lrs]
+                lr = [max(l_rate, w_rate) for l_rate, w_rate in zip(lr,warmup_lr)]
+            #
         else:
             lr = [param_group['lr'] for param_group in self.lr_scheduler.optimizer.param_groups]
         #
@@ -43,7 +117,16 @@ class SchedulerWrapper(torch.optim.lr_scheduler._LRScheduler):
 
 
     def step(self):
+        # some of the scheduler implementations in torch.option may be recursive (depends on previous lr) eg. cosine
+        # so it is necessary to restore the original lr from scheduler
+        for param_group, l_rate in zip(self.lr_scheduler.optimizer.param_groups, self.lr_backup):
+            param_group['lr'] = l_rate
+        #
+        # actual scheduler call
         self.lr_scheduler.step()
+        # backup the lr from scheduler
+        self.lr_backup = [param_group['lr'] for param_group in self.lr_scheduler.optimizer.param_groups]
+        # return the lr - warmup will be applied in this step
         return self.get_lr()