CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/tests/L0/run_optimizers/test_lamb.py
Views: 793
1
import unittest
2
import os
3
4
import torch
5
from torch.optim import Optimizer
6
import apex
7
from apex.multi_tensor_apply import multi_tensor_applier
8
from itertools import product
9
10
class RefLAMB(Optimizer):
11
r"""Implements Lamb algorithm.
12
13
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
14
15
Arguments:
16
params (iterable): iterable of parameters to optimize or dicts defining
17
parameter groups
18
lr (float, optional): learning rate (default: 1e-3)
19
betas (Tuple[float, float], optional): coefficients used for computing
20
running averages of gradient and its square (default: (0.9, 0.999))
21
eps (float, optional): term added to the denominator to improve
22
numerical stability (default: 1e-6)
23
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.01)
24
25
.. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
26
https://arxiv.org/abs/1904.00962
27
"""
28
29
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01):
30
if not 0.0 <= lr:
31
raise ValueError("Invalid learning rate: {}".format(lr))
32
if not 0.0 <= eps:
33
raise ValueError("Invalid epsilon value: {}".format(eps))
34
if not 0.0 <= betas[0] < 1.0:
35
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
36
if not 0.0 <= betas[1] < 1.0:
37
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
38
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
39
super(RefLAMB, self).__init__(params, defaults)
40
if multi_tensor_applier.available:
41
import amp_C
42
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
43
# Skip buffer
44
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
45
self.multi_tensor_lamb = amp_C.multi_tensor_lamb
46
else:
47
raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')
48
49
def step(self, closure=None):
50
"""Performs a single optimization step.
51
Arguments:
52
closure (callable, optional): A closure that reevaluates the model
53
and returns the loss.
54
"""
55
loss = None
56
if closure is not None:
57
loss = closure()
58
59
# create separate grad lists for fp32 and fp16 params
60
g_all_32, g_all_16 = [], []
61
for group in self.param_groups:
62
for p in group['params']:
63
if p.grad is None:
64
continue
65
if p.dtype == torch.float32:
66
g_all_32.append(p.grad.data)
67
elif p.dtype == torch.float16:
68
g_all_16.append(p.grad.data)
69
else:
70
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
71
72
device = self.param_groups[0]["params"][0].device
73
g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device)
74
# compute grad norm for two lists
75
if len(g_all_32) > 0:
76
g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
77
self._dummy_overflow_buf,
78
[g_all_32], False)[0]
79
if len(g_all_16) > 0:
80
g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,
81
self._dummy_overflow_buf,
82
[g_all_16], False)[0]
83
84
# blend two grad norms to get global grad norm
85
global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm,
86
self._dummy_overflow_buf,
87
[[g_norm_32, g_norm_16]],
88
False)[0]
89
90
max_grad_norm = 1.0
91
clipped_ratio = max_grad_norm / max(global_grad_norm, max_grad_norm)
92
93
for group in self.param_groups:
94
for p in group['params']:
95
if p.grad is None:
96
continue
97
p.grad.data *= clipped_ratio
98
grad = p.grad.data
99
if grad.is_sparse:
100
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
101
102
state = self.state[p]
103
104
# State initialization
105
if len(state) == 0:
106
state['step'] = 0
107
# Exponential moving average of gradient values
108
state['m'] = torch.zeros_like(p.data)
109
# Exponential moving average of squared gradient values
110
state['v'] = torch.zeros_like(p.data)
111
112
m_t, v_t = state['m'], state['v']
113
beta1, beta2 = group['betas']
114
115
state['step'] += 1
116
117
# m_t = beta1 * m + (1 - beta1) * g_t
118
m_t.mul_(beta1).add_(grad, alpha=1-beta1)
119
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
120
v_t.mul_(beta2).addcmul_(grad, grad, value=1-beta2)
121
122
# Debiasing
123
m_t_hat = m_t / (1.0 - beta1 ** state['step'])
124
v_t_hat = v_t / (1.0 - beta2 ** state['step'])
125
126
update = m_t_hat / v_t_hat.sqrt().add(group['eps'])
127
128
if group['weight_decay'] != 0:
129
update.add_(p.data, alpha=group['weight_decay'])
130
131
trust_ratio = 1.0
132
w_norm = p.data.pow(2).sum().sqrt()
133
g_norm = update.pow(2).sum().sqrt()
134
if w_norm > 0 and g_norm > 0:
135
trust_ratio = w_norm / g_norm
136
137
state['w_norm'] = w_norm
138
state['g_norm'] = g_norm
139
state['trust_ratio'] = trust_ratio
140
141
step_size = group['lr']
142
143
p.data.add_(update, alpha=-step_size*trust_ratio)
144
145
return loss
146
147
148
class TestFusedLAMB(unittest.TestCase):
149
def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):
150
self.max_abs_diff = max_abs_diff
151
self.max_rel_diff = max_rel_diff
152
self.iters = iters
153
torch.cuda.manual_seed(9876)
154
155
def tearDown(self):
156
pass
157
158
def gen_param_optim(self, tensors, lamb_option):
159
ref_param = []
160
tst_param = []
161
for tensor in tensors:
162
ref_param.append(torch.nn.Parameter(tensor.clone()))
163
tst_param.append(torch.nn.Parameter(tensor.clone()))
164
165
ref_optim = RefLAMB(ref_param, **lamb_option)
166
tst_optim = apex.optimizers.FusedLAMB(tst_param, use_nvlamb=True, **lamb_option)
167
168
return (ref_param, tst_param, ref_optim, tst_optim)
169
170
def gen_grad(self, ref_param, tst_param):
171
for p_ref, p_tst in zip(ref_param, tst_param):
172
p_ref.grad = torch.rand_like(p_ref)
173
p_tst.grad = p_ref.grad
174
175
def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
176
half_grads = []
177
for p_ref, _ in zip(ref_param, tst_param):
178
half_grads.append(torch.rand_like(p_ref).half())
179
p_ref.grad = half_grads[-1].float() / scale
180
return half_grads
181
182
def get_max_diff(self, ref_param, tst_param):
183
max_abs_diff = max_rel_diff = 0
184
for p_ref, p_tst in zip(ref_param, tst_param):
185
max_abs_diff_p = (p_ref - p_tst).abs().max().item()
186
max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()
187
188
if max_abs_diff_p > max_abs_diff: max_abs_diff = max_abs_diff_p
189
if max_rel_diff_p > max_rel_diff: max_rel_diff = max_rel_diff_p
190
191
return max_abs_diff, max_rel_diff
192
193
def gen_single_type_test(self, param_type=torch.float, device="cuda"):
194
nelem = 278011
195
tensor = torch.rand(nelem, dtype=param_type, device=device)
196
weight_decay = [0, 0.01]
197
198
for wd in weight_decay:
199
lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}
200
ref_param, tst_param, ref_optim, tst_optim = \
201
self.gen_param_optim([tensor], lamb_option)
202
203
for i in range(self.iters):
204
self.gen_grad(ref_param, tst_param)
205
ref_optim.step()
206
torch.cuda.synchronize()
207
tst_optim.step()
208
torch.cuda.synchronize()
209
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
210
211
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
212
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
213
214
def test_float(self):
215
self.gen_single_type_test(param_type=torch.float)
216
217
@unittest.skip("PyTorch optimizer is not numerically correct for fp16")
218
def test_half(self):
219
self.gen_single_type_test(param_type=torch.float16)
220
221
@unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
222
def test_multi_device(self):
223
devices = ("cuda:0", "cuda:1")
224
for current_dev, tensor_dev in product(devices, devices):
225
with torch.cuda.device(current_dev):
226
self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
227
228
def test_multi_params(self):
229
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
230
weight_decay = [0, 0.01]
231
232
for wd in weight_decay:
233
lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}
234
tensors = []
235
for size in sizes:
236
tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))
237
ref_param, tst_param, ref_optim, tst_optim = \
238
self.gen_param_optim(tensors, lamb_option)
239
240
for i in range(self.iters):
241
self.gen_grad(ref_param, tst_param)
242
ref_optim.step()
243
tst_optim.step()
244
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
245
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
246
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
247
248
def test_lamb_option(self):
249
nelem = 1
250
tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
251
weight_decay = [0, 0.01]
252
253
for wd in weight_decay:
254
lamb_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':wd}
255
ref_param, tst_param, ref_optim, tst_optim = \
256
self.gen_param_optim([tensor], lamb_option)
257
258
for i in range(self.iters):
259
self.gen_grad(ref_param, tst_param)
260
ref_optim.step()
261
tst_optim.step()
262
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
263
264
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
265
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
266
267
268
if __name__ == '__main__':
269
script_path = os.path.dirname(os.path.realpath(__file__))
270
unittest.main()
271
272