Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/network/AADLayer.py
Views: 792
import torch1import torch.nn as nn234class AADLayer(nn.Module):5def __init__(self, c_x, attr_c, c_id):6super(AADLayer, self).__init__()7self.attr_c = attr_c8self.c_id = c_id9self.c_x = c_x1011self.conv1 = nn.Conv2d(attr_c, c_x, kernel_size=1, stride=1, padding=0, bias=True)12self.conv2 = nn.Conv2d(attr_c, c_x, kernel_size=1, stride=1, padding=0, bias=True)13self.fc1 = nn.Linear(c_id, c_x)14self.fc2 = nn.Linear(c_id, c_x)15self.norm = nn.InstanceNorm2d(c_x, affine=False)1617self.conv_h = nn.Conv2d(c_x, 1, kernel_size=1, stride=1, padding=0, bias=True)1819def forward(self, h_in, z_attr, z_id):20# h_in cxnxn21# zid 256x1x122# zattr cxnxn23h = self.norm(h_in)24gamma_attr = self.conv1(z_attr)25beta_attr = self.conv2(z_attr)2627gamma_id = self.fc1(z_id)28beta_id = self.fc2(z_id)29A = gamma_attr * h + beta_attr30gamma_id = gamma_id.reshape(h.shape[0], self.c_x, 1, 1).expand_as(h)31beta_id = beta_id.reshape(h.shape[0], self.c_x, 1, 1).expand_as(h)32I = gamma_id * h + beta_id3334M = torch.sigmoid(self.conv_h(h))3536out = (torch.ones_like(M).to(M.device) - M) * A + M * I37return out383940class AddBlocksSequential(nn.Sequential):41def forward(self, *inputs):42h, z_attr, z_id = inputs43for i, module in enumerate(self._modules.values()):44if i%3 == 0 and i > 0:45inputs = (inputs, z_attr, z_id)46if type(inputs) == tuple:47inputs = module(*inputs)48else:49inputs = module(inputs)50return inputs515253class AAD_ResBlk(nn.Module):54def __init__(self, cin, cout, c_attr, c_id, num_blocks):55super(AAD_ResBlk, self).__init__()56self.cin = cin57self.cout = cout5859add_blocks = []60for i in range(num_blocks):61out = cin if i < (num_blocks-1) else cout62add_blocks.extend([AADLayer(cin, c_attr, c_id),63nn.ReLU(inplace=True),64nn.Conv2d(cin, out, kernel_size=3, stride=1, padding=1, bias=False)65])66self.add_blocks = AddBlocksSequential(*add_blocks)6768if cin != cout:69last_add_block = [AADLayer(cin, c_attr, c_id),70nn.ReLU(inplace=True),71nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1, bias=False)]72self.last_add_block = AddBlocksSequential(*last_add_block)737475def forward(self, h, z_attr, z_id):76x = self.add_blocks(h, z_attr, z_id)77if self.cin != self.cout:78h = self.last_add_block(h, z_attr, z_id)79x = x + h80return x8182838485