tf_idf_vectorizer/utils/datastruct/map/
mod.rs

1use hashbrown::HashTable;
2use std::borrow::Borrow;
3use std::fmt::Debug;
4use std::hash::Hasher;
5
6pub mod serde;
7
8/// IndexMap
9/// 連続領域を保証するHashMap
10/// 
11/// # Safety
12/// table, hashes は table_* メソッドが責任をもつこと 更新とか
13/// 
14/// いじんないじんないじんな いじったならあらゆるUnitTest書いて通せ
15#[derive(Clone, Debug)]
16pub struct IndexMap<K, V, S = ahash::RandomState> {
17    values: Vec<V>,
18    index_set: IndexSet<K, S>,
19}
20
21impl<K, V, S> IndexMap<K, V, S>
22where
23    K: Eq + std::hash::Hash + Clone,
24    S: std::hash::BuildHasher,
25{
26    pub fn with_hasher(hash_builder: S) -> Self {
27        IndexMap {
28            values: Vec::new(),
29            index_set: IndexSet::with_hasher(hash_builder),
30        }
31    }
32
33    pub fn with_capacity(capacity: usize) -> Self
34    where
35        S: Default,
36    {
37        IndexMap {
38            values: Vec::with_capacity(capacity),
39            index_set: IndexSet::with_capacity(capacity),
40        }
41    }
42
43    pub fn new() -> Self
44    where
45        S: Default,
46    {
47        IndexMap {
48            values: Vec::new(),
49            index_set: IndexSet::new(),
50        }
51    }
52
53    pub fn with_hasher_capacity(hash_builder: S, capacity: usize) -> Self {
54        IndexMap {
55            values: Vec::with_capacity(capacity),
56            index_set: IndexSet::with_hasher_capacity(hash_builder, capacity),
57        }
58    }
59
60    pub fn len(&self) -> usize {
61        self.values.len()
62    }
63
64    pub fn iter_values(&self) -> std::slice::Iter<'_, V> {
65        self.values.iter()
66    }
67
68    pub fn iter_keys(&self) -> std::slice::Iter<'_, K> {
69        self.index_set.keys.iter()
70    }
71
72    pub fn iter(&self) -> IndexMapIter<'_, K, V, S> {
73        IndexMapIter {
74            map: self,
75            index: 0,
76        }
77    }
78
79    pub fn values(&self) -> &Vec<V> {
80        &self.values
81    }
82
83    pub fn keys(&self) -> &Vec<K> {
84        &self.index_set.keys()
85    }
86
87    pub fn as_index_set(&self) -> &IndexSet<K, S> {
88        &self.index_set
89    }
90
91    pub fn get<Q: ?Sized>(&self, key: &Q) -> Option<&V> 
92    where 
93        K: Borrow<Q>,
94        Q: std::hash::Hash + Eq,
95    {
96        if let Some(idx) = self.index_set.table_get(key) {
97            unsafe {
98                Some(self.values.get_unchecked(idx))
99            }
100        } else {
101            None
102        }
103    }
104
105    pub fn get_mut<Q: ?Sized>(&mut self, key: &Q) -> Option<&mut V> 
106    where 
107        K: Borrow<Q>,
108        Q: std::hash::Hash + Eq,
109    {
110        if let Some(idx) = self.index_set.table_get(key) {
111            unsafe {
112                Some(self.values.get_unchecked_mut(idx))
113            }
114        } else {
115            None
116        }
117    }
118
119    pub fn get_with_index(&self, index: usize) -> Option<&V> {
120        self.values.get(index)
121    }
122
123    pub fn get_with_index_mut(&mut self, index: usize) -> Option<&mut V> {
124        self.values.get_mut(index)
125    }
126
127    pub fn get_key_with_index(&self, index: usize) -> Option<&K> {
128        self.index_set.get_with_index(index)
129    }
130
131    pub fn get_key_value_with_index(&self, index: usize) -> Option<(&K, &V)> {
132        if index < self.len() {
133            unsafe {
134                Some((
135                    self.index_set.keys.get_unchecked(index),
136                    self.values.get_unchecked(index),
137                ))
138            }
139        } else {
140            None
141        }
142    }
143
144    pub fn get_index<Q: ?Sized>(&self, key: &Q) -> Option<usize> 
145    where 
146        K: Borrow<Q>,
147        Q: std::hash::Hash + Eq,
148    {
149        self.index_set.table_get(key)
150    }
151
152    pub fn contains_key<Q: ?Sized>(&self, key: &Q) -> bool 
153    where 
154        K: Borrow<Q>,
155        Q: std::hash::Hash + Eq,
156    {
157        self.index_set.contains_key(key)
158    }
159
160    pub fn insert(&mut self, key: K, value: V) -> Option<InsertResult<K, V>> {
161        if let Some(idx) = self.index_set.table_get(&key) {
162            // K が Rc の場合を考慮して すべて差し替える
163            unsafe {
164                self.index_set.table_override(&key, &idx);
165            }
166            let old_value = Some(std::mem::replace(&mut self.values[idx], value));
167            let old_key = Some(std::mem::replace(&mut self.index_set.keys[idx], key.clone()));
168            Some(InsertResult {
169                old_value: old_value.unwrap(),
170                old_key:  old_key.unwrap(),
171            })
172        } else {
173            // New key, insert entry
174            let idx = self.values.len();
175            unsafe {
176                self.index_set.table_append(&key, &idx);
177            }
178            self.index_set.keys.push(key.clone());
179            self.values.push(value);
180            None
181        }
182    }
183
184    pub fn entry_mut<'a>(&'a mut self, key: K) -> EntryMut<'a, K, V, S> {
185        if let Some(idx) = self.index_set.table_get(&key) {
186            unsafe {
187                EntryMut::Occupied {
188                    key: key,
189                    value: self.values.get_unchecked_mut(idx),
190                    index: idx,
191                }
192            }
193        } else {
194            EntryMut::Vacant { key , map: self }
195        }
196    }
197
198    pub fn swap_remove<Q: ?Sized>(&mut self, key: &Q) -> Option<V> 
199    where 
200        K: Borrow<Q>,
201        Q: std::hash::Hash + Eq,
202    {
203        if let Some(idx) = self.index_set.table_get(key) {
204            let last_idx = self.values.len() - 1;
205            if idx == last_idx {
206                // 最後の要素を削除する場合
207                unsafe {
208                    self.index_set.table_swap_remove(key);
209                }
210                self.index_set.keys.pop();
211                return Some(self.values.pop().unwrap());
212            } else {
213                let last_idx_key = self.index_set.keys[last_idx].clone();
214                unsafe {
215                    // keyとの整合性があるうちに削除予定のをtableから消す ここでhashesがswap_removeされる
216                    // last_idxの要素がswapで移動してくる
217                    self.index_set.table_swap_remove(key);
218                    // 移動させられた要素のtableを再登録
219                    // 登録されていた前のidxに対するkeyはまだ整合性が取れているので問題ない
220                    self.index_set.table_override(last_idx_key.borrow(), &idx);
221                }
222                // swap_remove ここで実際にtableのidxとvalues, keys, hashesの整合性が回復
223                let value = self.values.swap_remove(idx);
224                self.index_set.keys.swap_remove(idx);
225                Some(value)
226            }
227        } else {
228            None
229        }
230    }
231
232    pub fn from_kv_vec(k_vec: Vec<K>, v_vec: Vec<V>) -> Self
233    where
234        S: std::hash::BuildHasher + Default,
235    {
236        let hash_builder = S::default();
237        let mut map = IndexMap::with_hasher(hash_builder);
238        for (k, v) in k_vec.into_iter().zip(v_vec.into_iter()) {
239            let idx = map.values.len();
240            unsafe {
241                map.index_set.table_append(&k, &idx);
242            }
243            map.index_set.keys.push(k);
244            map.values.push(v);
245        }
246        map
247    }
248}
249
250pub struct IndexMapIter<'a, K, V, S> {
251    pub map: &'a IndexMap<K, V, S>,
252    pub index: usize,
253}
254
255impl <'a, K, V, S> Iterator for IndexMapIter<'a, K, V, S> 
256where 
257    K: Eq + std::hash::Hash + Clone,
258    S: std::hash::BuildHasher,
259{
260    type Item = (&'a K, &'a V);
261
262    fn next(&mut self) -> Option<Self::Item> {
263        if self.index < self.map.len() {
264            unsafe {
265                let k = self.map.index_set.keys.get_unchecked(self.index);
266                let v = self.map.values.get_unchecked(self.index);
267                self.index += 1;
268                Some((k, v))
269            }
270        } else {
271            None
272        }
273    } 
274    
275    fn size_hint(&self) -> (usize, Option<usize>) {
276        let remaining = self.map.len() - self.index;
277        (remaining, Some(remaining))
278    }
279
280    fn nth(&mut self, n: usize) -> Option<Self::Item> {
281        self.index += n;
282        self.next()
283    }
284}
285
286pub enum EntryMut<'a, K, V, S> {
287    Occupied { key: K, value: &'a mut V, index: usize },
288    Vacant { key: K , map: &'a mut IndexMap<K, V, S> },
289}
290
291impl<'a, K, V, S> EntryMut<'a, K, V, S>
292where 
293    K: Eq + std::hash::Hash + Clone,
294    S: std::hash::BuildHasher,
295{
296    pub fn is_occupied(&self) -> bool {
297        matches!(self, EntryMut::Occupied { .. })
298    }
299
300    pub fn or_insert_with<F>(self, value: F) -> &'a mut V
301    where
302        F: FnOnce() -> V,
303        K: Clone,
304    {
305        match self {
306            EntryMut::Occupied { value: v, .. } => v,
307            EntryMut::Vacant { key, map } => {
308                map.insert(key.clone(), value());
309                map.get_mut(&key).unwrap()
310            }
311        }
312    }
313}
314
315#[derive(Debug, PartialEq)]
316pub struct InsertResult<K, V> {
317    pub old_value: V,
318    pub old_key: K,
319}
320
321#[derive(Clone, Debug)]
322pub struct IndexSet<K, S = ahash::RandomState> {
323    keys: Vec<K>,
324    hashes: Vec<u64>,
325    table: HashTable<usize>,
326    hash_builder: S,
327}
328
329impl<K, S> IndexSet<K, S>
330where
331    K: Eq + std::hash::Hash + Clone,
332    S: std::hash::BuildHasher,
333{
334    pub fn with_hasher(hash_builder: S) -> Self {
335        IndexSet {
336            keys: Vec::new(),
337            hashes: Vec::new(),
338            table: HashTable::new(),
339            hash_builder,
340        }
341    }
342
343    pub fn new() -> Self
344    where
345        S: Default,
346    {
347        IndexSet {
348            keys: Vec::new(),
349            hashes: Vec::new(),
350            table: HashTable::new(),
351            hash_builder: S::default(),
352        }
353    }
354
355    pub fn with_capacity(capacity: usize) -> Self
356    where
357        S: Default,
358    {
359        IndexSet {
360            keys: Vec::with_capacity(capacity),
361            hashes: Vec::with_capacity(capacity),
362            table: HashTable::with_capacity(capacity),
363            hash_builder: S::default(),
364        }
365    }
366
367    pub fn with_hasher_capacity(hash_builder: S, capacity: usize) -> Self {
368        IndexSet {
369            keys: Vec::with_capacity(capacity),
370            hashes: Vec::with_capacity(capacity),
371            table: HashTable::with_capacity(capacity),
372            hash_builder,
373        }
374    }
375
376    /// hash util
377    fn hash_key<Q: ?Sized>(&self, key: &Q) -> u64 
378    where 
379        K: Borrow<Q>,
380        Q: std::hash::Hash + Eq,
381    {
382        let mut hasher = self.hash_builder.build_hasher();
383        key.hash(&mut hasher);
384        hasher.finish()
385    }
386
387    /// override
388    /// 完全な整合性が必要
389    /// keyに対するidxを更新し、更新したidxを返す
390    /// 存在しない場合はNone
391    unsafe fn table_override<Q: ?Sized>(&mut self, key: &Q, idx: &usize) -> Option<usize> 
392    where 
393        K: Borrow<Q>,
394        Q: std::hash::Hash + Eq,
395    {
396        let hash = self.hash_key(key);
397        match self.table.find_entry(hash, |&i| self.keys[i].borrow() == key) {
398            Ok(mut occ) => {
399                // idxの上書きだけ
400                *occ.get_mut() = *idx;
401                Some(*idx)
402            }
403            Err(_) => {
404                None
405            }
406        }
407    }
408
409    /// append
410    /// 完全な整合性が必要
411    /// hashesとtableを更新する
412    unsafe fn table_append<Q: ?Sized>(&mut self, key: &Q, idx: &usize) 
413    where 
414        K: Borrow<Q>,
415        Q: std::hash::Hash + Eq,
416    {
417        let hash = self.hash_key(key);
418        self.hashes.push(hash);
419        self.table.insert_unique(
420            hash,
421            *idx,
422            |&i| self.hashes[i]
423        );
424    }
425
426    /// get
427    /// とくに注意なし 不可変参照なので
428    fn table_get<Q: ?Sized>(&self, key: &Q) -> Option<usize> 
429    where 
430        K: Borrow<Q>,
431        Q: std::hash::Hash + Eq,
432    {
433        let hash = self.hash_key(key);
434        self.table.find(
435            hash, 
436            |&i| self.keys[i].borrow() == key
437        ).copied()
438    }
439
440    /// remove
441    /// 完全な整合性が必要
442    /// hashesはswap_removeされます
443    unsafe fn table_swap_remove<Q: ?Sized>(&mut self, key: &Q) -> Option<usize> 
444    where 
445        K: Borrow<Q>,
446        Q: std::hash::Hash + Eq,
447    {
448        let hash = self.hash_key(key);
449        if let Ok(entry) = self.table.find_entry(
450            hash,
451            |&i| self.keys[i].borrow() == key
452        ) {
453            let (odl_idx, _) = entry.remove();
454            self.hashes.swap_remove(odl_idx);
455            Some(odl_idx)
456        } else {
457            None
458        }
459    }
460
461    pub fn contains_key<Q: ?Sized>(&self, key: &Q) -> bool 
462    where 
463        K: Borrow<Q>,
464        Q: std::hash::Hash + Eq,
465    {
466        self.table_get(key).is_some()
467    }
468
469    pub fn len(&self) -> usize {
470        self.keys.len()
471    }
472
473    pub fn get_index<Q: ?Sized>(&self, key: &Q) -> Option<usize> 
474    where 
475        K: Borrow<Q>,
476        Q: std::hash::Hash + Eq,
477    {
478        self.table_get(key)
479    }
480
481    pub fn iter(&self) -> std::slice::Iter<'_, K> {
482        self.keys.iter()
483    }
484
485    pub fn keys(&self) -> &Vec<K> {
486        &self.keys
487    }
488
489    pub fn get_with_index(&self, index: usize) -> Option<&K> {
490        self.keys.get(index)
491    }
492}
493
494
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499    use std::collections::HashMap;
500
501    // まずは比較しやすい型でテスト
502    type M = IndexMap<u64, i64>;
503
504    fn assert_internal_invariants(map: &M) {
505        // 長さが揃っていること
506        assert_eq!(map.values.len(), map.index_set.keys.len(), "values/keys len mismatch");
507        assert_eq!(map.values.len(), map.index_set.hashes.len(), "values/hashes len mismatch");
508
509        // table_get が返す idx が範囲内で、keys/values と一致すること
510        for (i, k) in map.index_set.keys.iter().enumerate() {
511            let idx = map.index_set.table_get(k).expect("table_get must find existing key");
512            assert_eq!(idx, i, "table idx mismatch for key");
513        }
514
515        // 逆方向も確認
516        // 重複キー禁止 + contains/get の整合
517        for i in 0..map.len() {
518            let k = &map.index_set.keys[i];
519            assert!(map.contains_key(k), "contains_key false for existing key");
520            let v = map.get(k).expect("get must return for existing key");
521            assert_eq!(*v, map.values[i], "get value mismatch");
522        }
523
524        // キー重複が無いこと
525        // O(n^2) だけどユニットテストならOK
526        for i in 0..map.index_set.keys.len() {
527            for j in (i + 1)..map.index_set.keys.len() {
528                assert!(map.index_set.keys[i] != map.index_set.keys[j], "duplicate keys detected");
529            }
530        }
531    }
532
533    fn assert_equals_oracle(map: &M, oracle: &HashMap<u64, i64>) {
534        assert_eq!(map.len(), oracle.len(), "len mismatch");
535
536        // 全キーが一致し、値も一致すること
537        for (k, v) in oracle.iter() {
538            let got = map.get(k).copied();
539            assert_eq!(got, Some(*v), "value mismatch for key={k}");
540        }
541
542        // map 側に余計なキーがないこと(oracle で確認)
543        for (k, v) in map.iter() {
544            assert_eq!(oracle.get(k).copied(), Some(*v), "extra/mismatch entry key={k}");
545        }
546    }
547
548    #[test]
549    fn basic_insert_get_overwrite() {
550        let mut m = M::new();
551
552        assert_eq!(m.insert(1, 10), None);
553        assert_eq!(m.insert(2, 20), None);
554        assert_eq!(m.get(&1).copied(), Some(10));
555        assert_eq!(m.get(&2).copied(), Some(20));
556
557        // overwrite
558        let old = m.insert(1, 99).unwrap();
559        assert_eq!(old.old_key, 1);
560        assert_eq!(old.old_value, 10);
561        assert_eq!(m.get(&1).copied(), Some(99));
562
563        assert_internal_invariants(&m);
564    }
565
566    #[test]
567    fn swap_remove_last_and_middle() {
568        let mut m = M::new();
569        for i in 0..10 {
570            m.insert(i, (i as i64) * 10);
571        }
572
573        // last remove
574        let v = m.swap_remove(&9);
575        assert_eq!(v, Some(90));
576        assert!(m.get(&9).is_none());
577
578        // middle remove
579        let v = m.swap_remove(&3);
580        assert_eq!(v, Some(30));
581        assert!(m.get(&3).is_none());
582
583        assert_internal_invariants(&m);
584    }
585
586    #[test]
587    fn entry_or_insert_with_works() {
588        let mut m = M::new();
589
590        let v = m.entry_mut(7).or_insert_with(|| 123);
591        assert_eq!(*v, 123);
592
593        // 2回目は既存参照が返る
594        let v2 = m.entry_mut(7).or_insert_with(|| 999);
595        assert_eq!(*v2, 123);
596
597        assert_internal_invariants(&m);
598    }
599
600    #[test]
601    fn compare_with_std_hashmap_small_scripted() {
602        let mut m = M::new();
603        let mut o = HashMap::<u64, i64>::new();
604
605        // 混ぜた操作を固定シナリオで
606        for i in 0..50u64 {
607            m.insert(i, i as i64);
608            o.insert(i, i as i64);
609        }
610
611        for i in 0..50u64 {
612            if i % 3 == 0 {
613                let a = m.swap_remove(&i);
614                let b = o.remove(&i);
615                assert_eq!(a, b);
616            }
617        }
618
619        for i in 0..50u64 {
620            if i % 5 == 0 {
621                m.insert(i, (i as i64) * 100);
622                o.insert(i, (i as i64) * 100);
623            }
624        }
625
626        assert_internal_invariants(&m);
627        assert_equals_oracle(&m, &o);
628    }
629
630    #[test]
631    fn randomized_ops_compare_with_oracle() {
632        use rand::{rngs::StdRng, Rng, SeedableRng};
633
634        let mut rng = StdRng::seed_from_u64(0xC0FFEE);
635        let mut m = M::new();
636        let mut o = HashMap::<u64, i64>::new();
637
638        // ある程度衝突や削除を踏む
639        const STEPS: usize = 30_000;
640        const KEY_SPACE: u64 = 2_000;
641
642        for _ in 0..STEPS {
643            let op = rng.gen_range(0..100);
644            let k = rng.gen_range(0..KEY_SPACE);
645            match op {
646                // insert (多め)
647                0..=59 => {
648                    let v = rng.gen_range(-1_000_000..=1_000_000);
649                    let a = m.insert(k, v);
650                    let b = o.insert(k, v);
651
652                    match (a, b) {
653                        (None, None) => {}
654                        (Some(ir), Some(old)) => {
655                            assert_eq!(ir.old_key, k);
656                            assert_eq!(ir.old_value, old);
657                        }
658                        _ => panic!("insert mismatch"),
659                    }
660                }
661                // swap_remove
662                60..=79 => {
663                    let a = m.swap_remove(&k);
664                    let b = o.remove(&k);
665                    assert_eq!(a, b);
666                }
667                // get
668                80..=94 => {
669                    let a = m.get(&k).copied();
670                    let b = o.get(&k).copied();
671                    assert_eq!(a, b);
672                }
673                // contains
674                _ => {
675                    let a = m.contains_key(&k);
676                    let b = o.contains_key(&k);
677                    assert_eq!(a, b);
678                }
679            }
680
681            // たまに内部整合をチェック(重いので間引く)
682            if rng.gen_ratio(1, 200) {
683                assert_internal_invariants(&m);
684                assert_equals_oracle(&m, &o);
685            }
686        }
687
688        // 最後に必ず一致
689        assert_internal_invariants(&m);
690        assert_equals_oracle(&m, &o);
691    }
692
693    #[test]
694    fn empty_map_basics() {
695        let m = M::new();
696
697        assert_eq!(m.len(), 0);
698        assert!(m.get(&123).is_none());
699        assert!(!m.contains_key(&123));
700        // 空でも長さ整合は成立
701        assert_eq!(m.values.len(), 0);
702        assert_eq!(m.index_set.keys.len(), 0);
703        assert_eq!(m.index_set.hashes.len(), 0);
704    }
705
706    #[test]
707    fn swap_remove_single_element_roundtrip() {
708        let mut m = M::new();
709        m.insert(42, -7);
710        assert_internal_invariants(&m);
711
712        let v = m.swap_remove(&42);
713        assert_eq!(v, Some(-7));
714        assert_eq!(m.len(), 0);
715        assert!(m.get(&42).is_none());
716        assert!(!m.contains_key(&42));
717
718        assert_internal_invariants(&m);
719    }
720
721    #[test]
722    fn remove_then_reinsert_same_key() {
723        let mut m = M::new();
724
725        m.insert(1, 10);
726        m.insert(2, 20);
727        m.insert(3, 30);
728        assert_internal_invariants(&m);
729
730        assert_eq!(m.swap_remove(&2), Some(20));
731        assert!(m.get(&2).is_none());
732        assert_internal_invariants(&m);
733
734        // 同じキーを再挿入しても table が壊れないこと
735        assert_eq!(m.insert(2, 200), None);
736        assert_eq!(m.get(&2).copied(), Some(200));
737        assert_internal_invariants(&m);
738    }
739
740    #[test]
741    fn from_kv_vec_builds_valid_map() {
742        let keys = vec![1u64, 2u64, 3u64, 10u64];
743        let values = vec![10i64, 20i64, 30i64, 100i64];
744
745        let m = M::from_kv_vec(keys.clone(), values.clone());
746        assert_eq!(m.len(), 4);
747
748        // 順序と内容が一致
749        assert_eq!(m.index_set.keys, keys);
750        assert_eq!(m.values, values);
751
752        assert_internal_invariants(&m);
753    }
754
755    #[test]
756    fn iter_order_matches_internal_storage_even_after_removes() {
757        let mut m = M::new();
758        for i in 0..8u64 {
759            m.insert(i, (i as i64) + 100);
760        }
761        assert_internal_invariants(&m);
762
763        // いくつか消して、内部順序が変わっても iter が keys/values と整合すること
764        assert_eq!(m.swap_remove(&0), Some(100));
765        assert_eq!(m.swap_remove(&5), Some(105));
766        assert_internal_invariants(&m);
767
768        let collected: Vec<(u64, i64)> = m.iter().map(|(k, v)| (*k, *v)).collect();
769        let expected: Vec<(u64, i64)> = m.index_set.keys.iter().copied().zip(m.values.iter().copied()).collect();
770        assert_eq!(collected, expected);
771    }
772
773    #[test]
774    fn num_23257_issue() {
775        const ITER_NUM: u64 = 223259;
776        let mut map: IndexMap<u64, Vec<u64>> = IndexMap::new();
777        for i in 0..ITER_NUM {
778            map.entry_mut(0).or_insert_with(Vec::new).push(i);
779        }
780        assert_eq!(map.len(), 1);
781        assert_eq!(map.get(&0).unwrap().len() as u64, ITER_NUM);
782    }
783}