Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/csrc/layer_norm_cuda_kernel.cu
Views: 792
#include "ATen/ATen.h"1#include "ATen/AccumulateType.h"2#include "ATen/cuda/CUDAContext.h"3#include <THC/THCDeviceUtils.cuh>45#include <cuda.h>6#include <cuda_runtime.h>78#include "type_shim.h"910template<typename U> __device__11void cuWelfordOnlineSum(12const U curr,13U& mu,14U& sigma2,15U& count)16{17count = count + U(1);18U delta = curr - mu;19U lmean = mu + delta / count;20mu = lmean;21U delta2 = curr - lmean;22sigma2 = sigma2 + delta * delta2;23}2425template<typename U> __device__26void cuChanOnlineSum(27const U muB,28const U sigma2B,29const U countB,30U& mu,31U& sigma2,32U& count)33{34U delta = muB - mu;35U nA = count;36U nB = countB;37count = count + countB;38U nX = count;39if (nX > U(0)) {40nA = nA / nX;41nB = nB / nX;42mu = nA*mu + nB*muB;43sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;44} else {45mu = U(0);46sigma2 = U(0);47}48}4950template<typename T, typename U> __device__51void cuWelfordMuSigma2(52const T* __restrict__ vals,53const int n1,54const int n2,55const int i1,56U& mu,57U& sigma2,58U* buf)59{60// Assumptions:61// 1) blockDim.x == warpSize62// 2) Tensor is contiguous63// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.64//65// compute variance and mean over n266U count = U(0);67mu= U(0);68sigma2 = U(0);69if (i1 < n1) {70// one warp normalizes one n1 index,71// synchronization is implicit72// initialize with standard Welford algorithm73const int numx = blockDim.x * blockDim.y;74const int thrx = threadIdx.x + threadIdx.y * blockDim.x;75const T* lvals = vals + i1*n2;76int l = 4*thrx;77for (; l+3 < n2; l+=4*numx) {78for (int k = 0; k < 4; ++k) {79U curr = static_cast<U>(lvals[l+k]);80cuWelfordOnlineSum<U>(curr,mu,sigma2,count);81}82}83for (; l < n2; ++l) {84U curr = static_cast<U>(lvals[l]);85cuWelfordOnlineSum<U>(curr,mu,sigma2,count);86}87// intra-warp reductions88for (int l = 0; l <= 4; ++l) {89int srcLaneB = (threadIdx.x+(1<<l))&31;90U muB = WARP_SHFL(mu, srcLaneB);91U countB = WARP_SHFL(count, srcLaneB);92U sigma2B = WARP_SHFL(sigma2, srcLaneB);93cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);94}95// threadIdx.x == 0 has correct values for each warp96// inter-warp reductions97if (blockDim.y > 1) {98U* ubuf = (U*)buf;99U* ibuf = (U*)(ubuf + blockDim.y);100for (int offset = blockDim.y/2; offset > 0; offset /= 2) {101// upper half of warps write to shared102if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {103const int wrt_y = threadIdx.y - offset;104ubuf[2*wrt_y] = mu;105ubuf[2*wrt_y+1] = sigma2;106ibuf[wrt_y] = count;107}108__syncthreads();109// lower half merges110if (threadIdx.x == 0 && threadIdx.y < offset) {111U muB = ubuf[2*threadIdx.y];112U sigma2B = ubuf[2*threadIdx.y+1];113U countB = ibuf[threadIdx.y];114cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);115}116__syncthreads();117}118// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values119if (threadIdx.x == 0 && threadIdx.y == 0) {120ubuf[0] = mu;121ubuf[1] = sigma2;122}123__syncthreads();124mu = ubuf[0];125sigma2 = ubuf[1]/U(n2);126// don't care about final value of count, we know count == n2127} else {128mu = WARP_SHFL(mu, 0);129sigma2 = WARP_SHFL(sigma2/U(n2), 0);130}131}132}133134template<> __device__135void cuWelfordMuSigma2(136const at::Half* __restrict__ vals,137const int n1,138const int n2,139const int i1,140float& mu,141float& sigma2,142float* buf)143{144// Assumptions:145// 1) blockDim.x == warpSize146// 2) Tensor is contiguous147// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.148//149// compute variance and mean over n2150float count = 0.0f;151mu= float(0);152sigma2 = float(0);153if (i1 < n1) {154// one warp normalizes one n1 index,155// synchronization is implicit156// initialize with standard Welford algorithm157const int numx = blockDim.x * blockDim.y;158const int thrx = threadIdx.x + threadIdx.y * blockDim.x;159const at::Half* lvals = vals + i1*n2;160int l = 8*thrx;161if ((((size_t)lvals)&3) != 0) {162// 16 bit alignment163// first thread consumes first point164if (thrx == 0) {165float curr = static_cast<float>(lvals[0]);166cuWelfordOnlineSum(curr,mu,sigma2,count);167}168++l;169}170// at this point, lvals[l] are 32 bit aligned for all threads.171for (; l+7 < n2; l+=8*numx) {172for (int k = 0; k < 8; k+=2) {173float2 curr = __half22float2(*((__half2*)(lvals+l+k)));174cuWelfordOnlineSum(curr.x,mu,sigma2,count);175cuWelfordOnlineSum(curr.y,mu,sigma2,count);176}177}178for (; l < n2; ++l) {179float curr = static_cast<float>(lvals[l]);180cuWelfordOnlineSum(curr,mu,sigma2,count);181}182// intra-warp reductions183for (int l = 0; l <= 4; ++l) {184int srcLaneB = (threadIdx.x+(1<<l))&31;185float muB = WARP_SHFL(mu, srcLaneB);186float countB = WARP_SHFL(count, srcLaneB);187float sigma2B = WARP_SHFL(sigma2, srcLaneB);188cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);189}190// threadIdx.x == 0 has correct values for each warp191// inter-warp reductions192if (blockDim.y > 1) {193float* ubuf = (float*)buf;194float* ibuf = (float*)(ubuf + blockDim.y);195for (int offset = blockDim.y/2; offset > 0; offset /= 2) {196// upper half of warps write to shared197if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {198const int wrt_y = threadIdx.y - offset;199ubuf[2*wrt_y] = mu;200ubuf[2*wrt_y+1] = sigma2;201ibuf[wrt_y] = count;202}203__syncthreads();204// lower half merges205if (threadIdx.x == 0 && threadIdx.y < offset) {206float muB = ubuf[2*threadIdx.y];207float sigma2B = ubuf[2*threadIdx.y+1];208float countB = ibuf[threadIdx.y];209cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);210}211__syncthreads();212}213// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values214if (threadIdx.x == 0 && threadIdx.y == 0) {215ubuf[0] = mu;216ubuf[1] = sigma2;217}218__syncthreads();219mu = ubuf[0];220sigma2 = ubuf[1]/float(n2);221// don't care about final value of count, we know count == n2222} else {223mu = WARP_SHFL(mu, 0);224sigma2 = WARP_SHFL(sigma2/float(n2), 0);225}226}227}228229template<typename U> U rsqrt(U v) {230return U(1) / sqrt(v);231}232template<> float rsqrt(float v) {233return rsqrtf(v);234}235template<> double rsqrt(double v) {236return rsqrt(v);237}238239namespace {240// This is the un-specialized struct. Note that we prevent instantiation of this241// struct by putting an undefined symbol in the function body so it won't compile.242// template <typename T>243// struct SharedMemory244// {245// // Ensure that we won't compile any un-specialized types246// __device__ T *getPointer()247// {248// extern __device__ void error(void);249// error();250// return NULL;251// }252// };253// https://github.com/NVIDIA/apex/issues/246254template <typename T>255struct SharedMemory;256257template <>258struct SharedMemory <float>259{260__device__ float *getPointer()261{262extern __shared__ float s_float[];263return s_float;264}265};266267template <>268struct SharedMemory <double>269{270__device__ double *getPointer()271{272extern __shared__ double s_double[];273return s_double;274}275};276}277278template<typename T, typename U> __global__279void cuApplyLayerNorm(280T* __restrict__ output_vals,281U* __restrict__ mean,282U* __restrict__ invvar,283const T* __restrict__ vals,284const int n1,285const int n2,286const U epsilon,287const T* __restrict__ gamma,288const T* __restrict__ beta289)290{291// Assumptions:292// 1) blockDim.x == warpSize293// 2) Tensors are contiguous294//295for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {296SharedMemory<U> shared;297U* buf = shared.getPointer();298U mu,sigma2;299cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf);300const T* lvals = vals + i1*n2;301T* ovals = output_vals + i1*n2;302U c_invvar = rsqrt(sigma2 + epsilon);303const int numx = blockDim.x * blockDim.y;304const int thrx = threadIdx.x + threadIdx.y * blockDim.x;305if (gamma != NULL && beta != NULL) {306for (int i = thrx; i < n2; i+=numx) {307U curr = static_cast<U>(lvals[i]);308ovals[i] = gamma[i] * static_cast<T>(c_invvar * (curr - mu)) + beta[i];309}310} else {311for (int i = thrx; i < n2; i+=numx) {312U curr = static_cast<U>(lvals[i]);313ovals[i] = static_cast<T>(c_invvar * (curr - mu));314}315}316if (threadIdx.x == 0 && threadIdx.y == 0) {317mean[i1] = mu;318invvar[i1] = c_invvar;319}320}321}322323template<typename T, typename U> __device__324void cuLoadWriteStridedInputs(325const int i1_block,326const int thr_load_row_off,327const int thr_load_col_off,328const int i2_off,329const int row_stride,330U* warp_buf1,331U* warp_buf2,332const T* input,333const T* dout,334const int i1_end,335const int n2,336const U* __restrict__ mean,337const U* __restrict__ invvar338)339{340int i1 = i1_block+thr_load_row_off;341if (i1 < i1_end) {342U curr_mean = mean[i1];343U curr_invvar = invvar[i1];344for (int k = 0; k < blockDim.y; ++k) {345int i2 = i2_off + k;346int load_idx = i1*n2+i2;347int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;348if (i2<n2) {349U curr_input = static_cast<U>(input[load_idx]);350U curr_dout = static_cast<U>(dout[load_idx]);351warp_buf1[write_idx] = curr_dout;352warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;353} else {354warp_buf1[write_idx] = U(0);355warp_buf2[write_idx] = U(0);356}357}358} else {359for (int k = 0; k < blockDim.y; ++k) {360int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;361warp_buf1[write_idx] = U(0);362warp_buf2[write_idx] = U(0);363}364}365}366367template<typename T, typename U> __device__368void cuLoadAddStridedInputs(369const int i1_block,370const int thr_load_row_off,371const int thr_load_col_off,372const int i2_off,373const int row_stride,374U* warp_buf1,375U* warp_buf2,376const T* input,377const T* dout,378const int i1_end,379const int n2,380const U* __restrict__ mean,381const U* __restrict__ invvar382)383{384int i1 = i1_block+thr_load_row_off;385if (i1 < i1_end) {386U curr_mean = mean[i1];387U curr_invvar = invvar[i1];388for (int k = 0; k < blockDim.y; ++k) {389int i2 = i2_off + k;390int load_idx = i1*n2+i2;391int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;392if (i2<n2) {393U curr_input = static_cast<U>(input[load_idx]);394U curr_dout = static_cast<U>(dout[load_idx]);395warp_buf1[write_idx] += curr_dout;396warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;397}398}399}400}401402template<typename T, typename U> __global__403void cuComputePartGradGammaBeta(404const T* __restrict__ dout,405const T* __restrict__ input,406const int n1,407const int n2,408const U* __restrict__ mean,409const U* __restrict__ invvar,410U epsilon,411U* part_grad_gamma,412U* part_grad_beta)413{414const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);415const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;416const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;417const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;418const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;419const int row_stride = blockDim.x+1;420const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);421const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;422const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;423SharedMemory<U> shared;424U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements425U* warp_buf1 = (U*)buf;426U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;427// compute partial sums from strided inputs428// do this to increase number of loads in flight429cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);430for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) {431cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);432}433__syncthreads();434// inter-warp reductions435// sum within each warp436U acc1 = U(0);437U acc2 = U(0);438for (int k = 0; k < blockDim.y; ++k) {439int row1 = threadIdx.y + k*blockDim.y;440int idx1 = row1*row_stride + threadIdx.x;441acc1 += warp_buf1[idx1];442acc2 += warp_buf2[idx1];443}444warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;445warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;446__syncthreads();447// sum all warps448for (int offset = blockDim.y/2; offset > 1; offset /= 2) {449if (threadIdx.y < offset) {450int row1 = threadIdx.y;451int row2 = threadIdx.y + offset;452int idx1 = row1*row_stride + threadIdx.x;453int idx2 = row2*row_stride + threadIdx.x;454warp_buf1[idx1] += warp_buf1[idx2];455warp_buf2[idx1] += warp_buf2[idx2];456}457__syncthreads();458}459int i2 = blockIdx.x * blockDim.x + threadIdx.x;460if (threadIdx.y == 0 && i2 < n2) {461int row1 = threadIdx.y;462int row2 = threadIdx.y + 1;463int idx1 = row1*row_stride + threadIdx.x;464int idx2 = row2*row_stride + threadIdx.x;465part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];466part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];467}468}469470template<typename T, typename U> __global__471void cuComputeGradGammaBeta(472const U* part_grad_gamma,473const U* part_grad_beta,474const int part_size,475const int n1,476const int n2,477T* grad_gamma,478T* grad_beta)479{480// sum partial gradients for gamma and beta481SharedMemory<U> shared;482U* buf = shared.getPointer();483int i2 = blockIdx.x * blockDim.x + threadIdx.x;484if (i2 < n2) {485// each warp does sequential reductions until reduced part_size is num_warps486int num_warp_reductions = part_size / blockDim.y;487U sum_gamma = U(0);488U sum_beta = U(0);489const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;490const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;491for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {492sum_gamma += part_grad_gamma_ptr[warp_offset*n2];493sum_beta += part_grad_beta_ptr[warp_offset*n2];494}495// inter-warp reductions496const int nbsize3 = blockDim.x * blockDim.y / 2;497for (int offset = blockDim.y/2; offset >= 1; offset /= 2) {498// top half write to shared memory499if (threadIdx.y >= offset && threadIdx.y < 2*offset) {500const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;501buf[write_idx] = sum_gamma;502buf[write_idx+nbsize3] = sum_beta;503}504__syncthreads();505// bottom half sums506if (threadIdx.y < offset) {507const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;508sum_gamma += buf[read_idx];509sum_beta += buf[read_idx+nbsize3];510}511__syncthreads();512}513// write out fully summed gradients514if (threadIdx.y == 0) {515grad_gamma[i2] = sum_gamma;516grad_beta[i2] = sum_beta;517}518}519}520521template<typename T, typename U> __global__522void cuComputeGradInput(523const T* __restrict__ dout,524const T* __restrict__ input,525const int n1,526const int n2,527const U* __restrict__ mean,528const U* __restrict__ invvar,529U epsilon,530const T* gamma,531T* grad_input)532{533for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {534U sum_loss1 = U(0);535U sum_loss2 = U(0);536const U c_mean = mean[i1];537const U c_invvar = invvar[i1];538const T* k_input = input + i1*n2;539const T* k_dout = dout + i1*n2;540const int numx = blockDim.x * blockDim.y;541const int thrx = threadIdx.x + threadIdx.y * blockDim.x;542if (gamma != NULL) {543int l = 4*thrx;544for (; l+3 < n2; l+=4*numx) {545for (int k = 0; k < 4; ++k) {546const U c_h = static_cast<U>(k_input[l+k]);547const U c_loss = static_cast<U>(k_dout[l+k]);548sum_loss1 += c_loss * gamma[l+k];549sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar;550}551}552for (; l < n2; ++l) {553const U c_h = static_cast<U>(k_input[l]);554const U c_loss = static_cast<U>(k_dout[l]);555sum_loss1 += c_loss * gamma[l];556sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;557}558} else {559int l = 4*thrx;560for (; l+3 < n2; l+=4*numx) {561for (int k = 0; k < 4; ++k) {562const U c_h = static_cast<U>(k_input[l+k]);563const U c_loss = static_cast<U>(k_dout[l+k]);564sum_loss1 += c_loss;565sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;566}567}568for (; l < n2; ++l) {569const U c_h = static_cast<U>(k_input[l]);570const U c_loss = static_cast<U>(k_dout[l]);571sum_loss1 += c_loss;572sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;573}574}575// intra-warp reductions576for (int mask = blockDim.x/2; mask > 0; mask /= 2) {577sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);578sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);579}580// inter-warp reductions581if (blockDim.y > 1) {582SharedMemory<U> shared;583U* buf = shared.getPointer();584for (int offset = blockDim.y/2; offset > 0; offset /= 2) {585// upper half of warps write to shared586if (threadIdx.y >= offset && threadIdx.y < 2*offset) {587const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;588buf[2*wrt_i] = sum_loss1;589buf[2*wrt_i+1] = sum_loss2;590}591__syncthreads();592// lower half merges593if (threadIdx.y < offset) {594const int read_i = threadIdx.y * blockDim.x + threadIdx.x;595sum_loss1 += buf[2*read_i];596sum_loss2 += buf[2*read_i+1];597}598__syncthreads();599}600if (threadIdx.y == 0) {601buf[2*threadIdx.x] = sum_loss1;602buf[2*threadIdx.x+1] = sum_loss2;603}604__syncthreads();605if (threadIdx.y !=0) {606sum_loss1 = buf[2*threadIdx.x];607sum_loss2 = buf[2*threadIdx.x+1];608}609}610// all threads now have the two sums over l611U fH = (U)n2;612U term1 = (U(1) / fH) * c_invvar;613T* k_grad_input = grad_input + i1*n2;614if (gamma != NULL) {615for (int l = thrx; l < n2; l+=numx) {616const U c_h = static_cast<U>(k_input[l]);617const U c_loss = static_cast<U>(k_dout[l]);618U f_grad_input = fH * c_loss * gamma[l];619f_grad_input -= sum_loss1;620f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;621f_grad_input *= term1;622k_grad_input[l] = static_cast<T>(f_grad_input);623}624} else {625for (int l = thrx; l < n2; l+=numx) {626const U c_h = static_cast<U>(k_input[l]);627const U c_loss = static_cast<U>(k_dout[l]);628U f_grad_input = fH * c_loss;629f_grad_input -= sum_loss1;630f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;631f_grad_input *= term1;632k_grad_input[l] = static_cast<T>(f_grad_input);633}634}635}636}637638template<typename T, typename U>639void HostApplyLayerNorm(640T* output,641U* mean,642U* invvar,643const T* input,644int n1,645int n2,646double epsilon,647const T* gamma,648const T* beta649)650{651auto stream = at::cuda::getCurrentCUDAStream().stream();652const dim3 threads(32,4,1);653const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];654const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);655int nshared =656threads.y > 1 ?657threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :6580;659cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(660output,661mean,662invvar,663input,664n1,n2,665U(epsilon),666gamma,beta);667}668669void cuda_layer_norm(670at::Tensor* output,671at::Tensor* mean,672at::Tensor* invvar,673at::Tensor* input,674int n1,675int n2,676#ifdef VERSION_GE_1_1677at::IntArrayRef normalized_shape,678#else679at::IntList normalized_shape,680#endif681at::Tensor* gamma,682at::Tensor* beta,683double epsilon)684{685using namespace at;686DISPATCH_DOUBLE_FLOAT_AND_HALF(input->scalar_type(), 0, "layer_norm_cuda_kernel",687using accscalar_t = at::acc_type<scalar_t_0, true>;688HostApplyLayerNorm(689output->DATA_PTR<scalar_t_0>(),690mean->DATA_PTR<accscalar_t>(),691invvar->DATA_PTR<accscalar_t>(),692input->DATA_PTR<scalar_t_0>(),693n1,n2,694epsilon,695gamma != NULL ? gamma->DATA_PTR<scalar_t_0>() : NULL,696beta != NULL ? beta->DATA_PTR<scalar_t_0>() : NULL);697)698}699700template<typename T, typename U>701void HostLayerNormGradient(702const T* dout,703const U* mean,704const U* invvar,705at::Tensor* input,706int n1,707int n2,708const T* gamma,709const T* beta,710double epsilon,711T* grad_input,712T* grad_gamma,713T* grad_beta714)715{716auto stream = at::cuda::getCurrentCUDAStream().stream();717718if (gamma != NULL && beta != NULL) {719// compute grad_gamma(j) and grad_beta(j)720const int part_size = 16;721const dim3 threads2(32,4,1);722const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);723const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);724const int nshared2_b = threads2.x * threads2.y * sizeof(U);725const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;726at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(input->scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input->scalar_type()));727at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);728cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(729dout,730input->DATA_PTR<T>(),731n1,n2,732mean,733invvar,734U(epsilon),735part_grad_gamma.DATA_PTR<U>(),736part_grad_beta.DATA_PTR<U>());737738const dim3 threads3(32,8,1);739const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);740const int nshared3 = threads3.x * threads3.y * sizeof(U);741cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(742part_grad_gamma.DATA_PTR<U>(),743part_grad_beta.DATA_PTR<U>(),744part_size,745n1,n2,746grad_gamma,747grad_beta);748}749750// compute grad_input751const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];752const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);753const dim3 threads1(32,4,1);754int nshared =755threads1.y > 1 ?756threads1.y*threads1.x*sizeof(U) :7570;758cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(759dout,760input->DATA_PTR<T>(),761n1,n2,762mean,763invvar,764U(epsilon),765gamma,766grad_input);767}768769void cuda_layer_norm_gradient(770at::Tensor* dout,771at::Tensor* mean,772at::Tensor* invvar,773at::Tensor* input,774int n1,775int n2,776#ifdef VERSION_GE_1_1777at::IntArrayRef normalized_shape,778#else779at::IntList normalized_shape,780#endif781at::Tensor* gamma,782at::Tensor* beta,783double epsilon,784at::Tensor* grad_input,785at::Tensor* grad_gamma,786at::Tensor* grad_beta)787{788using namespace at;789DISPATCH_FLOAT_AND_HALF(input->scalar_type(), 0, "cuComputeGradInput",790using accscalar_t = at::acc_type<scalar_t_0, true>;791HostLayerNormGradient(792dout->DATA_PTR<scalar_t_0>(),793mean->DATA_PTR<accscalar_t>(),794invvar->DATA_PTR<accscalar_t>(),795input,796n1,n2,797// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta798// if gamma Tensor is NULL on input.799gamma != NULL ? gamma->DATA_PTR<scalar_t_0>() : NULL,800gamma != NULL ? beta->DATA_PTR<scalar_t_0>() : NULL,801epsilon,802grad_input->DATA_PTR<scalar_t_0>(),803gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_0>() : NULL,804gamma != NULL ? grad_beta->DATA_PTR<scalar_t_0>() : NULL);805)806}807808809