Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/csrc/amp_C_frontend.cpp
Views: 792
#include <torch/extension.h>12void multi_tensor_scale_cuda(3int chunk_size,4at::Tensor noop_flag,5std::vector<std::vector<at::Tensor>> tensor_lists,6float scale);78void multi_tensor_sgd_cuda(9int chunk_size,10at::Tensor noop_flag,11std::vector<std::vector<at::Tensor>> tensor_lists,12float wd,13float momentum,14float dampening,15float lr,16bool nesterov,17bool first_run,18bool wd_after_momentum,19float scale);2021void multi_tensor_axpby_cuda(22int chunk_size,23at::Tensor noop_flag,24std::vector<std::vector<at::Tensor>> tensor_lists,25float a,26float b,27int arg_to_check);2829std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(30int chunk_size,31at::Tensor noop_flag,32std::vector<std::vector<at::Tensor>> tensor_lists,33at::optional<bool> per_tensor_python);3435void multi_tensor_lamb_stage1_cuda(36int chunk_size,37at::Tensor noop_flag,38std::vector<std::vector<at::Tensor>> tensor_lists,39at::Tensor per_tensor_decay,40const int step,41const float beta1,42const float beta2,43const float epsilon,44at::Tensor global_grad_norm,45const float max_global_grad_norm);4647void multi_tensor_lamb_stage2_cuda(48int chunk_size,49at::Tensor noop_flag,50std::vector<std::vector<at::Tensor>> tensor_lists,51at::Tensor per_tensor_param_norm,52at::Tensor per_tensor_update_norm,53const float lr,54const float weight_decay,55at::optional<bool> use_nvlamb_python);5657void multi_tensor_adam_cuda(58int chunk_size,59at::Tensor noop_flag,60std::vector<std::vector<at::Tensor>> tensor_lists,61const float lr,62const float beta1,63const float beta2,64const float epsilon,65const int step,66const int mode,67const int bias_correction,68const float weight_decay);697071void multi_tensor_adagrad_cuda(72int chunk_size,73at::Tensor noop_flag,74std::vector<std::vector<at::Tensor>> tensor_lists,75const float lr,76const float epsilon,77const int mode,78const float weight_decay);798081void multi_tensor_novograd_cuda(82int chunk_size,83at::Tensor noop_flag,84std::vector<std::vector<at::Tensor>> tensor_lists,85at::Tensor grad_norms,86const float lr,87const float beta1,88const float beta2,89const float epsilon,90const int step,91const int bias_correction,92const float weight_decay,93const int grad_averaging,94const int mode,95const int norm_type);9697void multi_tensor_lamb_cuda(98int chunk_size,99at::Tensor noop_flag,100std::vector<std::vector<at::Tensor>> tensor_lists,101const float lr,102const float beta1,103const float beta2,104const float epsilon,105const int step,106const int bias_correction,107const float weight_decay,108const int grad_averaging,109const int mode,110at::Tensor global_grad_norm,111const float max_grad_norm,112at::optional<bool> use_nvlamb_python);113114PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {115m.def("multi_tensor_scale", &multi_tensor_scale_cuda,116"Fused overflow check + scale for a list of contiguous tensors");117m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,118"Fused SGD optimizer for list of contiguous tensors");119m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda,120"out = a*x + b*y for a list of contiguous tensors");121m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,122"Computes L2 norm for a list of contiguous tensors");123m.def("multi_tensor_lamb_stage1_cuda", &multi_tensor_lamb_stage1_cuda,124"Computes update part of LAMB optimizer");125m.def("multi_tensor_lamb_stage2_cuda", &multi_tensor_lamb_stage2_cuda,126"Completes application of gradient to parameters for LAMB optimizer");127m.def("multi_tensor_adam", &multi_tensor_adam_cuda,128"Compute and apply gradient update to parameters for Adam optimizer");129m.def("multi_tensor_adagrad", &multi_tensor_adagrad_cuda,130"Compute and apply gradient update to parameters for Adam optimizer");131m.def("multi_tensor_novograd", &multi_tensor_novograd_cuda,132"Compute and apply gradient update to parameters for Adam optimizer");133m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda,134"Computes and apply update for LAMB optimizer");135}136137138