Contact Us!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/csrc/syncbn.cpp
Views: 802
1
#include <torch/extension.h>
2
#include <ATen/ATen.h>
3
4
#include <vector>
5
6
// returns {mean,biased_var}
7
// implemented using welford
8
std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input);
9
10
// reduces array of mean/var across processes
11
// returns global {mean,inv_std,biased_var}
12
// implemented using welford
13
std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes,
14
const at::Tensor var_biased_feature_nodes,
15
const at::Tensor numel,
16
const float eps);
17
18
// elementwise BN operation, returns output
19
// input/weight/shift should have identical data type;
20
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
21
at::Tensor batchnorm_forward_CUDA(const at::Tensor input,
22
const at::Tensor mean,
23
const at::Tensor inv_std,
24
const at::optional<at::Tensor> weight,
25
const at::optional<at::Tensor> shift);
26
27
// backward BN operation, returns {sum_dy, sum_dy_xmu, grad_weight, grad_bias}
28
// grad_output/input should have identical data type;
29
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
30
// implemented using kahan summation
31
std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output,
32
const at::Tensor input,
33
const at::Tensor mean,
34
const at::Tensor inv_std,
35
const at::optional<at::Tensor> weight);
36
37
// elementwise backward BN operation, returns grad_input
38
// grad_output/input/weight precision could be fp16/fp32;
39
// mean/inv_std/sum_dy/sum_dy_xmu precision is fp32
40
at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output,
41
const at::Tensor input,
42
const at::Tensor mean,
43
const at::Tensor inv_std,
44
const at::optional<at::Tensor> weight,
45
const at::Tensor sum_dy,
46
const at::Tensor sum_dy_xmu,
47
const at::Tensor count);
48
49
// returns {mean, biased_var}
50
// implemented using welford
51
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
52
std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input);
53
54
// elementwise BN operation, returns output
55
// input/weight/shift should have identical data type;
56
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
57
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
58
at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
59
const at::optional<at::Tensor> z,
60
const at::Tensor mean,
61
const at::Tensor inv_std,
62
const at::optional<at::Tensor> weight,
63
const at::optional<at::Tensor> shift,
64
const bool fuse_relu);
65
66
// backward BN operation, returns {sum_dy, sum_dy_xmu, grad_weight, grad_bias}
67
// grad_output/input should have identical data type;
68
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
69
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
70
std::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output,
71
const at::Tensor input,
72
const at::Tensor mean,
73
const at::Tensor inv_std,
74
const at::optional<at::Tensor> weight);
75
76
// elementwise backward BN operation, returns grad_input
77
// grad_output/input/weight precision could be fp16/fp32;
78
// mean/inv_std/sum_dy/sum_dy_xmu precision is fp32
79
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
80
at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
81
const at::Tensor input,
82
const at::Tensor mean,
83
const at::Tensor inv_std,
84
const at::optional<at::Tensor> weight,
85
const at::Tensor sum_dy,
86
const at::Tensor sum_dy_xmu,
87
const at::Tensor count);
88
89
at::Tensor relu_backward_c_last_CUDA(const at::Tensor grad_output,
90
const at::Tensor input,
91
const at::optional<at::Tensor> z,
92
const at::Tensor mean,
93
const at::Tensor inv_std,
94
const at::optional<at::Tensor> weight,
95
const at::optional<at::Tensor> shift);
96
97
98
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
99
m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance");
100
m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance");
101
m.def("batchnorm_forward", &batchnorm_forward_CUDA, "batchnorm forward");
102
m.def("reduce_bn", &reduce_bn_CUDA, "batchnorm backward reduce grad sum and bias/weight grad");
103
m.def("batchnorm_backward", &batchnorm_backward_CUDA, "batchnorm backward dgrad");
104
m.def("welford_mean_var_c_last", &welford_mean_var_c_last_CUDA, "welford mean variance nhwc");
105
m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc");
106
m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc");
107
m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc");
108
m.def("relu_bw_c_last", &relu_backward_c_last_CUDA, "relu_bw_c_last");
109
}
110
111