CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/tests/L0/run_amp/test_checkpointing.py
Views: 794
1
import unittest
2
3
import torch
4
import torch.nn as nn
5
import torch.nn.functional as F
6
import torch.optim as optim
7
8
from apex import amp
9
10
11
from utils import common_init, FLOAT
12
13
14
class MyModel(torch.nn.Module):
15
def __init__(self):
16
super(MyModel, self).__init__()
17
self.conv1 = nn.Conv2d(3, 6, 3, 1, 1)
18
self.bn1 = nn.BatchNorm2d(6)
19
self.param = nn.Parameter(torch.randn(1))
20
21
def forward(self, x):
22
x = x * self.param
23
x = F.relu(self.conv1(x))
24
x = self.bn1(x)
25
return x
26
27
28
class TestCheckpointing(unittest.TestCase):
29
def setUp(self):
30
self.initial_lr = 1e-3
31
self.test_opt_levels = ("O0", "O1", "O2", "O3")
32
33
def seed(self):
34
torch.manual_seed(2809)
35
torch.backends.cudnn.benchmark = False
36
torch.backends.cudnn.deterministic = True
37
38
def check_state_dict_fp32(self, state_dict):
39
for key in state_dict:
40
if 'num_batches_tracked' in key:
41
continue
42
param = state_dict[key]
43
self.assertEqual(param.type(), FLOAT,
44
'Parameter in state_dict not FLOAT')
45
46
def train_step(self, model, optimizer, data, loss_ids):
47
optimizer.zero_grad()
48
49
output = model(data)
50
51
# Call backward for num_losses-1
52
for idx in loss_ids:
53
loss = output.mean()
54
with amp.scale_loss(loss, optimizer, loss_id=idx) as scaled_loss:
55
scaled_loss.backward(retain_graph=True)
56
57
optimizer.step()
58
return output
59
60
def compare_models(self, modelA, modelB, test_setup=''):
61
state_dictA = modelA.state_dict()
62
state_dictB = modelB.state_dict()
63
self.assertEqual(len(state_dictA), len(state_dictB),
64
'state_dicts have different lengths' + test_setup)
65
for key in state_dictA:
66
paramA = state_dictA[key]
67
paramB = state_dictB[key]
68
self.assertTrue((paramA==paramB).all(),
69
msg='Parameters in state_dices not equal.' +
70
'key: {}\nparam: {}\nrestored: {}\ndiff: {} for {}'.format(
71
key, paramA, paramB, paramA - paramB, test_setup))
72
73
def test_restoring(self):
74
nb_epochs = 10
75
nb_epochs_restore = nb_epochs // 2
76
for opt_level in self.test_opt_levels:
77
for res_opt_level in self.test_opt_levels:
78
for amp_before_load in [True, False]:
79
for num_losses in range(1, 3):
80
test_setup = ('#' * 75 + '\n' + \
81
f'opt_level {opt_level}\n' + \
82
f'restore_opt_level {res_opt_level}\n' + \
83
f'amp_before_load {amp_before_load}\n' + \
84
f'num_losses {num_losses}\n')
85
86
self.seed()
87
88
# Create reference model
89
model = MyModel().to('cuda')
90
91
optimizer = optim.SGD(model.parameters(),
92
lr=self.initial_lr)
93
94
# Initialize with num_losses*2 for the original model and the restored one
95
model, optimizer = amp.initialize(
96
model, optimizer, opt_level=opt_level,
97
num_losses=num_losses*2, verbosity=0)
98
99
# Compare training behavior for same restore option
100
# We cannot really generalize it, since a saved model in O0
101
# would introduce a skipped step in O1, which will raise an error
102
if opt_level == res_opt_level:
103
# train for nb_epochs and restore after nb_epochs_restore
104
for epoch in range(nb_epochs):
105
106
x = torch.randn(16, 3, 24, 24, device='cuda')
107
output = self.train_step(
108
model, optimizer, x, range(num_losses))
109
# Initialize model one step before comparing.
110
# Otherwise the batchnorm layers will be updated
111
# additionally in restore_model
112
if epoch == (nb_epochs_restore - 1):
113
# Load model and optimizer
114
checkpoint = {
115
'model': model.state_dict(),
116
'optimizer': optimizer.state_dict(),
117
'amp': amp.state_dict()
118
}
119
# Check state_dict for FP32 tensors
120
self.check_state_dict_fp32(checkpoint['model'])
121
122
# Restore model
123
restore_model = MyModel().to('cuda')
124
restore_optimizer = optim.SGD(
125
restore_model.parameters(),
126
lr=self.initial_lr)
127
128
if amp_before_load:
129
restore_model, restore_optimizer = amp.initialize(
130
restore_model,
131
restore_optimizer,
132
opt_level=res_opt_level,
133
num_losses=num_losses*2,
134
verbosity=0)
135
136
restore_model.load_state_dict(checkpoint['model'])
137
restore_optimizer.load_state_dict(checkpoint['optimizer'])
138
# FIXME: We cannot test the amp.state_dict in the same script
139
# amp.load_state_dict(checkpoint['amp'])
140
141
if not amp_before_load:
142
restore_model, restore_optimizer = amp.initialize(
143
restore_model,
144
restore_optimizer,
145
opt_level=res_opt_level,
146
num_losses=num_losses*2,
147
verbosity=0)
148
149
elif epoch >= nb_epochs_restore:
150
restore_output = self.train_step(
151
restore_model,
152
restore_optimizer,
153
x,
154
range(num_losses, num_losses*2))
155
self.assertTrue(
156
torch.allclose(output.float(), restore_output.float()),
157
'Output of reference and restored models differ for ' + test_setup)
158
self.compare_models(model, restore_model, test_setup)
159
# if opt_level != res_opt_level
160
else:
161
# skip tests for different opt_levels
162
continue
163
164
def test_loss_scale_decrease(self):
165
num_losses = 3
166
nb_decrease_loss_scales = [0, 1, 2]
167
for opt_level in self.test_opt_levels:
168
#print('#' * 75 + f'\n opt_level {opt_level}\n')
169
# Create new tmp copy for this run
170
nb_decrease_loss_scales_tmp = list(nb_decrease_loss_scales)
171
172
model = MyModel().to('cuda')
173
174
optimizer = optim.SGD(model.parameters(),
175
lr=self.initial_lr)
176
177
model, optimizer = amp.initialize(
178
model, optimizer, opt_level=opt_level, num_losses=num_losses,
179
verbosity=0)
180
181
if amp._amp_state.opt_properties.loss_scale != 'dynamic':
182
#print('Static loss scale set. Skipping opt_level.')
183
continue
184
185
# force to skip some updates to decrease the loss_scale
186
initial_loss_scales = []
187
for idx in range(num_losses):
188
initial_loss_scales.append(
189
amp._amp_state.loss_scalers[idx].loss_scale())
190
191
for _ in range(len(nb_decrease_loss_scales)):
192
x = torch.randn(16, 3, 24, 24, device='cuda')
193
for idx in range(num_losses):
194
while nb_decrease_loss_scales_tmp[idx] > 0:
195
optimizer.zero_grad()
196
output = model(x * 2**17)
197
loss = output.mean()
198
199
with amp.scale_loss(loss, optimizer, loss_id=idx) as scaled_loss:
200
scaled_loss.backward(retain_graph=True)
201
optimizer.step()
202
nb_decrease_loss_scales_tmp[idx] -= 1
203
204
# Check loss scales afterwards
205
updated_loss_scales = []
206
for idx in range(num_losses):
207
updated_loss_scales.append(
208
amp._amp_state.loss_scalers[idx].loss_scale())
209
for factor, update_ls, init_ls in zip(nb_decrease_loss_scales,
210
updated_loss_scales,
211
initial_loss_scales):
212
self.assertEqual(update_ls, init_ls / 2**factor)
213
214
# Check state dict
215
amp_state_dict = amp.state_dict()
216
for scaler_idx, factor, init_ls in zip(amp_state_dict,
217
nb_decrease_loss_scales,
218
initial_loss_scales):
219
scaler = amp_state_dict[scaler_idx]
220
self.assertEqual(scaler['loss_scale'], init_ls / 2**factor)
221
unskipped_target = 0
222
self.assertEqual(scaler['unskipped'], unskipped_target)
223
224
def test_state_dict(self):
225
for opt_level in self.test_opt_levels:
226
# Skip O3
227
if opt_level == 'O3':
228
continue
229
230
model = MyModel().to('cuda')
231
optimizer = optim.Adam(model.parameters(), lr=1e-3)
232
model, optimizer = amp.initialize(
233
model, optimizer, opt_level=opt_level, verbosity=0)
234
235
# Export state_dict and check for Half
236
state_dict = model.state_dict()
237
for key in state_dict:
238
self.assertFalse('Half' in state_dict[key].type())
239
240
# Check, if model is still trainable
241
# Create dummy data
242
data = torch.randn(10, 3, 4, 4, device='cuda')
243
target = torch.randn(10, 6, 4, 4, device='cuda')
244
245
# Get initnial loss
246
optimizer.zero_grad()
247
output = model(data)
248
loss = F.mse_loss(output, target)
249
with amp.scale_loss(loss, optimizer) as scaled_loss:
250
scaled_loss.backward()
251
optimizer.step()
252
last_loss = loss.item()
253
254
# train for some epochs
255
for epoch in range(10):
256
optimizer.zero_grad()
257
output = model(data)
258
loss = F.mse_loss(output, target)
259
with amp.scale_loss(loss, optimizer) as scaled_loss:
260
scaled_loss.backward()
261
optimizer.step()
262
self.assertTrue(loss.item() < last_loss)
263
last_loss = loss.item()
264
265
if __name__=='__main__':
266
unittest.main()
267
268
269