CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/csrc/multi_tensor_axpby_kernel.cu
Views: 792
1
#include <ATen/ATen.h>
2
#include <ATen/AccumulateType.h>
3
#include <ATen/cuda/CUDAContext.h>
4
#include <ATen/cuda/Exceptions.h>
5
// Another possibility:
6
// #include <torch/all.h>
7
8
#include <assert.h>
9
10
#include "type_shim.h"
11
#include "multi_tensor_apply.cuh"
12
13
#define BLOCK_SIZE 512
14
#define ILP 4
15
16
template<typename T>
17
__device__ __forceinline__ bool is_aligned(T* p){
18
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
19
}
20
21
template<typename T>
22
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
23
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
24
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
25
}
26
27
template<typename x_t, typename y_t, typename out_t>
28
struct AxpbyFunctor
29
{
30
__device__ __forceinline__ void operator()(
31
int chunk_size,
32
volatile int* noop_gmem,
33
TensorListMetadata<3>& tl,
34
float a,
35
float b,
36
int arg_to_check)
37
{
38
// I'd like this kernel to propagate infs/nans.
39
// if(*noop_gmem == 1)
40
// return;
41
42
int tensor_loc = tl.block_to_tensor[blockIdx.x];
43
int chunk_idx = tl.block_to_chunk[blockIdx.x];
44
int n = tl.sizes[tensor_loc];
45
46
x_t* x = (x_t*)tl.addresses[0][tensor_loc];
47
x += chunk_idx*chunk_size;
48
49
y_t* y = (y_t*)tl.addresses[1][tensor_loc];
50
y += chunk_idx*chunk_size;
51
52
out_t* out = (out_t*)tl.addresses[2][tensor_loc];
53
out += chunk_idx*chunk_size;
54
55
n -= chunk_idx*chunk_size;
56
57
bool finite = true;
58
x_t r_x[ILP];
59
y_t r_y[ILP];
60
out_t r_out[ILP];
61
62
// to make things simple, we put aligned case in a different code path
63
if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x) && is_aligned(y) && is_aligned(out))
64
{
65
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
66
{
67
// load
68
load_store(r_x, x, 0 , i_start);
69
load_store(r_y, y, 0 , i_start);
70
#pragma unroll
71
for(int ii = 0; ii < ILP; ii++)
72
{
73
r_out[ii] = a*static_cast<float>(r_x[ii]) + b*static_cast<float>(r_y[ii]);
74
if(arg_to_check == -1)
75
finite = finite && (isfinite(r_x[ii]) && isfinite(r_y[ii]));
76
if(arg_to_check == 0)
77
finite = finite && isfinite(r_x[ii]);
78
if(arg_to_check == 1)
79
finite = finite && isfinite(r_y[ii]);
80
}
81
// store
82
load_store(out, r_out, i_start , 0);
83
}
84
}
85
else
86
{
87
// Non-divergent exit condition for __syncthreads, not necessary here
88
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
89
{
90
#pragma unroll
91
for(int ii = 0; ii < ILP; ii++)
92
{
93
r_x[ii] = 0;
94
r_y[ii] = 0;
95
int i = i_start + threadIdx.x + ii*blockDim.x;
96
if(i < n && i < chunk_size)
97
{
98
r_x[ii] = x[i];
99
r_y[ii] = y[i];
100
}
101
}
102
#pragma unroll
103
for(int ii = 0; ii < ILP; ii++)
104
{
105
r_out[ii] = a*static_cast<float>(r_x[ii]) + b*static_cast<float>(r_y[ii]);
106
if(arg_to_check == -1)
107
finite = finite && (isfinite(r_x[ii]) && isfinite(r_y[ii]));
108
if(arg_to_check == 0)
109
finite = finite && isfinite(r_x[ii]);
110
if(arg_to_check == 1)
111
finite = finite && isfinite(r_y[ii]);
112
}
113
// see note in multi_tensor_scale_kernel.cu
114
#pragma unroll
115
for(int ii = 0; ii < ILP; ii++)
116
{
117
int i = i_start + threadIdx.x + ii*blockDim.x;
118
if(i < n && i < chunk_size)
119
out[i] = r_out[ii];
120
}
121
}
122
}
123
if(!finite)
124
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
125
}
126
};
127
128
void multi_tensor_axpby_cuda(
129
int chunk_size,
130
at::Tensor noop_flag,
131
std::vector<std::vector<at::Tensor>> tensor_lists,
132
float a,
133
float b,
134
int arg_to_check)
135
{
136
using namespace at;
137
// The output (downscaled) type is always float.
138
// If build times suffer, think about where to put this dispatch,
139
// and what logic should be moved out of multi_tensor_apply.
140
141
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_axpby_cuda",
142
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_axpby_cuda",
143
DISPATCH_FLOAT_AND_HALF(tensor_lists[2][0].scalar_type(), 2, "multi_tensor_axpby_cuda",
144
multi_tensor_apply<3>(
145
BLOCK_SIZE,
146
chunk_size,
147
noop_flag,
148
tensor_lists,
149
AxpbyFunctor<scalar_t_0, scalar_t_1, scalar_t_2>(),
150
a,
151
b,
152
arg_to_check); )))
153
154
AT_CUDA_CHECK(cudaGetLastError());
155
156
// AT_CUDA_CHECK(cudaDeviceSynchronize());
157
}
158
159