Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/models/pix2pix_model.py
Views: 792
"""1Copyright (C) 2019 NVIDIA Corporation. All rights reserved.2Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).3"""45import torch6import models.networks as networks7import utils.inference.util as util8import random91011class Pix2PixModel(torch.nn.Module):12@staticmethod13def modify_commandline_options(parser, is_train):14networks.modify_commandline_options(parser, is_train)15return parser1617def __init__(self, opt):18super().__init__()19self.opt = opt20self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \21else torch.FloatTensor22self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \23else torch.ByteTensor2425self.netG, self.netD, self.netE = self.initialize_networks(opt)2627# set loss functions28if opt.isTrain:29self.criterionGAN = networks.GANLoss(30opt.gan_mode, tensor=self.FloatTensor, opt=self.opt)31self.criterionFeat = torch.nn.L1Loss()32if not opt.no_vgg_loss:33self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)34if opt.use_vae:35self.KLDLoss = networks.KLDLoss()3637# Entry point for all calls involving forward pass38# of deep networks. We used this approach since DataParallel module39# can't parallelize custom functions, we branch to different40# routines based on |mode|.41def forward(self, data, mode):42input_semantics, real_image = self.preprocess_input(data)43# input_semantics, real_image = data['label'], data['image']4445if mode == 'generator':46g_loss, generated = self.compute_generator_loss(input_semantics, real_image)47return g_loss, generated48elif mode == 'discriminator':49d_loss = self.compute_discriminator_loss(50input_semantics, real_image)51return d_loss52elif mode == 'inference':53with torch.no_grad():54fake_image = self.generate_fake(input_semantics)55return fake_image56elif mode == 'inference2':57with torch.no_grad():58fake_image = self.netG(input_semantics)59return fake_image60else:61raise ValueError("|mode| is invalid")6263def preprocess_input(self, data):64if self.use_gpu():65data['label'] = data['label'].cuda()66data['image'] = data['image'].cuda()6768return data['label'], data['image']6970def compute_generator_loss(self, input_semantics, real_image):71G_losses = {}7273fake_image = self.generate_fake(input_semantics)7475pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image)7677G_losses['GAN'] = self.criterionGAN(pred_fake, True,78for_discriminator=False)7980if not self.opt.no_ganFeat_loss:81num_D = len(pred_fake)82GAN_Feat_loss = self.FloatTensor(1).fill_(0)83for i in range(num_D): # for each discriminator84# last output is the final prediction, so we exclude it85num_intermediate_outputs = len(pred_fake[i]) - 186for j in range(num_intermediate_outputs): # for each layer output87unweighted_loss = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach())88GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat / num_D89G_losses['GAN_Feat'] = GAN_Feat_loss9091h ,w = fake_image.shape[-2:]92if not self.opt.no_vgg_loss and min(w,h)>=64:93G_losses['VGG'] = self.criterionVGG(fake_image, real_image) \94* self.opt.lambda_vgg9596return G_losses, fake_image9798def compute_discriminator_loss(self, input_semantics, real_image):99D_losses = {}100with torch.no_grad():101fake_image = self.generate_fake(input_semantics)102fake_image = fake_image.detach()103fake_image.requires_grad_()104105pred_fake, pred_real = self.discriminate(106input_semantics, fake_image, real_image)107108D_losses['D_Fake'] = self.criterionGAN(pred_fake, False,109for_discriminator=True)110D_losses['D_real'] = self.criterionGAN(pred_real, True,111for_discriminator=True)112113return D_losses114115def generate_fake(self, input_semantics):116# input_semantics = torch.nn.functional.interpolate(input_semantics, size=(h//4, w//4),117# mode='nearest')#[:, :, ::4, ::4]118119fake_image = self.netG(input_semantics)120121return fake_image122123def discriminate(self, input_semantics, fake_image, real_image):124h, w = fake_image.shape[-2:]125if fake_image.shape[-2:]!=input_semantics.shape[-2:]:126semantics = torch.nn.functional.interpolate(input_semantics, (h, w))127real = torch.nn.functional.interpolate(real_image, (h, w))128fake_concat = torch.cat([semantics, fake_image], dim=1)129real_concat = torch.cat([semantics, real], dim=1)130else:131fake_concat = torch.cat([input_semantics, fake_image], dim=1)132real_concat = torch.cat([input_semantics, real_image], dim=1)133# fake_concat = fake_image134# real_concat = real_image135136# In Batch Normalization, the fake and real images are137# recommended to be in the same batch to avoid disparate138# statistics in fake and real images.139# So both fake and real images are fed to D all at once.140fake_and_real = torch.cat([fake_concat, real_concat], dim=0)141142discriminator_out = self.netD(fake_and_real)143144pred_fake, pred_real = self.divide_pred(discriminator_out)145146return pred_fake, pred_real147148def encode_z(self, real_image):149mu, logvar = self.netE(real_image)150z = self.reparameterize(mu, logvar)151return z, mu, logvar152153def create_optimizers(self, opt):154G_params = list(self.netG.parameters())155if opt.use_vae:156G_params += list(self.netE.parameters())157if opt.isTrain:158D_params = list(self.netD.parameters())159160beta1, beta2 = opt.beta1, opt.beta2161if opt.no_TTUR:162G_lr, D_lr = opt.lr, opt.lr163else:164G_lr, D_lr = opt.lr / 2, opt.lr * 2165166optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2))167optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2))168169return optimizer_G, optimizer_D170171def save(self, epoch):172util.save_network(self.netG, 'G', epoch, self.opt)173util.save_network(self.netD, 'D', epoch, self.opt)174if self.opt.use_vae:175util.save_network(self.netE, 'E', epoch, self.opt)176177############################################################################178# Private helper methods179############################################################################180181def initialize_networks(self, opt):182netG = networks.define_G(opt)183netD = networks.define_D(opt) if opt.isTrain else None184netE = networks.define_E(opt) if opt.use_vae else None185186if not opt.isTrain or opt.continue_train:187netG = util.load_network(netG, 'G', opt.which_epoch, opt)188if opt.isTrain:189netD = util.load_network(netD, 'D', opt.which_epoch, opt)190if opt.use_vae:191netE = util.load_network(netE, 'E', opt.which_epoch, opt)192193return netG, netD, netE194195# preprocess the input, such as moving the tensors to GPUs and196# transforming the label map to one-hot encoding197# |data|: dictionary of the input data198199200201# Take the prediction of fake and real images from the combined batch202def divide_pred(self, pred):203# the prediction contains the intermediate outputs of multiscale GAN,204# so it's usually a list205if type(pred) == list:206fake = []207real = []208for p in pred:209fake.append([tensor[:tensor.size(0) // 2] for tensor in p])210real.append([tensor[tensor.size(0) // 2:] for tensor in p])211else:212fake = pred[:pred.size(0) // 2]213real = pred[pred.size(0) // 2:]214215return fake, real216217def get_edges(self, t):218edge = self.ByteTensor(t.size()).zero_()219edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])220edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1])221edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])222edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])223return edge.float()224225def reparameterize(self, mu, logvar):226std = torch.exp(0.5 * logvar)227eps = torch.randn_like(std)228return eps.mul(std) + mu229230def use_gpu(self):231return len(self.opt.gpu_ids) > 0232233234