Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/models/networks/base_network.py
Views: 813
"""1Copyright (C) 2019 NVIDIA Corporation. All rights reserved.2Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).3"""45import torch.nn as nn6from torch.nn import init789class BaseNetwork(nn.Module):10def __init__(self):11super(BaseNetwork, self).__init__()1213@staticmethod14def modify_commandline_options(parser, is_train):15return parser1617def print_network(self):18if isinstance(self, list):19self = self[0]20num_params = 021for param in self.parameters():22num_params += param.numel()23print('Network [%s] was created. Total number of parameters: %.1f million. '24'To see the architecture, do print(network).'25% (type(self).__name__, num_params / 1000000))2627def init_weights(self, init_type='normal', gain=0.02):28def init_func(m):29classname = m.__class__.__name__30if classname.find('BatchNorm2d') != -1:31if hasattr(m, 'weight') and m.weight is not None:32init.normal_(m.weight.data, 1.0, gain)33if hasattr(m, 'bias') and m.bias is not None:34init.constant_(m.bias.data, 0.0)35elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):36if init_type == 'normal':37init.normal_(m.weight.data, 0.0, gain)38elif init_type == 'xavier':39init.xavier_normal_(m.weight.data, gain=gain)40elif init_type == 'xavier_uniform':41init.xavier_uniform_(m.weight.data, gain=1.0)42elif init_type == 'kaiming':43init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')44elif init_type == 'orthogonal':45init.orthogonal_(m.weight.data, gain=gain)46elif init_type == 'none': # uses pytorch's default init method47m.reset_parameters()48else:49raise NotImplementedError('initialization method [%s] is not implemented' % init_type)50if hasattr(m, 'bias') and m.bias is not None:51init.constant_(m.bias.data, 0.0)5253self.apply(init_func)5455# propagate to children56for m in self.children():57if hasattr(m, 'init_weights'):58m.init_weights(init_type, gain)596061