CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/csrc/layer_norm_cuda.cpp
Views: 792
1
#include <torch/extension.h>
2
#include <vector>
3
#include <cassert>
4
#include "compat.h"
5
6
namespace {
7
void compute_n1_n2(
8
at::Tensor input,
9
#ifdef VERSION_GE_1_1
10
at::IntArrayRef normalized_shape,
11
#else
12
at::IntList normalized_shape,
13
#endif
14
int& n1,
15
int& n2)
16
{
17
int idiff = input.ndimension() - normalized_shape.size();
18
n2 = 1;
19
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
20
assert( input.sizes()[i+idiff] == normalized_shape[i] );
21
n2 *= normalized_shape[i];
22
}
23
n1 = 1;
24
for (int i = 0; i < idiff; ++i) {
25
n1 *= input.sizes()[i];
26
}
27
}
28
29
void check_args(
30
#ifdef VERSION_GE_1_1
31
at::IntArrayRef normalized_shape,
32
#else
33
at::IntList normalized_shape,
34
#endif
35
at::Tensor gamma,
36
at::Tensor beta
37
)
38
{
39
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
40
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
41
}
42
43
void check_args(
44
at::Tensor input,
45
#ifdef VERSION_GE_1_1
46
at::IntArrayRef normalized_shape,
47
#else
48
at::IntList normalized_shape,
49
#endif
50
int& n1,
51
int& n2
52
)
53
{
54
int64_t normalized_ndim = normalized_shape.size();
55
56
if (normalized_ndim < 1) {
57
std::stringstream ss;
58
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
59
<< "containing at least one element, but got normalized_shape="
60
<< normalized_shape;
61
throw std::runtime_error(ss.str());
62
}
63
64
auto input_shape = input.sizes();
65
auto input_ndim = input.dim();
66
67
if (input_ndim < normalized_ndim ||
68
!input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
69
std::stringstream ss;
70
ss << "Given normalized_shape=" << normalized_shape
71
<< ", expected input with shape [*";
72
for (auto size : normalized_shape) {
73
ss << ", " << size;
74
}
75
ss << "], but got input of size" << input_shape;
76
throw std::runtime_error(ss.str());
77
}
78
79
compute_n1_n2(input,normalized_shape,n1,n2);
80
}
81
82
83
void check_args(
84
at::Tensor input,
85
#ifdef VERSION_GE_1_1
86
at::IntArrayRef normalized_shape,
87
#else
88
at::IntList normalized_shape,
89
#endif
90
at::Tensor gamma,
91
at::Tensor beta,
92
int& n1,
93
int& n2
94
)
95
{
96
check_args(input,normalized_shape,n1,n2);
97
check_args(normalized_shape,gamma,beta);
98
}
99
}
100
101
void cuda_layer_norm(
102
at::Tensor* output,
103
at::Tensor* mean,
104
at::Tensor* invvar,
105
at::Tensor* input,
106
int n1,
107
int n2,
108
#ifdef VERSION_GE_1_1
109
at::IntArrayRef normalized_shape,
110
#else
111
at::IntList normalized_shape,
112
#endif
113
at::Tensor* gamma,
114
at::Tensor* beta,
115
double epsilon);
116
117
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
118
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
119
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
120
121
std::vector<at::Tensor> layer_norm(
122
at::Tensor input,
123
#ifdef VERSION_GE_1_1
124
at::IntArrayRef normalized_shape,
125
#else
126
at::IntList normalized_shape,
127
#endif
128
double epsilon) {
129
CHECK_INPUT(input);
130
int n1,n2;
131
check_args(input,normalized_shape,n1,n2);
132
at::Tensor output = at::empty_like(input);
133
at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));
134
at::Tensor invvar = at::empty_like(mean);
135
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
136
normalized_shape,NULL,NULL,epsilon);
137
return {output, mean, invvar};
138
}
139
std::vector<at::Tensor> layer_norm_affine(
140
at::Tensor input,
141
#ifdef VERSION_GE_1_1
142
at::IntArrayRef normalized_shape,
143
#else
144
at::IntList normalized_shape,
145
#endif
146
at::Tensor gamma,
147
at::Tensor beta,
148
double epsilon) {
149
CHECK_INPUT(input);
150
CHECK_INPUT(gamma);
151
CHECK_INPUT(beta);
152
int n1,n2;
153
check_args(input,normalized_shape,gamma,beta,n1,n2);
154
at::Tensor output = at::empty_like(input);
155
at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));
156
at::Tensor invvar = at::empty_like(mean);
157
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
158
normalized_shape,&gamma,&beta,epsilon);
159
return {output, mean, invvar};
160
}
161
162
void cuda_layer_norm_gradient(
163
at::Tensor* dout,
164
at::Tensor* mean,
165
at::Tensor* invvar,
166
at::Tensor* input,
167
int n1,
168
int n2,
169
#ifdef VERSION_GE_1_1
170
at::IntArrayRef normalized_shape,
171
#else
172
at::IntList normalized_shape,
173
#endif
174
at::Tensor* gamma,
175
at::Tensor* beta,
176
double epsilon,
177
at::Tensor* grad_input,
178
at::Tensor* grad_gamma,
179
at::Tensor* grad_beta
180
);
181
182
at::Tensor layer_norm_gradient(
183
at::Tensor dout,
184
at::Tensor mean,
185
at::Tensor invvar,
186
at::Tensor input,
187
#ifdef VERSION_GE_1_1
188
at::IntArrayRef normalized_shape,
189
#else
190
at::IntList normalized_shape,
191
#endif
192
double epsilon) {
193
CHECK_INPUT(dout);
194
CHECK_INPUT(mean);
195
CHECK_INPUT(invvar);
196
CHECK_INPUT(input);
197
int n1,n2;
198
check_args(input,normalized_shape,n1,n2);
199
at::Tensor grad_input = at::empty_like(input);
200
cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,
201
normalized_shape,NULL,NULL,epsilon,
202
&grad_input,NULL,NULL);
203
return grad_input;
204
}
205
std::vector<at::Tensor> layer_norm_gradient_affine(
206
at::Tensor dout,
207
at::Tensor mean,
208
at::Tensor invvar,
209
at::Tensor input,
210
#ifdef VERSION_GE_1_1
211
at::IntArrayRef normalized_shape,
212
#else
213
at::IntList normalized_shape,
214
#endif
215
at::Tensor gamma,
216
at::Tensor beta,
217
double epsilon) {
218
CHECK_INPUT(dout);
219
CHECK_INPUT(mean);
220
CHECK_INPUT(invvar);
221
CHECK_INPUT(input);
222
CHECK_INPUT(gamma);
223
CHECK_INPUT(beta);
224
int n1,n2;
225
check_args(input,normalized_shape,gamma,beta,n1,n2);
226
at::Tensor grad_input = at::empty_like(input);
227
at::Tensor grad_gamma = at::empty_like(gamma);
228
at::Tensor grad_beta = at::empty_like(beta);
229
cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,
230
normalized_shape,&gamma,&beta,epsilon,
231
&grad_input,&grad_gamma,&grad_beta);
232
return {grad_input, grad_gamma, grad_beta};
233
}
234
235
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
236
m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)");
237
m.def("forward", &layer_norm, "LayerNorm forward (CUDA)");
238
m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)");
239
m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)");
240
}
241
242
243