Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/csrc/multi_tensor_lamb.cu
Views: 792
#include <ATen/ATen.h>1#include <ATen/AccumulateType.h>2#include <ATen/cuda/CUDAContext.h>3#include <ATen/cuda/Exceptions.h>4// Another possibility:5// #include <torch/all.h>67#include <assert.h>89#include "type_shim.h"10#include "multi_tensor_apply.cuh"1112#define BLOCK_SIZE 51213#define ILP 41415template<typename T>16__device__ __forceinline__ bool is_aligned(T* p){17return ((uint64_t)p) % (ILP*sizeof(T)) == 0;18}1920template<typename T>21__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){22typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;23((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];24}2526typedef enum{27MOMENT_MODE_0 =0, // L2 regularization mode28MOMENT_MODE_1 =1 // Decoupled weight decay mode29} adamMode_t;3031std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(32int chunk_size,33at::Tensor noop_flag,34std::vector<std::vector<at::Tensor>> tensor_lists,35at::optional<bool> per_tensor_python);3637using MATH_T = float;3839template<typename T>40struct LAMBStage1Functor41{42__device__ __forceinline__ void operator()(43int chunk_size,44volatile int* noop_gmem,45TensorListMetadata<4>& tl,46const float beta1,47const float beta2,48const float beta3,49const float beta1_correction,50const float beta2_correction,51const float epsilon,52adamMode_t mode,53const float decay,54const float* global_grad_norm,55const float max_global_grad_norm)56{57// I'd like this kernel to propagate infs/nans.58// if(*noop_gmem == 1)59// return;6061int tensor_loc = tl.block_to_tensor[blockIdx.x];62int chunk_idx = tl.block_to_chunk[blockIdx.x];63int n = tl.sizes[tensor_loc];6465float clipped_global_grad_norm = (*global_grad_norm) > max_global_grad_norm ? (*global_grad_norm) / max_global_grad_norm : 1.0f;6667T* g = (T*)tl.addresses[0][tensor_loc];68g += chunk_idx*chunk_size;6970T* p = (T*)tl.addresses[1][tensor_loc];71p += chunk_idx*chunk_size;7273T* m = (T*)tl.addresses[2][tensor_loc];74m += chunk_idx*chunk_size;7576T* v = (T*)tl.addresses[3][tensor_loc];77v += chunk_idx*chunk_size;7879n -= chunk_idx*chunk_size;8081MATH_T r_g[ILP];82MATH_T r_p[ILP];83MATH_T r_m[ILP];84MATH_T r_v[ILP];85// to make things simple, we put aligned case in a different code path86if(n % ILP == 0 &&87chunk_size % ILP == 0 &&88is_aligned(g) &&89is_aligned(p) &&90is_aligned(m) &&91is_aligned(v))92{93T l_g[ILP];94T l_p[ILP];95T l_m[ILP];96T l_v[ILP];97for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)98{99// load100load_store(l_g, g, 0, i_start);101if (decay != 0)102load_store(l_p, p, 0, i_start);103load_store(l_m, m, 0, i_start);104load_store(l_v, v, 0, i_start);105// unpack106#pragma unroll107for(int ii = 0; ii < ILP; ii++)108{109r_g[ii] = l_g[ii];110if (decay == 0) {111r_p[ii] = MATH_T(0);112}113else {114r_p[ii] = l_p[ii];115}116r_m[ii] = l_m[ii];117r_v[ii] = l_v[ii];118}119#pragma unroll120for(int ii = 0; ii < ILP; ii++)121{122if (mode == MOMENT_MODE_0) {123MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;124// L2 on scaled grad125scaled_grad = scaled_grad + decay*r_p[ii];126r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;127r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;128MATH_T next_m_unbiased = r_m[ii] / beta1_correction;129MATH_T next_v_unbiased = r_v[ii] / beta2_correction;130MATH_T denom = sqrtf(next_v_unbiased) + epsilon;131r_p[ii] = next_m_unbiased / denom;132}133else {134MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;135r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;136r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;137MATH_T next_m_unbiased = r_m[ii] / beta1_correction;138MATH_T next_v_unbiased = r_v[ii] / beta2_correction;139MATH_T denom = sqrtf(next_v_unbiased) + epsilon;140r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);141}142}143#pragma unroll144for(int ii = 0; ii < ILP; ii++)145{146l_p[ii] = r_p[ii];147l_m[ii] = r_m[ii];148l_v[ii] = r_v[ii];149}150// store151load_store(g, l_p, i_start, 0);152load_store(m, l_m, i_start, 0);153load_store(v, l_v, i_start, 0);154}155}156else157{158// see note in multi_tensor_scale_kernel.cu159for(int i_start = 0;160i_start < n && i_start < chunk_size;161i_start += blockDim.x*ILP)162{163MATH_T r_g[ILP];164MATH_T r_p[ILP];165MATH_T r_m[ILP];166MATH_T r_v[ILP];167#pragma unroll168for(int ii = 0; ii < ILP; ii++)169{170int i = i_start + threadIdx.x + ii*blockDim.x;171if(i < n && i < chunk_size)172{173r_g[ii] = g[i];174// special ?optimization? for lamb stage 1175if (decay == 0) {176r_p[ii] = MATH_T(0);177}178else {179r_p[ii] = p[i];180}181r_m[ii] = m[i];182r_v[ii] = v[i];183} else {184r_g[ii] = MATH_T(0);185r_p[ii] = MATH_T(0);186r_m[ii] = MATH_T(0);187r_v[ii] = MATH_T(0);188}189}190#pragma unroll191for(int ii = 0; ii < ILP; ii++)192{193if (mode == MOMENT_MODE_0) {194MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;195// L2 on scaled grad196scaled_grad = scaled_grad + decay*r_p[ii];197r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;198r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;199MATH_T next_m_unbiased = r_m[ii] / beta1_correction;200MATH_T next_v_unbiased = r_v[ii] / beta2_correction;201MATH_T denom = sqrtf(next_v_unbiased) + epsilon;202r_p[ii] = next_m_unbiased / denom;203}204else {205MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;206r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;207r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;208MATH_T next_m_unbiased = r_m[ii] / beta1_correction;209MATH_T next_v_unbiased = r_v[ii] / beta2_correction;210MATH_T denom = sqrtf(next_v_unbiased) + epsilon;211r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);212}213}214#pragma unroll215for(int ii = 0; ii < ILP; ii++)216{217int i = i_start + threadIdx.x + ii*blockDim.x;218if(i < n && i < chunk_size)219{220g[i] = r_p[ii];221m[i] = r_m[ii];222v[i] = r_v[ii];223}224}225}226}227}228};229230// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.231// It computes new parameter value.232template<typename T>233struct LAMBStage2Functor234{235__device__ __forceinline__ void operator()(236int chunk_size,237volatile int* noop_gmem,238TensorListMetadata<2>& tl,239const float* per_tensor_param_norm,240const float* per_tensor_update_norm,241const float learning_rate,242const float decay,243bool use_nvlamb)244{245// I'd like this kernel to propagate infs/nans.246// if(*noop_gmem == 1)247// return;248249int tensor_loc = tl.block_to_tensor[blockIdx.x];250int tensor_num = tl.start_tensor_this_launch + tensor_loc;251int chunk_idx = tl.block_to_chunk[blockIdx.x];252int n = tl.sizes[tensor_loc];253254MATH_T ratio = learning_rate;255// nvlamb: apply adaptive learning rate to all parameters256// otherwise, only apply to those with non-zero weight decay257if (use_nvlamb || (decay != 0.0))258{259float param_norm = per_tensor_param_norm[tensor_num];260float update_norm = per_tensor_update_norm[tensor_num];261ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;262}263264T* update = (T*)tl.addresses[0][tensor_loc];265update += chunk_idx*chunk_size;266267T* p = (T*)tl.addresses[1][tensor_loc];268p += chunk_idx*chunk_size;269270n -= chunk_idx*chunk_size;271272// to make things simple, we put aligned case in a different code path273if(n % ILP == 0 &&274chunk_size % ILP == 0 &&275is_aligned(p) &&276is_aligned(update))277{278T r_p[ILP];279T r_update[ILP];280for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)281{282// load283load_store(r_p, p, 0, i_start);284load_store(r_update, update, 0, i_start);285#pragma unroll286for(int ii = 0; ii < ILP; ii++)287{288r_p[ii] = static_cast<MATH_T>(r_p[ii]) - (ratio * static_cast<MATH_T>(r_update[ii]));289}290load_store(p, r_p, i_start, 0);291}292}293else294{295for(int i_start = 0;296i_start < n && i_start < chunk_size;297i_start += blockDim.x*ILP)298{299MATH_T r_p[ILP];300MATH_T r_update[ILP];301#pragma unroll302for(int ii = 0; ii < ILP; ii++)303{304int i = i_start + threadIdx.x + ii*blockDim.x;305if(i < n && i < chunk_size)306{307r_p[ii] = p[i];308r_update[ii] = update[i];309}310}311#pragma unroll312for(int ii = 0; ii < ILP; ii++)313{314r_p[ii] = r_p[ii] - (ratio * r_update[ii]);315}316#pragma unroll317for(int ii = 0; ii < ILP; ii++)318{319int i = i_start + threadIdx.x + ii*blockDim.x;320if(i < n && i < chunk_size)321{322p[i] = r_p[ii];323}324}325}326}327}328};329330331void multi_tensor_lamb_cuda(332int chunk_size,333at::Tensor noop_flag,334std::vector<std::vector<at::Tensor>> tensor_lists,335const float lr,336const float beta1,337const float beta2,338const float epsilon,339const int step,340const int bias_correction,341const float weight_decay,342const int grad_averaging,343const int mode,344at::Tensor global_grad_norm,345const float max_grad_norm,346at::optional<bool> use_nvlamb_python)347{348using namespace at;349// Master weight and 32bit momentum(potentially changing) is not handled by this350// So we assume every tensor are all in the same type351352bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;353354// Handle bias correction mode355float bias_correction1 = 1.0f, bias_correction2 = 1.0f;356if (bias_correction == 1) {357bias_correction1 = 1 - std::pow(beta1, step);358bias_correction2 = 1 - std::pow(beta2, step);359}360361// Handle grad averaging mode362float beta3 = 1.0f;363if (grad_averaging == 1) beta3 = 1 - beta1;364365std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin()+1);366std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2);367368// Compute per tensor param norm369auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);370371// We now in-place modify grad to store update before compute its norm372// Generally this is not a issue since people modify grad in step() method all the time373// We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code374DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",375multi_tensor_apply<4>(376BLOCK_SIZE,377chunk_size,378noop_flag,379tensor_lists,380LAMBStage1Functor<scalar_t_0>(),381beta1,382beta2,383beta3, // 1-beta1 or 1 depends on averaging mode384bias_correction1,385bias_correction2,386epsilon,387(adamMode_t) mode,388weight_decay,389global_grad_norm.DATA_PTR<float>(),390max_grad_norm); )391392// Compute update norms393auto update_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true);394395std::vector<std::vector<at::Tensor>> grad_param_list(tensor_lists.begin(), tensor_lists.begin()+2);396397DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",398multi_tensor_apply<2>(399BLOCK_SIZE,400chunk_size,401noop_flag,402grad_param_list,403LAMBStage2Functor<scalar_t_0>(),404std::get<1>(param_norm_tuple).DATA_PTR<float>(),405std::get<1>(update_norm_tuple).DATA_PTR<float>(),406lr,407weight_decay,408use_nvlamb); )409410AT_CUDA_CHECK(cudaGetLastError());411412}413414415