CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/network/MultiscaleDiscriminator.py
Views: 792
1
import torch.nn as nn
2
import numpy as np
3
4
5
class NLayerDiscriminator(nn.Module):
6
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):
7
super(NLayerDiscriminator, self).__init__()
8
self.getIntermFeat = getIntermFeat
9
self.n_layers = n_layers
10
11
kw = 4
12
padw = int(np.ceil((kw-1.0)/2))
13
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
14
15
nf = ndf
16
for n in range(1, n_layers):
17
nf_prev = nf
18
nf = min(nf * 2, 512)
19
sequence += [[
20
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
21
norm_layer(nf), nn.LeakyReLU(0.2, True)
22
]]
23
24
nf_prev = nf
25
nf = min(nf * 2, 512)
26
sequence += [[
27
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
28
norm_layer(nf),
29
nn.LeakyReLU(0.2, True)
30
]]
31
32
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
33
34
if use_sigmoid:
35
sequence += [[nn.Sigmoid()]]
36
37
if getIntermFeat:
38
for n in range(len(sequence)):
39
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
40
else:
41
sequence_stream = []
42
for n in range(len(sequence)):
43
sequence_stream += sequence[n]
44
self.model = nn.Sequential(*sequence_stream)
45
46
def forward(self, input):
47
if self.getIntermFeat:
48
res = [input]
49
for n in range(self.n_layers+2):
50
model = getattr(self, 'model'+str(n))
51
res.append(model(res[-1]))
52
return res[1:]
53
else:
54
return self.model(input)
55
56
57
class MultiscaleDiscriminator(nn.Module):
58
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
59
use_sigmoid=False, num_D=3, getIntermFeat=False):
60
super(MultiscaleDiscriminator, self).__init__()
61
self.num_D = num_D
62
self.n_layers = n_layers
63
self.getIntermFeat = getIntermFeat
64
65
for i in range(num_D):
66
netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
67
if getIntermFeat:
68
for j in range(n_layers + 2):
69
setattr(self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j)))
70
else:
71
setattr(self, 'layer' + str(i), netD.model)
72
73
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
74
75
def singleD_forward(self, model, input):
76
if self.getIntermFeat:
77
result = [input]
78
for i in range(len(model)):
79
result.append(model[i](result[-1]))
80
return result[1:]
81
else:
82
return [model(input)]
83
84
def forward(self, input):
85
num_D = self.num_D
86
result = []
87
input_downsampled = input
88
for i in range(num_D):
89
if self.getIntermFeat:
90
model = [getattr(self, 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) for j in
91
range(self.n_layers + 2)]
92
else:
93
model = getattr(self, 'layer' + str(num_D - 1 - i))
94
result.append(self.singleD_forward(model, input_downsampled))
95
if i != (num_D - 1):
96
input_downsampled = self.downsample(input_downsampled)
97
return result
98
99