"""
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.networks.base_network import BaseNetwork
from models.networks.normalization import get_nonspade_norm_layer
from models.networks.architecture import ResnetBlock as ResnetBlock
from models.networks.architecture import SPADEResnetBlock as SPADEResnetBlock
import os
import copy
class SPADEGenerator(BaseNetwork):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.set_defaults(norm_G='spectralspadesyncbatch3x3')
parser.add_argument('--num_upsampling_layers',
choices=('normal', 'more', 'most'), default='normal',
help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator")
return parser
def __init__(self, opt):
super().__init__()
self.opt = opt
nf = opt.ngf
self.sw, self.sh, self.scale_ratio = self.compute_latent_vector_size(opt)
if opt.use_vae:
self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
else:
self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)
self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
self.ups = nn.ModuleList([
SPADEResnetBlock(16 * nf, 8 * nf, opt),
SPADEResnetBlock(8 * nf, 4 * nf, opt),
SPADEResnetBlock(4 * nf, 2 * nf, opt),
SPADEResnetBlock(2 * nf, 1 * nf, opt)
])
self.to_rgbs = nn.ModuleList([
nn.Conv2d(8 * nf, 3, 3, padding=1),
nn.Conv2d(4 * nf, 3, 3, padding=1),
nn.Conv2d(2 * nf, 3, 3, padding=1),
nn.Conv2d(1 * nf, 3, 3, padding=1)
])
self.up = nn.Upsample(scale_factor=2)
def encode(self, input):
h, w = input.size()[-2:]
sh, sw = h//2**self.scale_ratio, w//2**self.scale_ratio
x = F.interpolate(input, size=(sh, sw))
return self.fc(x)
def compute_latent_vector_size(self, opt):
if opt.num_upsampling_layers == 'normal':
num_up_layers = 5
elif opt.num_upsampling_layers == 'more':
num_up_layers = 6
elif opt.num_upsampling_layers == 'most':
num_up_layers = 7
else:
raise ValueError('opt.num_upsampling_layers [%s] not recognized' %
opt.num_upsampling_layers)
scale_ratio = num_up_layers
sw = opt.crop_size // (2**num_up_layers)
sh = round(sw / opt.aspect_ratio)
return sw, sh, scale_ratio
def forward(self, input, seg=None):
'''
20200307: Dangerous Change
Add separable forward to allow different
input and segmentation maps...
To return to original, simply add
seg = input at begining, and disable the seg parameter.
20200308: A more elegant solution:
@ Allow forward to take default parameters.
@ Allow customizable input encoding
20200310: Merge fc into encode, since encoder directly outputs processed feature map.
[TODO] @ Allow customizable segmap encoding?
'''
if seg is None:
seg = input
x = self.encode(input)
x = self.head_0(x, seg)
x = self.up(x)
x = self.G_middle_0(x, seg)
x = self.G_middle_1(x, seg)
if self.opt.is_test:
phase = len(self.to_rgbs)
else:
phase = self.opt.train_phase+1
for i in range(phase):
x = self.up(x)
x = self.ups[i](x, seg)
x = self.to_rgbs[phase-1](F.leaky_relu(x, 2e-1))
x = torch.tanh(x)
return x
def mixed_guidance_forward(self, input, seg=None, n=0, mode='progressive'):
'''
mixed_forward: input and seg are different images
For the first n levels (including encoder)
we use input, for the rest we use seg.
If mode = 'progressive', the output's like: AAABBB
If mode = 'one_plug', the output's like: AAABAA
If mode = 'one_ablate', the output's like: BBBABB
'''
if seg is None:
return self.forward(input)
if self.opt.is_test:
phase = len(self.to_rgbs)
else:
phase = self.opt.train_phase+1
if mode == 'progressive':
n = max(min(n, 4 + phase), 0)
guide_list = [input] * n + [seg] * (4+phase-n)
elif mode == 'one_plug':
n = max(min(n, 4 + phase-1), 0)
guide_list = [seg] * (4+phase)
guide_list[n] = input
elif mode == 'one_ablate':
if n > 3+phase:
return self.forward(input)
guide_list = [input] * (4+phase)
guide_list[n] = seg
x = self.encode(guide_list[0])
x = self.head_0(x, guide_list[1])
x = self.up(x)
x = self.G_middle_0(x, guide_list[2])
x = self.G_middle_1(x, guide_list[3])
for i in range(phase):
x = self.up(x)
x = self.ups[i](x, guide_list[4+i])
x = self.to_rgbs[phase-1](F.leaky_relu(x, 2e-1))
x = torch.tanh(x)
return x
class HiFaceGANGenerator(SPADEGenerator):
def __init__(self, opt):
super().__init__(opt)
self.opt = opt
nf = opt.ngf
self.sw, self.sh, self.scale_ratio = self.compute_latent_vector_size(opt)
if opt.use_vae:
self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
else:
self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)
self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, 16 * nf)
self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, 16 * nf)
self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt, 16 * nf)
self.ups = nn.ModuleList([
SPADEResnetBlock(16 * nf, 8 * nf, opt, 8 * nf),
SPADEResnetBlock(8 * nf, 4 * nf, opt, 4 * nf),
SPADEResnetBlock(4 * nf, 2 * nf, opt, 2 * nf),
SPADEResnetBlock(2 * nf, 1 * nf, opt, 1 * nf)
])
self.to_rgbs = nn.ModuleList([
nn.Conv2d(8 * nf, 3, 3, padding=1),
nn.Conv2d(4 * nf, 3, 3, padding=1),
nn.Conv2d(2 * nf, 3, 3, padding=1),
nn.Conv2d(1 * nf, 3, 3, padding=1)
])
self.up = nn.Upsample(scale_factor=2)
self.encoder = ContentAdaptiveSuppresor(opt, self.sw, self.sh, self.scale_ratio)
def nested_encode(self, x):
return self.encoder(x)
def forward(self, input):
xs = self.nested_encode(input)
x = self.encode(input)
'''
print([_x.shape for _x in xs])
print(x.shape)
print(self.head_0)
'''
x = self.head_0(x, xs[0])
x = self.up(x)
x = self.G_middle_0(x, xs[1])
x = self.G_middle_1(x, xs[1])
if self.opt.is_test:
phase = len(self.to_rgbs)
else:
phase = self.opt.train_phase+1
for i in range(phase):
x = self.up(x)
x = self.ups[i](x, xs[i+2])
x = self.to_rgbs[phase-1](F.leaky_relu(x, 2e-1))
x = torch.tanh(x)
return x
class ContentAdaptiveSuppresor(BaseNetwork):
def __init__(self, opt, sw, sh, n_2xdown,
norm_layer=nn.InstanceNorm2d):
super().__init__()
self.sw = sw
self.sh = sh
self.max_ratio = 16
self.n_2xdown = n_2xdown
ngf = opt.ngf
kw = 3
pw = (kw - 1) // 2
self.head = nn.Sequential(
nn.Conv2d(opt.semantic_nc, ngf, kw, stride=1, padding=pw, bias=False),
norm_layer(ngf),
nn.ReLU(),
)
cur_ratio = 1
for i in range(n_2xdown):
next_ratio = min(cur_ratio*2, self.max_ratio)
model = [
SimplifiedLIP(ngf*cur_ratio),
nn.Conv2d(ngf*cur_ratio, ngf*next_ratio, kw, stride=1, padding=pw),
norm_layer(ngf*next_ratio),
]
cur_ratio = next_ratio
if i < n_2xdown - 1:
model += [nn.ReLU(inplace=True)]
setattr(self, 'encoder_%d' % i, nn.Sequential(*model))
def forward(self, x):
x = [self.head(x)]
for i in range(self.n_2xdown):
net = getattr(self, 'encoder_%d' % i)
x = [net(x[0])] + x
return x
def lip2d(x, logit, kernel=3, stride=2, padding=1):
weight = logit.exp()
return F.avg_pool2d(x*weight, kernel, stride, padding) / F.avg_pool2d(weight, kernel, stride, padding)
class SoftGate(nn.Module):
COEFF = 12.0
def __init__(self):
super(SoftGate, self).__init__()
def forward(self, x):
return torch.sigmoid(x).mul(self.COEFF)
class SimplifiedLIP(nn.Module):
def __init__(self, channels):
super(SimplifiedLIP, self).__init__()
rp = channels
self.logit = nn.Sequential(
nn.Conv2d(channels, channels, 3, padding=1, bias=False),
nn.InstanceNorm2d(channels, affine=True),
SoftGate()
)
'''
OrderedDict((
('conv', nn.Conv2d(channels, channels, 3, padding=1, bias=False)),
('bn', nn.InstanceNorm2d(channels, affine=True)),
('gate', SoftGate()),
))
'''
def init_layer(self):
self.logit[0].weight.data.fill_(0.0)
def forward(self, x):
frac = lip2d(x, self.logit(x))
return frac
class LIPEncoder(BaseNetwork):
def __init__(self, opt, sw, sh, n_2xdown,
norm_layer=nn.InstanceNorm2d):
super().__init__()
self.sw = sw
self.sh = sh
self.max_ratio = 16
ngf = opt.ngf
kw = 3
pw = (kw - 1) // 2
model = [
nn.Conv2d(opt.semantic_nc, ngf, kw, stride=1, padding=pw, bias=False),
norm_layer(ngf),
nn.ReLU(),
]
cur_ratio = 1
for i in range(n_2xdown):
next_ratio = min(cur_ratio*2, self.max_ratio)
model += [
SimplifiedLIP(ngf*cur_ratio),
nn.Conv2d(ngf*cur_ratio, ngf*next_ratio, kw, stride=1, padding=pw),
norm_layer(ngf*next_ratio),
]
cur_ratio = next_ratio
if i < n_2xdown - 1:
model += [nn.ReLU(inplace=True)]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
class LIPSPADEGenerator(SPADEGenerator):
'''
20200309: SPADEGenerator with a learnable feature encoder
Encoder design: Local Importance-based Pooling (Ziteng Gao et.al.,ICCV 2019)
'''
def __init__(self, opt):
super().__init__(opt)
self.lip_encoder = LIPEncoder(opt, self.sw, self.sh, self.scale_ratio)
def encode(self, x):
return self.lip_encoder(x)
class NoiseClassPredictor(nn.Module):
'''
Input: nc*sw*sw tensor, either from clean or corrupted images
Output: n-dim tensor indicating the loss type (or intensity?)
'''
def __init__(self, opt, sw, nc, outdim):
super().__init__()
nbottleneck = 256
middim = 256
conv = [
nn.Conv2d(nc, nbottleneck, 1, stride=1),
nn.InstanceNorm2d(nbottleneck),
nn.LeakyReLU(0.2, inplace=True),
]
while sw > 4:
sw = sw // 2
conv += [
nn.Conv2d(nbottleneck, nbottleneck, 3, stride=2, padding=1),
nn.InstanceNorm2d(nbottleneck),
nn.LeakyReLU(0.2, inplace=True),
]
self.conv = nn.Sequential(*conv)
indim = sw * sw * nbottleneck
self.fc = nn.Sequential(
nn.Linear(indim, middim),
nn.BatchNorm1d(middim),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(middim, outdim),
)
def forward(self, x):
x = self.conv(x)
x = x.view(x.shape[0],-1)
return self.fc(x)
class NoiseIntensityPredictor(nn.Module):
'''
Input: nc*sw*sw tensor, either from clean or corrupted images
Output: 1-dim tensor indicating the loss intensity
'''
def __init__(self, opt, sw, nc, outdim):
super().__init__()
nbottleneck = 256
middim = 256
conv = [
nn.Conv2d(nc, nbottleneck, 1, stride=1),
nn.BatchNorm2d(nbottleneck),
nn.LeakyReLU(0.2, inplace=True),
]
while sw > 4:
sw = sw // 2
conv += [
nn.Conv2d(nbottleneck, nbottleneck, 3, stride=2, padding=1),
nn.BatchNorm2d(nbottleneck),
nn.LeakyReLU(0.2, inplace=True),
]
self.conv = nn.Sequential(*conv)
indim = sw * sw * nbottleneck
self.fc = nn.Sequential(
nn.Linear(indim, middim),
nn.BatchNorm1d(middim),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(middim, outdim),
)
def forward(self, x):
x = self.conv(x)
x = x.view(x.shape[0],-1)
x = self.fc(x)
return x.squeeze()
class SubAddGenerator(SPADEGenerator):
'''
20200311:
This generator contains a complete set
of self-supervised training scheme
that requires a separate dataloader.
The self-supervised pre-training is
implemented as a clean interface.
SubAddGenerator::train_E(self, dataloader, epochs)
For the upperlevel Pix2Pix_Model,
two things to be done:
A) Run the pretrain script
B) Save encoder and adjust learning rate.
-----------------------------------------
20200312:
Pre-test problem: The discriminator is hard to test real vs fake
Also, using residual of feature maps ain't work...
Cause: The feature map is too close to be properly separated.
Try to test on one single reduction: arbitrary ratio of downsampling
try to estimate the reduction ratio?
'''
def __init__(self, opt):
super().__init__(opt)
self.encoder = LIPEncoder(opt, self.sw, self.sh, self.scale_ratio)
self.dis_nc = self.opt.ngf * min(16, 2**self.scale_ratio)
self.discriminator = NoiseIntensityPredictor(opt, self.sw, self.dis_nc, 1)
if opt.isTrain:
self.attach_dataloader(opt)
self.noise_dim = opt.noise_dim
self.l1_loss = nn.L1Loss()
self.gan_loss = nn.MSELoss()
beta1, beta2 = opt.beta1, opt.beta2
if opt.no_TTUR:
G_lr, D_lr = opt.lr, opt.lr
else:
G_lr, D_lr = opt.lr / 2, opt.lr * 2
self.optimizer_E = torch.optim.Adam(
self.encoder.parameters(), lr=G_lr, betas=(beta1, beta2)
)
self.optimizer_D = torch.optim.Adam(
self.discriminator.parameters(), lr=D_lr/2, betas=(beta1, beta2)
)
def _create_auxiliary_opt(self, opt):
'''
Create auxiliary options
change necessary params
--- dataroot
--- dataset_mode
'''
aux_opt = copy.deepcopy(opt)
aux_opt.dataroot = opt.dataroot_assist
aux_opt.dataset_mode = 'assist'
aux_opt.batchSize = 4
aux_opt.nThreads = 4
return aux_opt
def attach_dataloader(self, opt):
aux_opt = self._create_auxiliary_opt(opt)
self.loader = data.create_dataloader(aux_opt)
def encode(self, x):
return self.encoder(x)
def process_input(self, data):
self.clean = data['clean'].cuda()
self.noisy = data['noisy'].cuda()
self.noise_label = data['label'].cuda()
self.clean_label = torch.ones_like(self.noise_label)
def update_E(self):
bundle_in = torch.cat((self.clean, self.noisy), dim=0)
bundle_out = self.encode(bundle_in)
nb = bundle_in.shape[0] // 2
F_real, F_fake = bundle_out[:nb], bundle_out[nb:]
pred_fake = self.discriminator(F_fake)
loss_l1 = self.l1_loss(F_fake, F_real)
loss_gan = self.gan_loss(pred_fake, self.clean_label)
loss_sum = loss_l1 * 10 + loss_gan
self.optimizer_E.zero_grad()
loss_sum.backward()
self.optimizer_E.step()
self.loss_l1 = loss_l1.item()
self.loss_gan_E = loss_gan.item()
self.loss_sum = loss_sum.item()
def update_D(self):
with torch.no_grad():
bundle_in = torch.cat((self.clean, self.noisy), dim=0)
bundle_out = self.encode(bundle_in)
nb = bundle_in.shape[0] // 2
F_fake = bundle_out[nb:] / (bundle_out[:nb] + 1e-6)
F_real = torch.ones_like(F_fake, requires_grad=False)
pred_real = self.discriminator(F_real)
loss_real = self.gan_loss(pred_real, self.clean_label)
pred_fake = self.discriminator(F_fake.detach())
loss_fake = self.gan_loss(pred_fake, self.noise_label)
loss_sum = (loss_real + loss_fake * self.opt.noise_dim) / 2
self.optimizer_D.zero_grad()
loss_sum.backward()
self.optimizer_D.step()
self.loss_gan_D_real = loss_real.item()
self.loss_gan_D_fake = loss_fake.item()
def debug_D(self):
with torch.no_grad():
bundle_in = torch.cat((self.clean, self.noisy), dim=0)
bundle_out = self.encode(bundle_in)
nb = bundle_in.shape[0] // 2
F_real, F_fake = bundle_out[:nb], bundle_out[nb:]
F_res = F_fake - F_real
pred_real = self.discriminator(F_real)
pred_fake = self.discriminator(F_res.detach())
print(pred_real, pred_fake)
real_acc = (pred_real == 0).sum().item() / pred_real.shape[0]
fake_acc = (pred_fake == self.noise_label).sum().item() / pred_fake.shape[0]
print(real_acc, fake_acc)
def log(self, epoch, i):
logstring = ' Epoch [%d] iter [%d]: ' % (epoch, i)
logstring += 'l1: %.4f ' % self.loss_l1
logstring += 'gen: %.4f ' % self.loss_gan_E
logstring += 'E_sum: %.4f ' % self.loss_sum
logstring += 'dis_real: %.4f ' % self.loss_gan_D_real
logstring += 'dis_fake: %.4f' % self.loss_gan_D_fake
print(logstring)
def train_E(self, epochs):
pretrained_ckpt_dir = os.path.join(
self.opt.checkpoints_dir, self.opt.name,
'pretrained_net_E_%d.pth' % epochs
)
print(pretrained_ckpt_dir)
print('======= Stage I: Subtraction =======')
if os.path.isfile(pretrained_ckpt_dir):
state_dict_E = torch.load(pretrained_ckpt_dir)
self.encoder.load_state_dict(state_dict_E)
print('======= Load cached checkpoints %s' % pretrained_ckpt_dir)
else:
print('======= total epochs: %d ' % epochs)
for epoch in range(1,epochs+1):
for i, data in enumerate(self.loader):
self.process_input(data)
self.update_E()
self.update_D()
if i % 10 == 0:
self.log(epoch, i)
print('Epoch [%d] finished' % epoch)
torch.save(self.encoder.state_dict(), os.path.join(
self.opt.checkpoints_dir, self.opt.name, 'pretrained_net_E_%d.pth' % epoch
))
class ContrasiveGenerator(SPADEGenerator):
def __init__(self, opt):
super().__init__(opt)
self.encoder = LIPEncoder(opt, self.sw, self.sh, self.scale_ratio)
if opt.isTrain:
self.attach_dataloader(opt)
self.noise_dim = opt.noise_dim
self.l1_loss = nn.L1Loss()
beta1, beta2 = opt.beta1, opt.beta2
self.optimizer_E = torch.optim.Adam(
self.encoder.parameters(), lr=opt.lr, betas=(beta1, beta2)
)
def _create_auxiliary_opt(self, opt):
'''
Create auxiliary options
change necessary params
--- dataroot
--- dataset_mode
'''
aux_opt = copy.deepcopy(opt)
aux_opt.dataroot = opt.dataroot_assist
aux_opt.dataset_mode = 'assist'
aux_opt.batchSize = 8
aux_opt.nThreads = 4
return aux_opt
def attach_dataloader(self, opt):
aux_opt = self._create_auxiliary_opt(opt)
self.loader = data.create_dataloader(aux_opt)
def encode(self, x):
return self.encoder(x)
def process_input(self, data):
self.clean = data['clean'].cuda()
self.noisy = data['noisy'].cuda()
def update_E(self):
bundle_in = torch.cat((self.clean, self.noisy), dim=0)
bundle_out = self.encode(bundle_in)
nb = bundle_in.shape[0] // 2
F_real, F_fake = bundle_out[:nb], bundle_out[nb:]
loss_l1 = self.l1_loss(F_fake, F_real)
self.optimizer_E.zero_grad()
loss_l1.backward()
self.optimizer_E.step()
self.loss_l1 = loss_l1.item()
def log(self, epoch, i):
logstring = ' Epoch [%d] iter [%d]: ' % (epoch, i)
logstring += 'l1: %.4f ' % self.loss_l1
print(logstring)
def train_E(self, epochs):
pretrained_ckpt_dir = os.path.join(
self.opt.checkpoints_dir, self.opt.name,
'pretrained_net_E_%d.pth' % epochs
)
print(pretrained_ckpt_dir)
print('======= Stage I: Subtraction =======')
if os.path.isfile(pretrained_ckpt_dir):
state_dict_E = torch.load(pretrained_ckpt_dir)
self.encoder.load_state_dict(state_dict_E)
print('======= Load cached checkpoints %s' % pretrained_ckpt_dir)
else:
print('======= total epochs: %d ' % epochs)
for epoch in range(1,epochs+1):
for i, data in enumerate(self.loader):
self.process_input(data)
self.update_E()
if i % 10 == 0:
self.log(epoch, i)
print('Epoch [%d] finished' % epoch)
torch.save(self.encoder.state_dict(), os.path.join(
self.opt.checkpoints_dir, self.opt.name, 'pretrained_net_E_%d.pth' % epoch
))