1use std::{collections::HashMap, num::NonZeroU32};
4
5use crate::hash::Fingerprint;
6
7const MAX_SHARD_SIZE: u32 = 512;
10
11fn default_shard_size() -> NonZeroU32 {
13    static ITEM_SHARD_SIZE: std::sync::OnceLock<NonZeroU32> = std::sync::OnceLock::new();
14
15    fn determine_default_shard_size() -> NonZeroU32 {
18        let thread_cnt = {
20            std::thread::available_parallelism()
21                .map(|n| n.get())
22                .unwrap_or(1)
23        };
24
25        let size = (thread_cnt.next_power_of_two() * 2) as u32;
27        NonZeroU32::new(size.min(MAX_SHARD_SIZE)).unwrap()
29    }
30
31    *ITEM_SHARD_SIZE.get_or_init(determine_default_shard_size)
32}
33
34type FMapBase<V> = parking_lot::RwLock<HashMap<Fingerprint, V>>;
35
36pub struct FingerprintMap<V> {
46    mask: u32,
47    shards: Vec<parking_lot::RwLock<HashMap<Fingerprint, V>>>,
48}
49
50impl<V> Default for FingerprintMap<V> {
51    fn default() -> Self {
52        Self::new(default_shard_size())
53    }
54}
55
56impl<V> FingerprintMap<V> {
57    pub fn new(shard_size: NonZeroU32) -> Self {
59        let shard_size = shard_size.get().next_power_of_two();
60        let shard_size = shard_size.min(MAX_SHARD_SIZE);
61
62        assert!(
63            shard_size.is_power_of_two(),
64            "shard size must be a power of two"
65        );
66        assert!(shard_size > 0, "shard size must be greater than zero");
67        Self {
68            mask: shard_size - 1,
69            shards: (0..shard_size)
70                .map(|_| parking_lot::RwLock::new(HashMap::new()))
71                .collect(),
72        }
73    }
74
75    pub fn into_items(self) -> impl Iterator<Item = (Fingerprint, V)> {
77        self.shards
78            .into_iter()
79            .flat_map(|shard| shard.into_inner().into_iter())
80    }
81
82    pub fn shard(&self, fg: Fingerprint) -> &FMapBase<V> {
84        let shards = &self.shards;
85        let route_idx = (fg.lower32() & self.mask) as usize;
86
87        debug_assert!(route_idx < shards.len());
89        unsafe { shards.get_unchecked(route_idx) }
92    }
93
94    pub fn as_mut_slice(&mut self) -> &mut [FMapBase<V>] {
96        &mut self.shards
97    }
98
99    pub fn contains_key(&self, fg: &Fingerprint) -> bool {
101        self.shard(*fg).read().contains_key(fg)
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    #[test]
108    fn test_default_shard_size() {
109        let size = super::default_shard_size().get();
110
111        eprintln!("size = {size}");
112
113        assert!(size > 0);
114        assert_eq!(size & (size - 1), 0);
115    }
116}