sparkl2d_kernels/cuda/
hashmap.rs1use crate::cuda::atomic::AtomicInt;
2use crate::{BlockHeaderId, BlockVirtualId, DevicePointer};
3use cuda_std::thread;
4
5const EMPTY: u64 = u64::MAX;
6
7fn hash(mut key: u64) -> u64 {
8 key ^= key >> 16;
9 key *= 0x85ebca6b;
10 key ^= key >> 13;
11 key *= 0xc2b2ae35;
12 key ^= key >> 16;
13 key
14}
15
16#[cfg_attr(not(target_os = "cuda"), derive(cust::DeviceCopy))]
17#[derive(Copy, Clone, Debug, PartialEq, Eq, bytemuck::Zeroable)]
18#[repr(C)]
19pub struct GridHashMapEntry {
20 pub key: BlockVirtualId,
21 pub value: BlockHeaderId,
22}
23
24impl GridHashMapEntry {
25 pub fn free() -> Self {
26 Self {
27 key: BlockVirtualId(EMPTY),
28 value: BlockHeaderId(0),
29 }
30 }
31}
32
33#[cfg_attr(not(target_os = "cuda"), derive(cust::DeviceCopy))]
34#[derive(Clone, Copy)]
35#[repr(C)]
36pub struct GridHashMap {
37 entries: DevicePointer<GridHashMapEntry>,
38 capacity: u32, }
40
41impl GridHashMap {
42 pub unsafe fn from_raw_parts(entries: DevicePointer<GridHashMapEntry>, capacity: u32) -> Self {
45 Self { entries, capacity }
46 }
47
48 pub fn insert_nonexistant_with(
49 &mut self,
50 key: BlockVirtualId,
51 mut value: impl FnMut() -> BlockHeaderId,
52 ) {
53 let mut slot = hash(key.0) & (self.capacity as u64 - 1);
54
55 for _ in 0..self.capacity - 1 {
59 let entry = unsafe { &mut *self.entries.as_mut_ptr().add(slot as usize) };
60 let prev = unsafe { entry.key.0.global_atomic_cas(EMPTY, key.0) };
61 if prev == EMPTY {
62 entry.value = value();
63 break;
64 } else if prev == key.0 {
65 break; }
67
68 slot = (slot + 1) & (self.capacity as u64 - 1);
69 }
70 }
71
72 pub fn get(&self, key: BlockVirtualId) -> Option<BlockHeaderId> {
73 let mut slot = hash(key.0) & (self.capacity as u64 - 1);
74
75 loop {
76 let entry = unsafe { *self.entries.as_ptr().add(slot as usize) };
77 if entry.key == key {
78 return Some(entry.value);
79 }
80
81 if entry.key.0 == EMPTY {
82 return None;
83 }
84
85 slot = (slot + 1) & (self.capacity as u64 - 1);
86 }
87 }
88}
89
90#[cfg_attr(target_os = "cuda", cuda_std::kernel)]
91pub unsafe fn reset_hashmap(grid: GridHashMap) {
92 let id = thread::index();
93 if (id as u32) < grid.capacity {
94 *grid.entries.as_mut_ptr().add(id as usize) = GridHashMapEntry::free();
95 }
96}