Contact Us!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/csrc/mlp.cpp
Views: 797
1
#include <torch/extension.h>
2
#include <torch/torch.h>
3
#include <vector>
4
5
#include <stdio.h>
6
7
size_t get_mlp_reserved_space(int64_t batch_size, int num_layers, const int* output_features);
8
9
template <typename T>
10
size_t get_mlp_bp_workspace_in_bytes(int batch_size, int num_layers, const int* output_features);
11
12
template <typename T>
13
int mlp_fp(
14
T* X,
15
int input_features,
16
int batch_size,
17
T** WPtr,
18
int num_layers,
19
int* output_features,
20
T** BPtr,
21
T* Y,
22
T* reserved_space,
23
int use_bias,
24
int activation,
25
void* lt_workspace);
26
27
template <typename T>
28
int mlp_bp(
29
T* X,
30
T* Y,
31
int input_features,
32
int batch_size,
33
T** WPtr,
34
int num_layers,
35
int* output_features,
36
T* dY,
37
T* reserved_space,
38
T* work_space,
39
T* dX,
40
T** dwPtr,
41
T** dbPtr,
42
bool requires_grad,
43
int use_bias,
44
int activation);
45
46
std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at::Tensor> inputs) {
47
48
auto num_layers = inputs.size() - 1;
49
if (use_bias) {
50
// inputs contains (input, weights, biases)
51
num_layers /= 2;
52
}
53
auto batch_size = inputs[0].size(0);
54
auto input_features = inputs[0].size(1);
55
56
std::vector<int> output_features;
57
for (int i = 0; i < num_layers; i++) {
58
output_features.push_back(inputs[i + 1].size(0));
59
}
60
61
auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
62
63
// create output/workspace tensor
64
auto out = at::empty({batch_size, output_features.back()}, inputs[0].type());
65
auto reserved_space = at::empty({reserved_size}, inputs[0].type());
66
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
67
auto lt_workspace = at::empty({1 << 22}, inputs[0].type());
68
69
AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_forward", [&] {
70
std::vector<scalar_t*> w_ptr;
71
std::vector<scalar_t*> b_ptr;
72
for (int i = 0; i < num_layers; i++) {
73
w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>());
74
if (use_bias) {
75
b_ptr.push_back(inputs[i + 1 + num_layers].data_ptr<scalar_t>());
76
}
77
}
78
auto result = mlp_fp<scalar_t>(
79
inputs[0].data_ptr<scalar_t>(),
80
input_features,
81
batch_size,
82
w_ptr.data(),
83
num_layers,
84
output_features.data(),
85
b_ptr.data(),
86
out.data_ptr<scalar_t>(),
87
reserved_space.data_ptr<scalar_t>(),
88
use_bias,
89
activation,
90
(void*) (lt_workspace.data_ptr<scalar_t>()));
91
});
92
93
return {out, reserved_space};
94
}
95
96
std::vector<at::Tensor> mlp_backward(
97
int use_bias,
98
int activation,
99
at::Tensor grad_o,
100
std::vector<at::Tensor> fprop_outputs,
101
std::vector<at::Tensor> inputs) {
102
103
auto num_layers = inputs.size() - 1;
104
if (use_bias) {
105
// inputs contains (input, weights, biases)
106
num_layers /= 2;
107
}
108
109
auto batch_size = inputs[0].size(0);
110
auto input_features = inputs[0].size(1);
111
112
bool requires_grad = inputs[0].requires_grad();
113
114
std::vector<int> output_features;
115
for (int i = 0; i < num_layers; i++) {
116
output_features.push_back(inputs[i + 1].size(0));
117
}
118
// create outputs, length of inputs
119
std::vector<at::Tensor> outputs;
120
for (int i = 0; i < inputs.size(); i++) {
121
outputs.push_back(at::empty(inputs[i].sizes(), inputs[i].type())); // clone for testing now
122
}
123
124
AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_backward", [&] {
125
std::vector<scalar_t*> w_ptr;
126
for (int i = 0; i < num_layers; i++) {
127
w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>());
128
}
129
std::vector<scalar_t*> outputs_ptr;
130
for (int i = 0; i < inputs.size(); i++) {
131
outputs_ptr.push_back(outputs[i].data_ptr<scalar_t>());
132
}
133
134
auto work_size =
135
get_mlp_bp_workspace_in_bytes<scalar_t>(batch_size, num_layers, output_features.data());
136
137
// auto work_space = at::empty({work_size*4}, at::kByte);
138
auto work_space = at::empty({work_size / sizeof(scalar_t)}, inputs[0].type());
139
140
auto result = mlp_bp<scalar_t>(
141
inputs[0].data_ptr<scalar_t>(),
142
fprop_outputs[0].data_ptr<scalar_t>(),
143
input_features,
144
batch_size,
145
w_ptr.data(),
146
num_layers,
147
output_features.data(),
148
grad_o.contiguous().data_ptr<scalar_t>(),
149
fprop_outputs[1].data_ptr<scalar_t>(),
150
work_space.data_ptr<scalar_t>(),
151
outputs_ptr[0],
152
outputs_ptr.data() + 1,
153
outputs_ptr.data() + 1 + num_layers,
154
requires_grad,
155
use_bias,
156
activation);
157
});
158
159
return outputs;
160
}
161
162
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
163
m.def("forward", &mlp_forward, "MLP forward");
164
m.def("backward", &mlp_backward, "MLP backward");
165
}
166
167
168