Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/tests/L1/common/main_amp.py
Views: 794
import argparse1import os2import shutil3import time45import torch6import torch.nn as nn7import torch.nn.parallel8import torch.backends.cudnn as cudnn9import torch.distributed as dist10import torch.optim11import torch.utils.data12import torch.utils.data.distributed13import torchvision.transforms as transforms14import torchvision.datasets as datasets15import torchvision.models as models1617import numpy as np1819try:20from apex.parallel import DistributedDataParallel as DDP21from apex.fp16_utils import *22from apex import amp, optimizers23from apex.multi_tensor_apply import multi_tensor_applier24except ImportError:25raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")2627model_names = sorted(name for name in models.__dict__28if name.islower() and not name.startswith("__")29and callable(models.__dict__[name]))3031parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')32parser.add_argument('data', metavar='DIR',33help='path to dataset')34parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',35choices=model_names,36help='model architecture: ' +37' | '.join(model_names) +38' (default: resnet18)')39parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',40help='number of data loading workers (default: 4)')41parser.add_argument('--epochs', default=90, type=int, metavar='N',42help='number of total epochs to run')43parser.add_argument('--start-epoch', default=0, type=int, metavar='N',44help='manual epoch number (useful on restarts)')45parser.add_argument('-b', '--batch-size', default=256, type=int,46metavar='N', help='mini-batch size per process (default: 256)')47parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,48metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')49parser.add_argument('--momentum', default=0.9, type=float, metavar='M',50help='momentum')51parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,52metavar='W', help='weight decay (default: 1e-4)')53parser.add_argument('--print-freq', '-p', default=10, type=int,54metavar='N', help='print frequency (default: 10)')55parser.add_argument('--resume', default='', type=str, metavar='PATH',56help='path to latest checkpoint (default: none)')57parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',58help='evaluate model on validation set')59parser.add_argument('--pretrained', dest='pretrained', action='store_true',60help='use pre-trained model')6162parser.add_argument('--prof', dest='prof', action='store_true',63help='Only run 10 iterations for profiling.')64parser.add_argument('--deterministic', action='store_true')6566parser.add_argument("--local_rank", default=0, type=int)67parser.add_argument('--sync_bn', action='store_true',68help='enabling apex sync BN.')6970parser.add_argument('--has-ext', action='store_true')71parser.add_argument('--opt-level', type=str)72parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)73parser.add_argument('--loss-scale', type=str, default=None)74parser.add_argument('--fused-adam', action='store_true')7576parser.add_argument('--prints-to-process', type=int, default=10)7778cudnn.benchmark = True7980def fast_collate(batch):81imgs = [img[0] for img in batch]82targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)83w = imgs[0].size[0]84h = imgs[0].size[1]85tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 )86for i, img in enumerate(imgs):87nump_array = np.asarray(img, dtype=np.uint8)88if(nump_array.ndim < 3):89nump_array = np.expand_dims(nump_array, axis=-1)90nump_array = np.rollaxis(nump_array, 2)9192tensor[i] += torch.from_numpy(nump_array)9394return tensor, targets9596best_prec1 = 097args = parser.parse_args()9899# Let multi_tensor_applier be the canary in the coalmine100# that verifies if the backend is what we think it is101assert multi_tensor_applier.available == args.has_ext102103print("opt_level = {}".format(args.opt_level))104print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))105print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))106107108print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))109110if args.deterministic:111cudnn.benchmark = False112cudnn.deterministic = True113torch.manual_seed(args.local_rank)114torch.set_printoptions(precision=10)115116def main():117global best_prec1, args118119args.distributed = False120if 'WORLD_SIZE' in os.environ:121args.distributed = int(os.environ['WORLD_SIZE']) > 1122123args.gpu = 0124args.world_size = 1125126if args.distributed:127args.gpu = args.local_rank % torch.cuda.device_count()128torch.cuda.set_device(args.gpu)129torch.distributed.init_process_group(backend='nccl',130init_method='env://')131args.world_size = torch.distributed.get_world_size()132133assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."134135# create model136if args.pretrained:137print("=> using pre-trained model '{}'".format(args.arch))138model = models.__dict__[args.arch](pretrained=True)139else:140print("=> creating model '{}'".format(args.arch))141model = models.__dict__[args.arch]()142143if args.sync_bn:144import apex145print("using apex synced BN")146model = apex.parallel.convert_syncbn_model(model)147148model = model.cuda()149150# Scale learning rate based on global batch size151args.lr = args.lr*float(args.batch_size*args.world_size)/256.152if args.fused_adam:153optimizer = optimizers.FusedAdam(model.parameters())154else:155optimizer = torch.optim.SGD(model.parameters(), args.lr,156momentum=args.momentum,157weight_decay=args.weight_decay)158159model, optimizer = amp.initialize(160model, optimizer,161# enabled=False,162opt_level=args.opt_level,163keep_batchnorm_fp32=args.keep_batchnorm_fp32,164loss_scale=args.loss_scale165)166167if args.distributed:168# By default, apex.parallel.DistributedDataParallel overlaps communication with169# computation in the backward pass.170# model = DDP(model)171# delay_allreduce delays all communication to the end of the backward pass.172model = DDP(model, delay_allreduce=True)173174# define loss function (criterion) and optimizer175criterion = nn.CrossEntropyLoss().cuda()176177# Optionally resume from a checkpoint178if args.resume:179# Use a local scope to avoid dangling references180def resume():181if os.path.isfile(args.resume):182print("=> loading checkpoint '{}'".format(args.resume))183checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))184args.start_epoch = checkpoint['epoch']185best_prec1 = checkpoint['best_prec1']186model.load_state_dict(checkpoint['state_dict'])187optimizer.load_state_dict(checkpoint['optimizer'])188print("=> loaded checkpoint '{}' (epoch {})"189.format(args.resume, checkpoint['epoch']))190else:191print("=> no checkpoint found at '{}'".format(args.resume))192resume()193194# Data loading code195traindir = os.path.join(args.data, 'train')196valdir = os.path.join(args.data, 'val')197198if(args.arch == "inception_v3"):199crop_size = 299200val_size = 320 # I chose this value arbitrarily, we can adjust.201else:202crop_size = 224203val_size = 256204205train_dataset = datasets.ImageFolder(206traindir,207transforms.Compose([208transforms.RandomResizedCrop(crop_size),209transforms.RandomHorizontalFlip(),210# transforms.ToTensor(), Too slow211# normalize,212]))213val_dataset = datasets.ImageFolder(valdir, transforms.Compose([214transforms.Resize(val_size),215transforms.CenterCrop(crop_size),216]))217218train_sampler = None219val_sampler = None220if args.distributed:221train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)222val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)223224train_loader = torch.utils.data.DataLoader(225train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),226num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)227228val_loader = torch.utils.data.DataLoader(229val_dataset,230batch_size=args.batch_size, shuffle=False,231num_workers=args.workers, pin_memory=True,232sampler=val_sampler,233collate_fn=fast_collate)234235if args.evaluate:236validate(val_loader, model, criterion)237return238239for epoch in range(args.start_epoch, args.epochs):240if args.distributed:241train_sampler.set_epoch(epoch)242243# train for one epoch244train(train_loader, model, criterion, optimizer, epoch)245if args.prof:246break247# evaluate on validation set248prec1 = validate(val_loader, model, criterion)249250# remember best prec@1 and save checkpoint251if args.local_rank == 0:252is_best = prec1 > best_prec1253best_prec1 = max(prec1, best_prec1)254save_checkpoint({255'epoch': epoch + 1,256'arch': args.arch,257'state_dict': model.state_dict(),258'best_prec1': best_prec1,259'optimizer' : optimizer.state_dict(),260}, is_best)261262class data_prefetcher():263def __init__(self, loader):264self.loader = iter(loader)265self.stream = torch.cuda.Stream()266self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)267self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)268# With Amp, it isn't necessary to manually convert data to half.269# if args.fp16:270# self.mean = self.mean.half()271# self.std = self.std.half()272self.preload()273274def preload(self):275try:276self.next_input, self.next_target = next(self.loader)277except StopIteration:278self.next_input = None279self.next_target = None280return281with torch.cuda.stream(self.stream):282self.next_input = self.next_input.cuda(non_blocking=True)283self.next_target = self.next_target.cuda(non_blocking=True)284# With Amp, it isn't necessary to manually convert data to half.285# if args.fp16:286# self.next_input = self.next_input.half()287# else:288self.next_input = self.next_input.float()289self.next_input = self.next_input.sub_(self.mean).div_(self.std)290291def next(self):292torch.cuda.current_stream().wait_stream(self.stream)293input = self.next_input294target = self.next_target295self.preload()296return input, target297298299def train(train_loader, model, criterion, optimizer, epoch):300batch_time = AverageMeter()301data_time = AverageMeter()302losses = AverageMeter()303top1 = AverageMeter()304top5 = AverageMeter()305306# switch to train mode307model.train()308end = time.time()309310run_info_dict = {"Iteration" : [],311"Loss" : [],312"Speed" : []}313314prefetcher = data_prefetcher(train_loader)315input, target = prefetcher.next()316i = -1317while input is not None:318i += 1319320# No learning rate warmup for this test, to expose bitwise inaccuracies more quickly321# adjust_learning_rate(optimizer, epoch, i, len(train_loader))322323if args.prof:324if i > 10:325break326# measure data loading time327data_time.update(time.time() - end)328329# compute output330output = model(input)331loss = criterion(output, target)332333# measure accuracy and record loss334prec1, prec5 = accuracy(output.data, target, topk=(1, 5))335336if args.distributed:337reduced_loss = reduce_tensor(loss.data)338prec1 = reduce_tensor(prec1)339prec5 = reduce_tensor(prec5)340else:341reduced_loss = loss.data342343losses.update(to_python_float(reduced_loss), input.size(0))344top1.update(to_python_float(prec1), input.size(0))345top5.update(to_python_float(prec5), input.size(0))346347# compute gradient and do SGD step348optimizer.zero_grad()349350with amp.scale_loss(loss, optimizer) as scaled_loss:351scaled_loss.backward()352353# for param in model.parameters():354# print(param.data.double().sum().item(), param.grad.data.double().sum().item())355356# torch.cuda.synchronize()357torch.cuda.nvtx.range_push("step")358optimizer.step()359torch.cuda.nvtx.range_pop()360361torch.cuda.synchronize()362# measure elapsed time363batch_time.update(time.time() - end)364365end = time.time()366367# If you decide to refactor this test, like examples/imagenet, to sample the loss every368# print_freq iterations, make sure to move this prefetching below the accuracy calculation.369input, target = prefetcher.next()370371if i % args.print_freq == 0 and i > 1:372if args.local_rank == 0:373print('Epoch: [{0}][{1}/{2}]\t'374'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'375'Speed {3:.3f} ({4:.3f})\t'376'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'377'Loss {loss.val:.10f} ({loss.avg:.4f})\t'378'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'379'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(380epoch, i, len(train_loader),381args.world_size * args.batch_size / batch_time.val,382args.world_size * args.batch_size / batch_time.avg,383batch_time=batch_time,384data_time=data_time, loss=losses, top1=top1, top5=top5))385run_info_dict["Iteration"].append(i)386run_info_dict["Loss"].append(losses.val)387run_info_dict["Speed"].append(args.world_size * args.batch_size / batch_time.val)388if len(run_info_dict["Loss"]) == args.prints_to_process:389if args.local_rank == 0:390torch.save(run_info_dict,391str(args.has_ext) + "_" + str(args.opt_level) + "_" +392str(args.loss_scale) + "_" + str(args.keep_batchnorm_fp32) + "_" +393str(args.fused_adam))394quit()395396397def validate(val_loader, model, criterion):398batch_time = AverageMeter()399losses = AverageMeter()400top1 = AverageMeter()401top5 = AverageMeter()402403# switch to evaluate mode404model.eval()405406end = time.time()407408prefetcher = data_prefetcher(val_loader)409input, target = prefetcher.next()410i = -1411while input is not None:412i += 1413414# compute output415with torch.no_grad():416output = model(input)417loss = criterion(output, target)418419# measure accuracy and record loss420prec1, prec5 = accuracy(output.data, target, topk=(1, 5))421422if args.distributed:423reduced_loss = reduce_tensor(loss.data)424prec1 = reduce_tensor(prec1)425prec5 = reduce_tensor(prec5)426else:427reduced_loss = loss.data428429losses.update(to_python_float(reduced_loss), input.size(0))430top1.update(to_python_float(prec1), input.size(0))431top5.update(to_python_float(prec5), input.size(0))432433# measure elapsed time434batch_time.update(time.time() - end)435end = time.time()436437if args.local_rank == 0 and i % args.print_freq == 0:438print('Test: [{0}/{1}]\t'439'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'440'Speed {2:.3f} ({3:.3f})\t'441'Loss {loss.val:.4f} ({loss.avg:.4f})\t'442'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'443'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(444i, len(val_loader),445args.world_size * args.batch_size / batch_time.val,446args.world_size * args.batch_size / batch_time.avg,447batch_time=batch_time, loss=losses,448top1=top1, top5=top5))449450input, target = prefetcher.next()451452print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'453.format(top1=top1, top5=top5))454455return top1.avg456457458def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):459torch.save(state, filename)460if is_best:461shutil.copyfile(filename, 'model_best.pth.tar')462463464class AverageMeter(object):465"""Computes and stores the average and current value"""466def __init__(self):467self.reset()468469def reset(self):470self.val = 0471self.avg = 0472self.sum = 0473self.count = 0474475def update(self, val, n=1):476self.val = val477self.sum += val * n478self.count += n479self.avg = self.sum / self.count480481482def adjust_learning_rate(optimizer, epoch, step, len_epoch):483"""LR schedule that should yield 76% converged accuracy with batch size 256"""484factor = epoch // 30485486if epoch >= 80:487factor = factor + 1488489lr = args.lr*(0.1**factor)490491"""Warmup"""492if epoch < 5:493lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch)494495# if(args.local_rank == 0):496# print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr))497498for param_group in optimizer.param_groups:499param_group['lr'] = lr500501502def accuracy(output, target, topk=(1,)):503"""Computes the precision@k for the specified values of k"""504maxk = max(topk)505batch_size = target.size(0)506507_, pred = output.topk(maxk, 1, True, True)508pred = pred.t()509correct = pred.eq(target.view(1, -1).expand_as(pred))510511res = []512for k in topk:513correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)514res.append(correct_k.mul_(100.0 / batch_size))515return res516517518def reduce_tensor(tensor):519rt = tensor.clone()520dist.all_reduce(rt, op=dist.reduce_op.SUM)521rt /= args.world_size522return rt523524if __name__ == '__main__':525main()526527528