1use std::iter::{Cloned, Filter};
2use std::mem;
3
4use super::{Addr, MemoryArena};
5use crate::fastcpy::fast_short_slice_copy;
6use crate::memory_arena::store;
7
8pub fn compute_table_memory_size(capacity: usize) -> usize {
12 capacity * mem::size_of::<KeyValue>()
13}
14
15#[cfg(not(feature = "compare_hash_only"))]
16type HashType = u32;
17
18#[cfg(feature = "compare_hash_only")]
19type HashType = u64;
20
21#[derive(Copy, Clone)]
25struct KeyValue {
26 key_value_addr: Addr,
27 hash: HashType,
28}
29
30impl Default for KeyValue {
31 fn default() -> Self {
32 KeyValue {
33 key_value_addr: Addr::null_pointer(),
34 hash: 0,
35 }
36 }
37}
38
39impl KeyValue {
40 #[inline]
41 fn is_empty(&self) -> bool {
42 self.key_value_addr.is_null()
43 }
44 #[inline]
45 fn is_not_empty_ref(&self) -> bool {
46 !self.key_value_addr.is_null()
47 }
48}
49
50pub struct SharedArenaHashMap {
64 table: Vec<KeyValue>,
65 mask: usize,
66 len: usize,
67}
68
69struct LinearProbing {
70 pos: usize,
71 mask: usize,
72}
73
74impl LinearProbing {
75 #[inline]
76 fn compute(hash: HashType, mask: usize) -> LinearProbing {
77 LinearProbing {
78 pos: hash as usize,
79 mask,
80 }
81 }
82
83 #[inline]
84 fn next_probe(&mut self) -> usize {
85 self.pos = self.pos.wrapping_add(1);
87 self.pos & self.mask
88 }
89}
90
91type IterNonEmpty<'a> = Filter<Cloned<std::slice::Iter<'a, KeyValue>>, fn(&KeyValue) -> bool>;
92
93pub struct Iter<'a> {
94 hashmap: &'a SharedArenaHashMap,
95 memory_arena: &'a MemoryArena,
96 inner: IterNonEmpty<'a>,
97}
98
99impl<'a> Iterator for Iter<'a> {
100 type Item = (&'a [u8], Addr);
101
102 fn next(&mut self) -> Option<Self::Item> {
103 self.inner.next().map(move |kv| {
104 let (key, offset): (&'a [u8], Addr) = self
105 .hashmap
106 .get_key_value(kv.key_value_addr, self.memory_arena);
107 (key, offset)
108 })
109 }
110}
111
112fn compute_previous_power_of_two(n: usize) -> usize {
117 assert!(n > 0);
118 let msb = (63u32 - (n as u64).leading_zeros()) as u8;
119 1 << msb
120}
121
122impl Default for SharedArenaHashMap {
123 fn default() -> Self {
124 SharedArenaHashMap::with_capacity(4)
125 }
126}
127
128impl SharedArenaHashMap {
129 pub fn with_capacity(table_size: usize) -> SharedArenaHashMap {
130 let table_size_power_of_2 = compute_previous_power_of_two(table_size);
131 let table = vec![KeyValue::default(); table_size_power_of_2];
132
133 SharedArenaHashMap {
134 table,
135 mask: table_size_power_of_2 - 1,
136 len: 0,
137 }
138 }
139
140 #[inline]
141 #[cfg(not(feature = "compare_hash_only"))]
142 fn get_hash(&self, key: &[u8]) -> HashType {
143 murmurhash32::murmurhash2(key)
144 }
145
146 #[inline]
147 #[cfg(feature = "compare_hash_only")]
148 fn get_hash(&self, key: &[u8]) -> HashType {
149 use std::hash::Hasher;
151 let mut hasher = ahash::AHasher::default();
152 hasher.write(key);
153 hasher.finish() as HashType
154 }
155
156 #[inline]
157 fn probe(&self, hash: HashType) -> LinearProbing {
158 LinearProbing::compute(hash, self.mask)
159 }
160
161 #[inline]
162 pub fn mem_usage(&self) -> usize {
163 self.table.len() * mem::size_of::<KeyValue>()
164 }
165
166 #[inline]
167 fn is_saturated(&self) -> bool {
168 self.table.len() <= self.len * 2
169 }
170
171 #[inline]
172 fn get_key_value<'a>(&'a self, addr: Addr, memory_arena: &'a MemoryArena) -> (&'a [u8], Addr) {
173 let data = memory_arena.slice_from(addr);
174 let key_bytes_len_bytes = unsafe { data.get_unchecked(..2) };
175 let key_bytes_len = u16::from_le_bytes(key_bytes_len_bytes.try_into().unwrap());
176 let key_bytes: &[u8] = unsafe { data.get_unchecked(2..2 + key_bytes_len as usize) };
177 (key_bytes, addr.offset(2 + key_bytes_len as u32))
178 }
179
180 #[inline]
181 #[cfg(not(feature = "compare_hash_only"))]
182 fn get_value_addr_if_key_match(
183 &self,
184 target_key: &[u8],
185 addr: Addr,
186 memory_arena: &MemoryArena,
187 ) -> Option<Addr> {
188 use crate::fastcmp::fast_short_slice_compare;
189
190 let (stored_key, value_addr) = self.get_key_value(addr, memory_arena);
191 if fast_short_slice_compare(stored_key, target_key) {
192 Some(value_addr)
193 } else {
194 None
195 }
196 }
197 #[inline]
198 #[cfg(feature = "compare_hash_only")]
199 fn get_value_addr_if_key_match(
200 &self,
201 _target_key: &[u8],
202 addr: Addr,
203 memory_arena: &MemoryArena,
204 ) -> Option<Addr> {
205 let data = memory_arena.slice_from(addr);
208 let key_bytes_len_bytes = &data[..2];
209 let key_bytes_len = u16::from_le_bytes(key_bytes_len_bytes.try_into().unwrap());
210 let value_addr = addr.offset(2 + key_bytes_len as u32);
211
212 Some(value_addr)
213 }
214
215 #[inline]
216 fn set_bucket(&mut self, hash: HashType, key_value_addr: Addr, bucket: usize) {
217 self.len += 1;
218
219 self.table[bucket] = KeyValue {
220 key_value_addr,
221 hash,
222 };
223 }
224
225 #[inline]
226 pub fn is_empty(&self) -> bool {
227 self.len() == 0
228 }
229
230 #[inline]
231 pub fn len(&self) -> usize {
232 self.len
233 }
234
235 #[inline]
236 pub fn iter<'a>(&'a self, memory_arena: &'a MemoryArena) -> Iter<'a> {
237 Iter {
238 inner: self
239 .table
240 .iter()
241 .cloned()
242 .filter(KeyValue::is_not_empty_ref),
243 hashmap: self,
244 memory_arena,
245 }
246 }
247
248 fn resize(&mut self) {
249 let new_len = (self.table.len() * 2).max(1 << 3);
250 let mask = new_len - 1;
251 self.mask = mask;
252 let new_table = vec![KeyValue::default(); new_len];
253 let old_table = mem::replace(&mut self.table, new_table);
254 for key_value in old_table.into_iter().filter(KeyValue::is_not_empty_ref) {
255 let mut probe = LinearProbing::compute(key_value.hash, mask);
256 loop {
257 let bucket = probe.next_probe();
258 if self.table[bucket].is_empty() {
259 self.table[bucket] = key_value;
260 break;
261 }
262 }
263 }
264 }
265
266 #[inline]
268 pub fn get<V>(&self, key: &[u8], memory_arena: &MemoryArena) -> Option<V>
269 where V: Copy + 'static {
270 let hash = self.get_hash(key);
271 let mut probe = self.probe(hash);
272 loop {
273 let bucket = probe.next_probe();
274 let kv: KeyValue = self.table[bucket];
275 if kv.is_empty() {
276 return None;
277 } else if kv.hash == hash {
278 if let Some(val_addr) =
279 self.get_value_addr_if_key_match(key, kv.key_value_addr, memory_arena)
280 {
281 let v = memory_arena.read(val_addr);
282 return Some(v);
283 }
284 }
285 }
286 }
287
288 #[inline]
301 pub fn mutate_or_create<V>(
302 &mut self,
303 key: &[u8],
304 memory_arena: &mut MemoryArena,
305 mut updater: impl FnMut(Option<V>) -> V,
306 ) -> V
307 where
308 V: Copy + 'static,
309 {
310 if self.is_saturated() {
311 self.resize();
312 }
313 let key = &key[..std::cmp::min(key.len(), u16::MAX as usize)];
315 let hash = self.get_hash(key);
316 let mut probe = self.probe(hash);
317 let mut bucket = probe.next_probe();
318 let mut kv: KeyValue = self.table[bucket];
319 loop {
320 if kv.is_empty() {
321 let val = updater(None);
323 let num_bytes = std::mem::size_of::<u16>() + key.len() + std::mem::size_of::<V>();
324 let key_addr = memory_arena.allocate_space(num_bytes);
325 {
326 let data = memory_arena.slice_mut(key_addr, num_bytes);
327 let key_len_bytes: [u8; 2] = (key.len() as u16).to_le_bytes();
328 data[..2].copy_from_slice(&key_len_bytes);
329 let stop = 2 + key.len();
330 fast_short_slice_copy(key, &mut data[2..stop]);
331 store(&mut data[stop..], val);
332 }
333
334 self.set_bucket(hash, key_addr, bucket);
335 return val;
336 }
337 if kv.hash == hash {
338 if let Some(val_addr) =
339 self.get_value_addr_if_key_match(key, kv.key_value_addr, memory_arena)
340 {
341 let v = memory_arena.read(val_addr);
342 let new_v = updater(Some(v));
343 memory_arena.write_at(val_addr, new_v);
344 return new_v;
345 }
346 }
347 bucket = probe.next_probe();
349 kv = self.table[bucket];
350 }
351 }
352}
353
354#[cfg(test)]
355mod tests {
356
357 use std::collections::HashMap;
358
359 use super::{SharedArenaHashMap, compute_previous_power_of_two};
360 use crate::MemoryArena;
361
362 #[test]
363 fn test_hash_map() {
364 let mut memory_arena = MemoryArena::default();
365 let mut hash_map: SharedArenaHashMap = SharedArenaHashMap::default();
366 hash_map.mutate_or_create(b"abc", &mut memory_arena, |opt_val: Option<u32>| {
367 assert_eq!(opt_val, None);
368 3u32
369 });
370 hash_map.mutate_or_create(b"abcd", &mut memory_arena, |opt_val: Option<u32>| {
371 assert_eq!(opt_val, None);
372 4u32
373 });
374 hash_map.mutate_or_create(b"abc", &mut memory_arena, |opt_val: Option<u32>| {
375 assert_eq!(opt_val, Some(3u32));
376 5u32
377 });
378 let mut vanilla_hash_map = HashMap::new();
379 let iter_values = hash_map.iter(&memory_arena);
380 for (key, addr) in iter_values {
381 let val: u32 = memory_arena.read(addr);
382 vanilla_hash_map.insert(key.to_owned(), val);
383 }
384 assert_eq!(vanilla_hash_map.len(), 2);
385 }
386
387 #[test]
388 fn test_long_key_truncation() {
389 let mut memory_arena = MemoryArena::default();
391 let mut hash_map: SharedArenaHashMap = SharedArenaHashMap::default();
392 let key1 = (0..u16::MAX as usize).map(|i| i as u8).collect::<Vec<_>>();
393 hash_map.mutate_or_create(&key1, &mut memory_arena, |opt_val: Option<u32>| {
394 assert_eq!(opt_val, None);
395 4u32
396 });
397 let key2 = (0..u16::MAX as usize + 1)
399 .map(|i| i as u8)
400 .collect::<Vec<_>>();
401 hash_map.mutate_or_create(&key2, &mut memory_arena, |opt_val: Option<u32>| {
402 assert_eq!(opt_val, Some(4));
403 3u32
404 });
405 let mut vanilla_hash_map = HashMap::new();
406 let iter_values = hash_map.iter(&memory_arena);
407 for (key, addr) in iter_values {
408 let val: u32 = memory_arena.read(addr);
409 vanilla_hash_map.insert(key.to_owned(), val);
410 assert_eq!(key.len(), key1[..].len());
411 assert_eq!(key, &key1[..])
412 }
413 assert_eq!(vanilla_hash_map.len(), 1); }
415
416 #[test]
417 fn test_empty_hashmap() {
418 let memory_arena = MemoryArena::default();
419 let hash_map: SharedArenaHashMap = SharedArenaHashMap::default();
420 assert_eq!(hash_map.get::<u32>(b"abc", &memory_arena), None);
421 }
422
423 #[test]
424 fn test_compute_previous_power_of_two() {
425 assert_eq!(compute_previous_power_of_two(8), 8);
426 assert_eq!(compute_previous_power_of_two(9), 8);
427 assert_eq!(compute_previous_power_of_two(7), 4);
428 assert_eq!(compute_previous_power_of_two(u64::MAX as usize), 1 << 63);
429 }
430
431 #[test]
432 fn test_many_terms() {
433 let mut memory_arena = MemoryArena::default();
434 let mut terms: Vec<String> = (0..20_000).map(|val| val.to_string()).collect();
435 let mut hash_map: SharedArenaHashMap = SharedArenaHashMap::default();
436 for term in terms.iter() {
437 hash_map.mutate_or_create(
438 term.as_bytes(),
439 &mut memory_arena,
440 |_opt_val: Option<u32>| 5u32,
441 );
442 }
443 let mut terms_back: Vec<String> = hash_map
444 .iter(&memory_arena)
445 .map(|(bytes, _)| String::from_utf8(bytes.to_vec()).unwrap())
446 .collect();
447 terms_back.sort();
448 terms.sort();
449
450 for pos in 0..terms.len() {
451 assert_eq!(terms[pos], terms_back[pos]);
452 }
453 }
454}