CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/network/AEI_Net.py
Views: 792
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
from .AADLayer import *
5
from network.resnet import MLAttrEncoderResnet
6
7
8
def weight_init(m):
9
if isinstance(m, nn.Linear):
10
m.weight.data.normal_(0, 0.001)
11
m.bias.data.zero_()
12
if isinstance(m, nn.Conv2d):
13
nn.init.xavier_normal_(m.weight.data)
14
15
if isinstance(m, nn.ConvTranspose2d):
16
nn.init.xavier_normal_(m.weight.data)
17
18
19
def conv4x4(in_c, out_c, norm=nn.BatchNorm2d):
20
return nn.Sequential(
21
nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=4, stride=2, padding=1, bias=False),
22
norm(out_c),
23
nn.LeakyReLU(0.1, inplace=True)
24
)
25
26
27
class deconv4x4(nn.Module):
28
def __init__(self, in_c, out_c, norm=nn.BatchNorm2d):
29
super(deconv4x4, self).__init__()
30
self.deconv = nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=4, stride=2, padding=1, bias=False)
31
self.bn = norm(out_c)
32
self.lrelu = nn.LeakyReLU(0.1, inplace=True)
33
34
def forward(self, input, skip, backbone):
35
x = self.deconv(input)
36
x = self.bn(x)
37
x = self.lrelu(x)
38
if backbone == 'linknet':
39
return x+skip
40
else:
41
return torch.cat((x, skip), dim=1)
42
43
44
class MLAttrEncoder(nn.Module):
45
def __init__(self, backbone):
46
super(MLAttrEncoder, self).__init__()
47
self.backbone = backbone
48
self.conv1 = conv4x4(3, 32)
49
self.conv2 = conv4x4(32, 64)
50
self.conv3 = conv4x4(64, 128)
51
self.conv4 = conv4x4(128, 256)
52
self.conv5 = conv4x4(256, 512)
53
self.conv6 = conv4x4(512, 1024)
54
self.conv7 = conv4x4(1024, 1024)
55
56
if backbone == 'unet':
57
self.deconv1 = deconv4x4(1024, 1024)
58
self.deconv2 = deconv4x4(2048, 512)
59
self.deconv3 = deconv4x4(1024, 256)
60
self.deconv4 = deconv4x4(512, 128)
61
self.deconv5 = deconv4x4(256, 64)
62
self.deconv6 = deconv4x4(128, 32)
63
elif backbone == 'linknet':
64
self.deconv1 = deconv4x4(1024, 1024)
65
self.deconv2 = deconv4x4(1024, 512)
66
self.deconv3 = deconv4x4(512, 256)
67
self.deconv4 = deconv4x4(256, 128)
68
self.deconv5 = deconv4x4(128, 64)
69
self.deconv6 = deconv4x4(64, 32)
70
self.apply(weight_init)
71
72
def forward(self, Xt):
73
feat1 = self.conv1(Xt)
74
# 32x128x128
75
feat2 = self.conv2(feat1)
76
# 64x64x64
77
feat3 = self.conv3(feat2)
78
# 128x32x32
79
feat4 = self.conv4(feat3)
80
# 256x16xx16
81
feat5 = self.conv5(feat4)
82
# 512x8x8
83
feat6 = self.conv6(feat5)
84
# 1024x4x4
85
z_attr1 = self.conv7(feat6)
86
# 1024x2x2
87
88
z_attr2 = self.deconv1(z_attr1, feat6, self.backbone)
89
z_attr3 = self.deconv2(z_attr2, feat5, self.backbone)
90
z_attr4 = self.deconv3(z_attr3, feat4, self.backbone)
91
z_attr5 = self.deconv4(z_attr4, feat3, self.backbone)
92
z_attr6 = self.deconv5(z_attr5, feat2, self.backbone)
93
z_attr7 = self.deconv6(z_attr6, feat1, self.backbone)
94
z_attr8 = F.interpolate(z_attr7, scale_factor=2, mode='bilinear', align_corners=True)
95
return z_attr1, z_attr2, z_attr3, z_attr4, z_attr5, z_attr6, z_attr7, z_attr8
96
97
98
class AADGenerator(nn.Module):
99
def __init__(self, backbone, c_id=256, num_blocks=2):
100
super(AADGenerator, self).__init__()
101
self.up1 = nn.ConvTranspose2d(c_id, 1024, kernel_size=2, stride=1, padding=0)
102
self.AADBlk1 = AAD_ResBlk(1024, 1024, 1024, c_id, num_blocks)
103
if backbone == 'linknet':
104
self.AADBlk2 = AAD_ResBlk(1024, 1024, 1024, c_id, num_blocks)
105
self.AADBlk3 = AAD_ResBlk(1024, 1024, 512, c_id, num_blocks)
106
self.AADBlk4 = AAD_ResBlk(1024, 512, 256, c_id, num_blocks)
107
self.AADBlk5 = AAD_ResBlk(512, 256, 128, c_id, num_blocks)
108
self.AADBlk6 = AAD_ResBlk(256, 128, 64, c_id, num_blocks)
109
self.AADBlk7 = AAD_ResBlk(128, 64, 32, c_id, num_blocks)
110
self.AADBlk8 = AAD_ResBlk(64, 3, 32, c_id, num_blocks)
111
else:
112
self.AADBlk2 = AAD_ResBlk(1024, 1024, 2048, c_id, num_blocks)
113
self.AADBlk3 = AAD_ResBlk(1024, 1024, 1024, c_id, num_blocks)
114
self.AADBlk4 = AAD_ResBlk(1024, 512, 512, c_id, num_blocks)
115
self.AADBlk5 = AAD_ResBlk(512, 256, 256, c_id, num_blocks)
116
self.AADBlk6 = AAD_ResBlk(256, 128, 128, c_id, num_blocks)
117
self.AADBlk7 = AAD_ResBlk(128, 64, 64, c_id, num_blocks)
118
self.AADBlk8 = AAD_ResBlk(64, 3, 64, c_id, num_blocks)
119
120
self.apply(weight_init)
121
122
def forward(self, z_attr, z_id):
123
m = self.up1(z_id.reshape(z_id.shape[0], -1, 1, 1))
124
m2 = F.interpolate(self.AADBlk1(m, z_attr[0], z_id), scale_factor=2, mode='bilinear', align_corners=True)
125
m3 = F.interpolate(self.AADBlk2(m2, z_attr[1], z_id), scale_factor=2, mode='bilinear', align_corners=True)
126
m4 = F.interpolate(self.AADBlk3(m3, z_attr[2], z_id), scale_factor=2, mode='bilinear', align_corners=True)
127
m5 = F.interpolate(self.AADBlk4(m4, z_attr[3], z_id), scale_factor=2, mode='bilinear', align_corners=True)
128
m6 = F.interpolate(self.AADBlk5(m5, z_attr[4], z_id), scale_factor=2, mode='bilinear', align_corners=True)
129
m7 = F.interpolate(self.AADBlk6(m6, z_attr[5], z_id), scale_factor=2, mode='bilinear', align_corners=True)
130
m8 = F.interpolate(self.AADBlk7(m7, z_attr[6], z_id), scale_factor=2, mode='bilinear', align_corners=True)
131
y = self.AADBlk8(m8, z_attr[7], z_id)
132
return torch.tanh(y)
133
134
135
136
class AEI_Net(nn.Module):
137
def __init__(self, backbone, num_blocks=2, c_id=256):
138
super(AEI_Net, self).__init__()
139
if backbone in ['unet', 'linknet']:
140
self.encoder = MLAttrEncoder(backbone)
141
elif backbone == 'resnet':
142
self.encoder = MLAttrEncoderResnet()
143
self.generator = AADGenerator(backbone, c_id, num_blocks)
144
145
def forward(self, Xt, z_id):
146
attr = self.encoder(Xt)
147
Y = self.generator(attr, z_id)
148
return Y, attr
149
150
def get_attr(self, X):
151
return self.encoder(X)
152
153
154
155
156