CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/csrc/amp_C_frontend.cpp
Views: 792
1
#include <torch/extension.h>
2
3
void multi_tensor_scale_cuda(
4
int chunk_size,
5
at::Tensor noop_flag,
6
std::vector<std::vector<at::Tensor>> tensor_lists,
7
float scale);
8
9
void multi_tensor_sgd_cuda(
10
int chunk_size,
11
at::Tensor noop_flag,
12
std::vector<std::vector<at::Tensor>> tensor_lists,
13
float wd,
14
float momentum,
15
float dampening,
16
float lr,
17
bool nesterov,
18
bool first_run,
19
bool wd_after_momentum,
20
float scale);
21
22
void multi_tensor_axpby_cuda(
23
int chunk_size,
24
at::Tensor noop_flag,
25
std::vector<std::vector<at::Tensor>> tensor_lists,
26
float a,
27
float b,
28
int arg_to_check);
29
30
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
31
int chunk_size,
32
at::Tensor noop_flag,
33
std::vector<std::vector<at::Tensor>> tensor_lists,
34
at::optional<bool> per_tensor_python);
35
36
void multi_tensor_lamb_stage1_cuda(
37
int chunk_size,
38
at::Tensor noop_flag,
39
std::vector<std::vector<at::Tensor>> tensor_lists,
40
at::Tensor per_tensor_decay,
41
const int step,
42
const float beta1,
43
const float beta2,
44
const float epsilon,
45
at::Tensor global_grad_norm,
46
const float max_global_grad_norm);
47
48
void multi_tensor_lamb_stage2_cuda(
49
int chunk_size,
50
at::Tensor noop_flag,
51
std::vector<std::vector<at::Tensor>> tensor_lists,
52
at::Tensor per_tensor_param_norm,
53
at::Tensor per_tensor_update_norm,
54
const float lr,
55
const float weight_decay,
56
at::optional<bool> use_nvlamb_python);
57
58
void multi_tensor_adam_cuda(
59
int chunk_size,
60
at::Tensor noop_flag,
61
std::vector<std::vector<at::Tensor>> tensor_lists,
62
const float lr,
63
const float beta1,
64
const float beta2,
65
const float epsilon,
66
const int step,
67
const int mode,
68
const int bias_correction,
69
const float weight_decay);
70
71
72
void multi_tensor_adagrad_cuda(
73
int chunk_size,
74
at::Tensor noop_flag,
75
std::vector<std::vector<at::Tensor>> tensor_lists,
76
const float lr,
77
const float epsilon,
78
const int mode,
79
const float weight_decay);
80
81
82
void multi_tensor_novograd_cuda(
83
int chunk_size,
84
at::Tensor noop_flag,
85
std::vector<std::vector<at::Tensor>> tensor_lists,
86
at::Tensor grad_norms,
87
const float lr,
88
const float beta1,
89
const float beta2,
90
const float epsilon,
91
const int step,
92
const int bias_correction,
93
const float weight_decay,
94
const int grad_averaging,
95
const int mode,
96
const int norm_type);
97
98
void multi_tensor_lamb_cuda(
99
int chunk_size,
100
at::Tensor noop_flag,
101
std::vector<std::vector<at::Tensor>> tensor_lists,
102
const float lr,
103
const float beta1,
104
const float beta2,
105
const float epsilon,
106
const int step,
107
const int bias_correction,
108
const float weight_decay,
109
const int grad_averaging,
110
const int mode,
111
at::Tensor global_grad_norm,
112
const float max_grad_norm,
113
at::optional<bool> use_nvlamb_python);
114
115
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
116
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
117
"Fused overflow check + scale for a list of contiguous tensors");
118
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
119
"Fused SGD optimizer for list of contiguous tensors");
120
m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda,
121
"out = a*x + b*y for a list of contiguous tensors");
122
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
123
"Computes L2 norm for a list of contiguous tensors");
124
m.def("multi_tensor_lamb_stage1_cuda", &multi_tensor_lamb_stage1_cuda,
125
"Computes update part of LAMB optimizer");
126
m.def("multi_tensor_lamb_stage2_cuda", &multi_tensor_lamb_stage2_cuda,
127
"Completes application of gradient to parameters for LAMB optimizer");
128
m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
129
"Compute and apply gradient update to parameters for Adam optimizer");
130
m.def("multi_tensor_adagrad", &multi_tensor_adagrad_cuda,
131
"Compute and apply gradient update to parameters for Adam optimizer");
132
m.def("multi_tensor_novograd", &multi_tensor_novograd_cuda,
133
"Compute and apply gradient update to parameters for Adam optimizer");
134
m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda,
135
"Computes and apply update for LAMB optimizer");
136
}
137
138