tf_idf_vectorizer/utils/datastruct/map/
mod.rs

1use hashbrown::HashTable;
2use std::fmt::Debug;
3use std::hash::Hasher;
4
5pub mod serde;
6
7/// IndexMap
8/// 連続領域を保証するHashMap
9/// 
10/// # Safety
11/// table, hashes は table_* メソッドが責任をもつこと 更新とか
12/// 
13/// いじんないじんないじんな いじったならあらゆるUnitTest書いて通せ
14#[derive(Clone, Debug)]
15pub struct IndexMap<K, V, S = ahash::RandomState> {
16    pub values: Vec<V>,
17    pub keys: Vec<K>,
18    pub hashes: Vec<u64>,
19    pub table: HashTable<usize>,
20    pub hash_builder: S,
21}
22
23impl<K, V, S> IndexMap<K, V, S>
24where
25    K: Eq + std::hash::Hash + Clone,
26    S: std::hash::BuildHasher,
27{
28    pub fn with_hasher(hash_builder: S) -> Self {
29        IndexMap {
30            values: Vec::new(),
31            keys: Vec::new(),
32            hashes: Vec::new(),
33            table: HashTable::new(),
34            hash_builder,
35        }
36    }
37
38    pub fn with_capacity(capacity: usize) -> Self
39    where
40        S: Default,
41    {
42        IndexMap {
43            values: Vec::with_capacity(capacity),
44            keys: Vec::with_capacity(capacity),
45            hashes: Vec::with_capacity(capacity),
46            table: HashTable::with_capacity(capacity),
47            hash_builder: S::default(),
48        }
49    }
50
51    pub fn new() -> Self
52    where
53        S: Default,
54    {
55        IndexMap {
56            values: Vec::new(),
57            keys: Vec::new(),
58            hashes: Vec::new(),
59            table: HashTable::new(),
60            hash_builder: S::default(),
61        }
62    }
63
64    pub fn len(&self) -> usize {
65        self.values.len()
66    }
67
68    pub fn iter_values(&self) -> std::slice::Iter<'_, V> {
69        self.values.iter()
70    }
71
72    pub fn iter_keys(&self) -> std::slice::Iter<'_, K> {
73        self.keys.iter()
74    }
75
76    pub fn iter(&self) -> IndexMapIter<'_, K, V, S> {
77        IndexMapIter {
78            map: self,
79            index: 0,
80        }
81    }
82
83    pub fn values(&self) -> &Vec<V> {
84        &self.values
85    }
86
87    pub fn keys(&self) -> &Vec<K> {
88        &self.keys
89    }
90
91    /// hash util
92    fn hash_key(&self, key: &K) -> u64 {
93        let mut hasher = self.hash_builder.build_hasher();
94        key.hash(&mut hasher);
95        hasher.finish()
96    }
97
98    /// override
99    /// 完全な整合性が必要
100    /// keyに対するidxを更新し、更新したidxを返す
101    /// 存在しない場合はNone
102    unsafe fn table_override(&mut self, key: &K, idx: &usize) -> Option<usize> {
103        let hash = self.hash_key(key);
104        match self.table.find_entry(hash, |&i| self.keys[i] == *key) {
105            Ok(mut occ) => {
106                // idxの上書きだけ
107                *occ.get_mut() = *idx;
108                Some(*idx)
109            }
110            Err(_) => {
111                None
112            }
113        }
114    }
115
116    /// append
117    /// 完全な整合性が必要
118    /// hashesとtableを更新する
119    unsafe fn table_append(&mut self, key: &K, idx: &usize) {
120        let hash = self.hash_key(key);
121        self.hashes.push(hash);
122        self.table.insert_unique(
123            hash,
124            *idx,
125            |&i| self.hashes[i]
126        );
127    }
128
129    /// get
130    /// とくに注意なし 不可変参照なので
131    fn table_get(&self, key: &K) -> Option<usize> {
132        let hash = self.hash_key(key);
133        self.table.find(
134            hash, 
135            |&i| self.keys[i] == *key
136        ).copied()
137    }
138
139    /// remove
140    /// 完全な整合性が必要
141    /// hashesはswap_removeされます
142    unsafe fn table_swap_remove(&mut self, key: &K) -> Option<usize> {
143        let hash = self.hash_key(key);
144        if let Ok(entry) = self.table.find_entry(
145            hash,
146            |&i| self.keys[i] == *key
147        ) {
148            let (odl_idx, _) = entry.remove();
149            self.hashes.swap_remove(odl_idx);
150            Some(odl_idx)
151        } else {
152            None
153        }
154    }
155
156    pub fn get(&self, key: &K) -> Option<&V> {
157        if let Some(idx) = self.table_get(key) {
158            unsafe {
159                Some(self.values.get_unchecked(idx))
160            }
161        } else {
162            None
163        }
164    }
165
166    pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
167        if let Some(idx) = self.table_get(key) {
168            unsafe {
169                Some(self.values.get_unchecked_mut(idx))
170            }
171        } else {
172            None
173        }
174    }
175
176    pub fn get_with_index(&self, index: usize) -> Option<&V> {
177        self.values.get(index)
178    }
179
180    pub fn get_with_index_mut(&mut self, index: usize) -> Option<&mut V> {
181        self.values.get_mut(index)
182    }
183
184    pub fn get_key_with_index(&self, index: usize) -> Option<&K> {
185        self.keys.get(index)
186    }
187
188    pub fn get_key_value_with_index(&self, index: usize) -> Option<(&K, &V)> {
189        if index < self.len() {
190            unsafe {
191                Some((
192                    self.keys.get_unchecked(index),
193                    self.values.get_unchecked(index),
194                ))
195            }
196        } else {
197            None
198        }
199    }
200
201    pub fn get_index(&self, key: &K) -> Option<usize> {
202        self.table_get(key)
203    }
204
205    pub fn contains_key(&self, key: &K) -> bool {
206        self.table_get(key).is_some()
207    }
208
209    pub fn insert(&mut self, key: &K, value: V) -> Option<InsertResult<K, V>> {
210        if let Some(idx) = self.table_get(key) {
211            // K が Rc の場合を考慮して すべて差し替える
212            unsafe {
213                self.table_override(key, &idx);
214            }
215            let old_value = Some(std::mem::replace(&mut self.values[idx], value));
216            let old_key = Some(std::mem::replace(&mut self.keys[idx], key.clone()));
217            Some(InsertResult {
218                old_value: old_value.unwrap(),
219                old_key:  old_key.unwrap(),
220            })
221        } else {
222            // New key, insert entry
223            let idx = self.values.len();
224            unsafe {
225                self.table_append(key, &idx);
226            }
227            self.keys.push(key.clone());
228            self.values.push(value);
229            None
230        }
231    }
232
233    pub fn entry_mut<'a>(&'a mut self, key: &'a K) -> EntryMut<'a, K, V, S> {
234        if let Some(idx) = self.table_get(key) {
235            unsafe {
236                EntryMut::Occupied {
237                    key: self.keys.get_unchecked(idx),
238                    value: self.values.get_unchecked_mut(idx),
239                    index: idx,
240                }
241            }
242        } else {
243            EntryMut::Vacant { key , map: self }
244        }
245    }
246
247    pub fn swap_remove(&mut self, key: &K) -> Option<V> {
248        if let Some(idx) = self.table_get(key) {
249            let last_idx = self.values.len() - 1;
250            if idx == last_idx {
251                // 最後の要素を削除する場合
252                unsafe {
253                    self.table_swap_remove(key);
254                }
255                self.keys.pop();
256                return Some(self.values.pop().unwrap());
257            } else {
258                let last_idx_key = self.keys[last_idx].clone();
259                unsafe {
260                    // keyとの整合性があるうちに削除予定のをtableから消す ここでhashesがswap_removeされる
261                    // last_idxの要素がswapで移動してくる
262                    self.table_swap_remove(key);
263                    // 移動させられた要素のtableを再登録
264                    // 登録されていた前のidxに対するkeyはまだ整合性が取れているので問題ない
265                    self.table_override(&last_idx_key, &idx);
266                }
267                // swap_remove ここで実際にtableのidxとvalues, keys, hashesの整合性が回復
268                let value = self.values.swap_remove(idx);
269                self.keys.swap_remove(idx);
270                Some(value)
271            }
272        } else {
273            None
274        }
275    }
276
277    pub fn from_kv_vec(k_vec: Vec<K>, v_vec: Vec<V>) -> Self
278    where
279        S: std::hash::BuildHasher + Default,
280    {
281        let hash_builder = S::default();
282        let mut map = IndexMap::with_hasher(hash_builder);
283        for (k, v) in k_vec.into_iter().zip(v_vec.into_iter()) {
284            let idx = map.values.len();
285            unsafe {
286                map.table_append(&k, &idx);
287            }
288            map.keys.push(k);
289            map.values.push(v);
290        }
291        map
292    }
293}
294
295pub struct IndexMapIter<'a, K, V, S> {
296    pub map: &'a IndexMap<K, V, S>,
297    pub index: usize,
298}
299
300impl <'a, K, V, S> Iterator for IndexMapIter<'a, K, V, S> 
301where 
302    K: Eq + std::hash::Hash + Clone,
303    S: std::hash::BuildHasher,
304{
305    type Item = (&'a K, &'a V);
306
307    fn next(&mut self) -> Option<Self::Item> {
308        if self.index < self.map.len() {
309            unsafe {
310                let k = self.map.keys.get_unchecked(self.index);
311                let v = self.map.values.get_unchecked(self.index);
312                self.index += 1;
313                Some((k, v))
314            }
315        } else {
316            None
317        }
318    } 
319    
320    fn size_hint(&self) -> (usize, Option<usize>) {
321        let remaining = self.map.len() - self.index;
322        (remaining, Some(remaining))
323    }
324
325    fn nth(&mut self, n: usize) -> Option<Self::Item> {
326        self.index += n;
327        self.next()
328    }
329}
330
331pub enum EntryMut<'a, K, V, S> {
332    Occupied { key: &'a K, value: &'a mut V, index: usize },
333    Vacant { key: &'a K , map: &'a mut IndexMap<K, V, S> },
334}
335
336impl<'a, K, V, S> EntryMut<'a, K, V, S>
337where 
338    K: Eq + std::hash::Hash + Clone,
339    S: std::hash::BuildHasher,
340{
341    pub fn is_occupied(&self) -> bool {
342        matches!(self, EntryMut::Occupied { .. })
343    }
344
345    pub fn or_insert_with<F>(self, value: F) -> &'a mut V
346    where
347        F: FnOnce() -> V,
348        K: Clone,
349    {
350        match self {
351            EntryMut::Occupied { value: v, .. } => v,
352            EntryMut::Vacant { key, map } => {
353                map.insert(key, value());
354                map.get_mut(key).unwrap()
355            }
356        }
357    }
358}
359
360#[derive(Debug, PartialEq)]
361pub struct InsertResult<K, V> {
362    pub old_value: V,
363    pub old_key: K,
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use std::collections::HashMap;
370
371    // まずは比較しやすい型でテスト
372    type M = IndexMap<u64, i64>;
373
374    fn assert_internal_invariants(map: &M) {
375        // 長さが揃っていること
376        assert_eq!(map.values.len(), map.keys.len(), "values/keys len mismatch");
377        assert_eq!(map.values.len(), map.hashes.len(), "values/hashes len mismatch");
378
379        // table_get が返す idx が範囲内で、keys/values と一致すること
380        for (i, k) in map.keys.iter().enumerate() {
381            let idx = map.table_get(k).expect("table_get must find existing key");
382            assert_eq!(idx, i, "table idx mismatch for key");
383        }
384
385        // 逆方向も確認
386        // 重複キー禁止 + contains/get の整合
387        for i in 0..map.len() {
388            let k = &map.keys[i];
389            assert!(map.contains_key(k), "contains_key false for existing key");
390            let v = map.get(k).expect("get must return for existing key");
391            assert_eq!(*v, map.values[i], "get value mismatch");
392        }
393
394        // キー重複が無いこと
395        // O(n^2) だけどユニットテストならOK
396        for i in 0..map.keys.len() {
397            for j in (i + 1)..map.keys.len() {
398                assert!(map.keys[i] != map.keys[j], "duplicate keys detected");
399            }
400        }
401    }
402
403    fn assert_equals_oracle(map: &M, oracle: &HashMap<u64, i64>) {
404        assert_eq!(map.len(), oracle.len(), "len mismatch");
405
406        // 全キーが一致し、値も一致すること
407        for (k, v) in oracle.iter() {
408            let got = map.get(k).copied();
409            assert_eq!(got, Some(*v), "value mismatch for key={k}");
410        }
411
412        // map 側に余計なキーがないこと(oracle で確認)
413        for (k, v) in map.iter() {
414            assert_eq!(oracle.get(k).copied(), Some(*v), "extra/mismatch entry key={k}");
415        }
416    }
417
418    #[test]
419    fn basic_insert_get_overwrite() {
420        let mut m = M::new();
421
422        assert_eq!(m.insert(&1, 10), None);
423        assert_eq!(m.insert(&2, 20), None);
424        assert_eq!(m.get(&1).copied(), Some(10));
425        assert_eq!(m.get(&2).copied(), Some(20));
426
427        // overwrite
428        let old = m.insert(&1, 99).unwrap();
429        assert_eq!(old.old_key, 1);
430        assert_eq!(old.old_value, 10);
431        assert_eq!(m.get(&1).copied(), Some(99));
432
433        assert_internal_invariants(&m);
434    }
435
436    #[test]
437    fn swap_remove_last_and_middle() {
438        let mut m = M::new();
439        for i in 0..10 {
440            m.insert(&i, (i as i64) * 10);
441        }
442
443        // last remove
444        let v = m.swap_remove(&9);
445        assert_eq!(v, Some(90));
446        assert!(m.get(&9).is_none());
447
448        // middle remove
449        let v = m.swap_remove(&3);
450        assert_eq!(v, Some(30));
451        assert!(m.get(&3).is_none());
452
453        assert_internal_invariants(&m);
454    }
455
456    #[test]
457    fn entry_or_insert_with_works() {
458        let mut m = M::new();
459
460        let v = m.entry_mut(&7).or_insert_with(|| 123);
461        assert_eq!(*v, 123);
462
463        // 2回目は既存参照が返る
464        let v2 = m.entry_mut(&7).or_insert_with(|| 999);
465        assert_eq!(*v2, 123);
466
467        assert_internal_invariants(&m);
468    }
469
470    #[test]
471    fn compare_with_std_hashmap_small_scripted() {
472        let mut m = M::new();
473        let mut o = HashMap::<u64, i64>::new();
474
475        // 混ぜた操作を固定シナリオで
476        for i in 0..50u64 {
477            m.insert(&i, i as i64);
478            o.insert(i, i as i64);
479        }
480
481        for i in 0..50u64 {
482            if i % 3 == 0 {
483                let a = m.swap_remove(&i);
484                let b = o.remove(&i);
485                assert_eq!(a, b);
486            }
487        }
488
489        for i in 0..50u64 {
490            if i % 5 == 0 {
491                m.insert(&i, (i as i64) * 100);
492                o.insert(i, (i as i64) * 100);
493            }
494        }
495
496        assert_internal_invariants(&m);
497        assert_equals_oracle(&m, &o);
498    }
499
500    #[test]
501    fn randomized_ops_compare_with_oracle() {
502        use rand::{rngs::StdRng, Rng, SeedableRng};
503
504        let mut rng = StdRng::seed_from_u64(0xC0FFEE);
505        let mut m = M::new();
506        let mut o = HashMap::<u64, i64>::new();
507
508        // ある程度衝突や削除を踏む
509        const STEPS: usize = 30_000;
510        const KEY_SPACE: u64 = 2_000;
511
512        for _ in 0..STEPS {
513            let op = rng.gen_range(0..100);
514            let k = rng.gen_range(0..KEY_SPACE);
515            match op {
516                // insert (多め)
517                0..=59 => {
518                    let v = rng.gen_range(-1_000_000..=1_000_000);
519                    let a = m.insert(&k, v);
520                    let b = o.insert(k, v);
521
522                    match (a, b) {
523                        (None, None) => {}
524                        (Some(ir), Some(old)) => {
525                            assert_eq!(ir.old_key, k);
526                            assert_eq!(ir.old_value, old);
527                        }
528                        _ => panic!("insert mismatch"),
529                    }
530                }
531                // swap_remove
532                60..=79 => {
533                    let a = m.swap_remove(&k);
534                    let b = o.remove(&k);
535                    assert_eq!(a, b);
536                }
537                // get
538                80..=94 => {
539                    let a = m.get(&k).copied();
540                    let b = o.get(&k).copied();
541                    assert_eq!(a, b);
542                }
543                // contains
544                _ => {
545                    let a = m.contains_key(&k);
546                    let b = o.contains_key(&k);
547                    assert_eq!(a, b);
548                }
549            }
550
551            // たまに内部整合をチェック(重いので間引く)
552            if rng.gen_ratio(1, 200) {
553                assert_internal_invariants(&m);
554                assert_equals_oracle(&m, &o);
555            }
556        }
557
558        // 最後に必ず一致
559        assert_internal_invariants(&m);
560        assert_equals_oracle(&m, &o);
561    }
562
563    #[test]
564    fn empty_map_basics() {
565        let m = M::new();
566
567        assert_eq!(m.len(), 0);
568        assert!(m.get(&123).is_none());
569        assert!(!m.contains_key(&123));
570        // 空でも長さ整合は成立
571        assert_eq!(m.values.len(), 0);
572        assert_eq!(m.keys.len(), 0);
573        assert_eq!(m.hashes.len(), 0);
574    }
575
576    #[test]
577    fn swap_remove_single_element_roundtrip() {
578        let mut m = M::new();
579        m.insert(&42, -7);
580        assert_internal_invariants(&m);
581
582        let v = m.swap_remove(&42);
583        assert_eq!(v, Some(-7));
584        assert_eq!(m.len(), 0);
585        assert!(m.get(&42).is_none());
586        assert!(!m.contains_key(&42));
587
588        assert_internal_invariants(&m);
589    }
590
591    #[test]
592    fn remove_then_reinsert_same_key() {
593        let mut m = M::new();
594
595        m.insert(&1, 10);
596        m.insert(&2, 20);
597        m.insert(&3, 30);
598        assert_internal_invariants(&m);
599
600        assert_eq!(m.swap_remove(&2), Some(20));
601        assert!(m.get(&2).is_none());
602        assert_internal_invariants(&m);
603
604        // 同じキーを再挿入しても table が壊れないこと
605        assert_eq!(m.insert(&2, 200), None);
606        assert_eq!(m.get(&2).copied(), Some(200));
607        assert_internal_invariants(&m);
608    }
609
610    #[test]
611    fn from_kv_vec_builds_valid_map() {
612        let keys = vec![1u64, 2u64, 3u64, 10u64];
613        let values = vec![10i64, 20i64, 30i64, 100i64];
614
615        let m = M::from_kv_vec(keys.clone(), values.clone());
616        assert_eq!(m.len(), 4);
617
618        // 順序と内容が一致
619        assert_eq!(m.keys, keys);
620        assert_eq!(m.values, values);
621
622        assert_internal_invariants(&m);
623    }
624
625    #[test]
626    fn iter_order_matches_internal_storage_even_after_removes() {
627        let mut m = M::new();
628        for i in 0..8u64 {
629            m.insert(&i, (i as i64) + 100);
630        }
631        assert_internal_invariants(&m);
632
633        // いくつか消して、内部順序が変わっても iter が keys/values と整合すること
634        assert_eq!(m.swap_remove(&0), Some(100));
635        assert_eq!(m.swap_remove(&5), Some(105));
636        assert_internal_invariants(&m);
637
638        let collected: Vec<(u64, i64)> = m.iter().map(|(k, v)| (*k, *v)).collect();
639        let expected: Vec<(u64, i64)> = m.keys.iter().copied().zip(m.values.iter().copied()).collect();
640        assert_eq!(collected, expected);
641    }
642}