CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/examples/dcgan/main_amp.py
Views: 792
1
from __future__ import print_function
2
import argparse
3
import os
4
import random
5
import torch
6
import torch.nn as nn
7
import torch.nn.parallel
8
import torch.backends.cudnn as cudnn
9
import torch.optim as optim
10
import torch.utils.data
11
import torchvision.datasets as dset
12
import torchvision.transforms as transforms
13
import torchvision.utils as vutils
14
15
try:
16
from apex import amp
17
except ImportError:
18
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
19
20
21
parser = argparse.ArgumentParser()
22
parser.add_argument('--dataset', default='cifar10', help='cifar10 | lsun | mnist |imagenet | folder | lfw | fake')
23
parser.add_argument('--dataroot', default='./', help='path to dataset')
24
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
25
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
26
parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network')
27
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
28
parser.add_argument('--ngf', type=int, default=64)
29
parser.add_argument('--ndf', type=int, default=64)
30
parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')
31
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
32
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
33
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
34
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
35
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
36
parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints')
37
parser.add_argument('--manualSeed', type=int, help='manual seed')
38
parser.add_argument('--classes', default='bedroom', help='comma separated list of classes for the lsun data set')
39
parser.add_argument('--opt_level', default='O1', help='amp opt_level, default="O1"')
40
41
opt = parser.parse_args()
42
print(opt)
43
44
45
try:
46
os.makedirs(opt.outf)
47
except OSError:
48
pass
49
50
if opt.manualSeed is None:
51
opt.manualSeed = 2809
52
print("Random Seed: ", opt.manualSeed)
53
random.seed(opt.manualSeed)
54
torch.manual_seed(opt.manualSeed)
55
56
cudnn.benchmark = True
57
58
59
if opt.dataset in ['imagenet', 'folder', 'lfw']:
60
# folder dataset
61
dataset = dset.ImageFolder(root=opt.dataroot,
62
transform=transforms.Compose([
63
transforms.Resize(opt.imageSize),
64
transforms.CenterCrop(opt.imageSize),
65
transforms.ToTensor(),
66
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
67
]))
68
nc=3
69
elif opt.dataset == 'lsun':
70
classes = [ c + '_train' for c in opt.classes.split(',')]
71
dataset = dset.LSUN(root=opt.dataroot, classes=classes,
72
transform=transforms.Compose([
73
transforms.Resize(opt.imageSize),
74
transforms.CenterCrop(opt.imageSize),
75
transforms.ToTensor(),
76
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
77
]))
78
nc=3
79
elif opt.dataset == 'cifar10':
80
dataset = dset.CIFAR10(root=opt.dataroot, download=True,
81
transform=transforms.Compose([
82
transforms.Resize(opt.imageSize),
83
transforms.ToTensor(),
84
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
85
]))
86
nc=3
87
88
elif opt.dataset == 'mnist':
89
dataset = dset.MNIST(root=opt.dataroot, download=True,
90
transform=transforms.Compose([
91
transforms.Resize(opt.imageSize),
92
transforms.ToTensor(),
93
transforms.Normalize((0.5,), (0.5,)),
94
]))
95
nc=1
96
97
elif opt.dataset == 'fake':
98
dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize),
99
transform=transforms.ToTensor())
100
nc=3
101
102
assert dataset
103
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
104
shuffle=True, num_workers=int(opt.workers))
105
106
device = torch.device("cuda:0")
107
ngpu = int(opt.ngpu)
108
nz = int(opt.nz)
109
ngf = int(opt.ngf)
110
ndf = int(opt.ndf)
111
112
113
# custom weights initialization called on netG and netD
114
def weights_init(m):
115
classname = m.__class__.__name__
116
if classname.find('Conv') != -1:
117
m.weight.data.normal_(0.0, 0.02)
118
elif classname.find('BatchNorm') != -1:
119
m.weight.data.normal_(1.0, 0.02)
120
m.bias.data.fill_(0)
121
122
123
class Generator(nn.Module):
124
def __init__(self, ngpu):
125
super(Generator, self).__init__()
126
self.ngpu = ngpu
127
self.main = nn.Sequential(
128
# input is Z, going into a convolution
129
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
130
nn.BatchNorm2d(ngf * 8),
131
nn.ReLU(True),
132
# state size. (ngf*8) x 4 x 4
133
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
134
nn.BatchNorm2d(ngf * 4),
135
nn.ReLU(True),
136
# state size. (ngf*4) x 8 x 8
137
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
138
nn.BatchNorm2d(ngf * 2),
139
nn.ReLU(True),
140
# state size. (ngf*2) x 16 x 16
141
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
142
nn.BatchNorm2d(ngf),
143
nn.ReLU(True),
144
# state size. (ngf) x 32 x 32
145
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
146
nn.Tanh()
147
# state size. (nc) x 64 x 64
148
)
149
150
def forward(self, input):
151
if input.is_cuda and self.ngpu > 1:
152
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
153
else:
154
output = self.main(input)
155
return output
156
157
158
netG = Generator(ngpu).to(device)
159
netG.apply(weights_init)
160
if opt.netG != '':
161
netG.load_state_dict(torch.load(opt.netG))
162
print(netG)
163
164
165
class Discriminator(nn.Module):
166
def __init__(self, ngpu):
167
super(Discriminator, self).__init__()
168
self.ngpu = ngpu
169
self.main = nn.Sequential(
170
# input is (nc) x 64 x 64
171
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
172
nn.LeakyReLU(0.2, inplace=True),
173
# state size. (ndf) x 32 x 32
174
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
175
nn.BatchNorm2d(ndf * 2),
176
nn.LeakyReLU(0.2, inplace=True),
177
# state size. (ndf*2) x 16 x 16
178
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
179
nn.BatchNorm2d(ndf * 4),
180
nn.LeakyReLU(0.2, inplace=True),
181
# state size. (ndf*4) x 8 x 8
182
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
183
nn.BatchNorm2d(ndf * 8),
184
nn.LeakyReLU(0.2, inplace=True),
185
# state size. (ndf*8) x 4 x 4
186
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
187
)
188
189
def forward(self, input):
190
if input.is_cuda and self.ngpu > 1:
191
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
192
else:
193
output = self.main(input)
194
195
return output.view(-1, 1).squeeze(1)
196
197
198
netD = Discriminator(ngpu).to(device)
199
netD.apply(weights_init)
200
if opt.netD != '':
201
netD.load_state_dict(torch.load(opt.netD))
202
print(netD)
203
204
criterion = nn.BCEWithLogitsLoss()
205
206
fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device)
207
real_label = 1
208
fake_label = 0
209
210
# setup optimizer
211
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
212
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
213
214
[netD, netG], [optimizerD, optimizerG] = amp.initialize(
215
[netD, netG], [optimizerD, optimizerG], opt_level=opt.opt_level, num_losses=3)
216
217
for epoch in range(opt.niter):
218
for i, data in enumerate(dataloader, 0):
219
############################
220
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
221
###########################
222
# train with real
223
netD.zero_grad()
224
real_cpu = data[0].to(device)
225
batch_size = real_cpu.size(0)
226
label = torch.full((batch_size,), real_label, device=device)
227
228
output = netD(real_cpu)
229
errD_real = criterion(output, label)
230
with amp.scale_loss(errD_real, optimizerD, loss_id=0) as errD_real_scaled:
231
errD_real_scaled.backward()
232
D_x = output.mean().item()
233
234
# train with fake
235
noise = torch.randn(batch_size, nz, 1, 1, device=device)
236
fake = netG(noise)
237
label.fill_(fake_label)
238
output = netD(fake.detach())
239
errD_fake = criterion(output, label)
240
with amp.scale_loss(errD_fake, optimizerD, loss_id=1) as errD_fake_scaled:
241
errD_fake_scaled.backward()
242
D_G_z1 = output.mean().item()
243
errD = errD_real + errD_fake
244
optimizerD.step()
245
246
############################
247
# (2) Update G network: maximize log(D(G(z)))
248
###########################
249
netG.zero_grad()
250
label.fill_(real_label) # fake labels are real for generator cost
251
output = netD(fake)
252
errG = criterion(output, label)
253
with amp.scale_loss(errG, optimizerG, loss_id=2) as errG_scaled:
254
errG_scaled.backward()
255
D_G_z2 = output.mean().item()
256
optimizerG.step()
257
258
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
259
% (epoch, opt.niter, i, len(dataloader),
260
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
261
if i % 100 == 0:
262
vutils.save_image(real_cpu,
263
'%s/real_samples.png' % opt.outf,
264
normalize=True)
265
fake = netG(fixed_noise)
266
vutils.save_image(fake.detach(),
267
'%s/amp_fake_samples_epoch_%03d.png' % (opt.outf, epoch),
268
normalize=True)
269
270
# do checkpointing
271
torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch))
272
torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))
273
274
275
276