]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xnn/layers/activation.py
d89e6eea496517cbc2c320905e891abf2ba20b15
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / layers / activation.py
1 import numpy as np
2 import torch
4 from .functional import *
5 from .. import utils
8 ###############################################################
9 # Parametric Activation (PACT) with clip values being power of 2
10 # Supports learned mode, estimated mode or fixed range
11 class PAct2(torch.nn.Module):
12     # constants
13     PACT2_RANGE_LEARN = False   # False : Running Avg, True  : Backprop
14     PACT2_RANGE_SHRINK = 0.01   # 0.01
15     PACT2_RANGE_INIT = 8.0      # this is the starting range
16     PACT2_RANGE_EXPANSION = 1.0 # expand the calculated range for margin
18     def __init__(self, inplace=False, signed=None, range_shrink_percentile=PACT2_RANGE_SHRINK, clip_range=None,
19                  power2_activation_range=True, **kwargs):
20         super().__init__()
21         if (clip_range is not None) and (signed is not None):
22             assert signed == (clip_range[0]<0.0), 'the signed flag provided did not match the clip_range provided'
23         #
24         self.inplace = inplace
25         self.clip_range = clip_range
26         self.signed = signed if (clip_range is None) else (clip_range[0]<0.0)
27         self.range_shrink_percentile = range_shrink_percentile # range shrinking factor
28         self.fixed_range = (clip_range is not None)
29         self.learn_range = (self.PACT2_RANGE_LEARN and (not self.fixed_range))
30         self.eps = np.power(2.0, -16.0)
31         self.power2_activation_range = power2_activation_range   # power of 2 ranges
32         self.log_base = None # 2.0  # log is used only in learned mode if log_base is not None
33         self.range_estimator = None
35         # any validation before at-least one iteration of training wll use default clip values.
36         clip_init = max(abs(np.array(clip_range))) if (clip_range is not None) else self.PACT2_RANGE_INIT
37         clip_init2 = np.power(2.0, np.ceil(np.log2(clip_init)))
39         if self.learn_range:
40             clip_signed_log = self.convert_to_log(torch.tensor(clip_init2))
41             default_clips = (-clip_signed_log, clip_signed_log) \
42                 if (self.signed == True or self.signed is None) else (0.0, clip_signed_log)
43             self.register_parameter('clips_act', torch.nn.Parameter(torch.tensor(default_clips, dtype=torch.float32)))
44             # Initially ranges will be dominated by running average, but eventually the update factor becomes too small.
45             # Then the backprop updates will have dominance.
46             self.range_update_factor_min = 0.0
47             self.register_buffer('num_batches_tracked', torch.tensor(-1.0, dtype=torch.float32))
48         else:
49             default_clips = (-clip_init2, clip_init2) \
50                 if (self.signed == True or self.signed is None) else (0.0, clip_init2)
51             self.register_buffer('clips_act', torch.tensor(default_clips, dtype=torch.float32))
52             # range_update_factor_min is the lower bound for exponential update factor.
53             # using 0.0 will freeze the ranges, since the update_factor becomes too small after some time
54             self.range_update_factor_min = 0.001
55             self.register_buffer('num_batches_tracked', torch.tensor(-1.0, dtype=torch.float32))
57             if utils.has_range_estimator:
58                 self.range_estimator = utils.RangeEstimator(range_shrink_percentile=range_shrink_percentile,
59                                                             range_update_factor_min=self.range_update_factor_min)
60             #
61         #
64     def forward(self, x, update_activation_range=True, enable=True):
65         if (self.training and update_activation_range):
66             self.num_batches_tracked += 1
67             # even in learn_range mode - do this for a few iterations to get a good starting point
68             if (not self.fixed_range) and ((not self.learn_range) or (self.num_batches_tracked < 100)):
69                 with torch.no_grad():
70                     self.update_clips_act(x.data)
71                 #
72             #
73         #
74         if not enable:
75             signed = self.clips_act[0] < 0.0 if (self.signed is None) else self.signed
76             y = x if signed else torch.nn.functional.relu(x)
77         else:
78             clips = self.get_clips_act()
79             y = clamp_g(x, clips[0], clips[1], self.training, self.inplace, requires_grad=self.learn_range)
80         #
81         return y
84     def __repr__(self):
85         clips = self.get_clips_act()
86         return '{}(inplace={}, signed={}, clips={})'.format(self.__class__.__name__, self.inplace, self.signed, clips)
89     def convert_to_log(self, x):
90         if (not self.learn_range) or (self.log_base is None):
91             return x
92         #
93         return utils.signed_log(x, self.log_base)
96     def convert_to_linear(self, x):
97         if (not self.learn_range) or (self.log_base is None):
98             return x
99         #
100         return utils.signed_pow(x, self.log_base)
103     def update_clips_act(self, x):
104         if self.learn_range or (self.range_estimator is None):
105             x_min, x_max = utils.extrema_fast(x, range_shrink_percentile=self.range_shrink_percentile)
106             x_min, x_max = self.convert_to_log(x_min*self.PACT2_RANGE_EXPANSION), self.convert_to_log(x_max*self.PACT2_RANGE_EXPANSION)
107             # exponential update factor
108             update_factor = 1.0 / float(self.num_batches_tracked if self.num_batches_tracked else 1.0)
109             update_factor = max(update_factor, self.range_update_factor_min)
110             # exponential moving average update
111             self.clips_act[0].data.mul_(1.0-update_factor).add_(x_min * update_factor)
112             self.clips_act[1].data.mul_(1.0-update_factor).add_(x_max * update_factor)
113         else:
114             mn, mx = self.range_estimator(x)
115             self.clips_act[0].data.fill_(mn)
116             self.clips_act[1].data.fill_(mx)
117         #
120     def get_clips_act(self):
121         # find the clip values
122         signed = self.clips_act[0] < 0.0 if (self.signed is None) else self.signed
123         clip_max = torch.max(torch.abs(self.clips_act)) if signed else torch.abs(self.clips_act[1])
124         clip_max = torch.clamp(clip_max, min=self.eps)
125         clip_max = self.convert_to_linear(clip_max)
126         # in range learning mode + training - this power2_activation_range is taken care in the quantize function
127         is_learning_range = (self.PACT2_RANGE_LEARN and self.training)
128         use_power2_activation_range = (self.power2_activation_range and (not is_learning_range))
129         clip_max2 = ceil2_g(clip_max) if use_power2_activation_range else clip_max
130         clip_min2 = (-clip_max2 if signed else clip_max2*0.0)
131         return (clip_min2, clip_max2)
134 ###############################################################
135 # return a function that creates PAct2 with the given fixed range
136 # remember: this function returns a type and not an instance
137 def get_fixed_pact2_type(inplace=False, signed=None, output_range=None):
138         def FixedPAct2Type(inplace=inplace, signed=signed):
139             assert output_range is not None, 'output_range must be specified for FixedPact2'
140             clip_range = output_range #max(abs(np.array(output_range)))
141             signed = True if ((output_range[0] < 0.0) or (signed is True)) else signed
142             return PAct2(inplace=inplace, signed=signed, clip_range=clip_range)
143         #
144         return FixedPAct2Type
147 ###############################################################
148 # return a derivative of Hardtanh with the given fixed range
149 # remember: this function returns a type and not an instance
150 def get_fixed_hardtanh_type(*args, **kwargs):
151         class FixedHardtanhType(torch.nn.Hardtanh):
152             def __init__(self, *args_, **kwargs_):
153                 super().__init__(*args, **kwargs)
154         #
155         return FixedHardtanhType
158 ###############################################################
159 class ReLU1(torch.nn.Hardtanh):
160     def __init__(self, min_val=0., max_val=1., inplace=False):
161         super().__init__(min_val=min_val, max_val=max_val, inplace=inplace)
164 ###############################################################
165 # Always quantized activation function.
166 # Inserting this activation function is a simple way to ensure quantization happens at certain places.
167 class QAct(torch.nn.Module):
168     def __init__(self, inplace=False, signed=True, **kwargs):
169         super().__init__()
170         self.inplace = inplace
171         self.signed = signed
173     def forward(self, x):
174         return x
177 # Never quantized activation function.
178 # Also if the next block is this, the previous block output is also not quantized.
179 # Inserting this activation function is a simple way to avoid quantization at certain places.
180 class NoQAct(torch.nn.Module):
181     def __init__(self, inplace=False, signed=True, **kwargs):
182         super().__init__()
183         self.inplace = inplace
184         self.signed = signed
186     def forward(self, x):
187         return x