-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathTHCTensorSort.cu
66 lines (55 loc) · 2.1 KB
/
THCTensorSort.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include <THC/THCTensorSort.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAException.h>
void THCudaLongTensor_fillSliceWithIndex(THCState* state,
THCudaLongTensor* t,
int dim) {
int64_t dims = THCudaLongTensor_nDimensionLegacyNoScalars(state, t);
THArgCheck(dims <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
ptrdiff_t inElements = THCudaLongTensor_nElement(state, t);
if (inElements > 0) {
int64_t sliceSize = THCudaLongTensor_sizeLegacyNoScalars(state, t, dim);
ptrdiff_t numSlices = inElements / sliceSize;
dim3 grid;
if (!THC_getGridFromTiles(numSlices, grid)) {
THError("Slice to fill with indices is too large");
}
int64_t maxThreads =
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
int64_t numThreads = sliceSize;
if (numThreads > maxThreads) {
numThreads = maxThreads;
}
dim3 block(numThreads);
#define FILL_INDEX(T, DIM) \
fillSliceWithIndex<T, DIM> \
<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>( \
info, numSlices, sliceSize, info.strides[collapseDim]); \
C10_CUDA_KERNEL_LAUNCH_CHECK()
if (THCTensor_canUse32BitIndexMath(state, t)) {
TensorInfo<int64_t, uint32_t> info =
getTensorInfo<int64_t, THCudaLongTensor, unsigned int>(state, t);
info.reduceDim(dim);
int collapseDim = info.collapseDims(dim);
if (info.isContiguous()) {
FILL_INDEX(unsigned int, -2);
} else {
if (info.dims == 1) {
FILL_INDEX(unsigned int, 1);
} else if (info.dims == 2) {
FILL_INDEX(unsigned int, 2);
} else {
FILL_INDEX(unsigned int, -1);
}
}
} else {
TensorInfo<int64_t, uint64_t> info =
getTensorInfo<int64_t, THCudaLongTensor, uint64_t>(state, t);
info.reduceDim(dim);
int collapseDim = info.collapseDims(dim);
// catch-all implementation
FILL_INDEX(uint64_t, -1);
}
#undef FILL_INDEX
}
}