Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/csrc/mlp_cuda.cu
Views: 792
#include <ATen/ATen.h>1#include <ATen/cuda/CUDAContext.h>2#include <assert.h>3#include <stdio.h>4#include <stdlib.h>5#include <string.h>6#include <torch/torch.h>78/* Includes, cuda */9#include <cublas_v2.h>10#include <cuda_runtime.h>1112// includes cublaslt13#include <cublasLt.h>1415// constants for fused bias+relu kernel16#define BIAS_RELU_FW_NTHREADS 128 // forward number of thread per block17#define BIAS_RELU_BW_NTHREADS_X 32 // backward number of thread in feature dim18#define BIAS_RELU_BW_NTHREADS_Y 16 // backward number of thread in batch dim19#define BIAS_RELU_RED_PER_THREAD 16 // backward minimal reduction length per thread2021// move to a header later on22#define ILP 423template<typename T>24__host__ __device__ __forceinline__ bool is_aligned(T* p){25return ((uint64_t)p) % (ILP*sizeof(T)) == 0;26}2728template<typename T>29__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){30typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;31((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];32}33template<typename T>34__device__ __forceinline__ void load_store(T* dst, volatile T* src, int dst_offset, int src_offset){35typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;36((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];37}38template<typename T>39__device__ __forceinline__ void load_store(volatile T* dst, T* src, int dst_offset, int src_offset){40typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;41((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];42}4344// Keep ReLU in float only. When using half, cast to float before calling.45__device__ __inline__ float relu(float a) {46float retf = max(a, 0.f);47return (retf);48}4950// Keep Sigmoid in float only. When using half, cast to float before calling.51__device__ __inline__ float sigmoid(float a) {52float retf = 1.f / (1.f + expf(-a));53return (retf);54}5556// FP64 Wrapper around cublas GEMMEx57cublasStatus_t mlp_gemm(58cublasHandle_t handle,59cublasOperation_t transa,60cublasOperation_t transb,61int m,62int n,63int k,64float* alpha,65const double* A,66int lda,67const double* B,68int ldb,69const float* beta,70double* C,71int ldc) {72return cublasGemmEx(73handle,74transa,75transb,76m,77n,78k,79alpha,80A,81CUDA_R_64F,82lda,83B,84CUDA_R_64F,85ldb,86beta,87C,88CUDA_R_64F,89ldc,90CUDA_R_64F,91CUBLAS_GEMM_DEFAULT);92}9394// FP32 Wrapper around cublas GEMMEx95cublasStatus_t mlp_gemm(96cublasHandle_t handle,97cublasOperation_t transa,98cublasOperation_t transb,99int m,100int n,101int k,102float* alpha,103const float* A,104int lda,105const float* B,106int ldb,107const float* beta,108float* C,109int ldc) {110return cublasGemmEx(111handle,112transa,113transb,114m,115n,116k,117alpha,118A,119CUDA_R_32F,120lda,121B,122CUDA_R_32F,123ldb,124beta,125C,126CUDA_R_32F,127ldc,128CUDA_R_32F,129CUBLAS_GEMM_DEFAULT);130}131132// FP16 Tensor core wrapper around cublas GEMMEx133cublasStatus_t mlp_gemm(134cublasHandle_t handle,135cublasOperation_t transa,136cublasOperation_t transb,137int m,138int n,139int k,140float* alpha,141const at::Half* A,142int lda,143const at::Half* B,144int ldb,145float* beta,146at::Half* C,147int ldc) {148return cublasGemmEx(149handle,150transa,151transb,152m,153n,154k,155alpha,156A,157CUDA_R_16F,158lda,159B,160CUDA_R_16F,161ldb,162beta,163C,164CUDA_R_16F,165ldc,166CUDA_R_32F,167CUBLAS_GEMM_DEFAULT_TENSOR_OP);168}169170int mlp_gemm_lt(171cublasLtHandle_t ltHandle,172cublasOperation_t transa,173cublasOperation_t transb,174int m,175int n,176int k,177float *alpha, /* host pointer */178const at::Half* A,179int lda,180const at::Half* B,181int ldb,182float *beta, /* host pointer */183at::Half* C,184int ldc,185void *workspace,186size_t workspaceSize,187cudaStream_t stream,188bool use_bias,189bool use_relu,190const void* bias) {191cublasStatus_t status = CUBLAS_STATUS_SUCCESS;192193cublasLtMatmulDescOpaque_t operationDesc = {};194cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};195cublasLtMatmulPreferenceOpaque_t preference = {};196197int returnedResults = 0;198cublasLtMatmulHeuristicResult_t heuristicResult = {};199cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;200201// Create operation descriptor; see cublasLtMatmulDescAttributes_t202// for details about defaults; here we just set the transforms for203// A and B.204status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);205if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;206status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));207if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;208status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));209if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;210211if (use_bias) {212status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));213if (status != CUBLAS_STATUS_SUCCESS) {214goto CLEANUP;215}216if (use_relu) {217epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;218} else {219epilogue = CUBLASLT_EPILOGUE_BIAS;220}221} else {222if (use_relu) {223epilogue = CUBLASLT_EPILOGUE_RELU;224}225}226227status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));228if (status != CUBLAS_STATUS_SUCCESS) {229goto CLEANUP;230}231232// Create matrix descriptors. Not setting any extra attributes.233status = cublasLtMatrixLayoutInit(234&Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);235if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;236status = cublasLtMatrixLayoutInit(237&Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);238if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;239status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);240if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;241242// Create preference handle; In general, extra attributes can be243// used here to disable tensor ops or to make sure algo selected244// will work with badly aligned A, B, C. However, for simplicity245// here we assume A,B,C are always well aligned (e.g., directly246// come from cudaMalloc)247status = cublasLtMatmulPreferenceInit(&preference);248if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;249status = cublasLtMatmulPreferenceSetAttribute(250&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));251if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;252253// We just need the best available heuristic to try and run matmul.254// There is no guarantee that this will work. For example, if A is255// badly aligned, you can request more (e.g. 32) algos and try to256// run them one by one until something works.257status = cublasLtMatmulAlgoGetHeuristic(258ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);259if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;260261if (returnedResults == 0) {262status = CUBLAS_STATUS_NOT_SUPPORTED;263goto CLEANUP;264}265status = cublasLtMatmul(ltHandle,266&operationDesc,267alpha,268A,269&Adesc,270B,271&Bdesc,272beta,273C,274&Cdesc,275C,276&Cdesc,277&heuristicResult.algo,278workspace,279workspaceSize,280stream);281282CLEANUP:283// Descriptors are no longer needed as all GPU work was already284// enqueued.285return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;286}287288int mlp_gemm_lt(289cublasLtHandle_t ltHandle,290cublasOperation_t transa,291cublasOperation_t transb,292int m,293int n,294int k,295float *alpha, /* host pointer */296const double* A,297int lda,298const double* B,299int ldb,300float *beta, /* host pointer */301double* C,302int ldc,303void *workspace,304size_t workspaceSize,305cudaStream_t stream,306bool use_bias,307bool use_relu,308const void* bias) {309return 1;310}311312int mlp_gemm_lt(313cublasLtHandle_t ltHandle,314cublasOperation_t transa,315cublasOperation_t transb,316int m,317int n,318int k,319float *alpha, /* host pointer */320const float *A,321int lda,322const float *B,323int ldb,324float *beta, /* host pointer */325float *C,326int ldc,327void *workspace,328size_t workspaceSize,329cudaStream_t stream,330bool use_bias,331bool use_relu,332const void* bias) {333cublasStatus_t status = CUBLAS_STATUS_SUCCESS;334335cublasLtMatmulDescOpaque_t operationDesc = {};336cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};337cublasLtMatmulPreferenceOpaque_t preference = {};338339int returnedResults = 0;340cublasLtMatmulHeuristicResult_t heuristicResult = {};341cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;342343// Create operation descriptor; see cublasLtMatmulDescAttributes_t344// for details about defaults; here we just set the transforms for345// A and B.346status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);347if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;348status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));349if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;350status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));351if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;352353if (use_bias) {354status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));355if (status != CUBLAS_STATUS_SUCCESS) {356goto CLEANUP;357}358if (use_relu) {359epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;360} else {361epilogue = CUBLASLT_EPILOGUE_BIAS;362}363} else {364if (use_relu) {365epilogue = CUBLASLT_EPILOGUE_RELU;366}367}368369status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));370if (status != CUBLAS_STATUS_SUCCESS) {371goto CLEANUP;372}373374// Create matrix descriptors. Not setting any extra attributes.375status = cublasLtMatrixLayoutInit(376&Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);377if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;378status = cublasLtMatrixLayoutInit(379&Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);380if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;381status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc);382if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;383384// Create preference handle; In general, extra attributes can be385// used here to disable tensor ops or to make sure algo selected386// will work with badly aligned A, B, C. However, for simplicity387// here we assume A,B,C are always well aligned (e.g., directly388// come from cudaMalloc)389status = cublasLtMatmulPreferenceInit(&preference);390if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;391status = cublasLtMatmulPreferenceSetAttribute(392&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));393if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;394395// We just need the best available heuristic to try and run matmul.396// There is no guarantee that this will work. For example, if A is397// badly aligned, you can request more (e.g. 32) algos and try to398// run them one by one until something works.399status = cublasLtMatmulAlgoGetHeuristic(400ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);401if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;402403if (returnedResults == 0) {404status = CUBLAS_STATUS_NOT_SUPPORTED;405goto CLEANUP;406}407408status = cublasLtMatmul(ltHandle,409&operationDesc,410alpha,411A,412&Adesc,413B,414&Bdesc,415beta,416C,417&Cdesc,418C,419&Cdesc,420&heuristicResult.algo,421workspace,422workspaceSize,423stream);424425CLEANUP:426// Descriptors are no longer needed as all GPU work was already427// enqueued.428return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;429}430431432// Bias ADD. Assume input X is [features x batch size], column major.433// Bias is one 'features' long vector, with implicit broadcast.434template <typename T>435__global__ void biasAdd_fprop(T *X, T *b, uint batch_size, uint features) {436T r_x[ILP];437T r_b[ILP];438if(is_aligned(X) && is_aligned(b) && features % ILP ==0) {439int tid = blockIdx.x * blockDim.x + threadIdx.x;440for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {441int row = tid % (features / ILP);442load_store(r_x, X, 0 , tid);443load_store(r_b, b, 0 , row);444#pragma unroll445for(int ii = 0; ii < ILP; ii++) {446float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);447r_x[ii] = bias_sum;448}449load_store(X, r_x, tid , 0);450}451} else {452int tid = blockIdx.x * blockDim.x + threadIdx.x;453for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {454#pragma unroll455for(int ii = 0; ii < ILP; ii++) {456int idx = tid + ii * blockDim.x * gridDim.x;457if(idx < features * batch_size) {458int row = tid % features;459r_x[ii] = X[idx];460r_b[ii] = b[row];461}462}463#pragma unroll464for(int ii = 0; ii < ILP; ii++) {465float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);466r_x[ii] = bias_sum;467}468#pragma unroll469for(int ii = 0; ii < ILP; ii++) {470int idx = tid + ii * blockDim.x * gridDim.x;471if(idx < features * batch_size) {472X[idx] = r_x[ii];473}474}475}476}477}478479// Bias ADD + ReLU. Assume input X is [features x batch size], column major.480// Activation support fuesed ReLU. Safe to call in-place.481template <typename T>482__global__ void biasAddRelu_fprop(T *X, T *b, uint batch_size, uint features) {483T r_x[ILP];484T r_b[ILP];485if(is_aligned(X) && is_aligned(b) && features % ILP ==0) {486int tid = blockIdx.x * blockDim.x + threadIdx.x;487for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {488int row = tid % (features / ILP);489load_store(r_x, X, 0 , tid);490load_store(r_b, b, 0 , row);491#pragma unroll492for(int ii = 0; ii < ILP; ii++) {493float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);494r_x[ii] = relu(bias_sum);495}496load_store(X, r_x, tid , 0);497}498} else {499int tid = blockIdx.x * blockDim.x + threadIdx.x;500for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {501#pragma unroll502for(int ii = 0; ii < ILP; ii++) {503int idx = tid + ii * blockDim.x * gridDim.x;504if(idx < features * batch_size) {505int row = tid % features;506r_x[ii] = X[idx];507r_b[ii] = b[row];508}509}510#pragma unroll511for(int ii = 0; ii < ILP; ii++) {512float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);513r_x[ii] = relu(bias_sum);514}515#pragma unroll516for(int ii = 0; ii < ILP; ii++) {517int idx = tid + ii * blockDim.x * gridDim.x;518if(idx < features * batch_size) {519X[idx] = r_x[ii];520}521}522}523}524}525526// ReLU. Assume input X is [features x batch size], column major.527// Safe to call in-place.528template <typename T>529__global__ void Relu_fprop(T *X, uint batch_size, uint features) {530T r_x[ILP];531if(is_aligned(X) && features % ILP ==0) {532int tid = blockIdx.x * blockDim.x + threadIdx.x;533for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {534load_store(r_x, X, 0 , tid);535#pragma unroll536for(int ii = 0; ii < ILP; ii++) {537r_x[ii] = relu(static_cast<float>(r_x[ii]));538}539load_store(X, r_x, tid , 0);540}541} else {542int tid = blockIdx.x * blockDim.x + threadIdx.x;543for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {544#pragma unroll545for(int ii = 0; ii < ILP; ii++) {546int idx = tid + ii * blockDim.x * gridDim.x;547if(idx < features * batch_size) {548r_x[ii] = X[idx];549}550}551#pragma unroll552for(int ii = 0; ii < ILP; ii++) {553r_x[ii] = relu(static_cast<float>(r_x[ii]));554}555#pragma unroll556for(int ii = 0; ii < ILP; ii++) {557int idx = tid + ii * blockDim.x * gridDim.x;558if(idx < features * batch_size) {559X[idx] = r_x[ii];560}561}562}563}564}565566// Sigmoid. Assume input X is [features x batch size], column major.567// Safe to call in-place.568template <typename T>569__global__ void Sigmoid_fprop(T *X, uint batch_size, uint features) {570T r_x[ILP];571if(is_aligned(X) && features % ILP ==0) {572int tid = blockIdx.x * blockDim.x + threadIdx.x;573for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {574load_store(r_x, X, 0 , tid);575#pragma unroll576for(int ii = 0; ii < ILP; ii++) {577r_x[ii] = sigmoid(static_cast<float>(r_x[ii]));578}579load_store(X, r_x, tid , 0);580}581} else {582int tid = blockIdx.x * blockDim.x + threadIdx.x;583for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {584#pragma unroll585for(int ii = 0; ii < ILP; ii++) {586int idx = tid + ii * blockDim.x * gridDim.x;587if(idx < features * batch_size) {588r_x[ii] = X[idx];589}590}591#pragma unroll592for(int ii = 0; ii < ILP; ii++) {593r_x[ii] = sigmoid(static_cast<float>(r_x[ii]));594}595#pragma unroll596for(int ii = 0; ii < ILP; ii++) {597int idx = tid + ii * blockDim.x * gridDim.x;598if(idx < features * batch_size) {599X[idx] = r_x[ii];600}601}602}603}604}605606// ReLU. Assume input X is [features x batch size], column major.607// Safe to call in-place.608template <typename T>609__global__ void Relu_bprop(T *dY, T *Y, uint batch_size, uint features, T *dX) {610T r_dy[ILP];611T r_y[ILP];612if(is_aligned(dY) &&613is_aligned(Y) &&614is_aligned(dX) &&615features % ILP ==0) {616int tid = blockIdx.x * blockDim.x + threadIdx.x;617for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {618load_store(r_dy, dY, 0 , tid);619load_store(r_y, Y, 0 , tid);620#pragma unroll621for(int ii=0;ii<ILP;ii++){622if ((float)r_y[ii] <= 0.f)623r_dy[ii] = 0;624}625load_store(dX, r_dy, tid, 0);626}627} else {628int tid = blockIdx.x * blockDim.x + threadIdx.x;629for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {630#pragma unroll631for(int ii = 0; ii < ILP; ii++) {632int idx = tid + ii * blockDim.x * gridDim.x;633if(idx < features * batch_size) {634r_dy[ii] = dY[idx];635r_y[ii] = Y[idx];636}637}638#pragma unroll639for(int ii = 0; ii < ILP; ii++) {640if ((float)r_y[ii] <= 0.f)641r_dy[ii] = 0;642}643#pragma unroll644for(int ii = 0; ii < ILP; ii++) {645int idx = tid + ii * blockDim.x * gridDim.x;646if(idx < features * batch_size) {647dX[idx] = r_dy[ii];648}649}650}651}652}653654// Sigmoid. Assume input X is [features x batch size], column major.655// Safe to call in-place.656template <typename T>657__global__ void Sigmoid_bprop(T *dY, T *Y, uint batch_size, uint features, T *dX) {658T r_dy[ILP];659T r_y[ILP];660if(is_aligned(dY) &&661is_aligned(Y) &&662is_aligned(dX) &&663features % ILP ==0) {664int tid = blockIdx.x * blockDim.x + threadIdx.x;665for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {666load_store(r_dy, dY, 0 , tid);667load_store(r_y, Y, 0 , tid);668#pragma unroll669for(int ii=0;ii<ILP;ii++){670float grad_out = r_dy[ii];671float out = r_y[ii];672float grad_i = out * ( 1.f - out) * grad_out;673r_dy[ii] = grad_i;674}675load_store(dX, r_dy, tid, 0);676}677} else {678int tid = blockIdx.x * blockDim.x + threadIdx.x;679for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {680#pragma unroll681for(int ii = 0; ii < ILP; ii++) {682int idx = tid + ii * blockDim.x * gridDim.x;683if(idx < features * batch_size) {684r_dy[ii] = dY[idx];685r_y[ii] = Y[idx];686}687}688#pragma unroll689for(int ii = 0; ii < ILP; ii++) {690float grad_out = r_dy[ii];691float out = r_y[ii];692float grad_i = out * ( 1.f - out) * grad_out;693r_dy[ii] = grad_i;694}695#pragma unroll696for(int ii = 0; ii < ILP; ii++) {697int idx = tid + ii * blockDim.x * gridDim.x;698if(idx < features * batch_size) {699dX[idx] = r_dy[ii];700}701}702}703}704}705706// Compute grid size for pointwise backward kernel.707// block_x/y is total elment being handled per block, not number of threads708void get_biasAddRelu_bprop_grid_size(709int yfeat,710int batch_size,711int block_x,712int block_y,713int* grid_x,714int* grid_y) {715716*grid_x = (yfeat + block_x - 1) / block_x;717// Get number of SMs for efficient reduction.718int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;719// can switch to occupancy calculation. use 4 below now for sm_70720int max_blocks_y = (num_SMs * 4+(*grid_x)-1) / (*grid_x);721// block_y should be from minimal work per thread722int nRedSplits = (batch_size + block_y - 1) / block_y;723// increase number of elem per thread redcution to not launch more than enough724// kernel adjust work, so here we just launch max block725*grid_y = std::min(nRedSplits, max_blocks_y);726return;727}728729// Addition done deterministically via a 2-pass approach. Each CTA writes out partial730// sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result.731template <typename T, int UNROLL_FACTOR>732__global__ void biasAdd_bprop(733T* dY,734int features,735int batch_size,736volatile float* intermediate,737int* semaphores,738T* db) {739// The feature that this thread is responsible for740int f = blockIdx.x * blockDim.x + threadIdx.x;741742// Compute the span this thread is responsible for743// For this block744int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;745int b_nStart = blockIdx.y * b_chunkSize;746int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;747// For this thread748int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;749int nStart = threadIdx.y * chunkSize + b_nStart;750int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;751752volatile float* out = intermediate + blockIdx.y * features;753754// Flag to trigger last reduction.755__shared__ bool isLastBlock;756// we know block size for now757__shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y];758759// Accumulate db in FP32 always760float db_local = 0;761if (f < features) {762int nidx = 0;763// Handle non-multiple of UNROLL_FACTOR residue764for (; nidx < nSpan % UNROLL_FACTOR; nidx++) {765int64_t row, col, flat_idx;766row = f;767col = nStart + nidx;768flat_idx = col * features + row;769db_local += (float)dY[flat_idx];770}771772// Handle meat of work773for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) {774int64_t row, col, flat_idx;775row = f;776col = nStart + nidx;777flat_idx = col * features + row;778#pragma unroll 4779for (int u = 0; u < UNROLL_FACTOR; u++) {780db_local += (float)dY[flat_idx];781flat_idx += features;782}783}784785// naive block reduction on y-dim786int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;787smem[linear_idx] = db_local;788}789__syncthreads();790if (f < features) {791if(threadIdx.y == 0) {792for(int yidx = 1; yidx < blockDim.y; yidx++){793db_local += smem[yidx * blockDim.x + threadIdx.x];794}795796// block result is in db_local now for all threadIdx.y == 0797// Write out partial result798out[f] = db_local;799}800}801__threadfence();802__syncthreads();803804// Increment semaphore and check if this is the last CTA in the grid_y dimension.805// Only thread (0,0) calls this806if (threadIdx.x == 0 && threadIdx.y == 0 && f < features) {807unsigned int sum_idx;808sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);809isLastBlock = (sum_idx == (gridDim.y - 1));810}811__syncthreads();812813db_local = 0;814// No block reduction for now, only thread (*,0) do grid reduction815if (isLastBlock && f < features) {816if(threadIdx.y == 0) {817for (int n = 0; n < gridDim.y; n++) {818int row, col;819row = f;820col = n;821db_local += (float)(intermediate[col * features + row]);822}823db[f] = (T)db_local;824}825}826}827828// Addition done deterministically via a 2-pass approach. Each CTA writes out partial829// sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result.830template <typename T, int UNROLL_FACTOR>831__global__ void biasAddRelu_bprop(832T* Y,833T* dY,834int features,835int batch_size,836T* dX,837volatile float* intermediate,838int* semaphores,839T* db) {840// The feature that this thread is responsible for841int f = blockIdx.x * blockDim.x + threadIdx.x;842843// Compute the span this thread is responsible for844// For this block845int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;846int b_nStart = blockIdx.y * b_chunkSize;847int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;848// For this thread849int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;850int nStart = threadIdx.y * chunkSize + b_nStart;851int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;852853volatile float* out = intermediate + blockIdx.y * features;854855// Flag to trigger last reduction.856__shared__ bool isLastBlock;857// we know block size for now858__shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y];859860// Accumulate db in FP32 always861float db_local = 0;862if (f < features) {863int nidx = 0;864// Handle non-multiple of UNROLL_FACTOR residue865for (; nidx < nSpan % UNROLL_FACTOR; nidx++) {866int row, col, flat_idx;867row = f;868col = nStart + nidx;869flat_idx = col * features + row;870T y_val = Y[flat_idx];871T dy_val = dY[flat_idx];872T dx_val;873if ((float)y_val > 0.f)874dx_val = dy_val;875else876dx_val = 0;877dX[flat_idx] = dx_val;878db_local += (float)dx_val;879}880881// Handle meat of work882for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) {883int row, col, flat_idx;884row = f;885col = nStart + nidx;886flat_idx = col * features + row;887#pragma unroll 4888for (int u = 0; u < UNROLL_FACTOR; u++) {889T y_val = Y[flat_idx];890T dy_val = dY[flat_idx];891T dx_val;892if ((float)y_val > 0.f)893dx_val = dy_val;894else895dx_val = 0;896dX[flat_idx] = dx_val;897db_local += (float)dx_val;898flat_idx += features;899}900}901902// naive block reduction on y-dim903int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;904smem[linear_idx] = db_local;905}906__syncthreads();907if (f < features) {908if(threadIdx.y == 0) {909for(int yidx = 1; yidx < blockDim.y; yidx++){910db_local += smem[yidx * blockDim.x + threadIdx.x];911}912913// block result is in db_local now for all threadIdx.y == 0914// Write out partial result915out[f] = db_local;916}917}918__threadfence();919__syncthreads();920921// Increment semaphore and check if this is the last CTA in the grid_y dimension.922// Only thread (0,0) calls this923if (threadIdx.x == 0 && threadIdx.y == 0 && f < features) {924unsigned int sum_idx;925sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);926isLastBlock = (sum_idx == (gridDim.y - 1));927}928__syncthreads();929930db_local = 0;931// No block reduction for now, only thread (*,0) do grid reduction932if (isLastBlock && f < features) {933if(threadIdx.y == 0) {934for (int n = 0; n < gridDim.y; n++) {935int row, col;936row = f;937col = n;938db_local += (float)(intermediate[col * features + row]);939}940db[f] = (T)db_local;941}942}943}944945// Addition done deterministically via a 2-pass approach. Each CTA writes out partial946// sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result.947template <typename T, int UNROLL_FACTOR>948__global__ void biasAddRelu_bprop_aligned(949T* Y,950T* dY,951int features,952int batch_size,953T* dX,954volatile float* intermediate,955int* semaphores,956T* db) {957// The feature that this thread is responsible for958int f = blockIdx.x * blockDim.x + threadIdx.x;959960// Compute the span this thread is responsible for961// For this block962int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;963int b_nStart = blockIdx.y * b_chunkSize;964int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;965// For this thread966int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;967int nStart = threadIdx.y * chunkSize + b_nStart;968int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;969970volatile float* out = intermediate + blockIdx.y * features;971972// Flag to trigger last reduction.973__shared__ bool isLastBlock;974975// Accumulate db in FP32 always976float db_local[ILP];977T r_y[ILP];978T r_dy[ILP];979#pragma unroll980for(int ii=0;ii<ILP;ii++){981db_local[ii] = 0.f;982}983984// f always <= features in this case985//if (f < features) {986int nidx = 0;987988// Handle non-multiple of UNROLL_FACTOR residue989for (; nidx < nSpan % UNROLL_FACTOR; nidx++) {990int row, col, flat_idx;991row = f;992col = nStart + nidx;993flat_idx = col * features / ILP + row;994995load_store(r_y, Y, 0, flat_idx);996load_store(r_dy, dY, 0, flat_idx);997#pragma unroll998for(int ii=0;ii<ILP;ii++){999if ((float)r_y[ii] <= 0.f)1000r_dy[ii] = 0;1001db_local[ii] += (float)r_dy[ii];1002}1003load_store(dX, r_dy, flat_idx, 0);1004}10051006// Handle meat of work1007for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) {1008int row, col, flat_idx;1009row = f;1010col = nStart + nidx;1011flat_idx = col * features / ILP + row; // total threads in x == features/ILP1012#pragma unroll1013for (int u = 0; u < UNROLL_FACTOR; u++) {1014load_store(r_y, Y, 0, flat_idx);1015load_store(r_dy, dY, 0, flat_idx);1016#pragma unroll1017for(int ii=0;ii<ILP;ii++){1018if ((float)r_y[ii] <= 0.f)1019r_dy[ii] = 0;1020db_local[ii] += (float)r_dy[ii];1021}1022load_store(dX, r_dy, flat_idx, 0);1023flat_idx += features/ILP;1024}1025}10261027// we know block size for now1028__shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y*ILP];1029// naive block reduction on y-dim1030int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;1031float* smem_out = smem + ILP * linear_idx;1032#pragma unroll1033for(int ii=0;ii<ILP;ii++){1034smem_out[ii] = db_local[ii]; // reuse local dy buffer1035}1036__syncthreads();1037if(threadIdx.y == 0) {1038for(int yidx = 1; yidx < blockDim.y; yidx++){1039float* smem_in = smem + ILP * (yidx * blockDim.x + threadIdx.x);1040#pragma unroll1041for(int ii=0;ii<ILP;ii++){1042db_local[ii] += smem_in[ii]; // reuse local dy buffer1043}1044}10451046// block result is in db_local now for all threadIdx.y == 01047if(gridDim.y == 1) {1048#pragma unroll1049for(int ii=0;ii<ILP;ii++){1050r_dy[ii] = db_local[ii]; // reuse local dy buffer1051}1052load_store(db, r_dy, f, 0);1053return;1054}10551056// Write out partial result1057load_store(out, db_local, f, 0);1058}1059__threadfence();1060__syncthreads();10611062// Increment semaphore and check if this is the last CTA in the grid_y dimension.1063// Only thread (0,0) calls this1064if (threadIdx.x == 0 && threadIdx.y == 0) {1065unsigned int sum_idx;1066sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);1067isLastBlock = (sum_idx == (gridDim.y - 1));1068}1069__syncthreads();10701071#pragma unroll1072for(int ii=0;ii<ILP;ii++){1073db_local[ii] = 0.f;1074}1075float r_db[ILP];10761077// No block reduction for now, only thread (*,0) do grid reduction1078if (isLastBlock) {1079if(threadIdx.y == 0){1080for (int n = 0; n < gridDim.y; n++) {1081int row, col;1082row = f;1083col = n;1084load_store(r_db, intermediate, 0, col * features / ILP + row);1085#pragma unroll1086for(int ii=0;ii<ILP;ii++){1087db_local[ii] += r_db[ii];1088}1089}1090#pragma unroll1091for(int ii=0;ii<ILP;ii++){1092r_dy[ii] = db_local[ii]; // reuse local dy buffer1093}1094load_store(db, r_dy, f, 0);1095}1096}1097}10981099// Lists where the num_layers-1 intermediate Y buffers start in reserved space on fprop, starting1100// offset 0. The last Y value is, of course, stored in the user provided output buffer.1101void get_y_offsets(1102int batch_size,1103int num_layers,1104const int* output_features,1105int* y_start_offsets) {1106y_start_offsets[0] = 0;1107for (int i = 1; i < num_layers; i++) {1108y_start_offsets[i] = y_start_offsets[i - 1] + batch_size * output_features[i - 1];1109}1110}11111112// Returns the reserved space (in elements) needed for the MLP1113size_t get_mlp_reserved_space(int64_t batch_size, int num_layers, const int* output_features) {1114size_t res_space = 0;1115// Need to store output of every intermediate MLP - size equal to output_features[i] * batch_size1116// for all 'i' in [0, num_layers-1)1117for (int l = 0; l < num_layers; l++) {1118res_space += output_features[l] * batch_size;1119}1120return res_space;1121}11221123// Returns the size of all fprop activations combined1124size_t get_all_activations_size(int64_t batch_size, int num_layers, const int* output_features) {1125size_t acts_size = 0;1126for (int l = 0; l < num_layers; l++) {1127acts_size += output_features[l] * batch_size;1128}1129return acts_size;1130}11311132#if 01133// Returns the work space (in elements) needed for the MLP bprop.1134size_t get_mlp_bp_workspace (int batch_size, int num_layers, const int* output_features) {1135/*1136Workspace is partitioned as1137DY_GEMMs : DX_GEMMs1138*/1139size_t work_space = 0;11401141// Store each intermediate dY explicitly. Need 2 dYs per MLP layer (one for o/p1142// of biasReLU_bp and one for o/p of dgrad GEMM).1143work_space += 2*get_all_activations_size(batch_size, num_layers, output_features);11441145return work_space;1146}1147#endif11481149// Scratch space needed for reductions in number of elements1150size_t get_reduction_scratch_space(int batch_size, int num_layers, const int* output_features) {1151size_t max_scratch_space = 0;1152// Loop over all layers to see which one needs the max scratch space1153for (int l = 0; l < num_layers; l++) {1154// need to find max(aligned, not_aligned)1155int tmp, res0, res1;11561157int block_x = BIAS_RELU_BW_NTHREADS_X;1158int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;1159get_biasAddRelu_bprop_grid_size(1160output_features[l], batch_size, block_x, block_y, &tmp, &res0);11611162block_x = ILP * BIAS_RELU_BW_NTHREADS_X;1163get_biasAddRelu_bprop_grid_size(1164output_features[l], batch_size, block_x, block_y, &tmp, &res1);11651166max_scratch_space = std::max(max_scratch_space, (size_t)(output_features[l] * res0));1167max_scratch_space = std::max(max_scratch_space, (size_t)(output_features[l] * res1));1168}11691170return max_scratch_space;1171}11721173// Buffer for semaphores1174size_t get_semaphores_size(int num_layers, const int* output_features) {1175// Upper bound on semaphores is one per feature for the layer1176// with the most features.1177int max_features = 0;1178for (int l = 0; l < num_layers; l++) {1179max_features = std::max(max_features, output_features[l]);1180}1181return (size_t)max_features;1182}11831184// Returns the work space (in elements) needed for the MLP bprop.1185template <typename T>1186size_t get_mlp_bp_workspace_in_bytes(int batch_size, int num_layers, const int* output_features) {1187size_t work_space = 0;11881189// Store each intermediate dY explicitly. Need 2 dYs per MLP layer (one for o/p1190// of biasReLU_bp and one for o/p of dgrad GEMM).1191work_space += 2 * get_all_activations_size(batch_size, num_layers, output_features) * sizeof(T);1192work_space +=1193get_reduction_scratch_space(batch_size, num_layers, output_features) * sizeof(float);1194work_space += get_semaphores_size(num_layers, output_features) * sizeof(int);11951196return work_space;1197}11981199// Returns pointers to each segment of the workspace1200template <typename T>1201void partition_mlp_bp_workspace(1202int batch_size,1203int num_layers,1204const int* output_features,1205void* work_space,1206T** dy_gemms,1207T** dx_gemms,1208float** db_scratch,1209int** semaphores) {1210/*1211Workspace is partitioned as1212DY_GEMMs : DX_GEMMs : DB_SCRATCH : SEMAPHORES1213*/1214// Start address where dy_gemm tensors are stored1215*dy_gemms = reinterpret_cast<T*>(work_space);1216// Start address where dx_gemm tensors are stored1217*dx_gemms = *dy_gemms + get_all_activations_size(batch_size, num_layers, output_features);1218// Start address where db intermediate tensors are stored1219*db_scratch = reinterpret_cast<float*>(1220*dx_gemms + get_all_activations_size(batch_size, num_layers, output_features));1221// Start address of semaphores1222*semaphores = reinterpret_cast<int*>(1223*db_scratch + get_reduction_scratch_space(batch_size, num_layers, output_features));12241225return;1226}12271228// Does a simple MLP fprop (GEMM+bias+ReLU).1229// Can handle num_layers number of layers, each with its own shape. Output of layer i is assumed1230// to be input of layer i+1. output_features, WPtr and BPtr are arrays of length num_layers, and1231// must be in the same order i.e. WPtr[i] and BPtr[i] are respectively the weight and bias of layer1232// 'i'.1233template <typename T>1234int mlp_fp(1235T* X,1236int input_features,1237int batch_size,1238T** WPtr,1239int num_layers,1240int* output_features,1241T** BPtr,1242T* Y,1243T* reserved_space,1244int use_bias,1245int activation,1246void* lt_workspace) {1247T *weight, *input, *output, *bias;1248T *reserved_space_x, *reserved_space_y;1249reserved_space_x = NULL;1250reserved_space_y = reserved_space;12511252// Get cublas handle from Pytorch1253cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();1254// Get the stream from cublas handle to reuse for biasReLU kernel.1255cudaStream_t stream;1256cublasGetStream(handle, &stream);12571258for (int layer = 0; layer < num_layers; layer++) {1259weight = WPtr[layer];1260input = (layer == 0) ? X : reserved_space_x;1261output = (layer == num_layers - 1) ? Y : reserved_space_y;1262if (use_bias) {1263bias = BPtr[layer];1264}1265int ifeat = (layer == 0) ? input_features : output_features[layer - 1];1266int ofeat = output_features[layer];12671268float one = 1.f;1269float zero = 0.f;12701271// try with cublaslt first for supported case with valid handle1272int cublaslt_status = 1;1273if(activation < 1){1274cublaslt_status = mlp_gemm_lt(1275//ltHandle,1276(cublasLtHandle_t)handle,1277CUBLAS_OP_T,1278CUBLAS_OP_N,1279ofeat,1280batch_size,1281ifeat,1282&one,1283weight,1284ifeat,1285input,1286ifeat,1287&zero,1288output,1289ofeat,1290lt_workspace,12911 << 22,1292stream,1293use_bias == 1,1294activation == 1,1295bias);1296}12971298// if cublaslt failed or not executed, fallback to cublas1299if (cublaslt_status != 0) {1300cublasStatus_t cublas_status;1301// Call GEMM: fprop is Y = W'X1302cublas_status = mlp_gemm(1303handle,1304CUBLAS_OP_T,1305CUBLAS_OP_N,1306ofeat,1307batch_size,1308ifeat,1309&one,1310weight,1311ifeat,1312input,1313ifeat,1314&zero,1315output,1316ofeat);13171318if (cublas_status != CUBLAS_STATUS_SUCCESS) {1319printf("GEMM fprop failed with %d\n", cublas_status);1320return 1;1321}13221323const uint &input_size = ofeat;1324int num_blocks = 0;1325int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;1326// Call biasReLU1327if(use_bias == 1) {1328if (activation == 0) { // no activation1329cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);1330biasAdd_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);1331} else if (activation == 1) { // relu1332cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAddRelu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);1333biasAddRelu_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);1334} else if (activation == 2) { // sigmoid1335cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);1336biasAdd_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);1337cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);1338Sigmoid_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);1339}1340} else {1341// don't need to do anything in case of no activation and no bias1342if (activation == 1) { // relu1343cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Relu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);1344Relu_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);1345} else if (activation == 2) { // sigmoid1346cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);1347Sigmoid_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);1348}1349}1350}1351// Set current output as next layer input1352reserved_space_x = reserved_space_y;1353// Set next layer output1354reserved_space_y += ofeat * batch_size;1355}13561357return 0;1358}13591360// Does a simple MLP bprop (GEMM+bias+ReLU).1361// Needs reserved space to come back exactly as it was populated in fprop.1362// Does dgrad and wgrad sequentially.1363template <typename T>1364int mlp_bp(1365T* X,1366T* Y,1367int input_features,1368int batch_size,1369T** WPtr,1370int num_layers,1371int* output_features,1372T* dY,1373T* reserved_space,1374T* work_space,1375T* dX,1376T** dwPtr,1377T** dbPtr,1378bool requires_grad,1379int use_bias,1380int activation) {1381T* weight;1382T *dweight, *dx, *dy, *dbias;1383T *x, *y;13841385// Where the dx of the biasReLU (== dy of gemm) is stored. Can be thrown away1386// after bp call.1387T* dy_gemm_base;1388// Where the dx after GEMM is stored.1389T* dx_gemm_base;1390// Where partial reduction results are stored.1391float* db_scratch;1392// Semaphores for reduction.1393int* semaphores;13941395partition_mlp_bp_workspace<T>(1396batch_size,1397num_layers,1398output_features,1399work_space,1400&dy_gemm_base,1401&dx_gemm_base,1402&db_scratch,1403&semaphores);14041405size_t semaphore_size = get_semaphores_size(num_layers, output_features) * sizeof(int);14061407// Get cublas handle from Pytorch1408cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();1409// Get the stream from cublas handle to reuse for biasReLU kernel.1410cudaStream_t stream;1411cublasGetStream(handle, &stream);14121413int* y_offsets = (int*)malloc(num_layers * sizeof(int));1414get_y_offsets(batch_size, num_layers, output_features, y_offsets);14151416for (int layer = num_layers - 1; layer >= 0; layer--) {1417weight = WPtr[layer];1418dweight = dwPtr[layer];14191420// x is read from reserved space1421x = (layer == 0) ? X : reserved_space + y_offsets[layer - 1];1422// dx is written in workspace for all but layer==01423dx = (layer == 0) ? dX : dx_gemm_base + y_offsets[layer - 1];14241425// y is read from reserved space1426y = (layer == num_layers - 1) ? Y : reserved_space + y_offsets[layer];1427// dx from layer+11428dy = (layer == num_layers - 1) ? dY : dx_gemm_base + y_offsets[layer];1429// dy_gemm is written to and read immediately1430T* dy_gemm = dy_gemm_base + y_offsets[layer];14311432dbias = dbPtr[layer];1433int xfeat = (layer == 0) ? input_features : output_features[layer - 1];1434int yfeat = output_features[layer];14351436float one = 1.f;1437float zero = 0.f;14381439if (use_bias == 1) {1440if (activation == 0) { // no acitvation1441// bgrad1442dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);1443int grid_x, grid_y;1444cudaMemsetAsync(semaphores, 0, semaphore_size, stream);14451446int block_x = BIAS_RELU_BW_NTHREADS_X;1447int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;1448get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);1449dim3 grid(grid_x, grid_y);1450biasAdd_bprop<T, 4><<<grid, block, 0, stream>>>(1451dy, yfeat, batch_size, db_scratch, semaphores, dbias);1452// bypass dgrad through reset pointer1453dy_gemm = dy;1454} else if (activation == 1) { // relu1455dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);1456int grid_x, grid_y;1457cudaMemsetAsync(semaphores, 0, semaphore_size, stream);14581459if(yfeat % (ILP * BIAS_RELU_BW_NTHREADS_X) == 0 &&1460is_aligned(y) &&1461is_aligned(dy) &&1462is_aligned(dy_gemm) &&1463is_aligned(dbias)){1464int block_x = ILP * BIAS_RELU_BW_NTHREADS_X;1465int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;1466get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);1467dim3 grid(grid_x, grid_y);1468biasAddRelu_bprop_aligned<T, 4><<<grid, block, 0, stream>>>(1469y, dy, yfeat, batch_size, dy_gemm, db_scratch, semaphores, dbias);1470} else {1471int block_x = BIAS_RELU_BW_NTHREADS_X;1472int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;1473get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);1474dim3 grid(grid_x, grid_y);1475biasAddRelu_bprop<T, 4><<<grid, block, 0, stream>>>(1476y, dy, yfeat, batch_size, dy_gemm, db_scratch, semaphores, dbias);1477}1478} else if (activation == 2) { // sigmoid1479// activation backward1480int num_blocks = 0;1481int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;1482cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);1483Sigmoid_bprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);14841485// bgrad, from dy_gemm1486dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);1487int grid_x, grid_y;1488cudaMemsetAsync(semaphores, 0, semaphore_size, stream);14891490int block_x = BIAS_RELU_BW_NTHREADS_X;1491int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;1492get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);1493dim3 grid(grid_x, grid_y);1494biasAdd_bprop<T, 4><<<grid, block, 0, stream>>>(1495dy_gemm, yfeat, batch_size, db_scratch, semaphores, dbias);1496}1497} else { // no bias below1498if (activation == 0) {1499// bypass dgrad through reset pointer1500dy_gemm = dy;1501} else if (activation == 1) { // relu1502int num_blocks = 0;1503int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;1504cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Relu_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);1505Relu_bprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);1506} else if (activation == 2) { // sigmoid1507int num_blocks = 0;1508int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;1509cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);1510Sigmoid_bprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);1511}1512}15131514cublasStatus_t cublas_status;1515// Call GEMM dgrad1516if (layer > 0 || requires_grad == 1) {1517cublas_status = mlp_gemm(1518handle,1519CUBLAS_OP_N,1520CUBLAS_OP_N,1521xfeat,1522batch_size,1523yfeat,1524&one,1525weight,1526xfeat,1527dy_gemm,1528yfeat,1529&zero,1530dx,1531xfeat);15321533if (cublas_status != CUBLAS_STATUS_SUCCESS) {1534printf("GEMM dgrad failed with %d\n", cublas_status);1535return 1;1536}1537}15381539// Call GEMM wgrad1540cublas_status = mlp_gemm(1541handle,1542CUBLAS_OP_N,1543CUBLAS_OP_T,1544xfeat,1545yfeat,1546batch_size,1547&one,1548x,1549xfeat,1550dy_gemm,1551yfeat,1552&zero,1553dweight,1554xfeat);15551556if (cublas_status != CUBLAS_STATUS_SUCCESS) {1557printf("GEMM wgrad failed with %d\n", cublas_status);1558return 1;1559}1560}15611562return 0;1563}15641565// Instantiate for floating point types1566template int mlp_fp<float>(1567float* X,1568int input_features,1569int batch_size,1570float** WPtr,1571int num_layers,1572int* output_features,1573float** BPtr,1574float* Y,1575float* reserved_space,1576int use_bias,1577int activation,1578void* lt_workspace);15791580template int mlp_bp<float>(1581float* X,1582float* Y,1583int input_features,1584int batch_size,1585float** WPtr,1586int num_layers,1587int* output_features,1588float* dY,1589float* reserved_space,1590float* work_space,1591float* dX,1592float** dwPtr,1593float** dbPtr,1594bool requires_grad,1595int use_bias,1596int activation);15971598template int mlp_fp<at::Half>(1599at::Half* X,1600int input_features,1601int batch_size,1602at::Half** WPtr,1603int num_layers,1604int* output_features,1605at::Half** BPtr,1606at::Half* Y,1607at::Half* reserved_space,1608int use_bias,1609int activation,1610void* lt_workspace);16111612template int mlp_bp<at::Half>(1613at::Half* X,1614at::Half* Y,1615int input_features,1616int batch_size,1617at::Half** WPtr,1618int num_layers,1619int* output_features,1620at::Half* dY,1621at::Half* reserved_space,1622at::Half* work_space,1623at::Half* dX,1624at::Half** dwPtr,1625at::Half** dbPtr,1626bool requires_grad,1627int use_bias,1628int activation);16291630template int mlp_fp<double>(1631double* X,1632int input_features,1633int batch_size,1634double** WPtr,1635int num_layers,1636int* output_features,1637double** BPtr,1638double* Y,1639double* reserved_space,1640int use_bias,1641int activation,1642void* lt_workspace);16431644template int mlp_bp<double>(1645double* X,1646double* Y,1647int input_features,1648int batch_size,1649double** WPtr,1650int num_layers,1651int* output_features,1652double* dY,1653double* reserved_space,1654double* work_space,1655double* dX,1656double** dwPtr,1657double** dbPtr,1658bool requires_grad,1659int use_bias,1660int activation);16611662template size_t get_mlp_bp_workspace_in_bytes<float>(1663int batch_size,1664int num_layers,1665const int* output_features);1666template size_t get_mlp_bp_workspace_in_bytes<at::Half>(1667int batch_size,1668int num_layers,1669const int* output_features);1670template size_t get_mlp_bp_workspace_in_bytes<double>(1671int batch_size,1672int num_layers,1673const int* output_features);1674167516761677