Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/tests/distributed/synced_batchnorm/test_groups.py
Views: 810
import torch1import numpy as np2import apex3import syncbn4import os5import argparse6import torch.optim as optim78def compare(desc, inp1, inp2, error):9a = inp1.clone().detach().cpu().numpy()10b = inp2.clone().detach().cpu().numpy()11close = np.allclose(a,b, error, error)12if not close:13print(desc, close)14z = a - b15index = (np.abs(z) >= error + error * np.abs(b)).nonzero()16print("dif : ", z[index])17print("inp1 : ", a[index])18print("inp2 : ", b[index])19return close2021feature_size = 1022space_size = 4023batch_size = 32242526from apex.parallel import DistributedDataParallel as DDP27parser = argparse.ArgumentParser()28parser.add_argument("--local_rank", default=0, type=int)29parser.add_argument("--fp16", action='store_true', default=False)30parser.add_argument("--fp64", action='store_true', default=False)31parser.add_argument("--group_size", default=0, type=int)32args = parser.parse_args()3334try:35args.world_size = int(os.environ['WORLD_SIZE'])36except:37print("This is a multi-gpu test. To run it please use 'python -m torch.distributed.launch --nproc_per_node=<num gpus> test_groups.py <more options>'")38exit(1)3940torch.cuda.set_device(args.local_rank)41torch.distributed.init_process_group(backend='nccl', init_method='env://')4243start = (args.local_rank%args.group_size) * batch_size//args.group_size44finish = (args.local_rank%args.group_size + 1) * batch_size//args.group_size4546error = 1e-547dtype = np.float3248if args.fp16:49error = 1e-350dtype = np.float1651elif args.fp64:52error = 1e-853dtype = np.float64545556np.random.seed(18 + args.local_rank//args.group_size)5758inp = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype)59grad = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype)60weight = np.random.randn(feature_size).astype(dtype)61bias = np.random.randn(feature_size).astype(dtype)626364type_tensor = torch.cuda.FloatTensor65if args.fp16:66type_tensor = torch.cuda.HalfTensor67if args.fp64:68type_tensor = torch.cuda.DoubleTensor6970ref_tensor = torch.cuda.DoubleTensor7172inp_t = type_tensor(inp)73weight_t = type_tensor(weight)74bias_t = type_tensor(bias)7576inp_r = ref_tensor(inp.transpose(1, 0, 2, 3).reshape(feature_size, -1))77inp2_r = ref_tensor(inp)78weight_r = ref_tensor(weight).view(-1, 1, 1)79bias_r = ref_tensor(bias).view(-1, 1, 1)8081grad_output_t = type_tensor(grad)8283m = inp_r.mean(1)84b_v = inp_r.var(1, unbiased=False)85unb_v = inp_r.var(1, unbiased=True)8687eps = 1e-58889mean, var_biased = syncbn.welford_mean_var(inp_t)90inv_std = 1.0 / torch.sqrt(var_biased + eps)9192bn = torch.nn.BatchNorm2d(feature_size).cuda()93bn.momentum = 1.094bn.weight.data = weight_t.clone()95bn.bias.data = bias_t.clone()96if args.fp16:97bn.half()98if args.fp64:99bn.double()100bn = DDP(bn)101inp_bn = inp_t.clone().requires_grad_()102grad_bn = grad_output_t.clone().detach()103out_bn = bn(inp_bn)104out_bn.backward(grad_bn)105# compensating the averaging over processes done by DDP106# in order to produce mathematically equivalent result107# https://github.com/NVIDIA/apex/issues/134#issuecomment-458307368108for param in bn.parameters():109param.grad = param.grad / args.group_size110bn_opt = optim.SGD(bn.parameters(), lr=1.0)111112sbn = apex.parallel.SyncBatchNorm(feature_size, process_group=apex.parallel.create_syncbn_process_group(args.group_size)).cuda()113sbn.momentum = 1.0114sbn.weight.data = weight_t.clone()115sbn.bias.data = bias_t.clone()116if args.fp16:117sbn.half()118if args.fp64:119sbn.double()120sbn = DDP(sbn)121sbn_opt = optim.SGD(sbn.parameters(), lr=1.0)122inp_sbn = inp_t.clone().requires_grad_()123grad_sbn = grad_output_t.clone().detach()124out_sbn = sbn(inp_sbn[start:finish])125out_sbn.backward(grad_sbn[start:finish])126127sbn_result = True128bn_result = True129130if args.local_rank == 0:131sbn_result = compare("comparing mean: ", mean, m, error) and sbn_result132sbn_result = compare("comparing biased variance: ", var_biased, b_v, error) and sbn_result133134out = syncbn.batchnorm_forward(inp_t, mean, inv_std, weight_t, bias_t)135out_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) + bias_r136137if args.local_rank == 0:138sbn_result = compare("comparing output: ", out, out_r, error) and sbn_result139compare("comparing bn output: ", out_bn, out_r, error)140141grad_output_t = type_tensor(grad)142143grad_output_r = ref_tensor(grad.transpose(1, 0, 2, 3).reshape(feature_size, -1))144grad_output2_r = ref_tensor(grad)145146grad_bias_r = grad_output_r.sum(1)147grad_weight_r = ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)148149mean_dy_r = grad_output_r.mean(1)150mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).mean(1)151152grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)153154mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)155grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu)156157if args.local_rank == 0:158sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result159sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result160sbn_result = compare("comparing mean_dy grad: ", mean_dy, mean_dy_r, error) and sbn_result161sbn_result = compare("comparing mean_dy_xmu grad: ", mean_dy_xmu, mean_dy_xmu_r, error) and sbn_result162sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error) and sbn_result163compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error)164165if args.local_rank == 0:166sbn_result = compare("comparing running_mean: ", bn.module.running_mean.data, sbn.module.running_mean.data, error) and sbn_result167sbn_result = compare("comparing running_variance: ", bn.module.running_var.data, sbn.module.running_var.data, error) and sbn_result168169# execute by both170compare("comparing layers output: ", out_bn[start:finish], out_sbn, error) and sbn_result171compare("comparing layers grad_input: ", inp_bn.grad[start:finish], inp_sbn.grad[start:finish], error) and sbn_result172173bn_opt.step()174sbn_opt.step()175176if args.local_rank == 0:177compare("comparing bn vs sbn bias: ", bn.module.bias, sbn.module.bias, error)178compare("comparing bn vs sbn weight: ", bn.module.weight, sbn.module.weight, error)179180181if sbn_result:182print("====SBN group test passed")183else:184print("*SBN group test failed*")185186187