Contact Us!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/tests/distributed/synced_batchnorm/test_groups.py
Views: 810
1
import torch
2
import numpy as np
3
import apex
4
import syncbn
5
import os
6
import argparse
7
import torch.optim as optim
8
9
def compare(desc, inp1, inp2, error):
10
a = inp1.clone().detach().cpu().numpy()
11
b = inp2.clone().detach().cpu().numpy()
12
close = np.allclose(a,b, error, error)
13
if not close:
14
print(desc, close)
15
z = a - b
16
index = (np.abs(z) >= error + error * np.abs(b)).nonzero()
17
print("dif : ", z[index])
18
print("inp1 : ", a[index])
19
print("inp2 : ", b[index])
20
return close
21
22
feature_size = 10
23
space_size = 40
24
batch_size = 32
25
26
27
from apex.parallel import DistributedDataParallel as DDP
28
parser = argparse.ArgumentParser()
29
parser.add_argument("--local_rank", default=0, type=int)
30
parser.add_argument("--fp16", action='store_true', default=False)
31
parser.add_argument("--fp64", action='store_true', default=False)
32
parser.add_argument("--group_size", default=0, type=int)
33
args = parser.parse_args()
34
35
try:
36
args.world_size = int(os.environ['WORLD_SIZE'])
37
except:
38
print("This is a multi-gpu test. To run it please use 'python -m torch.distributed.launch --nproc_per_node=<num gpus> test_groups.py <more options>'")
39
exit(1)
40
41
torch.cuda.set_device(args.local_rank)
42
torch.distributed.init_process_group(backend='nccl', init_method='env://')
43
44
start = (args.local_rank%args.group_size) * batch_size//args.group_size
45
finish = (args.local_rank%args.group_size + 1) * batch_size//args.group_size
46
47
error = 1e-5
48
dtype = np.float32
49
if args.fp16:
50
error = 1e-3
51
dtype = np.float16
52
elif args.fp64:
53
error = 1e-8
54
dtype = np.float64
55
56
57
np.random.seed(18 + args.local_rank//args.group_size)
58
59
inp = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype)
60
grad = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype)
61
weight = np.random.randn(feature_size).astype(dtype)
62
bias = np.random.randn(feature_size).astype(dtype)
63
64
65
type_tensor = torch.cuda.FloatTensor
66
if args.fp16:
67
type_tensor = torch.cuda.HalfTensor
68
if args.fp64:
69
type_tensor = torch.cuda.DoubleTensor
70
71
ref_tensor = torch.cuda.DoubleTensor
72
73
inp_t = type_tensor(inp)
74
weight_t = type_tensor(weight)
75
bias_t = type_tensor(bias)
76
77
inp_r = ref_tensor(inp.transpose(1, 0, 2, 3).reshape(feature_size, -1))
78
inp2_r = ref_tensor(inp)
79
weight_r = ref_tensor(weight).view(-1, 1, 1)
80
bias_r = ref_tensor(bias).view(-1, 1, 1)
81
82
grad_output_t = type_tensor(grad)
83
84
m = inp_r.mean(1)
85
b_v = inp_r.var(1, unbiased=False)
86
unb_v = inp_r.var(1, unbiased=True)
87
88
eps = 1e-5
89
90
mean, var_biased = syncbn.welford_mean_var(inp_t)
91
inv_std = 1.0 / torch.sqrt(var_biased + eps)
92
93
bn = torch.nn.BatchNorm2d(feature_size).cuda()
94
bn.momentum = 1.0
95
bn.weight.data = weight_t.clone()
96
bn.bias.data = bias_t.clone()
97
if args.fp16:
98
bn.half()
99
if args.fp64:
100
bn.double()
101
bn = DDP(bn)
102
inp_bn = inp_t.clone().requires_grad_()
103
grad_bn = grad_output_t.clone().detach()
104
out_bn = bn(inp_bn)
105
out_bn.backward(grad_bn)
106
# compensating the averaging over processes done by DDP
107
# in order to produce mathematically equivalent result
108
# https://github.com/NVIDIA/apex/issues/134#issuecomment-458307368
109
for param in bn.parameters():
110
param.grad = param.grad / args.group_size
111
bn_opt = optim.SGD(bn.parameters(), lr=1.0)
112
113
sbn = apex.parallel.SyncBatchNorm(feature_size, process_group=apex.parallel.create_syncbn_process_group(args.group_size)).cuda()
114
sbn.momentum = 1.0
115
sbn.weight.data = weight_t.clone()
116
sbn.bias.data = bias_t.clone()
117
if args.fp16:
118
sbn.half()
119
if args.fp64:
120
sbn.double()
121
sbn = DDP(sbn)
122
sbn_opt = optim.SGD(sbn.parameters(), lr=1.0)
123
inp_sbn = inp_t.clone().requires_grad_()
124
grad_sbn = grad_output_t.clone().detach()
125
out_sbn = sbn(inp_sbn[start:finish])
126
out_sbn.backward(grad_sbn[start:finish])
127
128
sbn_result = True
129
bn_result = True
130
131
if args.local_rank == 0:
132
sbn_result = compare("comparing mean: ", mean, m, error) and sbn_result
133
sbn_result = compare("comparing biased variance: ", var_biased, b_v, error) and sbn_result
134
135
out = syncbn.batchnorm_forward(inp_t, mean, inv_std, weight_t, bias_t)
136
out_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) + bias_r
137
138
if args.local_rank == 0:
139
sbn_result = compare("comparing output: ", out, out_r, error) and sbn_result
140
compare("comparing bn output: ", out_bn, out_r, error)
141
142
grad_output_t = type_tensor(grad)
143
144
grad_output_r = ref_tensor(grad.transpose(1, 0, 2, 3).reshape(feature_size, -1))
145
grad_output2_r = ref_tensor(grad)
146
147
grad_bias_r = grad_output_r.sum(1)
148
grad_weight_r = ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)
149
150
mean_dy_r = grad_output_r.mean(1)
151
mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).mean(1)
152
153
grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)
154
155
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
156
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu)
157
158
if args.local_rank == 0:
159
sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result
160
sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result
161
sbn_result = compare("comparing mean_dy grad: ", mean_dy, mean_dy_r, error) and sbn_result
162
sbn_result = compare("comparing mean_dy_xmu grad: ", mean_dy_xmu, mean_dy_xmu_r, error) and sbn_result
163
sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error) and sbn_result
164
compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error)
165
166
if args.local_rank == 0:
167
sbn_result = compare("comparing running_mean: ", bn.module.running_mean.data, sbn.module.running_mean.data, error) and sbn_result
168
sbn_result = compare("comparing running_variance: ", bn.module.running_var.data, sbn.module.running_var.data, error) and sbn_result
169
170
# execute by both
171
compare("comparing layers output: ", out_bn[start:finish], out_sbn, error) and sbn_result
172
compare("comparing layers grad_input: ", inp_bn.grad[start:finish], inp_sbn.grad[start:finish], error) and sbn_result
173
174
bn_opt.step()
175
sbn_opt.step()
176
177
if args.local_rank == 0:
178
compare("comparing bn vs sbn bias: ", bn.module.bias, sbn.module.bias, error)
179
compare("comparing bn vs sbn weight: ", bn.module.weight, sbn.module.weight, error)
180
181
182
if sbn_result:
183
print("====SBN group test passed")
184
else:
185
print("*SBN group test failed*")
186
187