[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / optim / lr_scheduler.py
1 import numpy as np
2 import torch
4 class SchedulerWrapper(torch.optim.lr_scheduler._LRScheduler):
5 def __init__(self, scheduler_type, optimizer, epochs, start_epoch=0, warmup_epochs=5, max_iter=None, \
6 polystep_power=1.0, milestones=None, multistep_gamma=0.5):
8 self.scheduler_type = scheduler_type
9 self.warmup_epochs = warmup_epochs
11 if scheduler_type == 'step':
12 lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=multistep_gamma, last_epoch=start_epoch-1)
13 elif scheduler_type == 'poly':
14 lambda_scheduler = lambda iter: ((1.0-iter/max_iter)**polystep_power)
15 lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_scheduler, last_epoch=start_epoch-1)
16 elif scheduler_type == 'cosine':
17 lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0, last_epoch=start_epoch-1)
18 else:
19 raise ValueError('Unknown scheduler {}'.format(scheduler_type))
20 #
21 self.lr_scheduler = lr_scheduler
22 if start_epoch > 0:
23 for step in range(start_epoch):
24 self.step()
25 else:
26 # to take care of first iteration and set warmup lr in param_group
27 self.get_lr()
28 #
31 def get_lr(self):
32 epoch = self.lr_scheduler.last_epoch + 1
33 if self.warmup_epochs and epoch <= self.warmup_epochs:
34 lr = [(epoch * l_rate / self.warmup_epochs) for l_rate in self.lr_scheduler.base_lrs]
35 else:
36 lr = [param_group['lr'] for param_group in self.lr_scheduler.optimizer.param_groups]
37 #
38 lr = [max(l_rate,0.0) for l_rate in lr]
39 for param_group, l_rate in zip(self.lr_scheduler.optimizer.param_groups, lr):
40 param_group['lr'] = l_rate
41 #
42 return lr
45 def step(self):
46 self.lr_scheduler.step()
47 return self.get_lr()
50 def load_state_dict(self, state):
51 self.lr_scheduler.load_state_dict(state)
54 def state_dict(self):
55 return self.lr_scheduler.state_dict()