1use std::collections::hash_map::RandomState;
2use std::hash::{BuildHasher, Hash, Hasher};
3use std::mem;
4
5const GROUP_SIZE: usize = 8;
6const MAX_LOAD_FACTOR: f64 = 0.875;
7
8const EMPTY: u8 = 0b1111_1111;
10const DELETED: u8 = 0b1000_0000;
11
12fn h2(hash: u64) -> u8 {
15 (hash & 0x7F) as u8
16}
17
18struct Group<K, V> {
22 ctrl: [u8; GROUP_SIZE],
24 slots: [Option<(K, V)>; GROUP_SIZE],
25}
26
27impl<K, V> Group<K, V> {
28 fn new() -> Self {
29 Group {
30 ctrl: [EMPTY; GROUP_SIZE],
31 slots: std::array::from_fn(|_| None),
32 }
33 }
34
35 fn match_h2(&self, needle: u8) -> u8 {
41 let mut mask = 0u8;
42 for i in 0..GROUP_SIZE {
43 if self.ctrl[i] == needle {
44 mask |= 1 << i;
45 }
46 }
47 mask
48 }
49
50 fn match_empty(&self) -> u8 {
52 let mut mask = 0u8;
53 for i in 0..GROUP_SIZE {
54 if self.ctrl[i] == EMPTY {
55 mask |= 1 << i;
56 }
57 }
58 mask
59 }
60
61 fn match_empty_or_deleted(&self) -> u8 {
63 let mut mask = 0u8;
64 for i in 0..GROUP_SIZE {
65 if self.ctrl[i] & 0x80 != 0 {
68 mask |= 1 << i;
69 }
70 }
71 mask
72 }
73}
74
75pub struct SwissTable<K, V> {
76 groups: Vec<Group<K, V>>,
77 num_groups: usize,
78 len: usize,
79 hash_builder: RandomState,
80}
81
82impl<K, V> SwissTable<K, V>
83where
84 K: Eq + Hash,
85 V: Clone,
86{
87 pub fn new() -> Self {
88 let num_groups = 1;
89 SwissTable {
90 groups: vec![Group::new()],
91 num_groups,
92 len: 0,
93 hash_builder: RandomState::new(),
94 }
95 }
96
97 pub fn len(&self) -> usize {
98 self.len
99 }
100
101 pub fn is_empty(&self) -> bool {
102 self.len == 0
103 }
104
105 pub fn capacity(&self) -> usize {
106 self.num_groups * GROUP_SIZE
107 }
108
109 fn hash_key(&self, key: &K) -> u64 {
110 let mut hasher = self.hash_builder.build_hasher();
111 key.hash(&mut hasher);
112 hasher.finish()
113 }
114
115 fn find(&self, key: &K) -> Option<(usize, usize)> {
116 let hash = self.hash_key(key);
117 let h2_val = h2(hash);
118 let start_group = (hash >> 7) as usize % self.num_groups;
119
120 for i in 0..self.num_groups {
121 let gi = (start_group + i) % self.num_groups;
122 let group = &self.groups[gi];
123
124 let mut candidates = group.match_h2(h2_val);
125 while candidates != 0 {
126 let slot = candidates.trailing_zeros() as usize;
127 candidates &= candidates - 1;
128
129 if let Some((ref k, _)) = group.slots[slot] {
130 if k == key {
131 return Some((gi, slot));
132 }
133 }
134 }
135
136 if group.match_empty() != 0 {
137 return None;
138 }
139 }
140 None
141 }
142
143 pub fn get(&self, key: &K) -> Option<&V> {
144 let (gi, si) = self.find(key)?;
145 self.groups[gi].slots[si].as_ref().map(|(_, v)| v)
146 }
147
148 pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
149 let (gi, si) = self.find(key)?;
150 self.groups[gi].slots[si].as_mut().map(|(_, v)| v)
151 }
152
153 pub fn contains_key(&self, key: &K) -> bool {
154 self.find(key).is_some()
155 }
156
157 pub fn insert(&mut self, key: K, value: V) -> Option<V> {
158 if self.len >= (self.capacity() as f64 * MAX_LOAD_FACTOR) as usize {
159 self.grow();
160 }
161
162 let hash = self.hash_key(&key);
163 let h2_val = h2(hash);
164 let start_group = (hash >> 7) as usize % self.num_groups;
165
166 for i in 0..self.num_groups {
167 let gi = (start_group + i) % self.num_groups;
168 let group = &mut self.groups[gi];
169
170 let mut candidates = group.match_h2(h2_val);
171 while candidates != 0 {
172 let slot = candidates.trailing_zeros() as usize;
173 candidates &= candidates - 1;
174
175 if let Some((ref k, ref mut v)) = group.slots[slot] {
176 if *k == key {
177 let old = mem::replace(v, value);
178 return Some(old);
179 }
180 }
181 }
182
183 let available = group.match_empty_or_deleted();
184 if available != 0 {
185 let slot = available.trailing_zeros() as usize;
186 group.ctrl[slot] = h2_val;
187 group.slots[slot] = Some((key, value));
188 self.len += 1;
189 return None;
190 }
191 }
192
193 unreachable!("table should have grown before reaching here");
194 }
195
196 pub fn remove(&mut self, key: &K) -> Option<V> {
197 let (gi, si) = self.find(key)?;
198 let group = &mut self.groups[gi];
199 group.ctrl[si] = DELETED;
200 let (_, v) = group.slots[si].take().unwrap();
201 self.len -= 1;
202 Some(v)
203 }
204
205 fn grow(&mut self) {
206 let new_num_groups = self.num_groups * 2;
207 let old_groups = mem::replace(
208 &mut self.groups,
209 (0..new_num_groups).map(|_| Group::new()).collect(),
210 );
211 let old_len = self.len;
212 self.num_groups = new_num_groups;
213 self.len = 0;
214
215 for group in old_groups {
216 for slot in group.slots {
217 if let Some((k, v)) = slot {
218 self.insert(k, v);
219 }
220 }
221 }
222
223 debug_assert_eq!(self.len, old_len);
224 }
225}
226
227impl<K, V> Default for SwissTable<K, V>
228where
229 K: Eq + Hash,
230 V: Clone,
231{
232 fn default() -> Self {
233 Self::new()
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[test]
242 fn empty_table() {
243 let table: SwissTable<String, i32> = SwissTable::new();
244 assert_eq!(table.len(), 0);
245 assert!(table.is_empty());
246 assert_eq!(table.get(&"hello".to_string()), None);
247 }
248
249 #[test]
250 fn insert_and_get() {
251 let mut table = SwissTable::new();
252 table.insert("foo", 1);
253 table.insert("bar", 2);
254 table.insert("baz", 3);
255
256 assert_eq!(table.get(&"foo"), Some(&1));
257 assert_eq!(table.get(&"bar"), Some(&2));
258 assert_eq!(table.get(&"baz"), Some(&3));
259 assert_eq!(table.get(&"missing"), None);
260 assert_eq!(table.len(), 3);
261 }
262
263 #[test]
264 fn insert_returns_old_value() {
265 let mut table = SwissTable::new();
266 assert_eq!(table.insert("key", 10), None);
267 assert_eq!(table.insert("key", 20), Some(10));
268 assert_eq!(table.get(&"key"), Some(&20));
269 assert_eq!(table.len(), 1);
270 }
271
272 #[test]
273 fn remove() {
274 let mut table = SwissTable::new();
275 table.insert("a", 1);
276 table.insert("b", 2);
277 table.insert("c", 3);
278
279 assert_eq!(table.remove(&"b"), Some(2));
280 assert_eq!(table.get(&"b"), None);
281 assert!(!table.contains_key(&"b"));
282 assert_eq!(table.len(), 2);
283
284 assert_eq!(table.remove(&"b"), None);
285 }
286
287 #[test]
288 fn remove_doesnt_break_probe_chain() {
289 let mut table = SwissTable::new();
290 for i in 0..20 {
291 table.insert(i, i * 100);
292 }
293
294 table.remove(&5);
295 table.remove(&10);
296 table.remove(&15);
297
298 for i in 0..20 {
299 if i == 5 || i == 10 || i == 15 {
300 assert_eq!(table.get(&i), None);
301 } else {
302 assert_eq!(table.get(&i), Some(&(i * 100)));
303 }
304 }
305 }
306
307 #[test]
308 fn contains_key() {
309 let mut table = SwissTable::new();
310 table.insert("present", 42);
311
312 assert!(table.contains_key(&"present"));
313 assert!(!table.contains_key(&"absent"));
314 }
315
316 #[test]
317 fn get_mut() {
318 let mut table = SwissTable::new();
319 table.insert("counter", 0);
320
321 if let Some(v) = table.get_mut(&"counter") {
322 *v += 10;
323 }
324
325 assert_eq!(table.get(&"counter"), Some(&10));
326 }
327
328 #[test]
329 fn grow_under_load() {
330 let mut table = SwissTable::new();
331
332 for i in 0..100 {
333 table.insert(i, i.to_string());
334 }
335
336 assert_eq!(table.len(), 100);
337
338 for i in 0..100 {
339 assert_eq!(table.get(&i), Some(&i.to_string()));
340 }
341 }
342
343 #[test]
344 fn string_keys() {
345 let mut table = SwissTable::new();
346
347 for i in 0..50 {
348 table.insert(format!("key_{}", i), i);
349 }
350
351 for i in 0..50 {
352 let key = format!("key_{}", i);
353 assert_eq!(table.get(&key), Some(&i));
354 }
355 }
356
357 #[test]
358 fn insert_after_remove() {
359 let mut table = SwissTable::new();
360 table.insert("reuse", 1);
361 table.remove(&"reuse");
362 table.insert("reuse", 2);
363
364 assert_eq!(table.get(&"reuse"), Some(&2));
365 assert_eq!(table.len(), 1);
366 }
367
368 #[test]
369 fn large_scale_insert_remove() {
370 let mut table = SwissTable::new();
371
372 for i in 0..1000 {
373 table.insert(i, i);
374 }
375 assert_eq!(table.len(), 1000);
376
377 for i in (0..1000).step_by(2) {
378 table.remove(&i);
379 }
380 assert_eq!(table.len(), 500);
381
382 for i in 0..1000 {
383 if i % 2 == 0 {
384 assert_eq!(table.get(&i), None);
385 } else {
386 assert_eq!(table.get(&i), Some(&i));
387 }
388 }
389 }
390}