sparkl2d_kernels/cuda/
prefix_sum.rs

1use cuda_std::{shared_array, thread};
2
3#[cfg_attr(target_os = "cuda", cuda_std::kernel)]
4pub unsafe fn add_data_grp(data: *mut u32, data_len: u32, rhs: *mut u32) {
5    let id = thread::index_1d();
6    let bid = thread::block_idx_x();
7
8    if id < data_len {
9        *data.add(id as usize) += *rhs.add(bid as usize);
10    }
11}
12
13// TODO: Benchmark and see if we can do 1024 instead of 512.
14// TODO: use dynamic shared memory instead.
15// TODO: optimize to avoid bank conflicts.
16#[cfg_attr(target_os = "cuda", cuda_std::kernel)]
17pub unsafe fn prefix_sum_512(data: *mut u32, data_len: u32, aux: *mut u32) {
18    const THREADS: usize = 512;
19    let thread_id = thread::thread_idx_x();
20    let block_id = thread::block_idx_x();
21    let shared = shared_array![u32; THREADS];
22
23    if block_id * THREADS as u32 >= data_len {
24        return;
25    }
26
27    let data_block_len = data_len as usize - block_id as usize * THREADS;
28    let shared_len = data_block_len.next_power_of_two().min(THREADS).max(1);
29    let elt_id = (thread_id + block_id * THREADS as u32) as usize;
30
31    prefix_sum(
32        data,
33        data_len as usize,
34        elt_id,
35        aux,
36        shared,
37        shared_len,
38        thread_id,
39        block_id,
40    )
41}
42
43// NOTE:
44//       `shared` must contain at least `shared_len` elements.
45//       `shared_len` must be a power of two.
46unsafe fn prefix_sum(
47    data: *mut u32,
48    data_len: usize,
49    elt_id: usize,
50    aux: *mut u32,
51    shared: *mut u32,
52    shared_len: usize,
53    thread_id: u32,
54    block_id: u32,
55) {
56    let bid = block_id as usize;
57    let tid = thread_id as usize;
58
59    // Init the shared memory.
60    *shared.add(tid) = if elt_id < data_len {
61        *data.add(elt_id)
62    } else {
63        0
64    };
65
66    // Up-Sweep.
67    let mut d = shared_len / 2;
68    let mut offset = 1;
69    while d > 0 {
70        thread::sync_threads();
71        if tid < d {
72            let ia = tid * 2 * offset + offset - 1;
73            let ib = (tid * 2 + 1) * offset + offset - 1;
74
75            let sum = *shared.add(ia) + *shared.add(ib);
76            *shared.add(ib) = sum;
77        }
78
79        d /= 2;
80        offset *= 2;
81    }
82
83    if tid == 0 {
84        let total_sum = *shared.add(shared_len - 1);
85        *aux.add(bid) = total_sum;
86        *shared.add(shared_len - 1) = 0;
87    }
88
89    // Down-Sweep
90    let mut d = 1;
91    let mut offset = shared_len / 2;
92
93    while d < shared_len {
94        thread::sync_threads();
95        if tid < d {
96            let ia = tid * 2 * offset + offset - 1;
97            let ib = (tid * 2 + 1) * offset + offset - 1;
98
99            let a = *shared.add(ia);
100            let b = *shared.add(ib);
101
102            *shared.add(ia) = b;
103            *shared.add(ib) = a + b;
104        }
105
106        d *= 2;
107        offset /= 2;
108    }
109
110    // Writeback the result
111    thread::sync_threads();
112    if elt_id < data_len as usize {
113        *data.add(elt_id) = *shared.add(tid);
114    };
115}