Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/csrc/welford.cu
Views: 792
#include <iostream>1#include <ATen/ATen.h>2#include <ATen/AccumulateType.h>3#include <ATen/cuda/CUDAContext.h>45#include <cuda.h>6#include <cuda_runtime.h>78#include <vector>910#include "type_shim.h"11#include "compat.h"121314__device__ __forceinline__ int lastpow2(int n)15{16int out = 1 << (31 - __clz(n));17if(n == out)18out >>= 1;19return out;20}2122__host__ __forceinline__ int h_next_pow2(unsigned int n) {23n--;24n |= (n >> 1);25n |= (n >> 2);26n |= (n >> 4);27n |= (n >> 8);28n |= (n >> 16);29return ++n;30}3132__host__ __forceinline__ int h_last_pow2(unsigned int n) {33n |= (n >> 1);34n |= (n >> 2);35n |= (n >> 4);36n |= (n >> 8);37n |= (n >> 16);38return n - (n >> 1);39}404142#define WARP_SIZE 324344template<typename T>45__device__ __forceinline__ T warp_reduce_sum(T val)46{47#pragma unroll48for(int i = WARP_SIZE/2; i > 0; i >>= 1)49val = val + __shfl_down_sync(0xffffffff, val, i);50return val;51}5253template<typename T>54__device__ __forceinline__ T reduce_block(T *x, T val)55{56int tid = threadIdx.y*blockDim.x + threadIdx.x;57int blockSize = blockDim.x * blockDim.y;5859if (blockSize > 32) {60val = warp_reduce_sum(val);61if (tid % WARP_SIZE == 0)62x[tid/WARP_SIZE] = val;6364__syncthreads();6566val = (tid < blockSize / WARP_SIZE? x[tid%WARP_SIZE] : T(0));67}6869if(tid/WARP_SIZE==0) val = warp_reduce_sum(val);7071return val;72}7374#define ELEMENTS_PER_ITER 4 // enables concurrency within each thread to hide latency75#define ELEMENTS_PER_THREAD 1676#define OPTIMAL_TILE_W 3277#define MAX_H_BLOCK 12878#define MAX_BLOCK_SIZE 5127980__host__ int div_ru(int x, int y) {81return h_last_pow2(1 + (x-1)/y);82}8384__host__ void flexible_launch_configs(85const int reduction,86const int stride,87dim3 &block,88dim3 &grid,89const bool coop_flag = false) {90int block_x = std::min(h_last_pow2(stride), OPTIMAL_TILE_W);91int block_y = std::min(h_last_pow2(div_ru(reduction , ELEMENTS_PER_THREAD)),92MAX_BLOCK_SIZE / block_x);93if (block_x * block_y != MAX_BLOCK_SIZE) {94block_x = std::min(h_last_pow2(stride), MAX_BLOCK_SIZE / block_y);95}9697int grid_x = div_ru(stride, block_x);98int grid_y = std::min(div_ru(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK);99if (coop_flag) {100// it's not worth having a grid reduction if the reduction dimension is not big enough101grid_y = grid_y < 8 ? 1 : grid_y;102}103104block.x = block_x;105block.y = block_y;106block.z = 1;107grid.x = grid_x;108grid.y = grid_y;109grid.z = 1;110}111112template<typename T, typename C>113__device__ __forceinline__ void welford_merge_element(C& count,114T& mean,115T& m2n,116const C& num_new,117const T& mean_new,118const T& m2n_new) {119T factor = T(1.0) / max(1, (count + num_new));120T delta0 = mean - mean_new;121mean = (mean_new * num_new + mean * count) * factor;122m2n += m2n_new + delta0 * delta0 * num_new * count * factor;123count += num_new;124}125126template<typename T>127__device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num)128{129#pragma unroll130for(int i = WARP_SIZE/2; i > 0; i >>= 1) {131auto num_new = __shfl_down_sync(0xffffffff, num, i);132auto mean_new = __shfl_down_sync(0xffffffff, mean, i);133auto m2n_new = __shfl_down_sync(0xffffffff, m2n, i);134welford_merge_element(num, mean, m2n, num_new, mean_new, m2n_new);135}136}137138template <typename T>139__device__ void welford_reduce_mean_m2n(140T* __restrict__ x,141int* __restrict__ count,142T &mean,143T &m2n,144int &num,145int block_size,146int thread_id)147{148int lane = thread_id % WARP_SIZE;149int wid = thread_id / WARP_SIZE;150151if (block_size > 32) {152warp_reduce_mean_m2n(mean, m2n, num);153if (lane == 0) {154x[wid*2] = mean;155x[wid*2+1] = m2n;156count[wid] = num;157}158__syncthreads();159160if (wid == 0) {161mean = (thread_id < block_size / WARP_SIZE)? x[lane*2] : T(0);162m2n = (thread_id < block_size / WARP_SIZE)? x[lane*2+1] : T(0);163num = (thread_id < block_size / WARP_SIZE)? count[lane] : int(0);164}165}166167if (wid==0) warp_reduce_mean_m2n(mean, m2n, num);168169return;170}171172// return spatial size for NC+ Tensors173__host__ int get_tensor_spatial_size(const at::Tensor& input)174{175auto space_size = input.size(2);176for (int i = 3; i < input.ndimension(); i++) {177space_size *= input.size(i);178}179return space_size;180}181182// promote accumulation scalar type. promote half to float.183__host__ at::ScalarType promote_scalartype(const at::Tensor& input)184{185return input.scalar_type() == at::ScalarType::Half ?186at::ScalarType::Float : input.scalar_type();187}188189// return single element size, optional accumulation type promotion.190__host__ size_t get_element_data_size(const at::Tensor& input, bool accumulation = false)191{192auto scalar_type = accumulation ? promote_scalartype(input) : input.scalar_type();193return at::elementSize(scalar_type);194}195196template<typename T, typename C>197__device__ __forceinline__ void welford_merge_block_vertical(C& count,198T& mean,199T& m2n,200C* shmem_count,201T* shmem_mean,202T* shmem_m2n) {203// write to shared memory204auto address_base = threadIdx.x + threadIdx.y * blockDim.x;205shmem_mean[address_base] = mean;206shmem_m2n[address_base] = m2n;207shmem_count[address_base] = count;208209#pragma unroll210for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {211__syncthreads();212if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {213auto address = address_base + offset * blockDim.x;214// read shared memory back to register for reduction215auto num_new = shmem_count[address];216auto mean_new = shmem_mean[address];217auto m2n_new = shmem_m2n[address];218219welford_merge_element(count, mean, m2n, num_new, mean_new, m2n_new);220221// last write is not necessary222shmem_mean[address_base] = mean;223shmem_m2n[address_base] = m2n;224shmem_count[address_base] = count;225}226}227}228229template<typename T>230__device__ __forceinline__ void merge_block_vertical(T& sum_dy,231T& sum_dy_xmu,232T* shmem_sum_dy,233T* shmem_sum_dy_xmu) {234// write to shared memory235auto address_base = threadIdx.x + threadIdx.y * blockDim.x;236shmem_sum_dy[address_base] = sum_dy;237shmem_sum_dy_xmu[address_base] = sum_dy_xmu;238239#pragma unroll240for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {241__syncthreads();242if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {243auto address = address_base + offset * blockDim.x;244245sum_dy += shmem_sum_dy[address];246sum_dy_xmu += shmem_sum_dy_xmu[address];247248// last write is not necessary249shmem_sum_dy[address_base] = sum_dy;250shmem_sum_dy_xmu[address_base] = sum_dy_xmu;251}252}253}254255256// welford kernel calculating mean/biased_variance/unbiased_variance257template <typename scalar_t, typename accscalar_t, typename outscalar_t>258__global__ void welford_kernel(259const scalar_t* __restrict__ input,260outscalar_t* __restrict__ out_mean,261outscalar_t* __restrict__ out_var_biased,262const int bs,263const int fs,264const int ss) {265int block_size = blockDim.x * blockDim.y;266int count = 0;267accscalar_t x_mean = accscalar_t(0);268accscalar_t m_2_n = accscalar_t(0);269270int thread_id = threadIdx.y*blockDim.x + threadIdx.x;271272for (int batch_id = threadIdx.y; batch_id < bs; batch_id += blockDim.y) {273int input_base = blockIdx.x*ss + batch_id*ss*fs;274// sequential welford275for (int offset = threadIdx.x; offset < ss ; offset += blockDim.x) {276count++;277auto x_n = static_cast<accscalar_t>(input[offset+input_base]);278auto d = x_n - x_mean;279x_mean += d / count;280m_2_n += d * (x_n - x_mean);281}282}283284static __shared__ int s_mem[160];285accscalar_t* s_mem_ac = (accscalar_t*) &s_mem[32];286287welford_reduce_mean_m2n<accscalar_t>(s_mem_ac, s_mem, x_mean, m_2_n, count, block_size, thread_id);288289if (thread_id == 0) {290out_mean[blockIdx.x] = static_cast<outscalar_t>(x_mean);291out_var_biased[blockIdx.x] = static_cast<outscalar_t>(m_2_n/count);292}293}294295// elementwise BN kernel296template <typename scalar_t, typename accscalar_t, typename layerscalar_t>297__global__ void batchnorm_forward_kernel(298const scalar_t* __restrict__ input,299const accscalar_t* __restrict__ mean,300const accscalar_t* __restrict__ inv_std,301const layerscalar_t* __restrict__ weight,302const layerscalar_t* __restrict__ shift,303scalar_t* __restrict__ out,304const int ss,305const int bs) {306auto m_c = mean[blockIdx.x];307auto inv_std_c = inv_std[blockIdx.x];308auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[blockIdx.x]);309auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[blockIdx.x]);310311for (int batch_offset = blockIdx.y*blockDim.y + threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) {312int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss;313for (int offset = threadIdx.x + blockIdx.z*blockDim.x; offset < ss ; offset+= gridDim.z*blockDim.x) {314out[address_base+offset] = static_cast<scalar_t>(w_c * (static_cast<accscalar_t>(input[address_base+offset]) - m_c ) * inv_std_c + s_c);315}316}317}318319// Backward BN kernel, calculates grad_bias, grad_weight as well as intermediate320// results to calculating grad_input.321// Breaking the grad_input to two step to support sync BN, which requires all322// reduce of the intermediate results across processes.323template <typename scalar_t, typename accscalar_t, typename layerscalar_t>324__global__ void reduce_bn_kernel(325const scalar_t* __restrict__ input,326const scalar_t* __restrict__ grad_output,327const accscalar_t* __restrict__ mean,328const accscalar_t* __restrict__ inv_std,329accscalar_t* __restrict__ sum_dy_o,330accscalar_t* __restrict__ sum_dy_xmu_o,331layerscalar_t* __restrict__ grad_weight,332layerscalar_t* __restrict__ grad_bias,333const int bs,334const int fs,335const int ss) {336static __shared__ int s_mem[64];337//int total_item_num = bs * ss;338339int thread_id = threadIdx.y*blockDim.x + threadIdx.x;340341auto r_mean = mean[blockIdx.x];342auto factor = inv_std[blockIdx.x];343344// Kahan sum345accscalar_t sum_dy = 0.0;346accscalar_t sum_dy_xmu = 0.0;347accscalar_t sum_dy_c = 0.0;348accscalar_t sum_dy_xmu_c = 0.0;349for (int batch_id = threadIdx.y; batch_id < bs; batch_id += blockDim.y) {350int input_base = blockIdx.x*ss + batch_id*ss*fs;351for (int offset = threadIdx.x; offset < ss ; offset += blockDim.x) {352auto e_grad = static_cast<accscalar_t>(grad_output[offset+input_base]);353auto e_input = static_cast<accscalar_t>(input[offset+input_base]);354// calculating sum_dy355auto sum_dy_y = e_grad - sum_dy_c;356auto sum_dy_t = sum_dy + sum_dy_y;357sum_dy_c = (sum_dy_t - sum_dy) - sum_dy_y;358sum_dy = sum_dy_t;359360// calculating sum_dy_xmu361auto sum_dy_xmu_y = e_grad * (e_input - r_mean) - sum_dy_xmu_c;362auto sum_dy_xmu_t = sum_dy_xmu + sum_dy_xmu_y;363sum_dy_xmu_c = (sum_dy_xmu_t - sum_dy_xmu) - sum_dy_xmu_y;364sum_dy_xmu = sum_dy_xmu_t;365}366}367368sum_dy = reduce_block((accscalar_t*)s_mem, sum_dy);369__syncthreads();370sum_dy_xmu = reduce_block((accscalar_t*)s_mem, sum_dy_xmu);371372if (thread_id == 0) {373if (grad_bias != NULL) {374grad_bias[blockIdx.x] = static_cast<layerscalar_t>(sum_dy);375}376if (grad_weight != NULL) {377grad_weight[blockIdx.x] = static_cast<layerscalar_t>(sum_dy_xmu * factor);378}379//mean_dy[blockIdx.x] = sum_dy / total_item_num;380//mean_dy_xmu[blockIdx.x] = sum_dy_xmu / total_item_num;381sum_dy_o[blockIdx.x] = sum_dy;382sum_dy_xmu_o[blockIdx.x] = sum_dy_xmu;383}384}385386// elementwise backward BN kernel387template <typename scalar_t, typename accscalar_t, typename layerscalar_t>388__global__ void batchnorm_backward_kernel(389const scalar_t* __restrict__ grad_output,390const scalar_t* __restrict__ input,391const accscalar_t* __restrict__ mean,392const accscalar_t* __restrict__ inv_std,393const layerscalar_t* __restrict__ weight,394const accscalar_t* __restrict__ sum_dy,395const accscalar_t* __restrict__ sum_dy_xmu,396const int* __restrict__ numel,397scalar_t* __restrict__ grad_input,398const int64_t world_size,399const int ss,400const int bs) {401int64_t div = 0;402for (int i = 0; i < world_size; i++) {403div += numel[i];404}405auto m_c = static_cast<accscalar_t>(mean[blockIdx.x]);406//auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]);407auto m_dy_c = static_cast<accscalar_t>(sum_dy[blockIdx.x]) / div;408auto factor_1_c = inv_std[blockIdx.x];409auto factor_2_c = (weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[blockIdx.x])) * factor_1_c;410//factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[blockIdx.x];411factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[blockIdx.x] / div;412413for (int batch_offset = blockIdx.y*blockDim.y+threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) {414int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss;415for (int offset = threadIdx.x + blockIdx.z*blockDim.x; offset < ss ; offset+= gridDim.z*blockDim.x) {416grad_input[address_base+offset] = (static_cast<accscalar_t>(grad_output[address_base+offset]) - m_dy_c - (static_cast<accscalar_t>(input[address_base+offset]) - m_c) * factor_1_c) * factor_2_c;417}418}419}420421// welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance422template423<typename scalar_t,424typename accscalar_t,425typename outscalar_t,426int PARALLEL_LOADS>427__global__ void428welford_kernel_c_last(429const scalar_t* __restrict__ input,430outscalar_t* __restrict__ out_mean,431outscalar_t* __restrict__ out_var_biased,432volatile accscalar_t* staging_data,433int* semaphores,434const int reduction_size,435const int stride) {436// hide latency with concurrency437accscalar_t x_mean[PARALLEL_LOADS];438accscalar_t m_2_n[PARALLEL_LOADS];439int count[PARALLEL_LOADS];440441#pragma unroll442for (int i = 0; i < PARALLEL_LOADS; i++) {443x_mean[i] = accscalar_t(0);444m_2_n[i] = accscalar_t(0);445count[i] = accscalar_t(0);446}447// tensor dimension (m,c)448449// loop along m dimension450int inner_loop_stride = blockDim.y * gridDim.y;451452// offset along m dimension453int m_offset = blockIdx.y * blockDim.y + threadIdx.y;454int c_offset = blockIdx.x * blockDim.x + threadIdx.x;455456int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);457int address_base = m_offset * stride + c_offset;458int address_increment = inner_loop_stride * stride;459460for (int i = 0; i < loop_count; i++) {461accscalar_t x_math[PARALLEL_LOADS];462accscalar_t x_count_inv[PARALLEL_LOADS];463accscalar_t is_valid[PARALLEL_LOADS];464465// load multiple data in466#pragma unroll467for (int j = 0; j < PARALLEL_LOADS; j++) {468if (c_offset < stride && m_offset < reduction_size) {469x_math[j] = input[address_base];470count[j]++;471x_count_inv[j] = accscalar_t(1) / count[j];472is_valid[j] = accscalar_t(1);473} else {474x_math[j] = accscalar_t(0);475x_count_inv[j] = accscalar_t(0);476is_valid[j] = accscalar_t(0);477}478m_offset += inner_loop_stride;479address_base += address_increment;480}481482// calculate mean/m2n with welford483#pragma unroll484for (int j = 0; j < PARALLEL_LOADS; j++) {485accscalar_t delta0 = x_math[j] - x_mean[j];486x_mean[j] += delta0 * x_count_inv[j];487accscalar_t delta1 = x_math[j] - x_mean[j];488m_2_n[j] += delta0 * delta1 * is_valid[j];489}490}491492// thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS493#pragma unroll494for (int j = 1; j < PARALLEL_LOADS; j++) {495welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]);496}497498// release x_mean / m_2_n499auto mean_th = x_mean[0];500auto m2_th = m_2_n[0];501auto count_th = count[0];502503// block-wise reduction with shared memory (since reduction cannot be done within a warp)504static __shared__ accscalar_t shmem_mean[MAX_BLOCK_SIZE];505static __shared__ accscalar_t shmem_m2n[MAX_BLOCK_SIZE];506static __shared__ int shmem_count[MAX_BLOCK_SIZE];507508welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);509510// grid reduction if needed (coop launch used at the first place)511if (gridDim.y > 1) {512volatile accscalar_t* staging_mean = staging_data;513volatile accscalar_t* staging_m2n = &staging_data[stride*gridDim.y];514volatile int* staging_count = reinterpret_cast<volatile int*>(&staging_m2n[stride*gridDim.y]);515516address_base = c_offset + blockIdx.y * stride;517// write data to staging_data;518if (threadIdx.y == 0 && c_offset < stride) {519staging_mean[address_base] = mean_th;520staging_m2n[address_base] = m2_th;521staging_count[address_base] = count_th;522}523524__threadfence();525__syncthreads(); // ensuring writes to staging_ is visible to all blocks526527__shared__ bool is_last_block_done;528// mark block done529if (threadIdx.x == 0 && threadIdx.y == 0) {530int old = atomicAdd(&semaphores[blockIdx.x], 1);531is_last_block_done = (old == (gridDim.y-1));532}533534__syncthreads();535536// check that all data is now available in global memory537if (is_last_block_done) {538count_th = 0;539mean_th = accscalar_t(0.0);540m2_th = accscalar_t(0.0);541542for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {543address_base = c_offset + y * stride;544int num_new = c_offset < stride ? staging_count[address_base] : 0;545accscalar_t mean_new = c_offset < stride ? staging_mean[address_base] : accscalar_t(0.0);546accscalar_t m2n_new = c_offset < stride ? staging_m2n[address_base] : accscalar_t(0.0);547548welford_merge_element(count_th, mean_th, m2_th, num_new, mean_new, m2n_new);549}550551welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);552if (threadIdx.y == 0 && c_offset < stride) {553out_mean[c_offset] = static_cast<outscalar_t>(mean_th);554out_var_biased[c_offset] = static_cast<outscalar_t>(m2_th / count_th);555}556}557} else {558if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {559out_mean[c_offset] = static_cast<outscalar_t>(mean_th);560out_var_biased[c_offset] = static_cast<outscalar_t>(m2_th / count_th);561}562}563}564565// parallel welford kernel to further reduce mean / biased_var566// into mean / unbiased_var / inv_std across multiple processes.567template <typename scalar_t>568__global__ void welford_kernel_parallel(569const scalar_t* __restrict__ mean,570const scalar_t* __restrict__ var_biased,571const int* __restrict__ numel,572scalar_t* __restrict__ out_mean,573scalar_t* __restrict__ out_var,574scalar_t* __restrict__ inv_std,575const int world_size,576const int feature_size,577const float eps) {578579for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < feature_size; i += gridDim.x * blockDim.x) {580// load data;581int address = i;582scalar_t x_mean = 0;583scalar_t m_2_n = 0;584int count = 0;585for (int j = 0; j < world_size; j++) {586welford_merge_element(count, x_mean, m_2_n, numel[j], mean[address], var_biased[address]*numel[j]);587address += feature_size;588}589out_mean[i] = x_mean;590out_var[i] = m_2_n/ (count - 1);591inv_std[i] = scalar_t(1) / sqrt(m_2_n/count + eps);592}593}594595// elementwise BN kernel596template <597typename scalar_t,598typename accscalar_t,599typename layerscalar_t,600int PARALLEL_LOADS>601__global__ void batchnorm_forward_c_last_kernel(602const scalar_t* __restrict__ input,603const scalar_t* __restrict__ z,604const accscalar_t* __restrict__ mean,605const accscalar_t* __restrict__ inv_std,606const layerscalar_t* __restrict__ weight,607const layerscalar_t* __restrict__ shift,608scalar_t* __restrict__ out,609const int reduction_size,610const int stride,611const bool fuse_relu) {612// tensor dimension (m,c)613// loop along m dimension614int inner_loop_stride = blockDim.y * gridDim.y;615616// offset along m dimension617int m_offset = blockIdx.y * blockDim.y + threadIdx.y;618int c_offset = blockIdx.x * blockDim.x + threadIdx.x;619620auto m_c = mean[c_offset];621auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);622auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);623auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);624625int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);626int address_base = m_offset * stride + c_offset;627int address_increment = inner_loop_stride * stride;628629for (int i = 0; i < loop_count; i++) {630#pragma unroll631for (int j = 0; j < PARALLEL_LOADS; j++) {632if (c_offset < stride && m_offset < reduction_size) {633auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;634if (z != NULL) {635tmp += z[address_base];636}637out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast<scalar_t>(tmp));638}639m_offset += inner_loop_stride;640address_base += address_increment;641}642}643}644645// elementwise BN kernel646template <647typename scalar_t,648typename accscalar_t,649typename layerscalar_t,650int PARALLEL_LOADS>651__global__ void relu_backward_c_last_kernel(652const scalar_t* __restrict__ grad_output,653const scalar_t* __restrict__ input,654const scalar_t* __restrict__ z,655const accscalar_t* __restrict__ mean,656const accscalar_t* __restrict__ inv_std,657const layerscalar_t* __restrict__ weight,658const layerscalar_t* __restrict__ shift,659scalar_t* __restrict__ out,660const int reduction_size,661const int stride) {662// tensor dimension (m,c)663// loop along m dimension664int inner_loop_stride = blockDim.y * gridDim.y;665666// offset along m dimension667int m_offset = blockIdx.y * blockDim.y + threadIdx.y;668int c_offset = blockIdx.x * blockDim.x + threadIdx.x;669670auto m_c = mean[c_offset];671auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);672auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);673auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);674675int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);676int address_base = m_offset * stride + c_offset;677int address_increment = inner_loop_stride * stride;678679for (int i = 0; i < loop_count; i++) {680#pragma unroll681for (int j = 0; j < PARALLEL_LOADS; j++) {682if (c_offset < stride && m_offset < reduction_size) {683auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;684if (z != NULL) {685tmp += z[address_base];686}687out[address_base] = (tmp <= accscalar_t(0.0) ? scalar_t(0.0) : grad_output[address_base]);688}689m_offset += inner_loop_stride;690address_base += address_increment;691}692}693}694695// batchnorm backward kernel for c last tensor696template697<typename scalar_t,698typename accscalar_t,699typename layerscalar_t,700int PARALLEL_LOADS>701__global__ void reduce_bn_c_last_kernel(702const scalar_t* __restrict__ input,703const scalar_t* __restrict__ grad_output,704const accscalar_t* __restrict__ mean,705const accscalar_t* __restrict__ inv_std,706accscalar_t* __restrict__ sum_dy_o,707accscalar_t* __restrict__ sum_dy_xmu_o,708layerscalar_t* __restrict__ grad_weight,709layerscalar_t* __restrict__ grad_bias,710volatile accscalar_t* staging_data,711int* semaphores,712const int reduction_size,713const int stride) {714715// hide latency with concurrency716accscalar_t sum_dy[PARALLEL_LOADS];717accscalar_t sum_dy_xmu[PARALLEL_LOADS];718719#pragma unroll720for (int i = 0; i < PARALLEL_LOADS; i++) {721sum_dy[i] = accscalar_t(0);722sum_dy_xmu[i] = accscalar_t(0);723}724// tensor dimension (m,c)725726// loop along m dimension727int inner_loop_stride = blockDim.y * gridDim.y;728729// offset along m dimension730int m_offset = blockIdx.y * blockDim.y + threadIdx.y;731int c_offset = blockIdx.x * blockDim.x + threadIdx.x;732733int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);734int address_base = m_offset * stride + c_offset;735int address_increment = inner_loop_stride * stride;736737auto r_mean = mean[c_offset];738auto factor = inv_std[c_offset];739740for (int i = 0; i < loop_count; i++) {741accscalar_t x_input[PARALLEL_LOADS];742accscalar_t x_grad_output[PARALLEL_LOADS];743744// load multiple data in745#pragma unroll746for (int j = 0; j < PARALLEL_LOADS; j++) {747if (c_offset < stride && m_offset < reduction_size) {748x_input[j] = input[address_base];749x_grad_output[j] = grad_output[address_base];750} else {751x_input[j] = accscalar_t(0);752x_grad_output[j] = accscalar_t(0);753}754m_offset += inner_loop_stride;755address_base += address_increment;756}757758// calculate sum_dy / sum_dy_xmu759#pragma unroll760for (int j = 0; j < PARALLEL_LOADS; j++) {761sum_dy[j] += x_grad_output[j];762sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean);763}764}765766// thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS767#pragma unroll768for (int j = 1; j < PARALLEL_LOADS; j++) {769sum_dy[0] += sum_dy[j];770sum_dy_xmu[0] += sum_dy_xmu[j];771}772773// release array of registers774auto sum_dy_th = sum_dy[0];775auto sum_dy_xmu_th = sum_dy_xmu[0];776777// block-wise reduction with shared memory (since reduction cannot be done within a warp)778static __shared__ accscalar_t shmem_sum_dy[MAX_BLOCK_SIZE];779static __shared__ accscalar_t shmem_sum_dy_xmu[MAX_BLOCK_SIZE];780781merge_block_vertical(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);782783// grid reduction if needed (coop launch used at the first place)784if (gridDim.y > 1) {785volatile accscalar_t* staging_sum_dy = staging_data;786volatile accscalar_t* staging_sum_dy_xmu = &staging_data[stride*gridDim.y];787788address_base = c_offset + blockIdx.y * stride;789// write data to staging_data;790if (threadIdx.y == 0 && c_offset < stride) {791staging_sum_dy[address_base] = sum_dy_th;792staging_sum_dy_xmu[address_base] = sum_dy_xmu_th;793}794795__threadfence();796__syncthreads(); // ensuring writes to staging_ is visible to all blocks797798__shared__ bool is_last_block_done;799// mark block done800if (threadIdx.x == 0 && threadIdx.y == 0) {801int old = atomicAdd(&semaphores[blockIdx.x], 1);802is_last_block_done = (old == (gridDim.y-1));803}804805__syncthreads();806807// check that all data is now available in global memory808if (is_last_block_done) {809sum_dy_th = accscalar_t(0.0);810sum_dy_xmu_th = accscalar_t(0.0);811812for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {813address_base = c_offset + y * stride;814sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : accscalar_t(0.0));815sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : accscalar_t(0.0));816}817818merge_block_vertical(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);819if (threadIdx.y == 0 && c_offset < stride) {820if (grad_bias != NULL) {821grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);822}823if (grad_weight != NULL) {824grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);825}826//mean_dy[c_offset] = sum_dy_th / reduction_size;827//mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;828sum_dy_o[c_offset] = sum_dy_th;829sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;830}831}832} else {833if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {834if (grad_bias != NULL) {835grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);836}837if (grad_weight != NULL) {838grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);839}840//mean_dy[c_offset] = sum_dy_th / reduction_size;841//mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;842sum_dy_o[c_offset] = sum_dy_th;843sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;844}845}846}847848// elementwise BN kernel849template <850typename scalar_t,851typename accscalar_t,852typename layerscalar_t,853int PARALLEL_LOADS>854__global__ void batchnorm_backward_c_last_kernel(855const scalar_t* __restrict__ grad_output,856const scalar_t* __restrict__ input,857const accscalar_t* __restrict__ mean,858const accscalar_t* __restrict__ inv_std,859const layerscalar_t* __restrict__ weight,860const accscalar_t* __restrict__ sum_dy,861const accscalar_t* __restrict__ sum_dy_xmu,862const int* __restrict__ numel,863scalar_t* __restrict__ grad_input,864const int64_t world_size,865const int reduction_size,866const int stride) {867int64_t div = 0;868for (int i = 0; i < world_size; i++) {869div += numel[i];870}871// tensor dimension (m,c)872// loop along m dimension873int inner_loop_stride = blockDim.y * gridDim.y;874875// offset along m dimension876int m_offset = blockIdx.y * blockDim.y + threadIdx.y;877int c_offset = blockIdx.x * blockDim.x + threadIdx.x;878879auto m_c = mean[c_offset];880auto m_dy_c = sum_dy[c_offset] / div;881auto factor_1_c = inv_std[c_offset];882auto factor_2_c = (weight == NULL? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset])) * factor_1_c;883factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] / div;884885int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);886int address_base = m_offset * stride + c_offset;887int address_increment = inner_loop_stride * stride;888889for (int i = 0; i < loop_count; i++) {890#pragma unroll891for (int j = 0; j < PARALLEL_LOADS; j++) {892if (c_offset < stride && m_offset < reduction_size) {893grad_input[address_base] = static_cast<scalar_t>(894(static_cast<accscalar_t>(grad_output[address_base]) - m_dy_c -895(static_cast<accscalar_t>(input[address_base]) - m_c) * factor_1_c)896* factor_2_c);897}898m_offset += inner_loop_stride;899address_base += address_increment;900}901}902}903904std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {905const auto batch_size = input.size(0);906const auto feature_size = input.size(1);907908auto space_size = get_tensor_spatial_size(input);909auto scalar_type = promote_scalartype(input);910911at::Tensor out_var_biased = at::empty({feature_size}, input.options().dtype(scalar_type));912at::Tensor out_mean = at::empty({feature_size}, input.options().dtype(scalar_type));913914int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / 32));915int block_x = max(1, min(MAX_BLOCK_SIZE / block_y, h_last_pow2(space_size)));916const dim3 block(block_x, block_y);917const dim3 grid(feature_size);918919auto stream = at::cuda::getCurrentCUDAStream();920921{922using namespace at;923DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "welford_mean_var_kernel",924using accscalar_t = at::acc_type<scalar_t_0, true>;925welford_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(926input.DATA_PTR<scalar_t_0>(),927out_mean.DATA_PTR<accscalar_t>(),928out_var_biased.DATA_PTR<accscalar_t>(),929batch_size,930feature_size,931space_size);932);933}934935return {out_mean, out_var_biased};936}937938at::Tensor batchnorm_forward_CUDA(939const at::Tensor input,940const at::Tensor mean,941const at::Tensor inv_std,942const at::optional<at::Tensor> weight,943const at::optional<at::Tensor> shift) {944const auto batch_size = input.size(0);945const auto feature_size = input.size(1);946at::Tensor out = at::empty_like(input);947948auto space_size = get_tensor_spatial_size(input);949950int block_x = max(32, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4));951int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4));952const dim3 block(block_x, block_y);953int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x));954int batch_group_size = max(1, min(65535, h_last_pow2(batch_size)/block_y));955const dim3 grid(feature_size, batch_group_size, grid_z);956auto stream = at::cuda::getCurrentCUDAStream();957958if (input.scalar_type() == at::ScalarType::Half959&& weight.has_value() &&960weight.value().scalar_type() == at::ScalarType::Float) {961using namespace at;962DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",963using accscalar_t = at::acc_type<scalar_t_0, true>;964batchnorm_forward_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(965input.DATA_PTR<scalar_t_0>(),966mean.DATA_PTR<accscalar_t>(),967inv_std.DATA_PTR<accscalar_t>(),968weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,969shift.has_value() ? shift.value().DATA_PTR<accscalar_t>() : NULL,970out.DATA_PTR<scalar_t_0>(),971space_size,972batch_size);973);974} else {975if (weight.has_value()) {976TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),977"input.scalar_type() is not supported with weight.scalar_type()");978}979using namespace at;980DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",981using accscalar_t = at::acc_type<scalar_t_0, true>;982batchnorm_forward_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(983input.DATA_PTR<scalar_t_0>(),984mean.DATA_PTR<accscalar_t>(),985inv_std.DATA_PTR<accscalar_t>(),986weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,987shift.has_value() ? shift.value().DATA_PTR<scalar_t_0>() : NULL,988out.DATA_PTR<scalar_t_0>(),989space_size,990batch_size);991);992}993return out;994}995996std::vector<at::Tensor> reduce_bn_CUDA(997const at::Tensor grad_output,998const at::Tensor input,999const at::Tensor mean,1000const at::Tensor inv_std,1001const at::optional<at::Tensor> weight)1002{1003const auto batch_size = input.size(0);1004const auto feature_size = input.size(1);10051006auto scalar_type = promote_scalartype(input);10071008at::Tensor sum_dy = at::empty({feature_size}, mean.options());1009at::Tensor sum_dy_xmu = at::empty({feature_size}, mean.options());10101011at::Tensor grad_weight;1012at::Tensor grad_bias;1013if (weight.has_value()) {1014grad_weight = at::empty({feature_size}, weight.value().options());1015grad_bias = at::empty({feature_size}, weight.value().options());1016} else {1017grad_weight = at::empty({0}, mean.options());1018grad_bias = at::empty({0}, mean.options());1019}10201021auto space_size = get_tensor_spatial_size(input);10221023int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE/ 32));1024int block_x = max(1, min(MAX_BLOCK_SIZE/ block_y, h_last_pow2(space_size)));1025const dim3 block(block_x, block_y);1026const dim3 grid(feature_size);1027auto stream = at::cuda::getCurrentCUDAStream();10281029if (input.scalar_type() == at::ScalarType::Half1030&& weight.has_value() &&1031weight.value().scalar_type() == at::ScalarType::Float) {1032using namespace at;1033DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",1034using accscalar_t = at::acc_type<scalar_t_0, true>;1035reduce_bn_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(1036input.DATA_PTR<scalar_t_0>(),1037grad_output.DATA_PTR<scalar_t_0>(),1038mean.DATA_PTR<accscalar_t>(),1039inv_std.DATA_PTR<accscalar_t>(),1040sum_dy.DATA_PTR<accscalar_t>(),1041sum_dy_xmu.DATA_PTR<accscalar_t>(),1042weight.has_value() ? grad_weight.DATA_PTR<accscalar_t>() : NULL,1043weight.has_value() ? grad_bias.DATA_PTR<accscalar_t>() : NULL,1044batch_size,1045feature_size,1046space_size);1047);1048} else {1049if (weight.has_value()) {1050TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),1051"input.scalar_type() is not supported with weight.scalar_type()");1052}1053using namespace at;1054DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",1055using accscalar_t = at::acc_type<scalar_t_0, true>;1056reduce_bn_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(1057input.DATA_PTR<scalar_t_0>(),1058grad_output.DATA_PTR<scalar_t_0>(),1059mean.DATA_PTR<accscalar_t>(),1060inv_std.DATA_PTR<accscalar_t>(),1061sum_dy.DATA_PTR<accscalar_t>(),1062sum_dy_xmu.DATA_PTR<accscalar_t>(),1063weight.has_value() ? grad_weight.DATA_PTR<scalar_t_0>() : NULL,1064weight.has_value() ? grad_bias.DATA_PTR<scalar_t_0>() : NULL,1065batch_size,1066feature_size,1067space_size);1068);1069}10701071return {sum_dy, sum_dy_xmu, grad_weight, grad_bias};1072}10731074at::Tensor batchnorm_backward_CUDA(1075const at::Tensor grad_output,1076const at::Tensor input,1077const at::Tensor mean,1078const at::Tensor inv_std,1079const at::optional<at::Tensor> weight,1080const at::Tensor sum_dy,1081const at::Tensor sum_dy_xmu,1082const at::Tensor count) {1083const auto batch_size = input.size(0);1084const auto feature_size = input.size(1);10851086at::Tensor grad_input = at::empty_like(input);10871088auto space_size = get_tensor_spatial_size(input);10891090int block_x = max(32, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4));1091int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4));1092const dim3 block(block_x, block_y);1093int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x));1094int batch_group_size = max(1, min(65535, h_last_pow2(batch_size)/block_y));1095const dim3 grid(feature_size, batch_group_size, grid_z);10961097auto stream = at::cuda::getCurrentCUDAStream();10981099if (input.scalar_type() == at::ScalarType::Half1100&& weight.has_value() &&1101weight.value().scalar_type() == at::ScalarType::Float) {1102using namespace at;1103DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward",1104using accscalar_t = at::acc_type<scalar_t_0, true>;1105batchnorm_backward_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(1106grad_output.DATA_PTR<scalar_t_0>(),1107input.DATA_PTR<scalar_t_0>(),1108mean.DATA_PTR<accscalar_t>(),1109inv_std.DATA_PTR<accscalar_t>(),1110weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,1111sum_dy.DATA_PTR<accscalar_t>(),1112sum_dy_xmu.DATA_PTR<accscalar_t>(),1113count.DATA_PTR<int>(),1114grad_input.DATA_PTR<scalar_t_0>(),1115count.numel(),1116space_size,1117batch_size);1118);1119} else {1120if (weight.has_value()) {1121TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),1122"input.scalar_type() is not supported with weight.scalar_type()");1123}1124using namespace at;1125DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward",1126using accscalar_t = at::acc_type<scalar_t_0, true>;1127batchnorm_backward_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(1128grad_output.DATA_PTR<scalar_t_0>(),1129input.DATA_PTR<scalar_t_0>(),1130mean.DATA_PTR<accscalar_t>(),1131inv_std.DATA_PTR<accscalar_t>(),1132weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,1133sum_dy.DATA_PTR<accscalar_t>(),1134sum_dy_xmu.DATA_PTR<accscalar_t>(),1135count.DATA_PTR<int>(),1136grad_input.DATA_PTR<scalar_t_0>(),1137count.numel(),1138space_size,1139batch_size);1140);1141}11421143return grad_input;1144}11451146std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes,1147const at::Tensor var_biased,1148const at::Tensor numel,1149const float eps) {1150const auto world_size = mean_feature_nodes.size(0);1151const auto feature_size = mean_feature_nodes.size(1);11521153at::Tensor out_var = at::empty({feature_size}, var_biased.options());1154at::Tensor inv_std = at::empty_like(out_var);1155at::Tensor out_mean = at::empty_like(out_var);11561157at::Tensor mean_feature_nodes_ = mean_feature_nodes.contiguous();1158at::Tensor var_biased_ = var_biased.contiguous();1159at::Tensor numel_ = numel.contiguous();11601161// TODO(jie): tile this for memory coalescing!1162const int block = std::min(h_last_pow2(feature_size), MAX_BLOCK_SIZE);1163const int grid = std::max<int>(1, feature_size / block);11641165auto stream = at::cuda::getCurrentCUDAStream();11661167{1168using namespace at;1169DISPATCH_FLOAT_AND_HALF(mean_feature_nodes.scalar_type(), 0, "welford_parallel_kernel",1170welford_kernel_parallel<scalar_t_0><<<grid, block, 0, stream>>>(1171mean_feature_nodes_.DATA_PTR<scalar_t_0>(),1172var_biased_.DATA_PTR<scalar_t_0>(),1173numel_.DATA_PTR<int>(),1174out_mean.DATA_PTR<scalar_t_0>(),1175out_var.DATA_PTR<scalar_t_0>(),1176inv_std.DATA_PTR<scalar_t_0>(),1177world_size,1178feature_size,1179eps);1180);1181}11821183return {out_mean, out_var, inv_std};1184}11851186std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) {1187const auto stride = input.size(input.ndimension()-1);1188const auto reduction_size = input.numel() / stride;11891190auto scalar_type = promote_scalartype(input);1191auto option = input.options().dtype(scalar_type);11921193at::Tensor out_var_biased = at::empty({stride}, option);1194at::Tensor out_mean = at::empty({stride}, option);11951196dim3 block;1197dim3 grid;1198flexible_launch_configs(reduction_size, stride, block, grid, true);11991200at::Tensor staging_data;1201at::Tensor semaphores;1202if (grid.y > 1) {1203staging_data = at::empty({4*stride*grid.y}, option);1204semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));1205}12061207auto stream = at::cuda::getCurrentCUDAStream();12081209{1210using namespace at;1211DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "welford_mean_var_c_last",1212using accscalar_t = at::acc_type<scalar_t_0, true>;1213accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.DATA_PTR<accscalar_t>() : nullptr;1214int* semaphores_ptr = grid.y > 1 ? semaphores.DATA_PTR<int>() : nullptr;1215welford_kernel_c_last<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>1216<<<grid, block, 0, stream>>>(1217input.DATA_PTR<scalar_t_0>(),1218out_mean.DATA_PTR<accscalar_t>(),1219out_var_biased.DATA_PTR<accscalar_t>(),1220staging_data_ptr,1221semaphores_ptr,1222reduction_size,1223stride);1224);1225}12261227return {out_mean, out_var_biased};1228}12291230at::Tensor batchnorm_forward_c_last_CUDA(1231const at::Tensor input,1232const at::optional<at::Tensor> z,1233const at::Tensor mean,1234const at::Tensor inv_std,1235const at::optional<at::Tensor> weight,1236const at::optional<at::Tensor> shift,1237const bool fuse_relu) {1238const auto stride = input.size(input.ndimension()-1);1239const auto reduction_size = input.numel() / stride;12401241at::Tensor out = at::empty_like(input);12421243dim3 block;1244dim3 grid;1245flexible_launch_configs(reduction_size, stride, block, grid);12461247auto stream = at::cuda::getCurrentCUDAStream();12481249if (input.scalar_type() == at::ScalarType::Half1250&& weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {1251using namespace at;1252DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",1253using accscalar_t = at::acc_type<scalar_t_0, true>;1254batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>1255<<<grid, block, 0, stream>>>(1256input.DATA_PTR<scalar_t_0>(),1257z.has_value() ? z.value().DATA_PTR<scalar_t_0>() : NULL,1258mean.DATA_PTR<accscalar_t>(),1259inv_std.DATA_PTR<accscalar_t>(),1260weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,1261shift.has_value() ? shift.value().DATA_PTR<accscalar_t>(): NULL,1262out.DATA_PTR<scalar_t_0>(),1263reduction_size,1264stride,1265fuse_relu);1266);1267} else {1268if (weight.has_value()) {1269TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),1270"input.scalar_type() is not supported with weight.scalar_type()");1271}1272using namespace at;1273DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",1274using accscalar_t = at::acc_type<scalar_t_0, true>;1275batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>1276<<<grid, block, 0, stream>>>(1277input.DATA_PTR<scalar_t_0>(),1278z.has_value() ? z.value().DATA_PTR<scalar_t_0>() : NULL,1279mean.DATA_PTR<accscalar_t>(),1280inv_std.DATA_PTR<accscalar_t>(),1281weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,1282shift.has_value() ? shift.value().DATA_PTR<scalar_t_0>(): NULL,1283out.DATA_PTR<scalar_t_0>(),1284reduction_size,1285stride,1286fuse_relu);1287);1288}1289return out;1290}12911292std::vector<at::Tensor> reduce_bn_c_last_CUDA(1293const at::Tensor grad_output,1294const at::Tensor input,1295const at::Tensor mean,1296const at::Tensor inv_std,1297const at::optional<at::Tensor> weight) {1298const auto stride = input.size(input.ndimension()-1);1299const auto reduction_size = input.numel() / stride;13001301at::Tensor sumn_dy = at::empty({stride}, mean.options());1302at::Tensor sum_dy_xmu = at::empty({stride}, mean.options());13031304at::Tensor grad_weight;1305at::Tensor grad_bias;1306if (weight.has_value()) {1307grad_weight = at::empty({stride}, weight.value().options());1308grad_bias = at::empty({stride}, weight.value().options());1309} else {1310// because I cannot return an uninitialized at::Tensor1311grad_weight = at::empty({0}, mean.options());1312grad_bias = at::empty({0}, mean.options());1313}13141315dim3 block;1316dim3 grid;1317flexible_launch_configs(reduction_size, stride, block, grid, true);13181319at::Tensor staging_data;1320at::Tensor semaphores;1321if (grid.y > 1) {1322staging_data = at::empty({2*stride*grid.y}, mean.options());1323semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));1324}1325auto stream = at::cuda::getCurrentCUDAStream();13261327if (input.scalar_type() == at::ScalarType::Half1328&& weight.has_value()1329&& weight.value().scalar_type() == at::ScalarType::Float) {1330using namespace at;1331DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",1332using accscalar_t = at::acc_type<scalar_t_0, true>;1333accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.DATA_PTR<accscalar_t>() : nullptr;1334int* semaphores_ptr = grid.y > 1 ? semaphores.DATA_PTR<int>() : nullptr;1335reduce_bn_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>1336<<<grid, block, 0, stream>>>(1337input.DATA_PTR<scalar_t_0>(),1338grad_output.DATA_PTR<scalar_t_0>(),1339mean.DATA_PTR<accscalar_t>(),1340inv_std.DATA_PTR<accscalar_t>(),1341sumn_dy.DATA_PTR<accscalar_t>(),1342sum_dy_xmu.DATA_PTR<accscalar_t>(),1343weight.has_value() ? grad_weight.DATA_PTR<accscalar_t>() : NULL,1344weight.has_value() ?grad_bias.DATA_PTR<accscalar_t>() : NULL,1345staging_data_ptr,1346semaphores_ptr,1347reduction_size,1348stride);1349);1350} else {1351if (weight.has_value()) {1352TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),1353"input.scalar_type() is not supported with weight.scalar_type()");1354}1355using namespace at;1356DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",1357using accscalar_t = at::acc_type<scalar_t_0, true>;1358accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.DATA_PTR<accscalar_t>() : nullptr;1359int* semaphores_ptr = grid.y > 1 ? semaphores.DATA_PTR<int>() : nullptr;1360reduce_bn_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>1361<<<grid, block, 0, stream>>>(1362input.DATA_PTR<scalar_t_0>(),1363grad_output.DATA_PTR<scalar_t_0>(),1364mean.DATA_PTR<accscalar_t>(),1365inv_std.DATA_PTR<accscalar_t>(),1366sumn_dy.DATA_PTR<accscalar_t>(),1367sum_dy_xmu.DATA_PTR<accscalar_t>(),1368weight.has_value() ? grad_weight.DATA_PTR<scalar_t_0>() : NULL,1369weight.has_value() ?grad_bias.DATA_PTR<scalar_t_0>() : NULL,1370staging_data_ptr,1371semaphores_ptr,1372reduction_size,1373stride);1374);1375}13761377return {sumn_dy, sum_dy_xmu, grad_weight, grad_bias};1378}13791380at::Tensor batchnorm_backward_c_last_CUDA(1381const at::Tensor grad_output,1382const at::Tensor input,1383const at::Tensor mean,1384const at::Tensor inv_std,1385const at::optional<at::Tensor> weight,1386const at::Tensor sum_dy,1387const at::Tensor sum_dy_xmu,1388const at::Tensor count) {1389const auto stride = input.size(input.ndimension()-1);1390const auto reduction_size = input.numel() / stride;13911392at::Tensor grad_input = at::empty_like(input);13931394dim3 block;1395dim3 grid;1396flexible_launch_configs(reduction_size, stride, block, grid);13971398auto stream = at::cuda::getCurrentCUDAStream();13991400if (input.scalar_type() == at::ScalarType::Half1401&& weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {1402using namespace at;1403DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",1404using accscalar_t = at::acc_type<scalar_t_0, true>;1405batchnorm_backward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>1406<<<grid, block, 0, stream>>>(1407grad_output.DATA_PTR<scalar_t_0>(),1408input.DATA_PTR<scalar_t_0>(),1409mean.DATA_PTR<accscalar_t>(),1410inv_std.DATA_PTR<accscalar_t>(),1411weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,1412sum_dy.DATA_PTR<accscalar_t>(),1413sum_dy_xmu.DATA_PTR<accscalar_t>(),1414count.DATA_PTR<int>(),1415grad_input.DATA_PTR<scalar_t_0>(),1416count.numel(),1417reduction_size,1418stride);1419);1420} else {1421if (weight.has_value()) {1422TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),1423"input.scalar_type() is not supported with weight.scalar_type()");1424}1425using namespace at;1426DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",1427using accscalar_t = at::acc_type<scalar_t_0, true>;1428batchnorm_backward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>1429<<<grid, block, 0, stream>>>(1430grad_output.DATA_PTR<scalar_t_0>(),1431input.DATA_PTR<scalar_t_0>(),1432mean.DATA_PTR<accscalar_t>(),1433inv_std.DATA_PTR<accscalar_t>(),1434weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,1435sum_dy.DATA_PTR<accscalar_t>(),1436sum_dy_xmu.DATA_PTR<accscalar_t>(),1437count.DATA_PTR<int>(),1438grad_input.DATA_PTR<scalar_t_0>(),1439count.numel(),1440reduction_size,1441stride);1442);1443}14441445return grad_input;1446}14471448at::Tensor relu_backward_c_last_CUDA(1449const at::Tensor grad_output,1450const at::Tensor input,1451const at::optional<at::Tensor> z,1452const at::Tensor mean,1453const at::Tensor inv_std,1454const at::optional<at::Tensor> weight,1455const at::optional<at::Tensor> shift) {14561457const auto stride = input.size(input.ndimension()-1);1458const auto reduction_size = input.numel() / stride;14591460at::Tensor out = at::empty_like(input);14611462dim3 block;1463dim3 grid;1464flexible_launch_configs(reduction_size, stride, block, grid);14651466auto stream = at::cuda::getCurrentCUDAStream();14671468if (input.scalar_type() == at::ScalarType::Half1469&& weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {1470using namespace at;1471DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",1472using accscalar_t = at::acc_type<scalar_t_0, true>;1473relu_backward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>1474<<<grid, block, 0, stream>>>(1475grad_output.DATA_PTR<scalar_t_0>(),1476input.DATA_PTR<scalar_t_0>(),1477z.has_value() ? z.value().DATA_PTR<scalar_t_0>() : NULL,1478mean.DATA_PTR<accscalar_t>(),1479inv_std.DATA_PTR<accscalar_t>(),1480weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,1481shift.has_value() ? shift.value().DATA_PTR<accscalar_t>(): NULL,1482out.DATA_PTR<scalar_t_0>(),1483reduction_size,1484stride);1485);1486} else {1487if (weight.has_value()) {1488TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),1489"input.scalar_type() is not supported with weight.scalar_type()");1490}1491using namespace at;1492DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",1493using accscalar_t = at::acc_type<scalar_t_0, true>;1494relu_backward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>1495<<<grid, block, 0, stream>>>(1496grad_output.DATA_PTR<scalar_t_0>(),1497input.DATA_PTR<scalar_t_0>(),1498z.has_value() ? z.value().DATA_PTR<scalar_t_0>() : NULL,1499mean.DATA_PTR<accscalar_t>(),1500inv_std.DATA_PTR<accscalar_t>(),1501weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,1502shift.has_value() ? shift.value().DATA_PTR<scalar_t_0>(): NULL,1503out.DATA_PTR<scalar_t_0>(),1504reduction_size,1505stride);1506);1507}1508return out;1509}151015111512