Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/models/networks/generator.py
Views: 813
"""1Copyright (C) 2019 NVIDIA Corporation. All rights reserved.2Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).3"""45import torch6import torch.nn as nn7import torch.nn.functional as F8from models.networks.base_network import BaseNetwork9from models.networks.normalization import get_nonspade_norm_layer10from models.networks.architecture import ResnetBlock as ResnetBlock11from models.networks.architecture import SPADEResnetBlock as SPADEResnetBlock12import os13#import data # only run from basic level!14import copy # deepcopy1516class SPADEGenerator(BaseNetwork):17@staticmethod18def modify_commandline_options(parser, is_train):19parser.set_defaults(norm_G='spectralspadesyncbatch3x3')20parser.add_argument('--num_upsampling_layers',21choices=('normal', 'more', 'most'), default='normal',22help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator")2324return parser2526def __init__(self, opt):27super().__init__()28self.opt = opt29nf = opt.ngf3031self.sw, self.sh, self.scale_ratio = self.compute_latent_vector_size(opt)3233if opt.use_vae:34# In case of VAE, we will sample from random z vector35self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)36else:37# Otherwise, we make the network deterministic by starting with38# downsampled segmentation map instead of random z39self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)4041self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)4243self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)44self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt)4546# 20200211 test 4x with only 3 stage4748self.ups = nn.ModuleList([49SPADEResnetBlock(16 * nf, 8 * nf, opt),50SPADEResnetBlock(8 * nf, 4 * nf, opt),51SPADEResnetBlock(4 * nf, 2 * nf, opt),52SPADEResnetBlock(2 * nf, 1 * nf, opt) # here53])5455self.to_rgbs = nn.ModuleList([56nn.Conv2d(8 * nf, 3, 3, padding=1),57nn.Conv2d(4 * nf, 3, 3, padding=1),58nn.Conv2d(2 * nf, 3, 3, padding=1),59nn.Conv2d(1 * nf, 3, 3, padding=1) # here60])6162self.up = nn.Upsample(scale_factor=2)6364# 20200309 interface for flexible encoder design65# and mid-level loss control!66# For basic network, it's just a 16x downsampling67def encode(self, input):68h, w = input.size()[-2:]69sh, sw = h//2**self.scale_ratio, w//2**self.scale_ratio70x = F.interpolate(input, size=(sh, sw))71return self.fc(x) # 20200310: Merge fc into encoder7273def compute_latent_vector_size(self, opt):74if opt.num_upsampling_layers == 'normal':75num_up_layers = 576elif opt.num_upsampling_layers == 'more':77num_up_layers = 678elif opt.num_upsampling_layers == 'most':79num_up_layers = 780else:81raise ValueError('opt.num_upsampling_layers [%s] not recognized' %82opt.num_upsampling_layers)8384# 20200211 Yang Lingbo with respect to phase85scale_ratio = num_up_layers86#scale_ratio = 4 #here87sw = opt.crop_size // (2**num_up_layers)88sh = round(sw / opt.aspect_ratio)8990return sw, sh, scale_ratio9192def forward(self, input, seg=None):93'''9420200307: Dangerous Change95Add separable forward to allow different96input and segmentation maps...9798To return to original, simply add99seg = input at begining, and disable the seg parameter.10010120200308: A more elegant solution:102@ Allow forward to take default parameters.103@ Allow customizable input encoding10410520200310: Merge fc into encode, since encoder directly outputs processed feature map.106107[TODO] @ Allow customizable segmap encoding?108'''109110if seg is None:111seg = input # Interesting change...112113# For basic generator, 16x downsampling.114# 20200310: Merge fc into encoder115x = self.encode(input)116#print(x.shape, input.shape, seg.shape)117118x = self.head_0(x, seg)119120x = self.up(x)121x = self.G_middle_0(x, seg)122x = self.G_middle_1(x, seg)123124if self.opt.is_test:125phase = len(self.to_rgbs)126else:127phase = self.opt.train_phase+1128129for i in range(phase):130x = self.up(x)131x = self.ups[i](x, seg)132133x = self.to_rgbs[phase-1](F.leaky_relu(x, 2e-1))134x = torch.tanh(x)135136return x137138def mixed_guidance_forward(self, input, seg=None, n=0, mode='progressive'):139'''140mixed_forward: input and seg are different images141For the first n levels (including encoder)142we use input, for the rest we use seg.143144If mode = 'progressive', the output's like: AAABBB145If mode = 'one_plug', the output's like: AAABAA146If mode = 'one_ablate', the output's like: BBBABB147'''148149if seg is None:150return self.forward(input)151152if self.opt.is_test:153phase = len(self.to_rgbs)154else:155phase = self.opt.train_phase+1156157if mode == 'progressive':158n = max(min(n, 4 + phase), 0)159guide_list = [input] * n + [seg] * (4+phase-n)160elif mode == 'one_plug':161n = max(min(n, 4 + phase-1), 0)162guide_list = [seg] * (4+phase)163guide_list[n] = input164elif mode == 'one_ablate':165if n > 3+phase:166return self.forward(input)167guide_list = [input] * (4+phase)168guide_list[n] = seg169170x = self.encode(guide_list[0])171x = self.head_0(x, guide_list[1])172173x = self.up(x)174x = self.G_middle_0(x, guide_list[2])175x = self.G_middle_1(x, guide_list[3])176177for i in range(phase):178x = self.up(x)179x = self.ups[i](x, guide_list[4+i])180181x = self.to_rgbs[phase-1](F.leaky_relu(x, 2e-1))182x = torch.tanh(x)183184return x185186class HiFaceGANGenerator(SPADEGenerator):187def __init__(self, opt):188super().__init__(opt)189self.opt = opt190nf = opt.ngf191192self.sw, self.sh, self.scale_ratio = self.compute_latent_vector_size(opt)193194if opt.use_vae:195# In case of VAE, we will sample from random z vector196self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)197else:198# Otherwise, we make the network deterministic by starting with199# downsampled segmentation map instead of random z200self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)201202self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, 16 * nf)203204self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, 16 * nf)205self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt, 16 * nf)206207# 20200211 test 4x with only 3 stage208209self.ups = nn.ModuleList([210SPADEResnetBlock(16 * nf, 8 * nf, opt, 8 * nf),211SPADEResnetBlock(8 * nf, 4 * nf, opt, 4 * nf),212SPADEResnetBlock(4 * nf, 2 * nf, opt, 2 * nf),213SPADEResnetBlock(2 * nf, 1 * nf, opt, 1 * nf) # here214])215216self.to_rgbs = nn.ModuleList([217nn.Conv2d(8 * nf, 3, 3, padding=1),218nn.Conv2d(4 * nf, 3, 3, padding=1),219nn.Conv2d(2 * nf, 3, 3, padding=1),220nn.Conv2d(1 * nf, 3, 3, padding=1) # here221])222223self.up = nn.Upsample(scale_factor=2)224self.encoder = ContentAdaptiveSuppresor(opt, self.sw, self.sh, self.scale_ratio)225226def nested_encode(self, x):227return self.encoder(x)228229def forward(self, input):230xs = self.nested_encode(input)231x = self.encode(input)232'''233print([_x.shape for _x in xs])234print(x.shape)235print(self.head_0)236'''237x = self.head_0(x, xs[0])238239x = self.up(x)240x = self.G_middle_0(x, xs[1])241x = self.G_middle_1(x, xs[1])242243if self.opt.is_test:244phase = len(self.to_rgbs)245else:246phase = self.opt.train_phase+1247248for i in range(phase):249x = self.up(x)250x = self.ups[i](x, xs[i+2])251252x = self.to_rgbs[phase-1](F.leaky_relu(x, 2e-1))253x = torch.tanh(x)254255return x256257258class ContentAdaptiveSuppresor(BaseNetwork):259def __init__(self, opt, sw, sh, n_2xdown,260norm_layer=nn.InstanceNorm2d):261super().__init__()262self.sw = sw263self.sh = sh264self.max_ratio = 16265self.n_2xdown = n_2xdown266267# norm_layer = get_nonspade_norm_layer(opt, opt.norm_E)268269# 20200310: Several Convolution (stride 1) + LIP blocks, 4 fold270ngf = opt.ngf271kw = 3272pw = (kw - 1) // 2273274self.head = nn.Sequential(275nn.Conv2d(opt.semantic_nc, ngf, kw, stride=1, padding=pw, bias=False),276norm_layer(ngf),277nn.ReLU(),278)279cur_ratio = 1280for i in range(n_2xdown):281next_ratio = min(cur_ratio*2, self.max_ratio)282model = [283SimplifiedLIP(ngf*cur_ratio),284nn.Conv2d(ngf*cur_ratio, ngf*next_ratio, kw, stride=1, padding=pw),285norm_layer(ngf*next_ratio),286]287cur_ratio = next_ratio288if i < n_2xdown - 1:289model += [nn.ReLU(inplace=True)]290setattr(self, 'encoder_%d' % i, nn.Sequential(*model))291292def forward(self, x):293# 20200628: Note the features are arranged from small to large294x = [self.head(x)]295for i in range(self.n_2xdown):296net = getattr(self, 'encoder_%d' % i)297x = [net(x[0])] + x298return x299300301#########################################302# Below are deprecated codes303#304# 20200309: LIP for local importance pooling305# 20200311: Self-supervised mask encoder306# Author: lingbo.ylb307# Quick trial, to be reformated later.308# 20200324: Nah forget about it...309#########################################310311312def lip2d(x, logit, kernel=3, stride=2, padding=1):313weight = logit.exp()314return F.avg_pool2d(x*weight, kernel, stride, padding) / F.avg_pool2d(weight, kernel, stride, padding)315316317class SoftGate(nn.Module):318COEFF = 12.0319320def __init__(self):321super(SoftGate, self).__init__()322323def forward(self, x):324return torch.sigmoid(x).mul(self.COEFF)325326class SimplifiedLIP(nn.Module):327def __init__(self, channels):328super(SimplifiedLIP, self).__init__()329330rp = channels331332self.logit = nn.Sequential(333nn.Conv2d(channels, channels, 3, padding=1, bias=False),334nn.InstanceNorm2d(channels, affine=True),335SoftGate()336)337'''338OrderedDict((339('conv', nn.Conv2d(channels, channels, 3, padding=1, bias=False)),340('bn', nn.InstanceNorm2d(channels, affine=True)),341('gate', SoftGate()),342))343'''344345def init_layer(self):346self.logit[0].weight.data.fill_(0.0)347348def forward(self, x):349frac = lip2d(x, self.logit(x))350return frac351352class LIPEncoder(BaseNetwork):353def __init__(self, opt, sw, sh, n_2xdown,354norm_layer=nn.InstanceNorm2d):355super().__init__()356self.sw = sw357self.sh = sh358self.max_ratio = 16359360# norm_layer = get_nonspade_norm_layer(opt, opt.norm_E)361362# 20200310: Several Convolution (stride 1) + LIP blocks, 4 fold363ngf = opt.ngf364kw = 3365pw = (kw - 1) // 2366367model = [368nn.Conv2d(opt.semantic_nc, ngf, kw, stride=1, padding=pw, bias=False),369norm_layer(ngf),370nn.ReLU(),371]372cur_ratio = 1373for i in range(n_2xdown):374next_ratio = min(cur_ratio*2, self.max_ratio)375model += [376SimplifiedLIP(ngf*cur_ratio),377nn.Conv2d(ngf*cur_ratio, ngf*next_ratio, kw, stride=1, padding=pw),378norm_layer(ngf*next_ratio),379]380cur_ratio = next_ratio381if i < n_2xdown - 1:382model += [nn.ReLU(inplace=True)]383384self.model = nn.Sequential(*model)385386def forward(self, x):387return self.model(x)388389class LIPSPADEGenerator(SPADEGenerator):390'''39120200309: SPADEGenerator with a learnable feature encoder392Encoder design: Local Importance-based Pooling (Ziteng Gao et.al.,ICCV 2019)393'''394def __init__(self, opt):395super().__init__(opt)396self.lip_encoder = LIPEncoder(opt, self.sw, self.sh, self.scale_ratio)397398def encode(self, x):399return self.lip_encoder(x)400401402class NoiseClassPredictor(nn.Module):403'''404Input: nc*sw*sw tensor, either from clean or corrupted images405Output: n-dim tensor indicating the loss type (or intensity?)406'''407def __init__(self, opt, sw, nc, outdim):408super().__init__()409nbottleneck = 256410middim = 256411# Compact info412conv = [413nn.Conv2d(nc, nbottleneck, 1, stride=1),414nn.InstanceNorm2d(nbottleneck),415nn.LeakyReLU(0.2, inplace=True),416]417# sw should be probably 16, downsample to 4 and convert to 1418while sw > 4:419sw = sw // 2420conv += [421nn.Conv2d(nbottleneck, nbottleneck, 3, stride=2, padding=1),422nn.InstanceNorm2d(nbottleneck),423nn.LeakyReLU(0.2, inplace=True),424]425426427self.conv = nn.Sequential(*conv)428429indim = sw * sw * nbottleneck430self.fc = nn.Sequential(431nn.Linear(indim, middim),432nn.BatchNorm1d(middim),433nn.LeakyReLU(0.2, inplace=True),434nn.Linear(middim, outdim),435# nn.Sigmoid(),436)437438def forward(self, x):439x = self.conv(x)440x = x.view(x.shape[0],-1)441return self.fc(x)442443444class NoiseIntensityPredictor(nn.Module):445'''446Input: nc*sw*sw tensor, either from clean or corrupted images447Output: 1-dim tensor indicating the loss intensity448'''449def __init__(self, opt, sw, nc, outdim):450super().__init__()451nbottleneck = 256452middim = 256453# Compact info454conv = [455nn.Conv2d(nc, nbottleneck, 1, stride=1),456nn.BatchNorm2d(nbottleneck),457nn.LeakyReLU(0.2, inplace=True),458]459# sw should be probably 16, downsample to 4 and convert to 1460while sw > 4:461sw = sw // 2462conv += [463nn.Conv2d(nbottleneck, nbottleneck, 3, stride=2, padding=1),464nn.BatchNorm2d(nbottleneck),465nn.LeakyReLU(0.2, inplace=True),466]467468469self.conv = nn.Sequential(*conv)470471indim = sw * sw * nbottleneck472self.fc = nn.Sequential(473nn.Linear(indim, middim),474nn.BatchNorm1d(middim),475#nn.Dropout(0.5),476nn.LeakyReLU(0.2, inplace=True),477nn.Linear(middim, outdim),478#nn.Dropout(0.5),479# nn.Sigmoid(),480)481482def forward(self, x):483x = self.conv(x)484x = x.view(x.shape[0],-1)485x = self.fc(x)486return x.squeeze()487488489class SubAddGenerator(SPADEGenerator):490'''49120200311:492This generator contains a complete set493of self-supervised training scheme494that requires a separate dataloader.495496The self-supervised pre-training is497implemented as a clean interface.498SubAddGenerator::train_E(self, dataloader, epochs)499500For the upperlevel Pix2Pix_Model,501two things to be done:502A) Run the pretrain script503B) Save encoder and adjust learning rate.504505-----------------------------------------50620200312:507Pre-test problem: The discriminator is hard to test real vs fake508Also, using residual of feature maps ain't work...509510Cause: The feature map is too close to be properly separated.511512Try to test on one single reduction: arbitrary ratio of downsampling513try to estimate the reduction ratio?514'''515def __init__(self, opt):516super().__init__(opt)517self.encoder = LIPEncoder(opt, self.sw, self.sh, self.scale_ratio)518519self.dis_nc = self.opt.ngf * min(16, 2**self.scale_ratio)520# intensity is a scalar521self.discriminator = NoiseIntensityPredictor(opt, self.sw, self.dis_nc, 1)522if opt.isTrain:523self.attach_dataloader(opt)524self.noise_dim = opt.noise_dim525526self.l1_loss = nn.L1Loss()527self.gan_loss = nn.MSELoss()528529#self.discriminator = NoiseClassPredictor(opt, self.sw, self.dis_nc,530# self.noise_dim + 1) # add a new label for clean images531532#self.gan_loss = nn.CrossEntropyLoss()533534beta1, beta2 = opt.beta1, opt.beta2535if opt.no_TTUR:536G_lr, D_lr = opt.lr, opt.lr537else:538G_lr, D_lr = opt.lr / 2, opt.lr * 2539540self.optimizer_E = torch.optim.Adam(541self.encoder.parameters(), lr=G_lr, betas=(beta1, beta2)542)543self.optimizer_D = torch.optim.Adam(544self.discriminator.parameters(), lr=D_lr/2, betas=(beta1, beta2)545)546547def _create_auxiliary_opt(self, opt):548'''549Create auxiliary options550change necessary params551--- dataroot552--- dataset_mode553'''554aux_opt = copy.deepcopy(opt) # just for safety555aux_opt.dataroot = opt.dataroot_assist556aux_opt.dataset_mode = 'assist'557aux_opt.batchSize = 4558aux_opt.nThreads = 4559return aux_opt560561def attach_dataloader(self, opt):562aux_opt = self._create_auxiliary_opt(opt)563self.loader = data.create_dataloader(aux_opt)564565def encode(self, x):566return self.encoder(x)567568def process_input(self, data):569self.clean = data['clean'].cuda()570self.noisy = data['noisy'].cuda()571# for BCELoss, class label is just int.572self.noise_label = data['label'].cuda()573# label design...574# clean label should be 0 or [1,0,0...0]?575# for BCELoss, class label is just 0.576#self.clean_label = torch.zeros_like(self.noise_label)577self.clean_label = torch.ones_like(self.noise_label)578579def update_E(self):580bundle_in = torch.cat((self.clean, self.noisy), dim=0)581bundle_out = self.encode(bundle_in)582nb = bundle_in.shape[0] // 2583F_real, F_fake = bundle_out[:nb], bundle_out[nb:]584585pred_fake = self.discriminator(F_fake)586loss_l1 = self.l1_loss(F_fake, F_real)587loss_gan = self.gan_loss(pred_fake, self.clean_label)588loss_sum = loss_l1 * 10 + loss_gan589590self.optimizer_E.zero_grad()591loss_sum.backward()592self.optimizer_E.step()593594self.loss_l1 = loss_l1.item()595self.loss_gan_E = loss_gan.item()596self.loss_sum = loss_sum.item()597598def update_D(self):599600with torch.no_grad():601bundle_in = torch.cat((self.clean, self.noisy), dim=0)602bundle_out = self.encode(bundle_in)603nb = bundle_in.shape[0] // 2604#F_real, #F_fake = bundle_out[:nb], bundle_out[nb:]605F_fake = bundle_out[nb:] / (bundle_out[:nb] + 1e-6)606F_real = torch.ones_like(F_fake, requires_grad=False)607608pred_real = self.discriminator(F_real)609loss_real = self.gan_loss(pred_real, self.clean_label)610pred_fake = self.discriminator(F_fake.detach())611loss_fake = self.gan_loss(pred_fake, self.noise_label)612loss_sum = (loss_real + loss_fake * self.opt.noise_dim) / 2613614self.optimizer_D.zero_grad()615loss_sum.backward()616self.optimizer_D.step()617618self.loss_gan_D_real = loss_real.item()619self.loss_gan_D_fake = loss_fake.item()620621def debug_D(self):622with torch.no_grad():623bundle_in = torch.cat((self.clean, self.noisy), dim=0)624bundle_out = self.encode(bundle_in)625nb = bundle_in.shape[0] // 2626F_real, F_fake = bundle_out[:nb], bundle_out[nb:]627F_res = F_fake - F_real # try to predict the residual, it's easier628#F_real = torch.zeros_like(F_real) # real res == 0629630631pred_real = self.discriminator(F_real)#.argmax(dim=1)632pred_fake = self.discriminator(F_res.detach())#.argmax(dim=1)633print(pred_real, pred_fake)634real_acc = (pred_real == 0).sum().item() / pred_real.shape[0]635fake_acc = (pred_fake == self.noise_label).sum().item() / pred_fake.shape[0]636print(real_acc, fake_acc)637638def log(self, epoch, i):639logstring = ' Epoch [%d] iter [%d]: ' % (epoch, i)640logstring += 'l1: %.4f ' % self.loss_l1641logstring += 'gen: %.4f ' % self.loss_gan_E642logstring += 'E_sum: %.4f ' % self.loss_sum643logstring += 'dis_real: %.4f ' % self.loss_gan_D_real644logstring += 'dis_fake: %.4f' % self.loss_gan_D_fake645print(logstring)646647def train_E(self, epochs):648pretrained_ckpt_dir = os.path.join(649self.opt.checkpoints_dir, self.opt.name,650'pretrained_net_E_%d.pth' % epochs651)652print(pretrained_ckpt_dir)653654print('======= Stage I: Subtraction =======')655if os.path.isfile(pretrained_ckpt_dir):656state_dict_E = torch.load(pretrained_ckpt_dir)657self.encoder.load_state_dict(state_dict_E)658print('======= Load cached checkpoints %s' % pretrained_ckpt_dir)659else:660print('======= total epochs: %d ' % epochs)661for epoch in range(1,epochs+1):662for i, data in enumerate(self.loader):663self.process_input(data)664665self.update_E()666self.update_D()667668if i % 10 == 0:669self.log(epoch, i) # output losses and thing.670671print('Epoch [%d] finished' % epoch)672# just save the latest.673torch.save(self.encoder.state_dict(), os.path.join(674self.opt.checkpoints_dir, self.opt.name, 'pretrained_net_E_%d.pth' % epoch675))676677class ContrasiveGenerator(SPADEGenerator):678def __init__(self, opt):679super().__init__(opt)680self.encoder = LIPEncoder(opt, self.sw, self.sh, self.scale_ratio)681682if opt.isTrain:683self.attach_dataloader(opt)684self.noise_dim = opt.noise_dim685686self.l1_loss = nn.L1Loss()687688beta1, beta2 = opt.beta1, opt.beta2689self.optimizer_E = torch.optim.Adam(690self.encoder.parameters(), lr=opt.lr, betas=(beta1, beta2)691)692693def _create_auxiliary_opt(self, opt):694'''695Create auxiliary options696change necessary params697--- dataroot698--- dataset_mode699'''700aux_opt = copy.deepcopy(opt) # just for safety701aux_opt.dataroot = opt.dataroot_assist702aux_opt.dataset_mode = 'assist'703aux_opt.batchSize = 8704aux_opt.nThreads = 4705return aux_opt706707def attach_dataloader(self, opt):708aux_opt = self._create_auxiliary_opt(opt)709self.loader = data.create_dataloader(aux_opt)710711def encode(self, x):712return self.encoder(x)713714def process_input(self, data):715self.clean = data['clean'].cuda()716self.noisy = data['noisy'].cuda()717718def update_E(self):719bundle_in = torch.cat((self.clean, self.noisy), dim=0)720bundle_out = self.encode(bundle_in)721nb = bundle_in.shape[0] // 2722F_real, F_fake = bundle_out[:nb], bundle_out[nb:]723loss_l1 = self.l1_loss(F_fake, F_real)724725self.optimizer_E.zero_grad()726loss_l1.backward()727self.optimizer_E.step()728729self.loss_l1 = loss_l1.item()730731def log(self, epoch, i):732logstring = ' Epoch [%d] iter [%d]: ' % (epoch, i)733logstring += 'l1: %.4f ' % self.loss_l1734print(logstring)735736def train_E(self, epochs):737pretrained_ckpt_dir = os.path.join(738self.opt.checkpoints_dir, self.opt.name,739'pretrained_net_E_%d.pth' % epochs740)741print(pretrained_ckpt_dir)742743print('======= Stage I: Subtraction =======')744if os.path.isfile(pretrained_ckpt_dir):745state_dict_E = torch.load(pretrained_ckpt_dir)746self.encoder.load_state_dict(state_dict_E)747print('======= Load cached checkpoints %s' % pretrained_ckpt_dir)748else:749print('======= total epochs: %d ' % epochs)750for epoch in range(1,epochs+1):751for i, data in enumerate(self.loader):752self.process_input(data)753self.update_E()754755if i % 10 == 0:756self.log(epoch, i) # output losses and thing.757758print('Epoch [%d] finished' % epoch)759# just save the latest.760torch.save(self.encoder.state_dict(), os.path.join(761self.opt.checkpoints_dir, self.opt.name, 'pretrained_net_E_%d.pth' % epoch762))763764765766