CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/csrc/welford.cu
Views: 792
1
#include <iostream>
2
#include <ATen/ATen.h>
3
#include <ATen/AccumulateType.h>
4
#include <ATen/cuda/CUDAContext.h>
5
6
#include <cuda.h>
7
#include <cuda_runtime.h>
8
9
#include <vector>
10
11
#include "type_shim.h"
12
#include "compat.h"
13
14
15
__device__ __forceinline__ int lastpow2(int n)
16
{
17
int out = 1 << (31 - __clz(n));
18
if(n == out)
19
out >>= 1;
20
return out;
21
}
22
23
__host__ __forceinline__ int h_next_pow2(unsigned int n) {
24
n--;
25
n |= (n >> 1);
26
n |= (n >> 2);
27
n |= (n >> 4);
28
n |= (n >> 8);
29
n |= (n >> 16);
30
return ++n;
31
}
32
33
__host__ __forceinline__ int h_last_pow2(unsigned int n) {
34
n |= (n >> 1);
35
n |= (n >> 2);
36
n |= (n >> 4);
37
n |= (n >> 8);
38
n |= (n >> 16);
39
return n - (n >> 1);
40
}
41
42
43
#define WARP_SIZE 32
44
45
template<typename T>
46
__device__ __forceinline__ T warp_reduce_sum(T val)
47
{
48
#pragma unroll
49
for(int i = WARP_SIZE/2; i > 0; i >>= 1)
50
val = val + __shfl_down_sync(0xffffffff, val, i);
51
return val;
52
}
53
54
template<typename T>
55
__device__ __forceinline__ T reduce_block(T *x, T val)
56
{
57
int tid = threadIdx.y*blockDim.x + threadIdx.x;
58
int blockSize = blockDim.x * blockDim.y;
59
60
if (blockSize > 32) {
61
val = warp_reduce_sum(val);
62
if (tid % WARP_SIZE == 0)
63
x[tid/WARP_SIZE] = val;
64
65
__syncthreads();
66
67
val = (tid < blockSize / WARP_SIZE? x[tid%WARP_SIZE] : T(0));
68
}
69
70
if(tid/WARP_SIZE==0) val = warp_reduce_sum(val);
71
72
return val;
73
}
74
75
#define ELEMENTS_PER_ITER 4 // enables concurrency within each thread to hide latency
76
#define ELEMENTS_PER_THREAD 16
77
#define OPTIMAL_TILE_W 32
78
#define MAX_H_BLOCK 128
79
#define MAX_BLOCK_SIZE 512
80
81
__host__ int div_ru(int x, int y) {
82
return h_last_pow2(1 + (x-1)/y);
83
}
84
85
__host__ void flexible_launch_configs(
86
const int reduction,
87
const int stride,
88
dim3 &block,
89
dim3 &grid,
90
const bool coop_flag = false) {
91
int block_x = std::min(h_last_pow2(stride), OPTIMAL_TILE_W);
92
int block_y = std::min(h_last_pow2(div_ru(reduction , ELEMENTS_PER_THREAD)),
93
MAX_BLOCK_SIZE / block_x);
94
if (block_x * block_y != MAX_BLOCK_SIZE) {
95
block_x = std::min(h_last_pow2(stride), MAX_BLOCK_SIZE / block_y);
96
}
97
98
int grid_x = div_ru(stride, block_x);
99
int grid_y = std::min(div_ru(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK);
100
if (coop_flag) {
101
// it's not worth having a grid reduction if the reduction dimension is not big enough
102
grid_y = grid_y < 8 ? 1 : grid_y;
103
}
104
105
block.x = block_x;
106
block.y = block_y;
107
block.z = 1;
108
grid.x = grid_x;
109
grid.y = grid_y;
110
grid.z = 1;
111
}
112
113
template<typename T, typename C>
114
__device__ __forceinline__ void welford_merge_element(C& count,
115
T& mean,
116
T& m2n,
117
const C& num_new,
118
const T& mean_new,
119
const T& m2n_new) {
120
T factor = T(1.0) / max(1, (count + num_new));
121
T delta0 = mean - mean_new;
122
mean = (mean_new * num_new + mean * count) * factor;
123
m2n += m2n_new + delta0 * delta0 * num_new * count * factor;
124
count += num_new;
125
}
126
127
template<typename T>
128
__device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num)
129
{
130
#pragma unroll
131
for(int i = WARP_SIZE/2; i > 0; i >>= 1) {
132
auto num_new = __shfl_down_sync(0xffffffff, num, i);
133
auto mean_new = __shfl_down_sync(0xffffffff, mean, i);
134
auto m2n_new = __shfl_down_sync(0xffffffff, m2n, i);
135
welford_merge_element(num, mean, m2n, num_new, mean_new, m2n_new);
136
}
137
}
138
139
template <typename T>
140
__device__ void welford_reduce_mean_m2n(
141
T* __restrict__ x,
142
int* __restrict__ count,
143
T &mean,
144
T &m2n,
145
int &num,
146
int block_size,
147
int thread_id)
148
{
149
int lane = thread_id % WARP_SIZE;
150
int wid = thread_id / WARP_SIZE;
151
152
if (block_size > 32) {
153
warp_reduce_mean_m2n(mean, m2n, num);
154
if (lane == 0) {
155
x[wid*2] = mean;
156
x[wid*2+1] = m2n;
157
count[wid] = num;
158
}
159
__syncthreads();
160
161
if (wid == 0) {
162
mean = (thread_id < block_size / WARP_SIZE)? x[lane*2] : T(0);
163
m2n = (thread_id < block_size / WARP_SIZE)? x[lane*2+1] : T(0);
164
num = (thread_id < block_size / WARP_SIZE)? count[lane] : int(0);
165
}
166
}
167
168
if (wid==0) warp_reduce_mean_m2n(mean, m2n, num);
169
170
return;
171
}
172
173
// return spatial size for NC+ Tensors
174
__host__ int get_tensor_spatial_size(const at::Tensor& input)
175
{
176
auto space_size = input.size(2);
177
for (int i = 3; i < input.ndimension(); i++) {
178
space_size *= input.size(i);
179
}
180
return space_size;
181
}
182
183
// promote accumulation scalar type. promote half to float.
184
__host__ at::ScalarType promote_scalartype(const at::Tensor& input)
185
{
186
return input.scalar_type() == at::ScalarType::Half ?
187
at::ScalarType::Float : input.scalar_type();
188
}
189
190
// return single element size, optional accumulation type promotion.
191
__host__ size_t get_element_data_size(const at::Tensor& input, bool accumulation = false)
192
{
193
auto scalar_type = accumulation ? promote_scalartype(input) : input.scalar_type();
194
return at::elementSize(scalar_type);
195
}
196
197
template<typename T, typename C>
198
__device__ __forceinline__ void welford_merge_block_vertical(C& count,
199
T& mean,
200
T& m2n,
201
C* shmem_count,
202
T* shmem_mean,
203
T* shmem_m2n) {
204
// write to shared memory
205
auto address_base = threadIdx.x + threadIdx.y * blockDim.x;
206
shmem_mean[address_base] = mean;
207
shmem_m2n[address_base] = m2n;
208
shmem_count[address_base] = count;
209
210
#pragma unroll
211
for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
212
__syncthreads();
213
if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
214
auto address = address_base + offset * blockDim.x;
215
// read shared memory back to register for reduction
216
auto num_new = shmem_count[address];
217
auto mean_new = shmem_mean[address];
218
auto m2n_new = shmem_m2n[address];
219
220
welford_merge_element(count, mean, m2n, num_new, mean_new, m2n_new);
221
222
// last write is not necessary
223
shmem_mean[address_base] = mean;
224
shmem_m2n[address_base] = m2n;
225
shmem_count[address_base] = count;
226
}
227
}
228
}
229
230
template<typename T>
231
__device__ __forceinline__ void merge_block_vertical(T& sum_dy,
232
T& sum_dy_xmu,
233
T* shmem_sum_dy,
234
T* shmem_sum_dy_xmu) {
235
// write to shared memory
236
auto address_base = threadIdx.x + threadIdx.y * blockDim.x;
237
shmem_sum_dy[address_base] = sum_dy;
238
shmem_sum_dy_xmu[address_base] = sum_dy_xmu;
239
240
#pragma unroll
241
for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
242
__syncthreads();
243
if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
244
auto address = address_base + offset * blockDim.x;
245
246
sum_dy += shmem_sum_dy[address];
247
sum_dy_xmu += shmem_sum_dy_xmu[address];
248
249
// last write is not necessary
250
shmem_sum_dy[address_base] = sum_dy;
251
shmem_sum_dy_xmu[address_base] = sum_dy_xmu;
252
}
253
}
254
}
255
256
257
// welford kernel calculating mean/biased_variance/unbiased_variance
258
template <typename scalar_t, typename accscalar_t, typename outscalar_t>
259
__global__ void welford_kernel(
260
const scalar_t* __restrict__ input,
261
outscalar_t* __restrict__ out_mean,
262
outscalar_t* __restrict__ out_var_biased,
263
const int bs,
264
const int fs,
265
const int ss) {
266
int block_size = blockDim.x * blockDim.y;
267
int count = 0;
268
accscalar_t x_mean = accscalar_t(0);
269
accscalar_t m_2_n = accscalar_t(0);
270
271
int thread_id = threadIdx.y*blockDim.x + threadIdx.x;
272
273
for (int batch_id = threadIdx.y; batch_id < bs; batch_id += blockDim.y) {
274
int input_base = blockIdx.x*ss + batch_id*ss*fs;
275
// sequential welford
276
for (int offset = threadIdx.x; offset < ss ; offset += blockDim.x) {
277
count++;
278
auto x_n = static_cast<accscalar_t>(input[offset+input_base]);
279
auto d = x_n - x_mean;
280
x_mean += d / count;
281
m_2_n += d * (x_n - x_mean);
282
}
283
}
284
285
static __shared__ int s_mem[160];
286
accscalar_t* s_mem_ac = (accscalar_t*) &s_mem[32];
287
288
welford_reduce_mean_m2n<accscalar_t>(s_mem_ac, s_mem, x_mean, m_2_n, count, block_size, thread_id);
289
290
if (thread_id == 0) {
291
out_mean[blockIdx.x] = static_cast<outscalar_t>(x_mean);
292
out_var_biased[blockIdx.x] = static_cast<outscalar_t>(m_2_n/count);
293
}
294
}
295
296
// elementwise BN kernel
297
template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
298
__global__ void batchnorm_forward_kernel(
299
const scalar_t* __restrict__ input,
300
const accscalar_t* __restrict__ mean,
301
const accscalar_t* __restrict__ inv_std,
302
const layerscalar_t* __restrict__ weight,
303
const layerscalar_t* __restrict__ shift,
304
scalar_t* __restrict__ out,
305
const int ss,
306
const int bs) {
307
auto m_c = mean[blockIdx.x];
308
auto inv_std_c = inv_std[blockIdx.x];
309
auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[blockIdx.x]);
310
auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[blockIdx.x]);
311
312
for (int batch_offset = blockIdx.y*blockDim.y + threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) {
313
int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss;
314
for (int offset = threadIdx.x + blockIdx.z*blockDim.x; offset < ss ; offset+= gridDim.z*blockDim.x) {
315
out[address_base+offset] = static_cast<scalar_t>(w_c * (static_cast<accscalar_t>(input[address_base+offset]) - m_c ) * inv_std_c + s_c);
316
}
317
}
318
}
319
320
// Backward BN kernel, calculates grad_bias, grad_weight as well as intermediate
321
// results to calculating grad_input.
322
// Breaking the grad_input to two step to support sync BN, which requires all
323
// reduce of the intermediate results across processes.
324
template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
325
__global__ void reduce_bn_kernel(
326
const scalar_t* __restrict__ input,
327
const scalar_t* __restrict__ grad_output,
328
const accscalar_t* __restrict__ mean,
329
const accscalar_t* __restrict__ inv_std,
330
accscalar_t* __restrict__ sum_dy_o,
331
accscalar_t* __restrict__ sum_dy_xmu_o,
332
layerscalar_t* __restrict__ grad_weight,
333
layerscalar_t* __restrict__ grad_bias,
334
const int bs,
335
const int fs,
336
const int ss) {
337
static __shared__ int s_mem[64];
338
//int total_item_num = bs * ss;
339
340
int thread_id = threadIdx.y*blockDim.x + threadIdx.x;
341
342
auto r_mean = mean[blockIdx.x];
343
auto factor = inv_std[blockIdx.x];
344
345
// Kahan sum
346
accscalar_t sum_dy = 0.0;
347
accscalar_t sum_dy_xmu = 0.0;
348
accscalar_t sum_dy_c = 0.0;
349
accscalar_t sum_dy_xmu_c = 0.0;
350
for (int batch_id = threadIdx.y; batch_id < bs; batch_id += blockDim.y) {
351
int input_base = blockIdx.x*ss + batch_id*ss*fs;
352
for (int offset = threadIdx.x; offset < ss ; offset += blockDim.x) {
353
auto e_grad = static_cast<accscalar_t>(grad_output[offset+input_base]);
354
auto e_input = static_cast<accscalar_t>(input[offset+input_base]);
355
// calculating sum_dy
356
auto sum_dy_y = e_grad - sum_dy_c;
357
auto sum_dy_t = sum_dy + sum_dy_y;
358
sum_dy_c = (sum_dy_t - sum_dy) - sum_dy_y;
359
sum_dy = sum_dy_t;
360
361
// calculating sum_dy_xmu
362
auto sum_dy_xmu_y = e_grad * (e_input - r_mean) - sum_dy_xmu_c;
363
auto sum_dy_xmu_t = sum_dy_xmu + sum_dy_xmu_y;
364
sum_dy_xmu_c = (sum_dy_xmu_t - sum_dy_xmu) - sum_dy_xmu_y;
365
sum_dy_xmu = sum_dy_xmu_t;
366
}
367
}
368
369
sum_dy = reduce_block((accscalar_t*)s_mem, sum_dy);
370
__syncthreads();
371
sum_dy_xmu = reduce_block((accscalar_t*)s_mem, sum_dy_xmu);
372
373
if (thread_id == 0) {
374
if (grad_bias != NULL) {
375
grad_bias[blockIdx.x] = static_cast<layerscalar_t>(sum_dy);
376
}
377
if (grad_weight != NULL) {
378
grad_weight[blockIdx.x] = static_cast<layerscalar_t>(sum_dy_xmu * factor);
379
}
380
//mean_dy[blockIdx.x] = sum_dy / total_item_num;
381
//mean_dy_xmu[blockIdx.x] = sum_dy_xmu / total_item_num;
382
sum_dy_o[blockIdx.x] = sum_dy;
383
sum_dy_xmu_o[blockIdx.x] = sum_dy_xmu;
384
}
385
}
386
387
// elementwise backward BN kernel
388
template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
389
__global__ void batchnorm_backward_kernel(
390
const scalar_t* __restrict__ grad_output,
391
const scalar_t* __restrict__ input,
392
const accscalar_t* __restrict__ mean,
393
const accscalar_t* __restrict__ inv_std,
394
const layerscalar_t* __restrict__ weight,
395
const accscalar_t* __restrict__ sum_dy,
396
const accscalar_t* __restrict__ sum_dy_xmu,
397
const int* __restrict__ numel,
398
scalar_t* __restrict__ grad_input,
399
const int64_t world_size,
400
const int ss,
401
const int bs) {
402
int64_t div = 0;
403
for (int i = 0; i < world_size; i++) {
404
div += numel[i];
405
}
406
auto m_c = static_cast<accscalar_t>(mean[blockIdx.x]);
407
//auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]);
408
auto m_dy_c = static_cast<accscalar_t>(sum_dy[blockIdx.x]) / div;
409
auto factor_1_c = inv_std[blockIdx.x];
410
auto factor_2_c = (weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[blockIdx.x])) * factor_1_c;
411
//factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[blockIdx.x];
412
factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[blockIdx.x] / div;
413
414
for (int batch_offset = blockIdx.y*blockDim.y+threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) {
415
int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss;
416
for (int offset = threadIdx.x + blockIdx.z*blockDim.x; offset < ss ; offset+= gridDim.z*blockDim.x) {
417
grad_input[address_base+offset] = (static_cast<accscalar_t>(grad_output[address_base+offset]) - m_dy_c - (static_cast<accscalar_t>(input[address_base+offset]) - m_c) * factor_1_c) * factor_2_c;
418
}
419
}
420
}
421
422
// welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance
423
template
424
<typename scalar_t,
425
typename accscalar_t,
426
typename outscalar_t,
427
int PARALLEL_LOADS>
428
__global__ void
429
welford_kernel_c_last(
430
const scalar_t* __restrict__ input,
431
outscalar_t* __restrict__ out_mean,
432
outscalar_t* __restrict__ out_var_biased,
433
volatile accscalar_t* staging_data,
434
int* semaphores,
435
const int reduction_size,
436
const int stride) {
437
// hide latency with concurrency
438
accscalar_t x_mean[PARALLEL_LOADS];
439
accscalar_t m_2_n[PARALLEL_LOADS];
440
int count[PARALLEL_LOADS];
441
442
#pragma unroll
443
for (int i = 0; i < PARALLEL_LOADS; i++) {
444
x_mean[i] = accscalar_t(0);
445
m_2_n[i] = accscalar_t(0);
446
count[i] = accscalar_t(0);
447
}
448
// tensor dimension (m,c)
449
450
// loop along m dimension
451
int inner_loop_stride = blockDim.y * gridDim.y;
452
453
// offset along m dimension
454
int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
455
int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
456
457
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
458
int address_base = m_offset * stride + c_offset;
459
int address_increment = inner_loop_stride * stride;
460
461
for (int i = 0; i < loop_count; i++) {
462
accscalar_t x_math[PARALLEL_LOADS];
463
accscalar_t x_count_inv[PARALLEL_LOADS];
464
accscalar_t is_valid[PARALLEL_LOADS];
465
466
// load multiple data in
467
#pragma unroll
468
for (int j = 0; j < PARALLEL_LOADS; j++) {
469
if (c_offset < stride && m_offset < reduction_size) {
470
x_math[j] = input[address_base];
471
count[j]++;
472
x_count_inv[j] = accscalar_t(1) / count[j];
473
is_valid[j] = accscalar_t(1);
474
} else {
475
x_math[j] = accscalar_t(0);
476
x_count_inv[j] = accscalar_t(0);
477
is_valid[j] = accscalar_t(0);
478
}
479
m_offset += inner_loop_stride;
480
address_base += address_increment;
481
}
482
483
// calculate mean/m2n with welford
484
#pragma unroll
485
for (int j = 0; j < PARALLEL_LOADS; j++) {
486
accscalar_t delta0 = x_math[j] - x_mean[j];
487
x_mean[j] += delta0 * x_count_inv[j];
488
accscalar_t delta1 = x_math[j] - x_mean[j];
489
m_2_n[j] += delta0 * delta1 * is_valid[j];
490
}
491
}
492
493
// thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS
494
#pragma unroll
495
for (int j = 1; j < PARALLEL_LOADS; j++) {
496
welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]);
497
}
498
499
// release x_mean / m_2_n
500
auto mean_th = x_mean[0];
501
auto m2_th = m_2_n[0];
502
auto count_th = count[0];
503
504
// block-wise reduction with shared memory (since reduction cannot be done within a warp)
505
static __shared__ accscalar_t shmem_mean[MAX_BLOCK_SIZE];
506
static __shared__ accscalar_t shmem_m2n[MAX_BLOCK_SIZE];
507
static __shared__ int shmem_count[MAX_BLOCK_SIZE];
508
509
welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
510
511
// grid reduction if needed (coop launch used at the first place)
512
if (gridDim.y > 1) {
513
volatile accscalar_t* staging_mean = staging_data;
514
volatile accscalar_t* staging_m2n = &staging_data[stride*gridDim.y];
515
volatile int* staging_count = reinterpret_cast<volatile int*>(&staging_m2n[stride*gridDim.y]);
516
517
address_base = c_offset + blockIdx.y * stride;
518
// write data to staging_data;
519
if (threadIdx.y == 0 && c_offset < stride) {
520
staging_mean[address_base] = mean_th;
521
staging_m2n[address_base] = m2_th;
522
staging_count[address_base] = count_th;
523
}
524
525
__threadfence();
526
__syncthreads(); // ensuring writes to staging_ is visible to all blocks
527
528
__shared__ bool is_last_block_done;
529
// mark block done
530
if (threadIdx.x == 0 && threadIdx.y == 0) {
531
int old = atomicAdd(&semaphores[blockIdx.x], 1);
532
is_last_block_done = (old == (gridDim.y-1));
533
}
534
535
__syncthreads();
536
537
// check that all data is now available in global memory
538
if (is_last_block_done) {
539
count_th = 0;
540
mean_th = accscalar_t(0.0);
541
m2_th = accscalar_t(0.0);
542
543
for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
544
address_base = c_offset + y * stride;
545
int num_new = c_offset < stride ? staging_count[address_base] : 0;
546
accscalar_t mean_new = c_offset < stride ? staging_mean[address_base] : accscalar_t(0.0);
547
accscalar_t m2n_new = c_offset < stride ? staging_m2n[address_base] : accscalar_t(0.0);
548
549
welford_merge_element(count_th, mean_th, m2_th, num_new, mean_new, m2n_new);
550
}
551
552
welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
553
if (threadIdx.y == 0 && c_offset < stride) {
554
out_mean[c_offset] = static_cast<outscalar_t>(mean_th);
555
out_var_biased[c_offset] = static_cast<outscalar_t>(m2_th / count_th);
556
}
557
}
558
} else {
559
if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
560
out_mean[c_offset] = static_cast<outscalar_t>(mean_th);
561
out_var_biased[c_offset] = static_cast<outscalar_t>(m2_th / count_th);
562
}
563
}
564
}
565
566
// parallel welford kernel to further reduce mean / biased_var
567
// into mean / unbiased_var / inv_std across multiple processes.
568
template <typename scalar_t>
569
__global__ void welford_kernel_parallel(
570
const scalar_t* __restrict__ mean,
571
const scalar_t* __restrict__ var_biased,
572
const int* __restrict__ numel,
573
scalar_t* __restrict__ out_mean,
574
scalar_t* __restrict__ out_var,
575
scalar_t* __restrict__ inv_std,
576
const int world_size,
577
const int feature_size,
578
const float eps) {
579
580
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < feature_size; i += gridDim.x * blockDim.x) {
581
// load data;
582
int address = i;
583
scalar_t x_mean = 0;
584
scalar_t m_2_n = 0;
585
int count = 0;
586
for (int j = 0; j < world_size; j++) {
587
welford_merge_element(count, x_mean, m_2_n, numel[j], mean[address], var_biased[address]*numel[j]);
588
address += feature_size;
589
}
590
out_mean[i] = x_mean;
591
out_var[i] = m_2_n/ (count - 1);
592
inv_std[i] = scalar_t(1) / sqrt(m_2_n/count + eps);
593
}
594
}
595
596
// elementwise BN kernel
597
template <
598
typename scalar_t,
599
typename accscalar_t,
600
typename layerscalar_t,
601
int PARALLEL_LOADS>
602
__global__ void batchnorm_forward_c_last_kernel(
603
const scalar_t* __restrict__ input,
604
const scalar_t* __restrict__ z,
605
const accscalar_t* __restrict__ mean,
606
const accscalar_t* __restrict__ inv_std,
607
const layerscalar_t* __restrict__ weight,
608
const layerscalar_t* __restrict__ shift,
609
scalar_t* __restrict__ out,
610
const int reduction_size,
611
const int stride,
612
const bool fuse_relu) {
613
// tensor dimension (m,c)
614
// loop along m dimension
615
int inner_loop_stride = blockDim.y * gridDim.y;
616
617
// offset along m dimension
618
int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
619
int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
620
621
auto m_c = mean[c_offset];
622
auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);
623
auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);
624
auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);
625
626
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
627
int address_base = m_offset * stride + c_offset;
628
int address_increment = inner_loop_stride * stride;
629
630
for (int i = 0; i < loop_count; i++) {
631
#pragma unroll
632
for (int j = 0; j < PARALLEL_LOADS; j++) {
633
if (c_offset < stride && m_offset < reduction_size) {
634
auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
635
if (z != NULL) {
636
tmp += z[address_base];
637
}
638
out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast<scalar_t>(tmp));
639
}
640
m_offset += inner_loop_stride;
641
address_base += address_increment;
642
}
643
}
644
}
645
646
// elementwise BN kernel
647
template <
648
typename scalar_t,
649
typename accscalar_t,
650
typename layerscalar_t,
651
int PARALLEL_LOADS>
652
__global__ void relu_backward_c_last_kernel(
653
const scalar_t* __restrict__ grad_output,
654
const scalar_t* __restrict__ input,
655
const scalar_t* __restrict__ z,
656
const accscalar_t* __restrict__ mean,
657
const accscalar_t* __restrict__ inv_std,
658
const layerscalar_t* __restrict__ weight,
659
const layerscalar_t* __restrict__ shift,
660
scalar_t* __restrict__ out,
661
const int reduction_size,
662
const int stride) {
663
// tensor dimension (m,c)
664
// loop along m dimension
665
int inner_loop_stride = blockDim.y * gridDim.y;
666
667
// offset along m dimension
668
int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
669
int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
670
671
auto m_c = mean[c_offset];
672
auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);
673
auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);
674
auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);
675
676
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
677
int address_base = m_offset * stride + c_offset;
678
int address_increment = inner_loop_stride * stride;
679
680
for (int i = 0; i < loop_count; i++) {
681
#pragma unroll
682
for (int j = 0; j < PARALLEL_LOADS; j++) {
683
if (c_offset < stride && m_offset < reduction_size) {
684
auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
685
if (z != NULL) {
686
tmp += z[address_base];
687
}
688
out[address_base] = (tmp <= accscalar_t(0.0) ? scalar_t(0.0) : grad_output[address_base]);
689
}
690
m_offset += inner_loop_stride;
691
address_base += address_increment;
692
}
693
}
694
}
695
696
// batchnorm backward kernel for c last tensor
697
template
698
<typename scalar_t,
699
typename accscalar_t,
700
typename layerscalar_t,
701
int PARALLEL_LOADS>
702
__global__ void reduce_bn_c_last_kernel(
703
const scalar_t* __restrict__ input,
704
const scalar_t* __restrict__ grad_output,
705
const accscalar_t* __restrict__ mean,
706
const accscalar_t* __restrict__ inv_std,
707
accscalar_t* __restrict__ sum_dy_o,
708
accscalar_t* __restrict__ sum_dy_xmu_o,
709
layerscalar_t* __restrict__ grad_weight,
710
layerscalar_t* __restrict__ grad_bias,
711
volatile accscalar_t* staging_data,
712
int* semaphores,
713
const int reduction_size,
714
const int stride) {
715
716
// hide latency with concurrency
717
accscalar_t sum_dy[PARALLEL_LOADS];
718
accscalar_t sum_dy_xmu[PARALLEL_LOADS];
719
720
#pragma unroll
721
for (int i = 0; i < PARALLEL_LOADS; i++) {
722
sum_dy[i] = accscalar_t(0);
723
sum_dy_xmu[i] = accscalar_t(0);
724
}
725
// tensor dimension (m,c)
726
727
// loop along m dimension
728
int inner_loop_stride = blockDim.y * gridDim.y;
729
730
// offset along m dimension
731
int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
732
int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
733
734
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
735
int address_base = m_offset * stride + c_offset;
736
int address_increment = inner_loop_stride * stride;
737
738
auto r_mean = mean[c_offset];
739
auto factor = inv_std[c_offset];
740
741
for (int i = 0; i < loop_count; i++) {
742
accscalar_t x_input[PARALLEL_LOADS];
743
accscalar_t x_grad_output[PARALLEL_LOADS];
744
745
// load multiple data in
746
#pragma unroll
747
for (int j = 0; j < PARALLEL_LOADS; j++) {
748
if (c_offset < stride && m_offset < reduction_size) {
749
x_input[j] = input[address_base];
750
x_grad_output[j] = grad_output[address_base];
751
} else {
752
x_input[j] = accscalar_t(0);
753
x_grad_output[j] = accscalar_t(0);
754
}
755
m_offset += inner_loop_stride;
756
address_base += address_increment;
757
}
758
759
// calculate sum_dy / sum_dy_xmu
760
#pragma unroll
761
for (int j = 0; j < PARALLEL_LOADS; j++) {
762
sum_dy[j] += x_grad_output[j];
763
sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean);
764
}
765
}
766
767
// thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS
768
#pragma unroll
769
for (int j = 1; j < PARALLEL_LOADS; j++) {
770
sum_dy[0] += sum_dy[j];
771
sum_dy_xmu[0] += sum_dy_xmu[j];
772
}
773
774
// release array of registers
775
auto sum_dy_th = sum_dy[0];
776
auto sum_dy_xmu_th = sum_dy_xmu[0];
777
778
// block-wise reduction with shared memory (since reduction cannot be done within a warp)
779
static __shared__ accscalar_t shmem_sum_dy[MAX_BLOCK_SIZE];
780
static __shared__ accscalar_t shmem_sum_dy_xmu[MAX_BLOCK_SIZE];
781
782
merge_block_vertical(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
783
784
// grid reduction if needed (coop launch used at the first place)
785
if (gridDim.y > 1) {
786
volatile accscalar_t* staging_sum_dy = staging_data;
787
volatile accscalar_t* staging_sum_dy_xmu = &staging_data[stride*gridDim.y];
788
789
address_base = c_offset + blockIdx.y * stride;
790
// write data to staging_data;
791
if (threadIdx.y == 0 && c_offset < stride) {
792
staging_sum_dy[address_base] = sum_dy_th;
793
staging_sum_dy_xmu[address_base] = sum_dy_xmu_th;
794
}
795
796
__threadfence();
797
__syncthreads(); // ensuring writes to staging_ is visible to all blocks
798
799
__shared__ bool is_last_block_done;
800
// mark block done
801
if (threadIdx.x == 0 && threadIdx.y == 0) {
802
int old = atomicAdd(&semaphores[blockIdx.x], 1);
803
is_last_block_done = (old == (gridDim.y-1));
804
}
805
806
__syncthreads();
807
808
// check that all data is now available in global memory
809
if (is_last_block_done) {
810
sum_dy_th = accscalar_t(0.0);
811
sum_dy_xmu_th = accscalar_t(0.0);
812
813
for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
814
address_base = c_offset + y * stride;
815
sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : accscalar_t(0.0));
816
sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : accscalar_t(0.0));
817
}
818
819
merge_block_vertical(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
820
if (threadIdx.y == 0 && c_offset < stride) {
821
if (grad_bias != NULL) {
822
grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
823
}
824
if (grad_weight != NULL) {
825
grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
826
}
827
//mean_dy[c_offset] = sum_dy_th / reduction_size;
828
//mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
829
sum_dy_o[c_offset] = sum_dy_th;
830
sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
831
}
832
}
833
} else {
834
if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
835
if (grad_bias != NULL) {
836
grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
837
}
838
if (grad_weight != NULL) {
839
grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
840
}
841
//mean_dy[c_offset] = sum_dy_th / reduction_size;
842
//mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
843
sum_dy_o[c_offset] = sum_dy_th;
844
sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
845
}
846
}
847
}
848
849
// elementwise BN kernel
850
template <
851
typename scalar_t,
852
typename accscalar_t,
853
typename layerscalar_t,
854
int PARALLEL_LOADS>
855
__global__ void batchnorm_backward_c_last_kernel(
856
const scalar_t* __restrict__ grad_output,
857
const scalar_t* __restrict__ input,
858
const accscalar_t* __restrict__ mean,
859
const accscalar_t* __restrict__ inv_std,
860
const layerscalar_t* __restrict__ weight,
861
const accscalar_t* __restrict__ sum_dy,
862
const accscalar_t* __restrict__ sum_dy_xmu,
863
const int* __restrict__ numel,
864
scalar_t* __restrict__ grad_input,
865
const int64_t world_size,
866
const int reduction_size,
867
const int stride) {
868
int64_t div = 0;
869
for (int i = 0; i < world_size; i++) {
870
div += numel[i];
871
}
872
// tensor dimension (m,c)
873
// loop along m dimension
874
int inner_loop_stride = blockDim.y * gridDim.y;
875
876
// offset along m dimension
877
int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
878
int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
879
880
auto m_c = mean[c_offset];
881
auto m_dy_c = sum_dy[c_offset] / div;
882
auto factor_1_c = inv_std[c_offset];
883
auto factor_2_c = (weight == NULL? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset])) * factor_1_c;
884
factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] / div;
885
886
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
887
int address_base = m_offset * stride + c_offset;
888
int address_increment = inner_loop_stride * stride;
889
890
for (int i = 0; i < loop_count; i++) {
891
#pragma unroll
892
for (int j = 0; j < PARALLEL_LOADS; j++) {
893
if (c_offset < stride && m_offset < reduction_size) {
894
grad_input[address_base] = static_cast<scalar_t>(
895
(static_cast<accscalar_t>(grad_output[address_base]) - m_dy_c -
896
(static_cast<accscalar_t>(input[address_base]) - m_c) * factor_1_c)
897
* factor_2_c);
898
}
899
m_offset += inner_loop_stride;
900
address_base += address_increment;
901
}
902
}
903
}
904
905
std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
906
const auto batch_size = input.size(0);
907
const auto feature_size = input.size(1);
908
909
auto space_size = get_tensor_spatial_size(input);
910
auto scalar_type = promote_scalartype(input);
911
912
at::Tensor out_var_biased = at::empty({feature_size}, input.options().dtype(scalar_type));
913
at::Tensor out_mean = at::empty({feature_size}, input.options().dtype(scalar_type));
914
915
int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / 32));
916
int block_x = max(1, min(MAX_BLOCK_SIZE / block_y, h_last_pow2(space_size)));
917
const dim3 block(block_x, block_y);
918
const dim3 grid(feature_size);
919
920
auto stream = at::cuda::getCurrentCUDAStream();
921
922
{
923
using namespace at;
924
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "welford_mean_var_kernel",
925
using accscalar_t = at::acc_type<scalar_t_0, true>;
926
welford_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
927
input.DATA_PTR<scalar_t_0>(),
928
out_mean.DATA_PTR<accscalar_t>(),
929
out_var_biased.DATA_PTR<accscalar_t>(),
930
batch_size,
931
feature_size,
932
space_size);
933
);
934
}
935
936
return {out_mean, out_var_biased};
937
}
938
939
at::Tensor batchnorm_forward_CUDA(
940
const at::Tensor input,
941
const at::Tensor mean,
942
const at::Tensor inv_std,
943
const at::optional<at::Tensor> weight,
944
const at::optional<at::Tensor> shift) {
945
const auto batch_size = input.size(0);
946
const auto feature_size = input.size(1);
947
at::Tensor out = at::empty_like(input);
948
949
auto space_size = get_tensor_spatial_size(input);
950
951
int block_x = max(32, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4));
952
int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4));
953
const dim3 block(block_x, block_y);
954
int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x));
955
int batch_group_size = max(1, min(65535, h_last_pow2(batch_size)/block_y));
956
const dim3 grid(feature_size, batch_group_size, grid_z);
957
auto stream = at::cuda::getCurrentCUDAStream();
958
959
if (input.scalar_type() == at::ScalarType::Half
960
&& weight.has_value() &&
961
weight.value().scalar_type() == at::ScalarType::Float) {
962
using namespace at;
963
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
964
using accscalar_t = at::acc_type<scalar_t_0, true>;
965
batchnorm_forward_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
966
input.DATA_PTR<scalar_t_0>(),
967
mean.DATA_PTR<accscalar_t>(),
968
inv_std.DATA_PTR<accscalar_t>(),
969
weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,
970
shift.has_value() ? shift.value().DATA_PTR<accscalar_t>() : NULL,
971
out.DATA_PTR<scalar_t_0>(),
972
space_size,
973
batch_size);
974
);
975
} else {
976
if (weight.has_value()) {
977
TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
978
"input.scalar_type() is not supported with weight.scalar_type()");
979
}
980
using namespace at;
981
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
982
using accscalar_t = at::acc_type<scalar_t_0, true>;
983
batchnorm_forward_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(
984
input.DATA_PTR<scalar_t_0>(),
985
mean.DATA_PTR<accscalar_t>(),
986
inv_std.DATA_PTR<accscalar_t>(),
987
weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,
988
shift.has_value() ? shift.value().DATA_PTR<scalar_t_0>() : NULL,
989
out.DATA_PTR<scalar_t_0>(),
990
space_size,
991
batch_size);
992
);
993
}
994
return out;
995
}
996
997
std::vector<at::Tensor> reduce_bn_CUDA(
998
const at::Tensor grad_output,
999
const at::Tensor input,
1000
const at::Tensor mean,
1001
const at::Tensor inv_std,
1002
const at::optional<at::Tensor> weight)
1003
{
1004
const auto batch_size = input.size(0);
1005
const auto feature_size = input.size(1);
1006
1007
auto scalar_type = promote_scalartype(input);
1008
1009
at::Tensor sum_dy = at::empty({feature_size}, mean.options());
1010
at::Tensor sum_dy_xmu = at::empty({feature_size}, mean.options());
1011
1012
at::Tensor grad_weight;
1013
at::Tensor grad_bias;
1014
if (weight.has_value()) {
1015
grad_weight = at::empty({feature_size}, weight.value().options());
1016
grad_bias = at::empty({feature_size}, weight.value().options());
1017
} else {
1018
grad_weight = at::empty({0}, mean.options());
1019
grad_bias = at::empty({0}, mean.options());
1020
}
1021
1022
auto space_size = get_tensor_spatial_size(input);
1023
1024
int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE/ 32));
1025
int block_x = max(1, min(MAX_BLOCK_SIZE/ block_y, h_last_pow2(space_size)));
1026
const dim3 block(block_x, block_y);
1027
const dim3 grid(feature_size);
1028
auto stream = at::cuda::getCurrentCUDAStream();
1029
1030
if (input.scalar_type() == at::ScalarType::Half
1031
&& weight.has_value() &&
1032
weight.value().scalar_type() == at::ScalarType::Float) {
1033
using namespace at;
1034
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
1035
using accscalar_t = at::acc_type<scalar_t_0, true>;
1036
reduce_bn_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
1037
input.DATA_PTR<scalar_t_0>(),
1038
grad_output.DATA_PTR<scalar_t_0>(),
1039
mean.DATA_PTR<accscalar_t>(),
1040
inv_std.DATA_PTR<accscalar_t>(),
1041
sum_dy.DATA_PTR<accscalar_t>(),
1042
sum_dy_xmu.DATA_PTR<accscalar_t>(),
1043
weight.has_value() ? grad_weight.DATA_PTR<accscalar_t>() : NULL,
1044
weight.has_value() ? grad_bias.DATA_PTR<accscalar_t>() : NULL,
1045
batch_size,
1046
feature_size,
1047
space_size);
1048
);
1049
} else {
1050
if (weight.has_value()) {
1051
TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
1052
"input.scalar_type() is not supported with weight.scalar_type()");
1053
}
1054
using namespace at;
1055
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
1056
using accscalar_t = at::acc_type<scalar_t_0, true>;
1057
reduce_bn_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(
1058
input.DATA_PTR<scalar_t_0>(),
1059
grad_output.DATA_PTR<scalar_t_0>(),
1060
mean.DATA_PTR<accscalar_t>(),
1061
inv_std.DATA_PTR<accscalar_t>(),
1062
sum_dy.DATA_PTR<accscalar_t>(),
1063
sum_dy_xmu.DATA_PTR<accscalar_t>(),
1064
weight.has_value() ? grad_weight.DATA_PTR<scalar_t_0>() : NULL,
1065
weight.has_value() ? grad_bias.DATA_PTR<scalar_t_0>() : NULL,
1066
batch_size,
1067
feature_size,
1068
space_size);
1069
);
1070
}
1071
1072
return {sum_dy, sum_dy_xmu, grad_weight, grad_bias};
1073
}
1074
1075
at::Tensor batchnorm_backward_CUDA(
1076
const at::Tensor grad_output,
1077
const at::Tensor input,
1078
const at::Tensor mean,
1079
const at::Tensor inv_std,
1080
const at::optional<at::Tensor> weight,
1081
const at::Tensor sum_dy,
1082
const at::Tensor sum_dy_xmu,
1083
const at::Tensor count) {
1084
const auto batch_size = input.size(0);
1085
const auto feature_size = input.size(1);
1086
1087
at::Tensor grad_input = at::empty_like(input);
1088
1089
auto space_size = get_tensor_spatial_size(input);
1090
1091
int block_x = max(32, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4));
1092
int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4));
1093
const dim3 block(block_x, block_y);
1094
int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x));
1095
int batch_group_size = max(1, min(65535, h_last_pow2(batch_size)/block_y));
1096
const dim3 grid(feature_size, batch_group_size, grid_z);
1097
1098
auto stream = at::cuda::getCurrentCUDAStream();
1099
1100
if (input.scalar_type() == at::ScalarType::Half
1101
&& weight.has_value() &&
1102
weight.value().scalar_type() == at::ScalarType::Float) {
1103
using namespace at;
1104
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward",
1105
using accscalar_t = at::acc_type<scalar_t_0, true>;
1106
batchnorm_backward_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
1107
grad_output.DATA_PTR<scalar_t_0>(),
1108
input.DATA_PTR<scalar_t_0>(),
1109
mean.DATA_PTR<accscalar_t>(),
1110
inv_std.DATA_PTR<accscalar_t>(),
1111
weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,
1112
sum_dy.DATA_PTR<accscalar_t>(),
1113
sum_dy_xmu.DATA_PTR<accscalar_t>(),
1114
count.DATA_PTR<int>(),
1115
grad_input.DATA_PTR<scalar_t_0>(),
1116
count.numel(),
1117
space_size,
1118
batch_size);
1119
);
1120
} else {
1121
if (weight.has_value()) {
1122
TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
1123
"input.scalar_type() is not supported with weight.scalar_type()");
1124
}
1125
using namespace at;
1126
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward",
1127
using accscalar_t = at::acc_type<scalar_t_0, true>;
1128
batchnorm_backward_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(
1129
grad_output.DATA_PTR<scalar_t_0>(),
1130
input.DATA_PTR<scalar_t_0>(),
1131
mean.DATA_PTR<accscalar_t>(),
1132
inv_std.DATA_PTR<accscalar_t>(),
1133
weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,
1134
sum_dy.DATA_PTR<accscalar_t>(),
1135
sum_dy_xmu.DATA_PTR<accscalar_t>(),
1136
count.DATA_PTR<int>(),
1137
grad_input.DATA_PTR<scalar_t_0>(),
1138
count.numel(),
1139
space_size,
1140
batch_size);
1141
);
1142
}
1143
1144
return grad_input;
1145
}
1146
1147
std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes,
1148
const at::Tensor var_biased,
1149
const at::Tensor numel,
1150
const float eps) {
1151
const auto world_size = mean_feature_nodes.size(0);
1152
const auto feature_size = mean_feature_nodes.size(1);
1153
1154
at::Tensor out_var = at::empty({feature_size}, var_biased.options());
1155
at::Tensor inv_std = at::empty_like(out_var);
1156
at::Tensor out_mean = at::empty_like(out_var);
1157
1158
at::Tensor mean_feature_nodes_ = mean_feature_nodes.contiguous();
1159
at::Tensor var_biased_ = var_biased.contiguous();
1160
at::Tensor numel_ = numel.contiguous();
1161
1162
// TODO(jie): tile this for memory coalescing!
1163
const int block = std::min(h_last_pow2(feature_size), MAX_BLOCK_SIZE);
1164
const int grid = std::max<int>(1, feature_size / block);
1165
1166
auto stream = at::cuda::getCurrentCUDAStream();
1167
1168
{
1169
using namespace at;
1170
DISPATCH_FLOAT_AND_HALF(mean_feature_nodes.scalar_type(), 0, "welford_parallel_kernel",
1171
welford_kernel_parallel<scalar_t_0><<<grid, block, 0, stream>>>(
1172
mean_feature_nodes_.DATA_PTR<scalar_t_0>(),
1173
var_biased_.DATA_PTR<scalar_t_0>(),
1174
numel_.DATA_PTR<int>(),
1175
out_mean.DATA_PTR<scalar_t_0>(),
1176
out_var.DATA_PTR<scalar_t_0>(),
1177
inv_std.DATA_PTR<scalar_t_0>(),
1178
world_size,
1179
feature_size,
1180
eps);
1181
);
1182
}
1183
1184
return {out_mean, out_var, inv_std};
1185
}
1186
1187
std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) {
1188
const auto stride = input.size(input.ndimension()-1);
1189
const auto reduction_size = input.numel() / stride;
1190
1191
auto scalar_type = promote_scalartype(input);
1192
auto option = input.options().dtype(scalar_type);
1193
1194
at::Tensor out_var_biased = at::empty({stride}, option);
1195
at::Tensor out_mean = at::empty({stride}, option);
1196
1197
dim3 block;
1198
dim3 grid;
1199
flexible_launch_configs(reduction_size, stride, block, grid, true);
1200
1201
at::Tensor staging_data;
1202
at::Tensor semaphores;
1203
if (grid.y > 1) {
1204
staging_data = at::empty({4*stride*grid.y}, option);
1205
semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
1206
}
1207
1208
auto stream = at::cuda::getCurrentCUDAStream();
1209
1210
{
1211
using namespace at;
1212
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "welford_mean_var_c_last",
1213
using accscalar_t = at::acc_type<scalar_t_0, true>;
1214
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.DATA_PTR<accscalar_t>() : nullptr;
1215
int* semaphores_ptr = grid.y > 1 ? semaphores.DATA_PTR<int>() : nullptr;
1216
welford_kernel_c_last<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
1217
<<<grid, block, 0, stream>>>(
1218
input.DATA_PTR<scalar_t_0>(),
1219
out_mean.DATA_PTR<accscalar_t>(),
1220
out_var_biased.DATA_PTR<accscalar_t>(),
1221
staging_data_ptr,
1222
semaphores_ptr,
1223
reduction_size,
1224
stride);
1225
);
1226
}
1227
1228
return {out_mean, out_var_biased};
1229
}
1230
1231
at::Tensor batchnorm_forward_c_last_CUDA(
1232
const at::Tensor input,
1233
const at::optional<at::Tensor> z,
1234
const at::Tensor mean,
1235
const at::Tensor inv_std,
1236
const at::optional<at::Tensor> weight,
1237
const at::optional<at::Tensor> shift,
1238
const bool fuse_relu) {
1239
const auto stride = input.size(input.ndimension()-1);
1240
const auto reduction_size = input.numel() / stride;
1241
1242
at::Tensor out = at::empty_like(input);
1243
1244
dim3 block;
1245
dim3 grid;
1246
flexible_launch_configs(reduction_size, stride, block, grid);
1247
1248
auto stream = at::cuda::getCurrentCUDAStream();
1249
1250
if (input.scalar_type() == at::ScalarType::Half
1251
&& weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {
1252
using namespace at;
1253
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
1254
using accscalar_t = at::acc_type<scalar_t_0, true>;
1255
batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
1256
<<<grid, block, 0, stream>>>(
1257
input.DATA_PTR<scalar_t_0>(),
1258
z.has_value() ? z.value().DATA_PTR<scalar_t_0>() : NULL,
1259
mean.DATA_PTR<accscalar_t>(),
1260
inv_std.DATA_PTR<accscalar_t>(),
1261
weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,
1262
shift.has_value() ? shift.value().DATA_PTR<accscalar_t>(): NULL,
1263
out.DATA_PTR<scalar_t_0>(),
1264
reduction_size,
1265
stride,
1266
fuse_relu);
1267
);
1268
} else {
1269
if (weight.has_value()) {
1270
TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
1271
"input.scalar_type() is not supported with weight.scalar_type()");
1272
}
1273
using namespace at;
1274
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
1275
using accscalar_t = at::acc_type<scalar_t_0, true>;
1276
batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
1277
<<<grid, block, 0, stream>>>(
1278
input.DATA_PTR<scalar_t_0>(),
1279
z.has_value() ? z.value().DATA_PTR<scalar_t_0>() : NULL,
1280
mean.DATA_PTR<accscalar_t>(),
1281
inv_std.DATA_PTR<accscalar_t>(),
1282
weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,
1283
shift.has_value() ? shift.value().DATA_PTR<scalar_t_0>(): NULL,
1284
out.DATA_PTR<scalar_t_0>(),
1285
reduction_size,
1286
stride,
1287
fuse_relu);
1288
);
1289
}
1290
return out;
1291
}
1292
1293
std::vector<at::Tensor> reduce_bn_c_last_CUDA(
1294
const at::Tensor grad_output,
1295
const at::Tensor input,
1296
const at::Tensor mean,
1297
const at::Tensor inv_std,
1298
const at::optional<at::Tensor> weight) {
1299
const auto stride = input.size(input.ndimension()-1);
1300
const auto reduction_size = input.numel() / stride;
1301
1302
at::Tensor sumn_dy = at::empty({stride}, mean.options());
1303
at::Tensor sum_dy_xmu = at::empty({stride}, mean.options());
1304
1305
at::Tensor grad_weight;
1306
at::Tensor grad_bias;
1307
if (weight.has_value()) {
1308
grad_weight = at::empty({stride}, weight.value().options());
1309
grad_bias = at::empty({stride}, weight.value().options());
1310
} else {
1311
// because I cannot return an uninitialized at::Tensor
1312
grad_weight = at::empty({0}, mean.options());
1313
grad_bias = at::empty({0}, mean.options());
1314
}
1315
1316
dim3 block;
1317
dim3 grid;
1318
flexible_launch_configs(reduction_size, stride, block, grid, true);
1319
1320
at::Tensor staging_data;
1321
at::Tensor semaphores;
1322
if (grid.y > 1) {
1323
staging_data = at::empty({2*stride*grid.y}, mean.options());
1324
semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
1325
}
1326
auto stream = at::cuda::getCurrentCUDAStream();
1327
1328
if (input.scalar_type() == at::ScalarType::Half
1329
&& weight.has_value()
1330
&& weight.value().scalar_type() == at::ScalarType::Float) {
1331
using namespace at;
1332
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
1333
using accscalar_t = at::acc_type<scalar_t_0, true>;
1334
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.DATA_PTR<accscalar_t>() : nullptr;
1335
int* semaphores_ptr = grid.y > 1 ? semaphores.DATA_PTR<int>() : nullptr;
1336
reduce_bn_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
1337
<<<grid, block, 0, stream>>>(
1338
input.DATA_PTR<scalar_t_0>(),
1339
grad_output.DATA_PTR<scalar_t_0>(),
1340
mean.DATA_PTR<accscalar_t>(),
1341
inv_std.DATA_PTR<accscalar_t>(),
1342
sumn_dy.DATA_PTR<accscalar_t>(),
1343
sum_dy_xmu.DATA_PTR<accscalar_t>(),
1344
weight.has_value() ? grad_weight.DATA_PTR<accscalar_t>() : NULL,
1345
weight.has_value() ?grad_bias.DATA_PTR<accscalar_t>() : NULL,
1346
staging_data_ptr,
1347
semaphores_ptr,
1348
reduction_size,
1349
stride);
1350
);
1351
} else {
1352
if (weight.has_value()) {
1353
TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
1354
"input.scalar_type() is not supported with weight.scalar_type()");
1355
}
1356
using namespace at;
1357
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
1358
using accscalar_t = at::acc_type<scalar_t_0, true>;
1359
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.DATA_PTR<accscalar_t>() : nullptr;
1360
int* semaphores_ptr = grid.y > 1 ? semaphores.DATA_PTR<int>() : nullptr;
1361
reduce_bn_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
1362
<<<grid, block, 0, stream>>>(
1363
input.DATA_PTR<scalar_t_0>(),
1364
grad_output.DATA_PTR<scalar_t_0>(),
1365
mean.DATA_PTR<accscalar_t>(),
1366
inv_std.DATA_PTR<accscalar_t>(),
1367
sumn_dy.DATA_PTR<accscalar_t>(),
1368
sum_dy_xmu.DATA_PTR<accscalar_t>(),
1369
weight.has_value() ? grad_weight.DATA_PTR<scalar_t_0>() : NULL,
1370
weight.has_value() ?grad_bias.DATA_PTR<scalar_t_0>() : NULL,
1371
staging_data_ptr,
1372
semaphores_ptr,
1373
reduction_size,
1374
stride);
1375
);
1376
}
1377
1378
return {sumn_dy, sum_dy_xmu, grad_weight, grad_bias};
1379
}
1380
1381
at::Tensor batchnorm_backward_c_last_CUDA(
1382
const at::Tensor grad_output,
1383
const at::Tensor input,
1384
const at::Tensor mean,
1385
const at::Tensor inv_std,
1386
const at::optional<at::Tensor> weight,
1387
const at::Tensor sum_dy,
1388
const at::Tensor sum_dy_xmu,
1389
const at::Tensor count) {
1390
const auto stride = input.size(input.ndimension()-1);
1391
const auto reduction_size = input.numel() / stride;
1392
1393
at::Tensor grad_input = at::empty_like(input);
1394
1395
dim3 block;
1396
dim3 grid;
1397
flexible_launch_configs(reduction_size, stride, block, grid);
1398
1399
auto stream = at::cuda::getCurrentCUDAStream();
1400
1401
if (input.scalar_type() == at::ScalarType::Half
1402
&& weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {
1403
using namespace at;
1404
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
1405
using accscalar_t = at::acc_type<scalar_t_0, true>;
1406
batchnorm_backward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
1407
<<<grid, block, 0, stream>>>(
1408
grad_output.DATA_PTR<scalar_t_0>(),
1409
input.DATA_PTR<scalar_t_0>(),
1410
mean.DATA_PTR<accscalar_t>(),
1411
inv_std.DATA_PTR<accscalar_t>(),
1412
weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,
1413
sum_dy.DATA_PTR<accscalar_t>(),
1414
sum_dy_xmu.DATA_PTR<accscalar_t>(),
1415
count.DATA_PTR<int>(),
1416
grad_input.DATA_PTR<scalar_t_0>(),
1417
count.numel(),
1418
reduction_size,
1419
stride);
1420
);
1421
} else {
1422
if (weight.has_value()) {
1423
TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
1424
"input.scalar_type() is not supported with weight.scalar_type()");
1425
}
1426
using namespace at;
1427
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
1428
using accscalar_t = at::acc_type<scalar_t_0, true>;
1429
batchnorm_backward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
1430
<<<grid, block, 0, stream>>>(
1431
grad_output.DATA_PTR<scalar_t_0>(),
1432
input.DATA_PTR<scalar_t_0>(),
1433
mean.DATA_PTR<accscalar_t>(),
1434
inv_std.DATA_PTR<accscalar_t>(),
1435
weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,
1436
sum_dy.DATA_PTR<accscalar_t>(),
1437
sum_dy_xmu.DATA_PTR<accscalar_t>(),
1438
count.DATA_PTR<int>(),
1439
grad_input.DATA_PTR<scalar_t_0>(),
1440
count.numel(),
1441
reduction_size,
1442
stride);
1443
);
1444
}
1445
1446
return grad_input;
1447
}
1448
1449
at::Tensor relu_backward_c_last_CUDA(
1450
const at::Tensor grad_output,
1451
const at::Tensor input,
1452
const at::optional<at::Tensor> z,
1453
const at::Tensor mean,
1454
const at::Tensor inv_std,
1455
const at::optional<at::Tensor> weight,
1456
const at::optional<at::Tensor> shift) {
1457
1458
const auto stride = input.size(input.ndimension()-1);
1459
const auto reduction_size = input.numel() / stride;
1460
1461
at::Tensor out = at::empty_like(input);
1462
1463
dim3 block;
1464
dim3 grid;
1465
flexible_launch_configs(reduction_size, stride, block, grid);
1466
1467
auto stream = at::cuda::getCurrentCUDAStream();
1468
1469
if (input.scalar_type() == at::ScalarType::Half
1470
&& weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {
1471
using namespace at;
1472
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
1473
using accscalar_t = at::acc_type<scalar_t_0, true>;
1474
relu_backward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
1475
<<<grid, block, 0, stream>>>(
1476
grad_output.DATA_PTR<scalar_t_0>(),
1477
input.DATA_PTR<scalar_t_0>(),
1478
z.has_value() ? z.value().DATA_PTR<scalar_t_0>() : NULL,
1479
mean.DATA_PTR<accscalar_t>(),
1480
inv_std.DATA_PTR<accscalar_t>(),
1481
weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,
1482
shift.has_value() ? shift.value().DATA_PTR<accscalar_t>(): NULL,
1483
out.DATA_PTR<scalar_t_0>(),
1484
reduction_size,
1485
stride);
1486
);
1487
} else {
1488
if (weight.has_value()) {
1489
TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
1490
"input.scalar_type() is not supported with weight.scalar_type()");
1491
}
1492
using namespace at;
1493
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
1494
using accscalar_t = at::acc_type<scalar_t_0, true>;
1495
relu_backward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
1496
<<<grid, block, 0, stream>>>(
1497
grad_output.DATA_PTR<scalar_t_0>(),
1498
input.DATA_PTR<scalar_t_0>(),
1499
z.has_value() ? z.value().DATA_PTR<scalar_t_0>() : NULL,
1500
mean.DATA_PTR<accscalar_t>(),
1501
inv_std.DATA_PTR<accscalar_t>(),
1502
weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,
1503
shift.has_value() ? shift.value().DATA_PTR<scalar_t_0>(): NULL,
1504
out.DATA_PTR<scalar_t_0>(),
1505
reduction_size,
1506
stride);
1507
);
1508
}
1509
return out;
1510
}
1511
1512