CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/csrc/flatten_unflatten.cpp
Views: 792
1
#include <torch/extension.h>
2
#include <torch/csrc/utils/tensor_flatten.h>
3
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_flatten.h
4
5
at::Tensor flatten(std::vector<at::Tensor> tensors)
6
{
7
return torch::utils::flatten_dense_tensors(tensors);
8
}
9
10
std::vector<at::Tensor> unflatten(at::Tensor flat, std::vector<at::Tensor> tensors)
11
{
12
return torch::utils::unflatten_dense_tensors(flat, tensors);
13
}
14
15
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
16
m.def("flatten", &flatten, "Flatten dense tensors");
17
m.def("unflatten", &unflatten, "Unflatten dense tensors");
18
}
19
20