CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/tests/L0/run_amp/test_fused_sgd.py
Views: 794
1
import unittest
2
3
import functools as ft
4
import itertools as it
5
6
from apex import amp
7
from apex.amp import _amp_state
8
import torch
9
from torch import nn
10
import torch.nn.functional as F
11
from torch.nn import Parameter
12
13
from utils import common_init, HALF, FLOAT,\
14
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
15
16
17
try:
18
import amp_C
19
disabled = False
20
from apex.optimizers import FusedSGD as FusedSGD
21
except ImportError as err:
22
print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err)
23
disabled = True
24
25
26
class MyModel(torch.nn.Module):
27
def __init__(self, unique):
28
super(MyModel, self).__init__()
29
self.weight0 = Parameter(unique +
30
torch.arange(2, device='cuda', dtype=torch.float32))
31
self.weight1 = Parameter(1. + unique + torch.arange(2, device='cuda', dtype=torch.float16))
32
33
@staticmethod
34
def ops(input, weight0, weight1):
35
return ((input*(weight0.float()))*(weight1.float())).sum()
36
37
def forward(self, input):
38
return self.ops(input, self.weight0, self.weight1)
39
40
# Abandon all hope, ye who enter here.
41
42
# This is hands down the ugliest code I have ever written, but it succeeds in testing
43
# multiple models/optimizers/losses fairly thoroughly. Many of the different test cases
44
# require slightly divergent code in a way that seems near-impossible to genericize into a simple
45
# cross product or nested loops.
46
47
class TestMultipleModelsOptimizersLosses(unittest.TestCase):
48
def setUp(self):
49
self.x = torch.ones((2), device='cuda', dtype=torch.float32)
50
common_init(self)
51
52
def tearDown(self):
53
pass
54
55
@unittest.skipIf(disabled, "amp_C is unavailable")
56
def test_2models2losses1optimizer(self):
57
model0 = MyModel(1)
58
model1 = MyModel(2)
59
60
optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
61
{'params' : model1.parameters(), 'lr' : 0.5}],
62
momentum=0.125)
63
64
reference_grads = []
65
for i in range(2):
66
optimizer.zero_grad()
67
loss0 = model0(self.x)
68
loss1 = model1(self.x)
69
loss0.backward()
70
loss1.backward()
71
72
reference_grads.append([param.grad.data.clone() for param in model0.parameters()] +
73
[param.grad.data.clone() for param in model1.parameters()])
74
75
optimizer.step()
76
77
final_params = [param.data.clone() for param in model0.parameters()] + \
78
[param.data.clone() for param in model1.parameters()]
79
80
for materialize_master_grads in (False, True):
81
for opt_level in ("O0", "O1", "O2", "O3"):
82
for how_to_zero in ("none", "model", "optimizer"):
83
for use_multiple_loss_scalers in (False, True):
84
if opt_level == "O1" or opt_level == "O2":
85
inject_inf_iters = (-1, 0, 1)
86
else:
87
inject_inf_iters = (-1,)
88
89
for inject_inf in inject_inf_iters:
90
if inject_inf >= 0:
91
inject_inf_locs = ("fp16", "fp32")
92
which_backwards = (0, 1)
93
else:
94
inject_inf_locs = ("fdsa",)
95
which_backwards = (None,)
96
97
for inject_inf_loc in inject_inf_locs:
98
for which_backward in which_backwards:
99
if use_multiple_loss_scalers:
100
num_losses = 2
101
loss_ids = [0, 1]
102
else:
103
num_losses = 1
104
loss_ids = [0, 0]
105
106
if inject_inf >= 0:
107
iters = 3
108
else:
109
iters = 2
110
111
model0 = MyModel(1)
112
model1 = MyModel(2)
113
114
models = [model0, model1]
115
116
optimizer = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},
117
{'params' : model1.parameters(), 'lr' : 0.5}],
118
momentum=0.125,
119
materialize_master_grads=materialize_master_grads)
120
121
_amp_state.allow_incoming_model_not_fp32 = True
122
[model0, model1], optimizer = amp.initialize(
123
[model0, model1],
124
optimizer,
125
opt_level=opt_level,
126
verbosity=0,
127
cast_model_type=False,
128
num_losses=num_losses)
129
_amp_state.allow_incoming_model_not_fp32 = False
130
131
_amp_state.loss_scalers[0]._loss_scale = 4.0
132
if use_multiple_loss_scalers:
133
_amp_state.loss_scalers[1]._loss_scale = 16.0
134
135
unskipped = 0
136
for i in range(iters):
137
if how_to_zero == "none":
138
for model in models:
139
for param in model.parameters():
140
param.grad = None
141
elif how_to_zero == "model":
142
for model in models:
143
model.zero_grad()
144
else:
145
optimizer.zero_grad()
146
147
loss0 = model0(self.x)
148
loss1 = model1(self.x)
149
150
with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss:
151
scaled_loss.backward()
152
if i == inject_inf and which_backward == 0:
153
if inject_inf_loc == "fp32":
154
model0.weight0.grad[0] = float('inf')
155
elif inject_inf_loc == "fp16":
156
model0.weight1.grad[0] = float('inf')
157
with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss:
158
scaled_loss.backward()
159
if i == inject_inf and which_backward == 1:
160
if inject_inf_loc == "fp32":
161
model1.weight0.grad[0] = float('inf')
162
elif inject_inf_loc == "fp16":
163
model1.weight1.grad[0] = float('inf')
164
165
if i != inject_inf:
166
master_params = amp.master_params(optimizer)
167
for param, reference_grad in zip(master_params, reference_grads[unskipped]):
168
if opt_level == "O2" and not materialize_master_grads:
169
continue
170
else:
171
self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()),
172
"opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers))
173
unskipped += 1
174
optimizer.step()
175
176
model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()]
177
for model, master, reference in zip(
178
model_params,
179
amp.master_params(optimizer),
180
final_params):
181
self.assertTrue(torch.allclose(model, reference))
182
self.assertTrue(torch.allclose(model, master.to(model.dtype)))
183
184
if opt_level == "O1":
185
_amp_state.handle._deactivate()
186
187
@unittest.skipIf(disabled, "amp_C is unavailable")
188
def test_3models2losses1optimizer(self):
189
190
model0 = MyModel(1)
191
model1 = MyModel(2)
192
model2 = MyModel(3)
193
194
optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
195
{'params' : model1.parameters(), 'lr' : 0.5},
196
{'params' : model2.parameters(), 'lr' : 0.125}],
197
momentum=0.125)
198
199
reference_grads = []
200
for i in range(2):
201
optimizer.zero_grad()
202
loss0 = model0(self.x) + model2(self.x)
203
loss1 = model1(self.x) + model2(self.x)
204
loss0.backward()
205
loss1.backward()
206
207
reference_grads.append([param.grad.data.clone() for param in model0.parameters()] +
208
[param.grad.data.clone() for param in model1.parameters()] +
209
[param.grad.data.clone() for param in model2.parameters()])
210
211
optimizer.step()
212
213
214
final_params = [param.data.clone() for param in model0.parameters()] + \
215
[param.data.clone() for param in model1.parameters()] + \
216
[param.data.clone() for param in model2.parameters()]
217
218
for materialize_master_grads in (False, True):
219
for opt_level in ("O0", "O1", "O2", "O3"):
220
for how_to_zero in ("none", "model", "optimizer"):
221
for use_multiple_loss_scalers in (False, True):
222
if opt_level == "O1" or opt_level == "O2":
223
inject_inf_iters = (-1, 0, 1)
224
else:
225
inject_inf_iters = (-1,)
226
227
for inject_inf in inject_inf_iters:
228
if inject_inf >= 0:
229
inject_inf_locs = ("fp16", "fp32")
230
which_backwards = (0, 1)
231
else:
232
inject_inf_locs = ("fdsa",)
233
which_backwards = (None,)
234
235
for inject_inf_loc in inject_inf_locs:
236
for which_backward in which_backwards:
237
if use_multiple_loss_scalers:
238
num_losses = 2
239
loss_ids = [0, 1]
240
else:
241
num_losses = 1
242
loss_ids = [0, 0]
243
244
if inject_inf >= 0:
245
iters = 3
246
if which_backward == 0:
247
which_models = (0, 2)
248
elif which_backward == 1:
249
which_models = (1, 2)
250
else:
251
iters = 2
252
which_models = (None,)
253
254
for which_model in which_models:
255
model0 = MyModel(1)
256
model1 = MyModel(2)
257
model2 = MyModel(3)
258
259
models = [model0, model1, model2]
260
261
optimizer = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},
262
{'params' : model1.parameters(), 'lr' : 0.5},
263
{'params' : model2.parameters(), 'lr' : 0.125}],
264
momentum=0.125,
265
materialize_master_grads=materialize_master_grads)
266
267
_amp_state.allow_incoming_model_not_fp32 = True
268
[model0, model1, model2], optimizer = amp.initialize(
269
[model0, model1, model2],
270
optimizer,
271
opt_level=opt_level,
272
verbosity=0,
273
cast_model_type=False,
274
num_losses=num_losses)
275
_amp_state.allow_incoming_model_not_fp32 = False
276
277
_amp_state.loss_scalers[0]._loss_scale = 4.0
278
if use_multiple_loss_scalers:
279
_amp_state.loss_scalers[1]._loss_scale = 16.0
280
281
unskipped = 0
282
for i in range(iters):
283
if how_to_zero == "none":
284
for model in models:
285
for param in model.parameters():
286
param.grad = None
287
elif how_to_zero == "model":
288
for model in models:
289
model.zero_grad()
290
else:
291
optimizer.zero_grad()
292
293
loss0 = model0(self.x) + model2(self.x)
294
loss1 = model1(self.x) + model2(self.x)
295
296
with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss:
297
scaled_loss.backward()
298
if i == inject_inf and which_backward == 0:
299
if which_model == 0:
300
inj_model = model0
301
elif which_model == 2:
302
inj_model = model2
303
else:
304
raise RuntimeError(which_model + " invalid for loss 0")
305
if inject_inf_loc == "fp32":
306
inj_model.weight0.grad[0] = float('inf')
307
elif inject_inf_loc == "fp16":
308
inj_model.weight1.grad[0] = float('inf')
309
with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss:
310
scaled_loss.backward()
311
if i == inject_inf and which_backward == 1:
312
if which_model == 1:
313
inj_model = model1
314
elif which_model == 2:
315
inj_model = model2
316
else:
317
raise RuntimeError(which_model + " invalid for loss 1 ")
318
if inject_inf_loc == "fp32":
319
inj_model.weight0.grad[0] = float('inf')
320
elif inject_inf_loc == "fp16":
321
inj_model.weight1.grad[0] = float('inf')
322
323
if i != inject_inf:
324
master_params = amp.master_params(optimizer)
325
for param, reference_grad in zip(master_params, reference_grads[unskipped]):
326
if opt_level == "O2" and not materialize_master_grads:
327
continue
328
else:
329
self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()),
330
"opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} which_model {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, which_model, use_multiple_loss_scalers))
331
unskipped += 1
332
333
optimizer.step()
334
335
model_params = [p for p in model0.parameters()] + \
336
[p for p in model1.parameters()] + \
337
[p for p in model2.parameters()]
338
for model, master, reference in zip(
339
model_params,
340
amp.master_params(optimizer),
341
final_params):
342
self.assertTrue(torch.allclose(model, reference))
343
self.assertTrue(torch.allclose(model, master.to(model.dtype)))
344
345
if opt_level == "O1":
346
_amp_state.handle._deactivate()
347
348
@unittest.skipIf(disabled, "amp_C is unavailable")
349
def test_2models2losses2optimizers(self):
350
model0 = MyModel(1)
351
model1 = MyModel(2)
352
353
optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
354
momentum=0.125)
355
optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}],
356
momentum=0.25)
357
358
# Don't do it like this: reference_grads = [[]]*5
359
# because then it creates a list of 5 references to the same "[]" and appending
360
# to any of them effectively makes you append to all of them, which multiplies
361
# the resulting size of reference_grads by 5x and needless to say makes the test fail.
362
reference_grads = [[], [], [], [], []]
363
final_params = [None, None, None, None, None]
364
for i in range(2):
365
optimizer0.zero_grad()
366
optimizer1.zero_grad()
367
loss0 = model0(self.x)
368
loss1 = model1(self.x)
369
loss0.backward()
370
loss1.backward()
371
372
reference_grads[0].append([param.grad.data.clone() for param in model0.parameters()] +
373
[param.grad.data.clone() for param in model1.parameters()])
374
375
optimizer0.step()
376
optimizer1.step()
377
378
final_params[0] = [param.data.clone() for param in model0.parameters()] + \
379
[param.data.clone() for param in model1.parameters()]
380
381
def what_got_skipped(which_iter, which_backward):
382
if which_iter == 0 and which_backward == 0:
383
return 1
384
if which_iter == 0 and which_backward == 1:
385
return 2
386
if which_iter == 1 and which_backward == 0:
387
return 3
388
if which_iter == 1 and which_backward == 1:
389
return 4
390
return 0
391
392
for which_iter in (0,1):
393
for which_backward in (0,1):
394
model0 = MyModel(1)
395
model1 = MyModel(2)
396
397
optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
398
momentum=0.125)
399
optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}],
400
momentum=0.25)
401
402
for i in range(3):
403
optimizer0.zero_grad()
404
optimizer1.zero_grad()
405
loss0 = model0(self.x)
406
loss1 = model1(self.x)
407
loss0.backward()
408
loss1.backward()
409
410
if i != which_iter:
411
reference_grads[what_got_skipped(which_iter, which_backward)].append(
412
[param.grad.data.clone() for param in model0.parameters()] +
413
[param.grad.data.clone() for param in model1.parameters()])
414
415
if i == which_iter:
416
if which_backward == 0:
417
optimizer1.step()
418
else:
419
optimizer0.step()
420
else:
421
optimizer0.step()
422
optimizer1.step()
423
424
final_params[what_got_skipped(which_iter, which_backward)] = \
425
[param.data.clone() for param in model0.parameters()] + \
426
[param.data.clone() for param in model1.parameters()]
427
428
for materialize_master_grads in (False, True):
429
for opt_level in ("O0", "O1", "O2", "O3"):
430
for how_to_zero in ("none", "model", "optimizer"):
431
for use_multiple_loss_scalers in (False, True):
432
if opt_level == "O1" or opt_level == "O2":
433
inject_inf_iters = (-1, 0, 1)
434
else:
435
inject_inf_iters = (-1,)
436
437
for inject_inf in inject_inf_iters:
438
if inject_inf >= 0:
439
inject_inf_locs = ("fp16", "fp32")
440
which_backwards = (0, 1)
441
else:
442
inject_inf_locs = ("fdsa",)
443
which_backwards = (None,)
444
445
for inject_inf_loc in inject_inf_locs:
446
for which_backward in which_backwards:
447
if use_multiple_loss_scalers:
448
num_losses = 2
449
loss_ids = [0, 1]
450
else:
451
num_losses = 1
452
loss_ids = [0, 0]
453
454
if inject_inf >= 0:
455
iters = 3
456
else:
457
iters = 2
458
459
model0 = MyModel(1)
460
model1 = MyModel(2)
461
462
models = [model0, model1]
463
464
optimizer0 = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25}],
465
momentum=0.125, materialize_master_grads=materialize_master_grads)
466
optimizer1 = FusedSGD([{'params' : model1.parameters(), 'lr' : 0.5}],
467
momentum=0.25, materialize_master_grads=materialize_master_grads)
468
469
_amp_state.allow_incoming_model_not_fp32 = True
470
[model0, model1], [optimizer0, optimizer1] = amp.initialize(
471
[model0, model1],
472
[optimizer0, optimizer1],
473
opt_level=opt_level,
474
verbosity=0,
475
cast_model_type=False,
476
num_losses=num_losses)
477
_amp_state.allow_incoming_model_not_fp32 = False
478
479
_amp_state.loss_scalers[0]._loss_scale = 4.0
480
if use_multiple_loss_scalers:
481
_amp_state.loss_scalers[1]._loss_scale = 16.0
482
483
unskipped = 0
484
for i in range(iters):
485
if how_to_zero == "none":
486
for model in models:
487
for param in model.parameters():
488
param.grad = None
489
elif how_to_zero == "model":
490
for model in models:
491
model.zero_grad()
492
else:
493
optimizer0.zero_grad()
494
optimizer1.zero_grad()
495
496
loss0 = model0(self.x)
497
loss1 = model1(self.x)
498
499
with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss:
500
scaled_loss.backward()
501
if i == inject_inf and which_backward == 0:
502
if inject_inf_loc == "fp32":
503
model0.weight0.grad[0] = float('inf')
504
elif inject_inf_loc == "fp16":
505
model0.weight1.grad[0] = float('inf')
506
with amp.scale_loss(loss1, optimizer1, loss_id=loss_ids[1]) as scaled_loss:
507
scaled_loss.backward()
508
if i == inject_inf and which_backward == 1:
509
if inject_inf_loc == "fp32":
510
model1.weight0.grad[0] = float('inf')
511
elif inject_inf_loc == "fp16":
512
model1.weight1.grad[0] = float('inf')
513
514
# print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers))
515
516
if i != inject_inf:
517
master_params = list(amp.master_params(optimizer0)) + \
518
list(amp.master_params(optimizer1))
519
for param, reference_grad in zip(master_params,
520
reference_grads[what_got_skipped(inject_inf, which_backward)][unskipped]):
521
if opt_level == "O2" and not materialize_master_grads:
522
continue
523
else:
524
self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()))
525
unskipped += 1
526
527
optimizer0.step()
528
optimizer1.step()
529
530
model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()]
531
master_params = [p for p in amp.master_params(optimizer0)] + \
532
[p for p in amp.master_params(optimizer1)]
533
for model, master, reference in zip(
534
model_params,
535
master_params,
536
final_params[what_got_skipped(inject_inf, which_backward)]):
537
self.assertTrue(torch.allclose(model, reference))
538
self.assertTrue(torch.allclose(model, master.to(model.dtype)))
539
540
if opt_level == "O1":
541
_amp_state.handle._deactivate()
542
543
@unittest.skipIf(disabled, "amp_C is unavailable")
544
def test_3models2losses2optimizers(self):
545
model0 = MyModel(1)
546
model1 = MyModel(2)
547
model2 = MyModel(3)
548
549
optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
550
{'params' : model1.parameters(), 'lr' : 1.0}],
551
momentum=0.5)
552
optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}],
553
momentum=0.25)
554
555
# Again, can't do this: reference_grads = [[]]*9
556
reference_grads = [[], [], [], [], [], [], [], [], []]
557
final_params = [None, None, None, None, None, None, None, None, None]
558
for i in range(2):
559
optimizer0.zero_grad()
560
optimizer1.zero_grad()
561
loss0 = model0(self.x) + model1(self.x)
562
loss1 = model2(self.x) + model1(self.x)
563
loss0.backward()
564
loss1.backward()
565
566
reference_grads[0].append([param.grad.data.clone() for param in model0.parameters()] +
567
[param.grad.data.clone() for param in model1.parameters()])
568
569
optimizer0.step()
570
optimizer1.step()
571
572
final_params[0] = \
573
[param.data.clone() for param in model0.parameters()] + \
574
[param.data.clone() for param in model1.parameters()] + \
575
[param.data.clone() for param in model2.parameters()]
576
577
def what_got_skipped(which_iter, which_backward, which_model):
578
if which_iter == 0:
579
if which_backward == 0:
580
if which_model == 0:
581
return 1
582
if which_model == 1:
583
return 2
584
if which_backward == 1:
585
if which_model == 2:
586
return 3
587
if which_model == 1:
588
return 4
589
if which_iter == 1:
590
if which_backward == 0:
591
if which_model == 0:
592
return 5
593
if which_model == 1:
594
return 6
595
if which_backward == 1:
596
if which_model == 2:
597
return 7
598
if which_model == 1:
599
return 8
600
return 0
601
602
for which_iter in (0,1):
603
for which_backward in (0,1):
604
if which_backward == 0:
605
which_models = (0,1)
606
if which_backward == 1:
607
which_models = (2,1)
608
for which_model in which_models:
609
610
model0 = MyModel(1)
611
model1 = MyModel(2)
612
model2 = MyModel(3)
613
614
optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
615
{'params' : model1.parameters(), 'lr' : 1.0}],
616
momentum=0.5)
617
optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}],
618
momentum=0.25)
619
620
for i in range(3):
621
optimizer0.zero_grad()
622
optimizer1.zero_grad()
623
loss0 = model0(self.x) + model1(self.x)
624
loss1 = model2(self.x) + model1(self.x)
625
loss0.backward()
626
loss1.backward()
627
628
if i != which_iter:
629
reference_grads[what_got_skipped(which_iter,
630
which_backward, which_model)].append(
631
[param.grad.data.clone() for param in model0.parameters()] +
632
[param.grad.data.clone() for param in model1.parameters()])
633
634
if i == which_iter:
635
if which_backward == 0:
636
# if which_model == 0:
637
optimizer1.step()
638
# if which_model == 1:
639
# optimizer1.step()
640
if which_backward == 1:
641
# if which_model == 2:
642
# optimizer0.step()
643
# if which_model == 1:
644
continue
645
else:
646
optimizer0.step()
647
optimizer1.step()
648
649
final_params[what_got_skipped(which_iter, which_backward, which_model)] = \
650
[param.data.clone() for param in model0.parameters()] + \
651
[param.data.clone() for param in model1.parameters()] + \
652
[param.data.clone() for param in model2.parameters()]
653
654
for materialize_master_grads in (False, True):
655
for opt_level in ("O0", "O1", "O2", "O3"):
656
for how_to_zero in ("none", "model", "optimizer"):
657
for use_multiple_loss_scalers in (False, True):
658
if opt_level == "O1" or opt_level == "O2":
659
inject_inf_iters = (-1, 0, 1)
660
else:
661
inject_inf_iters = (-1,)
662
663
for inject_inf in inject_inf_iters:
664
if inject_inf >= 0:
665
inject_inf_locs = ("fp16", "fp32")
666
which_backwards = (0, 1)
667
else:
668
inject_inf_locs = ("fdsa",)
669
which_backwards = (None,)
670
671
for inject_inf_loc in inject_inf_locs:
672
for which_backward in which_backwards:
673
if use_multiple_loss_scalers:
674
num_losses = 2
675
loss_ids = [0, 1]
676
else:
677
num_losses = 1
678
loss_ids = [0, 0]
679
680
if inject_inf >= 0:
681
iters = 3
682
if which_backward == 0:
683
which_models = (0, 1)
684
elif which_backward == 1:
685
which_models = (2, 1)
686
else:
687
iters = 2
688
which_models = (None,)
689
690
for which_model in which_models:
691
model0 = MyModel(1)
692
model1 = MyModel(2)
693
model2 = MyModel(3)
694
695
models = [model0, model1, model2]
696
697
optimizer0 = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},
698
{'params' : model1.parameters(), 'lr' : 1.0}],
699
momentum=0.5, materialize_master_grads=materialize_master_grads)
700
optimizer1 = FusedSGD([{'params' : model2.parameters(), 'lr' : 0.5}],
701
momentum=0.25, materialize_master_grads=materialize_master_grads)
702
703
_amp_state.allow_incoming_model_not_fp32 = True
704
[model0, model1, model2], [optimizer0, optimizer1] = amp.initialize(
705
[model0, model1, model2],
706
[optimizer0, optimizer1],
707
opt_level=opt_level,
708
verbosity=0,
709
cast_model_type=False,
710
num_losses=num_losses)
711
_amp_state.allow_incoming_model_not_fp32 = False
712
713
_amp_state.loss_scalers[0]._loss_scale = 4.0
714
if use_multiple_loss_scalers:
715
_amp_state.loss_scalers[1]._loss_scale = 16.0
716
717
unskipped = 0
718
for i in range(iters):
719
if how_to_zero == "none":
720
for model in models:
721
for param in model.parameters():
722
param.grad = None
723
elif how_to_zero == "model":
724
for model in models:
725
model.zero_grad()
726
else:
727
optimizer0.zero_grad()
728
optimizer1.zero_grad()
729
730
loss0 = model0(self.x) + model1(self.x)
731
loss1 = model2(self.x) + model1(self.x)
732
733
with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss:
734
scaled_loss.backward()
735
if i == inject_inf and which_backward == 0:
736
if which_model == 0:
737
inj_model = model0
738
elif which_model == 1:
739
inj_model = model1
740
else:
741
raise RuntimeError(which_model + " invalid for loss 0")
742
if inject_inf_loc == "fp32":
743
inj_model.weight0.grad[0] = float('inf')
744
elif inject_inf_loc == "fp16":
745
inj_model.weight1.grad[0] = float('inf')
746
with amp.scale_loss(loss1, [optimizer0, optimizer1], loss_id=loss_ids[1]) as scaled_loss:
747
scaled_loss.backward()
748
if i == inject_inf and which_backward == 1:
749
if which_model == 2:
750
inj_model = model2
751
elif which_model == 1:
752
inj_model = model1
753
else:
754
raise RuntimeError(which_model + " invalid for loss 1 ")
755
if inject_inf_loc == "fp32":
756
inj_model.weight0.grad[0] = float('inf')
757
elif inject_inf_loc == "fp16":
758
inj_model.weight1.grad[0] = float('inf')
759
760
if i != inject_inf:
761
master_params = list(amp.master_params(optimizer0)) + \
762
list(amp.master_params(optimizer1))
763
for param, reference_grad in zip(master_params,
764
reference_grads[what_got_skipped(inject_inf,
765
which_backward, which_model)][unskipped]):
766
if opt_level == "O2" and not materialize_master_grads:
767
continue
768
else:
769
self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()))
770
unskipped += 1
771
772
optimizer0.step()
773
optimizer1.step()
774
775
model_params = [p for p in model0.parameters()] + \
776
[p for p in model1.parameters()] + \
777
[p for p in model2.parameters()]
778
master_params = [p for p in amp.master_params(optimizer0)] + \
779
[p for p in amp.master_params(optimizer1)]
780
781
# print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {} which_model {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers, which_model))
782
783
for model, master, reference in zip(
784
model_params,
785
master_params,
786
final_params[what_got_skipped(inject_inf, which_backward, which_model)]):
787
self.assertTrue(torch.allclose(model, reference))
788
self.assertTrue(torch.allclose(model, master.to(model.dtype)))
789
790
if opt_level == "O1":
791
_amp_state.handle._deactivate()
792
793
if __name__ == '__main__':
794
unittest.main()
795
796