Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/models/networks/__init__.py
Views: 813
"""1Copyright (C) 2019 NVIDIA Corporation. All rights reserved.2Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).3"""45import torch6from models.networks.base_network import BaseNetwork7from models.networks.loss import *8from models.networks.discriminator import *9from models.networks.generator import *10from models.networks.encoder import *11import utils.inference.util as util121314def find_network_using_name(target_network_name, filename):15target_class_name = target_network_name + filename16module_name = 'models.networks.' + filename17network = util.find_class_in_module(target_class_name, module_name)1819assert issubclass(network, BaseNetwork), \20"Class %s should be a subclass of BaseNetwork" % network2122return network232425def modify_commandline_options(parser, is_train):26opt, _ = parser.parse_known_args()2728netG_cls = find_network_using_name(opt.netG, 'generator')29parser = netG_cls.modify_commandline_options(parser, is_train)30if is_train:31netD_cls = find_network_using_name(opt.netD, 'discriminator')32parser = netD_cls.modify_commandline_options(parser, is_train)33netE_cls = find_network_using_name('conv', 'encoder')34parser = netE_cls.modify_commandline_options(parser, is_train)3536return parser373839def create_network(cls, opt):40net = cls(opt)41net.print_network()42if len(opt.gpu_ids) > 0:43assert(torch.cuda.is_available())44net.cuda()45net.init_weights(opt.init_type, opt.init_variance)46return net474849def define_G(opt):50netG_cls = find_network_using_name(opt.netG, 'generator')51return create_network(netG_cls, opt)525354def define_D(opt):55netD_cls = find_network_using_name(opt.netD, 'discriminator')56return create_network(netD_cls, opt)575859def define_E(opt):60# there exists only one encoder type61netE_cls = find_network_using_name('conv', 'encoder')62return create_network(netE_cls, opt)636465