Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/csrc/multi_tensor_adam.cu
Views: 792
#include <ATen/ATen.h>1#include <ATen/AccumulateType.h>2#include <ATen/cuda/CUDAContext.h>3#include <ATen/cuda/Exceptions.h>4// Another possibility:5// #include <torch/all.h>67#include <assert.h>89#include "type_shim.h"10#include "multi_tensor_apply.cuh"1112#define BLOCK_SIZE 51213#define ILP 41415typedef enum{16ADAM_MODE_0 =0, // L2 regularization mode17ADAM_MODE_1 =1 // Decoupled weight decay mode(AdamW)18} adamMode_t;1920using MATH_T = float;2122template<typename T>23struct AdamFunctor24{25__device__ __forceinline__ void operator()(26int chunk_size,27volatile int* noop_gmem,28TensorListMetadata<4>& tl,29const float beta1,30const float beta2,31const float beta1_correction,32const float beta2_correction,33const float epsilon,34const float lr,35adamMode_t mode,36const float decay)37{38// I'd like this kernel to propagate infs/nans.39// if(*noop_gmem == 1)40// return;4142int tensor_loc = tl.block_to_tensor[blockIdx.x];4344// potentially use to pass in list of scalar45// int tensor_num = tl.start_tensor_this_launch + tensor_loc;4647int chunk_idx = tl.block_to_chunk[blockIdx.x];48int n = tl.sizes[tensor_loc];4950T* g = (T*)tl.addresses[0][tensor_loc];51g += chunk_idx*chunk_size;5253T* p = (T*)tl.addresses[1][tensor_loc];54p += chunk_idx*chunk_size;5556T* m = (T*)tl.addresses[2][tensor_loc];57m += chunk_idx*chunk_size;5859T* v = (T*)tl.addresses[3][tensor_loc];60v += chunk_idx*chunk_size;6162n -= chunk_idx*chunk_size;6364// see note in multi_tensor_scale_kernel.cu65for(int i_start = 0;66i_start < n && i_start < chunk_size;67i_start += blockDim.x*ILP)68{69MATH_T r_g[ILP];70MATH_T r_p[ILP];71MATH_T r_m[ILP];72MATH_T r_v[ILP];73#pragma unroll74for(int ii = 0; ii < ILP; ii++)75{76int i = i_start + threadIdx.x + ii*blockDim.x;77if(i < n && i < chunk_size)78{79r_g[ii] = g[i];80r_p[ii] = p[i];81r_m[ii] = m[i];82r_v[ii] = v[i];83} else {84r_g[ii] = MATH_T(0);85r_p[ii] = MATH_T(0);86r_m[ii] = MATH_T(0);87r_v[ii] = MATH_T(0);88}89}90#pragma unroll91for(int ii = 0; ii < ILP; ii++)92{93if(mode == ADAM_MODE_0) { // L294r_g[ii] = r_g[ii] + (decay * r_p[ii]);95r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];96r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];97MATH_T next_m_unbiased = r_m[ii] / beta1_correction;98MATH_T next_v_unbiased = r_v[ii] / beta2_correction;99MATH_T denom = sqrtf(next_v_unbiased) + epsilon;100MATH_T update = next_m_unbiased / denom;101r_p[ii] = r_p[ii] - (lr * update);102}103else { // weight decay104r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];105r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];106MATH_T next_m_unbiased = r_m[ii] / beta1_correction;107MATH_T next_v_unbiased = r_v[ii] / beta2_correction;108MATH_T denom = sqrtf(next_v_unbiased) + epsilon;109MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);110r_p[ii] = r_p[ii] - (lr * update);111}112}113#pragma unroll114for(int ii = 0; ii < ILP; ii++)115{116int i = i_start + threadIdx.x + ii*blockDim.x;117if(i < n && i < chunk_size)118{119p[i] = r_p[ii];120m[i] = r_m[ii];121v[i] = r_v[ii];122}123}124}125}126};127128void multi_tensor_adam_cuda(129int chunk_size,130at::Tensor noop_flag,131std::vector<std::vector<at::Tensor>> tensor_lists,132const float lr,133const float beta1,134const float beta2,135const float epsilon,136const int step,137const int mode,138const int bias_correction,139const float weight_decay)140{141using namespace at;142143// Handle bias correction mode144float bias_correction1 = 1.0f, bias_correction2 = 1.0f;145if (bias_correction == 1) {146bias_correction1 = 1 - std::pow(beta1, step);147bias_correction2 = 1 - std::pow(beta2, step);148}149150// Assume single type across p,g,m1,m2 now151DISPATCH_DOUBLE_FLOAT_AND_HALF(152tensor_lists[0][0].scalar_type(), 0, "adam",153multi_tensor_apply<4>(154BLOCK_SIZE,155chunk_size,156noop_flag,157tensor_lists,158AdamFunctor<scalar_t_0>(),159beta1,160beta2,161bias_correction1,162bias_correction2,163epsilon,164lr,165(adamMode_t) mode,166weight_decay); )167168AT_CUDA_CHECK(cudaGetLastError());169170}171172173