Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/csrc/multi_tensor_adagrad.cu
Views: 792
#include <ATen/ATen.h>1#include <ATen/AccumulateType.h>2#include <ATen/cuda/CUDAContext.h>3#include <ATen/cuda/Exceptions.h>4// Another possibility:5// #include <torch/all.h>67#include <assert.h>89#include "multi_tensor_apply.cuh"10#include "type_shim.h"1112#define BLOCK_SIZE 102413#define ILP 41415typedef enum {16ADAGRAD_MODE_0 = 0, // L2 regularization mode.17ADAGRAD_MODE_1 = 1, // AdamW-style weight decay.1819} adagradMode_t;2021using MATH_T = float;2223template <typename T> struct AdagradFunctor {24__device__ __forceinline__ void25operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<3> &tl,26const float epsilon, const float lr, adagradMode_t mode,27const float weight_decay) {28int tensor_loc = tl.block_to_tensor[blockIdx.x];29int chunk_idx = tl.block_to_chunk[blockIdx.x];30int n = tl.sizes[tensor_loc];3132T *g = (T *)tl.addresses[0][tensor_loc];33g += chunk_idx * chunk_size;3435T *p = (T *)tl.addresses[1][tensor_loc];36p += chunk_idx * chunk_size;3738T *h = (T *)tl.addresses[2][tensor_loc];39h += chunk_idx * chunk_size;4041n -= chunk_idx * chunk_size;4243// see note in multi_tensor_scale_kernel.cu44for (int i_start = 0; i_start < n && i_start < chunk_size;45i_start += blockDim.x * ILP) {46MATH_T r_g[ILP];47MATH_T r_p[ILP];48MATH_T r_h[ILP];49#pragma unroll50for (int ii = 0; ii < ILP; ii++) {51int i = i_start + threadIdx.x + ii * blockDim.x;52if (i < n && i < chunk_size) {53r_g[ii] = g[i];54r_p[ii] = p[i];55r_h[ii] = h[i];56} else {57r_g[ii] = MATH_T(0);58r_p[ii] = MATH_T(0);59r_h[ii] = MATH_T(0);60}61}62#pragma unroll63for (int ii = 0; ii < ILP; ii++) {64if (mode == ADAGRAD_MODE_0) { // L265r_g[ii] = r_g[ii] + weight_decay * r_p[ii];66r_h[ii] = r_h[ii] + r_g[ii] * r_g[ii];67r_p[ii] = r_p[ii] - lr * (r_g[ii] / (sqrtf(r_h[ii]) + epsilon));68} else { // AdamW-style69r_h[ii] = r_h[ii] + r_g[ii] * r_g[ii];70r_p[ii] = r_p[ii] - lr * (r_g[ii] / (sqrtf(r_h[ii]) + epsilon) + weight_decay * r_p[ii]);71}72}73#pragma unroll74for (int ii = 0; ii < ILP; ii++) {75int i = i_start + threadIdx.x + ii * blockDim.x;76if (i < n && i < chunk_size) {77p[i] = r_p[ii];78h[i] = r_h[ii];79}80}81}82}83};8485void multi_tensor_adagrad_cuda(86int chunk_size, at::Tensor noop_flag,87std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,88const float epsilon, const int mode, const float weight_decay) {89using namespace at;9091// Assume single type across p,g,h now92DISPATCH_DOUBLE_FLOAT_AND_HALF(93tensor_lists[0][0].scalar_type(), 0, "adagrad",94multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,95AdagradFunctor<scalar_t_0>(), epsilon, lr,96(adagradMode_t)mode, weight_decay);)9798AT_CUDA_CHECK(cudaGetLastError());99}100101102