0e9b4b40c5a6905c7d412cfaa90ed1ada1bcabab
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / layers / common_blocks.py
1 import torch
2 from . import functional
4 ###########################################################
5 # add
6 class AddBlock(torch.nn.Module):
7 def __init__(self, inplace=False, signed=True, *args, **kwargs):
8 super().__init__()
9 self.inplace = inplace
10 self.signed = signed
12 def forward(self, x):
13 assert isinstance(x, (list,tuple)), 'input to add block must be a list or tuple'
14 y = x[0]
15 for i in range(1,len(x)):
16 y = y + x[i]
17 #
18 return y
20 def __repr__(self):
21 return 'AddBlock(inplace={}, signed={})'.format(self.inplace, self.signed)
23 # sub
24 class SubtractBlock(torch.nn.Module):
25 def __init__(self, inplace=False, signed=True, *args, **kwargs):
26 super().__init__()
27 self.inplace = inplace
28 self.signed = signed
30 def forward(self, x):
31 assert isinstance(x, (list,tuple)), 'input to sub block must be a list or tuple'
32 y = x[0]
33 for i in range(1,len(x)):
34 y = y - x[i]
35 #
36 return y
38 def __repr__(self):
39 return 'SubtractBlock(inplace={}, signed={})'.format(self.inplace, self.signed)
42 ###########################################################
43 # mult
44 class MultBlock(torch.nn.Module):
45 def __init__(self, inplace=False, signed=True, *args, **kwargs):
46 super().__init__()
47 self.inplace = inplace
48 self.signed = signed
50 def forward(self, x):
51 assert isinstance(x, (list,tuple)), 'input to add block must be a list or tuple'
52 y = x[0]
53 for i in range(1,len(x)):
54 y = y * x[i]
55 #
56 return y
58 def __repr__(self):
59 return 'MultBlock(inplace={}, signed={})'.format(self.inplace, self.signed)
62 ###########################################################
63 # cat
64 class CatBlock(torch.nn.Module):
65 def __init__(self, inplace=False, signed=True, *args, **kwargs):
66 super().__init__()
67 self.inplace = inplace
68 self.signed = signed
70 def forward(self, x):
71 assert isinstance(x, (list,tuple)), 'input to add block must be a list or tuple'
72 y = torch.cat(x, dim=1)
73 return y
75 def __repr__(self):
76 return 'CatBlock(inplace={}, signed={})'.format(self.inplace, self.signed)
79 ###########################################################
80 # moving sum
81 class MovingSumBlock(torch.nn.Module):
82 def __init__(self):
83 super().__init__()
84 self.prev_x = 0
86 def forward(self, x):
87 y = x + self.prev_x
88 self.prev_x = x
89 return y
92 ###########################################################
93 # a bypass block that does nothing - can be used as placeholder
94 BypassBlock = torch.nn.Identity
98 ###########################################################
99 # convert to linear view that van be given to a fully connected layer
100 ViewAsLinear = torch.nn.Flatten
103 ###########################################################
104 # Split channel-wise and return the first part
105 class SplitChannelsTakeFirst(torch.nn.Module):
106 def __init__(self, splits=2):
107 super().__init__()
108 self.splits = splits
110 def forward(self, x):
111 parts = functional.channel_split_by_chunks(x, self.splits)
112 return parts[0]
114 def __repr__(self):
115 return 'SplitChannelsTakeFirst(splits={})'.format(self.splits)
118 ###########################################################
119 # Split channel-wise and return the first part
120 class SplitChannelsTakeLast(torch.nn.Module):
121 def __init__(self, splits=2):
122 super().__init__()
123 self.splits = splits
125 def forward(self, x):
126 parts = functional.channel_split_by_chunks(x, self.splits)
127 return parts[-1]
129 def __repr__(self):
130 return 'SplitChannelsTakeLast(splits={})'.format(self.splits)
133 ###########################################################
134 # Split channel-wise and add the parts
135 class SplitChannelsAdd(torch.nn.Module):
136 def __init__(self, splits=2):
137 super().__init__()
138 self.splits = splits
140 def forward(self, x):
141 parts = functional.channel_split_by_chunks(x, self.splits)
142 sum = parts[0]
143 for part in parts[1:]:
144 sum = sum + part
145 #
146 return sum
148 def __repr__(self):
149 return 'SplitChannelsAdd(splits={})'.format(self.splits)
152 ###########################################################
153 # Split channel-wise and return the first part
154 class SplitListTakeFirst(torch.nn.Module):
155 def __init__(self):
156 super().__init__()
158 def forward(self, x):
159 return x[0]
161 def __repr__(self):
162 return 'SplitListTakeFirst()'
165 ###############################################################
166 # Parallel as oposed to Sequential
167 class ParallelBlock(torch.nn.Module):
168 def __init__(self, *args):
169 super().__init__()
170 assert (len(args)==2), 'for now supporting only two modules in parallel'
171 #store it as modulelist for cuda() to work
172 self.blocks = torch.nn.ModuleList(args)
173 self.split = len(self.blocks)
175 def forward(self, x):
176 x_spits = functional.channel_split_by_chunks(x, self.split)
177 out_splits = [None]*self.split
178 for id, blk in enumerate(self.blocks):
179 out_splits[id] = blk(x_spits[id])
181 x = torch.cat(out_splits, dim=1)
182 return x
186 ###############################################################
187 class ShuffleBlock(torch.nn.Module):
188 def __init__(self,groups):
189 super().__init__()
190 self.groups = groups
191 def forward(self,x):
192 if self.groups > 1:
193 return functional.channel_shuffle(x, groups=self.groups)
194 else:
195 return x