Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/tests/distributed/DDP/ddp_race_condition_test.py
Views: 794
import torch1import torch.distributed as dist2from torch.nn import Parameter3from torch.nn import Module4from apex.parallel import DistributedDataParallel as DDP5import argparse6import os789parser = argparse.ArgumentParser(description='allreduce hook example')10parser.add_argument("--local_rank", default=0, type=int)11args = parser.parse_args()1213args.distributed = False14if 'WORLD_SIZE' in os.environ:15args.distributed = int(os.environ['WORLD_SIZE']) > 11617if args.distributed:18args.gpu = args.local_rank % torch.cuda.device_count()19torch.cuda.set_device(args.gpu)20torch.distributed.init_process_group(backend='nccl',21init_method='env://')22args.world_size = torch.distributed.get_world_size()2324torch.set_printoptions(precision=10)25torch.manual_seed(args.local_rank)2627class Model(Module):28def __init__(self):29super(Model, self).__init__()30self.a = Parameter(torch.cuda.FloatTensor(4096*4096).fill_(1.0))31self.b = Parameter(torch.cuda.FloatTensor(4096*4096).fill_(2.0))32def forward(self, input):33return (input*self.a)*self.b3435model = Model()36# model = DDP(model, message_size=1, gradient_predivide_factor=8.0)37# model = DDP(model, delay_allreduce=True)38# model = DDP(model, message_size=1, allreduce_trigger_params=[model.b])39model = DDP(model, message_size=1, allreduce_trigger_params=[model.b], num_allreduce_streams=3)4041x = torch.cuda.FloatTensor(4096*4096)4243passed = True44torch.cuda.cudart().cudaProfilerStart()45for i in range(10):46x.fill_(i + args.local_rank) # fill x with new values every iteration for sanity47model.zero_grad()48out = model(x)49loss = out.sum()50# torch.cuda.nvtx.range_push("backward")51loss.backward()52# torch.cuda.nvtx.range_pop()5354# torch.cuda.nvtx.range_push("synchronize() + info")55# torch.cuda.synchronize()56print("i = {}".format(i))57def info(name, param, val):58expected = val*4096*4096*(2.*i+1)/2.59actual = param.grad.data.sum().item()60print(name+": grad.data_ptr() = {}, expected sum {}, got {}".format(61param.grad.data_ptr(), expected, actual))62return (expected == actual)63if not info("model.a", model.module.a, 2.): passed = False64if not info("model.b", model.module.b, 1.): passed = False65# torch.cuda.nvtx.range_pop()66torch.cuda.cudart().cudaProfilerStop()6768print("passed = ", passed)697071