CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/csrc/mlp_cuda.cu
Views: 792
1
#include <ATen/ATen.h>
2
#include <ATen/cuda/CUDAContext.h>
3
#include <assert.h>
4
#include <stdio.h>
5
#include <stdlib.h>
6
#include <string.h>
7
#include <torch/torch.h>
8
9
/* Includes, cuda */
10
#include <cublas_v2.h>
11
#include <cuda_runtime.h>
12
13
// includes cublaslt
14
#include <cublasLt.h>
15
16
// constants for fused bias+relu kernel
17
#define BIAS_RELU_FW_NTHREADS 128 // forward number of thread per block
18
#define BIAS_RELU_BW_NTHREADS_X 32 // backward number of thread in feature dim
19
#define BIAS_RELU_BW_NTHREADS_Y 16 // backward number of thread in batch dim
20
#define BIAS_RELU_RED_PER_THREAD 16 // backward minimal reduction length per thread
21
22
// move to a header later on
23
#define ILP 4
24
template<typename T>
25
__host__ __device__ __forceinline__ bool is_aligned(T* p){
26
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
27
}
28
29
template<typename T>
30
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
31
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
32
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
33
}
34
template<typename T>
35
__device__ __forceinline__ void load_store(T* dst, volatile T* src, int dst_offset, int src_offset){
36
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
37
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
38
}
39
template<typename T>
40
__device__ __forceinline__ void load_store(volatile T* dst, T* src, int dst_offset, int src_offset){
41
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
42
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
43
}
44
45
// Keep ReLU in float only. When using half, cast to float before calling.
46
__device__ __inline__ float relu(float a) {
47
float retf = max(a, 0.f);
48
return (retf);
49
}
50
51
// Keep Sigmoid in float only. When using half, cast to float before calling.
52
__device__ __inline__ float sigmoid(float a) {
53
float retf = 1.f / (1.f + expf(-a));
54
return (retf);
55
}
56
57
// FP64 Wrapper around cublas GEMMEx
58
cublasStatus_t mlp_gemm(
59
cublasHandle_t handle,
60
cublasOperation_t transa,
61
cublasOperation_t transb,
62
int m,
63
int n,
64
int k,
65
float* alpha,
66
const double* A,
67
int lda,
68
const double* B,
69
int ldb,
70
const float* beta,
71
double* C,
72
int ldc) {
73
return cublasGemmEx(
74
handle,
75
transa,
76
transb,
77
m,
78
n,
79
k,
80
alpha,
81
A,
82
CUDA_R_64F,
83
lda,
84
B,
85
CUDA_R_64F,
86
ldb,
87
beta,
88
C,
89
CUDA_R_64F,
90
ldc,
91
CUDA_R_64F,
92
CUBLAS_GEMM_DEFAULT);
93
}
94
95
// FP32 Wrapper around cublas GEMMEx
96
cublasStatus_t mlp_gemm(
97
cublasHandle_t handle,
98
cublasOperation_t transa,
99
cublasOperation_t transb,
100
int m,
101
int n,
102
int k,
103
float* alpha,
104
const float* A,
105
int lda,
106
const float* B,
107
int ldb,
108
const float* beta,
109
float* C,
110
int ldc) {
111
return cublasGemmEx(
112
handle,
113
transa,
114
transb,
115
m,
116
n,
117
k,
118
alpha,
119
A,
120
CUDA_R_32F,
121
lda,
122
B,
123
CUDA_R_32F,
124
ldb,
125
beta,
126
C,
127
CUDA_R_32F,
128
ldc,
129
CUDA_R_32F,
130
CUBLAS_GEMM_DEFAULT);
131
}
132
133
// FP16 Tensor core wrapper around cublas GEMMEx
134
cublasStatus_t mlp_gemm(
135
cublasHandle_t handle,
136
cublasOperation_t transa,
137
cublasOperation_t transb,
138
int m,
139
int n,
140
int k,
141
float* alpha,
142
const at::Half* A,
143
int lda,
144
const at::Half* B,
145
int ldb,
146
float* beta,
147
at::Half* C,
148
int ldc) {
149
return cublasGemmEx(
150
handle,
151
transa,
152
transb,
153
m,
154
n,
155
k,
156
alpha,
157
A,
158
CUDA_R_16F,
159
lda,
160
B,
161
CUDA_R_16F,
162
ldb,
163
beta,
164
C,
165
CUDA_R_16F,
166
ldc,
167
CUDA_R_32F,
168
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
169
}
170
171
int mlp_gemm_lt(
172
cublasLtHandle_t ltHandle,
173
cublasOperation_t transa,
174
cublasOperation_t transb,
175
int m,
176
int n,
177
int k,
178
float *alpha, /* host pointer */
179
const at::Half* A,
180
int lda,
181
const at::Half* B,
182
int ldb,
183
float *beta, /* host pointer */
184
at::Half* C,
185
int ldc,
186
void *workspace,
187
size_t workspaceSize,
188
cudaStream_t stream,
189
bool use_bias,
190
bool use_relu,
191
const void* bias) {
192
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
193
194
cublasLtMatmulDescOpaque_t operationDesc = {};
195
cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
196
cublasLtMatmulPreferenceOpaque_t preference = {};
197
198
int returnedResults = 0;
199
cublasLtMatmulHeuristicResult_t heuristicResult = {};
200
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
201
202
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
203
// for details about defaults; here we just set the transforms for
204
// A and B.
205
status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
206
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
207
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
208
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
209
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
210
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
211
212
if (use_bias) {
213
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
214
if (status != CUBLAS_STATUS_SUCCESS) {
215
goto CLEANUP;
216
}
217
if (use_relu) {
218
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
219
} else {
220
epilogue = CUBLASLT_EPILOGUE_BIAS;
221
}
222
} else {
223
if (use_relu) {
224
epilogue = CUBLASLT_EPILOGUE_RELU;
225
}
226
}
227
228
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
229
if (status != CUBLAS_STATUS_SUCCESS) {
230
goto CLEANUP;
231
}
232
233
// Create matrix descriptors. Not setting any extra attributes.
234
status = cublasLtMatrixLayoutInit(
235
&Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
236
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
237
status = cublasLtMatrixLayoutInit(
238
&Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
239
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
240
status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);
241
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
242
243
// Create preference handle; In general, extra attributes can be
244
// used here to disable tensor ops or to make sure algo selected
245
// will work with badly aligned A, B, C. However, for simplicity
246
// here we assume A,B,C are always well aligned (e.g., directly
247
// come from cudaMalloc)
248
status = cublasLtMatmulPreferenceInit(&preference);
249
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
250
status = cublasLtMatmulPreferenceSetAttribute(
251
&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
252
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
253
254
// We just need the best available heuristic to try and run matmul.
255
// There is no guarantee that this will work. For example, if A is
256
// badly aligned, you can request more (e.g. 32) algos and try to
257
// run them one by one until something works.
258
status = cublasLtMatmulAlgoGetHeuristic(
259
ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
260
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
261
262
if (returnedResults == 0) {
263
status = CUBLAS_STATUS_NOT_SUPPORTED;
264
goto CLEANUP;
265
}
266
status = cublasLtMatmul(ltHandle,
267
&operationDesc,
268
alpha,
269
A,
270
&Adesc,
271
B,
272
&Bdesc,
273
beta,
274
C,
275
&Cdesc,
276
C,
277
&Cdesc,
278
&heuristicResult.algo,
279
workspace,
280
workspaceSize,
281
stream);
282
283
CLEANUP:
284
// Descriptors are no longer needed as all GPU work was already
285
// enqueued.
286
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
287
}
288
289
int mlp_gemm_lt(
290
cublasLtHandle_t ltHandle,
291
cublasOperation_t transa,
292
cublasOperation_t transb,
293
int m,
294
int n,
295
int k,
296
float *alpha, /* host pointer */
297
const double* A,
298
int lda,
299
const double* B,
300
int ldb,
301
float *beta, /* host pointer */
302
double* C,
303
int ldc,
304
void *workspace,
305
size_t workspaceSize,
306
cudaStream_t stream,
307
bool use_bias,
308
bool use_relu,
309
const void* bias) {
310
return 1;
311
}
312
313
int mlp_gemm_lt(
314
cublasLtHandle_t ltHandle,
315
cublasOperation_t transa,
316
cublasOperation_t transb,
317
int m,
318
int n,
319
int k,
320
float *alpha, /* host pointer */
321
const float *A,
322
int lda,
323
const float *B,
324
int ldb,
325
float *beta, /* host pointer */
326
float *C,
327
int ldc,
328
void *workspace,
329
size_t workspaceSize,
330
cudaStream_t stream,
331
bool use_bias,
332
bool use_relu,
333
const void* bias) {
334
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
335
336
cublasLtMatmulDescOpaque_t operationDesc = {};
337
cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
338
cublasLtMatmulPreferenceOpaque_t preference = {};
339
340
int returnedResults = 0;
341
cublasLtMatmulHeuristicResult_t heuristicResult = {};
342
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
343
344
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
345
// for details about defaults; here we just set the transforms for
346
// A and B.
347
status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
348
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
349
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
350
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
351
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
352
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
353
354
if (use_bias) {
355
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
356
if (status != CUBLAS_STATUS_SUCCESS) {
357
goto CLEANUP;
358
}
359
if (use_relu) {
360
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
361
} else {
362
epilogue = CUBLASLT_EPILOGUE_BIAS;
363
}
364
} else {
365
if (use_relu) {
366
epilogue = CUBLASLT_EPILOGUE_RELU;
367
}
368
}
369
370
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
371
if (status != CUBLAS_STATUS_SUCCESS) {
372
goto CLEANUP;
373
}
374
375
// Create matrix descriptors. Not setting any extra attributes.
376
status = cublasLtMatrixLayoutInit(
377
&Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
378
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
379
status = cublasLtMatrixLayoutInit(
380
&Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
381
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
382
status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc);
383
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
384
385
// Create preference handle; In general, extra attributes can be
386
// used here to disable tensor ops or to make sure algo selected
387
// will work with badly aligned A, B, C. However, for simplicity
388
// here we assume A,B,C are always well aligned (e.g., directly
389
// come from cudaMalloc)
390
status = cublasLtMatmulPreferenceInit(&preference);
391
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
392
status = cublasLtMatmulPreferenceSetAttribute(
393
&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
394
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
395
396
// We just need the best available heuristic to try and run matmul.
397
// There is no guarantee that this will work. For example, if A is
398
// badly aligned, you can request more (e.g. 32) algos and try to
399
// run them one by one until something works.
400
status = cublasLtMatmulAlgoGetHeuristic(
401
ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
402
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
403
404
if (returnedResults == 0) {
405
status = CUBLAS_STATUS_NOT_SUPPORTED;
406
goto CLEANUP;
407
}
408
409
status = cublasLtMatmul(ltHandle,
410
&operationDesc,
411
alpha,
412
A,
413
&Adesc,
414
B,
415
&Bdesc,
416
beta,
417
C,
418
&Cdesc,
419
C,
420
&Cdesc,
421
&heuristicResult.algo,
422
workspace,
423
workspaceSize,
424
stream);
425
426
CLEANUP:
427
// Descriptors are no longer needed as all GPU work was already
428
// enqueued.
429
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
430
}
431
432
433
// Bias ADD. Assume input X is [features x batch size], column major.
434
// Bias is one 'features' long vector, with implicit broadcast.
435
template <typename T>
436
__global__ void biasAdd_fprop(T *X, T *b, uint batch_size, uint features) {
437
T r_x[ILP];
438
T r_b[ILP];
439
if(is_aligned(X) && is_aligned(b) && features % ILP ==0) {
440
int tid = blockIdx.x * blockDim.x + threadIdx.x;
441
for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
442
int row = tid % (features / ILP);
443
load_store(r_x, X, 0 , tid);
444
load_store(r_b, b, 0 , row);
445
#pragma unroll
446
for(int ii = 0; ii < ILP; ii++) {
447
float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);
448
r_x[ii] = bias_sum;
449
}
450
load_store(X, r_x, tid , 0);
451
}
452
} else {
453
int tid = blockIdx.x * blockDim.x + threadIdx.x;
454
for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
455
#pragma unroll
456
for(int ii = 0; ii < ILP; ii++) {
457
int idx = tid + ii * blockDim.x * gridDim.x;
458
if(idx < features * batch_size) {
459
int row = tid % features;
460
r_x[ii] = X[idx];
461
r_b[ii] = b[row];
462
}
463
}
464
#pragma unroll
465
for(int ii = 0; ii < ILP; ii++) {
466
float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);
467
r_x[ii] = bias_sum;
468
}
469
#pragma unroll
470
for(int ii = 0; ii < ILP; ii++) {
471
int idx = tid + ii * blockDim.x * gridDim.x;
472
if(idx < features * batch_size) {
473
X[idx] = r_x[ii];
474
}
475
}
476
}
477
}
478
}
479
480
// Bias ADD + ReLU. Assume input X is [features x batch size], column major.
481
// Activation support fuesed ReLU. Safe to call in-place.
482
template <typename T>
483
__global__ void biasAddRelu_fprop(T *X, T *b, uint batch_size, uint features) {
484
T r_x[ILP];
485
T r_b[ILP];
486
if(is_aligned(X) && is_aligned(b) && features % ILP ==0) {
487
int tid = blockIdx.x * blockDim.x + threadIdx.x;
488
for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
489
int row = tid % (features / ILP);
490
load_store(r_x, X, 0 , tid);
491
load_store(r_b, b, 0 , row);
492
#pragma unroll
493
for(int ii = 0; ii < ILP; ii++) {
494
float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);
495
r_x[ii] = relu(bias_sum);
496
}
497
load_store(X, r_x, tid , 0);
498
}
499
} else {
500
int tid = blockIdx.x * blockDim.x + threadIdx.x;
501
for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
502
#pragma unroll
503
for(int ii = 0; ii < ILP; ii++) {
504
int idx = tid + ii * blockDim.x * gridDim.x;
505
if(idx < features * batch_size) {
506
int row = tid % features;
507
r_x[ii] = X[idx];
508
r_b[ii] = b[row];
509
}
510
}
511
#pragma unroll
512
for(int ii = 0; ii < ILP; ii++) {
513
float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);
514
r_x[ii] = relu(bias_sum);
515
}
516
#pragma unroll
517
for(int ii = 0; ii < ILP; ii++) {
518
int idx = tid + ii * blockDim.x * gridDim.x;
519
if(idx < features * batch_size) {
520
X[idx] = r_x[ii];
521
}
522
}
523
}
524
}
525
}
526
527
// ReLU. Assume input X is [features x batch size], column major.
528
// Safe to call in-place.
529
template <typename T>
530
__global__ void Relu_fprop(T *X, uint batch_size, uint features) {
531
T r_x[ILP];
532
if(is_aligned(X) && features % ILP ==0) {
533
int tid = blockIdx.x * blockDim.x + threadIdx.x;
534
for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
535
load_store(r_x, X, 0 , tid);
536
#pragma unroll
537
for(int ii = 0; ii < ILP; ii++) {
538
r_x[ii] = relu(static_cast<float>(r_x[ii]));
539
}
540
load_store(X, r_x, tid , 0);
541
}
542
} else {
543
int tid = blockIdx.x * blockDim.x + threadIdx.x;
544
for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
545
#pragma unroll
546
for(int ii = 0; ii < ILP; ii++) {
547
int idx = tid + ii * blockDim.x * gridDim.x;
548
if(idx < features * batch_size) {
549
r_x[ii] = X[idx];
550
}
551
}
552
#pragma unroll
553
for(int ii = 0; ii < ILP; ii++) {
554
r_x[ii] = relu(static_cast<float>(r_x[ii]));
555
}
556
#pragma unroll
557
for(int ii = 0; ii < ILP; ii++) {
558
int idx = tid + ii * blockDim.x * gridDim.x;
559
if(idx < features * batch_size) {
560
X[idx] = r_x[ii];
561
}
562
}
563
}
564
}
565
}
566
567
// Sigmoid. Assume input X is [features x batch size], column major.
568
// Safe to call in-place.
569
template <typename T>
570
__global__ void Sigmoid_fprop(T *X, uint batch_size, uint features) {
571
T r_x[ILP];
572
if(is_aligned(X) && features % ILP ==0) {
573
int tid = blockIdx.x * blockDim.x + threadIdx.x;
574
for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
575
load_store(r_x, X, 0 , tid);
576
#pragma unroll
577
for(int ii = 0; ii < ILP; ii++) {
578
r_x[ii] = sigmoid(static_cast<float>(r_x[ii]));
579
}
580
load_store(X, r_x, tid , 0);
581
}
582
} else {
583
int tid = blockIdx.x * blockDim.x + threadIdx.x;
584
for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
585
#pragma unroll
586
for(int ii = 0; ii < ILP; ii++) {
587
int idx = tid + ii * blockDim.x * gridDim.x;
588
if(idx < features * batch_size) {
589
r_x[ii] = X[idx];
590
}
591
}
592
#pragma unroll
593
for(int ii = 0; ii < ILP; ii++) {
594
r_x[ii] = sigmoid(static_cast<float>(r_x[ii]));
595
}
596
#pragma unroll
597
for(int ii = 0; ii < ILP; ii++) {
598
int idx = tid + ii * blockDim.x * gridDim.x;
599
if(idx < features * batch_size) {
600
X[idx] = r_x[ii];
601
}
602
}
603
}
604
}
605
}
606
607
// ReLU. Assume input X is [features x batch size], column major.
608
// Safe to call in-place.
609
template <typename T>
610
__global__ void Relu_bprop(T *dY, T *Y, uint batch_size, uint features, T *dX) {
611
T r_dy[ILP];
612
T r_y[ILP];
613
if(is_aligned(dY) &&
614
is_aligned(Y) &&
615
is_aligned(dX) &&
616
features % ILP ==0) {
617
int tid = blockIdx.x * blockDim.x + threadIdx.x;
618
for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
619
load_store(r_dy, dY, 0 , tid);
620
load_store(r_y, Y, 0 , tid);
621
#pragma unroll
622
for(int ii=0;ii<ILP;ii++){
623
if ((float)r_y[ii] <= 0.f)
624
r_dy[ii] = 0;
625
}
626
load_store(dX, r_dy, tid, 0);
627
}
628
} else {
629
int tid = blockIdx.x * blockDim.x + threadIdx.x;
630
for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
631
#pragma unroll
632
for(int ii = 0; ii < ILP; ii++) {
633
int idx = tid + ii * blockDim.x * gridDim.x;
634
if(idx < features * batch_size) {
635
r_dy[ii] = dY[idx];
636
r_y[ii] = Y[idx];
637
}
638
}
639
#pragma unroll
640
for(int ii = 0; ii < ILP; ii++) {
641
if ((float)r_y[ii] <= 0.f)
642
r_dy[ii] = 0;
643
}
644
#pragma unroll
645
for(int ii = 0; ii < ILP; ii++) {
646
int idx = tid + ii * blockDim.x * gridDim.x;
647
if(idx < features * batch_size) {
648
dX[idx] = r_dy[ii];
649
}
650
}
651
}
652
}
653
}
654
655
// Sigmoid. Assume input X is [features x batch size], column major.
656
// Safe to call in-place.
657
template <typename T>
658
__global__ void Sigmoid_bprop(T *dY, T *Y, uint batch_size, uint features, T *dX) {
659
T r_dy[ILP];
660
T r_y[ILP];
661
if(is_aligned(dY) &&
662
is_aligned(Y) &&
663
is_aligned(dX) &&
664
features % ILP ==0) {
665
int tid = blockIdx.x * blockDim.x + threadIdx.x;
666
for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
667
load_store(r_dy, dY, 0 , tid);
668
load_store(r_y, Y, 0 , tid);
669
#pragma unroll
670
for(int ii=0;ii<ILP;ii++){
671
float grad_out = r_dy[ii];
672
float out = r_y[ii];
673
float grad_i = out * ( 1.f - out) * grad_out;
674
r_dy[ii] = grad_i;
675
}
676
load_store(dX, r_dy, tid, 0);
677
}
678
} else {
679
int tid = blockIdx.x * blockDim.x + threadIdx.x;
680
for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
681
#pragma unroll
682
for(int ii = 0; ii < ILP; ii++) {
683
int idx = tid + ii * blockDim.x * gridDim.x;
684
if(idx < features * batch_size) {
685
r_dy[ii] = dY[idx];
686
r_y[ii] = Y[idx];
687
}
688
}
689
#pragma unroll
690
for(int ii = 0; ii < ILP; ii++) {
691
float grad_out = r_dy[ii];
692
float out = r_y[ii];
693
float grad_i = out * ( 1.f - out) * grad_out;
694
r_dy[ii] = grad_i;
695
}
696
#pragma unroll
697
for(int ii = 0; ii < ILP; ii++) {
698
int idx = tid + ii * blockDim.x * gridDim.x;
699
if(idx < features * batch_size) {
700
dX[idx] = r_dy[ii];
701
}
702
}
703
}
704
}
705
}
706
707
// Compute grid size for pointwise backward kernel.
708
// block_x/y is total elment being handled per block, not number of threads
709
void get_biasAddRelu_bprop_grid_size(
710
int yfeat,
711
int batch_size,
712
int block_x,
713
int block_y,
714
int* grid_x,
715
int* grid_y) {
716
717
*grid_x = (yfeat + block_x - 1) / block_x;
718
// Get number of SMs for efficient reduction.
719
int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
720
// can switch to occupancy calculation. use 4 below now for sm_70
721
int max_blocks_y = (num_SMs * 4+(*grid_x)-1) / (*grid_x);
722
// block_y should be from minimal work per thread
723
int nRedSplits = (batch_size + block_y - 1) / block_y;
724
// increase number of elem per thread redcution to not launch more than enough
725
// kernel adjust work, so here we just launch max block
726
*grid_y = std::min(nRedSplits, max_blocks_y);
727
return;
728
}
729
730
// Addition done deterministically via a 2-pass approach. Each CTA writes out partial
731
// sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result.
732
template <typename T, int UNROLL_FACTOR>
733
__global__ void biasAdd_bprop(
734
T* dY,
735
int features,
736
int batch_size,
737
volatile float* intermediate,
738
int* semaphores,
739
T* db) {
740
// The feature that this thread is responsible for
741
int f = blockIdx.x * blockDim.x + threadIdx.x;
742
743
// Compute the span this thread is responsible for
744
// For this block
745
int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;
746
int b_nStart = blockIdx.y * b_chunkSize;
747
int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;
748
// For this thread
749
int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;
750
int nStart = threadIdx.y * chunkSize + b_nStart;
751
int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;
752
753
volatile float* out = intermediate + blockIdx.y * features;
754
755
// Flag to trigger last reduction.
756
__shared__ bool isLastBlock;
757
// we know block size for now
758
__shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y];
759
760
// Accumulate db in FP32 always
761
float db_local = 0;
762
if (f < features) {
763
int nidx = 0;
764
// Handle non-multiple of UNROLL_FACTOR residue
765
for (; nidx < nSpan % UNROLL_FACTOR; nidx++) {
766
int64_t row, col, flat_idx;
767
row = f;
768
col = nStart + nidx;
769
flat_idx = col * features + row;
770
db_local += (float)dY[flat_idx];
771
}
772
773
// Handle meat of work
774
for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) {
775
int64_t row, col, flat_idx;
776
row = f;
777
col = nStart + nidx;
778
flat_idx = col * features + row;
779
#pragma unroll 4
780
for (int u = 0; u < UNROLL_FACTOR; u++) {
781
db_local += (float)dY[flat_idx];
782
flat_idx += features;
783
}
784
}
785
786
// naive block reduction on y-dim
787
int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;
788
smem[linear_idx] = db_local;
789
}
790
__syncthreads();
791
if (f < features) {
792
if(threadIdx.y == 0) {
793
for(int yidx = 1; yidx < blockDim.y; yidx++){
794
db_local += smem[yidx * blockDim.x + threadIdx.x];
795
}
796
797
// block result is in db_local now for all threadIdx.y == 0
798
// Write out partial result
799
out[f] = db_local;
800
}
801
}
802
__threadfence();
803
__syncthreads();
804
805
// Increment semaphore and check if this is the last CTA in the grid_y dimension.
806
// Only thread (0,0) calls this
807
if (threadIdx.x == 0 && threadIdx.y == 0 && f < features) {
808
unsigned int sum_idx;
809
sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);
810
isLastBlock = (sum_idx == (gridDim.y - 1));
811
}
812
__syncthreads();
813
814
db_local = 0;
815
// No block reduction for now, only thread (*,0) do grid reduction
816
if (isLastBlock && f < features) {
817
if(threadIdx.y == 0) {
818
for (int n = 0; n < gridDim.y; n++) {
819
int row, col;
820
row = f;
821
col = n;
822
db_local += (float)(intermediate[col * features + row]);
823
}
824
db[f] = (T)db_local;
825
}
826
}
827
}
828
829
// Addition done deterministically via a 2-pass approach. Each CTA writes out partial
830
// sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result.
831
template <typename T, int UNROLL_FACTOR>
832
__global__ void biasAddRelu_bprop(
833
T* Y,
834
T* dY,
835
int features,
836
int batch_size,
837
T* dX,
838
volatile float* intermediate,
839
int* semaphores,
840
T* db) {
841
// The feature that this thread is responsible for
842
int f = blockIdx.x * blockDim.x + threadIdx.x;
843
844
// Compute the span this thread is responsible for
845
// For this block
846
int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;
847
int b_nStart = blockIdx.y * b_chunkSize;
848
int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;
849
// For this thread
850
int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;
851
int nStart = threadIdx.y * chunkSize + b_nStart;
852
int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;
853
854
volatile float* out = intermediate + blockIdx.y * features;
855
856
// Flag to trigger last reduction.
857
__shared__ bool isLastBlock;
858
// we know block size for now
859
__shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y];
860
861
// Accumulate db in FP32 always
862
float db_local = 0;
863
if (f < features) {
864
int nidx = 0;
865
// Handle non-multiple of UNROLL_FACTOR residue
866
for (; nidx < nSpan % UNROLL_FACTOR; nidx++) {
867
int row, col, flat_idx;
868
row = f;
869
col = nStart + nidx;
870
flat_idx = col * features + row;
871
T y_val = Y[flat_idx];
872
T dy_val = dY[flat_idx];
873
T dx_val;
874
if ((float)y_val > 0.f)
875
dx_val = dy_val;
876
else
877
dx_val = 0;
878
dX[flat_idx] = dx_val;
879
db_local += (float)dx_val;
880
}
881
882
// Handle meat of work
883
for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) {
884
int row, col, flat_idx;
885
row = f;
886
col = nStart + nidx;
887
flat_idx = col * features + row;
888
#pragma unroll 4
889
for (int u = 0; u < UNROLL_FACTOR; u++) {
890
T y_val = Y[flat_idx];
891
T dy_val = dY[flat_idx];
892
T dx_val;
893
if ((float)y_val > 0.f)
894
dx_val = dy_val;
895
else
896
dx_val = 0;
897
dX[flat_idx] = dx_val;
898
db_local += (float)dx_val;
899
flat_idx += features;
900
}
901
}
902
903
// naive block reduction on y-dim
904
int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;
905
smem[linear_idx] = db_local;
906
}
907
__syncthreads();
908
if (f < features) {
909
if(threadIdx.y == 0) {
910
for(int yidx = 1; yidx < blockDim.y; yidx++){
911
db_local += smem[yidx * blockDim.x + threadIdx.x];
912
}
913
914
// block result is in db_local now for all threadIdx.y == 0
915
// Write out partial result
916
out[f] = db_local;
917
}
918
}
919
__threadfence();
920
__syncthreads();
921
922
// Increment semaphore and check if this is the last CTA in the grid_y dimension.
923
// Only thread (0,0) calls this
924
if (threadIdx.x == 0 && threadIdx.y == 0 && f < features) {
925
unsigned int sum_idx;
926
sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);
927
isLastBlock = (sum_idx == (gridDim.y - 1));
928
}
929
__syncthreads();
930
931
db_local = 0;
932
// No block reduction for now, only thread (*,0) do grid reduction
933
if (isLastBlock && f < features) {
934
if(threadIdx.y == 0) {
935
for (int n = 0; n < gridDim.y; n++) {
936
int row, col;
937
row = f;
938
col = n;
939
db_local += (float)(intermediate[col * features + row]);
940
}
941
db[f] = (T)db_local;
942
}
943
}
944
}
945
946
// Addition done deterministically via a 2-pass approach. Each CTA writes out partial
947
// sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result.
948
template <typename T, int UNROLL_FACTOR>
949
__global__ void biasAddRelu_bprop_aligned(
950
T* Y,
951
T* dY,
952
int features,
953
int batch_size,
954
T* dX,
955
volatile float* intermediate,
956
int* semaphores,
957
T* db) {
958
// The feature that this thread is responsible for
959
int f = blockIdx.x * blockDim.x + threadIdx.x;
960
961
// Compute the span this thread is responsible for
962
// For this block
963
int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;
964
int b_nStart = blockIdx.y * b_chunkSize;
965
int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;
966
// For this thread
967
int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;
968
int nStart = threadIdx.y * chunkSize + b_nStart;
969
int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;
970
971
volatile float* out = intermediate + blockIdx.y * features;
972
973
// Flag to trigger last reduction.
974
__shared__ bool isLastBlock;
975
976
// Accumulate db in FP32 always
977
float db_local[ILP];
978
T r_y[ILP];
979
T r_dy[ILP];
980
#pragma unroll
981
for(int ii=0;ii<ILP;ii++){
982
db_local[ii] = 0.f;
983
}
984
985
// f always <= features in this case
986
//if (f < features) {
987
int nidx = 0;
988
989
// Handle non-multiple of UNROLL_FACTOR residue
990
for (; nidx < nSpan % UNROLL_FACTOR; nidx++) {
991
int row, col, flat_idx;
992
row = f;
993
col = nStart + nidx;
994
flat_idx = col * features / ILP + row;
995
996
load_store(r_y, Y, 0, flat_idx);
997
load_store(r_dy, dY, 0, flat_idx);
998
#pragma unroll
999
for(int ii=0;ii<ILP;ii++){
1000
if ((float)r_y[ii] <= 0.f)
1001
r_dy[ii] = 0;
1002
db_local[ii] += (float)r_dy[ii];
1003
}
1004
load_store(dX, r_dy, flat_idx, 0);
1005
}
1006
1007
// Handle meat of work
1008
for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) {
1009
int row, col, flat_idx;
1010
row = f;
1011
col = nStart + nidx;
1012
flat_idx = col * features / ILP + row; // total threads in x == features/ILP
1013
#pragma unroll
1014
for (int u = 0; u < UNROLL_FACTOR; u++) {
1015
load_store(r_y, Y, 0, flat_idx);
1016
load_store(r_dy, dY, 0, flat_idx);
1017
#pragma unroll
1018
for(int ii=0;ii<ILP;ii++){
1019
if ((float)r_y[ii] <= 0.f)
1020
r_dy[ii] = 0;
1021
db_local[ii] += (float)r_dy[ii];
1022
}
1023
load_store(dX, r_dy, flat_idx, 0);
1024
flat_idx += features/ILP;
1025
}
1026
}
1027
1028
// we know block size for now
1029
__shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y*ILP];
1030
// naive block reduction on y-dim
1031
int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;
1032
float* smem_out = smem + ILP * linear_idx;
1033
#pragma unroll
1034
for(int ii=0;ii<ILP;ii++){
1035
smem_out[ii] = db_local[ii]; // reuse local dy buffer
1036
}
1037
__syncthreads();
1038
if(threadIdx.y == 0) {
1039
for(int yidx = 1; yidx < blockDim.y; yidx++){
1040
float* smem_in = smem + ILP * (yidx * blockDim.x + threadIdx.x);
1041
#pragma unroll
1042
for(int ii=0;ii<ILP;ii++){
1043
db_local[ii] += smem_in[ii]; // reuse local dy buffer
1044
}
1045
}
1046
1047
// block result is in db_local now for all threadIdx.y == 0
1048
if(gridDim.y == 1) {
1049
#pragma unroll
1050
for(int ii=0;ii<ILP;ii++){
1051
r_dy[ii] = db_local[ii]; // reuse local dy buffer
1052
}
1053
load_store(db, r_dy, f, 0);
1054
return;
1055
}
1056
1057
// Write out partial result
1058
load_store(out, db_local, f, 0);
1059
}
1060
__threadfence();
1061
__syncthreads();
1062
1063
// Increment semaphore and check if this is the last CTA in the grid_y dimension.
1064
// Only thread (0,0) calls this
1065
if (threadIdx.x == 0 && threadIdx.y == 0) {
1066
unsigned int sum_idx;
1067
sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);
1068
isLastBlock = (sum_idx == (gridDim.y - 1));
1069
}
1070
__syncthreads();
1071
1072
#pragma unroll
1073
for(int ii=0;ii<ILP;ii++){
1074
db_local[ii] = 0.f;
1075
}
1076
float r_db[ILP];
1077
1078
// No block reduction for now, only thread (*,0) do grid reduction
1079
if (isLastBlock) {
1080
if(threadIdx.y == 0){
1081
for (int n = 0; n < gridDim.y; n++) {
1082
int row, col;
1083
row = f;
1084
col = n;
1085
load_store(r_db, intermediate, 0, col * features / ILP + row);
1086
#pragma unroll
1087
for(int ii=0;ii<ILP;ii++){
1088
db_local[ii] += r_db[ii];
1089
}
1090
}
1091
#pragma unroll
1092
for(int ii=0;ii<ILP;ii++){
1093
r_dy[ii] = db_local[ii]; // reuse local dy buffer
1094
}
1095
load_store(db, r_dy, f, 0);
1096
}
1097
}
1098
}
1099
1100
// Lists where the num_layers-1 intermediate Y buffers start in reserved space on fprop, starting
1101
// offset 0. The last Y value is, of course, stored in the user provided output buffer.
1102
void get_y_offsets(
1103
int batch_size,
1104
int num_layers,
1105
const int* output_features,
1106
int* y_start_offsets) {
1107
y_start_offsets[0] = 0;
1108
for (int i = 1; i < num_layers; i++) {
1109
y_start_offsets[i] = y_start_offsets[i - 1] + batch_size * output_features[i - 1];
1110
}
1111
}
1112
1113
// Returns the reserved space (in elements) needed for the MLP
1114
size_t get_mlp_reserved_space(int64_t batch_size, int num_layers, const int* output_features) {
1115
size_t res_space = 0;
1116
// Need to store output of every intermediate MLP - size equal to output_features[i] * batch_size
1117
// for all 'i' in [0, num_layers-1)
1118
for (int l = 0; l < num_layers; l++) {
1119
res_space += output_features[l] * batch_size;
1120
}
1121
return res_space;
1122
}
1123
1124
// Returns the size of all fprop activations combined
1125
size_t get_all_activations_size(int64_t batch_size, int num_layers, const int* output_features) {
1126
size_t acts_size = 0;
1127
for (int l = 0; l < num_layers; l++) {
1128
acts_size += output_features[l] * batch_size;
1129
}
1130
return acts_size;
1131
}
1132
1133
#if 0
1134
// Returns the work space (in elements) needed for the MLP bprop.
1135
size_t get_mlp_bp_workspace (int batch_size, int num_layers, const int* output_features) {
1136
/*
1137
Workspace is partitioned as
1138
DY_GEMMs : DX_GEMMs
1139
*/
1140
size_t work_space = 0;
1141
1142
// Store each intermediate dY explicitly. Need 2 dYs per MLP layer (one for o/p
1143
// of biasReLU_bp and one for o/p of dgrad GEMM).
1144
work_space += 2*get_all_activations_size(batch_size, num_layers, output_features);
1145
1146
return work_space;
1147
}
1148
#endif
1149
1150
// Scratch space needed for reductions in number of elements
1151
size_t get_reduction_scratch_space(int batch_size, int num_layers, const int* output_features) {
1152
size_t max_scratch_space = 0;
1153
// Loop over all layers to see which one needs the max scratch space
1154
for (int l = 0; l < num_layers; l++) {
1155
// need to find max(aligned, not_aligned)
1156
int tmp, res0, res1;
1157
1158
int block_x = BIAS_RELU_BW_NTHREADS_X;
1159
int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
1160
get_biasAddRelu_bprop_grid_size(
1161
output_features[l], batch_size, block_x, block_y, &tmp, &res0);
1162
1163
block_x = ILP * BIAS_RELU_BW_NTHREADS_X;
1164
get_biasAddRelu_bprop_grid_size(
1165
output_features[l], batch_size, block_x, block_y, &tmp, &res1);
1166
1167
max_scratch_space = std::max(max_scratch_space, (size_t)(output_features[l] * res0));
1168
max_scratch_space = std::max(max_scratch_space, (size_t)(output_features[l] * res1));
1169
}
1170
1171
return max_scratch_space;
1172
}
1173
1174
// Buffer for semaphores
1175
size_t get_semaphores_size(int num_layers, const int* output_features) {
1176
// Upper bound on semaphores is one per feature for the layer
1177
// with the most features.
1178
int max_features = 0;
1179
for (int l = 0; l < num_layers; l++) {
1180
max_features = std::max(max_features, output_features[l]);
1181
}
1182
return (size_t)max_features;
1183
}
1184
1185
// Returns the work space (in elements) needed for the MLP bprop.
1186
template <typename T>
1187
size_t get_mlp_bp_workspace_in_bytes(int batch_size, int num_layers, const int* output_features) {
1188
size_t work_space = 0;
1189
1190
// Store each intermediate dY explicitly. Need 2 dYs per MLP layer (one for o/p
1191
// of biasReLU_bp and one for o/p of dgrad GEMM).
1192
work_space += 2 * get_all_activations_size(batch_size, num_layers, output_features) * sizeof(T);
1193
work_space +=
1194
get_reduction_scratch_space(batch_size, num_layers, output_features) * sizeof(float);
1195
work_space += get_semaphores_size(num_layers, output_features) * sizeof(int);
1196
1197
return work_space;
1198
}
1199
1200
// Returns pointers to each segment of the workspace
1201
template <typename T>
1202
void partition_mlp_bp_workspace(
1203
int batch_size,
1204
int num_layers,
1205
const int* output_features,
1206
void* work_space,
1207
T** dy_gemms,
1208
T** dx_gemms,
1209
float** db_scratch,
1210
int** semaphores) {
1211
/*
1212
Workspace is partitioned as
1213
DY_GEMMs : DX_GEMMs : DB_SCRATCH : SEMAPHORES
1214
*/
1215
// Start address where dy_gemm tensors are stored
1216
*dy_gemms = reinterpret_cast<T*>(work_space);
1217
// Start address where dx_gemm tensors are stored
1218
*dx_gemms = *dy_gemms + get_all_activations_size(batch_size, num_layers, output_features);
1219
// Start address where db intermediate tensors are stored
1220
*db_scratch = reinterpret_cast<float*>(
1221
*dx_gemms + get_all_activations_size(batch_size, num_layers, output_features));
1222
// Start address of semaphores
1223
*semaphores = reinterpret_cast<int*>(
1224
*db_scratch + get_reduction_scratch_space(batch_size, num_layers, output_features));
1225
1226
return;
1227
}
1228
1229
// Does a simple MLP fprop (GEMM+bias+ReLU).
1230
// Can handle num_layers number of layers, each with its own shape. Output of layer i is assumed
1231
// to be input of layer i+1. output_features, WPtr and BPtr are arrays of length num_layers, and
1232
// must be in the same order i.e. WPtr[i] and BPtr[i] are respectively the weight and bias of layer
1233
// 'i'.
1234
template <typename T>
1235
int mlp_fp(
1236
T* X,
1237
int input_features,
1238
int batch_size,
1239
T** WPtr,
1240
int num_layers,
1241
int* output_features,
1242
T** BPtr,
1243
T* Y,
1244
T* reserved_space,
1245
int use_bias,
1246
int activation,
1247
void* lt_workspace) {
1248
T *weight, *input, *output, *bias;
1249
T *reserved_space_x, *reserved_space_y;
1250
reserved_space_x = NULL;
1251
reserved_space_y = reserved_space;
1252
1253
// Get cublas handle from Pytorch
1254
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
1255
// Get the stream from cublas handle to reuse for biasReLU kernel.
1256
cudaStream_t stream;
1257
cublasGetStream(handle, &stream);
1258
1259
for (int layer = 0; layer < num_layers; layer++) {
1260
weight = WPtr[layer];
1261
input = (layer == 0) ? X : reserved_space_x;
1262
output = (layer == num_layers - 1) ? Y : reserved_space_y;
1263
if (use_bias) {
1264
bias = BPtr[layer];
1265
}
1266
int ifeat = (layer == 0) ? input_features : output_features[layer - 1];
1267
int ofeat = output_features[layer];
1268
1269
float one = 1.f;
1270
float zero = 0.f;
1271
1272
// try with cublaslt first for supported case with valid handle
1273
int cublaslt_status = 1;
1274
if(activation < 1){
1275
cublaslt_status = mlp_gemm_lt(
1276
//ltHandle,
1277
(cublasLtHandle_t)handle,
1278
CUBLAS_OP_T,
1279
CUBLAS_OP_N,
1280
ofeat,
1281
batch_size,
1282
ifeat,
1283
&one,
1284
weight,
1285
ifeat,
1286
input,
1287
ifeat,
1288
&zero,
1289
output,
1290
ofeat,
1291
lt_workspace,
1292
1 << 22,
1293
stream,
1294
use_bias == 1,
1295
activation == 1,
1296
bias);
1297
}
1298
1299
// if cublaslt failed or not executed, fallback to cublas
1300
if (cublaslt_status != 0) {
1301
cublasStatus_t cublas_status;
1302
// Call GEMM: fprop is Y = W'X
1303
cublas_status = mlp_gemm(
1304
handle,
1305
CUBLAS_OP_T,
1306
CUBLAS_OP_N,
1307
ofeat,
1308
batch_size,
1309
ifeat,
1310
&one,
1311
weight,
1312
ifeat,
1313
input,
1314
ifeat,
1315
&zero,
1316
output,
1317
ofeat);
1318
1319
if (cublas_status != CUBLAS_STATUS_SUCCESS) {
1320
printf("GEMM fprop failed with %d\n", cublas_status);
1321
return 1;
1322
}
1323
1324
const uint &input_size = ofeat;
1325
int num_blocks = 0;
1326
int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
1327
// Call biasReLU
1328
if(use_bias == 1) {
1329
if (activation == 0) { // no activation
1330
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
1331
biasAdd_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
1332
} else if (activation == 1) { // relu
1333
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAddRelu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
1334
biasAddRelu_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
1335
} else if (activation == 2) { // sigmoid
1336
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
1337
biasAdd_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
1338
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
1339
Sigmoid_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
1340
}
1341
} else {
1342
// don't need to do anything in case of no activation and no bias
1343
if (activation == 1) { // relu
1344
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Relu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
1345
Relu_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
1346
} else if (activation == 2) { // sigmoid
1347
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
1348
Sigmoid_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
1349
}
1350
}
1351
}
1352
// Set current output as next layer input
1353
reserved_space_x = reserved_space_y;
1354
// Set next layer output
1355
reserved_space_y += ofeat * batch_size;
1356
}
1357
1358
return 0;
1359
}
1360
1361
// Does a simple MLP bprop (GEMM+bias+ReLU).
1362
// Needs reserved space to come back exactly as it was populated in fprop.
1363
// Does dgrad and wgrad sequentially.
1364
template <typename T>
1365
int mlp_bp(
1366
T* X,
1367
T* Y,
1368
int input_features,
1369
int batch_size,
1370
T** WPtr,
1371
int num_layers,
1372
int* output_features,
1373
T* dY,
1374
T* reserved_space,
1375
T* work_space,
1376
T* dX,
1377
T** dwPtr,
1378
T** dbPtr,
1379
bool requires_grad,
1380
int use_bias,
1381
int activation) {
1382
T* weight;
1383
T *dweight, *dx, *dy, *dbias;
1384
T *x, *y;
1385
1386
// Where the dx of the biasReLU (== dy of gemm) is stored. Can be thrown away
1387
// after bp call.
1388
T* dy_gemm_base;
1389
// Where the dx after GEMM is stored.
1390
T* dx_gemm_base;
1391
// Where partial reduction results are stored.
1392
float* db_scratch;
1393
// Semaphores for reduction.
1394
int* semaphores;
1395
1396
partition_mlp_bp_workspace<T>(
1397
batch_size,
1398
num_layers,
1399
output_features,
1400
work_space,
1401
&dy_gemm_base,
1402
&dx_gemm_base,
1403
&db_scratch,
1404
&semaphores);
1405
1406
size_t semaphore_size = get_semaphores_size(num_layers, output_features) * sizeof(int);
1407
1408
// Get cublas handle from Pytorch
1409
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
1410
// Get the stream from cublas handle to reuse for biasReLU kernel.
1411
cudaStream_t stream;
1412
cublasGetStream(handle, &stream);
1413
1414
int* y_offsets = (int*)malloc(num_layers * sizeof(int));
1415
get_y_offsets(batch_size, num_layers, output_features, y_offsets);
1416
1417
for (int layer = num_layers - 1; layer >= 0; layer--) {
1418
weight = WPtr[layer];
1419
dweight = dwPtr[layer];
1420
1421
// x is read from reserved space
1422
x = (layer == 0) ? X : reserved_space + y_offsets[layer - 1];
1423
// dx is written in workspace for all but layer==0
1424
dx = (layer == 0) ? dX : dx_gemm_base + y_offsets[layer - 1];
1425
1426
// y is read from reserved space
1427
y = (layer == num_layers - 1) ? Y : reserved_space + y_offsets[layer];
1428
// dx from layer+1
1429
dy = (layer == num_layers - 1) ? dY : dx_gemm_base + y_offsets[layer];
1430
// dy_gemm is written to and read immediately
1431
T* dy_gemm = dy_gemm_base + y_offsets[layer];
1432
1433
dbias = dbPtr[layer];
1434
int xfeat = (layer == 0) ? input_features : output_features[layer - 1];
1435
int yfeat = output_features[layer];
1436
1437
float one = 1.f;
1438
float zero = 0.f;
1439
1440
if (use_bias == 1) {
1441
if (activation == 0) { // no acitvation
1442
// bgrad
1443
dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);
1444
int grid_x, grid_y;
1445
cudaMemsetAsync(semaphores, 0, semaphore_size, stream);
1446
1447
int block_x = BIAS_RELU_BW_NTHREADS_X;
1448
int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
1449
get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
1450
dim3 grid(grid_x, grid_y);
1451
biasAdd_bprop<T, 4><<<grid, block, 0, stream>>>(
1452
dy, yfeat, batch_size, db_scratch, semaphores, dbias);
1453
// bypass dgrad through reset pointer
1454
dy_gemm = dy;
1455
} else if (activation == 1) { // relu
1456
dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);
1457
int grid_x, grid_y;
1458
cudaMemsetAsync(semaphores, 0, semaphore_size, stream);
1459
1460
if(yfeat % (ILP * BIAS_RELU_BW_NTHREADS_X) == 0 &&
1461
is_aligned(y) &&
1462
is_aligned(dy) &&
1463
is_aligned(dy_gemm) &&
1464
is_aligned(dbias)){
1465
int block_x = ILP * BIAS_RELU_BW_NTHREADS_X;
1466
int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
1467
get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
1468
dim3 grid(grid_x, grid_y);
1469
biasAddRelu_bprop_aligned<T, 4><<<grid, block, 0, stream>>>(
1470
y, dy, yfeat, batch_size, dy_gemm, db_scratch, semaphores, dbias);
1471
} else {
1472
int block_x = BIAS_RELU_BW_NTHREADS_X;
1473
int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
1474
get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
1475
dim3 grid(grid_x, grid_y);
1476
biasAddRelu_bprop<T, 4><<<grid, block, 0, stream>>>(
1477
y, dy, yfeat, batch_size, dy_gemm, db_scratch, semaphores, dbias);
1478
}
1479
} else if (activation == 2) { // sigmoid
1480
// activation backward
1481
int num_blocks = 0;
1482
int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
1483
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);
1484
Sigmoid_bprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);
1485
1486
// bgrad, from dy_gemm
1487
dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);
1488
int grid_x, grid_y;
1489
cudaMemsetAsync(semaphores, 0, semaphore_size, stream);
1490
1491
int block_x = BIAS_RELU_BW_NTHREADS_X;
1492
int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
1493
get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
1494
dim3 grid(grid_x, grid_y);
1495
biasAdd_bprop<T, 4><<<grid, block, 0, stream>>>(
1496
dy_gemm, yfeat, batch_size, db_scratch, semaphores, dbias);
1497
}
1498
} else { // no bias below
1499
if (activation == 0) {
1500
// bypass dgrad through reset pointer
1501
dy_gemm = dy;
1502
} else if (activation == 1) { // relu
1503
int num_blocks = 0;
1504
int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
1505
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Relu_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);
1506
Relu_bprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);
1507
} else if (activation == 2) { // sigmoid
1508
int num_blocks = 0;
1509
int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
1510
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);
1511
Sigmoid_bprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);
1512
}
1513
}
1514
1515
cublasStatus_t cublas_status;
1516
// Call GEMM dgrad
1517
if (layer > 0 || requires_grad == 1) {
1518
cublas_status = mlp_gemm(
1519
handle,
1520
CUBLAS_OP_N,
1521
CUBLAS_OP_N,
1522
xfeat,
1523
batch_size,
1524
yfeat,
1525
&one,
1526
weight,
1527
xfeat,
1528
dy_gemm,
1529
yfeat,
1530
&zero,
1531
dx,
1532
xfeat);
1533
1534
if (cublas_status != CUBLAS_STATUS_SUCCESS) {
1535
printf("GEMM dgrad failed with %d\n", cublas_status);
1536
return 1;
1537
}
1538
}
1539
1540
// Call GEMM wgrad
1541
cublas_status = mlp_gemm(
1542
handle,
1543
CUBLAS_OP_N,
1544
CUBLAS_OP_T,
1545
xfeat,
1546
yfeat,
1547
batch_size,
1548
&one,
1549
x,
1550
xfeat,
1551
dy_gemm,
1552
yfeat,
1553
&zero,
1554
dweight,
1555
xfeat);
1556
1557
if (cublas_status != CUBLAS_STATUS_SUCCESS) {
1558
printf("GEMM wgrad failed with %d\n", cublas_status);
1559
return 1;
1560
}
1561
}
1562
1563
return 0;
1564
}
1565
1566
// Instantiate for floating point types
1567
template int mlp_fp<float>(
1568
float* X,
1569
int input_features,
1570
int batch_size,
1571
float** WPtr,
1572
int num_layers,
1573
int* output_features,
1574
float** BPtr,
1575
float* Y,
1576
float* reserved_space,
1577
int use_bias,
1578
int activation,
1579
void* lt_workspace);
1580
1581
template int mlp_bp<float>(
1582
float* X,
1583
float* Y,
1584
int input_features,
1585
int batch_size,
1586
float** WPtr,
1587
int num_layers,
1588
int* output_features,
1589
float* dY,
1590
float* reserved_space,
1591
float* work_space,
1592
float* dX,
1593
float** dwPtr,
1594
float** dbPtr,
1595
bool requires_grad,
1596
int use_bias,
1597
int activation);
1598
1599
template int mlp_fp<at::Half>(
1600
at::Half* X,
1601
int input_features,
1602
int batch_size,
1603
at::Half** WPtr,
1604
int num_layers,
1605
int* output_features,
1606
at::Half** BPtr,
1607
at::Half* Y,
1608
at::Half* reserved_space,
1609
int use_bias,
1610
int activation,
1611
void* lt_workspace);
1612
1613
template int mlp_bp<at::Half>(
1614
at::Half* X,
1615
at::Half* Y,
1616
int input_features,
1617
int batch_size,
1618
at::Half** WPtr,
1619
int num_layers,
1620
int* output_features,
1621
at::Half* dY,
1622
at::Half* reserved_space,
1623
at::Half* work_space,
1624
at::Half* dX,
1625
at::Half** dwPtr,
1626
at::Half** dbPtr,
1627
bool requires_grad,
1628
int use_bias,
1629
int activation);
1630
1631
template int mlp_fp<double>(
1632
double* X,
1633
int input_features,
1634
int batch_size,
1635
double** WPtr,
1636
int num_layers,
1637
int* output_features,
1638
double** BPtr,
1639
double* Y,
1640
double* reserved_space,
1641
int use_bias,
1642
int activation,
1643
void* lt_workspace);
1644
1645
template int mlp_bp<double>(
1646
double* X,
1647
double* Y,
1648
int input_features,
1649
int batch_size,
1650
double** WPtr,
1651
int num_layers,
1652
int* output_features,
1653
double* dY,
1654
double* reserved_space,
1655
double* work_space,
1656
double* dX,
1657
double** dwPtr,
1658
double** dbPtr,
1659
bool requires_grad,
1660
int use_bias,
1661
int activation);
1662
1663
template size_t get_mlp_bp_workspace_in_bytes<float>(
1664
int batch_size,
1665
int num_layers,
1666
const int* output_features);
1667
template size_t get_mlp_bp_workspace_in_bytes<at::Half>(
1668
int batch_size,
1669
int num_layers,
1670
const int* output_features);
1671
template size_t get_mlp_bp_workspace_in_bytes<double>(
1672
int batch_size,
1673
int num_layers,
1674
const int* output_features);
1675
1676
1677