CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/tests/distributed/DDP/ddp_race_condition_test.py
Views: 794
1
import torch
2
import torch.distributed as dist
3
from torch.nn import Parameter
4
from torch.nn import Module
5
from apex.parallel import DistributedDataParallel as DDP
6
import argparse
7
import os
8
9
10
parser = argparse.ArgumentParser(description='allreduce hook example')
11
parser.add_argument("--local_rank", default=0, type=int)
12
args = parser.parse_args()
13
14
args.distributed = False
15
if 'WORLD_SIZE' in os.environ:
16
args.distributed = int(os.environ['WORLD_SIZE']) > 1
17
18
if args.distributed:
19
args.gpu = args.local_rank % torch.cuda.device_count()
20
torch.cuda.set_device(args.gpu)
21
torch.distributed.init_process_group(backend='nccl',
22
init_method='env://')
23
args.world_size = torch.distributed.get_world_size()
24
25
torch.set_printoptions(precision=10)
26
torch.manual_seed(args.local_rank)
27
28
class Model(Module):
29
def __init__(self):
30
super(Model, self).__init__()
31
self.a = Parameter(torch.cuda.FloatTensor(4096*4096).fill_(1.0))
32
self.b = Parameter(torch.cuda.FloatTensor(4096*4096).fill_(2.0))
33
def forward(self, input):
34
return (input*self.a)*self.b
35
36
model = Model()
37
# model = DDP(model, message_size=1, gradient_predivide_factor=8.0)
38
# model = DDP(model, delay_allreduce=True)
39
# model = DDP(model, message_size=1, allreduce_trigger_params=[model.b])
40
model = DDP(model, message_size=1, allreduce_trigger_params=[model.b], num_allreduce_streams=3)
41
42
x = torch.cuda.FloatTensor(4096*4096)
43
44
passed = True
45
torch.cuda.cudart().cudaProfilerStart()
46
for i in range(10):
47
x.fill_(i + args.local_rank) # fill x with new values every iteration for sanity
48
model.zero_grad()
49
out = model(x)
50
loss = out.sum()
51
# torch.cuda.nvtx.range_push("backward")
52
loss.backward()
53
# torch.cuda.nvtx.range_pop()
54
55
# torch.cuda.nvtx.range_push("synchronize() + info")
56
# torch.cuda.synchronize()
57
print("i = {}".format(i))
58
def info(name, param, val):
59
expected = val*4096*4096*(2.*i+1)/2.
60
actual = param.grad.data.sum().item()
61
print(name+": grad.data_ptr() = {}, expected sum {}, got {}".format(
62
param.grad.data_ptr(), expected, actual))
63
return (expected == actual)
64
if not info("model.a", model.module.a, 2.): passed = False
65
if not info("model.b", model.module.b, 1.): passed = False
66
# torch.cuda.nvtx.range_pop()
67
torch.cuda.cudart().cudaProfilerStop()
68
69
print("passed = ", passed)
70
71