Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/models/networks/architecture.py
Views: 813
1
"""
2
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
"""
5
6
import torch
7
import torch.nn as nn
8
import torch.nn.functional as F
9
import torchvision
10
import torch.nn.utils.spectral_norm as spectral_norm
11
from models.networks.normalization import SPADE
12
13
14
# ResNet block that uses SPADE.
15
# It differs from the ResNet block of pix2pixHD in that
16
# it takes in the segmentation map as input, learns the skip connection if necessary,
17
# and applies normalization first and then convolution.
18
# This architecture seemed like a standard architecture for unconditional or
19
# class-conditional GAN architecture using residual block.
20
# The code was inspired from https://github.com/LMescheder/GAN_stability.
21
class SPADEResnetBlock(nn.Module):
22
def __init__(self, fin, fout, opt, semantic_nc=None):
23
super().__init__()
24
# Attributes
25
self.learned_shortcut = (fin != fout)
26
fmiddle = min(fin, fout)
27
if semantic_nc is None:
28
semantic_nc = opt.semantic_nc
29
30
# create conv layers
31
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
32
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
33
if self.learned_shortcut:
34
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
35
36
# apply spectral norm if specified
37
if 'spectral' in opt.norm_G:
38
self.conv_0 = spectral_norm(self.conv_0)
39
self.conv_1 = spectral_norm(self.conv_1)
40
if self.learned_shortcut:
41
self.conv_s = spectral_norm(self.conv_s)
42
43
# define normalization layers
44
spade_config_str = opt.norm_G.replace('spectral', '')
45
self.norm_0 = SPADE(spade_config_str, fin, semantic_nc)
46
self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc)
47
if self.learned_shortcut:
48
self.norm_s = SPADE(spade_config_str, fin, semantic_nc)
49
50
# note the resnet block with SPADE also takes in |seg|,
51
# the semantic segmentation map as input
52
def forward(self, x, seg):
53
x_s = self.shortcut(x, seg)
54
55
dx = self.conv_0(self.actvn(self.norm_0(x, seg)))
56
dx = self.conv_1(self.actvn(self.norm_1(dx, seg)))
57
58
out = x_s + dx
59
60
return out
61
62
def shortcut(self, x, seg):
63
if self.learned_shortcut:
64
x_s = self.conv_s(self.norm_s(x, seg))
65
else:
66
x_s = x
67
return x_s
68
69
def actvn(self, x):
70
return F.leaky_relu(x, 2e-1)
71
72
73
# ResNet block used in pix2pixHD
74
# We keep the same architecture as pix2pixHD.
75
class ResnetBlock(nn.Module):
76
def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3):
77
super().__init__()
78
79
pw = (kernel_size - 1) // 2
80
self.conv_block = nn.Sequential(
81
nn.ReflectionPad2d(pw),
82
norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),
83
activation,
84
nn.ReflectionPad2d(pw),
85
norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size))
86
)
87
88
def forward(self, x):
89
y = self.conv_block(x)
90
out = x + y
91
return out
92
93
94
# VGG architecter, used for the perceptual loss using a pretrained VGG network
95
class VGG19(torch.nn.Module):
96
def __init__(self, requires_grad=False):
97
super().__init__()
98
vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
99
self.slice1 = torch.nn.Sequential()
100
self.slice2 = torch.nn.Sequential()
101
self.slice3 = torch.nn.Sequential()
102
self.slice4 = torch.nn.Sequential()
103
self.slice5 = torch.nn.Sequential()
104
for x in range(2):
105
self.slice1.add_module(str(x), vgg_pretrained_features[x])
106
for x in range(2, 7):
107
self.slice2.add_module(str(x), vgg_pretrained_features[x])
108
for x in range(7, 12):
109
self.slice3.add_module(str(x), vgg_pretrained_features[x])
110
for x in range(12, 21):
111
self.slice4.add_module(str(x), vgg_pretrained_features[x])
112
for x in range(21, 30):
113
self.slice5.add_module(str(x), vgg_pretrained_features[x])
114
if not requires_grad:
115
for param in self.parameters():
116
param.requires_grad = False
117
118
def forward(self, X):
119
h_relu1 = self.slice1(X)
120
h_relu2 = self.slice2(h_relu1)
121
h_relu3 = self.slice3(h_relu2)
122
h_relu4 = self.slice4(h_relu3)
123
h_relu5 = self.slice5(h_relu4)
124
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
125
return out
126
127