Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/models/networks/discriminator.py
Views: 813
"""1Copyright (C) 2019 NVIDIA Corporation. All rights reserved.2Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).3"""45import torch.nn as nn6import numpy as np7import torch, math8import torch.nn.functional as F9from models.networks.base_network import BaseNetwork10from models.networks.normalization import get_nonspade_norm_layer11import utils.inference.util as util121314class MultiscaleDiscriminator(BaseNetwork):15@staticmethod16def modify_commandline_options(parser, is_train):17parser.add_argument('--netD_subarch', type=str, default='n_layer',18help='architecture of each discriminator')19parser.add_argument('--num_D', type=int, default=2,20help='number of discriminators to be used in multiscale')21opt, _ = parser.parse_known_args()2223# define properties of each discriminator of the multiscale discriminator24subnetD = util.find_class_in_module(opt.netD_subarch + 'discriminator',25'models.networks.discriminator')26subnetD.modify_commandline_options(parser, is_train)2728return parser2930def __init__(self, opt):31super().__init__()32self.opt = opt3334for i in range(opt.num_D):35subnetD = self.create_single_discriminator(opt)36self.add_module('discriminator_%d' % i, subnetD)3738def create_single_discriminator(self, opt):39subarch = opt.netD_subarch40if subarch == 'n_layer':41netD = NLayerDiscriminator(opt)42else:43raise ValueError('unrecognized discriminator subarchitecture %s' % subarch)44return netD4546def downsample(self, input):47return F.avg_pool2d(input, kernel_size=3,48stride=2, padding=[1, 1],49count_include_pad=False)5051# Returns list of lists of discriminator outputs.52# The final result is of size opt.num_D x opt.n_layers_D53def forward(self, input):54result = []55get_intermediate_features = not self.opt.no_ganFeat_loss56for name, D in self.named_children():57out = D(input)58if not get_intermediate_features:59out = [out]60result.append(out)61input = self.downsample(input)6263return result646566# Defines the PatchGAN discriminator with the specified arguments.67class NLayerDiscriminator(BaseNetwork):68@staticmethod69def modify_commandline_options(parser, is_train):70parser.add_argument('--n_layers_D', type=int, default=4,71help='# layers in each discriminator')72return parser7374def __init__(self, opt):75super().__init__()76self.opt = opt7778kw = 479padw = int(np.ceil((kw - 1.0) / 2))80nf = opt.ndf81input_nc = self.compute_D_input_nc(opt)8283norm_layer = get_nonspade_norm_layer(opt, opt.norm_D)84sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw),85nn.LeakyReLU(0.2, False)]]8687for n in range(1, opt.n_layers_D):88nf_prev = nf89nf = min(nf * 2, 512)90stride = 1 if n == opt.n_layers_D - 1 else 291sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw,92stride=stride, padding=padw)),93nn.LeakyReLU(0.2, False)94]]9596sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]9798# We divide the layers into groups to extract intermediate layer outputs99for n in range(len(sequence)):100self.add_module('model' + str(n), nn.Sequential(*sequence[n]))101102def compute_D_input_nc(self, opt):103input_nc = opt.label_nc + opt.output_nc104if opt.contain_dontcare_label:105input_nc += 1106if not opt.no_instance:107input_nc += 1108return input_nc109110def forward(self, input):111results = [input]112for submodel in self.children():113intermediate_output = submodel(results[-1])114results.append(intermediate_output)115116get_intermediate_features = not self.opt.no_ganFeat_loss117if get_intermediate_features:118return results[1:]119else:120return results[-1]121122123class ScaledLeakyReLU(nn.Module):124def __init__(self, negative_slope=0.2):125super().__init__()126127self.negative_slope = negative_slope128129def forward(self, input):130out = F.leaky_relu(input, negative_slope=self.negative_slope)131132return out * math.sqrt(2)133134135def make_kernel(k):136k = torch.tensor(k, dtype=torch.float32)137138if k.ndim == 1:139k = k[None, :] * k[:, None]140141k /= k.sum()142143return k144145146class Blur(nn.Module):147def __init__(self, kernel, pad, upsample_factor=1):148super().__init__()149150kernel = make_kernel(kernel)151152if upsample_factor > 1:153kernel = kernel * (upsample_factor ** 2)154155self.register_buffer('kernel', kernel)156157self.pad = pad158159def forward(self, input):160out = upfirdn2d(input, self.kernel, pad=self.pad)161162return out163164165class EqualConv2d(nn.Module):166def __init__(167self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True168):169super().__init__()170171self.weight = nn.Parameter(172torch.randn(out_channel, in_channel, kernel_size, kernel_size)173)174self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)175176self.stride = stride177self.padding = padding178179if bias:180self.bias = nn.Parameter(torch.zeros(out_channel))181182else:183self.bias = None184185def forward(self, input):186out = F.conv2d(187input,188self.weight * self.scale,189bias=self.bias,190stride=self.stride,191padding=self.padding,192)193194return out195196197class ConvLayer(nn.Sequential):198def __init__(self, in_channel, out_channel, kernel_size,199downsample=False, blur_kernel=[1, 3, 3, 1],200bias=True, activate=True):201layers = []202203if downsample:204factor = 2205p = (len(blur_kernel) - factor) + (kernel_size - 1)206pad0 = (p + 1) // 2207pad1 = p // 2208209layers.append(Blur(blur_kernel, pad=(pad0, pad1)))210211stride = 2212self.padding = 0213214else:215stride = 1216self.padding = kernel_size // 2217218layers.append(219EqualConv2d(in_channel, out_channel, kernel_size,220padding=self.padding, stride=stride, bias=bias and not activate)221)222223if activate:224if bias:225layers.append(FusedLeakyReLU(out_channel))226else:227layers.append(ScaledLeakyReLU(0.2))228229super().__init__(*layers)230231232233234