Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/csrc/syncbn.cpp
Views: 802
#include <torch/extension.h>1#include <ATen/ATen.h>23#include <vector>45// returns {mean,biased_var}6// implemented using welford7std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input);89// reduces array of mean/var across processes10// returns global {mean,inv_std,biased_var}11// implemented using welford12std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes,13const at::Tensor var_biased_feature_nodes,14const at::Tensor numel,15const float eps);1617// elementwise BN operation, returns output18// input/weight/shift should have identical data type;19// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)20at::Tensor batchnorm_forward_CUDA(const at::Tensor input,21const at::Tensor mean,22const at::Tensor inv_std,23const at::optional<at::Tensor> weight,24const at::optional<at::Tensor> shift);2526// backward BN operation, returns {sum_dy, sum_dy_xmu, grad_weight, grad_bias}27// grad_output/input should have identical data type;28// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)29// implemented using kahan summation30std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output,31const at::Tensor input,32const at::Tensor mean,33const at::Tensor inv_std,34const at::optional<at::Tensor> weight);3536// elementwise backward BN operation, returns grad_input37// grad_output/input/weight precision could be fp16/fp32;38// mean/inv_std/sum_dy/sum_dy_xmu precision is fp3239at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output,40const at::Tensor input,41const at::Tensor mean,42const at::Tensor inv_std,43const at::optional<at::Tensor> weight,44const at::Tensor sum_dy,45const at::Tensor sum_dy_xmu,46const at::Tensor count);4748// returns {mean, biased_var}49// implemented using welford50// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL51std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input);5253// elementwise BN operation, returns output54// input/weight/shift should have identical data type;55// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)56// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL57at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,58const at::optional<at::Tensor> z,59const at::Tensor mean,60const at::Tensor inv_std,61const at::optional<at::Tensor> weight,62const at::optional<at::Tensor> shift,63const bool fuse_relu);6465// backward BN operation, returns {sum_dy, sum_dy_xmu, grad_weight, grad_bias}66// grad_output/input should have identical data type;67// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)68// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL69std::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output,70const at::Tensor input,71const at::Tensor mean,72const at::Tensor inv_std,73const at::optional<at::Tensor> weight);7475// elementwise backward BN operation, returns grad_input76// grad_output/input/weight precision could be fp16/fp32;77// mean/inv_std/sum_dy/sum_dy_xmu precision is fp3278// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL79at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,80const at::Tensor input,81const at::Tensor mean,82const at::Tensor inv_std,83const at::optional<at::Tensor> weight,84const at::Tensor sum_dy,85const at::Tensor sum_dy_xmu,86const at::Tensor count);8788at::Tensor relu_backward_c_last_CUDA(const at::Tensor grad_output,89const at::Tensor input,90const at::optional<at::Tensor> z,91const at::Tensor mean,92const at::Tensor inv_std,93const at::optional<at::Tensor> weight,94const at::optional<at::Tensor> shift);959697PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {98m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance");99m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance");100m.def("batchnorm_forward", &batchnorm_forward_CUDA, "batchnorm forward");101m.def("reduce_bn", &reduce_bn_CUDA, "batchnorm backward reduce grad sum and bias/weight grad");102m.def("batchnorm_backward", &batchnorm_backward_CUDA, "batchnorm backward dgrad");103m.def("welford_mean_var_c_last", &welford_mean_var_c_last_CUDA, "welford mean variance nhwc");104m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc");105m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc");106m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc");107m.def("relu_bw_c_last", &relu_backward_c_last_CUDA, "relu_bw_c_last");108}109110111