sparkl2d_kernels/cuda/
prefix_sum.rs1use cuda_std::{shared_array, thread};
2
3#[cfg_attr(target_os = "cuda", cuda_std::kernel)]
4pub unsafe fn add_data_grp(data: *mut u32, data_len: u32, rhs: *mut u32) {
5 let id = thread::index_1d();
6 let bid = thread::block_idx_x();
7
8 if id < data_len {
9 *data.add(id as usize) += *rhs.add(bid as usize);
10 }
11}
12
13#[cfg_attr(target_os = "cuda", cuda_std::kernel)]
17pub unsafe fn prefix_sum_512(data: *mut u32, data_len: u32, aux: *mut u32) {
18 const THREADS: usize = 512;
19 let thread_id = thread::thread_idx_x();
20 let block_id = thread::block_idx_x();
21 let shared = shared_array![u32; THREADS];
22
23 if block_id * THREADS as u32 >= data_len {
24 return;
25 }
26
27 let data_block_len = data_len as usize - block_id as usize * THREADS;
28 let shared_len = data_block_len.next_power_of_two().min(THREADS).max(1);
29 let elt_id = (thread_id + block_id * THREADS as u32) as usize;
30
31 prefix_sum(
32 data,
33 data_len as usize,
34 elt_id,
35 aux,
36 shared,
37 shared_len,
38 thread_id,
39 block_id,
40 )
41}
42
43unsafe fn prefix_sum(
47 data: *mut u32,
48 data_len: usize,
49 elt_id: usize,
50 aux: *mut u32,
51 shared: *mut u32,
52 shared_len: usize,
53 thread_id: u32,
54 block_id: u32,
55) {
56 let bid = block_id as usize;
57 let tid = thread_id as usize;
58
59 *shared.add(tid) = if elt_id < data_len {
61 *data.add(elt_id)
62 } else {
63 0
64 };
65
66 let mut d = shared_len / 2;
68 let mut offset = 1;
69 while d > 0 {
70 thread::sync_threads();
71 if tid < d {
72 let ia = tid * 2 * offset + offset - 1;
73 let ib = (tid * 2 + 1) * offset + offset - 1;
74
75 let sum = *shared.add(ia) + *shared.add(ib);
76 *shared.add(ib) = sum;
77 }
78
79 d /= 2;
80 offset *= 2;
81 }
82
83 if tid == 0 {
84 let total_sum = *shared.add(shared_len - 1);
85 *aux.add(bid) = total_sum;
86 *shared.add(shared_len - 1) = 0;
87 }
88
89 let mut d = 1;
91 let mut offset = shared_len / 2;
92
93 while d < shared_len {
94 thread::sync_threads();
95 if tid < d {
96 let ia = tid * 2 * offset + offset - 1;
97 let ib = (tid * 2 + 1) * offset + offset - 1;
98
99 let a = *shared.add(ia);
100 let b = *shared.add(ib);
101
102 *shared.add(ia) = b;
103 *shared.add(ib) = a + b;
104 }
105
106 d *= 2;
107 offset /= 2;
108 }
109
110 thread::sync_threads();
112 if elt_id < data_len as usize {
113 *data.add(elt_id) = *shared.add(tid);
114 };
115}