Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/models/models.py
Views: 792
import math1import torch2import torch.nn.functional as F3from torch import nn4from torch.nn import Parameter5from .config import device, num_classes678def create_model(opt):9if opt.model == 'pix2pixHD':10#from .pix2pixHD_model import Pix2PixHDModel, InferenceModel11from .fs_model import fsModel12model = fsModel()13else:14from .ui_model import UIModel15model = UIModel()1617model.initialize(opt)18if opt.verbose:19print("model [%s] was created" % (model.name()))2021if opt.isTrain and len(opt.gpu_ids) and not opt.fp16:22model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)2324return model25262728class SEBlock(nn.Module):29def __init__(self, channel, reduction=16):30super(SEBlock, self).__init__()31self.avg_pool = nn.AdaptiveAvgPool2d(1)32self.fc = nn.Sequential(33nn.Linear(channel, channel // reduction),34nn.PReLU(),35nn.Linear(channel // reduction, channel),36nn.Sigmoid()37)3839def forward(self, x):40b, c, _, _ = x.size()41y = self.avg_pool(x).view(b, c)42y = self.fc(y).view(b, c, 1, 1)43return x * y444546class IRBlock(nn.Module):47expansion = 14849def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):50super(IRBlock, self).__init__()51self.bn0 = nn.BatchNorm2d(inplanes)52self.conv1 = conv3x3(inplanes, inplanes)53self.bn1 = nn.BatchNorm2d(inplanes)54self.prelu = nn.PReLU()55self.conv2 = conv3x3(inplanes, planes, stride)56self.bn2 = nn.BatchNorm2d(planes)57self.downsample = downsample58self.stride = stride59self.use_se = use_se60if self.use_se:61self.se = SEBlock(planes)6263def forward(self, x):64residual = x65out = self.bn0(x)66out = self.conv1(out)67out = self.bn1(out)68out = self.prelu(out)6970out = self.conv2(out)71out = self.bn2(out)72if self.use_se:73out = self.se(out)7475if self.downsample is not None:76residual = self.downsample(x)7778out += residual79out = self.prelu(out)8081return out828384class ResNet(nn.Module):8586def __init__(self, block, layers, use_se=True):87self.inplanes = 6488self.use_se = use_se89super(ResNet, self).__init__()90self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, bias=False)91self.bn1 = nn.BatchNorm2d(64)92self.prelu = nn.PReLU()93self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)94self.layer1 = self._make_layer(block, 64, layers[0])95self.layer2 = self._make_layer(block, 128, layers[1], stride=2)96self.layer3 = self._make_layer(block, 256, layers[2], stride=2)97self.layer4 = self._make_layer(block, 512, layers[3], stride=2)98self.bn2 = nn.BatchNorm2d(512)99self.dropout = nn.Dropout()100self.fc = nn.Linear(512 * 7 * 7, 512)101self.bn3 = nn.BatchNorm1d(512)102103for m in self.modules():104if isinstance(m, nn.Conv2d):105nn.init.xavier_normal_(m.weight)106elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):107nn.init.constant_(m.weight, 1)108nn.init.constant_(m.bias, 0)109elif isinstance(m, nn.Linear):110nn.init.xavier_normal_(m.weight)111nn.init.constant_(m.bias, 0)112113def _make_layer(self, block, planes, blocks, stride=1):114downsample = None115if stride != 1 or self.inplanes != planes * block.expansion:116downsample = nn.Sequential(117nn.Conv2d(self.inplanes, planes * block.expansion,118kernel_size=1, stride=stride, bias=False),119nn.BatchNorm2d(planes * block.expansion),120)121122layers = []123layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))124self.inplanes = planes125for i in range(1, blocks):126layers.append(block(self.inplanes, planes, use_se=self.use_se))127128return nn.Sequential(*layers)129130def forward(self, x):131x = self.conv1(x)132x = self.bn1(x)133x = self.prelu(x)134x = self.maxpool(x)135136x = self.layer1(x)137x = self.layer2(x)138x = self.layer3(x)139x = self.layer4(x)140141x = self.bn2(x)142x = self.dropout(x)143x = x.view(x.size(0), -1)144x = self.fc(x)145x = self.bn3(x)146147return x148149150class ArcMarginModel(nn.Module):151def __init__(self, args):152super(ArcMarginModel, self).__init__()153154self.weight = Parameter(torch.FloatTensor(num_classes, args.emb_size))155nn.init.xavier_uniform_(self.weight)156157self.easy_margin = args.easy_margin158self.m = args.margin_m159self.s = args.margin_s160161self.cos_m = math.cos(self.m)162self.sin_m = math.sin(self.m)163self.th = math.cos(math.pi - self.m)164self.mm = math.sin(math.pi - self.m) * self.m165166def forward(self, input, label):167x = F.normalize(input)168W = F.normalize(self.weight)169cosine = F.linear(x, W)170sine = torch.sqrt(1.0 - torch.pow(cosine, 2))171phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m)172if self.easy_margin:173phi = torch.where(cosine > 0, phi, cosine)174else:175phi = torch.where(cosine > self.th, phi, cosine - self.mm)176one_hot = torch.zeros(cosine.size(), device=device)177one_hot.scatter_(1, label.view(-1, 1).long(), 1)178output = (one_hot * phi) + ((1.0 - one_hot) * cosine)179output *= self.s180return output181182183