Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/csrc/mlp.cpp
Views: 797
#include <torch/extension.h>1#include <torch/torch.h>2#include <vector>34#include <stdio.h>56size_t get_mlp_reserved_space(int64_t batch_size, int num_layers, const int* output_features);78template <typename T>9size_t get_mlp_bp_workspace_in_bytes(int batch_size, int num_layers, const int* output_features);1011template <typename T>12int mlp_fp(13T* X,14int input_features,15int batch_size,16T** WPtr,17int num_layers,18int* output_features,19T** BPtr,20T* Y,21T* reserved_space,22int use_bias,23int activation,24void* lt_workspace);2526template <typename T>27int mlp_bp(28T* X,29T* Y,30int input_features,31int batch_size,32T** WPtr,33int num_layers,34int* output_features,35T* dY,36T* reserved_space,37T* work_space,38T* dX,39T** dwPtr,40T** dbPtr,41bool requires_grad,42int use_bias,43int activation);4445std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at::Tensor> inputs) {4647auto num_layers = inputs.size() - 1;48if (use_bias) {49// inputs contains (input, weights, biases)50num_layers /= 2;51}52auto batch_size = inputs[0].size(0);53auto input_features = inputs[0].size(1);5455std::vector<int> output_features;56for (int i = 0; i < num_layers; i++) {57output_features.push_back(inputs[i + 1].size(0));58}5960auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());6162// create output/workspace tensor63auto out = at::empty({batch_size, output_features.back()}, inputs[0].type());64auto reserved_space = at::empty({reserved_size}, inputs[0].type());65// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB66auto lt_workspace = at::empty({1 << 22}, inputs[0].type());6768AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_forward", [&] {69std::vector<scalar_t*> w_ptr;70std::vector<scalar_t*> b_ptr;71for (int i = 0; i < num_layers; i++) {72w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>());73if (use_bias) {74b_ptr.push_back(inputs[i + 1 + num_layers].data_ptr<scalar_t>());75}76}77auto result = mlp_fp<scalar_t>(78inputs[0].data_ptr<scalar_t>(),79input_features,80batch_size,81w_ptr.data(),82num_layers,83output_features.data(),84b_ptr.data(),85out.data_ptr<scalar_t>(),86reserved_space.data_ptr<scalar_t>(),87use_bias,88activation,89(void*) (lt_workspace.data_ptr<scalar_t>()));90});9192return {out, reserved_space};93}9495std::vector<at::Tensor> mlp_backward(96int use_bias,97int activation,98at::Tensor grad_o,99std::vector<at::Tensor> fprop_outputs,100std::vector<at::Tensor> inputs) {101102auto num_layers = inputs.size() - 1;103if (use_bias) {104// inputs contains (input, weights, biases)105num_layers /= 2;106}107108auto batch_size = inputs[0].size(0);109auto input_features = inputs[0].size(1);110111bool requires_grad = inputs[0].requires_grad();112113std::vector<int> output_features;114for (int i = 0; i < num_layers; i++) {115output_features.push_back(inputs[i + 1].size(0));116}117// create outputs, length of inputs118std::vector<at::Tensor> outputs;119for (int i = 0; i < inputs.size(); i++) {120outputs.push_back(at::empty(inputs[i].sizes(), inputs[i].type())); // clone for testing now121}122123AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_backward", [&] {124std::vector<scalar_t*> w_ptr;125for (int i = 0; i < num_layers; i++) {126w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>());127}128std::vector<scalar_t*> outputs_ptr;129for (int i = 0; i < inputs.size(); i++) {130outputs_ptr.push_back(outputs[i].data_ptr<scalar_t>());131}132133auto work_size =134get_mlp_bp_workspace_in_bytes<scalar_t>(batch_size, num_layers, output_features.data());135136// auto work_space = at::empty({work_size*4}, at::kByte);137auto work_space = at::empty({work_size / sizeof(scalar_t)}, inputs[0].type());138139auto result = mlp_bp<scalar_t>(140inputs[0].data_ptr<scalar_t>(),141fprop_outputs[0].data_ptr<scalar_t>(),142input_features,143batch_size,144w_ptr.data(),145num_layers,146output_features.data(),147grad_o.contiguous().data_ptr<scalar_t>(),148fprop_outputs[1].data_ptr<scalar_t>(),149work_space.data_ptr<scalar_t>(),150outputs_ptr[0],151outputs_ptr.data() + 1,152outputs_ptr.data() + 1 + num_layers,153requires_grad,154use_bias,155activation);156});157158return outputs;159}160161PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {162m.def("forward", &mlp_forward, "MLP forward");163m.def("backward", &mlp_backward, "MLP backward");164}165166167168