Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/models/networks/encoder.py
Views: 813
1
"""
2
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
"""
5
6
import torch.nn as nn
7
import numpy as np
8
import torch.nn.functional as F
9
from models.networks.base_network import BaseNetwork
10
from models.networks.normalization import get_nonspade_norm_layer
11
12
13
class ConvEncoder(BaseNetwork):
14
""" Same architecture as the image discriminator """
15
16
def __init__(self, opt):
17
super().__init__()
18
19
kw = 3
20
pw = int(np.ceil((kw - 1.0) / 2))
21
ndf = opt.ngf
22
norm_layer = get_nonspade_norm_layer(opt, opt.norm_E)
23
self.layer1 = norm_layer(nn.Conv2d(3, ndf, kw, stride=2, padding=pw))
24
self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw))
25
self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw))
26
self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw))
27
self.layer5 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw))
28
if opt.crop_size >= 256:
29
self.layer6 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw))
30
31
self.so = s0 = 4
32
self.fc_mu = nn.Linear(ndf * 8 * s0 * s0, 256)
33
self.fc_var = nn.Linear(ndf * 8 * s0 * s0, 256)
34
35
self.actvn = nn.LeakyReLU(0.2, False)
36
self.opt = opt
37
38
def forward(self, x):
39
if x.size(2) != 256 or x.size(3) != 256:
40
x = F.interpolate(x, size=(256, 256), mode='bilinear')
41
42
x = self.layer1(x)
43
x = self.layer2(self.actvn(x))
44
x = self.layer3(self.actvn(x))
45
x = self.layer4(self.actvn(x))
46
x = self.layer5(self.actvn(x))
47
if self.opt.crop_size >= 256:
48
x = self.layer6(self.actvn(x))
49
x = self.actvn(x)
50
51
x = x.view(x.size(0), -1)
52
mu = self.fc_mu(x)
53
logvar = self.fc_var(x)
54
55
return mu, logvar
56
57