CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/tests/L1/common/main_amp.py
Views: 794
1
import argparse
2
import os
3
import shutil
4
import time
5
6
import torch
7
import torch.nn as nn
8
import torch.nn.parallel
9
import torch.backends.cudnn as cudnn
10
import torch.distributed as dist
11
import torch.optim
12
import torch.utils.data
13
import torch.utils.data.distributed
14
import torchvision.transforms as transforms
15
import torchvision.datasets as datasets
16
import torchvision.models as models
17
18
import numpy as np
19
20
try:
21
from apex.parallel import DistributedDataParallel as DDP
22
from apex.fp16_utils import *
23
from apex import amp, optimizers
24
from apex.multi_tensor_apply import multi_tensor_applier
25
except ImportError:
26
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
27
28
model_names = sorted(name for name in models.__dict__
29
if name.islower() and not name.startswith("__")
30
and callable(models.__dict__[name]))
31
32
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
33
parser.add_argument('data', metavar='DIR',
34
help='path to dataset')
35
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
36
choices=model_names,
37
help='model architecture: ' +
38
' | '.join(model_names) +
39
' (default: resnet18)')
40
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
41
help='number of data loading workers (default: 4)')
42
parser.add_argument('--epochs', default=90, type=int, metavar='N',
43
help='number of total epochs to run')
44
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
45
help='manual epoch number (useful on restarts)')
46
parser.add_argument('-b', '--batch-size', default=256, type=int,
47
metavar='N', help='mini-batch size per process (default: 256)')
48
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
49
metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')
50
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
51
help='momentum')
52
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
53
metavar='W', help='weight decay (default: 1e-4)')
54
parser.add_argument('--print-freq', '-p', default=10, type=int,
55
metavar='N', help='print frequency (default: 10)')
56
parser.add_argument('--resume', default='', type=str, metavar='PATH',
57
help='path to latest checkpoint (default: none)')
58
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
59
help='evaluate model on validation set')
60
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
61
help='use pre-trained model')
62
63
parser.add_argument('--prof', dest='prof', action='store_true',
64
help='Only run 10 iterations for profiling.')
65
parser.add_argument('--deterministic', action='store_true')
66
67
parser.add_argument("--local_rank", default=0, type=int)
68
parser.add_argument('--sync_bn', action='store_true',
69
help='enabling apex sync BN.')
70
71
parser.add_argument('--has-ext', action='store_true')
72
parser.add_argument('--opt-level', type=str)
73
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
74
parser.add_argument('--loss-scale', type=str, default=None)
75
parser.add_argument('--fused-adam', action='store_true')
76
77
parser.add_argument('--prints-to-process', type=int, default=10)
78
79
cudnn.benchmark = True
80
81
def fast_collate(batch):
82
imgs = [img[0] for img in batch]
83
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
84
w = imgs[0].size[0]
85
h = imgs[0].size[1]
86
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 )
87
for i, img in enumerate(imgs):
88
nump_array = np.asarray(img, dtype=np.uint8)
89
if(nump_array.ndim < 3):
90
nump_array = np.expand_dims(nump_array, axis=-1)
91
nump_array = np.rollaxis(nump_array, 2)
92
93
tensor[i] += torch.from_numpy(nump_array)
94
95
return tensor, targets
96
97
best_prec1 = 0
98
args = parser.parse_args()
99
100
# Let multi_tensor_applier be the canary in the coalmine
101
# that verifies if the backend is what we think it is
102
assert multi_tensor_applier.available == args.has_ext
103
104
print("opt_level = {}".format(args.opt_level))
105
print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
106
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
107
108
109
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
110
111
if args.deterministic:
112
cudnn.benchmark = False
113
cudnn.deterministic = True
114
torch.manual_seed(args.local_rank)
115
torch.set_printoptions(precision=10)
116
117
def main():
118
global best_prec1, args
119
120
args.distributed = False
121
if 'WORLD_SIZE' in os.environ:
122
args.distributed = int(os.environ['WORLD_SIZE']) > 1
123
124
args.gpu = 0
125
args.world_size = 1
126
127
if args.distributed:
128
args.gpu = args.local_rank % torch.cuda.device_count()
129
torch.cuda.set_device(args.gpu)
130
torch.distributed.init_process_group(backend='nccl',
131
init_method='env://')
132
args.world_size = torch.distributed.get_world_size()
133
134
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
135
136
# create model
137
if args.pretrained:
138
print("=> using pre-trained model '{}'".format(args.arch))
139
model = models.__dict__[args.arch](pretrained=True)
140
else:
141
print("=> creating model '{}'".format(args.arch))
142
model = models.__dict__[args.arch]()
143
144
if args.sync_bn:
145
import apex
146
print("using apex synced BN")
147
model = apex.parallel.convert_syncbn_model(model)
148
149
model = model.cuda()
150
151
# Scale learning rate based on global batch size
152
args.lr = args.lr*float(args.batch_size*args.world_size)/256.
153
if args.fused_adam:
154
optimizer = optimizers.FusedAdam(model.parameters())
155
else:
156
optimizer = torch.optim.SGD(model.parameters(), args.lr,
157
momentum=args.momentum,
158
weight_decay=args.weight_decay)
159
160
model, optimizer = amp.initialize(
161
model, optimizer,
162
# enabled=False,
163
opt_level=args.opt_level,
164
keep_batchnorm_fp32=args.keep_batchnorm_fp32,
165
loss_scale=args.loss_scale
166
)
167
168
if args.distributed:
169
# By default, apex.parallel.DistributedDataParallel overlaps communication with
170
# computation in the backward pass.
171
# model = DDP(model)
172
# delay_allreduce delays all communication to the end of the backward pass.
173
model = DDP(model, delay_allreduce=True)
174
175
# define loss function (criterion) and optimizer
176
criterion = nn.CrossEntropyLoss().cuda()
177
178
# Optionally resume from a checkpoint
179
if args.resume:
180
# Use a local scope to avoid dangling references
181
def resume():
182
if os.path.isfile(args.resume):
183
print("=> loading checkpoint '{}'".format(args.resume))
184
checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
185
args.start_epoch = checkpoint['epoch']
186
best_prec1 = checkpoint['best_prec1']
187
model.load_state_dict(checkpoint['state_dict'])
188
optimizer.load_state_dict(checkpoint['optimizer'])
189
print("=> loaded checkpoint '{}' (epoch {})"
190
.format(args.resume, checkpoint['epoch']))
191
else:
192
print("=> no checkpoint found at '{}'".format(args.resume))
193
resume()
194
195
# Data loading code
196
traindir = os.path.join(args.data, 'train')
197
valdir = os.path.join(args.data, 'val')
198
199
if(args.arch == "inception_v3"):
200
crop_size = 299
201
val_size = 320 # I chose this value arbitrarily, we can adjust.
202
else:
203
crop_size = 224
204
val_size = 256
205
206
train_dataset = datasets.ImageFolder(
207
traindir,
208
transforms.Compose([
209
transforms.RandomResizedCrop(crop_size),
210
transforms.RandomHorizontalFlip(),
211
# transforms.ToTensor(), Too slow
212
# normalize,
213
]))
214
val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
215
transforms.Resize(val_size),
216
transforms.CenterCrop(crop_size),
217
]))
218
219
train_sampler = None
220
val_sampler = None
221
if args.distributed:
222
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
223
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
224
225
train_loader = torch.utils.data.DataLoader(
226
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
227
num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)
228
229
val_loader = torch.utils.data.DataLoader(
230
val_dataset,
231
batch_size=args.batch_size, shuffle=False,
232
num_workers=args.workers, pin_memory=True,
233
sampler=val_sampler,
234
collate_fn=fast_collate)
235
236
if args.evaluate:
237
validate(val_loader, model, criterion)
238
return
239
240
for epoch in range(args.start_epoch, args.epochs):
241
if args.distributed:
242
train_sampler.set_epoch(epoch)
243
244
# train for one epoch
245
train(train_loader, model, criterion, optimizer, epoch)
246
if args.prof:
247
break
248
# evaluate on validation set
249
prec1 = validate(val_loader, model, criterion)
250
251
# remember best prec@1 and save checkpoint
252
if args.local_rank == 0:
253
is_best = prec1 > best_prec1
254
best_prec1 = max(prec1, best_prec1)
255
save_checkpoint({
256
'epoch': epoch + 1,
257
'arch': args.arch,
258
'state_dict': model.state_dict(),
259
'best_prec1': best_prec1,
260
'optimizer' : optimizer.state_dict(),
261
}, is_best)
262
263
class data_prefetcher():
264
def __init__(self, loader):
265
self.loader = iter(loader)
266
self.stream = torch.cuda.Stream()
267
self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
268
self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
269
# With Amp, it isn't necessary to manually convert data to half.
270
# if args.fp16:
271
# self.mean = self.mean.half()
272
# self.std = self.std.half()
273
self.preload()
274
275
def preload(self):
276
try:
277
self.next_input, self.next_target = next(self.loader)
278
except StopIteration:
279
self.next_input = None
280
self.next_target = None
281
return
282
with torch.cuda.stream(self.stream):
283
self.next_input = self.next_input.cuda(non_blocking=True)
284
self.next_target = self.next_target.cuda(non_blocking=True)
285
# With Amp, it isn't necessary to manually convert data to half.
286
# if args.fp16:
287
# self.next_input = self.next_input.half()
288
# else:
289
self.next_input = self.next_input.float()
290
self.next_input = self.next_input.sub_(self.mean).div_(self.std)
291
292
def next(self):
293
torch.cuda.current_stream().wait_stream(self.stream)
294
input = self.next_input
295
target = self.next_target
296
self.preload()
297
return input, target
298
299
300
def train(train_loader, model, criterion, optimizer, epoch):
301
batch_time = AverageMeter()
302
data_time = AverageMeter()
303
losses = AverageMeter()
304
top1 = AverageMeter()
305
top5 = AverageMeter()
306
307
# switch to train mode
308
model.train()
309
end = time.time()
310
311
run_info_dict = {"Iteration" : [],
312
"Loss" : [],
313
"Speed" : []}
314
315
prefetcher = data_prefetcher(train_loader)
316
input, target = prefetcher.next()
317
i = -1
318
while input is not None:
319
i += 1
320
321
# No learning rate warmup for this test, to expose bitwise inaccuracies more quickly
322
# adjust_learning_rate(optimizer, epoch, i, len(train_loader))
323
324
if args.prof:
325
if i > 10:
326
break
327
# measure data loading time
328
data_time.update(time.time() - end)
329
330
# compute output
331
output = model(input)
332
loss = criterion(output, target)
333
334
# measure accuracy and record loss
335
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
336
337
if args.distributed:
338
reduced_loss = reduce_tensor(loss.data)
339
prec1 = reduce_tensor(prec1)
340
prec5 = reduce_tensor(prec5)
341
else:
342
reduced_loss = loss.data
343
344
losses.update(to_python_float(reduced_loss), input.size(0))
345
top1.update(to_python_float(prec1), input.size(0))
346
top5.update(to_python_float(prec5), input.size(0))
347
348
# compute gradient and do SGD step
349
optimizer.zero_grad()
350
351
with amp.scale_loss(loss, optimizer) as scaled_loss:
352
scaled_loss.backward()
353
354
# for param in model.parameters():
355
# print(param.data.double().sum().item(), param.grad.data.double().sum().item())
356
357
# torch.cuda.synchronize()
358
torch.cuda.nvtx.range_push("step")
359
optimizer.step()
360
torch.cuda.nvtx.range_pop()
361
362
torch.cuda.synchronize()
363
# measure elapsed time
364
batch_time.update(time.time() - end)
365
366
end = time.time()
367
368
# If you decide to refactor this test, like examples/imagenet, to sample the loss every
369
# print_freq iterations, make sure to move this prefetching below the accuracy calculation.
370
input, target = prefetcher.next()
371
372
if i % args.print_freq == 0 and i > 1:
373
if args.local_rank == 0:
374
print('Epoch: [{0}][{1}/{2}]\t'
375
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
376
'Speed {3:.3f} ({4:.3f})\t'
377
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
378
'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
379
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
380
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
381
epoch, i, len(train_loader),
382
args.world_size * args.batch_size / batch_time.val,
383
args.world_size * args.batch_size / batch_time.avg,
384
batch_time=batch_time,
385
data_time=data_time, loss=losses, top1=top1, top5=top5))
386
run_info_dict["Iteration"].append(i)
387
run_info_dict["Loss"].append(losses.val)
388
run_info_dict["Speed"].append(args.world_size * args.batch_size / batch_time.val)
389
if len(run_info_dict["Loss"]) == args.prints_to_process:
390
if args.local_rank == 0:
391
torch.save(run_info_dict,
392
str(args.has_ext) + "_" + str(args.opt_level) + "_" +
393
str(args.loss_scale) + "_" + str(args.keep_batchnorm_fp32) + "_" +
394
str(args.fused_adam))
395
quit()
396
397
398
def validate(val_loader, model, criterion):
399
batch_time = AverageMeter()
400
losses = AverageMeter()
401
top1 = AverageMeter()
402
top5 = AverageMeter()
403
404
# switch to evaluate mode
405
model.eval()
406
407
end = time.time()
408
409
prefetcher = data_prefetcher(val_loader)
410
input, target = prefetcher.next()
411
i = -1
412
while input is not None:
413
i += 1
414
415
# compute output
416
with torch.no_grad():
417
output = model(input)
418
loss = criterion(output, target)
419
420
# measure accuracy and record loss
421
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
422
423
if args.distributed:
424
reduced_loss = reduce_tensor(loss.data)
425
prec1 = reduce_tensor(prec1)
426
prec5 = reduce_tensor(prec5)
427
else:
428
reduced_loss = loss.data
429
430
losses.update(to_python_float(reduced_loss), input.size(0))
431
top1.update(to_python_float(prec1), input.size(0))
432
top5.update(to_python_float(prec5), input.size(0))
433
434
# measure elapsed time
435
batch_time.update(time.time() - end)
436
end = time.time()
437
438
if args.local_rank == 0 and i % args.print_freq == 0:
439
print('Test: [{0}/{1}]\t'
440
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
441
'Speed {2:.3f} ({3:.3f})\t'
442
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
443
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
444
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
445
i, len(val_loader),
446
args.world_size * args.batch_size / batch_time.val,
447
args.world_size * args.batch_size / batch_time.avg,
448
batch_time=batch_time, loss=losses,
449
top1=top1, top5=top5))
450
451
input, target = prefetcher.next()
452
453
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
454
.format(top1=top1, top5=top5))
455
456
return top1.avg
457
458
459
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
460
torch.save(state, filename)
461
if is_best:
462
shutil.copyfile(filename, 'model_best.pth.tar')
463
464
465
class AverageMeter(object):
466
"""Computes and stores the average and current value"""
467
def __init__(self):
468
self.reset()
469
470
def reset(self):
471
self.val = 0
472
self.avg = 0
473
self.sum = 0
474
self.count = 0
475
476
def update(self, val, n=1):
477
self.val = val
478
self.sum += val * n
479
self.count += n
480
self.avg = self.sum / self.count
481
482
483
def adjust_learning_rate(optimizer, epoch, step, len_epoch):
484
"""LR schedule that should yield 76% converged accuracy with batch size 256"""
485
factor = epoch // 30
486
487
if epoch >= 80:
488
factor = factor + 1
489
490
lr = args.lr*(0.1**factor)
491
492
"""Warmup"""
493
if epoch < 5:
494
lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch)
495
496
# if(args.local_rank == 0):
497
# print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr))
498
499
for param_group in optimizer.param_groups:
500
param_group['lr'] = lr
501
502
503
def accuracy(output, target, topk=(1,)):
504
"""Computes the precision@k for the specified values of k"""
505
maxk = max(topk)
506
batch_size = target.size(0)
507
508
_, pred = output.topk(maxk, 1, True, True)
509
pred = pred.t()
510
correct = pred.eq(target.view(1, -1).expand_as(pred))
511
512
res = []
513
for k in topk:
514
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
515
res.append(correct_k.mul_(100.0 / batch_size))
516
return res
517
518
519
def reduce_tensor(tensor):
520
rt = tensor.clone()
521
dist.all_reduce(rt, op=dist.reduce_op.SUM)
522
rt /= args.world_size
523
return rt
524
525
if __name__ == '__main__':
526
main()
527
528