Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/network/MultiscaleDiscriminator.py
Views: 792
import torch.nn as nn1import numpy as np234class NLayerDiscriminator(nn.Module):5def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):6super(NLayerDiscriminator, self).__init__()7self.getIntermFeat = getIntermFeat8self.n_layers = n_layers910kw = 411padw = int(np.ceil((kw-1.0)/2))12sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]1314nf = ndf15for n in range(1, n_layers):16nf_prev = nf17nf = min(nf * 2, 512)18sequence += [[19nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),20norm_layer(nf), nn.LeakyReLU(0.2, True)21]]2223nf_prev = nf24nf = min(nf * 2, 512)25sequence += [[26nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),27norm_layer(nf),28nn.LeakyReLU(0.2, True)29]]3031sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]3233if use_sigmoid:34sequence += [[nn.Sigmoid()]]3536if getIntermFeat:37for n in range(len(sequence)):38setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))39else:40sequence_stream = []41for n in range(len(sequence)):42sequence_stream += sequence[n]43self.model = nn.Sequential(*sequence_stream)4445def forward(self, input):46if self.getIntermFeat:47res = [input]48for n in range(self.n_layers+2):49model = getattr(self, 'model'+str(n))50res.append(model(res[-1]))51return res[1:]52else:53return self.model(input)545556class MultiscaleDiscriminator(nn.Module):57def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,58use_sigmoid=False, num_D=3, getIntermFeat=False):59super(MultiscaleDiscriminator, self).__init__()60self.num_D = num_D61self.n_layers = n_layers62self.getIntermFeat = getIntermFeat6364for i in range(num_D):65netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)66if getIntermFeat:67for j in range(n_layers + 2):68setattr(self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j)))69else:70setattr(self, 'layer' + str(i), netD.model)7172self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)7374def singleD_forward(self, model, input):75if self.getIntermFeat:76result = [input]77for i in range(len(model)):78result.append(model[i](result[-1]))79return result[1:]80else:81return [model(input)]8283def forward(self, input):84num_D = self.num_D85result = []86input_downsampled = input87for i in range(num_D):88if self.getIntermFeat:89model = [getattr(self, 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) for j in90range(self.n_layers + 2)]91else:92model = getattr(self, 'layer' + str(num_D - 1 - i))93result.append(self.singleD_forward(model, input_downsampled))94if i != (num_D - 1):95input_downsampled = self.downsample(input_downsampled)96return result979899