CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/network/AADLayer.py
Views: 792
1
import torch
2
import torch.nn as nn
3
4
5
class AADLayer(nn.Module):
6
def __init__(self, c_x, attr_c, c_id):
7
super(AADLayer, self).__init__()
8
self.attr_c = attr_c
9
self.c_id = c_id
10
self.c_x = c_x
11
12
self.conv1 = nn.Conv2d(attr_c, c_x, kernel_size=1, stride=1, padding=0, bias=True)
13
self.conv2 = nn.Conv2d(attr_c, c_x, kernel_size=1, stride=1, padding=0, bias=True)
14
self.fc1 = nn.Linear(c_id, c_x)
15
self.fc2 = nn.Linear(c_id, c_x)
16
self.norm = nn.InstanceNorm2d(c_x, affine=False)
17
18
self.conv_h = nn.Conv2d(c_x, 1, kernel_size=1, stride=1, padding=0, bias=True)
19
20
def forward(self, h_in, z_attr, z_id):
21
# h_in cxnxn
22
# zid 256x1x1
23
# zattr cxnxn
24
h = self.norm(h_in)
25
gamma_attr = self.conv1(z_attr)
26
beta_attr = self.conv2(z_attr)
27
28
gamma_id = self.fc1(z_id)
29
beta_id = self.fc2(z_id)
30
A = gamma_attr * h + beta_attr
31
gamma_id = gamma_id.reshape(h.shape[0], self.c_x, 1, 1).expand_as(h)
32
beta_id = beta_id.reshape(h.shape[0], self.c_x, 1, 1).expand_as(h)
33
I = gamma_id * h + beta_id
34
35
M = torch.sigmoid(self.conv_h(h))
36
37
out = (torch.ones_like(M).to(M.device) - M) * A + M * I
38
return out
39
40
41
class AddBlocksSequential(nn.Sequential):
42
def forward(self, *inputs):
43
h, z_attr, z_id = inputs
44
for i, module in enumerate(self._modules.values()):
45
if i%3 == 0 and i > 0:
46
inputs = (inputs, z_attr, z_id)
47
if type(inputs) == tuple:
48
inputs = module(*inputs)
49
else:
50
inputs = module(inputs)
51
return inputs
52
53
54
class AAD_ResBlk(nn.Module):
55
def __init__(self, cin, cout, c_attr, c_id, num_blocks):
56
super(AAD_ResBlk, self).__init__()
57
self.cin = cin
58
self.cout = cout
59
60
add_blocks = []
61
for i in range(num_blocks):
62
out = cin if i < (num_blocks-1) else cout
63
add_blocks.extend([AADLayer(cin, c_attr, c_id),
64
nn.ReLU(inplace=True),
65
nn.Conv2d(cin, out, kernel_size=3, stride=1, padding=1, bias=False)
66
])
67
self.add_blocks = AddBlocksSequential(*add_blocks)
68
69
if cin != cout:
70
last_add_block = [AADLayer(cin, c_attr, c_id),
71
nn.ReLU(inplace=True),
72
nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1, bias=False)]
73
self.last_add_block = AddBlocksSequential(*last_add_block)
74
75
76
def forward(self, h, z_attr, z_id):
77
x = self.add_blocks(h, z_attr, z_id)
78
if self.cin != self.cout:
79
h = self.last_add_block(h, z_attr, z_id)
80
x = x + h
81
return x
82
83
84
85