1use std::iter::{Cloned, Filter};
2use std::mem;
3
4use super::{Addr, MemoryArena};
5use crate::fastcpy::fast_short_slice_copy;
6use crate::memory_arena::store;
7
8pub fn compute_table_memory_size(capacity: usize) -> usize {
12 capacity * mem::size_of::<KeyValue>()
13}
14
15#[cfg(not(feature = "compare_hash_only"))]
16type HashType = u32;
17
18#[cfg(feature = "compare_hash_only")]
19type HashType = u64;
20
21#[derive(Copy, Clone)]
25struct KeyValue {
26 key_value_addr: Addr,
27 hash: HashType,
28}
29
30impl Default for KeyValue {
31 fn default() -> Self {
32 KeyValue {
33 key_value_addr: Addr::null_pointer(),
34 hash: 0,
35 }
36 }
37}
38
39impl KeyValue {
40 #[inline]
41 fn is_empty(&self) -> bool {
42 self.key_value_addr.is_null()
43 }
44 #[inline]
45 fn is_not_empty_ref(&self) -> bool {
46 !self.key_value_addr.is_null()
47 }
48}
49
50pub struct SharedArenaHashMap {
64 table: Vec<KeyValue>,
65 mask: usize,
66 len: usize,
67}
68
69struct LinearProbing {
70 pos: usize,
71 mask: usize,
72}
73
74impl LinearProbing {
75 #[inline]
76 fn compute(hash: HashType, mask: usize) -> LinearProbing {
77 LinearProbing {
78 pos: hash as usize,
79 mask,
80 }
81 }
82
83 #[inline]
84 fn next_probe(&mut self) -> usize {
85 self.pos = self.pos.wrapping_add(1);
87 self.pos & self.mask
88 }
89}
90
91type IterNonEmpty<'a> = Filter<Cloned<std::slice::Iter<'a, KeyValue>>, fn(&KeyValue) -> bool>;
92
93pub struct Iter<'a> {
94 hashmap: &'a SharedArenaHashMap,
95 memory_arena: &'a MemoryArena,
96 inner: IterNonEmpty<'a>,
97}
98
99impl<'a> Iterator for Iter<'a> {
100 type Item = (&'a [u8], Addr);
101
102 fn next(&mut self) -> Option<Self::Item> {
103 self.inner.next().map(move |kv| {
104 let (key, offset): (&'a [u8], Addr) = self
105 .hashmap
106 .get_key_value(kv.key_value_addr, self.memory_arena);
107 (key, offset)
108 })
109 }
110}
111
112fn compute_previous_power_of_two(n: usize) -> usize {
117 assert!(n > 0);
118 let msb = (63u32 - (n as u64).leading_zeros()) as u8;
119 1 << msb
120}
121
122impl Default for SharedArenaHashMap {
123 fn default() -> Self {
124 SharedArenaHashMap::with_capacity(4)
125 }
126}
127
128impl SharedArenaHashMap {
129 pub fn with_capacity(table_size: usize) -> SharedArenaHashMap {
130 let table_size_power_of_2 = compute_previous_power_of_two(table_size);
131 let table = vec![KeyValue::default(); table_size_power_of_2];
132
133 SharedArenaHashMap {
134 table,
135 mask: table_size_power_of_2 - 1,
136 len: 0,
137 }
138 }
139
140 #[inline]
141 #[cfg(not(feature = "compare_hash_only"))]
142 fn get_hash(&self, key: &[u8]) -> HashType {
143 murmurhash32::murmurhash2(key)
144 }
145
146 #[inline]
147 #[cfg(feature = "compare_hash_only")]
148 fn get_hash(&self, key: &[u8]) -> HashType {
149 use std::hash::Hasher;
151 let mut hasher = ahash::AHasher::default();
152 hasher.write(key);
153 hasher.finish() as HashType
154 }
155
156 #[inline]
157 fn probe(&self, hash: HashType) -> LinearProbing {
158 LinearProbing::compute(hash, self.mask)
159 }
160
161 #[inline]
162 pub fn mem_usage(&self) -> usize {
163 self.table.len() * mem::size_of::<KeyValue>()
164 }
165
166 #[inline]
167 fn is_saturated(&self) -> bool {
168 self.table.len() <= self.len * 2
169 }
170
171 #[inline]
172 fn get_key_value<'a>(&'a self, addr: Addr, memory_arena: &'a MemoryArena) -> (&'a [u8], Addr) {
173 let data = memory_arena.slice_from(addr);
174 let key_bytes_len_bytes = unsafe { data.get_unchecked(..2) };
175 let key_bytes_len = u16::from_le_bytes(key_bytes_len_bytes.try_into().unwrap());
176 let key_bytes: &[u8] = unsafe { data.get_unchecked(2..2 + key_bytes_len as usize) };
177 (key_bytes, addr.offset(2 + key_bytes_len as u32))
178 }
179
180 #[inline]
181 #[cfg(not(feature = "compare_hash_only"))]
182 fn get_value_addr_if_key_match(
183 &self,
184 target_key: &[u8],
185 addr: Addr,
186 memory_arena: &MemoryArena,
187 ) -> Option<Addr> {
188 use crate::fastcmp::fast_short_slice_compare;
189
190 let (stored_key, value_addr) = self.get_key_value(addr, memory_arena);
191 if fast_short_slice_compare(stored_key, target_key) {
192 Some(value_addr)
193 } else {
194 None
195 }
196 }
197 #[inline]
198 #[cfg(feature = "compare_hash_only")]
199 fn get_value_addr_if_key_match(
200 &self,
201 _target_key: &[u8],
202 addr: Addr,
203 memory_arena: &MemoryArena,
204 ) -> Option<Addr> {
205 let data = memory_arena.slice_from(addr);
208 let key_bytes_len_bytes = &data[..2];
209 let key_bytes_len = u16::from_le_bytes(key_bytes_len_bytes.try_into().unwrap());
210 let value_addr = addr.offset(2 + key_bytes_len as u32);
211
212 Some(value_addr)
213 }
214
215 #[inline]
216 fn set_bucket(&mut self, hash: HashType, key_value_addr: Addr, bucket: usize) {
217 self.len += 1;
218
219 self.table[bucket] = KeyValue {
220 key_value_addr,
221 hash,
222 };
223 }
224
225 #[inline]
226 pub fn is_empty(&self) -> bool {
227 self.len() == 0
228 }
229
230 #[inline]
231 pub fn len(&self) -> usize {
232 self.len
233 }
234
235 #[inline]
236 pub fn iter<'a>(&'a self, memory_arena: &'a MemoryArena) -> Iter<'a> {
237 Iter {
238 inner: self
239 .table
240 .iter()
241 .cloned()
242 .filter(KeyValue::is_not_empty_ref),
243 hashmap: self,
244 memory_arena,
245 }
246 }
247
248 fn resize(&mut self) {
249 let new_len = (self.table.len() * 2).max(1 << 3);
250 let mask = new_len - 1;
251 self.mask = mask;
252 let new_table = vec![KeyValue::default(); new_len];
253 let old_table = mem::replace(&mut self.table, new_table);
254 for key_value in old_table.into_iter().filter(KeyValue::is_not_empty_ref) {
255 let mut probe = LinearProbing::compute(key_value.hash, mask);
256 loop {
257 let bucket = probe.next_probe();
258 if self.table[bucket].is_empty() {
259 self.table[bucket] = key_value;
260 break;
261 }
262 }
263 }
264 }
265
266 #[inline]
268 pub fn get<V>(&self, key: &[u8], memory_arena: &MemoryArena) -> Option<V>
269 where V: Copy + 'static {
270 let hash = self.get_hash(key);
271 let mut probe = self.probe(hash);
272 loop {
273 let bucket = probe.next_probe();
274 let kv: KeyValue = self.table[bucket];
275 if kv.is_empty() {
276 return None;
277 } else if kv.hash == hash
278 && let Some(val_addr) =
279 self.get_value_addr_if_key_match(key, kv.key_value_addr, memory_arena)
280 {
281 let v = memory_arena.read(val_addr);
282 return Some(v);
283 }
284 }
285 }
286
287 #[inline]
300 pub fn mutate_or_create<V>(
301 &mut self,
302 key: &[u8],
303 memory_arena: &mut MemoryArena,
304 mut updater: impl FnMut(Option<V>) -> V,
305 ) -> V
306 where
307 V: Copy + 'static,
308 {
309 if self.is_saturated() {
310 self.resize();
311 }
312 let key = &key[..std::cmp::min(key.len(), u16::MAX as usize)];
314 let hash = self.get_hash(key);
315 let mut probe = self.probe(hash);
316 let mut bucket = probe.next_probe();
317 let mut kv: KeyValue = self.table[bucket];
318 loop {
319 if kv.is_empty() {
320 let val = updater(None);
322 let num_bytes = std::mem::size_of::<u16>() + key.len() + std::mem::size_of::<V>();
323 let key_addr = memory_arena.allocate_space(num_bytes);
324 {
325 let data = memory_arena.slice_mut(key_addr, num_bytes);
326 let key_len_bytes: [u8; 2] = (key.len() as u16).to_le_bytes();
327 data[..2].copy_from_slice(&key_len_bytes);
328 let stop = 2 + key.len();
329 fast_short_slice_copy(key, &mut data[2..stop]);
330 store(&mut data[stop..], val);
331 }
332
333 self.set_bucket(hash, key_addr, bucket);
334 return val;
335 }
336 if kv.hash == hash
337 && let Some(val_addr) =
338 self.get_value_addr_if_key_match(key, kv.key_value_addr, memory_arena)
339 {
340 let v = memory_arena.read(val_addr);
341 let new_v = updater(Some(v));
342 memory_arena.write_at(val_addr, new_v);
343 return new_v;
344 }
345 bucket = probe.next_probe();
347 kv = self.table[bucket];
348 }
349 }
350}
351
352#[cfg(test)]
353mod tests {
354
355 use std::collections::HashMap;
356
357 use super::{SharedArenaHashMap, compute_previous_power_of_two};
358 use crate::MemoryArena;
359
360 #[test]
361 fn test_hash_map() {
362 let mut memory_arena = MemoryArena::default();
363 let mut hash_map: SharedArenaHashMap = SharedArenaHashMap::default();
364 hash_map.mutate_or_create(b"abc", &mut memory_arena, |opt_val: Option<u32>| {
365 assert_eq!(opt_val, None);
366 3u32
367 });
368 hash_map.mutate_or_create(b"abcd", &mut memory_arena, |opt_val: Option<u32>| {
369 assert_eq!(opt_val, None);
370 4u32
371 });
372 hash_map.mutate_or_create(b"abc", &mut memory_arena, |opt_val: Option<u32>| {
373 assert_eq!(opt_val, Some(3u32));
374 5u32
375 });
376 let mut vanilla_hash_map = HashMap::new();
377 let iter_values = hash_map.iter(&memory_arena);
378 for (key, addr) in iter_values {
379 let val: u32 = memory_arena.read(addr);
380 vanilla_hash_map.insert(key.to_owned(), val);
381 }
382 assert_eq!(vanilla_hash_map.len(), 2);
383 }
384
385 #[test]
386 fn test_long_key_truncation() {
387 let mut memory_arena = MemoryArena::default();
389 let mut hash_map: SharedArenaHashMap = SharedArenaHashMap::default();
390 let key1 = (0..u16::MAX as usize).map(|i| i as u8).collect::<Vec<_>>();
391 hash_map.mutate_or_create(&key1, &mut memory_arena, |opt_val: Option<u32>| {
392 assert_eq!(opt_val, None);
393 4u32
394 });
395 let key2 = (0..u16::MAX as usize + 1)
397 .map(|i| i as u8)
398 .collect::<Vec<_>>();
399 hash_map.mutate_or_create(&key2, &mut memory_arena, |opt_val: Option<u32>| {
400 assert_eq!(opt_val, Some(4));
401 3u32
402 });
403 let mut vanilla_hash_map = HashMap::new();
404 let iter_values = hash_map.iter(&memory_arena);
405 for (key, addr) in iter_values {
406 let val: u32 = memory_arena.read(addr);
407 vanilla_hash_map.insert(key.to_owned(), val);
408 assert_eq!(key.len(), key1[..].len());
409 assert_eq!(key, &key1[..])
410 }
411 assert_eq!(vanilla_hash_map.len(), 1); }
413
414 #[test]
415 fn test_empty_hashmap() {
416 let memory_arena = MemoryArena::default();
417 let hash_map: SharedArenaHashMap = SharedArenaHashMap::default();
418 assert_eq!(hash_map.get::<u32>(b"abc", &memory_arena), None);
419 }
420
421 #[test]
422 fn test_compute_previous_power_of_two() {
423 assert_eq!(compute_previous_power_of_two(8), 8);
424 assert_eq!(compute_previous_power_of_two(9), 8);
425 assert_eq!(compute_previous_power_of_two(7), 4);
426 assert_eq!(compute_previous_power_of_two(u64::MAX as usize), 1 << 63);
427 }
428
429 #[test]
430 fn test_many_terms() {
431 let mut memory_arena = MemoryArena::default();
432 let mut terms: Vec<String> = (0..20_000).map(|val| val.to_string()).collect();
433 let mut hash_map: SharedArenaHashMap = SharedArenaHashMap::default();
434 for term in terms.iter() {
435 hash_map.mutate_or_create(
436 term.as_bytes(),
437 &mut memory_arena,
438 |_opt_val: Option<u32>| 5u32,
439 );
440 }
441 let mut terms_back: Vec<String> = hash_map
442 .iter(&memory_arena)
443 .map(|(bytes, _)| String::from_utf8(bytes.to_vec()).unwrap())
444 .collect();
445 terms_back.sort();
446 terms.sort();
447
448 for pos in 0..terms.len() {
449 assert_eq!(terms[pos], terms_back[pos]);
450 }
451 }
452}