CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/csrc/multi_tensor_adam.cu
Views: 792
1
#include <ATen/ATen.h>
2
#include <ATen/AccumulateType.h>
3
#include <ATen/cuda/CUDAContext.h>
4
#include <ATen/cuda/Exceptions.h>
5
// Another possibility:
6
// #include <torch/all.h>
7
8
#include <assert.h>
9
10
#include "type_shim.h"
11
#include "multi_tensor_apply.cuh"
12
13
#define BLOCK_SIZE 512
14
#define ILP 4
15
16
typedef enum{
17
ADAM_MODE_0 =0, // L2 regularization mode
18
ADAM_MODE_1 =1 // Decoupled weight decay mode(AdamW)
19
} adamMode_t;
20
21
using MATH_T = float;
22
23
template<typename T>
24
struct AdamFunctor
25
{
26
__device__ __forceinline__ void operator()(
27
int chunk_size,
28
volatile int* noop_gmem,
29
TensorListMetadata<4>& tl,
30
const float beta1,
31
const float beta2,
32
const float beta1_correction,
33
const float beta2_correction,
34
const float epsilon,
35
const float lr,
36
adamMode_t mode,
37
const float decay)
38
{
39
// I'd like this kernel to propagate infs/nans.
40
// if(*noop_gmem == 1)
41
// return;
42
43
int tensor_loc = tl.block_to_tensor[blockIdx.x];
44
45
// potentially use to pass in list of scalar
46
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
47
48
int chunk_idx = tl.block_to_chunk[blockIdx.x];
49
int n = tl.sizes[tensor_loc];
50
51
T* g = (T*)tl.addresses[0][tensor_loc];
52
g += chunk_idx*chunk_size;
53
54
T* p = (T*)tl.addresses[1][tensor_loc];
55
p += chunk_idx*chunk_size;
56
57
T* m = (T*)tl.addresses[2][tensor_loc];
58
m += chunk_idx*chunk_size;
59
60
T* v = (T*)tl.addresses[3][tensor_loc];
61
v += chunk_idx*chunk_size;
62
63
n -= chunk_idx*chunk_size;
64
65
// see note in multi_tensor_scale_kernel.cu
66
for(int i_start = 0;
67
i_start < n && i_start < chunk_size;
68
i_start += blockDim.x*ILP)
69
{
70
MATH_T r_g[ILP];
71
MATH_T r_p[ILP];
72
MATH_T r_m[ILP];
73
MATH_T r_v[ILP];
74
#pragma unroll
75
for(int ii = 0; ii < ILP; ii++)
76
{
77
int i = i_start + threadIdx.x + ii*blockDim.x;
78
if(i < n && i < chunk_size)
79
{
80
r_g[ii] = g[i];
81
r_p[ii] = p[i];
82
r_m[ii] = m[i];
83
r_v[ii] = v[i];
84
} else {
85
r_g[ii] = MATH_T(0);
86
r_p[ii] = MATH_T(0);
87
r_m[ii] = MATH_T(0);
88
r_v[ii] = MATH_T(0);
89
}
90
}
91
#pragma unroll
92
for(int ii = 0; ii < ILP; ii++)
93
{
94
if(mode == ADAM_MODE_0) { // L2
95
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
96
r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];
97
r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];
98
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
99
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
100
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
101
MATH_T update = next_m_unbiased / denom;
102
r_p[ii] = r_p[ii] - (lr * update);
103
}
104
else { // weight decay
105
r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];
106
r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];
107
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
108
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
109
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
110
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
111
r_p[ii] = r_p[ii] - (lr * update);
112
}
113
}
114
#pragma unroll
115
for(int ii = 0; ii < ILP; ii++)
116
{
117
int i = i_start + threadIdx.x + ii*blockDim.x;
118
if(i < n && i < chunk_size)
119
{
120
p[i] = r_p[ii];
121
m[i] = r_m[ii];
122
v[i] = r_v[ii];
123
}
124
}
125
}
126
}
127
};
128
129
void multi_tensor_adam_cuda(
130
int chunk_size,
131
at::Tensor noop_flag,
132
std::vector<std::vector<at::Tensor>> tensor_lists,
133
const float lr,
134
const float beta1,
135
const float beta2,
136
const float epsilon,
137
const int step,
138
const int mode,
139
const int bias_correction,
140
const float weight_decay)
141
{
142
using namespace at;
143
144
// Handle bias correction mode
145
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
146
if (bias_correction == 1) {
147
bias_correction1 = 1 - std::pow(beta1, step);
148
bias_correction2 = 1 - std::pow(beta2, step);
149
}
150
151
// Assume single type across p,g,m1,m2 now
152
DISPATCH_DOUBLE_FLOAT_AND_HALF(
153
tensor_lists[0][0].scalar_type(), 0, "adam",
154
multi_tensor_apply<4>(
155
BLOCK_SIZE,
156
chunk_size,
157
noop_flag,
158
tensor_lists,
159
AdamFunctor<scalar_t_0>(),
160
beta1,
161
beta2,
162
bias_correction1,
163
bias_correction2,
164
epsilon,
165
lr,
166
(adamMode_t) mode,
167
weight_decay); )
168
169
AT_CUDA_CHECK(cudaGetLastError());
170
171
}
172
173