Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/network/AEI_Net.py
Views: 792
import torch1import torch.nn as nn2import torch.nn.functional as F3from .AADLayer import *4from network.resnet import MLAttrEncoderResnet567def weight_init(m):8if isinstance(m, nn.Linear):9m.weight.data.normal_(0, 0.001)10m.bias.data.zero_()11if isinstance(m, nn.Conv2d):12nn.init.xavier_normal_(m.weight.data)1314if isinstance(m, nn.ConvTranspose2d):15nn.init.xavier_normal_(m.weight.data)161718def conv4x4(in_c, out_c, norm=nn.BatchNorm2d):19return nn.Sequential(20nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=4, stride=2, padding=1, bias=False),21norm(out_c),22nn.LeakyReLU(0.1, inplace=True)23)242526class deconv4x4(nn.Module):27def __init__(self, in_c, out_c, norm=nn.BatchNorm2d):28super(deconv4x4, self).__init__()29self.deconv = nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=4, stride=2, padding=1, bias=False)30self.bn = norm(out_c)31self.lrelu = nn.LeakyReLU(0.1, inplace=True)3233def forward(self, input, skip, backbone):34x = self.deconv(input)35x = self.bn(x)36x = self.lrelu(x)37if backbone == 'linknet':38return x+skip39else:40return torch.cat((x, skip), dim=1)414243class MLAttrEncoder(nn.Module):44def __init__(self, backbone):45super(MLAttrEncoder, self).__init__()46self.backbone = backbone47self.conv1 = conv4x4(3, 32)48self.conv2 = conv4x4(32, 64)49self.conv3 = conv4x4(64, 128)50self.conv4 = conv4x4(128, 256)51self.conv5 = conv4x4(256, 512)52self.conv6 = conv4x4(512, 1024)53self.conv7 = conv4x4(1024, 1024)5455if backbone == 'unet':56self.deconv1 = deconv4x4(1024, 1024)57self.deconv2 = deconv4x4(2048, 512)58self.deconv3 = deconv4x4(1024, 256)59self.deconv4 = deconv4x4(512, 128)60self.deconv5 = deconv4x4(256, 64)61self.deconv6 = deconv4x4(128, 32)62elif backbone == 'linknet':63self.deconv1 = deconv4x4(1024, 1024)64self.deconv2 = deconv4x4(1024, 512)65self.deconv3 = deconv4x4(512, 256)66self.deconv4 = deconv4x4(256, 128)67self.deconv5 = deconv4x4(128, 64)68self.deconv6 = deconv4x4(64, 32)69self.apply(weight_init)7071def forward(self, Xt):72feat1 = self.conv1(Xt)73# 32x128x12874feat2 = self.conv2(feat1)75# 64x64x6476feat3 = self.conv3(feat2)77# 128x32x3278feat4 = self.conv4(feat3)79# 256x16xx1680feat5 = self.conv5(feat4)81# 512x8x882feat6 = self.conv6(feat5)83# 1024x4x484z_attr1 = self.conv7(feat6)85# 1024x2x28687z_attr2 = self.deconv1(z_attr1, feat6, self.backbone)88z_attr3 = self.deconv2(z_attr2, feat5, self.backbone)89z_attr4 = self.deconv3(z_attr3, feat4, self.backbone)90z_attr5 = self.deconv4(z_attr4, feat3, self.backbone)91z_attr6 = self.deconv5(z_attr5, feat2, self.backbone)92z_attr7 = self.deconv6(z_attr6, feat1, self.backbone)93z_attr8 = F.interpolate(z_attr7, scale_factor=2, mode='bilinear', align_corners=True)94return z_attr1, z_attr2, z_attr3, z_attr4, z_attr5, z_attr6, z_attr7, z_attr8959697class AADGenerator(nn.Module):98def __init__(self, backbone, c_id=256, num_blocks=2):99super(AADGenerator, self).__init__()100self.up1 = nn.ConvTranspose2d(c_id, 1024, kernel_size=2, stride=1, padding=0)101self.AADBlk1 = AAD_ResBlk(1024, 1024, 1024, c_id, num_blocks)102if backbone == 'linknet':103self.AADBlk2 = AAD_ResBlk(1024, 1024, 1024, c_id, num_blocks)104self.AADBlk3 = AAD_ResBlk(1024, 1024, 512, c_id, num_blocks)105self.AADBlk4 = AAD_ResBlk(1024, 512, 256, c_id, num_blocks)106self.AADBlk5 = AAD_ResBlk(512, 256, 128, c_id, num_blocks)107self.AADBlk6 = AAD_ResBlk(256, 128, 64, c_id, num_blocks)108self.AADBlk7 = AAD_ResBlk(128, 64, 32, c_id, num_blocks)109self.AADBlk8 = AAD_ResBlk(64, 3, 32, c_id, num_blocks)110else:111self.AADBlk2 = AAD_ResBlk(1024, 1024, 2048, c_id, num_blocks)112self.AADBlk3 = AAD_ResBlk(1024, 1024, 1024, c_id, num_blocks)113self.AADBlk4 = AAD_ResBlk(1024, 512, 512, c_id, num_blocks)114self.AADBlk5 = AAD_ResBlk(512, 256, 256, c_id, num_blocks)115self.AADBlk6 = AAD_ResBlk(256, 128, 128, c_id, num_blocks)116self.AADBlk7 = AAD_ResBlk(128, 64, 64, c_id, num_blocks)117self.AADBlk8 = AAD_ResBlk(64, 3, 64, c_id, num_blocks)118119self.apply(weight_init)120121def forward(self, z_attr, z_id):122m = self.up1(z_id.reshape(z_id.shape[0], -1, 1, 1))123m2 = F.interpolate(self.AADBlk1(m, z_attr[0], z_id), scale_factor=2, mode='bilinear', align_corners=True)124m3 = F.interpolate(self.AADBlk2(m2, z_attr[1], z_id), scale_factor=2, mode='bilinear', align_corners=True)125m4 = F.interpolate(self.AADBlk3(m3, z_attr[2], z_id), scale_factor=2, mode='bilinear', align_corners=True)126m5 = F.interpolate(self.AADBlk4(m4, z_attr[3], z_id), scale_factor=2, mode='bilinear', align_corners=True)127m6 = F.interpolate(self.AADBlk5(m5, z_attr[4], z_id), scale_factor=2, mode='bilinear', align_corners=True)128m7 = F.interpolate(self.AADBlk6(m6, z_attr[5], z_id), scale_factor=2, mode='bilinear', align_corners=True)129m8 = F.interpolate(self.AADBlk7(m7, z_attr[6], z_id), scale_factor=2, mode='bilinear', align_corners=True)130y = self.AADBlk8(m8, z_attr[7], z_id)131return torch.tanh(y)132133134135class AEI_Net(nn.Module):136def __init__(self, backbone, num_blocks=2, c_id=256):137super(AEI_Net, self).__init__()138if backbone in ['unet', 'linknet']:139self.encoder = MLAttrEncoder(backbone)140elif backbone == 'resnet':141self.encoder = MLAttrEncoderResnet()142self.generator = AADGenerator(backbone, c_id, num_blocks)143144def forward(self, Xt, z_id):145attr = self.encoder(Xt)146Y = self.generator(attr, z_id)147return Y, attr148149def get_attr(self, X):150return self.encoder(X)151152153154155156