Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/examples/dcgan/main_amp.py
Views: 792
from __future__ import print_function1import argparse2import os3import random4import torch5import torch.nn as nn6import torch.nn.parallel7import torch.backends.cudnn as cudnn8import torch.optim as optim9import torch.utils.data10import torchvision.datasets as dset11import torchvision.transforms as transforms12import torchvision.utils as vutils1314try:15from apex import amp16except ImportError:17raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")181920parser = argparse.ArgumentParser()21parser.add_argument('--dataset', default='cifar10', help='cifar10 | lsun | mnist |imagenet | folder | lfw | fake')22parser.add_argument('--dataroot', default='./', help='path to dataset')23parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)24parser.add_argument('--batchSize', type=int, default=64, help='input batch size')25parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network')26parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')27parser.add_argument('--ngf', type=int, default=64)28parser.add_argument('--ndf', type=int, default=64)29parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')30parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')31parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')32parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')33parser.add_argument('--netG', default='', help="path to netG (to continue training)")34parser.add_argument('--netD', default='', help="path to netD (to continue training)")35parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints')36parser.add_argument('--manualSeed', type=int, help='manual seed')37parser.add_argument('--classes', default='bedroom', help='comma separated list of classes for the lsun data set')38parser.add_argument('--opt_level', default='O1', help='amp opt_level, default="O1"')3940opt = parser.parse_args()41print(opt)424344try:45os.makedirs(opt.outf)46except OSError:47pass4849if opt.manualSeed is None:50opt.manualSeed = 280951print("Random Seed: ", opt.manualSeed)52random.seed(opt.manualSeed)53torch.manual_seed(opt.manualSeed)5455cudnn.benchmark = True565758if opt.dataset in ['imagenet', 'folder', 'lfw']:59# folder dataset60dataset = dset.ImageFolder(root=opt.dataroot,61transform=transforms.Compose([62transforms.Resize(opt.imageSize),63transforms.CenterCrop(opt.imageSize),64transforms.ToTensor(),65transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),66]))67nc=368elif opt.dataset == 'lsun':69classes = [ c + '_train' for c in opt.classes.split(',')]70dataset = dset.LSUN(root=opt.dataroot, classes=classes,71transform=transforms.Compose([72transforms.Resize(opt.imageSize),73transforms.CenterCrop(opt.imageSize),74transforms.ToTensor(),75transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),76]))77nc=378elif opt.dataset == 'cifar10':79dataset = dset.CIFAR10(root=opt.dataroot, download=True,80transform=transforms.Compose([81transforms.Resize(opt.imageSize),82transforms.ToTensor(),83transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),84]))85nc=38687elif opt.dataset == 'mnist':88dataset = dset.MNIST(root=opt.dataroot, download=True,89transform=transforms.Compose([90transforms.Resize(opt.imageSize),91transforms.ToTensor(),92transforms.Normalize((0.5,), (0.5,)),93]))94nc=19596elif opt.dataset == 'fake':97dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize),98transform=transforms.ToTensor())99nc=3100101assert dataset102dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,103shuffle=True, num_workers=int(opt.workers))104105device = torch.device("cuda:0")106ngpu = int(opt.ngpu)107nz = int(opt.nz)108ngf = int(opt.ngf)109ndf = int(opt.ndf)110111112# custom weights initialization called on netG and netD113def weights_init(m):114classname = m.__class__.__name__115if classname.find('Conv') != -1:116m.weight.data.normal_(0.0, 0.02)117elif classname.find('BatchNorm') != -1:118m.weight.data.normal_(1.0, 0.02)119m.bias.data.fill_(0)120121122class Generator(nn.Module):123def __init__(self, ngpu):124super(Generator, self).__init__()125self.ngpu = ngpu126self.main = nn.Sequential(127# input is Z, going into a convolution128nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),129nn.BatchNorm2d(ngf * 8),130nn.ReLU(True),131# state size. (ngf*8) x 4 x 4132nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),133nn.BatchNorm2d(ngf * 4),134nn.ReLU(True),135# state size. (ngf*4) x 8 x 8136nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),137nn.BatchNorm2d(ngf * 2),138nn.ReLU(True),139# state size. (ngf*2) x 16 x 16140nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),141nn.BatchNorm2d(ngf),142nn.ReLU(True),143# state size. (ngf) x 32 x 32144nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),145nn.Tanh()146# state size. (nc) x 64 x 64147)148149def forward(self, input):150if input.is_cuda and self.ngpu > 1:151output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))152else:153output = self.main(input)154return output155156157netG = Generator(ngpu).to(device)158netG.apply(weights_init)159if opt.netG != '':160netG.load_state_dict(torch.load(opt.netG))161print(netG)162163164class Discriminator(nn.Module):165def __init__(self, ngpu):166super(Discriminator, self).__init__()167self.ngpu = ngpu168self.main = nn.Sequential(169# input is (nc) x 64 x 64170nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),171nn.LeakyReLU(0.2, inplace=True),172# state size. (ndf) x 32 x 32173nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),174nn.BatchNorm2d(ndf * 2),175nn.LeakyReLU(0.2, inplace=True),176# state size. (ndf*2) x 16 x 16177nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),178nn.BatchNorm2d(ndf * 4),179nn.LeakyReLU(0.2, inplace=True),180# state size. (ndf*4) x 8 x 8181nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),182nn.BatchNorm2d(ndf * 8),183nn.LeakyReLU(0.2, inplace=True),184# state size. (ndf*8) x 4 x 4185nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),186)187188def forward(self, input):189if input.is_cuda and self.ngpu > 1:190output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))191else:192output = self.main(input)193194return output.view(-1, 1).squeeze(1)195196197netD = Discriminator(ngpu).to(device)198netD.apply(weights_init)199if opt.netD != '':200netD.load_state_dict(torch.load(opt.netD))201print(netD)202203criterion = nn.BCEWithLogitsLoss()204205fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device)206real_label = 1207fake_label = 0208209# setup optimizer210optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))211optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))212213[netD, netG], [optimizerD, optimizerG] = amp.initialize(214[netD, netG], [optimizerD, optimizerG], opt_level=opt.opt_level, num_losses=3)215216for epoch in range(opt.niter):217for i, data in enumerate(dataloader, 0):218############################219# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))220###########################221# train with real222netD.zero_grad()223real_cpu = data[0].to(device)224batch_size = real_cpu.size(0)225label = torch.full((batch_size,), real_label, device=device)226227output = netD(real_cpu)228errD_real = criterion(output, label)229with amp.scale_loss(errD_real, optimizerD, loss_id=0) as errD_real_scaled:230errD_real_scaled.backward()231D_x = output.mean().item()232233# train with fake234noise = torch.randn(batch_size, nz, 1, 1, device=device)235fake = netG(noise)236label.fill_(fake_label)237output = netD(fake.detach())238errD_fake = criterion(output, label)239with amp.scale_loss(errD_fake, optimizerD, loss_id=1) as errD_fake_scaled:240errD_fake_scaled.backward()241D_G_z1 = output.mean().item()242errD = errD_real + errD_fake243optimizerD.step()244245############################246# (2) Update G network: maximize log(D(G(z)))247###########################248netG.zero_grad()249label.fill_(real_label) # fake labels are real for generator cost250output = netD(fake)251errG = criterion(output, label)252with amp.scale_loss(errG, optimizerG, loss_id=2) as errG_scaled:253errG_scaled.backward()254D_G_z2 = output.mean().item()255optimizerG.step()256257print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'258% (epoch, opt.niter, i, len(dataloader),259errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))260if i % 100 == 0:261vutils.save_image(real_cpu,262'%s/real_samples.png' % opt.outf,263normalize=True)264fake = netG(fixed_noise)265vutils.save_image(fake.detach(),266'%s/amp_fake_samples_epoch_%03d.png' % (opt.outf, epoch),267normalize=True)268269# do checkpointing270torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch))271torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))272273274275276