Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/tests/L0/run_amp/test_fused_sgd.py
Views: 794
import unittest12import functools as ft3import itertools as it45from apex import amp6from apex.amp import _amp_state7import torch8from torch import nn9import torch.nn.functional as F10from torch.nn import Parameter1112from utils import common_init, HALF, FLOAT,\13ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT141516try:17import amp_C18disabled = False19from apex.optimizers import FusedSGD as FusedSGD20except ImportError as err:21print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err)22disabled = True232425class MyModel(torch.nn.Module):26def __init__(self, unique):27super(MyModel, self).__init__()28self.weight0 = Parameter(unique +29torch.arange(2, device='cuda', dtype=torch.float32))30self.weight1 = Parameter(1. + unique + torch.arange(2, device='cuda', dtype=torch.float16))3132@staticmethod33def ops(input, weight0, weight1):34return ((input*(weight0.float()))*(weight1.float())).sum()3536def forward(self, input):37return self.ops(input, self.weight0, self.weight1)3839# Abandon all hope, ye who enter here.4041# This is hands down the ugliest code I have ever written, but it succeeds in testing42# multiple models/optimizers/losses fairly thoroughly. Many of the different test cases43# require slightly divergent code in a way that seems near-impossible to genericize into a simple44# cross product or nested loops.4546class TestMultipleModelsOptimizersLosses(unittest.TestCase):47def setUp(self):48self.x = torch.ones((2), device='cuda', dtype=torch.float32)49common_init(self)5051def tearDown(self):52pass5354@unittest.skipIf(disabled, "amp_C is unavailable")55def test_2models2losses1optimizer(self):56model0 = MyModel(1)57model1 = MyModel(2)5859optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},60{'params' : model1.parameters(), 'lr' : 0.5}],61momentum=0.125)6263reference_grads = []64for i in range(2):65optimizer.zero_grad()66loss0 = model0(self.x)67loss1 = model1(self.x)68loss0.backward()69loss1.backward()7071reference_grads.append([param.grad.data.clone() for param in model0.parameters()] +72[param.grad.data.clone() for param in model1.parameters()])7374optimizer.step()7576final_params = [param.data.clone() for param in model0.parameters()] + \77[param.data.clone() for param in model1.parameters()]7879for materialize_master_grads in (False, True):80for opt_level in ("O0", "O1", "O2", "O3"):81for how_to_zero in ("none", "model", "optimizer"):82for use_multiple_loss_scalers in (False, True):83if opt_level == "O1" or opt_level == "O2":84inject_inf_iters = (-1, 0, 1)85else:86inject_inf_iters = (-1,)8788for inject_inf in inject_inf_iters:89if inject_inf >= 0:90inject_inf_locs = ("fp16", "fp32")91which_backwards = (0, 1)92else:93inject_inf_locs = ("fdsa",)94which_backwards = (None,)9596for inject_inf_loc in inject_inf_locs:97for which_backward in which_backwards:98if use_multiple_loss_scalers:99num_losses = 2100loss_ids = [0, 1]101else:102num_losses = 1103loss_ids = [0, 0]104105if inject_inf >= 0:106iters = 3107else:108iters = 2109110model0 = MyModel(1)111model1 = MyModel(2)112113models = [model0, model1]114115optimizer = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},116{'params' : model1.parameters(), 'lr' : 0.5}],117momentum=0.125,118materialize_master_grads=materialize_master_grads)119120_amp_state.allow_incoming_model_not_fp32 = True121[model0, model1], optimizer = amp.initialize(122[model0, model1],123optimizer,124opt_level=opt_level,125verbosity=0,126cast_model_type=False,127num_losses=num_losses)128_amp_state.allow_incoming_model_not_fp32 = False129130_amp_state.loss_scalers[0]._loss_scale = 4.0131if use_multiple_loss_scalers:132_amp_state.loss_scalers[1]._loss_scale = 16.0133134unskipped = 0135for i in range(iters):136if how_to_zero == "none":137for model in models:138for param in model.parameters():139param.grad = None140elif how_to_zero == "model":141for model in models:142model.zero_grad()143else:144optimizer.zero_grad()145146loss0 = model0(self.x)147loss1 = model1(self.x)148149with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss:150scaled_loss.backward()151if i == inject_inf and which_backward == 0:152if inject_inf_loc == "fp32":153model0.weight0.grad[0] = float('inf')154elif inject_inf_loc == "fp16":155model0.weight1.grad[0] = float('inf')156with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss:157scaled_loss.backward()158if i == inject_inf and which_backward == 1:159if inject_inf_loc == "fp32":160model1.weight0.grad[0] = float('inf')161elif inject_inf_loc == "fp16":162model1.weight1.grad[0] = float('inf')163164if i != inject_inf:165master_params = amp.master_params(optimizer)166for param, reference_grad in zip(master_params, reference_grads[unskipped]):167if opt_level == "O2" and not materialize_master_grads:168continue169else:170self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()),171"opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers))172unskipped += 1173optimizer.step()174175model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()]176for model, master, reference in zip(177model_params,178amp.master_params(optimizer),179final_params):180self.assertTrue(torch.allclose(model, reference))181self.assertTrue(torch.allclose(model, master.to(model.dtype)))182183if opt_level == "O1":184_amp_state.handle._deactivate()185186@unittest.skipIf(disabled, "amp_C is unavailable")187def test_3models2losses1optimizer(self):188189model0 = MyModel(1)190model1 = MyModel(2)191model2 = MyModel(3)192193optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},194{'params' : model1.parameters(), 'lr' : 0.5},195{'params' : model2.parameters(), 'lr' : 0.125}],196momentum=0.125)197198reference_grads = []199for i in range(2):200optimizer.zero_grad()201loss0 = model0(self.x) + model2(self.x)202loss1 = model1(self.x) + model2(self.x)203loss0.backward()204loss1.backward()205206reference_grads.append([param.grad.data.clone() for param in model0.parameters()] +207[param.grad.data.clone() for param in model1.parameters()] +208[param.grad.data.clone() for param in model2.parameters()])209210optimizer.step()211212213final_params = [param.data.clone() for param in model0.parameters()] + \214[param.data.clone() for param in model1.parameters()] + \215[param.data.clone() for param in model2.parameters()]216217for materialize_master_grads in (False, True):218for opt_level in ("O0", "O1", "O2", "O3"):219for how_to_zero in ("none", "model", "optimizer"):220for use_multiple_loss_scalers in (False, True):221if opt_level == "O1" or opt_level == "O2":222inject_inf_iters = (-1, 0, 1)223else:224inject_inf_iters = (-1,)225226for inject_inf in inject_inf_iters:227if inject_inf >= 0:228inject_inf_locs = ("fp16", "fp32")229which_backwards = (0, 1)230else:231inject_inf_locs = ("fdsa",)232which_backwards = (None,)233234for inject_inf_loc in inject_inf_locs:235for which_backward in which_backwards:236if use_multiple_loss_scalers:237num_losses = 2238loss_ids = [0, 1]239else:240num_losses = 1241loss_ids = [0, 0]242243if inject_inf >= 0:244iters = 3245if which_backward == 0:246which_models = (0, 2)247elif which_backward == 1:248which_models = (1, 2)249else:250iters = 2251which_models = (None,)252253for which_model in which_models:254model0 = MyModel(1)255model1 = MyModel(2)256model2 = MyModel(3)257258models = [model0, model1, model2]259260optimizer = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},261{'params' : model1.parameters(), 'lr' : 0.5},262{'params' : model2.parameters(), 'lr' : 0.125}],263momentum=0.125,264materialize_master_grads=materialize_master_grads)265266_amp_state.allow_incoming_model_not_fp32 = True267[model0, model1, model2], optimizer = amp.initialize(268[model0, model1, model2],269optimizer,270opt_level=opt_level,271verbosity=0,272cast_model_type=False,273num_losses=num_losses)274_amp_state.allow_incoming_model_not_fp32 = False275276_amp_state.loss_scalers[0]._loss_scale = 4.0277if use_multiple_loss_scalers:278_amp_state.loss_scalers[1]._loss_scale = 16.0279280unskipped = 0281for i in range(iters):282if how_to_zero == "none":283for model in models:284for param in model.parameters():285param.grad = None286elif how_to_zero == "model":287for model in models:288model.zero_grad()289else:290optimizer.zero_grad()291292loss0 = model0(self.x) + model2(self.x)293loss1 = model1(self.x) + model2(self.x)294295with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss:296scaled_loss.backward()297if i == inject_inf and which_backward == 0:298if which_model == 0:299inj_model = model0300elif which_model == 2:301inj_model = model2302else:303raise RuntimeError(which_model + " invalid for loss 0")304if inject_inf_loc == "fp32":305inj_model.weight0.grad[0] = float('inf')306elif inject_inf_loc == "fp16":307inj_model.weight1.grad[0] = float('inf')308with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss:309scaled_loss.backward()310if i == inject_inf and which_backward == 1:311if which_model == 1:312inj_model = model1313elif which_model == 2:314inj_model = model2315else:316raise RuntimeError(which_model + " invalid for loss 1 ")317if inject_inf_loc == "fp32":318inj_model.weight0.grad[0] = float('inf')319elif inject_inf_loc == "fp16":320inj_model.weight1.grad[0] = float('inf')321322if i != inject_inf:323master_params = amp.master_params(optimizer)324for param, reference_grad in zip(master_params, reference_grads[unskipped]):325if opt_level == "O2" and not materialize_master_grads:326continue327else:328self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()),329"opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} which_model {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, which_model, use_multiple_loss_scalers))330unskipped += 1331332optimizer.step()333334model_params = [p for p in model0.parameters()] + \335[p for p in model1.parameters()] + \336[p for p in model2.parameters()]337for model, master, reference in zip(338model_params,339amp.master_params(optimizer),340final_params):341self.assertTrue(torch.allclose(model, reference))342self.assertTrue(torch.allclose(model, master.to(model.dtype)))343344if opt_level == "O1":345_amp_state.handle._deactivate()346347@unittest.skipIf(disabled, "amp_C is unavailable")348def test_2models2losses2optimizers(self):349model0 = MyModel(1)350model1 = MyModel(2)351352optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],353momentum=0.125)354optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}],355momentum=0.25)356357# Don't do it like this: reference_grads = [[]]*5358# because then it creates a list of 5 references to the same "[]" and appending359# to any of them effectively makes you append to all of them, which multiplies360# the resulting size of reference_grads by 5x and needless to say makes the test fail.361reference_grads = [[], [], [], [], []]362final_params = [None, None, None, None, None]363for i in range(2):364optimizer0.zero_grad()365optimizer1.zero_grad()366loss0 = model0(self.x)367loss1 = model1(self.x)368loss0.backward()369loss1.backward()370371reference_grads[0].append([param.grad.data.clone() for param in model0.parameters()] +372[param.grad.data.clone() for param in model1.parameters()])373374optimizer0.step()375optimizer1.step()376377final_params[0] = [param.data.clone() for param in model0.parameters()] + \378[param.data.clone() for param in model1.parameters()]379380def what_got_skipped(which_iter, which_backward):381if which_iter == 0 and which_backward == 0:382return 1383if which_iter == 0 and which_backward == 1:384return 2385if which_iter == 1 and which_backward == 0:386return 3387if which_iter == 1 and which_backward == 1:388return 4389return 0390391for which_iter in (0,1):392for which_backward in (0,1):393model0 = MyModel(1)394model1 = MyModel(2)395396optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],397momentum=0.125)398optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}],399momentum=0.25)400401for i in range(3):402optimizer0.zero_grad()403optimizer1.zero_grad()404loss0 = model0(self.x)405loss1 = model1(self.x)406loss0.backward()407loss1.backward()408409if i != which_iter:410reference_grads[what_got_skipped(which_iter, which_backward)].append(411[param.grad.data.clone() for param in model0.parameters()] +412[param.grad.data.clone() for param in model1.parameters()])413414if i == which_iter:415if which_backward == 0:416optimizer1.step()417else:418optimizer0.step()419else:420optimizer0.step()421optimizer1.step()422423final_params[what_got_skipped(which_iter, which_backward)] = \424[param.data.clone() for param in model0.parameters()] + \425[param.data.clone() for param in model1.parameters()]426427for materialize_master_grads in (False, True):428for opt_level in ("O0", "O1", "O2", "O3"):429for how_to_zero in ("none", "model", "optimizer"):430for use_multiple_loss_scalers in (False, True):431if opt_level == "O1" or opt_level == "O2":432inject_inf_iters = (-1, 0, 1)433else:434inject_inf_iters = (-1,)435436for inject_inf in inject_inf_iters:437if inject_inf >= 0:438inject_inf_locs = ("fp16", "fp32")439which_backwards = (0, 1)440else:441inject_inf_locs = ("fdsa",)442which_backwards = (None,)443444for inject_inf_loc in inject_inf_locs:445for which_backward in which_backwards:446if use_multiple_loss_scalers:447num_losses = 2448loss_ids = [0, 1]449else:450num_losses = 1451loss_ids = [0, 0]452453if inject_inf >= 0:454iters = 3455else:456iters = 2457458model0 = MyModel(1)459model1 = MyModel(2)460461models = [model0, model1]462463optimizer0 = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25}],464momentum=0.125, materialize_master_grads=materialize_master_grads)465optimizer1 = FusedSGD([{'params' : model1.parameters(), 'lr' : 0.5}],466momentum=0.25, materialize_master_grads=materialize_master_grads)467468_amp_state.allow_incoming_model_not_fp32 = True469[model0, model1], [optimizer0, optimizer1] = amp.initialize(470[model0, model1],471[optimizer0, optimizer1],472opt_level=opt_level,473verbosity=0,474cast_model_type=False,475num_losses=num_losses)476_amp_state.allow_incoming_model_not_fp32 = False477478_amp_state.loss_scalers[0]._loss_scale = 4.0479if use_multiple_loss_scalers:480_amp_state.loss_scalers[1]._loss_scale = 16.0481482unskipped = 0483for i in range(iters):484if how_to_zero == "none":485for model in models:486for param in model.parameters():487param.grad = None488elif how_to_zero == "model":489for model in models:490model.zero_grad()491else:492optimizer0.zero_grad()493optimizer1.zero_grad()494495loss0 = model0(self.x)496loss1 = model1(self.x)497498with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss:499scaled_loss.backward()500if i == inject_inf and which_backward == 0:501if inject_inf_loc == "fp32":502model0.weight0.grad[0] = float('inf')503elif inject_inf_loc == "fp16":504model0.weight1.grad[0] = float('inf')505with amp.scale_loss(loss1, optimizer1, loss_id=loss_ids[1]) as scaled_loss:506scaled_loss.backward()507if i == inject_inf and which_backward == 1:508if inject_inf_loc == "fp32":509model1.weight0.grad[0] = float('inf')510elif inject_inf_loc == "fp16":511model1.weight1.grad[0] = float('inf')512513# print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers))514515if i != inject_inf:516master_params = list(amp.master_params(optimizer0)) + \517list(amp.master_params(optimizer1))518for param, reference_grad in zip(master_params,519reference_grads[what_got_skipped(inject_inf, which_backward)][unskipped]):520if opt_level == "O2" and not materialize_master_grads:521continue522else:523self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()))524unskipped += 1525526optimizer0.step()527optimizer1.step()528529model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()]530master_params = [p for p in amp.master_params(optimizer0)] + \531[p for p in amp.master_params(optimizer1)]532for model, master, reference in zip(533model_params,534master_params,535final_params[what_got_skipped(inject_inf, which_backward)]):536self.assertTrue(torch.allclose(model, reference))537self.assertTrue(torch.allclose(model, master.to(model.dtype)))538539if opt_level == "O1":540_amp_state.handle._deactivate()541542@unittest.skipIf(disabled, "amp_C is unavailable")543def test_3models2losses2optimizers(self):544model0 = MyModel(1)545model1 = MyModel(2)546model2 = MyModel(3)547548optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},549{'params' : model1.parameters(), 'lr' : 1.0}],550momentum=0.5)551optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}],552momentum=0.25)553554# Again, can't do this: reference_grads = [[]]*9555reference_grads = [[], [], [], [], [], [], [], [], []]556final_params = [None, None, None, None, None, None, None, None, None]557for i in range(2):558optimizer0.zero_grad()559optimizer1.zero_grad()560loss0 = model0(self.x) + model1(self.x)561loss1 = model2(self.x) + model1(self.x)562loss0.backward()563loss1.backward()564565reference_grads[0].append([param.grad.data.clone() for param in model0.parameters()] +566[param.grad.data.clone() for param in model1.parameters()])567568optimizer0.step()569optimizer1.step()570571final_params[0] = \572[param.data.clone() for param in model0.parameters()] + \573[param.data.clone() for param in model1.parameters()] + \574[param.data.clone() for param in model2.parameters()]575576def what_got_skipped(which_iter, which_backward, which_model):577if which_iter == 0:578if which_backward == 0:579if which_model == 0:580return 1581if which_model == 1:582return 2583if which_backward == 1:584if which_model == 2:585return 3586if which_model == 1:587return 4588if which_iter == 1:589if which_backward == 0:590if which_model == 0:591return 5592if which_model == 1:593return 6594if which_backward == 1:595if which_model == 2:596return 7597if which_model == 1:598return 8599return 0600601for which_iter in (0,1):602for which_backward in (0,1):603if which_backward == 0:604which_models = (0,1)605if which_backward == 1:606which_models = (2,1)607for which_model in which_models:608609model0 = MyModel(1)610model1 = MyModel(2)611model2 = MyModel(3)612613optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},614{'params' : model1.parameters(), 'lr' : 1.0}],615momentum=0.5)616optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}],617momentum=0.25)618619for i in range(3):620optimizer0.zero_grad()621optimizer1.zero_grad()622loss0 = model0(self.x) + model1(self.x)623loss1 = model2(self.x) + model1(self.x)624loss0.backward()625loss1.backward()626627if i != which_iter:628reference_grads[what_got_skipped(which_iter,629which_backward, which_model)].append(630[param.grad.data.clone() for param in model0.parameters()] +631[param.grad.data.clone() for param in model1.parameters()])632633if i == which_iter:634if which_backward == 0:635# if which_model == 0:636optimizer1.step()637# if which_model == 1:638# optimizer1.step()639if which_backward == 1:640# if which_model == 2:641# optimizer0.step()642# if which_model == 1:643continue644else:645optimizer0.step()646optimizer1.step()647648final_params[what_got_skipped(which_iter, which_backward, which_model)] = \649[param.data.clone() for param in model0.parameters()] + \650[param.data.clone() for param in model1.parameters()] + \651[param.data.clone() for param in model2.parameters()]652653for materialize_master_grads in (False, True):654for opt_level in ("O0", "O1", "O2", "O3"):655for how_to_zero in ("none", "model", "optimizer"):656for use_multiple_loss_scalers in (False, True):657if opt_level == "O1" or opt_level == "O2":658inject_inf_iters = (-1, 0, 1)659else:660inject_inf_iters = (-1,)661662for inject_inf in inject_inf_iters:663if inject_inf >= 0:664inject_inf_locs = ("fp16", "fp32")665which_backwards = (0, 1)666else:667inject_inf_locs = ("fdsa",)668which_backwards = (None,)669670for inject_inf_loc in inject_inf_locs:671for which_backward in which_backwards:672if use_multiple_loss_scalers:673num_losses = 2674loss_ids = [0, 1]675else:676num_losses = 1677loss_ids = [0, 0]678679if inject_inf >= 0:680iters = 3681if which_backward == 0:682which_models = (0, 1)683elif which_backward == 1:684which_models = (2, 1)685else:686iters = 2687which_models = (None,)688689for which_model in which_models:690model0 = MyModel(1)691model1 = MyModel(2)692model2 = MyModel(3)693694models = [model0, model1, model2]695696optimizer0 = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},697{'params' : model1.parameters(), 'lr' : 1.0}],698momentum=0.5, materialize_master_grads=materialize_master_grads)699optimizer1 = FusedSGD([{'params' : model2.parameters(), 'lr' : 0.5}],700momentum=0.25, materialize_master_grads=materialize_master_grads)701702_amp_state.allow_incoming_model_not_fp32 = True703[model0, model1, model2], [optimizer0, optimizer1] = amp.initialize(704[model0, model1, model2],705[optimizer0, optimizer1],706opt_level=opt_level,707verbosity=0,708cast_model_type=False,709num_losses=num_losses)710_amp_state.allow_incoming_model_not_fp32 = False711712_amp_state.loss_scalers[0]._loss_scale = 4.0713if use_multiple_loss_scalers:714_amp_state.loss_scalers[1]._loss_scale = 16.0715716unskipped = 0717for i in range(iters):718if how_to_zero == "none":719for model in models:720for param in model.parameters():721param.grad = None722elif how_to_zero == "model":723for model in models:724model.zero_grad()725else:726optimizer0.zero_grad()727optimizer1.zero_grad()728729loss0 = model0(self.x) + model1(self.x)730loss1 = model2(self.x) + model1(self.x)731732with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss:733scaled_loss.backward()734if i == inject_inf and which_backward == 0:735if which_model == 0:736inj_model = model0737elif which_model == 1:738inj_model = model1739else:740raise RuntimeError(which_model + " invalid for loss 0")741if inject_inf_loc == "fp32":742inj_model.weight0.grad[0] = float('inf')743elif inject_inf_loc == "fp16":744inj_model.weight1.grad[0] = float('inf')745with amp.scale_loss(loss1, [optimizer0, optimizer1], loss_id=loss_ids[1]) as scaled_loss:746scaled_loss.backward()747if i == inject_inf and which_backward == 1:748if which_model == 2:749inj_model = model2750elif which_model == 1:751inj_model = model1752else:753raise RuntimeError(which_model + " invalid for loss 1 ")754if inject_inf_loc == "fp32":755inj_model.weight0.grad[0] = float('inf')756elif inject_inf_loc == "fp16":757inj_model.weight1.grad[0] = float('inf')758759if i != inject_inf:760master_params = list(amp.master_params(optimizer0)) + \761list(amp.master_params(optimizer1))762for param, reference_grad in zip(master_params,763reference_grads[what_got_skipped(inject_inf,764which_backward, which_model)][unskipped]):765if opt_level == "O2" and not materialize_master_grads:766continue767else:768self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()))769unskipped += 1770771optimizer0.step()772optimizer1.step()773774model_params = [p for p in model0.parameters()] + \775[p for p in model1.parameters()] + \776[p for p in model2.parameters()]777master_params = [p for p in amp.master_params(optimizer0)] + \778[p for p in amp.master_params(optimizer1)]779780# print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {} which_model {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers, which_model))781782for model, master, reference in zip(783model_params,784master_params,785final_params[what_got_skipped(inject_inf, which_backward, which_model)]):786self.assertTrue(torch.allclose(model, reference))787self.assertTrue(torch.allclose(model, master.to(model.dtype)))788789if opt_level == "O1":790_amp_state.handle._deactivate()791792if __name__ == '__main__':793unittest.main()794795796