Skip to main content

zigzag_alloc/collections/
hash_map.rs

1//! Swiss-table inspired open-addressing hash map.
2//!
3//! [`ExHashMap`] is a high-performance hash map that uses a separate
4//! *control byte* array to enable parallel SIMD key probing — the same
5//! technique used by Google's Abseil `flat_hash_map` and Rust's
6//! `hashbrown`.
7//!
8//! ## Algorithm Overview
9//!
10//! Each slot has a one-byte *control byte* stored in a contiguous `ctrl`
11//! array.  The control byte is either:
12//!
13//! * [`CTRL_EMPTY`] (`0x80`) — the slot is vacant.
14//! * A 7-bit hash tag `h2(hash)` — the slot is occupied by a key whose
15//!   upper-57-bit hash matches `h2`.
16//!
17//! Lookups load a 16-byte [`Group`] from `ctrl` and use SIMD to find all
18//! slots whose tag matches the query in a single instruction, drastically
19//! reducing branch mispredictions and cache misses compared to traditional
20//! chaining or linear probing.
21//!
22//! ## Load Factor
23//!
24//! The table grows when `len * 8 >= cap * 7` (87.5 % load factor).
25//!
26//! ## Context Trait
27//!
28//! Hashing and equality are provided by the [`HashContext<K>`] trait rather
29//! than being hard-coded via `Hash` / `Eq`.  This allows callers to choose
30//! domain-specific hash functions without wrapper types.
31
32use core::{alloc::Layout, marker::PhantomData, mem::MaybeUninit, ptr::NonNull};
33
34use crate::alloc::allocator::Allocator;
35use crate::simd::{self, Group, GROUP_WIDTH, CTRL_EMPTY};
36use super::HashContext;
37
38/// Returns the low bits of `hash` used to select the initial probe position.
39#[inline]
40fn h1(hash: u64) -> usize { hash as usize }
41
42/// Returns the 7-bit control tag stored in the `ctrl` array for an occupied slot.
43///
44/// Uses bits `[63:57]` of the hash so that `h1` and `h2` cover different
45/// parts of the hash value, reducing correlation.
46#[inline]
47fn h2(hash: u64) -> u8 { ((hash >> 57) as u8) & 0x7F }
48
49/// A single key-value slot in the hash map.
50///
51/// Slots are `MaybeUninit` because the map controls initialisation via the
52/// parallel `ctrl` array; a slot is initialised if and only if its
53/// corresponding control byte is not [`CTRL_EMPTY`].
54struct Slot<K, V> {
55    key: MaybeUninit<K>,
56    val: MaybeUninit<V>,
57}
58
59/// A high-performance open-addressing hash map with SIMD probing.
60///
61/// # Type Parameters
62///
63/// * `'a` — Lifetime of the allocator reference.
64/// * `K`  — Key type.
65/// * `V`  — Value type.
66/// * `C`  — [`HashContext<K>`] providing hash and equality functions.
67///
68/// # Memory Layout
69///
70/// Two separate heap allocations:
71/// 1. **`ctrl`** — `cap + GROUP_WIDTH` control bytes.  The extra `GROUP_WIDTH`
72///    bytes at the end are mirror copies of the first `GROUP_WIDTH` control
73///    bytes, enabling SIMD group loads at the table boundary without
74///    out-of-bounds access.
75/// 2. **`data`** — `cap` [`Slot<K, V>`] entries.
76///
77/// Both allocations are freed on drop.
78pub struct ExHashMap<'a, K, V, C: HashContext<K>> {
79    /// Pointer to the control-byte array (`cap + GROUP_WIDTH` bytes).
80    ctrl:    NonNull<u8>,
81    /// Pointer to the slot array (`cap` entries).
82    data:    NonNull<Slot<K, V>>,
83    /// Current table capacity (always a power of two when non-zero).
84    cap:     usize,
85    /// Number of occupied slots.
86    len:     usize,
87    /// Allocator reference used for all internal allocations.
88    alloc:   &'a dyn Allocator,
89    /// Hashing and equality context.
90    ctx:     C,
91    _marker: PhantomData<(K, V)>,
92}
93
94impl<'a, K, V, C: HashContext<K>> ExHashMap<'a, K, V, C> {
95    /// Creates a new, empty map that will allocate through `alloc`.
96    ///
97    /// No memory is allocated until the first insertion.
98    pub fn new(alloc: &'a dyn Allocator, ctx: C) -> Self {
99        Self {
100            ctrl:    NonNull::dangling(),
101            data:    NonNull::dangling(),
102            cap:     0,
103            len:     0,
104            alloc,
105            ctx,
106            _marker: PhantomData,
107        }
108    }
109
110    /// Returns the number of key-value pairs in the map.
111    #[inline] pub fn len(&self)      -> usize { self.len }
112    /// Returns the current table capacity.
113    #[inline] pub fn capacity(&self) -> usize { self.cap }
114    /// Returns `true` if the map contains no entries.
115    #[inline] pub fn is_empty(&self) -> bool  { self.len == 0 }
116
117    /// Returns a reference to the value associated with `key`, or `None`.
118    pub fn get(&self, key: &K) -> Option<&V> {
119        let (idx, _) = self.find(key)?;
120        // SAFETY: `find` returns an index whose control byte is not
121        // `CTRL_EMPTY`, so the slot at `idx` is fully initialised.
122        Some(unsafe { (*self.data.as_ptr().add(idx)).val.assume_init_ref() })
123    }
124
125    /// Returns a mutable reference to the value associated with `key`, or `None`.
126    pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
127        let (idx, _) = self.find(key)?;
128        // SAFETY: Same as `get`; unique access via `&mut self`.
129        Some(unsafe { (*self.data.as_ptr().add(idx)).val.assume_init_mut() })
130    }
131
132    /// Returns `true` if the map contains an entry for `key`.
133    pub fn contains_key(&self, key: &K) -> bool { self.find(key).is_some() }
134
135    /// Inserts `key`→`val`, returning the previous value for `key` if it existed.
136    ///
137    /// # Panics
138    ///
139    /// Panics if the allocator fails during table growth.
140    pub fn insert(&mut self, key: K, val: V) -> Option<V> {
141        self.try_insert(key, val).unwrap_or_else(|_| panic!("HashMap: OOM"))
142    }
143
144    /// Attempts to insert `key`→`val`.
145    ///
146    /// Returns:
147    /// * `Ok(None)` — new entry inserted.
148    /// * `Ok(Some(old))` — key already existed; old value returned.
149    /// * `Err((key, val))` — allocation failed; inputs returned to caller.
150    pub fn try_insert(&mut self, key: K, val: V) -> Result<Option<V>, (K, V)> {
151        // Grow before reaching the 87.5 % load threshold.
152        if self.cap == 0 || self.len * 8 >= self.cap * 7 {
153            if !self.try_grow() { return Err((key, val)); }
154        }
155        let hash = self.ctx.hash(&key);
156
157        // If the key already exists, overwrite its value.
158        if let Some((idx, _)) = self.find(&key) {
159            // SAFETY: Slot at `idx` is fully initialised (found by `find`).
160            let old = unsafe { (*self.data.as_ptr().add(idx)).val.assume_init_read() };
161            unsafe { (*self.data.as_ptr().add(idx)).val = MaybeUninit::new(val) };
162            drop(key);
163            return Ok(Some(old));
164        }
165
166        let slot = self.find_empty_slot(hash);
167        unsafe {
168            // SAFETY: `slot` is an empty slot within the allocated `data` array.
169            (*self.data.as_ptr().add(slot)).key = MaybeUninit::new(key);
170            (*self.data.as_ptr().add(slot)).val = MaybeUninit::new(val);
171            // SAFETY: `slot < cap`; `set_ctrl` maintains the mirror invariant.
172            self.set_ctrl(slot, h2(hash));
173        }
174        self.len += 1;
175        Ok(None)
176    }
177
178    /// Removes the entry for `key` and returns its value, or `None`.
179    ///
180    /// Uses backward-shift deletion to preserve the probing invariant without
181    /// a tombstone mechanism.
182    pub fn remove(&mut self, key: &K) -> Option<V> {
183        let (idx, _) = self.find(key)?;
184
185        // SAFETY: `find` guarantees the slot at `idx` is fully initialised.
186        let val = unsafe {
187            let s = self.data.as_ptr().add(idx);
188            let v = (*s).val.assume_init_read();
189            (*s).key.assume_init_drop();
190            v
191        };
192        self.len -= 1;
193
194        // Backward-shift deletion: slide subsequent elements back until we hit
195        // an empty slot or an element that is already at its ideal position.
196        let mask = self.cap - 1;
197        let mut cur = idx;
198        loop {
199            let nxt      = (cur + 1) & mask;
200            let nxt_ctrl = unsafe { *self.ctrl.as_ptr().add(nxt) };
201            if nxt_ctrl == CTRL_EMPTY { break; }
202
203            let nxt_ideal = {
204                // SAFETY: The slot at `nxt` is occupied (ctrl != CTRL_EMPTY).
205                let k = unsafe { (*self.data.as_ptr().add(nxt)).key.assume_init_ref() };
206                h1(self.ctx.hash(k)) & mask
207            };
208            if is_between(cur, nxt_ideal, nxt) { break; }
209
210            unsafe {
211                // SAFETY: Both slots are within `0..cap` and `nxt` is occupied.
212                let src  = self.data.as_ptr().add(nxt);
213                let dst  = self.data.as_ptr().add(cur);
214                let k    = (*src).key.assume_init_read();
215                let v    = (*src).val.assume_init_read();
216                (*dst).key = MaybeUninit::new(k);
217                (*dst).val = MaybeUninit::new(v);
218                self.set_ctrl(cur, nxt_ctrl);
219                self.set_ctrl(nxt, CTRL_EMPTY);
220            }
221            cur = nxt;
222        }
223        // SAFETY: `cur` is within `0..cap`.
224        unsafe { self.set_ctrl(cur, CTRL_EMPTY) };
225        Some(val)
226    }
227
228    /// Iterates over all key-value pairs, calling `f` for each.
229    ///
230    /// The iteration order is unspecified.
231    pub fn for_each<F: FnMut(&K, &V)>(&self, mut f: F) {
232        for i in 0..self.cap {
233            // SAFETY: Control byte `!= CTRL_EMPTY` means the slot is occupied
234            // and fully initialised.
235            if unsafe { *self.ctrl.as_ptr().add(i) } != CTRL_EMPTY {
236                let s = unsafe { &*self.data.as_ptr().add(i) };
237                f(unsafe { s.key.assume_init_ref() }, unsafe { s.val.assume_init_ref() });
238            }
239        }
240    }
241
242    /// Iterates over all key-value pairs with mutable access to values.
243    pub fn for_each_mut<F: FnMut(&K, &mut V)>(&mut self, mut f: F) {
244        for i in 0..self.cap {
245            if unsafe { *self.ctrl.as_ptr().add(i) } != CTRL_EMPTY {
246                let s = unsafe { &mut *self.data.as_ptr().add(i) };
247                f(unsafe { s.key.assume_init_ref() }, unsafe { s.val.assume_init_mut() });
248            }
249        }
250    }
251
252    /// Searches for `key` using SIMD group probing.
253    ///
254    /// Returns `Some((slot_index, hash))` on a hit, `None` on a miss.
255    ///
256    /// # Algorithm
257    ///
258    /// Starting at `h1(hash) & mask`, loads 16 control bytes at a time and
259    /// uses SIMD to find slots whose tag matches `h2(hash)`.  Stops when an
260    /// empty slot is found (the key cannot appear beyond an empty slot due to
261    /// the robin-hood invariant maintained by `remove`).
262    fn find(&self, key: &K) -> Option<(usize, u64)> {
263        if self.cap == 0 { return None; }
264        let hash = self.ctx.hash(key);
265        let tag  = h2(hash);
266        let mask = self.cap - 1;
267        let mut pos = h1(hash) & mask;
268        loop {
269            // SAFETY: `ctrl` is allocated with `cap + GROUP_WIDTH` bytes;
270            // `pos < cap`, so loading a 16-byte group is always within bounds.
271            let group = unsafe { Group::load(self.ctrl.as_ptr().add(pos)) };
272
273            for bit in unsafe { group.match_byte(tag) } {
274                let idx = (pos + bit) & mask;
275                // SAFETY: `idx < cap`; control byte is non-empty so slot is initialised.
276                if self.ctx.eq(
277                    unsafe { (*self.data.as_ptr().add(idx)).key.assume_init_ref() },
278                    key,
279                ) {
280                    return Some((idx, hash));
281                }
282            }
283
284            if unsafe { group.match_empty().any() } { return None; }
285            pos = (pos + GROUP_WIDTH) & mask;
286        }
287    }
288
289    /// Finds the first empty slot for a key with the given `hash`.
290    ///
291    /// Assumes the table is not full (guaranteed by the load-factor check in
292    /// `try_insert`).
293    fn find_empty_slot(&self, hash: u64) -> usize {
294        let mask = self.cap - 1;
295        let mut pos = h1(hash) & mask;
296        loop {
297            // SAFETY: Same group-load safety as `find`.
298            let group = unsafe { Group::load(self.ctrl.as_ptr().add(pos)) };
299            if let Some(bit) = unsafe { group.match_empty().lowest() } {
300                return (pos + bit) & mask;
301            }
302            pos = (pos + GROUP_WIDTH) & mask;
303        }
304    }
305
306    /// Sets the control byte at `idx` to `val` and updates the mirror copy.
307    ///
308    /// The first `GROUP_WIDTH` control bytes are mirrored at offsets
309    /// `[cap, cap + GROUP_WIDTH)` so that SIMD group loads at positions near
310    /// the end of the table correctly wrap around to the beginning.
311    ///
312    /// # Safety
313    ///
314    /// * `idx` must be in `0..cap`.
315    /// * `ctrl` must be allocated for at least `cap + GROUP_WIDTH` bytes.
316    #[inline]
317    unsafe fn set_ctrl(&mut self, idx: usize, val: u8) {
318        unsafe {
319            // SAFETY: `idx < cap < cap + GROUP_WIDTH` — always within the allocation.
320            *self.ctrl.as_ptr().add(idx) = val;
321            if idx < GROUP_WIDTH {
322                // SAFETY: `cap + idx < cap + GROUP_WIDTH` — within the mirror region.
323                *self.ctrl.as_ptr().add(self.cap + idx) = val;
324            }
325        }
326    }
327
328    /// Doubles the table capacity, rehashing all existing entries.
329    ///
330    /// Returns `true` on success, `false` if any allocation failed (in which
331    /// case the map state is unchanged).
332    #[cold]
333    fn try_grow(&mut self) -> bool {
334        let new_cap = if self.cap == 0 { GROUP_WIDTH } else { self.cap * 2 };
335
336        let ctrl_layout = match Layout::array::<u8>(new_cap + GROUP_WIDTH) {
337            Ok(l) => l, Err(_) => return false,
338        };
339        let data_layout = match Layout::array::<Slot<K, V>>(new_cap) {
340            Ok(l) => l, Err(_) => return false,
341        };
342
343        // SAFETY: Both layouts have non-zero sizes (new_cap >= GROUP_WIDTH > 0).
344        let new_ctrl = match unsafe { self.alloc.alloc(ctrl_layout) } {
345            Some(p) => p, None => return false,
346        };
347        let new_data = match unsafe { self.alloc.alloc(data_layout) } {
348            Some(p) => p.cast::<Slot<K, V>>(),
349            None => {
350                // SAFETY: `new_ctrl` was just allocated from `self.alloc` with
351                // `ctrl_layout`; releasing it before returning is correct.
352                unsafe { self.alloc.dealloc(new_ctrl, ctrl_layout) };
353                return false;
354            }
355        };
356
357        // SAFETY: `new_ctrl` is valid for `new_cap + GROUP_WIDTH` bytes.
358        unsafe { simd::fill_bytes(new_ctrl.as_ptr(), CTRL_EMPTY, new_cap + GROUP_WIDTH) };
359
360        let old_ctrl = self.ctrl;
361        let old_data = self.data;
362        let old_cap  = self.cap;
363
364        self.ctrl = new_ctrl;
365        self.data = new_data;
366        self.cap  = new_cap;
367        self.len  = 0;
368
369        // Rehash all occupied entries from the old table.
370        for i in 0..old_cap {
371            // SAFETY: `old_ctrl` is valid for `old_cap + GROUP_WIDTH` bytes.
372            let c = unsafe { *old_ctrl.as_ptr().add(i) };
373            if c != CTRL_EMPTY {
374                // SAFETY: Control byte is non-empty, so the slot is initialised.
375                let k    = unsafe { (*old_data.as_ptr().add(i)).key.assume_init_read() };
376                let v    = unsafe { (*old_data.as_ptr().add(i)).val.assume_init_read() };
377                let hash = self.ctx.hash(&k);
378                let slot = self.find_empty_slot(hash);
379                unsafe {
380                    (*self.data.as_ptr().add(slot)).key = MaybeUninit::new(k);
381                    (*self.data.as_ptr().add(slot)).val = MaybeUninit::new(v);
382                    self.set_ctrl(slot, h2(hash));
383                }
384                self.len += 1;
385            }
386        }
387
388        if old_cap > 0 {
389            // SAFETY: `old_ctrl` and `old_data` were allocated from `self.alloc`
390            // with the corresponding layouts; releasing them now is correct.
391            unsafe {
392                self.alloc.dealloc(old_ctrl, Layout::array::<u8>(old_cap + GROUP_WIDTH).unwrap());
393                self.alloc.dealloc(old_data.cast(), Layout::array::<Slot<K, V>>(old_cap).unwrap());
394            }
395        }
396        true
397    }
398}
399
400impl<K, V, C: HashContext<K>> Drop for ExHashMap<'_, K, V, C> {
401    /// Drops all live key-value pairs and releases both backing allocations.
402    fn drop(&mut self) {
403        if self.cap == 0 { return; }
404        for i in 0..self.cap {
405            // SAFETY: Control byte `!= CTRL_EMPTY` means the slot is occupied.
406            if unsafe { *self.ctrl.as_ptr().add(i) } != CTRL_EMPTY {
407                unsafe {
408                    (*self.data.as_ptr().add(i)).key.assume_init_drop();
409                    (*self.data.as_ptr().add(i)).val.assume_init_drop();
410                }
411            }
412        }
413        // SAFETY: Both pointers were obtained from `self.alloc` with these
414        // exact layouts during `try_grow`.
415        unsafe {
416            self.alloc.dealloc(self.ctrl, Layout::array::<u8>(self.cap + GROUP_WIDTH).unwrap());
417            self.alloc.dealloc(self.data.cast(), Layout::array::<Slot<K, V>>(self.cap).unwrap());
418        }
419    }
420}
421
422/// Returns `true` if `ideal` is "between" `cur` and `nxt` in a circular table
423/// of size `cap` (where positions are compared modulo `cap`).
424///
425/// Used by the backward-shift deletion algorithm to determine whether moving
426/// an element would violate the probing invariant.
427#[inline]
428fn is_between(cur: usize, ideal: usize, nxt: usize) -> bool {
429    (ideal <= cur && cur < nxt)
430        || (cur < nxt && nxt < ideal)
431        || (nxt < ideal && ideal <= cur)
432}