Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/models/networks/generator.py
Views: 813
1
"""
2
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
"""
5
6
import torch
7
import torch.nn as nn
8
import torch.nn.functional as F
9
from models.networks.base_network import BaseNetwork
10
from models.networks.normalization import get_nonspade_norm_layer
11
from models.networks.architecture import ResnetBlock as ResnetBlock
12
from models.networks.architecture import SPADEResnetBlock as SPADEResnetBlock
13
import os
14
#import data # only run from basic level!
15
import copy # deepcopy
16
17
class SPADEGenerator(BaseNetwork):
18
@staticmethod
19
def modify_commandline_options(parser, is_train):
20
parser.set_defaults(norm_G='spectralspadesyncbatch3x3')
21
parser.add_argument('--num_upsampling_layers',
22
choices=('normal', 'more', 'most'), default='normal',
23
help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator")
24
25
return parser
26
27
def __init__(self, opt):
28
super().__init__()
29
self.opt = opt
30
nf = opt.ngf
31
32
self.sw, self.sh, self.scale_ratio = self.compute_latent_vector_size(opt)
33
34
if opt.use_vae:
35
# In case of VAE, we will sample from random z vector
36
self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
37
else:
38
# Otherwise, we make the network deterministic by starting with
39
# downsampled segmentation map instead of random z
40
self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)
41
42
self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
43
44
self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
45
self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
46
47
# 20200211 test 4x with only 3 stage
48
49
self.ups = nn.ModuleList([
50
SPADEResnetBlock(16 * nf, 8 * nf, opt),
51
SPADEResnetBlock(8 * nf, 4 * nf, opt),
52
SPADEResnetBlock(4 * nf, 2 * nf, opt),
53
SPADEResnetBlock(2 * nf, 1 * nf, opt) # here
54
])
55
56
self.to_rgbs = nn.ModuleList([
57
nn.Conv2d(8 * nf, 3, 3, padding=1),
58
nn.Conv2d(4 * nf, 3, 3, padding=1),
59
nn.Conv2d(2 * nf, 3, 3, padding=1),
60
nn.Conv2d(1 * nf, 3, 3, padding=1) # here
61
])
62
63
self.up = nn.Upsample(scale_factor=2)
64
65
# 20200309 interface for flexible encoder design
66
# and mid-level loss control!
67
# For basic network, it's just a 16x downsampling
68
def encode(self, input):
69
h, w = input.size()[-2:]
70
sh, sw = h//2**self.scale_ratio, w//2**self.scale_ratio
71
x = F.interpolate(input, size=(sh, sw))
72
return self.fc(x) # 20200310: Merge fc into encoder
73
74
def compute_latent_vector_size(self, opt):
75
if opt.num_upsampling_layers == 'normal':
76
num_up_layers = 5
77
elif opt.num_upsampling_layers == 'more':
78
num_up_layers = 6
79
elif opt.num_upsampling_layers == 'most':
80
num_up_layers = 7
81
else:
82
raise ValueError('opt.num_upsampling_layers [%s] not recognized' %
83
opt.num_upsampling_layers)
84
85
# 20200211 Yang Lingbo with respect to phase
86
scale_ratio = num_up_layers
87
#scale_ratio = 4 #here
88
sw = opt.crop_size // (2**num_up_layers)
89
sh = round(sw / opt.aspect_ratio)
90
91
return sw, sh, scale_ratio
92
93
def forward(self, input, seg=None):
94
'''
95
20200307: Dangerous Change
96
Add separable forward to allow different
97
input and segmentation maps...
98
99
To return to original, simply add
100
seg = input at begining, and disable the seg parameter.
101
102
20200308: A more elegant solution:
103
@ Allow forward to take default parameters.
104
@ Allow customizable input encoding
105
106
20200310: Merge fc into encode, since encoder directly outputs processed feature map.
107
108
[TODO] @ Allow customizable segmap encoding?
109
'''
110
111
if seg is None:
112
seg = input # Interesting change...
113
114
# For basic generator, 16x downsampling.
115
# 20200310: Merge fc into encoder
116
x = self.encode(input)
117
#print(x.shape, input.shape, seg.shape)
118
119
x = self.head_0(x, seg)
120
121
x = self.up(x)
122
x = self.G_middle_0(x, seg)
123
x = self.G_middle_1(x, seg)
124
125
if self.opt.is_test:
126
phase = len(self.to_rgbs)
127
else:
128
phase = self.opt.train_phase+1
129
130
for i in range(phase):
131
x = self.up(x)
132
x = self.ups[i](x, seg)
133
134
x = self.to_rgbs[phase-1](F.leaky_relu(x, 2e-1))
135
x = torch.tanh(x)
136
137
return x
138
139
def mixed_guidance_forward(self, input, seg=None, n=0, mode='progressive'):
140
'''
141
mixed_forward: input and seg are different images
142
For the first n levels (including encoder)
143
we use input, for the rest we use seg.
144
145
If mode = 'progressive', the output's like: AAABBB
146
If mode = 'one_plug', the output's like: AAABAA
147
If mode = 'one_ablate', the output's like: BBBABB
148
'''
149
150
if seg is None:
151
return self.forward(input)
152
153
if self.opt.is_test:
154
phase = len(self.to_rgbs)
155
else:
156
phase = self.opt.train_phase+1
157
158
if mode == 'progressive':
159
n = max(min(n, 4 + phase), 0)
160
guide_list = [input] * n + [seg] * (4+phase-n)
161
elif mode == 'one_plug':
162
n = max(min(n, 4 + phase-1), 0)
163
guide_list = [seg] * (4+phase)
164
guide_list[n] = input
165
elif mode == 'one_ablate':
166
if n > 3+phase:
167
return self.forward(input)
168
guide_list = [input] * (4+phase)
169
guide_list[n] = seg
170
171
x = self.encode(guide_list[0])
172
x = self.head_0(x, guide_list[1])
173
174
x = self.up(x)
175
x = self.G_middle_0(x, guide_list[2])
176
x = self.G_middle_1(x, guide_list[3])
177
178
for i in range(phase):
179
x = self.up(x)
180
x = self.ups[i](x, guide_list[4+i])
181
182
x = self.to_rgbs[phase-1](F.leaky_relu(x, 2e-1))
183
x = torch.tanh(x)
184
185
return x
186
187
class HiFaceGANGenerator(SPADEGenerator):
188
def __init__(self, opt):
189
super().__init__(opt)
190
self.opt = opt
191
nf = opt.ngf
192
193
self.sw, self.sh, self.scale_ratio = self.compute_latent_vector_size(opt)
194
195
if opt.use_vae:
196
# In case of VAE, we will sample from random z vector
197
self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
198
else:
199
# Otherwise, we make the network deterministic by starting with
200
# downsampled segmentation map instead of random z
201
self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)
202
203
self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, 16 * nf)
204
205
self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, 16 * nf)
206
self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt, 16 * nf)
207
208
# 20200211 test 4x with only 3 stage
209
210
self.ups = nn.ModuleList([
211
SPADEResnetBlock(16 * nf, 8 * nf, opt, 8 * nf),
212
SPADEResnetBlock(8 * nf, 4 * nf, opt, 4 * nf),
213
SPADEResnetBlock(4 * nf, 2 * nf, opt, 2 * nf),
214
SPADEResnetBlock(2 * nf, 1 * nf, opt, 1 * nf) # here
215
])
216
217
self.to_rgbs = nn.ModuleList([
218
nn.Conv2d(8 * nf, 3, 3, padding=1),
219
nn.Conv2d(4 * nf, 3, 3, padding=1),
220
nn.Conv2d(2 * nf, 3, 3, padding=1),
221
nn.Conv2d(1 * nf, 3, 3, padding=1) # here
222
])
223
224
self.up = nn.Upsample(scale_factor=2)
225
self.encoder = ContentAdaptiveSuppresor(opt, self.sw, self.sh, self.scale_ratio)
226
227
def nested_encode(self, x):
228
return self.encoder(x)
229
230
def forward(self, input):
231
xs = self.nested_encode(input)
232
x = self.encode(input)
233
'''
234
print([_x.shape for _x in xs])
235
print(x.shape)
236
print(self.head_0)
237
'''
238
x = self.head_0(x, xs[0])
239
240
x = self.up(x)
241
x = self.G_middle_0(x, xs[1])
242
x = self.G_middle_1(x, xs[1])
243
244
if self.opt.is_test:
245
phase = len(self.to_rgbs)
246
else:
247
phase = self.opt.train_phase+1
248
249
for i in range(phase):
250
x = self.up(x)
251
x = self.ups[i](x, xs[i+2])
252
253
x = self.to_rgbs[phase-1](F.leaky_relu(x, 2e-1))
254
x = torch.tanh(x)
255
256
return x
257
258
259
class ContentAdaptiveSuppresor(BaseNetwork):
260
def __init__(self, opt, sw, sh, n_2xdown,
261
norm_layer=nn.InstanceNorm2d):
262
super().__init__()
263
self.sw = sw
264
self.sh = sh
265
self.max_ratio = 16
266
self.n_2xdown = n_2xdown
267
268
# norm_layer = get_nonspade_norm_layer(opt, opt.norm_E)
269
270
# 20200310: Several Convolution (stride 1) + LIP blocks, 4 fold
271
ngf = opt.ngf
272
kw = 3
273
pw = (kw - 1) // 2
274
275
self.head = nn.Sequential(
276
nn.Conv2d(opt.semantic_nc, ngf, kw, stride=1, padding=pw, bias=False),
277
norm_layer(ngf),
278
nn.ReLU(),
279
)
280
cur_ratio = 1
281
for i in range(n_2xdown):
282
next_ratio = min(cur_ratio*2, self.max_ratio)
283
model = [
284
SimplifiedLIP(ngf*cur_ratio),
285
nn.Conv2d(ngf*cur_ratio, ngf*next_ratio, kw, stride=1, padding=pw),
286
norm_layer(ngf*next_ratio),
287
]
288
cur_ratio = next_ratio
289
if i < n_2xdown - 1:
290
model += [nn.ReLU(inplace=True)]
291
setattr(self, 'encoder_%d' % i, nn.Sequential(*model))
292
293
def forward(self, x):
294
# 20200628: Note the features are arranged from small to large
295
x = [self.head(x)]
296
for i in range(self.n_2xdown):
297
net = getattr(self, 'encoder_%d' % i)
298
x = [net(x[0])] + x
299
return x
300
301
302
#########################################
303
# Below are deprecated codes
304
#
305
# 20200309: LIP for local importance pooling
306
# 20200311: Self-supervised mask encoder
307
# Author: lingbo.ylb
308
# Quick trial, to be reformated later.
309
# 20200324: Nah forget about it...
310
#########################################
311
312
313
def lip2d(x, logit, kernel=3, stride=2, padding=1):
314
weight = logit.exp()
315
return F.avg_pool2d(x*weight, kernel, stride, padding) / F.avg_pool2d(weight, kernel, stride, padding)
316
317
318
class SoftGate(nn.Module):
319
COEFF = 12.0
320
321
def __init__(self):
322
super(SoftGate, self).__init__()
323
324
def forward(self, x):
325
return torch.sigmoid(x).mul(self.COEFF)
326
327
class SimplifiedLIP(nn.Module):
328
def __init__(self, channels):
329
super(SimplifiedLIP, self).__init__()
330
331
rp = channels
332
333
self.logit = nn.Sequential(
334
nn.Conv2d(channels, channels, 3, padding=1, bias=False),
335
nn.InstanceNorm2d(channels, affine=True),
336
SoftGate()
337
)
338
'''
339
OrderedDict((
340
('conv', nn.Conv2d(channels, channels, 3, padding=1, bias=False)),
341
('bn', nn.InstanceNorm2d(channels, affine=True)),
342
('gate', SoftGate()),
343
))
344
'''
345
346
def init_layer(self):
347
self.logit[0].weight.data.fill_(0.0)
348
349
def forward(self, x):
350
frac = lip2d(x, self.logit(x))
351
return frac
352
353
class LIPEncoder(BaseNetwork):
354
def __init__(self, opt, sw, sh, n_2xdown,
355
norm_layer=nn.InstanceNorm2d):
356
super().__init__()
357
self.sw = sw
358
self.sh = sh
359
self.max_ratio = 16
360
361
# norm_layer = get_nonspade_norm_layer(opt, opt.norm_E)
362
363
# 20200310: Several Convolution (stride 1) + LIP blocks, 4 fold
364
ngf = opt.ngf
365
kw = 3
366
pw = (kw - 1) // 2
367
368
model = [
369
nn.Conv2d(opt.semantic_nc, ngf, kw, stride=1, padding=pw, bias=False),
370
norm_layer(ngf),
371
nn.ReLU(),
372
]
373
cur_ratio = 1
374
for i in range(n_2xdown):
375
next_ratio = min(cur_ratio*2, self.max_ratio)
376
model += [
377
SimplifiedLIP(ngf*cur_ratio),
378
nn.Conv2d(ngf*cur_ratio, ngf*next_ratio, kw, stride=1, padding=pw),
379
norm_layer(ngf*next_ratio),
380
]
381
cur_ratio = next_ratio
382
if i < n_2xdown - 1:
383
model += [nn.ReLU(inplace=True)]
384
385
self.model = nn.Sequential(*model)
386
387
def forward(self, x):
388
return self.model(x)
389
390
class LIPSPADEGenerator(SPADEGenerator):
391
'''
392
20200309: SPADEGenerator with a learnable feature encoder
393
Encoder design: Local Importance-based Pooling (Ziteng Gao et.al.,ICCV 2019)
394
'''
395
def __init__(self, opt):
396
super().__init__(opt)
397
self.lip_encoder = LIPEncoder(opt, self.sw, self.sh, self.scale_ratio)
398
399
def encode(self, x):
400
return self.lip_encoder(x)
401
402
403
class NoiseClassPredictor(nn.Module):
404
'''
405
Input: nc*sw*sw tensor, either from clean or corrupted images
406
Output: n-dim tensor indicating the loss type (or intensity?)
407
'''
408
def __init__(self, opt, sw, nc, outdim):
409
super().__init__()
410
nbottleneck = 256
411
middim = 256
412
# Compact info
413
conv = [
414
nn.Conv2d(nc, nbottleneck, 1, stride=1),
415
nn.InstanceNorm2d(nbottleneck),
416
nn.LeakyReLU(0.2, inplace=True),
417
]
418
# sw should be probably 16, downsample to 4 and convert to 1
419
while sw > 4:
420
sw = sw // 2
421
conv += [
422
nn.Conv2d(nbottleneck, nbottleneck, 3, stride=2, padding=1),
423
nn.InstanceNorm2d(nbottleneck),
424
nn.LeakyReLU(0.2, inplace=True),
425
]
426
427
428
self.conv = nn.Sequential(*conv)
429
430
indim = sw * sw * nbottleneck
431
self.fc = nn.Sequential(
432
nn.Linear(indim, middim),
433
nn.BatchNorm1d(middim),
434
nn.LeakyReLU(0.2, inplace=True),
435
nn.Linear(middim, outdim),
436
# nn.Sigmoid(),
437
)
438
439
def forward(self, x):
440
x = self.conv(x)
441
x = x.view(x.shape[0],-1)
442
return self.fc(x)
443
444
445
class NoiseIntensityPredictor(nn.Module):
446
'''
447
Input: nc*sw*sw tensor, either from clean or corrupted images
448
Output: 1-dim tensor indicating the loss intensity
449
'''
450
def __init__(self, opt, sw, nc, outdim):
451
super().__init__()
452
nbottleneck = 256
453
middim = 256
454
# Compact info
455
conv = [
456
nn.Conv2d(nc, nbottleneck, 1, stride=1),
457
nn.BatchNorm2d(nbottleneck),
458
nn.LeakyReLU(0.2, inplace=True),
459
]
460
# sw should be probably 16, downsample to 4 and convert to 1
461
while sw > 4:
462
sw = sw // 2
463
conv += [
464
nn.Conv2d(nbottleneck, nbottleneck, 3, stride=2, padding=1),
465
nn.BatchNorm2d(nbottleneck),
466
nn.LeakyReLU(0.2, inplace=True),
467
]
468
469
470
self.conv = nn.Sequential(*conv)
471
472
indim = sw * sw * nbottleneck
473
self.fc = nn.Sequential(
474
nn.Linear(indim, middim),
475
nn.BatchNorm1d(middim),
476
#nn.Dropout(0.5),
477
nn.LeakyReLU(0.2, inplace=True),
478
nn.Linear(middim, outdim),
479
#nn.Dropout(0.5),
480
# nn.Sigmoid(),
481
)
482
483
def forward(self, x):
484
x = self.conv(x)
485
x = x.view(x.shape[0],-1)
486
x = self.fc(x)
487
return x.squeeze()
488
489
490
class SubAddGenerator(SPADEGenerator):
491
'''
492
20200311:
493
This generator contains a complete set
494
of self-supervised training scheme
495
that requires a separate dataloader.
496
497
The self-supervised pre-training is
498
implemented as a clean interface.
499
SubAddGenerator::train_E(self, dataloader, epochs)
500
501
For the upperlevel Pix2Pix_Model,
502
two things to be done:
503
A) Run the pretrain script
504
B) Save encoder and adjust learning rate.
505
506
-----------------------------------------
507
20200312:
508
Pre-test problem: The discriminator is hard to test real vs fake
509
Also, using residual of feature maps ain't work...
510
511
Cause: The feature map is too close to be properly separated.
512
513
Try to test on one single reduction: arbitrary ratio of downsampling
514
try to estimate the reduction ratio?
515
'''
516
def __init__(self, opt):
517
super().__init__(opt)
518
self.encoder = LIPEncoder(opt, self.sw, self.sh, self.scale_ratio)
519
520
self.dis_nc = self.opt.ngf * min(16, 2**self.scale_ratio)
521
# intensity is a scalar
522
self.discriminator = NoiseIntensityPredictor(opt, self.sw, self.dis_nc, 1)
523
if opt.isTrain:
524
self.attach_dataloader(opt)
525
self.noise_dim = opt.noise_dim
526
527
self.l1_loss = nn.L1Loss()
528
self.gan_loss = nn.MSELoss()
529
530
#self.discriminator = NoiseClassPredictor(opt, self.sw, self.dis_nc,
531
# self.noise_dim + 1) # add a new label for clean images
532
533
#self.gan_loss = nn.CrossEntropyLoss()
534
535
beta1, beta2 = opt.beta1, opt.beta2
536
if opt.no_TTUR:
537
G_lr, D_lr = opt.lr, opt.lr
538
else:
539
G_lr, D_lr = opt.lr / 2, opt.lr * 2
540
541
self.optimizer_E = torch.optim.Adam(
542
self.encoder.parameters(), lr=G_lr, betas=(beta1, beta2)
543
)
544
self.optimizer_D = torch.optim.Adam(
545
self.discriminator.parameters(), lr=D_lr/2, betas=(beta1, beta2)
546
)
547
548
def _create_auxiliary_opt(self, opt):
549
'''
550
Create auxiliary options
551
change necessary params
552
--- dataroot
553
--- dataset_mode
554
'''
555
aux_opt = copy.deepcopy(opt) # just for safety
556
aux_opt.dataroot = opt.dataroot_assist
557
aux_opt.dataset_mode = 'assist'
558
aux_opt.batchSize = 4
559
aux_opt.nThreads = 4
560
return aux_opt
561
562
def attach_dataloader(self, opt):
563
aux_opt = self._create_auxiliary_opt(opt)
564
self.loader = data.create_dataloader(aux_opt)
565
566
def encode(self, x):
567
return self.encoder(x)
568
569
def process_input(self, data):
570
self.clean = data['clean'].cuda()
571
self.noisy = data['noisy'].cuda()
572
# for BCELoss, class label is just int.
573
self.noise_label = data['label'].cuda()
574
# label design...
575
# clean label should be 0 or [1,0,0...0]?
576
# for BCELoss, class label is just 0.
577
#self.clean_label = torch.zeros_like(self.noise_label)
578
self.clean_label = torch.ones_like(self.noise_label)
579
580
def update_E(self):
581
bundle_in = torch.cat((self.clean, self.noisy), dim=0)
582
bundle_out = self.encode(bundle_in)
583
nb = bundle_in.shape[0] // 2
584
F_real, F_fake = bundle_out[:nb], bundle_out[nb:]
585
586
pred_fake = self.discriminator(F_fake)
587
loss_l1 = self.l1_loss(F_fake, F_real)
588
loss_gan = self.gan_loss(pred_fake, self.clean_label)
589
loss_sum = loss_l1 * 10 + loss_gan
590
591
self.optimizer_E.zero_grad()
592
loss_sum.backward()
593
self.optimizer_E.step()
594
595
self.loss_l1 = loss_l1.item()
596
self.loss_gan_E = loss_gan.item()
597
self.loss_sum = loss_sum.item()
598
599
def update_D(self):
600
601
with torch.no_grad():
602
bundle_in = torch.cat((self.clean, self.noisy), dim=0)
603
bundle_out = self.encode(bundle_in)
604
nb = bundle_in.shape[0] // 2
605
#F_real, #F_fake = bundle_out[:nb], bundle_out[nb:]
606
F_fake = bundle_out[nb:] / (bundle_out[:nb] + 1e-6)
607
F_real = torch.ones_like(F_fake, requires_grad=False)
608
609
pred_real = self.discriminator(F_real)
610
loss_real = self.gan_loss(pred_real, self.clean_label)
611
pred_fake = self.discriminator(F_fake.detach())
612
loss_fake = self.gan_loss(pred_fake, self.noise_label)
613
loss_sum = (loss_real + loss_fake * self.opt.noise_dim) / 2
614
615
self.optimizer_D.zero_grad()
616
loss_sum.backward()
617
self.optimizer_D.step()
618
619
self.loss_gan_D_real = loss_real.item()
620
self.loss_gan_D_fake = loss_fake.item()
621
622
def debug_D(self):
623
with torch.no_grad():
624
bundle_in = torch.cat((self.clean, self.noisy), dim=0)
625
bundle_out = self.encode(bundle_in)
626
nb = bundle_in.shape[0] // 2
627
F_real, F_fake = bundle_out[:nb], bundle_out[nb:]
628
F_res = F_fake - F_real # try to predict the residual, it's easier
629
#F_real = torch.zeros_like(F_real) # real res == 0
630
631
632
pred_real = self.discriminator(F_real)#.argmax(dim=1)
633
pred_fake = self.discriminator(F_res.detach())#.argmax(dim=1)
634
print(pred_real, pred_fake)
635
real_acc = (pred_real == 0).sum().item() / pred_real.shape[0]
636
fake_acc = (pred_fake == self.noise_label).sum().item() / pred_fake.shape[0]
637
print(real_acc, fake_acc)
638
639
def log(self, epoch, i):
640
logstring = ' Epoch [%d] iter [%d]: ' % (epoch, i)
641
logstring += 'l1: %.4f ' % self.loss_l1
642
logstring += 'gen: %.4f ' % self.loss_gan_E
643
logstring += 'E_sum: %.4f ' % self.loss_sum
644
logstring += 'dis_real: %.4f ' % self.loss_gan_D_real
645
logstring += 'dis_fake: %.4f' % self.loss_gan_D_fake
646
print(logstring)
647
648
def train_E(self, epochs):
649
pretrained_ckpt_dir = os.path.join(
650
self.opt.checkpoints_dir, self.opt.name,
651
'pretrained_net_E_%d.pth' % epochs
652
)
653
print(pretrained_ckpt_dir)
654
655
print('======= Stage I: Subtraction =======')
656
if os.path.isfile(pretrained_ckpt_dir):
657
state_dict_E = torch.load(pretrained_ckpt_dir)
658
self.encoder.load_state_dict(state_dict_E)
659
print('======= Load cached checkpoints %s' % pretrained_ckpt_dir)
660
else:
661
print('======= total epochs: %d ' % epochs)
662
for epoch in range(1,epochs+1):
663
for i, data in enumerate(self.loader):
664
self.process_input(data)
665
666
self.update_E()
667
self.update_D()
668
669
if i % 10 == 0:
670
self.log(epoch, i) # output losses and thing.
671
672
print('Epoch [%d] finished' % epoch)
673
# just save the latest.
674
torch.save(self.encoder.state_dict(), os.path.join(
675
self.opt.checkpoints_dir, self.opt.name, 'pretrained_net_E_%d.pth' % epoch
676
))
677
678
class ContrasiveGenerator(SPADEGenerator):
679
def __init__(self, opt):
680
super().__init__(opt)
681
self.encoder = LIPEncoder(opt, self.sw, self.sh, self.scale_ratio)
682
683
if opt.isTrain:
684
self.attach_dataloader(opt)
685
self.noise_dim = opt.noise_dim
686
687
self.l1_loss = nn.L1Loss()
688
689
beta1, beta2 = opt.beta1, opt.beta2
690
self.optimizer_E = torch.optim.Adam(
691
self.encoder.parameters(), lr=opt.lr, betas=(beta1, beta2)
692
)
693
694
def _create_auxiliary_opt(self, opt):
695
'''
696
Create auxiliary options
697
change necessary params
698
--- dataroot
699
--- dataset_mode
700
'''
701
aux_opt = copy.deepcopy(opt) # just for safety
702
aux_opt.dataroot = opt.dataroot_assist
703
aux_opt.dataset_mode = 'assist'
704
aux_opt.batchSize = 8
705
aux_opt.nThreads = 4
706
return aux_opt
707
708
def attach_dataloader(self, opt):
709
aux_opt = self._create_auxiliary_opt(opt)
710
self.loader = data.create_dataloader(aux_opt)
711
712
def encode(self, x):
713
return self.encoder(x)
714
715
def process_input(self, data):
716
self.clean = data['clean'].cuda()
717
self.noisy = data['noisy'].cuda()
718
719
def update_E(self):
720
bundle_in = torch.cat((self.clean, self.noisy), dim=0)
721
bundle_out = self.encode(bundle_in)
722
nb = bundle_in.shape[0] // 2
723
F_real, F_fake = bundle_out[:nb], bundle_out[nb:]
724
loss_l1 = self.l1_loss(F_fake, F_real)
725
726
self.optimizer_E.zero_grad()
727
loss_l1.backward()
728
self.optimizer_E.step()
729
730
self.loss_l1 = loss_l1.item()
731
732
def log(self, epoch, i):
733
logstring = ' Epoch [%d] iter [%d]: ' % (epoch, i)
734
logstring += 'l1: %.4f ' % self.loss_l1
735
print(logstring)
736
737
def train_E(self, epochs):
738
pretrained_ckpt_dir = os.path.join(
739
self.opt.checkpoints_dir, self.opt.name,
740
'pretrained_net_E_%d.pth' % epochs
741
)
742
print(pretrained_ckpt_dir)
743
744
print('======= Stage I: Subtraction =======')
745
if os.path.isfile(pretrained_ckpt_dir):
746
state_dict_E = torch.load(pretrained_ckpt_dir)
747
self.encoder.load_state_dict(state_dict_E)
748
print('======= Load cached checkpoints %s' % pretrained_ckpt_dir)
749
else:
750
print('======= total epochs: %d ' % epochs)
751
for epoch in range(1,epochs+1):
752
for i, data in enumerate(self.loader):
753
self.process_input(data)
754
self.update_E()
755
756
if i % 10 == 0:
757
self.log(epoch, i) # output losses and thing.
758
759
print('Epoch [%d] finished' % epoch)
760
# just save the latest.
761
torch.save(self.encoder.state_dict(), os.path.join(
762
self.opt.checkpoints_dir, self.opt.name, 'pretrained_net_E_%d.pth' % epoch
763
))
764
765
766