CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/csrc/multi_tensor_l2norm_kernel.cu
Views: 792
1
#include <ATen/ATen.h>
2
#include <ATen/AccumulateType.h>
3
#include <ATen/cuda/CUDAContext.h>
4
#include <ATen/cuda/Exceptions.h>
5
#include <c10/cuda/CUDAGuard.h>
6
// Another possibility:
7
// #include <torch/all.h>
8
9
#include <assert.h>
10
11
#include "type_shim.h"
12
#include "multi_tensor_apply.cuh"
13
14
#define BLOCK_SIZE 512
15
#define ILP 4
16
17
template<typename T>
18
__device__ __forceinline__ bool is_aligned(T* p){
19
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
20
}
21
22
template<typename T>
23
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
24
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
25
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
26
}
27
28
template<typename x_t>
29
struct L2NormFunctor
30
{
31
__device__ __forceinline__ void operator()(
32
int chunk_size,
33
volatile int* noop_gmem,
34
TensorListMetadata<1>& tl,
35
float* output,
36
float* output_per_tensor,
37
bool per_tensor,
38
int max_chunks_per_tensor)
39
{
40
// I'd like this kernel to propagate infs/nans.
41
// if(*noop_gmem == 1)
42
// return;
43
44
int tensor_loc = tl.block_to_tensor[blockIdx.x];
45
int chunk_idx = tl.block_to_chunk[blockIdx.x];
46
int n = tl.sizes[tensor_loc];
47
48
x_t* x = (x_t*)tl.addresses[0][tensor_loc];
49
x += chunk_idx*chunk_size;
50
51
n -= chunk_idx*chunk_size;
52
53
__shared__ float s_vals[512];
54
55
float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
56
x_t r_x[ILP];
57
for(int i = 0; i < ILP; i++)
58
{
59
vals[i] = 0.f;
60
r_x[i] = 0;
61
}
62
63
// to make things simple, we put aligned case in a different code path
64
if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x))
65
{
66
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
67
{
68
// load
69
load_store(r_x, x, 0 , i_start);
70
#pragma unroll
71
for(int ii = 0; ii < ILP; ii++)
72
{
73
float next = static_cast<float>(r_x[ii]);
74
vals[ii] += next*next;
75
}
76
}
77
}
78
else
79
{
80
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
81
{
82
#pragma unroll
83
for(int ii = 0; ii < ILP; ii++)
84
{
85
int i = i_start + threadIdx.x + ii*blockDim.x;
86
if(i < n && i < chunk_size)
87
{
88
float next = static_cast<float>(x[i]);
89
vals[ii] += next*next;
90
}
91
}
92
}
93
}
94
95
float val = 0.f;
96
for(int i = 0; i < ILP; i++)
97
val += vals[i];
98
99
float final = reduce_block_into_lanes(s_vals, val);
100
101
if(threadIdx.x == 0)
102
{
103
if(!isfinite(final))
104
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
105
output[blockIdx.x] += final;
106
if(per_tensor)
107
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final;
108
}
109
}
110
};
111
112
// Probably better to template, but since we are not likely to support other norm
113
template<typename x_t>
114
struct MaxNormFunctor
115
{
116
__device__ __forceinline__ void operator()(
117
int chunk_size,
118
volatile int* noop_gmem,
119
TensorListMetadata<1>& tl,
120
float* output,
121
float* output_per_tensor,
122
bool per_tensor,
123
int max_chunks_per_tensor)
124
{
125
// I'd like this kernel to propagate infs/nans.
126
// if(*noop_gmem == 1)
127
// return;
128
129
int tensor_loc = tl.block_to_tensor[blockIdx.x];
130
int chunk_idx = tl.block_to_chunk[blockIdx.x];
131
int n = tl.sizes[tensor_loc];
132
133
x_t* x = (x_t*)tl.addresses[0][tensor_loc];
134
x += chunk_idx*chunk_size;
135
136
n -= chunk_idx*chunk_size;
137
138
__shared__ float s_vals[512];
139
140
float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
141
x_t r_x[ILP];
142
for(int i = 0; i < ILP; i++)
143
{
144
vals[i] = 0.f;
145
r_x[i] = 0;
146
}
147
148
// to make things simple, we put aligned case in a different code path
149
if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x))
150
{
151
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
152
{
153
// load
154
load_store(r_x, x, 0 , i_start);
155
#pragma unroll
156
for(int ii = 0; ii < ILP; ii++)
157
{
158
float next = static_cast<float>(r_x[ii]);
159
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
160
}
161
}
162
}
163
else
164
{
165
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
166
{
167
#pragma unroll
168
for(int ii = 0; ii < ILP; ii++)
169
{
170
int i = i_start + threadIdx.x + ii*blockDim.x;
171
if(i < n && i < chunk_size)
172
{
173
float next = static_cast<float>(x[i]);
174
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
175
}
176
}
177
}
178
}
179
180
float val = 0.f;
181
for(int i = 0; i < ILP; i++)
182
val = fmaxf(fabsf(val), fabsf(vals[i]));
183
184
float final = reduce_block_into_lanes_max_op(s_vals, val);
185
186
if(threadIdx.x == 0)
187
{
188
if(!isfinite(final))
189
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
190
output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final));
191
if(per_tensor)
192
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final;
193
}
194
}
195
};
196
197
198
__global__ void cleanup(
199
float* output,
200
float* output_per_tensor,
201
float* ret,
202
float* ret_per_tensor,
203
bool per_tensor,
204
int max_chunks_per_tensor)
205
{
206
__shared__ float vals[512];
207
208
if(blockIdx.x == 0)
209
{
210
float val = 0;
211
if(threadIdx.x < 320)
212
val = output[threadIdx.x];
213
214
float final = reduce_block_into_lanes(vals, val);
215
216
if(threadIdx.x == 0)
217
*ret = sqrt(final);
218
}
219
220
if(per_tensor)
221
{
222
float* output_this_tensor = output_per_tensor + blockIdx.x*max_chunks_per_tensor;
223
224
float val = 0;
225
for(int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
226
val += output_this_tensor[i];
227
228
float final = reduce_block_into_lanes(vals, val);
229
230
if(threadIdx.x == 0)
231
ret_per_tensor[blockIdx.x] = sqrt(final);
232
}
233
}
234
235
__global__ void cleanup_v2(
236
float* output,
237
float* output_per_tensor,
238
float* ret,
239
float* ret_per_tensor,
240
bool per_tensor,
241
int max_chunks_per_tensor,
242
int norm_type,
243
float alpha,
244
float beta)
245
{
246
__shared__ float vals[512];
247
248
if(blockIdx.x == 0)
249
{
250
float val = 0;
251
if(threadIdx.x < 320)
252
val = output[threadIdx.x];
253
254
if (norm_type == 0) {
255
float final = reduce_block_into_lanes_max_op(vals, val);
256
if(threadIdx.x == 0)
257
*ret = alpha * (*ret) + beta * final;
258
}
259
else {
260
float final = reduce_block_into_lanes(vals, val);
261
if(threadIdx.x == 0)
262
*ret = sqrt(alpha * (*ret) * (*ret) + beta * final);
263
}
264
}
265
266
if(per_tensor)
267
{
268
float* output_this_tensor = output_per_tensor + blockIdx.x*max_chunks_per_tensor;
269
270
if (norm_type == 0) {
271
float val = 0;
272
for(int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
273
val = fmaxf(fabsf(val), fabsf(output_this_tensor[i]));
274
275
float final = reduce_block_into_lanes_max_op(vals, val);
276
277
if(threadIdx.x == 0)
278
ret_per_tensor[blockIdx.x] = alpha * ret_per_tensor[blockIdx.x] + beta * final;
279
}
280
else {
281
float val = 0;
282
for(int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
283
val += output_this_tensor[i];
284
285
float final = reduce_block_into_lanes(vals, val);
286
287
if(threadIdx.x == 0)
288
ret_per_tensor[blockIdx.x] = sqrt(alpha * ret_per_tensor[blockIdx.x] * ret_per_tensor[blockIdx.x] + beta * final);
289
}
290
}
291
}
292
293
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
294
int chunk_size,
295
at::Tensor noop_flag,
296
std::vector<std::vector<at::Tensor>> tensor_lists,
297
at::optional<bool> per_tensor_python)
298
{
299
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
300
301
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
302
auto output = at::zeros({320}, float_options);
303
304
at::Tensor output_per_tensor;
305
at::Tensor ret_per_tensor;
306
307
int ntensors = tensor_lists[0].size();
308
int max_chunks_per_tensor = -1;
309
310
if(per_tensor)
311
{
312
for(int t = 0; t < ntensors; t++)
313
{
314
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
315
if(max_chunks_this_tensor > max_chunks_per_tensor)
316
max_chunks_per_tensor = max_chunks_this_tensor;
317
}
318
output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options);
319
ret_per_tensor = at::empty({ntensors}, float_options);
320
}
321
else
322
{
323
ret_per_tensor = at::empty({0}, float_options);
324
}
325
326
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
327
multi_tensor_apply<1>(
328
BLOCK_SIZE,
329
chunk_size,
330
noop_flag,
331
tensor_lists,
332
L2NormFunctor<scalar_t_0>(),
333
output.DATA_PTR<float>(),
334
per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,
335
per_tensor,
336
max_chunks_per_tensor);)
337
338
AT_CUDA_CHECK(cudaGetLastError());
339
// AT_CUDA_CHECK(cudaDeviceSynchronize());
340
341
// This involves one more small kernel launches, but will be negligible end to end.
342
// I could get rid of these by hacking the functor + multi tensor harness with persistence
343
// logic, but keeping it simple for now
344
auto ret = at::empty({1}, output.options());
345
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
346
auto stream = at::cuda::getCurrentCUDAStream();
347
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
348
output.DATA_PTR<float>(),
349
per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,
350
ret.DATA_PTR<float>(),
351
per_tensor ? ret_per_tensor.DATA_PTR<float>() : nullptr,
352
per_tensor,
353
max_chunks_per_tensor);
354
355
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
356
}
357
358
359
// Compute and update grad norm
360
// Here use a per tensor norm, and blend new norm(n) and old norm(gn) by
361
// L-2: gn = sqrt(a * gn^2 + b * n^2)
362
// L-inf: gn = a * gn + b * n
363
void multi_tensor_norm_out_cuda(
364
int chunk_size,
365
at::Tensor noop_flag,
366
std::vector<std::vector<at::Tensor>> tensor_lists,
367
at::Tensor out,
368
const float alpha,
369
const float beta,
370
const int norm_type)
371
{
372
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
373
TORCH_CHECK(tensor_lists[0][0].device() == noop_flag.device(), "noop flag should be on the same device as tensors");
374
// we don't need global thus uses empty here
375
auto output = at::empty({320}, float_options);
376
377
at::Tensor output_per_tensor;
378
at::Tensor ret_per_tensor;
379
380
int ntensors = tensor_lists[0].size();
381
int max_chunks_per_tensor = -1;
382
383
for(int t = 0; t < ntensors; t++)
384
{
385
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
386
if(max_chunks_this_tensor > max_chunks_per_tensor)
387
max_chunks_per_tensor = max_chunks_this_tensor;
388
}
389
390
// Although it is single write then read, still need to be zero
391
// Since tailing element also participate cleanup
392
output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options);
393
394
if (norm_type == 0) {
395
DISPATCH_FLOAT_AND_HALF(
396
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda",
397
multi_tensor_apply<1>(
398
BLOCK_SIZE,
399
chunk_size,
400
noop_flag,
401
tensor_lists,
402
MaxNormFunctor<scalar_t_0>(),
403
output.DATA_PTR<float>(),
404
output_per_tensor.DATA_PTR<float>(),
405
true,
406
max_chunks_per_tensor);)
407
}
408
else {
409
DISPATCH_FLOAT_AND_HALF(
410
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
411
multi_tensor_apply<1>(
412
BLOCK_SIZE,
413
chunk_size,
414
noop_flag,
415
tensor_lists,
416
L2NormFunctor<scalar_t_0>(),
417
output.DATA_PTR<float>(),
418
output_per_tensor.DATA_PTR<float>(),
419
true,
420
max_chunks_per_tensor);)
421
}
422
AT_CUDA_CHECK(cudaGetLastError());
423
424
// AT_CUDA_CHECK(cudaDeviceSynchronize());
425
426
// This involves one more small kernel launches, but will be negligible end to end.
427
// I could get rid of these by hacking the functor + multi tensor harness with persistence
428
// logic, but keeping it simple for now
429
auto ret = at::empty({1}, output.options());
430
431
// Adding the following device guard since it happens sometimes that the
432
// tensors are on one device and the cuda stream is on another device which
433
// results in ILLEGAL MEM ACCESS error.
434
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
435
auto stream = at::cuda::getCurrentCUDAStream();
436
cleanup_v2<<<ntensors, 512, 0, stream>>>(
437
output.DATA_PTR<float>(),
438
output_per_tensor.DATA_PTR<float>(),
439
ret.DATA_PTR<float>(),
440
out.DATA_PTR<float>(),
441
true,
442
max_chunks_per_tensor,
443
norm_type,
444
alpha,
445
beta);
446
447
return ;
448
}
449
450