CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/examples/imagenet/main_amp.py
Views: 792
1
import argparse
2
import os
3
import shutil
4
import time
5
6
import torch
7
import torch.nn as nn
8
import torch.nn.parallel
9
import torch.backends.cudnn as cudnn
10
import torch.distributed as dist
11
import torch.optim
12
import torch.utils.data
13
import torch.utils.data.distributed
14
import torchvision.transforms as transforms
15
import torchvision.datasets as datasets
16
import torchvision.models as models
17
18
import numpy as np
19
20
try:
21
from apex.parallel import DistributedDataParallel as DDP
22
from apex.fp16_utils import *
23
from apex import amp, optimizers
24
from apex.multi_tensor_apply import multi_tensor_applier
25
except ImportError:
26
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
27
28
def fast_collate(batch, memory_format):
29
30
imgs = [img[0] for img in batch]
31
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
32
w = imgs[0].size[0]
33
h = imgs[0].size[1]
34
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8).contiguous(memory_format=memory_format)
35
for i, img in enumerate(imgs):
36
nump_array = np.asarray(img, dtype=np.uint8)
37
if(nump_array.ndim < 3):
38
nump_array = np.expand_dims(nump_array, axis=-1)
39
nump_array = np.rollaxis(nump_array, 2)
40
tensor[i] += torch.from_numpy(nump_array)
41
return tensor, targets
42
43
44
def parse():
45
model_names = sorted(name for name in models.__dict__
46
if name.islower() and not name.startswith("__")
47
and callable(models.__dict__[name]))
48
49
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
50
parser.add_argument('data', metavar='DIR',
51
help='path to dataset')
52
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
53
choices=model_names,
54
help='model architecture: ' +
55
' | '.join(model_names) +
56
' (default: resnet18)')
57
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
58
help='number of data loading workers (default: 4)')
59
parser.add_argument('--epochs', default=90, type=int, metavar='N',
60
help='number of total epochs to run')
61
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
62
help='manual epoch number (useful on restarts)')
63
parser.add_argument('-b', '--batch-size', default=256, type=int,
64
metavar='N', help='mini-batch size per process (default: 256)')
65
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
66
metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')
67
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
68
help='momentum')
69
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
70
metavar='W', help='weight decay (default: 1e-4)')
71
parser.add_argument('--print-freq', '-p', default=10, type=int,
72
metavar='N', help='print frequency (default: 10)')
73
parser.add_argument('--resume', default='', type=str, metavar='PATH',
74
help='path to latest checkpoint (default: none)')
75
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
76
help='evaluate model on validation set')
77
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
78
help='use pre-trained model')
79
80
parser.add_argument('--prof', default=-1, type=int,
81
help='Only run 10 iterations for profiling.')
82
parser.add_argument('--deterministic', action='store_true')
83
84
parser.add_argument("--local_rank", default=0, type=int)
85
parser.add_argument('--sync_bn', action='store_true',
86
help='enabling apex sync BN.')
87
88
parser.add_argument('--opt-level', type=str)
89
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
90
parser.add_argument('--loss-scale', type=str, default=None)
91
parser.add_argument('--channels-last', type=bool, default=False)
92
args = parser.parse_args()
93
return args
94
95
def main():
96
global best_prec1, args
97
98
args = parse()
99
print("opt_level = {}".format(args.opt_level))
100
print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
101
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
102
103
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
104
105
cudnn.benchmark = True
106
best_prec1 = 0
107
if args.deterministic:
108
cudnn.benchmark = False
109
cudnn.deterministic = True
110
torch.manual_seed(args.local_rank)
111
torch.set_printoptions(precision=10)
112
113
args.distributed = False
114
if 'WORLD_SIZE' in os.environ:
115
args.distributed = int(os.environ['WORLD_SIZE']) > 1
116
117
args.gpu = 0
118
args.world_size = 1
119
120
if args.distributed:
121
args.gpu = args.local_rank
122
torch.cuda.set_device(args.gpu)
123
torch.distributed.init_process_group(backend='nccl',
124
init_method='env://')
125
args.world_size = torch.distributed.get_world_size()
126
127
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
128
129
if args.channels_last:
130
memory_format = torch.channels_last
131
else:
132
memory_format = torch.contiguous_format
133
134
# create model
135
if args.pretrained:
136
print("=> using pre-trained model '{}'".format(args.arch))
137
model = models.__dict__[args.arch](pretrained=True)
138
else:
139
print("=> creating model '{}'".format(args.arch))
140
model = models.__dict__[args.arch]()
141
142
if args.sync_bn:
143
import apex
144
print("using apex synced BN")
145
model = apex.parallel.convert_syncbn_model(model)
146
147
model = model.cuda().to(memory_format=memory_format)
148
149
# Scale learning rate based on global batch size
150
args.lr = args.lr*float(args.batch_size*args.world_size)/256.
151
optimizer = torch.optim.SGD(model.parameters(), args.lr,
152
momentum=args.momentum,
153
weight_decay=args.weight_decay)
154
155
# Initialize Amp. Amp accepts either values or strings for the optional override arguments,
156
# for convenient interoperation with argparse.
157
model, optimizer = amp.initialize(model, optimizer,
158
opt_level=args.opt_level,
159
keep_batchnorm_fp32=args.keep_batchnorm_fp32,
160
loss_scale=args.loss_scale
161
)
162
163
# For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
164
# This must be done AFTER the call to amp.initialize. If model = DDP(model) is called
165
# before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
166
# the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
167
if args.distributed:
168
# By default, apex.parallel.DistributedDataParallel overlaps communication with
169
# computation in the backward pass.
170
# model = DDP(model)
171
# delay_allreduce delays all communication to the end of the backward pass.
172
model = DDP(model, delay_allreduce=True)
173
174
# define loss function (criterion) and optimizer
175
criterion = nn.CrossEntropyLoss().cuda()
176
177
# Optionally resume from a checkpoint
178
if args.resume:
179
# Use a local scope to avoid dangling references
180
def resume():
181
if os.path.isfile(args.resume):
182
print("=> loading checkpoint '{}'".format(args.resume))
183
checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
184
args.start_epoch = checkpoint['epoch']
185
global best_prec1
186
best_prec1 = checkpoint['best_prec1']
187
model.load_state_dict(checkpoint['state_dict'])
188
optimizer.load_state_dict(checkpoint['optimizer'])
189
print("=> loaded checkpoint '{}' (epoch {})"
190
.format(args.resume, checkpoint['epoch']))
191
else:
192
print("=> no checkpoint found at '{}'".format(args.resume))
193
resume()
194
195
# Data loading code
196
traindir = os.path.join(args.data, 'train')
197
valdir = os.path.join(args.data, 'val')
198
199
if(args.arch == "inception_v3"):
200
raise RuntimeError("Currently, inception_v3 is not supported by this example.")
201
# crop_size = 299
202
# val_size = 320 # I chose this value arbitrarily, we can adjust.
203
else:
204
crop_size = 224
205
val_size = 256
206
207
train_dataset = datasets.ImageFolder(
208
traindir,
209
transforms.Compose([
210
transforms.RandomResizedCrop(crop_size),
211
transforms.RandomHorizontalFlip(),
212
# transforms.ToTensor(), Too slow
213
# normalize,
214
]))
215
val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
216
transforms.Resize(val_size),
217
transforms.CenterCrop(crop_size),
218
]))
219
220
train_sampler = None
221
val_sampler = None
222
if args.distributed:
223
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
224
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
225
226
collate_fn = lambda b: fast_collate(b, memory_format)
227
228
train_loader = torch.utils.data.DataLoader(
229
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
230
num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=collate_fn)
231
232
val_loader = torch.utils.data.DataLoader(
233
val_dataset,
234
batch_size=args.batch_size, shuffle=False,
235
num_workers=args.workers, pin_memory=True,
236
sampler=val_sampler,
237
collate_fn=collate_fn)
238
239
if args.evaluate:
240
validate(val_loader, model, criterion)
241
return
242
243
for epoch in range(args.start_epoch, args.epochs):
244
if args.distributed:
245
train_sampler.set_epoch(epoch)
246
247
# train for one epoch
248
train(train_loader, model, criterion, optimizer, epoch)
249
250
# evaluate on validation set
251
prec1 = validate(val_loader, model, criterion)
252
253
# remember best prec@1 and save checkpoint
254
if args.local_rank == 0:
255
is_best = prec1 > best_prec1
256
best_prec1 = max(prec1, best_prec1)
257
save_checkpoint({
258
'epoch': epoch + 1,
259
'arch': args.arch,
260
'state_dict': model.state_dict(),
261
'best_prec1': best_prec1,
262
'optimizer' : optimizer.state_dict(),
263
}, is_best)
264
265
class data_prefetcher():
266
def __init__(self, loader):
267
self.loader = iter(loader)
268
self.stream = torch.cuda.Stream()
269
self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
270
self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
271
# With Amp, it isn't necessary to manually convert data to half.
272
# if args.fp16:
273
# self.mean = self.mean.half()
274
# self.std = self.std.half()
275
self.preload()
276
277
def preload(self):
278
try:
279
self.next_input, self.next_target = next(self.loader)
280
except StopIteration:
281
self.next_input = None
282
self.next_target = None
283
return
284
# if record_stream() doesn't work, another option is to make sure device inputs are created
285
# on the main stream.
286
# self.next_input_gpu = torch.empty_like(self.next_input, device='cuda')
287
# self.next_target_gpu = torch.empty_like(self.next_target, device='cuda')
288
# Need to make sure the memory allocated for next_* is not still in use by the main stream
289
# at the time we start copying to next_*:
290
# self.stream.wait_stream(torch.cuda.current_stream())
291
with torch.cuda.stream(self.stream):
292
self.next_input = self.next_input.cuda(non_blocking=True)
293
self.next_target = self.next_target.cuda(non_blocking=True)
294
# more code for the alternative if record_stream() doesn't work:
295
# copy_ will record the use of the pinned source tensor in this side stream.
296
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
297
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
298
# self.next_input = self.next_input_gpu
299
# self.next_target = self.next_target_gpu
300
301
# With Amp, it isn't necessary to manually convert data to half.
302
# if args.fp16:
303
# self.next_input = self.next_input.half()
304
# else:
305
self.next_input = self.next_input.float()
306
self.next_input = self.next_input.sub_(self.mean).div_(self.std)
307
308
def next(self):
309
torch.cuda.current_stream().wait_stream(self.stream)
310
input = self.next_input
311
target = self.next_target
312
if input is not None:
313
input.record_stream(torch.cuda.current_stream())
314
if target is not None:
315
target.record_stream(torch.cuda.current_stream())
316
self.preload()
317
return input, target
318
319
320
def train(train_loader, model, criterion, optimizer, epoch):
321
batch_time = AverageMeter()
322
losses = AverageMeter()
323
top1 = AverageMeter()
324
top5 = AverageMeter()
325
326
# switch to train mode
327
model.train()
328
end = time.time()
329
330
prefetcher = data_prefetcher(train_loader)
331
input, target = prefetcher.next()
332
i = 0
333
while input is not None:
334
i += 1
335
if args.prof >= 0 and i == args.prof:
336
print("Profiling begun at iteration {}".format(i))
337
torch.cuda.cudart().cudaProfilerStart()
338
339
if args.prof >= 0: torch.cuda.nvtx.range_push("Body of iteration {}".format(i))
340
341
adjust_learning_rate(optimizer, epoch, i, len(train_loader))
342
343
# compute output
344
if args.prof >= 0: torch.cuda.nvtx.range_push("forward")
345
output = model(input)
346
if args.prof >= 0: torch.cuda.nvtx.range_pop()
347
loss = criterion(output, target)
348
349
# compute gradient and do SGD step
350
optimizer.zero_grad()
351
352
if args.prof >= 0: torch.cuda.nvtx.range_push("backward")
353
with amp.scale_loss(loss, optimizer) as scaled_loss:
354
scaled_loss.backward()
355
if args.prof >= 0: torch.cuda.nvtx.range_pop()
356
357
# for param in model.parameters():
358
# print(param.data.double().sum().item(), param.grad.data.double().sum().item())
359
360
if args.prof >= 0: torch.cuda.nvtx.range_push("optimizer.step()")
361
optimizer.step()
362
if args.prof >= 0: torch.cuda.nvtx.range_pop()
363
364
if i%args.print_freq == 0:
365
# Every print_freq iterations, check the loss, accuracy, and speed.
366
# For best performance, it doesn't make sense to print these metrics every
367
# iteration, since they incur an allreduce and some host<->device syncs.
368
369
# Measure accuracy
370
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
371
372
# Average loss and accuracy across processes for logging
373
if args.distributed:
374
reduced_loss = reduce_tensor(loss.data)
375
prec1 = reduce_tensor(prec1)
376
prec5 = reduce_tensor(prec5)
377
else:
378
reduced_loss = loss.data
379
380
# to_python_float incurs a host<->device sync
381
losses.update(to_python_float(reduced_loss), input.size(0))
382
top1.update(to_python_float(prec1), input.size(0))
383
top5.update(to_python_float(prec5), input.size(0))
384
385
torch.cuda.synchronize()
386
batch_time.update((time.time() - end)/args.print_freq)
387
end = time.time()
388
389
if args.local_rank == 0:
390
print('Epoch: [{0}][{1}/{2}]\t'
391
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
392
'Speed {3:.3f} ({4:.3f})\t'
393
'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
394
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
395
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
396
epoch, i, len(train_loader),
397
args.world_size*args.batch_size/batch_time.val,
398
args.world_size*args.batch_size/batch_time.avg,
399
batch_time=batch_time,
400
loss=losses, top1=top1, top5=top5))
401
if args.prof >= 0: torch.cuda.nvtx.range_push("prefetcher.next()")
402
input, target = prefetcher.next()
403
if args.prof >= 0: torch.cuda.nvtx.range_pop()
404
405
# Pop range "Body of iteration {}".format(i)
406
if args.prof >= 0: torch.cuda.nvtx.range_pop()
407
408
if args.prof >= 0 and i == args.prof + 10:
409
print("Profiling ended at iteration {}".format(i))
410
torch.cuda.cudart().cudaProfilerStop()
411
quit()
412
413
414
def validate(val_loader, model, criterion):
415
batch_time = AverageMeter()
416
losses = AverageMeter()
417
top1 = AverageMeter()
418
top5 = AverageMeter()
419
420
# switch to evaluate mode
421
model.eval()
422
423
end = time.time()
424
425
prefetcher = data_prefetcher(val_loader)
426
input, target = prefetcher.next()
427
i = 0
428
while input is not None:
429
i += 1
430
431
# compute output
432
with torch.no_grad():
433
output = model(input)
434
loss = criterion(output, target)
435
436
# measure accuracy and record loss
437
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
438
439
if args.distributed:
440
reduced_loss = reduce_tensor(loss.data)
441
prec1 = reduce_tensor(prec1)
442
prec5 = reduce_tensor(prec5)
443
else:
444
reduced_loss = loss.data
445
446
losses.update(to_python_float(reduced_loss), input.size(0))
447
top1.update(to_python_float(prec1), input.size(0))
448
top5.update(to_python_float(prec5), input.size(0))
449
450
# measure elapsed time
451
batch_time.update(time.time() - end)
452
end = time.time()
453
454
# TODO: Change timings to mirror train().
455
if args.local_rank == 0 and i % args.print_freq == 0:
456
print('Test: [{0}/{1}]\t'
457
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
458
'Speed {2:.3f} ({3:.3f})\t'
459
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
460
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
461
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
462
i, len(val_loader),
463
args.world_size * args.batch_size / batch_time.val,
464
args.world_size * args.batch_size / batch_time.avg,
465
batch_time=batch_time, loss=losses,
466
top1=top1, top5=top5))
467
468
input, target = prefetcher.next()
469
470
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
471
.format(top1=top1, top5=top5))
472
473
return top1.avg
474
475
476
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
477
torch.save(state, filename)
478
if is_best:
479
shutil.copyfile(filename, 'model_best.pth.tar')
480
481
482
class AverageMeter(object):
483
"""Computes and stores the average and current value"""
484
def __init__(self):
485
self.reset()
486
487
def reset(self):
488
self.val = 0
489
self.avg = 0
490
self.sum = 0
491
self.count = 0
492
493
def update(self, val, n=1):
494
self.val = val
495
self.sum += val * n
496
self.count += n
497
self.avg = self.sum / self.count
498
499
500
def adjust_learning_rate(optimizer, epoch, step, len_epoch):
501
"""LR schedule that should yield 76% converged accuracy with batch size 256"""
502
factor = epoch // 30
503
504
if epoch >= 80:
505
factor = factor + 1
506
507
lr = args.lr*(0.1**factor)
508
509
"""Warmup"""
510
if epoch < 5:
511
lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch)
512
513
# if(args.local_rank == 0):
514
# print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr))
515
516
for param_group in optimizer.param_groups:
517
param_group['lr'] = lr
518
519
520
def accuracy(output, target, topk=(1,)):
521
"""Computes the precision@k for the specified values of k"""
522
maxk = max(topk)
523
batch_size = target.size(0)
524
525
_, pred = output.topk(maxk, 1, True, True)
526
pred = pred.t()
527
correct = pred.eq(target.view(1, -1).expand_as(pred))
528
529
res = []
530
for k in topk:
531
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
532
res.append(correct_k.mul_(100.0 / batch_size))
533
return res
534
535
536
def reduce_tensor(tensor):
537
rt = tensor.clone()
538
dist.all_reduce(rt, op=dist.reduce_op.SUM)
539
rt /= args.world_size
540
return rt
541
542
if __name__ == '__main__':
543
main()
544
545