CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/csrc/multi_tensor_lamb.cu
Views: 792
1
#include <ATen/ATen.h>
2
#include <ATen/AccumulateType.h>
3
#include <ATen/cuda/CUDAContext.h>
4
#include <ATen/cuda/Exceptions.h>
5
// Another possibility:
6
// #include <torch/all.h>
7
8
#include <assert.h>
9
10
#include "type_shim.h"
11
#include "multi_tensor_apply.cuh"
12
13
#define BLOCK_SIZE 512
14
#define ILP 4
15
16
template<typename T>
17
__device__ __forceinline__ bool is_aligned(T* p){
18
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
19
}
20
21
template<typename T>
22
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
23
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
24
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
25
}
26
27
typedef enum{
28
MOMENT_MODE_0 =0, // L2 regularization mode
29
MOMENT_MODE_1 =1 // Decoupled weight decay mode
30
} adamMode_t;
31
32
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
33
int chunk_size,
34
at::Tensor noop_flag,
35
std::vector<std::vector<at::Tensor>> tensor_lists,
36
at::optional<bool> per_tensor_python);
37
38
using MATH_T = float;
39
40
template<typename T>
41
struct LAMBStage1Functor
42
{
43
__device__ __forceinline__ void operator()(
44
int chunk_size,
45
volatile int* noop_gmem,
46
TensorListMetadata<4>& tl,
47
const float beta1,
48
const float beta2,
49
const float beta3,
50
const float beta1_correction,
51
const float beta2_correction,
52
const float epsilon,
53
adamMode_t mode,
54
const float decay,
55
const float* global_grad_norm,
56
const float max_global_grad_norm)
57
{
58
// I'd like this kernel to propagate infs/nans.
59
// if(*noop_gmem == 1)
60
// return;
61
62
int tensor_loc = tl.block_to_tensor[blockIdx.x];
63
int chunk_idx = tl.block_to_chunk[blockIdx.x];
64
int n = tl.sizes[tensor_loc];
65
66
float clipped_global_grad_norm = (*global_grad_norm) > max_global_grad_norm ? (*global_grad_norm) / max_global_grad_norm : 1.0f;
67
68
T* g = (T*)tl.addresses[0][tensor_loc];
69
g += chunk_idx*chunk_size;
70
71
T* p = (T*)tl.addresses[1][tensor_loc];
72
p += chunk_idx*chunk_size;
73
74
T* m = (T*)tl.addresses[2][tensor_loc];
75
m += chunk_idx*chunk_size;
76
77
T* v = (T*)tl.addresses[3][tensor_loc];
78
v += chunk_idx*chunk_size;
79
80
n -= chunk_idx*chunk_size;
81
82
MATH_T r_g[ILP];
83
MATH_T r_p[ILP];
84
MATH_T r_m[ILP];
85
MATH_T r_v[ILP];
86
// to make things simple, we put aligned case in a different code path
87
if(n % ILP == 0 &&
88
chunk_size % ILP == 0 &&
89
is_aligned(g) &&
90
is_aligned(p) &&
91
is_aligned(m) &&
92
is_aligned(v))
93
{
94
T l_g[ILP];
95
T l_p[ILP];
96
T l_m[ILP];
97
T l_v[ILP];
98
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
99
{
100
// load
101
load_store(l_g, g, 0, i_start);
102
if (decay != 0)
103
load_store(l_p, p, 0, i_start);
104
load_store(l_m, m, 0, i_start);
105
load_store(l_v, v, 0, i_start);
106
// unpack
107
#pragma unroll
108
for(int ii = 0; ii < ILP; ii++)
109
{
110
r_g[ii] = l_g[ii];
111
if (decay == 0) {
112
r_p[ii] = MATH_T(0);
113
}
114
else {
115
r_p[ii] = l_p[ii];
116
}
117
r_m[ii] = l_m[ii];
118
r_v[ii] = l_v[ii];
119
}
120
#pragma unroll
121
for(int ii = 0; ii < ILP; ii++)
122
{
123
if (mode == MOMENT_MODE_0) {
124
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
125
// L2 on scaled grad
126
scaled_grad = scaled_grad + decay*r_p[ii];
127
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
128
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
129
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
130
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
131
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
132
r_p[ii] = next_m_unbiased / denom;
133
}
134
else {
135
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
136
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
137
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
138
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
139
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
140
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
141
r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
142
}
143
}
144
#pragma unroll
145
for(int ii = 0; ii < ILP; ii++)
146
{
147
l_p[ii] = r_p[ii];
148
l_m[ii] = r_m[ii];
149
l_v[ii] = r_v[ii];
150
}
151
// store
152
load_store(g, l_p, i_start, 0);
153
load_store(m, l_m, i_start, 0);
154
load_store(v, l_v, i_start, 0);
155
}
156
}
157
else
158
{
159
// see note in multi_tensor_scale_kernel.cu
160
for(int i_start = 0;
161
i_start < n && i_start < chunk_size;
162
i_start += blockDim.x*ILP)
163
{
164
MATH_T r_g[ILP];
165
MATH_T r_p[ILP];
166
MATH_T r_m[ILP];
167
MATH_T r_v[ILP];
168
#pragma unroll
169
for(int ii = 0; ii < ILP; ii++)
170
{
171
int i = i_start + threadIdx.x + ii*blockDim.x;
172
if(i < n && i < chunk_size)
173
{
174
r_g[ii] = g[i];
175
// special ?optimization? for lamb stage 1
176
if (decay == 0) {
177
r_p[ii] = MATH_T(0);
178
}
179
else {
180
r_p[ii] = p[i];
181
}
182
r_m[ii] = m[i];
183
r_v[ii] = v[i];
184
} else {
185
r_g[ii] = MATH_T(0);
186
r_p[ii] = MATH_T(0);
187
r_m[ii] = MATH_T(0);
188
r_v[ii] = MATH_T(0);
189
}
190
}
191
#pragma unroll
192
for(int ii = 0; ii < ILP; ii++)
193
{
194
if (mode == MOMENT_MODE_0) {
195
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
196
// L2 on scaled grad
197
scaled_grad = scaled_grad + decay*r_p[ii];
198
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
199
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
200
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
201
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
202
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
203
r_p[ii] = next_m_unbiased / denom;
204
}
205
else {
206
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
207
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
208
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
209
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
210
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
211
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
212
r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
213
}
214
}
215
#pragma unroll
216
for(int ii = 0; ii < ILP; ii++)
217
{
218
int i = i_start + threadIdx.x + ii*blockDim.x;
219
if(i < n && i < chunk_size)
220
{
221
g[i] = r_p[ii];
222
m[i] = r_m[ii];
223
v[i] = r_v[ii];
224
}
225
}
226
}
227
}
228
}
229
};
230
231
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
232
// It computes new parameter value.
233
template<typename T>
234
struct LAMBStage2Functor
235
{
236
__device__ __forceinline__ void operator()(
237
int chunk_size,
238
volatile int* noop_gmem,
239
TensorListMetadata<2>& tl,
240
const float* per_tensor_param_norm,
241
const float* per_tensor_update_norm,
242
const float learning_rate,
243
const float decay,
244
bool use_nvlamb)
245
{
246
// I'd like this kernel to propagate infs/nans.
247
// if(*noop_gmem == 1)
248
// return;
249
250
int tensor_loc = tl.block_to_tensor[blockIdx.x];
251
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
252
int chunk_idx = tl.block_to_chunk[blockIdx.x];
253
int n = tl.sizes[tensor_loc];
254
255
MATH_T ratio = learning_rate;
256
// nvlamb: apply adaptive learning rate to all parameters
257
// otherwise, only apply to those with non-zero weight decay
258
if (use_nvlamb || (decay != 0.0))
259
{
260
float param_norm = per_tensor_param_norm[tensor_num];
261
float update_norm = per_tensor_update_norm[tensor_num];
262
ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
263
}
264
265
T* update = (T*)tl.addresses[0][tensor_loc];
266
update += chunk_idx*chunk_size;
267
268
T* p = (T*)tl.addresses[1][tensor_loc];
269
p += chunk_idx*chunk_size;
270
271
n -= chunk_idx*chunk_size;
272
273
// to make things simple, we put aligned case in a different code path
274
if(n % ILP == 0 &&
275
chunk_size % ILP == 0 &&
276
is_aligned(p) &&
277
is_aligned(update))
278
{
279
T r_p[ILP];
280
T r_update[ILP];
281
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
282
{
283
// load
284
load_store(r_p, p, 0, i_start);
285
load_store(r_update, update, 0, i_start);
286
#pragma unroll
287
for(int ii = 0; ii < ILP; ii++)
288
{
289
r_p[ii] = static_cast<MATH_T>(r_p[ii]) - (ratio * static_cast<MATH_T>(r_update[ii]));
290
}
291
load_store(p, r_p, i_start, 0);
292
}
293
}
294
else
295
{
296
for(int i_start = 0;
297
i_start < n && i_start < chunk_size;
298
i_start += blockDim.x*ILP)
299
{
300
MATH_T r_p[ILP];
301
MATH_T r_update[ILP];
302
#pragma unroll
303
for(int ii = 0; ii < ILP; ii++)
304
{
305
int i = i_start + threadIdx.x + ii*blockDim.x;
306
if(i < n && i < chunk_size)
307
{
308
r_p[ii] = p[i];
309
r_update[ii] = update[i];
310
}
311
}
312
#pragma unroll
313
for(int ii = 0; ii < ILP; ii++)
314
{
315
r_p[ii] = r_p[ii] - (ratio * r_update[ii]);
316
}
317
#pragma unroll
318
for(int ii = 0; ii < ILP; ii++)
319
{
320
int i = i_start + threadIdx.x + ii*blockDim.x;
321
if(i < n && i < chunk_size)
322
{
323
p[i] = r_p[ii];
324
}
325
}
326
}
327
}
328
}
329
};
330
331
332
void multi_tensor_lamb_cuda(
333
int chunk_size,
334
at::Tensor noop_flag,
335
std::vector<std::vector<at::Tensor>> tensor_lists,
336
const float lr,
337
const float beta1,
338
const float beta2,
339
const float epsilon,
340
const int step,
341
const int bias_correction,
342
const float weight_decay,
343
const int grad_averaging,
344
const int mode,
345
at::Tensor global_grad_norm,
346
const float max_grad_norm,
347
at::optional<bool> use_nvlamb_python)
348
{
349
using namespace at;
350
// Master weight and 32bit momentum(potentially changing) is not handled by this
351
// So we assume every tensor are all in the same type
352
353
bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
354
355
// Handle bias correction mode
356
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
357
if (bias_correction == 1) {
358
bias_correction1 = 1 - std::pow(beta1, step);
359
bias_correction2 = 1 - std::pow(beta2, step);
360
}
361
362
// Handle grad averaging mode
363
float beta3 = 1.0f;
364
if (grad_averaging == 1) beta3 = 1 - beta1;
365
366
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin()+1);
367
std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2);
368
369
// Compute per tensor param norm
370
auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);
371
372
// We now in-place modify grad to store update before compute its norm
373
// Generally this is not a issue since people modify grad in step() method all the time
374
// We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code
375
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
376
multi_tensor_apply<4>(
377
BLOCK_SIZE,
378
chunk_size,
379
noop_flag,
380
tensor_lists,
381
LAMBStage1Functor<scalar_t_0>(),
382
beta1,
383
beta2,
384
beta3, // 1-beta1 or 1 depends on averaging mode
385
bias_correction1,
386
bias_correction2,
387
epsilon,
388
(adamMode_t) mode,
389
weight_decay,
390
global_grad_norm.DATA_PTR<float>(),
391
max_grad_norm); )
392
393
// Compute update norms
394
auto update_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true);
395
396
std::vector<std::vector<at::Tensor>> grad_param_list(tensor_lists.begin(), tensor_lists.begin()+2);
397
398
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
399
multi_tensor_apply<2>(
400
BLOCK_SIZE,
401
chunk_size,
402
noop_flag,
403
grad_param_list,
404
LAMBStage2Functor<scalar_t_0>(),
405
std::get<1>(param_norm_tuple).DATA_PTR<float>(),
406
std::get<1>(update_norm_tuple).DATA_PTR<float>(),
407
lr,
408
weight_decay,
409
use_nvlamb); )
410
411
AT_CUDA_CHECK(cudaGetLastError());
412
413
}
414
415