Skip to main content

tf_idf_vectorizer/vectorizer/
mod.rs

1pub mod corpus;
2pub mod tfidf;
3pub mod term;
4pub mod serde;
5pub mod evaluate;
6
7use std::cmp::Ordering;
8use std::sync::Arc;
9use std::hash::Hash;
10
11use half::f16;
12use num_traits::Num;
13
14use crate::utils::datastruct::map::index_map::{EntryMut, InsertResult, RemoveResult};
15use crate::utils::datastruct::vector::{TFVector, TFVectorTrait, IDFVector};
16use crate::{DefaultTFIDFEngine, TFIDFEngine, TermFrequency};
17use crate::utils::datastruct::map::IndexMap;
18use crate::Corpus;
19
20/// TF-IDF Vectorizer
21///
22/// The top-level struct of this crate, providing the main TF-IDF vectorizer features.
23///
24/// It converts a document collection into TF-IDF vectors and supports similarity
25/// computation and search functionality.
26///
27/// ### Internals
28/// - Corpus vocabulary
29/// - Sparse TF vectors per document
30/// - term index mapping
31/// - Cached IDF vector
32/// - Pluggable TF-IDF engine
33/// - Inverted document index
34///
35/// ### Type Parameters
36/// - `N`: Vector parameter type (e.g., `f32`, `f64`, `u16`)
37/// - `K`: Document key type (e.g., `String`, `usize`)
38/// - `E`: TF-IDF calculation engine
39///
40/// ### Notes
41/// - Requires an `Arc<Corpus>` on construction
42/// - `Corpus` can be shared across multiple vectorizers
43///
44/// ### Serialization
45/// Supported.  
46/// Serialized data includes the `Corpus` reference.
47///
48/// For corpus-independent storage, use [`TFIDFData`].
49#[derive(Debug, Clone)]
50pub struct TFIDFVectorizer<N = f16, K = String, E = DefaultTFIDFEngine>
51where
52    N: Num + Copy + Into<f64> + Send + Sync,
53    E: TFIDFEngine<N> + Send + Sync,
54    K: Clone + Send + Sync + Eq + std::hash::Hash,
55{
56    /// Document's TF Vector
57    pub documents: IndexMap<K, TFVector<N>>,
58    /// TF Vector's term dimension sample and reverse index
59    /// Key is never changed and unused terms are not removed
60    pub term_dim_rev_index: IndexMap<Box<str>, Vec<u32>>,
61    /// Corpus reference
62    pub corpus_ref: Arc<Corpus>,
63    /// IDF Vector
64    pub idf_cache: IDFVector,
65    _marker: std::marker::PhantomData<E>,
66}
67
68impl <N, K, E> TFIDFVectorizer<N, K, E>
69where
70    N: Num + Copy + Into<f64> + Send + Sync,
71    E: TFIDFEngine<N> + Send + Sync,
72    K: Clone + Send + Sync + Eq + Hash,
73{
74    /// Create a new TFIDFVectorizer instance
75    pub fn new(corpus_ref: Arc<Corpus>) -> Self {
76        let mut instance = Self {
77            documents: IndexMap::new(),
78            term_dim_rev_index: IndexMap::new(),
79            corpus_ref,
80            idf_cache: IDFVector::new(),
81            _marker: std::marker::PhantomData,
82        };
83        instance.re_calc_idf();
84        instance
85    }
86
87    /// set corpus reference
88    /// and recalculate idf
89    pub fn set_corpus_ref(&mut self, corpus_ref: Arc<Corpus>) {
90        self.corpus_ref = corpus_ref;
91        self.re_calc_idf();
92    }
93
94    /// Corpusに変更があればIDFを再計算する
95    pub fn update_idf(&mut self) {
96        if self.corpus_ref.get_gen_num() != self.idf_cache.latest_entropy {
97            self.re_calc_idf();
98        }
99        // 更新がなければ何もしない
100    }
101
102    /// CorpusからIDFを再計算する
103    fn re_calc_idf(&mut self) {
104        self.idf_cache.latest_entropy = self.corpus_ref.get_gen_num();
105        self.idf_cache.doc_num = self.corpus_ref.get_doc_num();
106        self.idf_cache.idf_vec = E::idf_vec(&self.corpus_ref, self.term_dim_rev_index.keys());
107    }
108}
109
110impl <N, K, E> TFIDFVectorizer<N, K, E>
111where
112    N: Num + Copy + Into<f64> + Send + Sync,
113    E: TFIDFEngine<N> + Send + Sync,
114    K: PartialEq + Clone + Send + Sync + Eq + Hash
115{
116    /// Add a document
117    /// The immediately referenced Corpus is also updated
118    pub fn add_doc(&mut self, key: K, doc: &TermFrequency) {
119        // 新語彙を追加 (O(|doc_vocab|))
120        for tok in doc.term_set(){
121            self.term_dim_rev_index
122            .entry_mut(tok.into_boxed_str())
123                .or_insert_with(Vec::new);
124        }
125        let tf_vec= E::tf_vec(doc, self.term_dim_rev_index.as_index_set());
126        let (new_terms, old_terms) = self.add_tf_vec(key, tf_vec);
127        // コーパスも更新
128        if old_terms.is_empty() && new_terms.is_empty() {
129            // 何も変わっていなければ何もしない
130            return;
131        }
132        if old_terms.is_empty() {
133            // 追加のみ
134            let add_terms: Vec<&Box<str>> = new_terms.iter()
135                .filter_map(|&idx| self.term_dim_rev_index.get_key_with_index(idx as usize))
136                .collect();
137            self.corpus_ref.add_set(&add_terms);
138            return;
139        }
140        let mut new_terms_iter = new_terms.into_iter().fuse();
141        let mut old_terms_iter = old_terms.into_iter().fuse();
142        let mut new_term_next = new_terms_iter.next();
143        let mut old_term_next = old_terms_iter.next();
144        let mut add_terms = Vec::new();
145        let mut del_terms = Vec::new();
146        while let (Some(new_idx), Some(old_idx)) = (new_term_next, old_term_next) {
147            match new_idx.cmp(&old_idx) {
148                Ordering::Less => {
149                    // new にのみ存在 -> 追加
150                    let term = self.term_dim_rev_index.get_key_with_index(new_idx as usize).expect("unreachable");
151                    add_terms.push(term);
152                    new_term_next = new_terms_iter.next();
153                }
154                Ordering::Greater => {
155                    // old にのみ存在 -> 削除
156                    let term = self.term_dim_rev_index.get_key_with_index(old_idx as usize).expect("unreachable");
157                    del_terms.push(term);
158                    old_term_next = old_terms_iter.next();
159                }
160                Ordering::Equal => {
161                    // 両方に存在 -> 何もしない
162                    new_term_next = new_terms_iter.next();
163                    old_term_next = old_terms_iter.next();
164                }
165            }
166        }
167        while let Some(new_idx) = new_term_next {
168            // new にのみ存在 -> 追加
169            let term = self.term_dim_rev_index.get_key_with_index(new_idx as usize).expect("unreachable");
170            add_terms.push(term);
171            new_term_next = new_terms_iter.next();
172        }
173        while let Some(old_idx) = old_term_next {
174            // old にのみ存在 -> 削除
175            let term = self.term_dim_rev_index.get_key_with_index(old_idx as usize).expect("unreachable");
176            del_terms.push(term);
177            old_term_next = old_terms_iter.next();
178        }
179        self.corpus_ref.add_set(&add_terms);
180        self.corpus_ref.sub_set(&del_terms);
181    }
182
183    fn add_tf_vec(&mut self, key: K, tf_vec: TFVector<N>) -> (Vec<u32>, Vec<u32>) {
184        let new_tf_terms_ind: Vec<u32> = tf_vec.as_ind_slice().to_vec();
185        match self.documents.insert(key, tf_vec) {
186            InsertResult::New { index: id } => { 
187                self.documents.get_with_index(id).expect("unreachable").as_ind_slice().iter().for_each(|&idx| {
188                    self.term_dim_rev_index.get_with_index_mut(idx as usize).expect("unreachable").push(id as u32);
189                });
190                (new_tf_terms_ind, Vec::new())
191            }
192            InsertResult::Override { old_value: old_tf, old_key: _, index: id } => {
193                let old_tf_ind_iter = old_tf.as_ind_slice().iter();
194                let new_tf_ind_iter = self.documents.get_with_index(id).expect("unreachable").as_ind_slice().iter();
195                let mut old_it = old_tf_ind_iter.fuse();
196                let mut new_it = new_tf_ind_iter.fuse();
197                let mut old_next = old_it.next();
198                let mut new_next = new_it.next();
199                while let (Some(old_idx), Some(new_idx)) = (old_next, new_next) {
200                    match old_idx.cmp(new_idx) {
201                        Ordering::Equal => {
202                            // 両方に存在 -> 何もしない
203                            old_next = old_it.next();
204                            new_next = new_it.next();
205                        }
206                        Ordering::Less => {
207                            // old にのみ存在 -> 削除
208                            let doc_keys = self.term_dim_rev_index.get_with_index_mut(*old_idx as usize).expect("unreachable");
209                            doc_keys.iter().position(|k| *k == id as u32).map(|pos| {
210                                doc_keys.swap_remove(pos);
211                            });
212                            old_next = old_it.next();
213                        }
214                        Ordering::Greater => {
215                            // new にのみ存在 -> 追加
216                            let doc_keys = self.term_dim_rev_index.get_with_index_mut(*new_idx as usize).expect("unreachable");
217                            doc_keys.push(id as u32);
218                            new_next = new_it.next();
219                        }
220                    }
221                }
222                (new_tf_terms_ind, old_tf.as_ind_slice().to_vec())
223            }
224        }
225    }
226
227    pub fn del_doc(&mut self, key: &K)
228    where
229        K: PartialEq,
230    {
231        match self.documents.swap_remove(key) {
232            RemoveResult::Removed { old_value: tf_vec, old_key: _, index: id } => {
233                // 逆Indexから削除
234                let terms_idx = tf_vec.as_ind_slice();
235                terms_idx.iter().for_each(|&idx| {
236                    let doc_keys = self.term_dim_rev_index.get_with_index_mut(idx as usize).expect("unreachable");
237                    doc_keys.iter().position(|k| *k == id as u32).map(|pos| {
238                        doc_keys.swap_remove(pos);
239                    });
240                });
241                // swap したdocにおいて逆IndexのIDを書き換え
242                let swap_doc_id = self.documents.len() as u32;
243                if swap_doc_id != id as u32 {
244                    self.documents.get_with_index(id).expect("unreachable").as_ind_slice().iter().for_each(|&idx| {
245                        let doc_keys = self.term_dim_rev_index.get_with_index_mut(idx as usize).expect("unreachable");
246                        doc_keys.iter().position(|k| *k == swap_doc_id).map(|pos| {
247                            doc_keys[pos] = id as u32;
248                        });
249                    });
250                }
251                // コーパスからも削除
252                let terms = terms_idx.iter()
253                    .filter_map(|&idx| self.term_dim_rev_index.get_key_with_index(idx as usize))
254                    .collect::<Vec<&Box<str>>>();
255                self.corpus_ref.sub_set(&terms);
256            }
257            RemoveResult::None => {}
258        }
259    }
260
261    /// Get TFVector by document ID
262    pub fn get_tf(&self, key: &K) -> Option<&TFVector<N>>
263    where
264        K: Eq + Hash,
265    {
266        self.documents.get(key)
267    }
268
269    /// Get TermFrequency by document ID
270    /// If quantized, there may be some error
271    /// Words not included in the corpus are ignored
272    pub fn get_tf_into_term_freq(&self, key: &K) -> Option<TermFrequency>
273    {
274        if let Some(tf_vec) = self.get_tf(key) {
275            let mut term_freq = TermFrequency::new();
276            tf_vec.raw_iter().for_each(|(idx, val)| {
277                let idx = idx as usize;
278                if let Some(term) = self.term_dim_rev_index.get_key_with_index(idx) {
279                    let term_num = E::tf_denorm(val);
280                    term_freq.set_term_count(term, term_num as u64);
281                } // out of range is ignored
282            });
283            Some(term_freq)
284        } else {
285            None
286        }
287    }
288
289    /// Check if a document with the given ID exists
290    pub fn contains_doc(&self, key: &K) -> bool
291    where
292        K: PartialEq,
293    {
294        self.documents.contains_key(key)
295    }
296
297    /// Check if the term exists in the term dimension sample
298    pub fn contains_term(&self, term: &str) -> bool {
299        self.term_dim_rev_index.contains_key(&Box::<str>::from(term))
300    }
301
302    /// Check if all terms in the given TermFrequency exist in the term dimension sample
303    pub fn contains_terms_from_freq(&self, freq: &TermFrequency) -> bool {
304        freq.term_set_ref_str().iter().all(|tok| self.contains_term(tok))
305    }
306
307    pub fn doc_num(&self) -> usize {
308        self.documents.len()
309    }
310
311    /// Merge another TFIDFVectorizer into this one
312    pub fn merge(&mut self, other: Self)
313    where
314        K: Eq + Hash,
315    {
316        // termの追加と置換行列を作成
317        let perm_idxs: Vec<u32> = other.term_dim_rev_index.into_iter().map(|(term, _)| {
318            match self.term_dim_rev_index.entry_mut(term) {
319                EntryMut::Occupied { index, ..} => index as u32,
320                EntryMut::Vacant { key, map } => {
321                    match map.insert(key, Vec::new()) {
322                        InsertResult::New { index } => index as u32,
323                        InsertResult::Override { .. } => unreachable!(),
324                    }
325                },
326            }
327        }).collect();
328        // documents のマージ
329        other.documents.into_iter().for_each(|(key, mut tf_vec)| {
330            tf_vec.perm(&perm_idxs);
331            let (_, old_tf_terms_ind) = self.add_tf_vec(key, tf_vec);
332            // コーパスも更新
333            let del_terms = old_tf_terms_ind.into_iter().map(|old_idx| {
334                self.term_dim_rev_index.get_key_with_index(old_idx as usize).expect("unreachable")
335            }).collect::<Vec<&Box<str>>>();
336            self.corpus_ref.sub_set(&del_terms);
337        });
338    }
339}