Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever
GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/tests/L0/run_amp/test_cache.py
1075 views
1
import unittest
2
3
import functools as ft
4
import itertools as it
5
6
from apex import amp
7
from apex.amp import _amp_state
8
import torch
9
from torch import nn
10
import torch.nn.functional as F
11
12
from utils import common_init, HALF, FLOAT,\
13
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
14
15
def get_reference_grad(i, w, ops):
16
# Creating new tensors ensures, among other things, that the new tensors are not in the cache.
17
# In fact, they are guaranteed not to use the cache because they are not torch.nn.Parameters.
18
fp32_i = i.detach().clone().float()
19
fp32_w = w.detach().clone().float().requires_grad_()
20
loss = ops(fp32_i, fp32_w)
21
loss.backward()
22
return fp32_w.grad
23
24
class WhitelistModule(torch.nn.Module):
25
def __init__(self, dtype):
26
super(WhitelistModule, self).__init__()
27
self.weight = torch.nn.Parameter(torch.arange(8*8, device='cuda', dtype=dtype).view(8,8))
28
29
@staticmethod
30
def ops(input, weight):
31
return (input.mm(weight)).mm(weight).sum()
32
33
def forward(self, input):
34
return self.ops(input, self.weight)
35
36
37
class BlacklistModule(torch.nn.Module):
38
def __init__(self, dtype):
39
super(BlacklistModule, self).__init__()
40
self.weight = torch.nn.Parameter(torch.arange(2*8, device='cuda', dtype=dtype).view(2,8))
41
42
@staticmethod
43
def ops(input, weight):
44
return (input + torch.pow(weight, 2) + torch.pow(weight, 2)).sum()
45
46
def forward(self, input):
47
return self.ops(input, self.weight)
48
49
50
class PromoteModule(torch.nn.Module):
51
def __init__(self, dtype):
52
super(PromoteModule, self).__init__()
53
self.weight = torch.nn.Parameter(torch.arange(2*8, device='cuda', dtype=dtype).view(2,8))
54
55
@staticmethod
56
def ops(input, weight):
57
return ((input*weight)*weight).sum()
58
59
def forward(self, input):
60
return self.ops(input, self.weight)
61
62
class TestCache(unittest.TestCase):
63
def setUp(self):
64
self.x = torch.ones((2, 8), device='cuda', dtype=torch.float32)
65
common_init(self)
66
67
def tearDown(self):
68
pass
69
70
def train_eval_train_test(self, module, t):
71
model = module(t).cuda()
72
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
73
74
_amp_state.allow_incoming_model_not_fp32 = True
75
model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0)
76
_amp_state.allow_incoming_model_not_fp32 = False
77
78
def training_step():
79
for param in model.parameters():
80
param.grad = None
81
82
loss = model(self.x).sum()
83
_amp_state.loss_scalers[0]._loss_scale = 4.0
84
with amp.scale_loss(loss, optimizer) as scaled_loss:
85
scaled_loss.backward()
86
87
self.assertEqual(len([p.grad for p in model.parameters() if p.grad is not None]), 1)
88
self.assertEqual(model.weight.grad.type(), model.weight.type())
89
90
reference_grad = get_reference_grad(self.x, model.weight, model.ops)
91
92
# Currently there's no difference in the allclose calls, so no need for branching,
93
# but I'm keeping this in case we want different tolerances for fp16 and fp32 checks.
94
if model.weight.grad.type() == "torch.cuda.HalfTensor":
95
self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad))
96
elif model.weight.grad.type() == "torch.cuda.FloatTensor":
97
self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad))
98
else:
99
raise RuntimeError("model.weight.grad.type = {}".format(model.weight.grad.type()))
100
101
model.weight.data -= 1.
102
103
# Simulates first epoch
104
training_step()
105
106
# Simulates eval
107
with torch.no_grad():
108
loss = model(self.x).sum()
109
110
# Simulates resuming training after eval
111
training_step()
112
113
_amp_state.handle._deactivate()
114
115
# I could easily have these as a set of for loops in a single test,
116
# instead of going for granularity.
117
def test_whitelist_module_fp16_weight(self):
118
self.train_eval_train_test(WhitelistModule, torch.float16)
119
120
def test_whitelist_module_fp32_weight(self):
121
self.train_eval_train_test(WhitelistModule, torch.float32)
122
123
def test_blacklist_module_fp16_weight(self):
124
self.train_eval_train_test(BlacklistModule, torch.float16)
125
126
def test_blacklist_module_fp32_weight(self):
127
self.train_eval_train_test(BlacklistModule, torch.float32)
128
129
def test_promote_module_fp16_weight(self):
130
self.train_eval_train_test(PromoteModule, torch.float16)
131
132
def test_promote_module_fp32_weight(self):
133
self.train_eval_train_test(PromoteModule, torch.float32)
134
135
136
if __name__ == '__main__':
137
unittest.main()
138
139