Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/tests/L0/run_optimizers/test_dist_adam.py
Views: 792
import argparse1import random2import sys34import torch5from torch.nn.parallel import DistributedDataParallel as DDP67from apex import amp8from apex.optimizers import FusedAdam9from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam101112class TestModel(torch.nn.Module):13def __init__(self, args):14super(TestModel, self).__init__()1516self.linear = torch.nn.Sequential(*[torch.nn.Linear(args.dim, args.dim, bias=args.bias) for _ in range(args.layers)])1718def forward(self, x):19return self.linear(x)2021def setup(args):22## Model23ref_model = TestModel(args).cuda()24dist_model = TestModel(args).cuda()2526# Same weights27with torch.no_grad():28for dp, rp in zip(dist_model.parameters(), ref_model.parameters()):29dp.data.copy_(rp.data)3031dist_model = dist_model.half()323334## Optimizer35# same hyperparameters36ref_opt_args = { 'lr': 1e-3, 'eps': 1e-6, 'weight_decay': 0.01 }37ref_opt = FusedAdam(ref_model.parameters(), **ref_opt_args)3839dist_opt_args = ref_opt_args.copy()40dist_opt_args.update( {'overlap_reductions' : False} )41dist_opt_args.update( {'process_group_size' : args.n_gpu} )42dist_opt_args.update( {'dwu_group_size' : args.dwu_group_size} )43dist_opt_args.update( {'dwu_num_blocks' : 1} )44dist_opt_args.update( {'dwu_num_chunks' : 1} )45dist_opt = DistributedFusedAdam(dist_model.parameters(), **dist_opt_args)46dist_opt.set_global_scale(1.)4748## amp-init49amp_args = { 'loss_scale' : 'dynamic' , 'opt_level' : 'O2'}50ref_model, ref_opt = amp.initialize(ref_model, ref_opt, **amp_args)515253## DDP54ref_model = DDP(ref_model, device_ids=[args.rank])55with torch.no_grad():56for dp in dist_model.parameters():57torch.distributed.broadcast(dp.data, src=0)58for rp in ref_model.parameters():59torch.distributed.broadcast(rp.data, src=0)60torch.cuda.synchronize()61torch.distributed.barrier()62if get_rank() == 0:63print(f'dist opt with {args.n_gpu} GPUs')6465return ref_model, ref_opt, dist_model, dist_opt6667def parse_args():68parser = argparse.ArgumentParser()6970parser.add_argument('--local_rank', type=int, default=-1)71parser.add_argument('--steps', type=int, default=20)72parser.add_argument('--batch', type=int, default=32)73parser.add_argument('--dim', type=int, default=4)74parser.add_argument('--layers', type=int, default=2)75parser.add_argument('--bias', action='store_true')76parser.add_argument('--atol', type=float, default=1e-3)77parser.add_argument('--rtol', type=float, default=1)78parser.add_argument('--dwu_group_size', type=float, default=1)7980args = parser.parse_args()8182return args8384def setup_env(args):85torch.cuda.set_device(args.local_rank)86torch.distributed.init_process_group(backend='nccl', init_method='env://')87args.rank = torch.distributed.get_rank()88args.n_gpu = torch.distributed.get_world_size()8990seed = 42 + get_rank()9192random.seed(seed)93torch.manual_seed(seed)9495return args9697def get_rank():98return torch.distributed.get_rank()99100def main():101args = parse_args()102args = setup_env(args)103tol_args = { 'atol' : args.atol, 'rtol' : args.rtol }104105torch.set_printoptions(precision=16)106107ref_model, ref_opt, dist_model, dist_opt = setup(args)108109# lazy_init not called yet, initialize stash110stash = ref_opt._amp_stash111stash.all_fp16_params, stash.all_fp32_from_fp16_params = [], []112113# make sure everything from _first_step_init_ is ready before training114# e.g. registering allreduce_hook115# so that gradients are copied/reduced when necessary116dist_opt._init_everything()117118for i in range(args.steps):119x_ref = torch.randn(args.batch, args.dim, dtype=torch.half).cuda().requires_grad_(True)120x_dist = x_ref.clone().detach().requires_grad_(True)121122if get_rank() == 0:123print(f'[{i}] Checking input')124#print("x_ref:", x_ref.flatten()[:10])125#print("x_dist:", x_dist.flatten()[:10])126assert(torch.allclose(x_ref, x_dist, **tol_args))127128y_ref = ref_model(x_ref).half()129y_dist = dist_model(x_dist)130131if get_rank() == 0:132print(f'[{i}] Checking output')133#print("y_ref:", y_ref.flatten()[:10])134#print("y_dist:", y_dist.flatten()[:10])135assert(torch.allclose(y_ref, y_dist, **tol_args))136137dy = torch.randn_like(y_ref)138139y_ref.backward(dy)140y_dist.backward(dy)141142if get_rank() == 0:143print(f'[{i}] Checking gradients')144torch.distributed.barrier()145torch.cuda.synchronize()146assert(torch.allclose(x_ref.grad, x_dist.grad, **tol_args))147148# gradient all-reduce within distributed optimizer149dist_opt.complete_reductions()150151if get_rank() == 0:152print(f'[{i}] Stepping')153ref_opt.step()154dist_opt.step()155156torch.cuda.synchronize()157torch.distributed.barrier()158print('Checking new weights')159if get_rank() == 0:160print("ref param:", ref_model.module.linear[0].weight)161print("dist param:", dist_model.linear[0].weight)162163for i, (rp, dp) in enumerate(zip(ref_model.parameters(), dist_model.parameters())):164if not torch.allclose(rp, dp, **tol_args):165if get_rank() == 0:166print(f'Rank: {get_rank()}, Param: {i}')167print(f'ref: {rp.sum().item()}, dist: {dp.sum().item()}')168print(rp)169print(dp)170171print(torch.abs(rp-dp) > tol_args['atol'])172sys.exit(0)173174# zero grads175for rp, dp in zip(ref_model.parameters(), dist_model.parameters()):176rp.grad = None177dp.grad = None178179180if __name__ == "__main__":181main()182183184185