CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/tests/L0/run_optimizers/test_dist_adam.py
Views: 792
1
import argparse
2
import random
3
import sys
4
5
import torch
6
from torch.nn.parallel import DistributedDataParallel as DDP
7
8
from apex import amp
9
from apex.optimizers import FusedAdam
10
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam
11
12
13
class TestModel(torch.nn.Module):
14
def __init__(self, args):
15
super(TestModel, self).__init__()
16
17
self.linear = torch.nn.Sequential(*[torch.nn.Linear(args.dim, args.dim, bias=args.bias) for _ in range(args.layers)])
18
19
def forward(self, x):
20
return self.linear(x)
21
22
def setup(args):
23
## Model
24
ref_model = TestModel(args).cuda()
25
dist_model = TestModel(args).cuda()
26
27
# Same weights
28
with torch.no_grad():
29
for dp, rp in zip(dist_model.parameters(), ref_model.parameters()):
30
dp.data.copy_(rp.data)
31
32
dist_model = dist_model.half()
33
34
35
## Optimizer
36
# same hyperparameters
37
ref_opt_args = { 'lr': 1e-3, 'eps': 1e-6, 'weight_decay': 0.01 }
38
ref_opt = FusedAdam(ref_model.parameters(), **ref_opt_args)
39
40
dist_opt_args = ref_opt_args.copy()
41
dist_opt_args.update( {'overlap_reductions' : False} )
42
dist_opt_args.update( {'process_group_size' : args.n_gpu} )
43
dist_opt_args.update( {'dwu_group_size' : args.dwu_group_size} )
44
dist_opt_args.update( {'dwu_num_blocks' : 1} )
45
dist_opt_args.update( {'dwu_num_chunks' : 1} )
46
dist_opt = DistributedFusedAdam(dist_model.parameters(), **dist_opt_args)
47
dist_opt.set_global_scale(1.)
48
49
## amp-init
50
amp_args = { 'loss_scale' : 'dynamic' , 'opt_level' : 'O2'}
51
ref_model, ref_opt = amp.initialize(ref_model, ref_opt, **amp_args)
52
53
54
## DDP
55
ref_model = DDP(ref_model, device_ids=[args.rank])
56
with torch.no_grad():
57
for dp in dist_model.parameters():
58
torch.distributed.broadcast(dp.data, src=0)
59
for rp in ref_model.parameters():
60
torch.distributed.broadcast(rp.data, src=0)
61
torch.cuda.synchronize()
62
torch.distributed.barrier()
63
if get_rank() == 0:
64
print(f'dist opt with {args.n_gpu} GPUs')
65
66
return ref_model, ref_opt, dist_model, dist_opt
67
68
def parse_args():
69
parser = argparse.ArgumentParser()
70
71
parser.add_argument('--local_rank', type=int, default=-1)
72
parser.add_argument('--steps', type=int, default=20)
73
parser.add_argument('--batch', type=int, default=32)
74
parser.add_argument('--dim', type=int, default=4)
75
parser.add_argument('--layers', type=int, default=2)
76
parser.add_argument('--bias', action='store_true')
77
parser.add_argument('--atol', type=float, default=1e-3)
78
parser.add_argument('--rtol', type=float, default=1)
79
parser.add_argument('--dwu_group_size', type=float, default=1)
80
81
args = parser.parse_args()
82
83
return args
84
85
def setup_env(args):
86
torch.cuda.set_device(args.local_rank)
87
torch.distributed.init_process_group(backend='nccl', init_method='env://')
88
args.rank = torch.distributed.get_rank()
89
args.n_gpu = torch.distributed.get_world_size()
90
91
seed = 42 + get_rank()
92
93
random.seed(seed)
94
torch.manual_seed(seed)
95
96
return args
97
98
def get_rank():
99
return torch.distributed.get_rank()
100
101
def main():
102
args = parse_args()
103
args = setup_env(args)
104
tol_args = { 'atol' : args.atol, 'rtol' : args.rtol }
105
106
torch.set_printoptions(precision=16)
107
108
ref_model, ref_opt, dist_model, dist_opt = setup(args)
109
110
# lazy_init not called yet, initialize stash
111
stash = ref_opt._amp_stash
112
stash.all_fp16_params, stash.all_fp32_from_fp16_params = [], []
113
114
# make sure everything from _first_step_init_ is ready before training
115
# e.g. registering allreduce_hook
116
# so that gradients are copied/reduced when necessary
117
dist_opt._init_everything()
118
119
for i in range(args.steps):
120
x_ref = torch.randn(args.batch, args.dim, dtype=torch.half).cuda().requires_grad_(True)
121
x_dist = x_ref.clone().detach().requires_grad_(True)
122
123
if get_rank() == 0:
124
print(f'[{i}] Checking input')
125
#print("x_ref:", x_ref.flatten()[:10])
126
#print("x_dist:", x_dist.flatten()[:10])
127
assert(torch.allclose(x_ref, x_dist, **tol_args))
128
129
y_ref = ref_model(x_ref).half()
130
y_dist = dist_model(x_dist)
131
132
if get_rank() == 0:
133
print(f'[{i}] Checking output')
134
#print("y_ref:", y_ref.flatten()[:10])
135
#print("y_dist:", y_dist.flatten()[:10])
136
assert(torch.allclose(y_ref, y_dist, **tol_args))
137
138
dy = torch.randn_like(y_ref)
139
140
y_ref.backward(dy)
141
y_dist.backward(dy)
142
143
if get_rank() == 0:
144
print(f'[{i}] Checking gradients')
145
torch.distributed.barrier()
146
torch.cuda.synchronize()
147
assert(torch.allclose(x_ref.grad, x_dist.grad, **tol_args))
148
149
# gradient all-reduce within distributed optimizer
150
dist_opt.complete_reductions()
151
152
if get_rank() == 0:
153
print(f'[{i}] Stepping')
154
ref_opt.step()
155
dist_opt.step()
156
157
torch.cuda.synchronize()
158
torch.distributed.barrier()
159
print('Checking new weights')
160
if get_rank() == 0:
161
print("ref param:", ref_model.module.linear[0].weight)
162
print("dist param:", dist_model.linear[0].weight)
163
164
for i, (rp, dp) in enumerate(zip(ref_model.parameters(), dist_model.parameters())):
165
if not torch.allclose(rp, dp, **tol_args):
166
if get_rank() == 0:
167
print(f'Rank: {get_rank()}, Param: {i}')
168
print(f'ref: {rp.sum().item()}, dist: {dp.sum().item()}')
169
print(rp)
170
print(dp)
171
172
print(torch.abs(rp-dp) > tol_args['atol'])
173
sys.exit(0)
174
175
# zero grads
176
for rp, dp in zip(ref_model.parameters(), dist_model.parameters()):
177
rp.grad = None
178
dp.grad = None
179
180
181
if __name__ == "__main__":
182
main()
183
184
185