CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/csrc/multi_tensor_sgd_kernel.cu
Views: 792
1
#include <ATen/ATen.h>
2
#include <ATen/AccumulateType.h>
3
#include <ATen/cuda/CUDAContext.h>
4
#include <ATen/cuda/Exceptions.h>
5
#include "multi_tensor_apply.cuh"
6
#include "compat.h"
7
8
#include <assert.h>
9
#include <cuda_runtime.h>
10
11
#define BLOCK_SIZE 512
12
#define ILP 4
13
14
/**
15
* Perform fused SGD on multiple buffers
16
* N: number of tensors
17
* tl[0] : gradients
18
* tl[1] : weights
19
* tl[2] : momentum buffers
20
* tl[3] : fp16 weights (if appropriate)
21
* wd : weight_decay (scalar)
22
* momentum : momentum (scalar)
23
* dampening : momentum dampening (scalar)
24
* lr : learning rate (scalar)
25
* nesterov : enable nesterov (bool)
26
* first run : necessary for proper momentum handling & init
27
* wd_after_momentum : apply weight decay _after_ momentum instead of before
28
**/
29
template<int N, typename T_grad, typename T_weight>
30
struct SGDFunctor
31
{
32
__device__ __forceinline__ void operator()(
33
int chunk_size,
34
volatile int* noop_gmem,
35
TensorListMetadata<N>& tl,
36
float wd,
37
float momentum,
38
float dampening,
39
float lr,
40
bool nesterov,
41
bool first_run,
42
bool wd_after_momentum,
43
float scale)
44
{
45
// Early exit if we don't need to do anything
46
if (*noop_gmem) return;
47
48
int tensor_loc = tl.block_to_tensor[blockIdx.x];
49
int chunk_idx = tl.block_to_chunk[blockIdx.x];
50
int n = tl.sizes[tensor_loc];
51
52
T_grad* grad_in = (T_grad*)tl.addresses[0][tensor_loc];
53
grad_in += chunk_idx*chunk_size;
54
55
T_weight* weight_in = (T_weight*)tl.addresses[1][tensor_loc];
56
weight_in += chunk_idx*chunk_size;
57
58
T_weight* mom_in = (T_weight*)tl.addresses[2][tensor_loc];
59
mom_in += chunk_idx*chunk_size;
60
61
at::Half *model_weights_out = nullptr;
62
if(N == 4)
63
{
64
model_weights_out = (at::Half*)tl.addresses[3][tensor_loc];
65
model_weights_out += chunk_idx*chunk_size;
66
}
67
68
n -= chunk_idx*chunk_size;
69
70
// Non-divergent exit condition for the __syncthreads
71
float incoming_grads[ILP];
72
float incoming_weights[ILP];
73
float incoming_moms[ILP];
74
for(int i_start = 0;
75
i_start < n && i_start < chunk_size;
76
i_start += blockDim.x*ILP)
77
{
78
#pragma unroll
79
for(int ii = 0; ii < ILP; ii++)
80
{
81
incoming_grads[ii] = 0;
82
incoming_weights[ii] = 0;
83
incoming_moms[ii] = 0;
84
int i = i_start + threadIdx.x + ii*blockDim.x;
85
if(i < n && i < chunk_size)
86
{
87
incoming_grads[ii] = static_cast<float>(grad_in[i])*scale;
88
incoming_weights[ii] = static_cast<float>(weight_in[i]);
89
incoming_moms[ii] = static_cast<float>(mom_in[i]);
90
}
91
}
92
93
// note for clarification to future michael:
94
// From a pure memory dependency perspective, there's likely no point unrolling
95
// the write loop, since writes just fire off once their LDGs arrive.
96
// Put another way, the STGs are dependent on the LDGs, but not on each other.
97
// There is still compute ILP benefit from unrolling the loop though.
98
#pragma unroll
99
for(int ii = 0; ii < ILP; ii++)
100
{
101
int i = i_start + threadIdx.x + ii*blockDim.x;
102
if(i < n && i < chunk_size)
103
{
104
// apply weight decay before momentum if necessary
105
if(wd != 0.f && !wd_after_momentum)
106
incoming_grads[ii] += wd * incoming_weights[ii];
107
108
if(momentum != 0.f)
109
{
110
if(!first_run)
111
incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii];
112
else // initialize momentums to current incoming grads
113
incoming_moms[ii] = incoming_grads[ii];
114
115
if(nesterov)
116
incoming_grads[ii] += momentum * incoming_moms[ii];
117
else
118
incoming_grads[ii] = incoming_moms[ii];
119
}
120
121
// Apply WD after momentum if desired
122
if(wd != 0.f && wd_after_momentum)
123
incoming_grads[ii] += wd * incoming_weights[ii];
124
125
// adjust the weight and write out
126
weight_in[i] += (-lr * incoming_grads[ii]);
127
128
// if necessary, write out an fp16 copy of the weights
129
if(N == 4)
130
model_weights_out[i] = static_cast<at::Half>(weight_in[i]);
131
132
// also write out the new momentum
133
if(momentum != 0.f)
134
mom_in[i] = incoming_moms[ii];
135
}
136
}
137
}
138
}
139
};
140
141
void multi_tensor_sgd_cuda(
142
int chunk_size,
143
at::Tensor noop_flag,
144
std::vector<std::vector<at::Tensor>> tensor_lists,
145
float wd,
146
float momentum,
147
float dampening,
148
float lr,
149
bool nesterov,
150
bool first_run,
151
bool wd_after_momentum,
152
float scale)
153
{
154
auto num_tensors = tensor_lists.size();
155
auto grad_type = tensor_lists[0][0].scalar_type();
156
auto weight_type = tensor_lists[1][0].scalar_type();
157
158
if(num_tensors == 4)
159
for(int i = 0; i < tensor_lists[3].size(); i++)
160
TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
161
"Additional output tensors should always be fp16.");
162
163
TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), "expected noop flag to be on the same device as tensors");
164
165
// We have 3 possibilities to handle here, in terms of
166
// grad_type, param_type, momentum_type, requires_fp16_copy
167
// 1. fp16, fp16, fp16, No
168
// 2. fp32, fp32, fp32, No
169
// 3. fp16, fp32, fp32, Yes
170
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
171
// It's easier to hardcode these possibilities than to use
172
// switches etc. to handle the cross-product of cases where
173
// we don't want the majority of them.
174
175
// Case 1. fp16, fp16, fp16, No
176
if(grad_type == at::ScalarType::Half &&
177
weight_type == at::ScalarType::Half &&
178
num_tensors == 3)
179
{
180
multi_tensor_apply<3>(
181
BLOCK_SIZE,
182
chunk_size,
183
noop_flag,
184
tensor_lists,
185
SGDFunctor<3, at::Half, at::Half>(),
186
wd,
187
momentum,
188
dampening,
189
lr,
190
nesterov,
191
first_run,
192
wd_after_momentum,
193
scale);
194
}
195
// Case 2. fp16, fp32, fp32, No
196
// else if (grad_type == at::ScalarType::Half &&
197
// weight_type == at::ScalarType::Float &&
198
// num_tensors == 3) {
199
// multi_tensor_apply<3>(
200
// BLOCK_SIZE,
201
// chunk_size,
202
// noop_flag,
203
// tensor_lists,
204
// SGDFunctor<3, at::Half, float>(),
205
// wd,
206
// momentum,
207
// dampening,
208
// lr,
209
// nesterov,
210
// first_run,
211
// wd_after_momentum);
212
// }
213
// Case 2. fp32, fp32, fp32, No
214
else if(grad_type == at::ScalarType::Float &&
215
weight_type == at::ScalarType::Float &&
216
num_tensors == 3)
217
{
218
multi_tensor_apply<3>(
219
BLOCK_SIZE,
220
chunk_size,
221
noop_flag,
222
tensor_lists,
223
SGDFunctor<3, float, float>(),
224
wd,
225
momentum,
226
dampening,
227
lr,
228
nesterov,
229
first_run,
230
wd_after_momentum,
231
scale);
232
}
233
// Case 3. fp16, fp32, fp32, Yes
234
else if(grad_type == at::ScalarType::Half &&
235
weight_type == at::ScalarType::Float &&
236
num_tensors == 4)
237
{
238
multi_tensor_apply<4>(
239
BLOCK_SIZE,
240
chunk_size,
241
noop_flag,
242
tensor_lists,
243
SGDFunctor<4, at::Half, float>(),
244
wd,
245
momentum,
246
dampening,
247
lr,
248
nesterov,
249
first_run,
250
wd_after_momentum,
251
scale);
252
}
253
// Case 4. fp32, fp32, fp32, Yes
254
else if(grad_type == at::ScalarType::Float &&
255
weight_type == at::ScalarType::Float &&
256
num_tensors == 4)
257
{
258
multi_tensor_apply<4>(
259
BLOCK_SIZE,
260
chunk_size,
261
noop_flag,
262
tensor_lists,
263
SGDFunctor<4, float, float>(),
264
wd,
265
momentum,
266
dampening,
267
lr,
268
nesterov,
269
first_run,
270
wd_after_momentum,
271
scale);
272
}
273
else
274
{
275
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
276
"gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
277
}
278
279
AT_CUDA_CHECK(cudaGetLastError());
280
}
281
282