Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/csrc/multi_tensor_l2norm_kernel.cu
Views: 792
#include <ATen/ATen.h>1#include <ATen/AccumulateType.h>2#include <ATen/cuda/CUDAContext.h>3#include <ATen/cuda/Exceptions.h>4#include <c10/cuda/CUDAGuard.h>5// Another possibility:6// #include <torch/all.h>78#include <assert.h>910#include "type_shim.h"11#include "multi_tensor_apply.cuh"1213#define BLOCK_SIZE 51214#define ILP 41516template<typename T>17__device__ __forceinline__ bool is_aligned(T* p){18return ((uint64_t)p) % (ILP*sizeof(T)) == 0;19}2021template<typename T>22__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){23typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;24((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];25}2627template<typename x_t>28struct L2NormFunctor29{30__device__ __forceinline__ void operator()(31int chunk_size,32volatile int* noop_gmem,33TensorListMetadata<1>& tl,34float* output,35float* output_per_tensor,36bool per_tensor,37int max_chunks_per_tensor)38{39// I'd like this kernel to propagate infs/nans.40// if(*noop_gmem == 1)41// return;4243int tensor_loc = tl.block_to_tensor[blockIdx.x];44int chunk_idx = tl.block_to_chunk[blockIdx.x];45int n = tl.sizes[tensor_loc];4647x_t* x = (x_t*)tl.addresses[0][tensor_loc];48x += chunk_idx*chunk_size;4950n -= chunk_idx*chunk_size;5152__shared__ float s_vals[512];5354float vals[ILP]; // = {0}; // this probably works too but I want to be sure...55x_t r_x[ILP];56for(int i = 0; i < ILP; i++)57{58vals[i] = 0.f;59r_x[i] = 0;60}6162// to make things simple, we put aligned case in a different code path63if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x))64{65for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)66{67// load68load_store(r_x, x, 0 , i_start);69#pragma unroll70for(int ii = 0; ii < ILP; ii++)71{72float next = static_cast<float>(r_x[ii]);73vals[ii] += next*next;74}75}76}77else78{79for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)80{81#pragma unroll82for(int ii = 0; ii < ILP; ii++)83{84int i = i_start + threadIdx.x + ii*blockDim.x;85if(i < n && i < chunk_size)86{87float next = static_cast<float>(x[i]);88vals[ii] += next*next;89}90}91}92}9394float val = 0.f;95for(int i = 0; i < ILP; i++)96val += vals[i];9798float final = reduce_block_into_lanes(s_vals, val);99100if(threadIdx.x == 0)101{102if(!isfinite(final))103*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.104output[blockIdx.x] += final;105if(per_tensor)106output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final;107}108}109};110111// Probably better to template, but since we are not likely to support other norm112template<typename x_t>113struct MaxNormFunctor114{115__device__ __forceinline__ void operator()(116int chunk_size,117volatile int* noop_gmem,118TensorListMetadata<1>& tl,119float* output,120float* output_per_tensor,121bool per_tensor,122int max_chunks_per_tensor)123{124// I'd like this kernel to propagate infs/nans.125// if(*noop_gmem == 1)126// return;127128int tensor_loc = tl.block_to_tensor[blockIdx.x];129int chunk_idx = tl.block_to_chunk[blockIdx.x];130int n = tl.sizes[tensor_loc];131132x_t* x = (x_t*)tl.addresses[0][tensor_loc];133x += chunk_idx*chunk_size;134135n -= chunk_idx*chunk_size;136137__shared__ float s_vals[512];138139float vals[ILP]; // = {0}; // this probably works too but I want to be sure...140x_t r_x[ILP];141for(int i = 0; i < ILP; i++)142{143vals[i] = 0.f;144r_x[i] = 0;145}146147// to make things simple, we put aligned case in a different code path148if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x))149{150for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)151{152// load153load_store(r_x, x, 0 , i_start);154#pragma unroll155for(int ii = 0; ii < ILP; ii++)156{157float next = static_cast<float>(r_x[ii]);158vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));159}160}161}162else163{164for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)165{166#pragma unroll167for(int ii = 0; ii < ILP; ii++)168{169int i = i_start + threadIdx.x + ii*blockDim.x;170if(i < n && i < chunk_size)171{172float next = static_cast<float>(x[i]);173vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));174}175}176}177}178179float val = 0.f;180for(int i = 0; i < ILP; i++)181val = fmaxf(fabsf(val), fabsf(vals[i]));182183float final = reduce_block_into_lanes_max_op(s_vals, val);184185if(threadIdx.x == 0)186{187if(!isfinite(final))188*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.189output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final));190if(per_tensor)191output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final;192}193}194};195196197__global__ void cleanup(198float* output,199float* output_per_tensor,200float* ret,201float* ret_per_tensor,202bool per_tensor,203int max_chunks_per_tensor)204{205__shared__ float vals[512];206207if(blockIdx.x == 0)208{209float val = 0;210if(threadIdx.x < 320)211val = output[threadIdx.x];212213float final = reduce_block_into_lanes(vals, val);214215if(threadIdx.x == 0)216*ret = sqrt(final);217}218219if(per_tensor)220{221float* output_this_tensor = output_per_tensor + blockIdx.x*max_chunks_per_tensor;222223float val = 0;224for(int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)225val += output_this_tensor[i];226227float final = reduce_block_into_lanes(vals, val);228229if(threadIdx.x == 0)230ret_per_tensor[blockIdx.x] = sqrt(final);231}232}233234__global__ void cleanup_v2(235float* output,236float* output_per_tensor,237float* ret,238float* ret_per_tensor,239bool per_tensor,240int max_chunks_per_tensor,241int norm_type,242float alpha,243float beta)244{245__shared__ float vals[512];246247if(blockIdx.x == 0)248{249float val = 0;250if(threadIdx.x < 320)251val = output[threadIdx.x];252253if (norm_type == 0) {254float final = reduce_block_into_lanes_max_op(vals, val);255if(threadIdx.x == 0)256*ret = alpha * (*ret) + beta * final;257}258else {259float final = reduce_block_into_lanes(vals, val);260if(threadIdx.x == 0)261*ret = sqrt(alpha * (*ret) * (*ret) + beta * final);262}263}264265if(per_tensor)266{267float* output_this_tensor = output_per_tensor + blockIdx.x*max_chunks_per_tensor;268269if (norm_type == 0) {270float val = 0;271for(int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)272val = fmaxf(fabsf(val), fabsf(output_this_tensor[i]));273274float final = reduce_block_into_lanes_max_op(vals, val);275276if(threadIdx.x == 0)277ret_per_tensor[blockIdx.x] = alpha * ret_per_tensor[blockIdx.x] + beta * final;278}279else {280float val = 0;281for(int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)282val += output_this_tensor[i];283284float final = reduce_block_into_lanes(vals, val);285286if(threadIdx.x == 0)287ret_per_tensor[blockIdx.x] = sqrt(alpha * ret_per_tensor[blockIdx.x] * ret_per_tensor[blockIdx.x] + beta * final);288}289}290}291292std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(293int chunk_size,294at::Tensor noop_flag,295std::vector<std::vector<at::Tensor>> tensor_lists,296at::optional<bool> per_tensor_python)297{298bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;299300auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);301auto output = at::zeros({320}, float_options);302303at::Tensor output_per_tensor;304at::Tensor ret_per_tensor;305306int ntensors = tensor_lists[0].size();307int max_chunks_per_tensor = -1;308309if(per_tensor)310{311for(int t = 0; t < ntensors; t++)312{313int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;314if(max_chunks_this_tensor > max_chunks_per_tensor)315max_chunks_per_tensor = max_chunks_this_tensor;316}317output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options);318ret_per_tensor = at::empty({ntensors}, float_options);319}320else321{322ret_per_tensor = at::empty({0}, float_options);323}324325DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",326multi_tensor_apply<1>(327BLOCK_SIZE,328chunk_size,329noop_flag,330tensor_lists,331L2NormFunctor<scalar_t_0>(),332output.DATA_PTR<float>(),333per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,334per_tensor,335max_chunks_per_tensor);)336337AT_CUDA_CHECK(cudaGetLastError());338// AT_CUDA_CHECK(cudaDeviceSynchronize());339340// This involves one more small kernel launches, but will be negligible end to end.341// I could get rid of these by hacking the functor + multi tensor harness with persistence342// logic, but keeping it simple for now343auto ret = at::empty({1}, output.options());344const at::cuda::OptionalCUDAGuard device_guard(device_of(output));345auto stream = at::cuda::getCurrentCUDAStream();346cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(347output.DATA_PTR<float>(),348per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,349ret.DATA_PTR<float>(),350per_tensor ? ret_per_tensor.DATA_PTR<float>() : nullptr,351per_tensor,352max_chunks_per_tensor);353354return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);355}356357358// Compute and update grad norm359// Here use a per tensor norm, and blend new norm(n) and old norm(gn) by360// L-2: gn = sqrt(a * gn^2 + b * n^2)361// L-inf: gn = a * gn + b * n362void multi_tensor_norm_out_cuda(363int chunk_size,364at::Tensor noop_flag,365std::vector<std::vector<at::Tensor>> tensor_lists,366at::Tensor out,367const float alpha,368const float beta,369const int norm_type)370{371auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);372TORCH_CHECK(tensor_lists[0][0].device() == noop_flag.device(), "noop flag should be on the same device as tensors");373// we don't need global thus uses empty here374auto output = at::empty({320}, float_options);375376at::Tensor output_per_tensor;377at::Tensor ret_per_tensor;378379int ntensors = tensor_lists[0].size();380int max_chunks_per_tensor = -1;381382for(int t = 0; t < ntensors; t++)383{384int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;385if(max_chunks_this_tensor > max_chunks_per_tensor)386max_chunks_per_tensor = max_chunks_this_tensor;387}388389// Although it is single write then read, still need to be zero390// Since tailing element also participate cleanup391output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options);392393if (norm_type == 0) {394DISPATCH_FLOAT_AND_HALF(395tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda",396multi_tensor_apply<1>(397BLOCK_SIZE,398chunk_size,399noop_flag,400tensor_lists,401MaxNormFunctor<scalar_t_0>(),402output.DATA_PTR<float>(),403output_per_tensor.DATA_PTR<float>(),404true,405max_chunks_per_tensor);)406}407else {408DISPATCH_FLOAT_AND_HALF(409tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",410multi_tensor_apply<1>(411BLOCK_SIZE,412chunk_size,413noop_flag,414tensor_lists,415L2NormFunctor<scalar_t_0>(),416output.DATA_PTR<float>(),417output_per_tensor.DATA_PTR<float>(),418true,419max_chunks_per_tensor);)420}421AT_CUDA_CHECK(cudaGetLastError());422423// AT_CUDA_CHECK(cudaDeviceSynchronize());424425// This involves one more small kernel launches, but will be negligible end to end.426// I could get rid of these by hacking the functor + multi tensor harness with persistence427// logic, but keeping it simple for now428auto ret = at::empty({1}, output.options());429430// Adding the following device guard since it happens sometimes that the431// tensors are on one device and the cuda stream is on another device which432// results in ILLEGAL MEM ACCESS error.433const at::cuda::OptionalCUDAGuard device_guard(device_of(output));434auto stream = at::cuda::getCurrentCUDAStream();435cleanup_v2<<<ntensors, 512, 0, stream>>>(436output.DATA_PTR<float>(),437output_per_tensor.DATA_PTR<float>(),438ret.DATA_PTR<float>(),439out.DATA_PTR<float>(),440true,441max_chunks_per_tensor,442norm_type,443alpha,444beta);445446return ;447}448449450