Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/examples/imagenet/main_amp.py
Views: 792
import argparse1import os2import shutil3import time45import torch6import torch.nn as nn7import torch.nn.parallel8import torch.backends.cudnn as cudnn9import torch.distributed as dist10import torch.optim11import torch.utils.data12import torch.utils.data.distributed13import torchvision.transforms as transforms14import torchvision.datasets as datasets15import torchvision.models as models1617import numpy as np1819try:20from apex.parallel import DistributedDataParallel as DDP21from apex.fp16_utils import *22from apex import amp, optimizers23from apex.multi_tensor_apply import multi_tensor_applier24except ImportError:25raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")2627def fast_collate(batch, memory_format):2829imgs = [img[0] for img in batch]30targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)31w = imgs[0].size[0]32h = imgs[0].size[1]33tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8).contiguous(memory_format=memory_format)34for i, img in enumerate(imgs):35nump_array = np.asarray(img, dtype=np.uint8)36if(nump_array.ndim < 3):37nump_array = np.expand_dims(nump_array, axis=-1)38nump_array = np.rollaxis(nump_array, 2)39tensor[i] += torch.from_numpy(nump_array)40return tensor, targets414243def parse():44model_names = sorted(name for name in models.__dict__45if name.islower() and not name.startswith("__")46and callable(models.__dict__[name]))4748parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')49parser.add_argument('data', metavar='DIR',50help='path to dataset')51parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',52choices=model_names,53help='model architecture: ' +54' | '.join(model_names) +55' (default: resnet18)')56parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',57help='number of data loading workers (default: 4)')58parser.add_argument('--epochs', default=90, type=int, metavar='N',59help='number of total epochs to run')60parser.add_argument('--start-epoch', default=0, type=int, metavar='N',61help='manual epoch number (useful on restarts)')62parser.add_argument('-b', '--batch-size', default=256, type=int,63metavar='N', help='mini-batch size per process (default: 256)')64parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,65metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')66parser.add_argument('--momentum', default=0.9, type=float, metavar='M',67help='momentum')68parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,69metavar='W', help='weight decay (default: 1e-4)')70parser.add_argument('--print-freq', '-p', default=10, type=int,71metavar='N', help='print frequency (default: 10)')72parser.add_argument('--resume', default='', type=str, metavar='PATH',73help='path to latest checkpoint (default: none)')74parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',75help='evaluate model on validation set')76parser.add_argument('--pretrained', dest='pretrained', action='store_true',77help='use pre-trained model')7879parser.add_argument('--prof', default=-1, type=int,80help='Only run 10 iterations for profiling.')81parser.add_argument('--deterministic', action='store_true')8283parser.add_argument("--local_rank", default=0, type=int)84parser.add_argument('--sync_bn', action='store_true',85help='enabling apex sync BN.')8687parser.add_argument('--opt-level', type=str)88parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)89parser.add_argument('--loss-scale', type=str, default=None)90parser.add_argument('--channels-last', type=bool, default=False)91args = parser.parse_args()92return args9394def main():95global best_prec1, args9697args = parse()98print("opt_level = {}".format(args.opt_level))99print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))100print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))101102print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))103104cudnn.benchmark = True105best_prec1 = 0106if args.deterministic:107cudnn.benchmark = False108cudnn.deterministic = True109torch.manual_seed(args.local_rank)110torch.set_printoptions(precision=10)111112args.distributed = False113if 'WORLD_SIZE' in os.environ:114args.distributed = int(os.environ['WORLD_SIZE']) > 1115116args.gpu = 0117args.world_size = 1118119if args.distributed:120args.gpu = args.local_rank121torch.cuda.set_device(args.gpu)122torch.distributed.init_process_group(backend='nccl',123init_method='env://')124args.world_size = torch.distributed.get_world_size()125126assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."127128if args.channels_last:129memory_format = torch.channels_last130else:131memory_format = torch.contiguous_format132133# create model134if args.pretrained:135print("=> using pre-trained model '{}'".format(args.arch))136model = models.__dict__[args.arch](pretrained=True)137else:138print("=> creating model '{}'".format(args.arch))139model = models.__dict__[args.arch]()140141if args.sync_bn:142import apex143print("using apex synced BN")144model = apex.parallel.convert_syncbn_model(model)145146model = model.cuda().to(memory_format=memory_format)147148# Scale learning rate based on global batch size149args.lr = args.lr*float(args.batch_size*args.world_size)/256.150optimizer = torch.optim.SGD(model.parameters(), args.lr,151momentum=args.momentum,152weight_decay=args.weight_decay)153154# Initialize Amp. Amp accepts either values or strings for the optional override arguments,155# for convenient interoperation with argparse.156model, optimizer = amp.initialize(model, optimizer,157opt_level=args.opt_level,158keep_batchnorm_fp32=args.keep_batchnorm_fp32,159loss_scale=args.loss_scale160)161162# For distributed training, wrap the model with apex.parallel.DistributedDataParallel.163# This must be done AFTER the call to amp.initialize. If model = DDP(model) is called164# before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter165# the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.166if args.distributed:167# By default, apex.parallel.DistributedDataParallel overlaps communication with168# computation in the backward pass.169# model = DDP(model)170# delay_allreduce delays all communication to the end of the backward pass.171model = DDP(model, delay_allreduce=True)172173# define loss function (criterion) and optimizer174criterion = nn.CrossEntropyLoss().cuda()175176# Optionally resume from a checkpoint177if args.resume:178# Use a local scope to avoid dangling references179def resume():180if os.path.isfile(args.resume):181print("=> loading checkpoint '{}'".format(args.resume))182checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))183args.start_epoch = checkpoint['epoch']184global best_prec1185best_prec1 = checkpoint['best_prec1']186model.load_state_dict(checkpoint['state_dict'])187optimizer.load_state_dict(checkpoint['optimizer'])188print("=> loaded checkpoint '{}' (epoch {})"189.format(args.resume, checkpoint['epoch']))190else:191print("=> no checkpoint found at '{}'".format(args.resume))192resume()193194# Data loading code195traindir = os.path.join(args.data, 'train')196valdir = os.path.join(args.data, 'val')197198if(args.arch == "inception_v3"):199raise RuntimeError("Currently, inception_v3 is not supported by this example.")200# crop_size = 299201# val_size = 320 # I chose this value arbitrarily, we can adjust.202else:203crop_size = 224204val_size = 256205206train_dataset = datasets.ImageFolder(207traindir,208transforms.Compose([209transforms.RandomResizedCrop(crop_size),210transforms.RandomHorizontalFlip(),211# transforms.ToTensor(), Too slow212# normalize,213]))214val_dataset = datasets.ImageFolder(valdir, transforms.Compose([215transforms.Resize(val_size),216transforms.CenterCrop(crop_size),217]))218219train_sampler = None220val_sampler = None221if args.distributed:222train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)223val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)224225collate_fn = lambda b: fast_collate(b, memory_format)226227train_loader = torch.utils.data.DataLoader(228train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),229num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=collate_fn)230231val_loader = torch.utils.data.DataLoader(232val_dataset,233batch_size=args.batch_size, shuffle=False,234num_workers=args.workers, pin_memory=True,235sampler=val_sampler,236collate_fn=collate_fn)237238if args.evaluate:239validate(val_loader, model, criterion)240return241242for epoch in range(args.start_epoch, args.epochs):243if args.distributed:244train_sampler.set_epoch(epoch)245246# train for one epoch247train(train_loader, model, criterion, optimizer, epoch)248249# evaluate on validation set250prec1 = validate(val_loader, model, criterion)251252# remember best prec@1 and save checkpoint253if args.local_rank == 0:254is_best = prec1 > best_prec1255best_prec1 = max(prec1, best_prec1)256save_checkpoint({257'epoch': epoch + 1,258'arch': args.arch,259'state_dict': model.state_dict(),260'best_prec1': best_prec1,261'optimizer' : optimizer.state_dict(),262}, is_best)263264class data_prefetcher():265def __init__(self, loader):266self.loader = iter(loader)267self.stream = torch.cuda.Stream()268self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)269self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)270# With Amp, it isn't necessary to manually convert data to half.271# if args.fp16:272# self.mean = self.mean.half()273# self.std = self.std.half()274self.preload()275276def preload(self):277try:278self.next_input, self.next_target = next(self.loader)279except StopIteration:280self.next_input = None281self.next_target = None282return283# if record_stream() doesn't work, another option is to make sure device inputs are created284# on the main stream.285# self.next_input_gpu = torch.empty_like(self.next_input, device='cuda')286# self.next_target_gpu = torch.empty_like(self.next_target, device='cuda')287# Need to make sure the memory allocated for next_* is not still in use by the main stream288# at the time we start copying to next_*:289# self.stream.wait_stream(torch.cuda.current_stream())290with torch.cuda.stream(self.stream):291self.next_input = self.next_input.cuda(non_blocking=True)292self.next_target = self.next_target.cuda(non_blocking=True)293# more code for the alternative if record_stream() doesn't work:294# copy_ will record the use of the pinned source tensor in this side stream.295# self.next_input_gpu.copy_(self.next_input, non_blocking=True)296# self.next_target_gpu.copy_(self.next_target, non_blocking=True)297# self.next_input = self.next_input_gpu298# self.next_target = self.next_target_gpu299300# With Amp, it isn't necessary to manually convert data to half.301# if args.fp16:302# self.next_input = self.next_input.half()303# else:304self.next_input = self.next_input.float()305self.next_input = self.next_input.sub_(self.mean).div_(self.std)306307def next(self):308torch.cuda.current_stream().wait_stream(self.stream)309input = self.next_input310target = self.next_target311if input is not None:312input.record_stream(torch.cuda.current_stream())313if target is not None:314target.record_stream(torch.cuda.current_stream())315self.preload()316return input, target317318319def train(train_loader, model, criterion, optimizer, epoch):320batch_time = AverageMeter()321losses = AverageMeter()322top1 = AverageMeter()323top5 = AverageMeter()324325# switch to train mode326model.train()327end = time.time()328329prefetcher = data_prefetcher(train_loader)330input, target = prefetcher.next()331i = 0332while input is not None:333i += 1334if args.prof >= 0 and i == args.prof:335print("Profiling begun at iteration {}".format(i))336torch.cuda.cudart().cudaProfilerStart()337338if args.prof >= 0: torch.cuda.nvtx.range_push("Body of iteration {}".format(i))339340adjust_learning_rate(optimizer, epoch, i, len(train_loader))341342# compute output343if args.prof >= 0: torch.cuda.nvtx.range_push("forward")344output = model(input)345if args.prof >= 0: torch.cuda.nvtx.range_pop()346loss = criterion(output, target)347348# compute gradient and do SGD step349optimizer.zero_grad()350351if args.prof >= 0: torch.cuda.nvtx.range_push("backward")352with amp.scale_loss(loss, optimizer) as scaled_loss:353scaled_loss.backward()354if args.prof >= 0: torch.cuda.nvtx.range_pop()355356# for param in model.parameters():357# print(param.data.double().sum().item(), param.grad.data.double().sum().item())358359if args.prof >= 0: torch.cuda.nvtx.range_push("optimizer.step()")360optimizer.step()361if args.prof >= 0: torch.cuda.nvtx.range_pop()362363if i%args.print_freq == 0:364# Every print_freq iterations, check the loss, accuracy, and speed.365# For best performance, it doesn't make sense to print these metrics every366# iteration, since they incur an allreduce and some host<->device syncs.367368# Measure accuracy369prec1, prec5 = accuracy(output.data, target, topk=(1, 5))370371# Average loss and accuracy across processes for logging372if args.distributed:373reduced_loss = reduce_tensor(loss.data)374prec1 = reduce_tensor(prec1)375prec5 = reduce_tensor(prec5)376else:377reduced_loss = loss.data378379# to_python_float incurs a host<->device sync380losses.update(to_python_float(reduced_loss), input.size(0))381top1.update(to_python_float(prec1), input.size(0))382top5.update(to_python_float(prec5), input.size(0))383384torch.cuda.synchronize()385batch_time.update((time.time() - end)/args.print_freq)386end = time.time()387388if args.local_rank == 0:389print('Epoch: [{0}][{1}/{2}]\t'390'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'391'Speed {3:.3f} ({4:.3f})\t'392'Loss {loss.val:.10f} ({loss.avg:.4f})\t'393'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'394'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(395epoch, i, len(train_loader),396args.world_size*args.batch_size/batch_time.val,397args.world_size*args.batch_size/batch_time.avg,398batch_time=batch_time,399loss=losses, top1=top1, top5=top5))400if args.prof >= 0: torch.cuda.nvtx.range_push("prefetcher.next()")401input, target = prefetcher.next()402if args.prof >= 0: torch.cuda.nvtx.range_pop()403404# Pop range "Body of iteration {}".format(i)405if args.prof >= 0: torch.cuda.nvtx.range_pop()406407if args.prof >= 0 and i == args.prof + 10:408print("Profiling ended at iteration {}".format(i))409torch.cuda.cudart().cudaProfilerStop()410quit()411412413def validate(val_loader, model, criterion):414batch_time = AverageMeter()415losses = AverageMeter()416top1 = AverageMeter()417top5 = AverageMeter()418419# switch to evaluate mode420model.eval()421422end = time.time()423424prefetcher = data_prefetcher(val_loader)425input, target = prefetcher.next()426i = 0427while input is not None:428i += 1429430# compute output431with torch.no_grad():432output = model(input)433loss = criterion(output, target)434435# measure accuracy and record loss436prec1, prec5 = accuracy(output.data, target, topk=(1, 5))437438if args.distributed:439reduced_loss = reduce_tensor(loss.data)440prec1 = reduce_tensor(prec1)441prec5 = reduce_tensor(prec5)442else:443reduced_loss = loss.data444445losses.update(to_python_float(reduced_loss), input.size(0))446top1.update(to_python_float(prec1), input.size(0))447top5.update(to_python_float(prec5), input.size(0))448449# measure elapsed time450batch_time.update(time.time() - end)451end = time.time()452453# TODO: Change timings to mirror train().454if args.local_rank == 0 and i % args.print_freq == 0:455print('Test: [{0}/{1}]\t'456'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'457'Speed {2:.3f} ({3:.3f})\t'458'Loss {loss.val:.4f} ({loss.avg:.4f})\t'459'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'460'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(461i, len(val_loader),462args.world_size * args.batch_size / batch_time.val,463args.world_size * args.batch_size / batch_time.avg,464batch_time=batch_time, loss=losses,465top1=top1, top5=top5))466467input, target = prefetcher.next()468469print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'470.format(top1=top1, top5=top5))471472return top1.avg473474475def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):476torch.save(state, filename)477if is_best:478shutil.copyfile(filename, 'model_best.pth.tar')479480481class AverageMeter(object):482"""Computes and stores the average and current value"""483def __init__(self):484self.reset()485486def reset(self):487self.val = 0488self.avg = 0489self.sum = 0490self.count = 0491492def update(self, val, n=1):493self.val = val494self.sum += val * n495self.count += n496self.avg = self.sum / self.count497498499def adjust_learning_rate(optimizer, epoch, step, len_epoch):500"""LR schedule that should yield 76% converged accuracy with batch size 256"""501factor = epoch // 30502503if epoch >= 80:504factor = factor + 1505506lr = args.lr*(0.1**factor)507508"""Warmup"""509if epoch < 5:510lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch)511512# if(args.local_rank == 0):513# print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr))514515for param_group in optimizer.param_groups:516param_group['lr'] = lr517518519def accuracy(output, target, topk=(1,)):520"""Computes the precision@k for the specified values of k"""521maxk = max(topk)522batch_size = target.size(0)523524_, pred = output.topk(maxk, 1, True, True)525pred = pred.t()526correct = pred.eq(target.view(1, -1).expand_as(pred))527528res = []529for k in topk:530correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)531res.append(correct_k.mul_(100.0 / batch_size))532return res533534535def reduce_tensor(tensor):536rt = tensor.clone()537dist.all_reduce(rt, op=dist.reduce_op.SUM)538rt /= args.world_size539return rt540541if __name__ == '__main__':542main()543544545