Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/examples/simple/distributed/distributed_data_parallel.py
Views: 794
import torch1import argparse2import os3from apex import amp4# FOR DISTRIBUTED: (can also use torch.nn.parallel.DistributedDataParallel instead)5from apex.parallel import DistributedDataParallel67parser = argparse.ArgumentParser()8# FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied9# automatically by torch.distributed.launch.10parser.add_argument("--local_rank", default=0, type=int)11args = parser.parse_args()1213# FOR DISTRIBUTED: If we are running under torch.distributed.launch,14# the 'WORLD_SIZE' environment variable will also be set automatically.15args.distributed = False16if 'WORLD_SIZE' in os.environ:17args.distributed = int(os.environ['WORLD_SIZE']) > 11819if args.distributed:20# FOR DISTRIBUTED: Set the device according to local_rank.21torch.cuda.set_device(args.local_rank)2223# FOR DISTRIBUTED: Initialize the backend. torch.distributed.launch will provide24# environment variables, and requires that you use init_method=`env://`.25torch.distributed.init_process_group(backend='nccl',26init_method='env://')2728torch.backends.cudnn.benchmark = True2930N, D_in, D_out = 64, 1024, 163132# Each process receives its own batch of "fake input data" and "fake target data."33# The "training loop" in each process just uses this fake batch over and over.34# https://github.com/NVIDIA/apex/tree/master/examples/imagenet provides a more realistic35# example of distributed data sampling for both training and validation.36x = torch.randn(N, D_in, device='cuda')37y = torch.randn(N, D_out, device='cuda')3839model = torch.nn.Linear(D_in, D_out).cuda()40optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)4142model, optimizer = amp.initialize(model, optimizer, opt_level="O1")4344if args.distributed:45# FOR DISTRIBUTED: After amp.initialize, wrap the model with46# apex.parallel.DistributedDataParallel.47model = DistributedDataParallel(model)48# torch.nn.parallel.DistributedDataParallel is also fine, with some added args:49# model = torch.nn.parallel.DistributedDataParallel(model,50# device_ids=[args.local_rank],51# output_device=args.local_rank)5253loss_fn = torch.nn.MSELoss()5455for t in range(500):56optimizer.zero_grad()57y_pred = model(x)58loss = loss_fn(y_pred, y)59with amp.scale_loss(loss, optimizer) as scaled_loss:60scaled_loss.backward()61optimizer.step()6263if args.local_rank == 0:64print("final loss = ", loss)656667