Skip to main content

tantivy_stacker/
shared_arena_hashmap.rs

1use std::iter::{Cloned, Filter};
2use std::mem;
3
4use super::{Addr, MemoryArena};
5use crate::fastcpy::fast_short_slice_copy;
6use crate::memory_arena::store;
7
8/// Returns the actual memory size in bytes
9/// required to create a table with a given capacity.
10/// required to create a table of size
11pub fn compute_table_memory_size(capacity: usize) -> usize {
12    capacity * mem::size_of::<KeyValue>()
13}
14
15#[cfg(not(feature = "compare_hash_only"))]
16type HashType = u32;
17
18#[cfg(feature = "compare_hash_only")]
19type HashType = u64;
20
21/// `KeyValue` is the item stored in the hash table.
22/// The key is actually a `BytesRef` object stored in an external memory arena.
23/// The `value_addr` also points to an address in the memory arena.
24#[derive(Copy, Clone)]
25struct KeyValue {
26    key_value_addr: Addr,
27    hash: HashType,
28}
29
30impl Default for KeyValue {
31    fn default() -> Self {
32        KeyValue {
33            key_value_addr: Addr::null_pointer(),
34            hash: 0,
35        }
36    }
37}
38
39impl KeyValue {
40    #[inline]
41    fn is_empty(&self) -> bool {
42        self.key_value_addr.is_null()
43    }
44    #[inline]
45    fn is_not_empty_ref(&self) -> bool {
46        !self.key_value_addr.is_null()
47    }
48}
49
50/// Customized `HashMap` with `&[u8]` keys
51///
52/// Its main particularity is that rather than storing its
53/// keys in the heap, keys are stored in a memory arena
54/// inline with the values.
55///
56/// The quirky API has the benefit of avoiding
57/// the computation of the hash of the key twice,
58/// or copying the key as long as there is no insert.
59///
60/// SharedArenaHashMap is like ArenaHashMap but gets the memory arena
61/// passed as an argument to the methods.
62/// So one MemoryArena can be shared with multiple SharedArenaHashMap.
63pub struct SharedArenaHashMap {
64    table: Vec<KeyValue>,
65    mask: usize,
66    len: usize,
67}
68
69struct LinearProbing {
70    pos: usize,
71    mask: usize,
72}
73
74impl LinearProbing {
75    #[inline]
76    fn compute(hash: HashType, mask: usize) -> LinearProbing {
77        LinearProbing {
78            pos: hash as usize,
79            mask,
80        }
81    }
82
83    #[inline]
84    fn next_probe(&mut self) -> usize {
85        // Not saving the masked version removes a dependency.
86        self.pos = self.pos.wrapping_add(1);
87        self.pos & self.mask
88    }
89}
90
91type IterNonEmpty<'a> = Filter<Cloned<std::slice::Iter<'a, KeyValue>>, fn(&KeyValue) -> bool>;
92
93pub struct Iter<'a> {
94    hashmap: &'a SharedArenaHashMap,
95    memory_arena: &'a MemoryArena,
96    inner: IterNonEmpty<'a>,
97}
98
99impl<'a> Iterator for Iter<'a> {
100    type Item = (&'a [u8], Addr);
101
102    fn next(&mut self) -> Option<Self::Item> {
103        self.inner.next().map(move |kv| {
104            let (key, offset): (&'a [u8], Addr) = self
105                .hashmap
106                .get_key_value(kv.key_value_addr, self.memory_arena);
107            (key, offset)
108        })
109    }
110}
111
112/// Returns the greatest power of two lower or equal to `n`.
113/// Except if n == 0, in that case, return 1.
114///
115/// # Panics if n == 0
116fn compute_previous_power_of_two(n: usize) -> usize {
117    assert!(n > 0);
118    let msb = (63u32 - (n as u64).leading_zeros()) as u8;
119    1 << msb
120}
121
122impl Default for SharedArenaHashMap {
123    fn default() -> Self {
124        SharedArenaHashMap::with_capacity(4)
125    }
126}
127
128impl SharedArenaHashMap {
129    pub fn with_capacity(table_size: usize) -> SharedArenaHashMap {
130        let table_size_power_of_2 = compute_previous_power_of_two(table_size);
131        let table = vec![KeyValue::default(); table_size_power_of_2];
132
133        SharedArenaHashMap {
134            table,
135            mask: table_size_power_of_2 - 1,
136            len: 0,
137        }
138    }
139
140    #[inline]
141    #[cfg(not(feature = "compare_hash_only"))]
142    fn get_hash(&self, key: &[u8]) -> HashType {
143        murmurhash32::murmurhash2(key)
144    }
145
146    #[inline]
147    #[cfg(feature = "compare_hash_only")]
148    fn get_hash(&self, key: &[u8]) -> HashType {
149        /// Since we compare only the hash we need a high quality hash.
150        use std::hash::Hasher;
151        let mut hasher = ahash::AHasher::default();
152        hasher.write(key);
153        hasher.finish() as HashType
154    }
155
156    #[inline]
157    fn probe(&self, hash: HashType) -> LinearProbing {
158        LinearProbing::compute(hash, self.mask)
159    }
160
161    #[inline]
162    pub fn mem_usage(&self) -> usize {
163        self.table.len() * mem::size_of::<KeyValue>()
164    }
165
166    #[inline]
167    fn is_saturated(&self) -> bool {
168        self.table.len() <= self.len * 2
169    }
170
171    #[inline]
172    fn get_key_value<'a>(&'a self, addr: Addr, memory_arena: &'a MemoryArena) -> (&'a [u8], Addr) {
173        let data = memory_arena.slice_from(addr);
174        let key_bytes_len_bytes = unsafe { data.get_unchecked(..2) };
175        let key_bytes_len = u16::from_le_bytes(key_bytes_len_bytes.try_into().unwrap());
176        let key_bytes: &[u8] = unsafe { data.get_unchecked(2..2 + key_bytes_len as usize) };
177        (key_bytes, addr.offset(2 + key_bytes_len as u32))
178    }
179
180    #[inline]
181    #[cfg(not(feature = "compare_hash_only"))]
182    fn get_value_addr_if_key_match(
183        &self,
184        target_key: &[u8],
185        addr: Addr,
186        memory_arena: &MemoryArena,
187    ) -> Option<Addr> {
188        use crate::fastcmp::fast_short_slice_compare;
189
190        let (stored_key, value_addr) = self.get_key_value(addr, memory_arena);
191        if fast_short_slice_compare(stored_key, target_key) {
192            Some(value_addr)
193        } else {
194            None
195        }
196    }
197    #[inline]
198    #[cfg(feature = "compare_hash_only")]
199    fn get_value_addr_if_key_match(
200        &self,
201        _target_key: &[u8],
202        addr: Addr,
203        memory_arena: &MemoryArena,
204    ) -> Option<Addr> {
205        // For the compare_hash_only feature, it would make sense to store the keys at a different
206        // memory location. Here they will just pollute the cache.
207        let data = memory_arena.slice_from(addr);
208        let key_bytes_len_bytes = &data[..2];
209        let key_bytes_len = u16::from_le_bytes(key_bytes_len_bytes.try_into().unwrap());
210        let value_addr = addr.offset(2 + key_bytes_len as u32);
211
212        Some(value_addr)
213    }
214
215    #[inline]
216    fn set_bucket(&mut self, hash: HashType, key_value_addr: Addr, bucket: usize) {
217        self.len += 1;
218
219        self.table[bucket] = KeyValue {
220            key_value_addr,
221            hash,
222        };
223    }
224
225    #[inline]
226    pub fn is_empty(&self) -> bool {
227        self.len() == 0
228    }
229
230    #[inline]
231    pub fn len(&self) -> usize {
232        self.len
233    }
234
235    #[inline]
236    pub fn iter<'a>(&'a self, memory_arena: &'a MemoryArena) -> Iter<'a> {
237        Iter {
238            inner: self
239                .table
240                .iter()
241                .cloned()
242                .filter(KeyValue::is_not_empty_ref),
243            hashmap: self,
244            memory_arena,
245        }
246    }
247
248    fn resize(&mut self) {
249        let new_len = (self.table.len() * 2).max(1 << 3);
250        let mask = new_len - 1;
251        self.mask = mask;
252        let new_table = vec![KeyValue::default(); new_len];
253        let old_table = mem::replace(&mut self.table, new_table);
254        for key_value in old_table.into_iter().filter(KeyValue::is_not_empty_ref) {
255            let mut probe = LinearProbing::compute(key_value.hash, mask);
256            loop {
257                let bucket = probe.next_probe();
258                if self.table[bucket].is_empty() {
259                    self.table[bucket] = key_value;
260                    break;
261                }
262            }
263        }
264    }
265
266    /// Get a value associated to a key.
267    #[inline]
268    pub fn get<V>(&self, key: &[u8], memory_arena: &MemoryArena) -> Option<V>
269    where V: Copy + 'static {
270        let hash = self.get_hash(key);
271        let mut probe = self.probe(hash);
272        loop {
273            let bucket = probe.next_probe();
274            let kv: KeyValue = self.table[bucket];
275            if kv.is_empty() {
276                return None;
277            } else if kv.hash == hash {
278                if let Some(val_addr) =
279                    self.get_value_addr_if_key_match(key, kv.key_value_addr, memory_arena)
280                {
281                    let v = memory_arena.read(val_addr);
282                    return Some(v);
283                }
284            }
285        }
286    }
287
288    /// `update` create a new entry for a given key if it does not exist
289    /// or updates the existing entry.
290    ///
291    /// The actual logic for this update is define in the `updater`
292    /// argument.
293    ///
294    /// If the key is not present, `updater` will receive `None` and
295    /// will be in charge of returning a default value.
296    /// If the key already as an associated value, then it will be passed
297    /// `Some(previous_value)`.
298    ///
299    /// The key will be truncated to u16::MAX bytes.
300    #[inline]
301    pub fn mutate_or_create<V>(
302        &mut self,
303        key: &[u8],
304        memory_arena: &mut MemoryArena,
305        mut updater: impl FnMut(Option<V>) -> V,
306    ) -> V
307    where
308        V: Copy + 'static,
309    {
310        if self.is_saturated() {
311            self.resize();
312        }
313        // Limit the key size to u16::MAX
314        let key = &key[..std::cmp::min(key.len(), u16::MAX as usize)];
315        let hash = self.get_hash(key);
316        let mut probe = self.probe(hash);
317        let mut bucket = probe.next_probe();
318        let mut kv: KeyValue = self.table[bucket];
319        loop {
320            if kv.is_empty() {
321                // The key does not exist yet.
322                let val = updater(None);
323                let num_bytes = std::mem::size_of::<u16>() + key.len() + std::mem::size_of::<V>();
324                let key_addr = memory_arena.allocate_space(num_bytes);
325                {
326                    let data = memory_arena.slice_mut(key_addr, num_bytes);
327                    let key_len_bytes: [u8; 2] = (key.len() as u16).to_le_bytes();
328                    data[..2].copy_from_slice(&key_len_bytes);
329                    let stop = 2 + key.len();
330                    fast_short_slice_copy(key, &mut data[2..stop]);
331                    store(&mut data[stop..], val);
332                }
333
334                self.set_bucket(hash, key_addr, bucket);
335                return val;
336            }
337            if kv.hash == hash {
338                if let Some(val_addr) =
339                    self.get_value_addr_if_key_match(key, kv.key_value_addr, memory_arena)
340                {
341                    let v = memory_arena.read(val_addr);
342                    let new_v = updater(Some(v));
343                    memory_arena.write_at(val_addr, new_v);
344                    return new_v;
345                }
346            }
347            // This allows fetching the next bucket before the loop jmp
348            bucket = probe.next_probe();
349            kv = self.table[bucket];
350        }
351    }
352}
353
354#[cfg(test)]
355mod tests {
356
357    use std::collections::HashMap;
358
359    use super::{SharedArenaHashMap, compute_previous_power_of_two};
360    use crate::MemoryArena;
361
362    #[test]
363    fn test_hash_map() {
364        let mut memory_arena = MemoryArena::default();
365        let mut hash_map: SharedArenaHashMap = SharedArenaHashMap::default();
366        hash_map.mutate_or_create(b"abc", &mut memory_arena, |opt_val: Option<u32>| {
367            assert_eq!(opt_val, None);
368            3u32
369        });
370        hash_map.mutate_or_create(b"abcd", &mut memory_arena, |opt_val: Option<u32>| {
371            assert_eq!(opt_val, None);
372            4u32
373        });
374        hash_map.mutate_or_create(b"abc", &mut memory_arena, |opt_val: Option<u32>| {
375            assert_eq!(opt_val, Some(3u32));
376            5u32
377        });
378        let mut vanilla_hash_map = HashMap::new();
379        let iter_values = hash_map.iter(&memory_arena);
380        for (key, addr) in iter_values {
381            let val: u32 = memory_arena.read(addr);
382            vanilla_hash_map.insert(key.to_owned(), val);
383        }
384        assert_eq!(vanilla_hash_map.len(), 2);
385    }
386
387    #[test]
388    fn test_long_key_truncation() {
389        // Keys longer than u16::MAX are truncated.
390        let mut memory_arena = MemoryArena::default();
391        let mut hash_map: SharedArenaHashMap = SharedArenaHashMap::default();
392        let key1 = (0..u16::MAX as usize).map(|i| i as u8).collect::<Vec<_>>();
393        hash_map.mutate_or_create(&key1, &mut memory_arena, |opt_val: Option<u32>| {
394            assert_eq!(opt_val, None);
395            4u32
396        });
397        // Due to truncation, this key is the same as key1
398        let key2 = (0..u16::MAX as usize + 1)
399            .map(|i| i as u8)
400            .collect::<Vec<_>>();
401        hash_map.mutate_or_create(&key2, &mut memory_arena, |opt_val: Option<u32>| {
402            assert_eq!(opt_val, Some(4));
403            3u32
404        });
405        let mut vanilla_hash_map = HashMap::new();
406        let iter_values = hash_map.iter(&memory_arena);
407        for (key, addr) in iter_values {
408            let val: u32 = memory_arena.read(addr);
409            vanilla_hash_map.insert(key.to_owned(), val);
410            assert_eq!(key.len(), key1[..].len());
411            assert_eq!(key, &key1[..])
412        }
413        assert_eq!(vanilla_hash_map.len(), 1); // Both map to the same key
414    }
415
416    #[test]
417    fn test_empty_hashmap() {
418        let memory_arena = MemoryArena::default();
419        let hash_map: SharedArenaHashMap = SharedArenaHashMap::default();
420        assert_eq!(hash_map.get::<u32>(b"abc", &memory_arena), None);
421    }
422
423    #[test]
424    fn test_compute_previous_power_of_two() {
425        assert_eq!(compute_previous_power_of_two(8), 8);
426        assert_eq!(compute_previous_power_of_two(9), 8);
427        assert_eq!(compute_previous_power_of_two(7), 4);
428        assert_eq!(compute_previous_power_of_two(u64::MAX as usize), 1 << 63);
429    }
430
431    #[test]
432    fn test_many_terms() {
433        let mut memory_arena = MemoryArena::default();
434        let mut terms: Vec<String> = (0..20_000).map(|val| val.to_string()).collect();
435        let mut hash_map: SharedArenaHashMap = SharedArenaHashMap::default();
436        for term in terms.iter() {
437            hash_map.mutate_or_create(
438                term.as_bytes(),
439                &mut memory_arena,
440                |_opt_val: Option<u32>| 5u32,
441            );
442        }
443        let mut terms_back: Vec<String> = hash_map
444            .iter(&memory_arena)
445            .map(|(bytes, _)| String::from_utf8(bytes.to_vec()).unwrap())
446            .collect();
447        terms_back.sort();
448        terms.sort();
449
450        for pos in 0..terms.len() {
451            assert_eq!(terms[pos], terms_back[pos]);
452        }
453    }
454}