Skip to main content

sshash_lib/
partitioned_mphf.rs

1//! Partitioned Minimal Perfect Hash Function
2//!
3//! Wraps multiple PHast MPHFs, one per partition, with hash-based partition
4//! selection matching the C++ `partitioned_phf` design. Keys are assigned to
5//! partitions via Lemire fast-range reduction on an independent rapidhash,
6//! and each partition's MPHF is built independently (in parallel via rayon).
7//!
8//! For datasets with fewer than `AVG_PARTITION_SIZE` keys (including all our
9//! test indices), only a single partition is created and the `get()` fast path
10//! skips the partition hash entirely.
11
12use crate::mphf_config::{Mphf, MphfHasher, build_mphf_from_vec, mphf_hasher, read_mphf};
13use rayon::prelude::*;
14use std::hash::Hash;
15use std::io::{self, Read, Write};
16
17/// Average number of keys per partition, matching C++ PTHash `partitioned_phf`.
18const AVG_PARTITION_SIZE: usize = 3_000_000;
19
20/// Seed used for partition selection, chosen to differ from PHast's internal
21/// seeds (0, 1, 2, ...) to ensure statistical independence.
22const PARTITION_HASH_SEED: u64 = 0xC6A4_A793_5BD1_E995;
23
24/// A partitioned minimal perfect hash function.
25///
26/// Splits keys into partitions by hash range (Lemire fast-range reduction),
27/// builds one PHast MPHF per partition, and routes queries transparently.
28/// Global indices are `offsets[partition] + inner_mphf.get(key)`.
29pub struct PartitionedMphf {
30    /// One PHast MPHF per partition.
31    inners: Vec<Mphf>,
32    /// Cumulative key counts: `offsets[i]` = total keys in partitions 0..i.
33    /// Length = `num_partitions + 1`.
34    offsets: Vec<usize>,
35    /// Number of partitions.
36    num_partitions: u32,
37    /// Total number of keys across all partitions.
38    num_keys: usize,
39    /// Hasher for partition selection.
40    hasher: MphfHasher,
41}
42
43impl PartitionedMphf {
44    /// Build a partitioned MPHF from an owned Vec of keys.
45    ///
46    /// If `partitioned` is false (or there are fewer than `AVG_PARTITION_SIZE`
47    /// keys), a single partition is used — equivalent to a monolithic MPHF with
48    /// zero query overhead.
49    pub fn build_from_vec<K: Hash + Clone + Send + Sync>(keys: Vec<K>, partitioned: bool) -> Self {
50        let num_keys = keys.len();
51        if num_keys == 0 {
52            return Self {
53                inners: Vec::new(),
54                offsets: vec![0],
55                num_partitions: 0,
56                num_keys: 0,
57                hasher: mphf_hasher(),
58            };
59        }
60
61        let num_partitions = if partitioned {
62            num_keys.div_ceil(AVG_PARTITION_SIZE).max(1)
63        } else {
64            1
65        };
66
67        if num_partitions == 1 {
68            // Single partition: build directly, no partitioning overhead
69            let mphf = build_mphf_from_vec(keys);
70            return Self {
71                inners: vec![mphf],
72                offsets: vec![0, num_keys],
73                num_partitions: 1,
74                num_keys,
75                hasher: mphf_hasher(),
76            };
77        }
78
79        // Multi-partition: hash-and-partition
80        let hasher = mphf_hasher();
81        let np = num_partitions as u128;
82
83        // Assign keys to partitions
84        let mut partition_keys: Vec<Vec<K>> = (0..num_partitions).map(|_| Vec::new()).collect();
85        for key in keys {
86            let hash = hasher.hash_one_with_seed(&key, PARTITION_HASH_SEED);
87            let p = ((hash as u128 * np) >> 64) as usize;
88            partition_keys[p].push(key);
89        }
90
91        // Compute cumulative offsets
92        let mut offsets = Vec::with_capacity(num_partitions + 1);
93        offsets.push(0);
94        for pk in &partition_keys {
95            let prev = *offsets.last().unwrap();
96            offsets.push(prev + pk.len());
97        }
98
99        // Build inner MPHFs in parallel (each single-threaded internally)
100        let inners: Vec<Mphf> = partition_keys
101            .into_par_iter()
102            .map(|pk| {
103                if pk.is_empty() {
104                    // Empty partition — build a trivial MPHF from empty vec
105                    build_mphf_from_vec(pk)
106                } else {
107                    build_mphf_from_vec(pk)
108                }
109            })
110            .collect();
111
112        Self {
113            inners,
114            offsets,
115            num_partitions: num_partitions as u32,
116            num_keys,
117            hasher,
118        }
119    }
120
121    /// Build a partitioned MPHF from a slice of keys (clones into Vec).
122    pub fn build_from_slice<K: Hash + Clone + Send + Sync>(keys: &[K], partitioned: bool) -> Self {
123        Self::build_from_vec(keys.to_vec(), partitioned)
124    }
125
126    /// Look up a key and return its global index in [0, num_keys).
127    ///
128    /// For keys NOT in the build set, returns `num_keys` (out-of-range sentinel).
129    /// The COMBINE-lab ph fork returns `usize::MAX` for keys that exhaust all
130    /// levels without matching, which we map to `num_keys`.
131    #[inline]
132    pub fn get<K: Hash + ?Sized>(&self, key: &K) -> usize {
133        if self.num_partitions == 1 {
134            // Fast path: skip partition hash entirely.
135            let idx = self.inners[0].get(key);
136            if idx == usize::MAX { return self.num_keys; }
137            return idx;
138        }
139        let p = self.partition_for(key);
140        let idx = self.inners[p].get(key);
141        if idx == usize::MAX { return self.num_keys; }
142        self.offsets[p] + idx
143    }
144
145    /// Total number of keys.
146    pub fn num_keys(&self) -> usize {
147        self.num_keys
148    }
149
150    /// Number of partitions.
151    pub fn num_partitions(&self) -> u32 {
152        self.num_partitions
153    }
154
155    /// Compute which partition a key belongs to (Lemire fast-range reduction).
156    #[inline]
157    fn partition_for<K: Hash + ?Sized>(&self, key: &K) -> usize {
158        let hash = self.hasher.hash_one_with_seed(key, PARTITION_HASH_SEED);
159        ((hash as u128 * self.num_partitions as u128) >> 64) as usize
160    }
161
162    /// Serialize to a writer.
163    pub fn write_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
164        // Header
165        writer.write_all(&self.num_partitions.to_le_bytes())?;
166        writer.write_all(&(self.num_keys as u64).to_le_bytes())?;
167
168        // Offsets (num_partitions + 1 entries)
169        for &off in &self.offsets {
170            writer.write_all(&(off as u64).to_le_bytes())?;
171        }
172
173        // Inner MPHFs
174        for mphf in &self.inners {
175            mphf.write(writer)?;
176        }
177
178        Ok(())
179    }
180
181    /// Deserialize from a reader.
182    pub fn read_from(reader: &mut dyn Read) -> io::Result<Self> {
183        let mut buf4 = [0u8; 4];
184        let mut buf8 = [0u8; 8];
185
186        reader.read_exact(&mut buf4)?;
187        let num_partitions = u32::from_le_bytes(buf4);
188
189        reader.read_exact(&mut buf8)?;
190        let num_keys = u64::from_le_bytes(buf8) as usize;
191
192        let num_offsets = num_partitions as usize + 1;
193        let mut offsets = Vec::with_capacity(num_offsets);
194        for _ in 0..num_offsets {
195            reader.read_exact(&mut buf8)?;
196            offsets.push(u64::from_le_bytes(buf8) as usize);
197        }
198
199        let mut inners = Vec::with_capacity(num_partitions as usize);
200        for _ in 0..num_partitions {
201            inners.push(read_mphf(reader)?);
202        }
203
204        Ok(Self {
205            inners,
206            offsets,
207            num_partitions,
208            num_keys,
209            hasher: mphf_hasher(),
210        })
211    }
212
213    /// Estimate serialized byte size (for container offset table pre-allocation).
214    pub fn write_bytes(&self) -> usize {
215        let header = 4 + 8; // num_partitions + num_keys
216        let offsets = (self.offsets.len()) * 8;
217        let mphfs: usize = self.inners.iter().map(|m| m.write_bytes()).sum();
218        header + offsets + mphfs
219    }
220}
221
222/// Extension trait to hash with a specific seed.
223///
224/// PHast's `BuildRapidHash` implements `BuildSeededHasher`, so we use
225/// `build_hasher(seed)` to get a seeded hasher for partition selection.
226trait HashOneWithSeed {
227    fn hash_one_with_seed<K: Hash + ?Sized>(&self, key: &K, seed: u64) -> u64;
228}
229
230impl HashOneWithSeed for MphfHasher {
231    #[inline]
232    fn hash_one_with_seed<K: Hash + ?Sized>(&self, key: &K, seed: u64) -> u64 {
233        use ph::BuildSeededHasher;
234        use std::hash::Hasher;
235        let mut hasher = self.build_hasher(seed);
236        key.hash(&mut hasher);
237        hasher.finish()
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn test_partition_count_math() {
247        // < AVG_PARTITION_SIZE → 1 partition
248        assert_eq!(1_000_000usize.div_ceil(AVG_PARTITION_SIZE).max(1), 1);
249        // Exactly AVG_PARTITION_SIZE → 1 partition
250        assert_eq!(AVG_PARTITION_SIZE.div_ceil(AVG_PARTITION_SIZE).max(1), 1);
251        // Just over → 2 partitions
252        assert_eq!((AVG_PARTITION_SIZE + 1).div_ceil(AVG_PARTITION_SIZE).max(1), 2);
253        // 10M → 4 partitions
254        assert_eq!(10_000_000usize.div_ceil(AVG_PARTITION_SIZE).max(1), 4);
255    }
256
257    #[test]
258    fn test_single_partition_roundtrip() {
259        let keys: Vec<u64> = (0..1000).collect();
260        let pmphf = PartitionedMphf::build_from_vec(keys.clone(), true);
261
262        assert_eq!(pmphf.num_partitions(), 1);
263        assert_eq!(pmphf.num_keys(), 1000);
264
265        // All keys should get unique indices in [0, 1000)
266        let mut indices: Vec<usize> = keys.iter().map(|k| pmphf.get(k)).collect();
267        indices.sort();
268        indices.dedup();
269        assert_eq!(indices.len(), 1000);
270        assert!(indices.iter().all(|&i| i < 1000));
271    }
272
273    #[test]
274    fn test_monolithic_flag() {
275        let keys: Vec<u64> = (0..100).collect();
276        let pmphf = PartitionedMphf::build_from_vec(keys.clone(), false);
277
278        assert_eq!(pmphf.num_partitions(), 1);
279        assert_eq!(pmphf.num_keys(), 100);
280
281        let mut indices: Vec<usize> = keys.iter().map(|k| pmphf.get(k)).collect();
282        indices.sort();
283        indices.dedup();
284        assert_eq!(indices.len(), 100);
285    }
286
287    #[test]
288    fn test_serialization_roundtrip() {
289        let keys: Vec<u64> = (0..500).collect();
290        let pmphf = PartitionedMphf::build_from_vec(keys.clone(), true);
291
292        let mut buf = Vec::new();
293        pmphf.write_to(&mut buf).unwrap();
294
295        let pmphf2 = PartitionedMphf::read_from(&mut buf.as_slice()).unwrap();
296
297        assert_eq!(pmphf.num_partitions(), pmphf2.num_partitions());
298        assert_eq!(pmphf.num_keys(), pmphf2.num_keys());
299
300        // Verify same results
301        for key in &keys {
302            assert_eq!(pmphf.get(key), pmphf2.get(key));
303        }
304    }
305
306    #[test]
307    fn test_empty() {
308        let keys: Vec<u64> = Vec::new();
309        let pmphf = PartitionedMphf::build_from_vec(keys, true);
310        assert_eq!(pmphf.num_partitions(), 0);
311        assert_eq!(pmphf.num_keys(), 0);
312    }
313
314    #[test]
315    fn test_write_bytes_sanity() {
316        let keys: Vec<u64> = (0..100).collect();
317        let pmphf = PartitionedMphf::build_from_vec(keys, true);
318
319        let mut buf = Vec::new();
320        pmphf.write_to(&mut buf).unwrap();
321
322        // write_bytes() is an estimate (PHast's write_bytes() is approximate)
323        // Just verify it's in a reasonable range
324        let actual = buf.len();
325        let estimate = pmphf.write_bytes();
326        assert!(estimate > 0, "estimate should be positive");
327        // The estimate comes from PHast's write_bytes() which may not be exact
328        assert!(
329            actual > 0 && estimate > 0,
330            "both actual ({actual}) and estimate ({estimate}) should be positive"
331        );
332    }
333}