Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/models/networks/encoder.py
Views: 813
"""1Copyright (C) 2019 NVIDIA Corporation. All rights reserved.2Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).3"""45import torch.nn as nn6import numpy as np7import torch.nn.functional as F8from models.networks.base_network import BaseNetwork9from models.networks.normalization import get_nonspade_norm_layer101112class ConvEncoder(BaseNetwork):13""" Same architecture as the image discriminator """1415def __init__(self, opt):16super().__init__()1718kw = 319pw = int(np.ceil((kw - 1.0) / 2))20ndf = opt.ngf21norm_layer = get_nonspade_norm_layer(opt, opt.norm_E)22self.layer1 = norm_layer(nn.Conv2d(3, ndf, kw, stride=2, padding=pw))23self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw))24self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw))25self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw))26self.layer5 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw))27if opt.crop_size >= 256:28self.layer6 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw))2930self.so = s0 = 431self.fc_mu = nn.Linear(ndf * 8 * s0 * s0, 256)32self.fc_var = nn.Linear(ndf * 8 * s0 * s0, 256)3334self.actvn = nn.LeakyReLU(0.2, False)35self.opt = opt3637def forward(self, x):38if x.size(2) != 256 or x.size(3) != 256:39x = F.interpolate(x, size=(256, 256), mode='bilinear')4041x = self.layer1(x)42x = self.layer2(self.actvn(x))43x = self.layer3(self.actvn(x))44x = self.layer4(self.actvn(x))45x = self.layer5(self.actvn(x))46if self.opt.crop_size >= 256:47x = self.layer6(self.actvn(x))48x = self.actvn(x)4950x = x.view(x.size(0), -1)51mu = self.fc_mu(x)52logvar = self.fc_var(x)5354return mu, logvar555657