Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/csrc/multi_tensor_novograd.cu
Views: 792
#include <ATen/ATen.h>1#include <ATen/AccumulateType.h>2#include <ATen/cuda/CUDAContext.h>3#include <ATen/cuda/Exceptions.h>4// Another possibility:5// #include <torch/all.h>67#include <assert.h>89#include "type_shim.h"10#include "multi_tensor_apply.cuh"1112#define BLOCK_SIZE 51213#define ILP 41415typedef enum{16MOMENT_MODE_0 =0, // Novograd paper mode, momentum caculation with denom then decay inside17MOMENT_MODE_1 =1 // Decoupled weight decay mode18} momentMode_t;1920void multi_tensor_norm_out_cuda(21int chunk_size,22at::Tensor noop_flag,23std::vector<std::vector<at::Tensor>> tensor_lists,24at::Tensor out,25const float alpha,26const float beta,27const int norm_type);2829using MATH_T = float;3031template<typename T>32struct NovoGradFunctor33{34__device__ __forceinline__ void operator()(35int chunk_size,36volatile int* noop_gmem,37TensorListMetadata<3>& tl,38const float beta1,39const float beta2,40const float beta3,41const float beta1_correction,42const float beta2_correction,43const float epsilon,44const float lr,45momentMode_t m_mode,46const float decay,47const float* per_tensor_grad_norm)48{49// I'd like this kernel to propagate infs/nans.50// if(*noop_gmem == 1)51// return;5253int tensor_loc = tl.block_to_tensor[blockIdx.x];54int tensor_num = tl.start_tensor_this_launch + tensor_loc;55int chunk_idx = tl.block_to_chunk[blockIdx.x];56int n = tl.sizes[tensor_loc];5758float grad_norm = per_tensor_grad_norm[tensor_num];5960T* g = (T*)tl.addresses[0][tensor_loc];61g += chunk_idx*chunk_size;6263T* p = (T*)tl.addresses[1][tensor_loc];64p += chunk_idx*chunk_size;6566T* m = (T*)tl.addresses[2][tensor_loc];67m += chunk_idx*chunk_size;6869n -= chunk_idx*chunk_size;7071// see note in multi_tensor_scale_kernel.cu72for(int i_start = 0;73i_start < n && i_start < chunk_size;74i_start += blockDim.x*ILP)75{76MATH_T r_g[ILP];77MATH_T r_p[ILP];78MATH_T r_m[ILP];79#pragma unroll80for(int ii = 0; ii < ILP; ii++)81{82int i = i_start + threadIdx.x + ii*blockDim.x;83if(i < n && i < chunk_size)84{85r_g[ii] = g[i];86r_p[ii] = p[i];87r_m[ii] = m[i];88} else {89r_g[ii] = MATH_T(0);90r_p[ii] = MATH_T(0);91r_m[ii] = MATH_T(0);92}93}94#pragma unroll95for(int ii = 0; ii < ILP; ii++)96{97if (m_mode == MOMENT_MODE_0) {98MATH_T next_v_unbiased = grad_norm / beta2_correction;99MATH_T denom = next_v_unbiased + epsilon;100r_g[ii] = (r_g[ii] / denom) + (decay * r_p[ii]);101r_m[ii] = beta1 * r_m[ii] + beta3 * r_g[ii];102MATH_T next_m_unbiased = r_m[ii] / beta1_correction;103r_p[ii] = r_p[ii] - (lr * next_m_unbiased);104}105else {106r_m[ii] = beta1 * r_m[ii] + beta3 * r_g[ii];107MATH_T next_m_unbiased = r_m[ii] / beta1_correction;108MATH_T next_v_unbiased = grad_norm / beta2_correction;109MATH_T denom = next_v_unbiased + epsilon;110MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);111r_p[ii] = r_p[ii] - (lr * update);112}113}114#pragma unroll115for(int ii = 0; ii < ILP; ii++)116{117int i = i_start + threadIdx.x + ii*blockDim.x;118if(i < n && i < chunk_size)119{120p[i] = r_p[ii];121m[i] = r_m[ii];122}123}124}125}126};127128void multi_tensor_novograd_cuda(129int chunk_size,130at::Tensor noop_flag,131std::vector<std::vector<at::Tensor>> tensor_lists,132at::Tensor grad_norms,133const float lr,134const float beta1,135const float beta2,136const float epsilon,137const int step,138const int bias_correction,139const float weight_decay,140const int grad_averaging,141const int moment_mode,142const int norm_type)143{144using namespace at;145146// Handle bias correction mode147float bias_correction1 = 1.0f, bias_correction2 = 1.0f;148if (bias_correction == 1) {149bias_correction1 = 1 - std::pow(beta1, step);150bias_correction2 = std::sqrt(1 - std::pow(beta2, step));151}152153// Handle grad averaging mode154float beta3 = 1;155if (grad_averaging == 1) beta3 = 1 - beta1;156157std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin()+1);158159// Compute and update grad norm160// Here use a per tensor norm, and blend new norm(n) and old norm(gn) by161// L-2: gn = sqrt(a * gn^2 + b * n^2)162// L-inf: gn = a * gn + b * n163multi_tensor_norm_out_cuda(chunk_size, noop_flag, grad_list, grad_norms, beta2, (1.0f - beta2), norm_type);164165// Assume single type across p,g,m1,m2 now166DISPATCH_DOUBLE_FLOAT_AND_HALF(167tensor_lists[0][0].scalar_type(), 0, "novograd",168multi_tensor_apply<3>(169BLOCK_SIZE,170chunk_size,171noop_flag,172tensor_lists,173NovoGradFunctor<scalar_t_0>(),174beta1,175beta2,176beta3, // 1-beta1 or 1 depends on averaging mode177bias_correction1,178bias_correction2,179epsilon,180lr,181(momentMode_t) moment_mode,182weight_decay,183grad_norms.DATA_PTR<float>()); )184185AT_CUDA_CHECK(cudaGetLastError());186187}188189190