CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/models/pix2pix_model.py
Views: 792
1
"""
2
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
"""
5
6
import torch
7
import models.networks as networks
8
import utils.inference.util as util
9
import random
10
11
12
class Pix2PixModel(torch.nn.Module):
13
@staticmethod
14
def modify_commandline_options(parser, is_train):
15
networks.modify_commandline_options(parser, is_train)
16
return parser
17
18
def __init__(self, opt):
19
super().__init__()
20
self.opt = opt
21
self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \
22
else torch.FloatTensor
23
self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \
24
else torch.ByteTensor
25
26
self.netG, self.netD, self.netE = self.initialize_networks(opt)
27
28
# set loss functions
29
if opt.isTrain:
30
self.criterionGAN = networks.GANLoss(
31
opt.gan_mode, tensor=self.FloatTensor, opt=self.opt)
32
self.criterionFeat = torch.nn.L1Loss()
33
if not opt.no_vgg_loss:
34
self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
35
if opt.use_vae:
36
self.KLDLoss = networks.KLDLoss()
37
38
# Entry point for all calls involving forward pass
39
# of deep networks. We used this approach since DataParallel module
40
# can't parallelize custom functions, we branch to different
41
# routines based on |mode|.
42
def forward(self, data, mode):
43
input_semantics, real_image = self.preprocess_input(data)
44
# input_semantics, real_image = data['label'], data['image']
45
46
if mode == 'generator':
47
g_loss, generated = self.compute_generator_loss(input_semantics, real_image)
48
return g_loss, generated
49
elif mode == 'discriminator':
50
d_loss = self.compute_discriminator_loss(
51
input_semantics, real_image)
52
return d_loss
53
elif mode == 'inference':
54
with torch.no_grad():
55
fake_image = self.generate_fake(input_semantics)
56
return fake_image
57
elif mode == 'inference2':
58
with torch.no_grad():
59
fake_image = self.netG(input_semantics)
60
return fake_image
61
else:
62
raise ValueError("|mode| is invalid")
63
64
def preprocess_input(self, data):
65
if self.use_gpu():
66
data['label'] = data['label'].cuda()
67
data['image'] = data['image'].cuda()
68
69
return data['label'], data['image']
70
71
def compute_generator_loss(self, input_semantics, real_image):
72
G_losses = {}
73
74
fake_image = self.generate_fake(input_semantics)
75
76
pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image)
77
78
G_losses['GAN'] = self.criterionGAN(pred_fake, True,
79
for_discriminator=False)
80
81
if not self.opt.no_ganFeat_loss:
82
num_D = len(pred_fake)
83
GAN_Feat_loss = self.FloatTensor(1).fill_(0)
84
for i in range(num_D): # for each discriminator
85
# last output is the final prediction, so we exclude it
86
num_intermediate_outputs = len(pred_fake[i]) - 1
87
for j in range(num_intermediate_outputs): # for each layer output
88
unweighted_loss = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach())
89
GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat / num_D
90
G_losses['GAN_Feat'] = GAN_Feat_loss
91
92
h ,w = fake_image.shape[-2:]
93
if not self.opt.no_vgg_loss and min(w,h)>=64:
94
G_losses['VGG'] = self.criterionVGG(fake_image, real_image) \
95
* self.opt.lambda_vgg
96
97
return G_losses, fake_image
98
99
def compute_discriminator_loss(self, input_semantics, real_image):
100
D_losses = {}
101
with torch.no_grad():
102
fake_image = self.generate_fake(input_semantics)
103
fake_image = fake_image.detach()
104
fake_image.requires_grad_()
105
106
pred_fake, pred_real = self.discriminate(
107
input_semantics, fake_image, real_image)
108
109
D_losses['D_Fake'] = self.criterionGAN(pred_fake, False,
110
for_discriminator=True)
111
D_losses['D_real'] = self.criterionGAN(pred_real, True,
112
for_discriminator=True)
113
114
return D_losses
115
116
def generate_fake(self, input_semantics):
117
# input_semantics = torch.nn.functional.interpolate(input_semantics, size=(h//4, w//4),
118
# mode='nearest')#[:, :, ::4, ::4]
119
120
fake_image = self.netG(input_semantics)
121
122
return fake_image
123
124
def discriminate(self, input_semantics, fake_image, real_image):
125
h, w = fake_image.shape[-2:]
126
if fake_image.shape[-2:]!=input_semantics.shape[-2:]:
127
semantics = torch.nn.functional.interpolate(input_semantics, (h, w))
128
real = torch.nn.functional.interpolate(real_image, (h, w))
129
fake_concat = torch.cat([semantics, fake_image], dim=1)
130
real_concat = torch.cat([semantics, real], dim=1)
131
else:
132
fake_concat = torch.cat([input_semantics, fake_image], dim=1)
133
real_concat = torch.cat([input_semantics, real_image], dim=1)
134
# fake_concat = fake_image
135
# real_concat = real_image
136
137
# In Batch Normalization, the fake and real images are
138
# recommended to be in the same batch to avoid disparate
139
# statistics in fake and real images.
140
# So both fake and real images are fed to D all at once.
141
fake_and_real = torch.cat([fake_concat, real_concat], dim=0)
142
143
discriminator_out = self.netD(fake_and_real)
144
145
pred_fake, pred_real = self.divide_pred(discriminator_out)
146
147
return pred_fake, pred_real
148
149
def encode_z(self, real_image):
150
mu, logvar = self.netE(real_image)
151
z = self.reparameterize(mu, logvar)
152
return z, mu, logvar
153
154
def create_optimizers(self, opt):
155
G_params = list(self.netG.parameters())
156
if opt.use_vae:
157
G_params += list(self.netE.parameters())
158
if opt.isTrain:
159
D_params = list(self.netD.parameters())
160
161
beta1, beta2 = opt.beta1, opt.beta2
162
if opt.no_TTUR:
163
G_lr, D_lr = opt.lr, opt.lr
164
else:
165
G_lr, D_lr = opt.lr / 2, opt.lr * 2
166
167
optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2))
168
optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2))
169
170
return optimizer_G, optimizer_D
171
172
def save(self, epoch):
173
util.save_network(self.netG, 'G', epoch, self.opt)
174
util.save_network(self.netD, 'D', epoch, self.opt)
175
if self.opt.use_vae:
176
util.save_network(self.netE, 'E', epoch, self.opt)
177
178
############################################################################
179
# Private helper methods
180
############################################################################
181
182
def initialize_networks(self, opt):
183
netG = networks.define_G(opt)
184
netD = networks.define_D(opt) if opt.isTrain else None
185
netE = networks.define_E(opt) if opt.use_vae else None
186
187
if not opt.isTrain or opt.continue_train:
188
netG = util.load_network(netG, 'G', opt.which_epoch, opt)
189
if opt.isTrain:
190
netD = util.load_network(netD, 'D', opt.which_epoch, opt)
191
if opt.use_vae:
192
netE = util.load_network(netE, 'E', opt.which_epoch, opt)
193
194
return netG, netD, netE
195
196
# preprocess the input, such as moving the tensors to GPUs and
197
# transforming the label map to one-hot encoding
198
# |data|: dictionary of the input data
199
200
201
202
# Take the prediction of fake and real images from the combined batch
203
def divide_pred(self, pred):
204
# the prediction contains the intermediate outputs of multiscale GAN,
205
# so it's usually a list
206
if type(pred) == list:
207
fake = []
208
real = []
209
for p in pred:
210
fake.append([tensor[:tensor.size(0) // 2] for tensor in p])
211
real.append([tensor[tensor.size(0) // 2:] for tensor in p])
212
else:
213
fake = pred[:pred.size(0) // 2]
214
real = pred[pred.size(0) // 2:]
215
216
return fake, real
217
218
def get_edges(self, t):
219
edge = self.ByteTensor(t.size()).zero_()
220
edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])
221
edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1])
222
edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
223
edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
224
return edge.float()
225
226
def reparameterize(self, mu, logvar):
227
std = torch.exp(0.5 * logvar)
228
eps = torch.randn_like(std)
229
return eps.mul(std) + mu
230
231
def use_gpu(self):
232
return len(self.opt.gpu_ids) > 0
233
234