CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/examples/simple/distributed/distributed_data_parallel.py
Views: 794
1
import torch
2
import argparse
3
import os
4
from apex import amp
5
# FOR DISTRIBUTED: (can also use torch.nn.parallel.DistributedDataParallel instead)
6
from apex.parallel import DistributedDataParallel
7
8
parser = argparse.ArgumentParser()
9
# FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied
10
# automatically by torch.distributed.launch.
11
parser.add_argument("--local_rank", default=0, type=int)
12
args = parser.parse_args()
13
14
# FOR DISTRIBUTED: If we are running under torch.distributed.launch,
15
# the 'WORLD_SIZE' environment variable will also be set automatically.
16
args.distributed = False
17
if 'WORLD_SIZE' in os.environ:
18
args.distributed = int(os.environ['WORLD_SIZE']) > 1
19
20
if args.distributed:
21
# FOR DISTRIBUTED: Set the device according to local_rank.
22
torch.cuda.set_device(args.local_rank)
23
24
# FOR DISTRIBUTED: Initialize the backend. torch.distributed.launch will provide
25
# environment variables, and requires that you use init_method=`env://`.
26
torch.distributed.init_process_group(backend='nccl',
27
init_method='env://')
28
29
torch.backends.cudnn.benchmark = True
30
31
N, D_in, D_out = 64, 1024, 16
32
33
# Each process receives its own batch of "fake input data" and "fake target data."
34
# The "training loop" in each process just uses this fake batch over and over.
35
# https://github.com/NVIDIA/apex/tree/master/examples/imagenet provides a more realistic
36
# example of distributed data sampling for both training and validation.
37
x = torch.randn(N, D_in, device='cuda')
38
y = torch.randn(N, D_out, device='cuda')
39
40
model = torch.nn.Linear(D_in, D_out).cuda()
41
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
42
43
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
44
45
if args.distributed:
46
# FOR DISTRIBUTED: After amp.initialize, wrap the model with
47
# apex.parallel.DistributedDataParallel.
48
model = DistributedDataParallel(model)
49
# torch.nn.parallel.DistributedDataParallel is also fine, with some added args:
50
# model = torch.nn.parallel.DistributedDataParallel(model,
51
# device_ids=[args.local_rank],
52
# output_device=args.local_rank)
53
54
loss_fn = torch.nn.MSELoss()
55
56
for t in range(500):
57
optimizer.zero_grad()
58
y_pred = model(x)
59
loss = loss_fn(y_pred, y)
60
with amp.scale_loss(loss, optimizer) as scaled_loss:
61
scaled_loss.backward()
62
optimizer.step()
63
64
if args.local_rank == 0:
65
print("final loss = ", loss)
66
67