Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/models/networks/architecture.py
Views: 813
"""1Copyright (C) 2019 NVIDIA Corporation. All rights reserved.2Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).3"""45import torch6import torch.nn as nn7import torch.nn.functional as F8import torchvision9import torch.nn.utils.spectral_norm as spectral_norm10from models.networks.normalization import SPADE111213# ResNet block that uses SPADE.14# It differs from the ResNet block of pix2pixHD in that15# it takes in the segmentation map as input, learns the skip connection if necessary,16# and applies normalization first and then convolution.17# This architecture seemed like a standard architecture for unconditional or18# class-conditional GAN architecture using residual block.19# The code was inspired from https://github.com/LMescheder/GAN_stability.20class SPADEResnetBlock(nn.Module):21def __init__(self, fin, fout, opt, semantic_nc=None):22super().__init__()23# Attributes24self.learned_shortcut = (fin != fout)25fmiddle = min(fin, fout)26if semantic_nc is None:27semantic_nc = opt.semantic_nc2829# create conv layers30self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)31self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)32if self.learned_shortcut:33self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)3435# apply spectral norm if specified36if 'spectral' in opt.norm_G:37self.conv_0 = spectral_norm(self.conv_0)38self.conv_1 = spectral_norm(self.conv_1)39if self.learned_shortcut:40self.conv_s = spectral_norm(self.conv_s)4142# define normalization layers43spade_config_str = opt.norm_G.replace('spectral', '')44self.norm_0 = SPADE(spade_config_str, fin, semantic_nc)45self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc)46if self.learned_shortcut:47self.norm_s = SPADE(spade_config_str, fin, semantic_nc)4849# note the resnet block with SPADE also takes in |seg|,50# the semantic segmentation map as input51def forward(self, x, seg):52x_s = self.shortcut(x, seg)5354dx = self.conv_0(self.actvn(self.norm_0(x, seg)))55dx = self.conv_1(self.actvn(self.norm_1(dx, seg)))5657out = x_s + dx5859return out6061def shortcut(self, x, seg):62if self.learned_shortcut:63x_s = self.conv_s(self.norm_s(x, seg))64else:65x_s = x66return x_s6768def actvn(self, x):69return F.leaky_relu(x, 2e-1)707172# ResNet block used in pix2pixHD73# We keep the same architecture as pix2pixHD.74class ResnetBlock(nn.Module):75def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3):76super().__init__()7778pw = (kernel_size - 1) // 279self.conv_block = nn.Sequential(80nn.ReflectionPad2d(pw),81norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),82activation,83nn.ReflectionPad2d(pw),84norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size))85)8687def forward(self, x):88y = self.conv_block(x)89out = x + y90return out919293# VGG architecter, used for the perceptual loss using a pretrained VGG network94class VGG19(torch.nn.Module):95def __init__(self, requires_grad=False):96super().__init__()97vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features98self.slice1 = torch.nn.Sequential()99self.slice2 = torch.nn.Sequential()100self.slice3 = torch.nn.Sequential()101self.slice4 = torch.nn.Sequential()102self.slice5 = torch.nn.Sequential()103for x in range(2):104self.slice1.add_module(str(x), vgg_pretrained_features[x])105for x in range(2, 7):106self.slice2.add_module(str(x), vgg_pretrained_features[x])107for x in range(7, 12):108self.slice3.add_module(str(x), vgg_pretrained_features[x])109for x in range(12, 21):110self.slice4.add_module(str(x), vgg_pretrained_features[x])111for x in range(21, 30):112self.slice5.add_module(str(x), vgg_pretrained_features[x])113if not requires_grad:114for param in self.parameters():115param.requires_grad = False116117def forward(self, X):118h_relu1 = self.slice1(X)119h_relu2 = self.slice2(h_relu1)120h_relu3 = self.slice3(h_relu2)121h_relu4 = self.slice4(h_relu3)122h_relu5 = self.slice5(h_relu4)123out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]124return out125126127