Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/models/networks/__init__.py
Views: 813
1
"""
2
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
"""
5
6
import torch
7
from models.networks.base_network import BaseNetwork
8
from models.networks.loss import *
9
from models.networks.discriminator import *
10
from models.networks.generator import *
11
from models.networks.encoder import *
12
import utils.inference.util as util
13
14
15
def find_network_using_name(target_network_name, filename):
16
target_class_name = target_network_name + filename
17
module_name = 'models.networks.' + filename
18
network = util.find_class_in_module(target_class_name, module_name)
19
20
assert issubclass(network, BaseNetwork), \
21
"Class %s should be a subclass of BaseNetwork" % network
22
23
return network
24
25
26
def modify_commandline_options(parser, is_train):
27
opt, _ = parser.parse_known_args()
28
29
netG_cls = find_network_using_name(opt.netG, 'generator')
30
parser = netG_cls.modify_commandline_options(parser, is_train)
31
if is_train:
32
netD_cls = find_network_using_name(opt.netD, 'discriminator')
33
parser = netD_cls.modify_commandline_options(parser, is_train)
34
netE_cls = find_network_using_name('conv', 'encoder')
35
parser = netE_cls.modify_commandline_options(parser, is_train)
36
37
return parser
38
39
40
def create_network(cls, opt):
41
net = cls(opt)
42
net.print_network()
43
if len(opt.gpu_ids) > 0:
44
assert(torch.cuda.is_available())
45
net.cuda()
46
net.init_weights(opt.init_type, opt.init_variance)
47
return net
48
49
50
def define_G(opt):
51
netG_cls = find_network_using_name(opt.netG, 'generator')
52
return create_network(netG_cls, opt)
53
54
55
def define_D(opt):
56
netD_cls = find_network_using_name(opt.netD, 'discriminator')
57
return create_network(netD_cls, opt)
58
59
60
def define_E(opt):
61
# there exists only one encoder type
62
netE_cls = find_network_using_name('conv', 'encoder')
63
return create_network(netE_cls, opt)
64
65