Skip to main content

swiss_table/
lib.rs

1use std::collections::hash_map::RandomState;
2use std::hash::{BuildHasher, Hash, Hasher};
3use std::mem;
4
5const GROUP_SIZE: usize = 8;
6const MAX_LOAD_FACTOR: f64 = 0.875;
7
8// control bytes
9const EMPTY: u8 = 0b1111_1111;
10const DELETED: u8 = 0b1000_0000;
11
12/// Extract h2 from a hash.
13/// Lowest 7 bits.
14fn h2(hash: u64) -> u8 {
15    (hash & 0x7F) as u8
16}
17
18/// A group is 8 control bytes + 8 slots.
19/// The control bytes summarize the state
20/// of the group so we can do fast parallel matching.
21struct Group<K, V> {
22    // Each byte is either EMPTY, DELETED, or an h2 fingerprint
23    ctrl: [u8; GROUP_SIZE],
24    slots: [Option<(K, V)>; GROUP_SIZE],
25}
26
27impl<K, V> Group<K, V> {
28    fn new() -> Self {
29        Group {
30            ctrl: [EMPTY; GROUP_SIZE],
31            slots: std::array::from_fn(|_| None),
32        }
33    }
34
35    /// SIMD-style parallel match.
36    /// In a real implementation, this would be a single SIMD instruction
37    /// Just simulating it here: return a bitmask of which slots have ctrl == needle.
38    ///
39    /// e.g., if slots 1 and 5 match, returns 0b00100010
40    fn match_h2(&self, needle: u8) -> u8 {
41        let mut mask = 0u8;
42        for i in 0..GROUP_SIZE {
43            if self.ctrl[i] == needle {
44                mask |= 1 << i;
45            }
46        }
47        mask
48    }
49
50    /// Return bitmask of empty slots.
51    fn match_empty(&self) -> u8 {
52        let mut mask = 0u8;
53        for i in 0..GROUP_SIZE {
54            if self.ctrl[i] == EMPTY {
55                mask |= 1 << i;
56            }
57        }
58        mask
59    }
60
61    /// Return bitmask of empty OR deleted slots (available for insertion).
62    fn match_empty_or_deleted(&self) -> u8 {
63        let mut mask = 0u8;
64        for i in 0..GROUP_SIZE {
65            // Both EMPTY and DELETED have the high bit set.
66            // which will never happen for a valid h2 value.
67            if self.ctrl[i] & 0x80 != 0 {
68                mask |= 1 << i;
69            }
70        }
71        mask
72    }
73}
74
75pub struct SwissTable<K, V> {
76    groups: Vec<Group<K, V>>,
77    num_groups: usize,
78    len: usize,
79    hash_builder: RandomState,
80}
81
82impl<K, V> SwissTable<K, V>
83where
84    K: Eq + Hash,
85    V: Clone,
86{
87    pub fn new() -> Self {
88        let num_groups = 1;
89        SwissTable {
90            groups: vec![Group::new()],
91            num_groups,
92            len: 0,
93            hash_builder: RandomState::new(),
94        }
95    }
96
97    pub fn len(&self) -> usize {
98        self.len
99    }
100
101    pub fn is_empty(&self) -> bool {
102        self.len == 0
103    }
104
105    pub fn capacity(&self) -> usize {
106        self.num_groups * GROUP_SIZE
107    }
108
109    fn hash_key(&self, key: &K) -> u64 {
110        let mut hasher = self.hash_builder.build_hasher();
111        key.hash(&mut hasher);
112        hasher.finish()
113    }
114
115    fn find(&self, key: &K) -> Option<(usize, usize)> {
116        let hash = self.hash_key(key);
117        let h2_val = h2(hash);
118        let start_group = (hash >> 7) as usize % self.num_groups;
119
120        for i in 0..self.num_groups {
121            let gi = (start_group + i) % self.num_groups;
122            let group = &self.groups[gi];
123
124            let mut candidates = group.match_h2(h2_val);
125            while candidates != 0 {
126                let slot = candidates.trailing_zeros() as usize;
127                candidates &= candidates - 1;
128
129                if let Some((ref k, _)) = group.slots[slot] {
130                    if k == key {
131                        return Some((gi, slot));
132                    }
133                }
134            }
135
136            if group.match_empty() != 0 {
137                return None;
138            }
139        }
140        None
141    }
142
143    pub fn get(&self, key: &K) -> Option<&V> {
144        let (gi, si) = self.find(key)?;
145        self.groups[gi].slots[si].as_ref().map(|(_, v)| v)
146    }
147
148    pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
149        let (gi, si) = self.find(key)?;
150        self.groups[gi].slots[si].as_mut().map(|(_, v)| v)
151    }
152
153    pub fn contains_key(&self, key: &K) -> bool {
154        self.find(key).is_some()
155    }
156
157    pub fn insert(&mut self, key: K, value: V) -> Option<V> {
158        if self.len >= (self.capacity() as f64 * MAX_LOAD_FACTOR) as usize {
159            self.grow();
160        }
161
162        let hash = self.hash_key(&key);
163        let h2_val = h2(hash);
164        let start_group = (hash >> 7) as usize % self.num_groups;
165
166        for i in 0..self.num_groups {
167            let gi = (start_group + i) % self.num_groups;
168            let group = &mut self.groups[gi];
169
170            let mut candidates = group.match_h2(h2_val);
171            while candidates != 0 {
172                let slot = candidates.trailing_zeros() as usize;
173                candidates &= candidates - 1;
174
175                if let Some((ref k, ref mut v)) = group.slots[slot] {
176                    if *k == key {
177                        let old = mem::replace(v, value);
178                        return Some(old);
179                    }
180                }
181            }
182
183            let available = group.match_empty_or_deleted();
184            if available != 0 {
185                let slot = available.trailing_zeros() as usize;
186                group.ctrl[slot] = h2_val;
187                group.slots[slot] = Some((key, value));
188                self.len += 1;
189                return None;
190            }
191        }
192
193        unreachable!("table should have grown before reaching here");
194    }
195
196    pub fn remove(&mut self, key: &K) -> Option<V> {
197        let (gi, si) = self.find(key)?;
198        let group = &mut self.groups[gi];
199        group.ctrl[si] = DELETED;
200        let (_, v) = group.slots[si].take().unwrap();
201        self.len -= 1;
202        Some(v)
203    }
204
205    fn grow(&mut self) {
206        let new_num_groups = self.num_groups * 2;
207        let old_groups = mem::replace(
208            &mut self.groups,
209            (0..new_num_groups).map(|_| Group::new()).collect(),
210        );
211        let old_len = self.len;
212        self.num_groups = new_num_groups;
213        self.len = 0;
214
215        for group in old_groups {
216            for slot in group.slots {
217                if let Some((k, v)) = slot {
218                    self.insert(k, v);
219                }
220            }
221        }
222
223        debug_assert_eq!(self.len, old_len);
224    }
225}
226
227impl<K, V> Default for SwissTable<K, V>
228where
229    K: Eq + Hash,
230    V: Clone,
231{
232    fn default() -> Self {
233        Self::new()
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[test]
242    fn empty_table() {
243        let table: SwissTable<String, i32> = SwissTable::new();
244        assert_eq!(table.len(), 0);
245        assert!(table.is_empty());
246        assert_eq!(table.get(&"hello".to_string()), None);
247    }
248
249    #[test]
250    fn insert_and_get() {
251        let mut table = SwissTable::new();
252        table.insert("foo", 1);
253        table.insert("bar", 2);
254        table.insert("baz", 3);
255
256        assert_eq!(table.get(&"foo"), Some(&1));
257        assert_eq!(table.get(&"bar"), Some(&2));
258        assert_eq!(table.get(&"baz"), Some(&3));
259        assert_eq!(table.get(&"missing"), None);
260        assert_eq!(table.len(), 3);
261    }
262
263    #[test]
264    fn insert_returns_old_value() {
265        let mut table = SwissTable::new();
266        assert_eq!(table.insert("key", 10), None);
267        assert_eq!(table.insert("key", 20), Some(10));
268        assert_eq!(table.get(&"key"), Some(&20));
269        assert_eq!(table.len(), 1);
270    }
271
272    #[test]
273    fn remove() {
274        let mut table = SwissTable::new();
275        table.insert("a", 1);
276        table.insert("b", 2);
277        table.insert("c", 3);
278
279        assert_eq!(table.remove(&"b"), Some(2));
280        assert_eq!(table.get(&"b"), None);
281        assert!(!table.contains_key(&"b"));
282        assert_eq!(table.len(), 2);
283
284        assert_eq!(table.remove(&"b"), None);
285    }
286
287    #[test]
288    fn remove_doesnt_break_probe_chain() {
289        let mut table = SwissTable::new();
290        for i in 0..20 {
291            table.insert(i, i * 100);
292        }
293
294        table.remove(&5);
295        table.remove(&10);
296        table.remove(&15);
297
298        for i in 0..20 {
299            if i == 5 || i == 10 || i == 15 {
300                assert_eq!(table.get(&i), None);
301            } else {
302                assert_eq!(table.get(&i), Some(&(i * 100)));
303            }
304        }
305    }
306
307    #[test]
308    fn contains_key() {
309        let mut table = SwissTable::new();
310        table.insert("present", 42);
311
312        assert!(table.contains_key(&"present"));
313        assert!(!table.contains_key(&"absent"));
314    }
315
316    #[test]
317    fn get_mut() {
318        let mut table = SwissTable::new();
319        table.insert("counter", 0);
320
321        if let Some(v) = table.get_mut(&"counter") {
322            *v += 10;
323        }
324
325        assert_eq!(table.get(&"counter"), Some(&10));
326    }
327
328    #[test]
329    fn grow_under_load() {
330        let mut table = SwissTable::new();
331
332        for i in 0..100 {
333            table.insert(i, i.to_string());
334        }
335
336        assert_eq!(table.len(), 100);
337
338        for i in 0..100 {
339            assert_eq!(table.get(&i), Some(&i.to_string()));
340        }
341    }
342
343    #[test]
344    fn string_keys() {
345        let mut table = SwissTable::new();
346
347        for i in 0..50 {
348            table.insert(format!("key_{}", i), i);
349        }
350
351        for i in 0..50 {
352            let key = format!("key_{}", i);
353            assert_eq!(table.get(&key), Some(&i));
354        }
355    }
356
357    #[test]
358    fn insert_after_remove() {
359        let mut table = SwissTable::new();
360        table.insert("reuse", 1);
361        table.remove(&"reuse");
362        table.insert("reuse", 2);
363
364        assert_eq!(table.get(&"reuse"), Some(&2));
365        assert_eq!(table.len(), 1);
366    }
367
368    #[test]
369    fn large_scale_insert_remove() {
370        let mut table = SwissTable::new();
371
372        for i in 0..1000 {
373            table.insert(i, i);
374        }
375        assert_eq!(table.len(), 1000);
376
377        for i in (0..1000).step_by(2) {
378            table.remove(&i);
379        }
380        assert_eq!(table.len(), 500);
381
382        for i in 0..1000 {
383            if i % 2 == 0 {
384                assert_eq!(table.get(&i), None);
385            } else {
386                assert_eq!(table.get(&i), Some(&i));
387            }
388        }
389    }
390}