Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/csrc/layer_norm_cuda.cpp
Views: 792
#include <torch/extension.h>1#include <vector>2#include <cassert>3#include "compat.h"45namespace {6void compute_n1_n2(7at::Tensor input,8#ifdef VERSION_GE_1_19at::IntArrayRef normalized_shape,10#else11at::IntList normalized_shape,12#endif13int& n1,14int& n2)15{16int idiff = input.ndimension() - normalized_shape.size();17n2 = 1;18for (int i = 0; i < (int)normalized_shape.size(); ++i) {19assert( input.sizes()[i+idiff] == normalized_shape[i] );20n2 *= normalized_shape[i];21}22n1 = 1;23for (int i = 0; i < idiff; ++i) {24n1 *= input.sizes()[i];25}26}2728void check_args(29#ifdef VERSION_GE_1_130at::IntArrayRef normalized_shape,31#else32at::IntList normalized_shape,33#endif34at::Tensor gamma,35at::Tensor beta36)37{38TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));39TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));40}4142void check_args(43at::Tensor input,44#ifdef VERSION_GE_1_145at::IntArrayRef normalized_shape,46#else47at::IntList normalized_shape,48#endif49int& n1,50int& n251)52{53int64_t normalized_ndim = normalized_shape.size();5455if (normalized_ndim < 1) {56std::stringstream ss;57ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "58<< "containing at least one element, but got normalized_shape="59<< normalized_shape;60throw std::runtime_error(ss.str());61}6263auto input_shape = input.sizes();64auto input_ndim = input.dim();6566if (input_ndim < normalized_ndim ||67!input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {68std::stringstream ss;69ss << "Given normalized_shape=" << normalized_shape70<< ", expected input with shape [*";71for (auto size : normalized_shape) {72ss << ", " << size;73}74ss << "], but got input of size" << input_shape;75throw std::runtime_error(ss.str());76}7778compute_n1_n2(input,normalized_shape,n1,n2);79}808182void check_args(83at::Tensor input,84#ifdef VERSION_GE_1_185at::IntArrayRef normalized_shape,86#else87at::IntList normalized_shape,88#endif89at::Tensor gamma,90at::Tensor beta,91int& n1,92int& n293)94{95check_args(input,normalized_shape,n1,n2);96check_args(normalized_shape,gamma,beta);97}98}99100void cuda_layer_norm(101at::Tensor* output,102at::Tensor* mean,103at::Tensor* invvar,104at::Tensor* input,105int n1,106int n2,107#ifdef VERSION_GE_1_1108at::IntArrayRef normalized_shape,109#else110at::IntList normalized_shape,111#endif112at::Tensor* gamma,113at::Tensor* beta,114double epsilon);115116#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")117#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")118#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)119120std::vector<at::Tensor> layer_norm(121at::Tensor input,122#ifdef VERSION_GE_1_1123at::IntArrayRef normalized_shape,124#else125at::IntList normalized_shape,126#endif127double epsilon) {128CHECK_INPUT(input);129int n1,n2;130check_args(input,normalized_shape,n1,n2);131at::Tensor output = at::empty_like(input);132at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));133at::Tensor invvar = at::empty_like(mean);134cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,135normalized_shape,NULL,NULL,epsilon);136return {output, mean, invvar};137}138std::vector<at::Tensor> layer_norm_affine(139at::Tensor input,140#ifdef VERSION_GE_1_1141at::IntArrayRef normalized_shape,142#else143at::IntList normalized_shape,144#endif145at::Tensor gamma,146at::Tensor beta,147double epsilon) {148CHECK_INPUT(input);149CHECK_INPUT(gamma);150CHECK_INPUT(beta);151int n1,n2;152check_args(input,normalized_shape,gamma,beta,n1,n2);153at::Tensor output = at::empty_like(input);154at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));155at::Tensor invvar = at::empty_like(mean);156cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,157normalized_shape,&gamma,&beta,epsilon);158return {output, mean, invvar};159}160161void cuda_layer_norm_gradient(162at::Tensor* dout,163at::Tensor* mean,164at::Tensor* invvar,165at::Tensor* input,166int n1,167int n2,168#ifdef VERSION_GE_1_1169at::IntArrayRef normalized_shape,170#else171at::IntList normalized_shape,172#endif173at::Tensor* gamma,174at::Tensor* beta,175double epsilon,176at::Tensor* grad_input,177at::Tensor* grad_gamma,178at::Tensor* grad_beta179);180181at::Tensor layer_norm_gradient(182at::Tensor dout,183at::Tensor mean,184at::Tensor invvar,185at::Tensor input,186#ifdef VERSION_GE_1_1187at::IntArrayRef normalized_shape,188#else189at::IntList normalized_shape,190#endif191double epsilon) {192CHECK_INPUT(dout);193CHECK_INPUT(mean);194CHECK_INPUT(invvar);195CHECK_INPUT(input);196int n1,n2;197check_args(input,normalized_shape,n1,n2);198at::Tensor grad_input = at::empty_like(input);199cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,200normalized_shape,NULL,NULL,epsilon,201&grad_input,NULL,NULL);202return grad_input;203}204std::vector<at::Tensor> layer_norm_gradient_affine(205at::Tensor dout,206at::Tensor mean,207at::Tensor invvar,208at::Tensor input,209#ifdef VERSION_GE_1_1210at::IntArrayRef normalized_shape,211#else212at::IntList normalized_shape,213#endif214at::Tensor gamma,215at::Tensor beta,216double epsilon) {217CHECK_INPUT(dout);218CHECK_INPUT(mean);219CHECK_INPUT(invvar);220CHECK_INPUT(input);221CHECK_INPUT(gamma);222CHECK_INPUT(beta);223int n1,n2;224check_args(input,normalized_shape,gamma,beta,n1,n2);225at::Tensor grad_input = at::empty_like(input);226at::Tensor grad_gamma = at::empty_like(gamma);227at::Tensor grad_beta = at::empty_like(beta);228cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,229normalized_shape,&gamma,&beta,epsilon,230&grad_input,&grad_gamma,&grad_beta);231return {grad_input, grad_gamma, grad_beta};232}233234PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {235m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)");236m.def("forward", &layer_norm, "LayerNorm forward (CUDA)");237m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)");238m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)");239}240241242243