torch.nn.ReLU is the recommended activation module. removed the custom defined module...
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / utils / module_utils.py
1 import torch
2 from .. import layers
5 def is_normalization(module):
6     is_norm = isinstance(module, (torch.nn.BatchNorm2d, layers.DefaultNorm2d,
7                                  torch.nn.GroupNorm, layers.GroupBatchNorm2d))
8     return is_norm
11 def is_activation(module):
12     is_act = isinstance(module, (torch.nn.ReLU, torch.nn.ReLU6, torch.nn.Hardtanh,
13                                  layers.NoAct, layers.PAct2))
14     return is_act
16 def is_pact2(module):
17     is_act = isinstance(module, (layers.PAct2))
18     return is_act
20 def is_conv(module):
21     return isinstance(module, torch.nn.Conv2d)
23 def is_deconv(module):
24     return isinstance(module, torch.nn.ConvTranspose2d)
26 def is_conv_deconv(module):
27     return isinstance(module, (torch.nn.Conv2d,torch.nn.ConvTranspose2d))
29 def is_linear(module):
30     return isinstance(module, torch.nn.Linear)
32 def is_dwconv(module):
33     return is_conv(module) and (module.weight.size(1) == 1)
36 def is_bn(module):
37     return isinstance(module, torch.nn.BatchNorm2d)
40 def get_parent_module(module, m):
41     for p in module.modules():
42         for c in p.children():
43             if c is m:
44                 return p
45     #
46     return None
50 def get_module_name(module, m):
51     for name, mod in module.named_modules():
52         if mod is m:
53             return name
54     #
55     return None
58 def is_tensor(inp):
59     return isinstance(inp, torch.Tensor)
62 def is_none(inp):
63     if is_tensor(inp):
64         return False
65     else:
66         return (inp is None)
69 def is_not_none(inp):
70     return not is_none(inp)
73 def is_list(inp):
74     return isinstance(inp, (list, tuple))
77 def is_not_list(inp):
78     return not is_list(inp)
81 def is_fixed_range(op):
82     return isinstance(op, (torch.nn.ReLU6, torch.nn.Sigmoid, torch.nn.Tanh, torch.nn.Hardtanh, \
83                            layers.PAct2))
86 def get_range(op):
87     if isinstance(op, layers.PAct2):
88         return op.get_clips_act()
89     elif isinstance(op, torch.nn.ReLU6):
90         return 0.0, 6.0
91     elif isinstance(op, torch.nn.Sigmoid):
92         return 0.0, 1.0
93     elif isinstance(op, torch.nn.Tanh):
94         return -1.0, 1.0
95     elif isinstance(op, torch.nn.Hardtanh):
96         return op.min_val, op.max_val
97     else:
98         assert False, 'dont know the range of the module'
101 def add_module_names(model):
102     for n, m in model.named_modules():
103         m.name = n
104     #
105     return model
108 def squeeze_list(inputs):
109     return inputs[0] if (is_list(inputs) and (len(inputs)==1) and is_list(inputs[0])) else inputs
112 def make_list(inputs):
113     return inputs if is_list(inputs) else (inputs,)
116 def apply_setattr(model, always=False, **kwargs):
117     assert len(kwargs) >= 1, 'atlest one keyword argument must be specified. ..=.., in addition always=.. can be specified.'
118     def setattr_func(op):
119         for name, value in kwargs.items():
120             if hasattr(op, name) or always:
121                 setattr(op, name, value)
122             #
123         #
124     #
125     model.apply(setattr_func)