0e0f7b0e955420a08abdcab15d3400d77bb3fb14
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / utils / module_utils.py
1 import torch
2 from .. import layers
5 def is_normalization(module):
6 is_norm = isinstance(module, (torch.nn.BatchNorm2d, layers.DefaultNorm2d,
7 torch.nn.GroupNorm, layers.GroupBatchNorm2d))
8 return is_norm
11 def is_activation(module):
12 is_act = isinstance(module, (torch.nn.ReLU, torch.nn.ReLU6, layers.NoAct,
13 layers.PAct2, layers.ReLUN))
14 return is_act
16 def is_pact(module):
17 is_act = isinstance(module, (layers.PAct2))
18 return is_act
20 def is_conv(module):
21 return isinstance(module, torch.nn.Conv2d)
23 def is_deconv(module):
24 return isinstance(module, torch.nn.ConvTranspose2d)
26 def is_conv_deconv(module):
27 return isinstance(module, (torch.nn.Conv2d,torch.nn.ConvTranspose2d))
29 def is_linear(module):
30 return isinstance(module, torch.nn.Linear)
32 def is_dwconv(module):
33 return is_conv(module) and (module.weight.size(1) == 1)
36 def is_bn(module):
37 return isinstance(module, torch.nn.BatchNorm2d)
40 def get_parent_module(module, m):
41 for p in module.modules():
42 for c in p.children():
43 if c is m:
44 return p
45 #
46 return None
50 def get_module_name(module, m):
51 for name, mod in module.named_modules():
52 if mod is m:
53 return name
54 #
55 return None
58 def is_tensor(inp):
59 return isinstance(inp, torch.Tensor)
62 def is_none(inp):
63 if is_tensor(inp):
64 return False
65 else:
66 return (inp is None)
69 def is_not_none(inp):
70 return not is_none(inp)
73 def is_list(inp):
74 return isinstance(inp, (list, tuple))
77 def is_not_list(inp):
78 return not is_list(inp)
81 def is_fixed_range(op):
82 return isinstance(op, (torch.nn.ReLU6, torch.nn.Sigmoid, torch.nn.Tanh, torch.nn.Hardtanh, \
83 layers.PAct2, layers.ReLUN))
86 def get_range(op):
87 if isinstance(op, layers.PAct2):
88 return op.get_clips_act()
89 elif isinstance(op, torch.nn.ReLUN):
90 return op.get_clips_act()
91 elif isinstance(op, torch.nn.ReLU6):
92 return 0.0, 6.0
93 elif isinstance(op, torch.nn.Sigmoid):
94 return 0.0, 1.0
95 elif isinstance(op, (torch.nn.Tanh,torch.nn.Hardtanh)):
96 return -1.0, 1.0
97 else:
98 assert False, 'dont know the range of the module'
101 def add_module_names(model):
102 for n, m in model.named_modules():
103 m.name = n
104 #
105 return model
108 def squeeze_list(inputs):
109 return inputs[0] if (is_list(inputs) and (len(inputs)==1) and is_list(inputs[0])) else inputs
112 def make_list(inputs):
113 return inputs if is_list(inputs) else (inputs,)
116 def apply_setattr(model, always=False, **kwargs):
117 assert len(kwargs) >= 1, 'atlest one keyword argument must be specified. ..=.., in addition always=.. can be specified.'
118 def setattr_func(op):
119 for name, value in kwargs.items():
120 if hasattr(op, name) or always:
121 setattr(op, name, value)
122 #
123 #
124 #
125 model.apply(setattr_func)