]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xnn/layers/activation.py
21c94c01668b3db7d19928e3f300e67cb04d1229
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / layers / activation.py
1 import numpy as np
2 import torch
4 from .functional import *
5 from .. import utils
8 ###############################################################
9 # Parametric Activation (PACT) with clip values being power of 2
10 # Supports learned mode, estimated mode or fixed range
11 class PAct2(torch.nn.Module):
12     # constants
13     PACT2_RANGE_LEARN = False   # False : Running Avg, True  : Backprop
14     PACT2_RANGE_SHRINK = 0.01   # 0.01
15     PACT2_RANGE_INIT = 8.0      # this is the starting range
16     PACT2_RANGE_EXPANSION = 1.1 # expand the calculated range for margin
18     def __init__(self, inplace=False, signed=None, percentile_range_shrink=PACT2_RANGE_SHRINK, clip_range=None, **kwargs):
19         super().__init__()
20         if (clip_range is not None) and (signed is not None):
21             assert signed == (clip_range[0]<0.0), 'the signed flag provided did not match the clip_range provided'
22         #
23         self.inplace = inplace
24         self.clip_range = clip_range
25         self.signed = signed if (clip_range is None) else (clip_range[0]<0.0)
26         self.percentile_range_shrink = percentile_range_shrink # range shrinking factor
27         self.fixed_range = (clip_range is not None)
28         self.learn_range = (self.PACT2_RANGE_LEARN and (not self.fixed_range))
29         self.eps = np.power(2.0, -16.0)
30         self.power2 = True   # power of 2 ranges
31         self.log_base = None # 2.0  # log is used only in learned mode if log_base is not None
33         # any validation before at-least one iteration of training wll use default clip values.
34         clip_init = max(abs(np.array(clip_range))) if (clip_range is not None) else self.PACT2_RANGE_INIT
35         clip_init2 = np.power(2.0, np.ceil(np.log2(clip_init)))
37         if self.learn_range:
38             clip_signed_log = self.convert_to_log(torch.tensor(clip_init2))
39             default_clips = (-clip_signed_log, clip_signed_log) \
40                 if (self.signed == True or self.signed is None) else (0.0, clip_signed_log)
41             self.register_parameter('clips_act', torch.nn.Parameter(torch.tensor(default_clips)))
42             # Initially ranges will be dominated by running average, but eventually the update factor becomes too small.
43             # Then the backprop updates will have dominance.
44             self.range_update_factor_min = 0.0
45             self.register_buffer('num_batches_tracked', torch.tensor(-1.0))
46         else:
47             default_clips = (-clip_init2, clip_init2) \
48                 if (self.signed == True or self.signed is None) else (0.0, clip_init2)
49             self.register_buffer('clips_act', torch.tensor(default_clips))
50             # range_update_factor_min is the lower bound for exponential update factor.
51             # using 0.0 will freeze the ranges, since the update_factor becomes too small after some time
52             self.range_update_factor_min = 0.001
53             self.register_buffer('num_batches_tracked', torch.tensor(-1.0))
54         #
57     def forward(self, x, update_range=True, enable=True):
58         if (self.training and update_range):
59             self.num_batches_tracked += 1
60             # even in learn_range mode - do this for a few iterations to get a good starting point
61             if (not self.fixed_range) and ((not self.learn_range) or (self.num_batches_tracked < 100)):
62                 with torch.no_grad():
63                     self.update_scale_act(x.data)
64                 #
65             #
66         #
67         if not enable:
68             signed = self.clips_act[0] < 0.0 if (self.signed is None) else self.signed
69             return x if signed else torch.nn.functional.relu(x)
70         #
71         clips = self.get_clips_act()
72         y = clamp_g(x, clips[0], clips[1], self.training, self.inplace, requires_grad=self.learn_range)
73         return y
76     def __repr__(self):
77         clips = self.get_clips_act()
78         return '{}(inplace={}, signed={}, clips={})'.format(self.__class__.__name__, self.inplace, self.signed, clips)
81     def convert_to_log(self, x):
82         if (not self.learn_range) or (self.log_base is None):
83             return x
84         #
85         return utils.signed_log(x, self.log_base)
88     def convert_to_linear(self, x):
89         if (not self.learn_range) or (self.log_base is None):
90             return x
91         #
92         return utils.signed_pow(x, self.log_base)
95     def update_scale_act(self, x):
96         # compute the new scale
97         x_min, x_max = utils.extrema_fast(x, percentile_range_shrink=self.percentile_range_shrink)
98         x_min, x_max = self.convert_to_log(x_min*self.PACT2_RANGE_EXPANSION), self.convert_to_log(x_max*self.PACT2_RANGE_EXPANSION)
99         # exponential update factor
100         update_factor = 1.0 / float(self.num_batches_tracked if self.num_batches_tracked else 1.0)
101         update_factor = max(update_factor, self.range_update_factor_min)
102         # exponential moving average update
103         self.clips_act[0].data.mul_(1.0-update_factor).add_(x_min * update_factor)
104         self.clips_act[1].data.mul_(1.0-update_factor).add_(x_max * update_factor)
107     def get_clips_act(self):
108         # find the clip values
109         signed = self.clips_act[0] < 0.0 if (self.signed is None) else self.signed
110         clip_max = torch.max(torch.abs(self.clips_act)) if signed else torch.abs(self.clips_act[1])
111         clip_max = torch.clamp(clip_max, min=self.eps)
112         clip_max = self.convert_to_linear(clip_max)
113         # in range learning mode + training - this power2 is taken care in the quantize function
114         use_power2 = (self.power2 and (not (self.PACT2_RANGE_LEARN and self.training)))
115         clip_max2 = ceil2_g(clip_max) if use_power2 else clip_max
116         clip_min2 = (-clip_max2 if signed else clip_max2*0.0)
117         return (clip_min2, clip_max2)
120 ###############################################################
121 # return a function that creates PAct2 with the given fixed range
122 def get_fixed_pact2(inplace=False, signed=None, output_range=None):
123         def FixedPAct2Creator(inplace=inplace, signed=signed):
124             assert output_range is not None, 'output_range must be specified for FixedPact2'
125             clip_range = output_range #max(abs(np.array(output_range)))
126             signed = True if ((output_range[0] < 0.0) or (signed is True)) else signed
127             return PAct2(inplace=inplace, signed=signed, clip_range=clip_range)
128         #
129         return FixedPAct2Creator
132 ###############################################################
133 class ReLUN(torch.nn.Module):
134     def __init__(self, inplace=False, signed=False, clips=None, **kwargs):
135         super().__init__()
136         self.clips_act = clips
137         self.inplace = inplace
138         self.signed = signed
140     def forward(self, x):
141         y = torch.clamp(x, 0.0, self.clips_act)
142         return y
144     def get_clips_act(self):
145         return 0.0, self.clips_act
147     def __repr__(self):
148         return '{}(inplace={}, signed={}, clips={})'.format(self.__class__.__name__, self.inplace, self.signed, self.clips_act)
151 ###############################################################
152 class ReLU8(ReLUN):
153     def __init__(self, inplace=False, signed=False, **kwargs):
154         super().__init__(inplace, signed, (0,8.0))
157 ###############################################################
158 class NoAct(torch.nn.Module):
159     def __init__(self, inplace=False, signed=True, **kwargs):
160         super().__init__()
161         self.inplace = inplace
162         self.signed = signed
164     def forward(self, x):
165         return x