Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/models/networks/base_network.py
Views: 813
1
"""
2
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
"""
5
6
import torch.nn as nn
7
from torch.nn import init
8
9
10
class BaseNetwork(nn.Module):
11
def __init__(self):
12
super(BaseNetwork, self).__init__()
13
14
@staticmethod
15
def modify_commandline_options(parser, is_train):
16
return parser
17
18
def print_network(self):
19
if isinstance(self, list):
20
self = self[0]
21
num_params = 0
22
for param in self.parameters():
23
num_params += param.numel()
24
print('Network [%s] was created. Total number of parameters: %.1f million. '
25
'To see the architecture, do print(network).'
26
% (type(self).__name__, num_params / 1000000))
27
28
def init_weights(self, init_type='normal', gain=0.02):
29
def init_func(m):
30
classname = m.__class__.__name__
31
if classname.find('BatchNorm2d') != -1:
32
if hasattr(m, 'weight') and m.weight is not None:
33
init.normal_(m.weight.data, 1.0, gain)
34
if hasattr(m, 'bias') and m.bias is not None:
35
init.constant_(m.bias.data, 0.0)
36
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
37
if init_type == 'normal':
38
init.normal_(m.weight.data, 0.0, gain)
39
elif init_type == 'xavier':
40
init.xavier_normal_(m.weight.data, gain=gain)
41
elif init_type == 'xavier_uniform':
42
init.xavier_uniform_(m.weight.data, gain=1.0)
43
elif init_type == 'kaiming':
44
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
45
elif init_type == 'orthogonal':
46
init.orthogonal_(m.weight.data, gain=gain)
47
elif init_type == 'none': # uses pytorch's default init method
48
m.reset_parameters()
49
else:
50
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
51
if hasattr(m, 'bias') and m.bias is not None:
52
init.constant_(m.bias.data, 0.0)
53
54
self.apply(init_func)
55
56
# propagate to children
57
for m in self.children():
58
if hasattr(m, 'init_weights'):
59
m.init_weights(init_type, gain)
60
61