Skip to main content

tf_idf_vectorizer/vectorizer/
mod.rs

1pub mod corpus;
2pub mod tfidf;
3pub mod term;
4pub mod serde;
5pub mod evaluate;
6
7use std::cmp::Ordering;
8use std::sync::Arc;
9use std::hash::Hash;
10
11use half::f16;
12use num_traits::Num;
13
14use crate::utils::datastruct::map::index_map::{InsertResult, RemoveResult};
15use crate::utils::datastruct::vector::{TFVector, TFVectorTrait, IDFVector};
16use crate::{DefaultTFIDFEngine, TFIDFEngine, TermFrequency};
17use crate::utils::datastruct::map::IndexMap;
18use crate::Corpus;
19
20/// TF-IDF Vectorizer
21///
22/// The top-level struct of this crate, providing the main TF-IDF vectorizer features.
23///
24/// It converts a document collection into TF-IDF vectors and supports similarity
25/// computation and search functionality.
26///
27/// ### Internals
28/// - Corpus vocabulary
29/// - Sparse TF vectors per document
30/// - term index mapping
31/// - Cached IDF vector
32/// - Pluggable TF-IDF engine
33/// - Inverted document index
34///
35/// ### Type Parameters
36/// - `N`: Vector parameter type (e.g., `f32`, `f64`, `u16`)
37/// - `K`: Document key type (e.g., `String`, `usize`)
38/// - `E`: TF-IDF calculation engine
39///
40/// ### Notes
41/// - Requires an `Arc<Corpus>` on construction
42/// - `Corpus` can be shared across multiple vectorizers
43///
44/// ### Serialization
45/// Supported.  
46/// Serialized data includes the `Corpus` reference.
47///
48/// For corpus-independent storage, use [`TFIDFData`].
49#[derive(Debug, Clone)]
50pub struct TFIDFVectorizer<N = f16, K = String, E = DefaultTFIDFEngine>
51where
52    N: Num + Copy + Into<f64> + Send + Sync,
53    E: TFIDFEngine<N> + Send + Sync,
54    K: Clone + Send + Sync + Eq + std::hash::Hash,
55{
56    /// Document's TF Vector
57    pub documents: IndexMap<K, TFVector<N>>,
58    /// TF Vector's term dimension sample and reverse index
59    /// Key is never changed and unused terms are not removed
60    pub term_dim_rev_index: IndexMap<Box<str>, Vec<u32>>,
61    /// Corpus reference
62    pub corpus_ref: Arc<Corpus>,
63    /// IDF Vector
64    pub idf_cache: IDFVector,
65    _marker: std::marker::PhantomData<E>,
66}
67
68impl <N, K, E> TFIDFVectorizer<N, K, E>
69where
70    N: Num + Copy + Into<f64> + Send + Sync,
71    E: TFIDFEngine<N> + Send + Sync,
72    K: Clone + Send + Sync + Eq + Hash,
73{
74    /// Create a new TFIDFVectorizer instance
75    pub fn new(corpus_ref: Arc<Corpus>) -> Self {
76        let mut instance = Self {
77            documents: IndexMap::new(),
78            term_dim_rev_index: IndexMap::new(),
79            corpus_ref,
80            idf_cache: IDFVector::new(),
81            _marker: std::marker::PhantomData,
82        };
83        instance.re_calc_idf();
84        instance
85    }
86
87    /// set corpus reference
88    /// and recalculate idf
89    pub fn set_corpus_ref(&mut self, corpus_ref: Arc<Corpus>) {
90        self.corpus_ref = corpus_ref;
91        self.re_calc_idf();
92    }
93
94    /// Corpusに変更があればIDFを再計算する
95    pub fn update_idf(&mut self) {
96        if self.corpus_ref.get_gen_num() != self.idf_cache.latest_entropy {
97            self.re_calc_idf();
98        }
99        // 更新がなければ何もしない
100    }
101
102    /// CorpusからIDFを再計算する
103    fn re_calc_idf(&mut self) {
104        self.idf_cache.latest_entropy = self.corpus_ref.get_gen_num();
105        self.idf_cache.doc_num = self.corpus_ref.get_doc_num();
106        self.idf_cache.idf_vec = E::idf_vec(&self.corpus_ref, self.term_dim_rev_index.keys());
107    }
108}
109
110impl <N, K, E> TFIDFVectorizer<N, K, E>
111where
112    N: Num + Copy + Into<f64> + Send + Sync,
113    E: TFIDFEngine<N> + Send + Sync,
114    K: PartialEq + Clone + Send + Sync + Eq + Hash
115{
116    /// Add a document
117    /// The immediately referenced Corpus is also updated
118    pub fn add_doc(&mut self, key: K, doc: &TermFrequency) {
119        // 新語彙を追加 (O(|doc_vocab|))
120        for tok in doc.term_set(){
121            self.term_dim_rev_index
122            .entry_mut(tok.into_boxed_str())
123                .or_insert_with(Vec::new);
124        }
125        let tf_vec= E::tf_vec(doc, self.term_dim_rev_index.as_index_set());
126        // 追加した単語
127        let mut added_terms = Vec::new();
128        // 削除した単語
129        let mut removed_terms = Vec::new();
130        match self.documents.insert(key, tf_vec) {
131            InsertResult::New { index: id } => { 
132                self.documents.get_with_index(id).expect("unreachable").as_ind_slice().iter().for_each(|&idx| {
133                    self.term_dim_rev_index.get_with_index_mut(idx as usize).expect("unreachable").push(id as u32);
134                    added_terms.push(idx);
135                });
136            }
137            InsertResult::Override { old_value: old_tf, old_key: _, index: id } => {
138                let old_tf_ind_iter = old_tf.as_ind_slice().iter();
139                let new_tf_ind_iter = self.documents.get_with_index(id).expect("unreachable").as_ind_slice().iter();
140                let mut old_it = old_tf_ind_iter.fuse();
141                let mut new_it = new_tf_ind_iter.fuse();
142                let mut old_next = old_it.next();
143                let mut new_next = new_it.next();
144                while let (Some(old_idx), Some(new_idx)) = (old_next, new_next) {
145                    match old_idx.cmp(new_idx) {
146                        Ordering::Equal => {
147                            // 両方に存在 -> 何もしない
148                            old_next = old_it.next();
149                            new_next = new_it.next();
150                        }
151                        Ordering::Less => {
152                            // old にのみ存在 -> 削除
153                            let doc_keys = self.term_dim_rev_index.get_with_index_mut(*old_idx as usize).expect("unreachable");
154                            doc_keys.iter().position(|k| *k == id as u32).map(|pos| {
155                                doc_keys.swap_remove(pos);
156                            });
157                            removed_terms.push(*old_idx);
158                            old_next = old_it.next();
159                        }
160                        Ordering::Greater => {
161                            // new にのみ存在 -> 追加
162                            let doc_keys = self.term_dim_rev_index.get_with_index_mut(*new_idx as usize).expect("unreachable");
163                            doc_keys.push(id as u32);
164                            added_terms.push(*new_idx);
165                            new_next = new_it.next();
166                        }
167                    }
168                }
169            }
170        }
171        // コーパスも更新
172        let added_terms_str = added_terms.iter()
173            .filter_map(|&idx| self.term_dim_rev_index.get_key_with_index(idx as usize).cloned())
174            .collect::<Vec<Box<str>>>();
175        self.corpus_ref.add_set(&added_terms_str);
176        let removed_terms_str = removed_terms.iter()
177            .filter_map(|&idx| self.term_dim_rev_index.get_key_with_index(idx as usize).cloned())
178            .collect::<Vec<Box<str>>>();
179        self.corpus_ref.sub_set(&removed_terms_str);
180    }
181
182    pub fn del_doc(&mut self, key: &K)
183    where
184        K: PartialEq,
185    {
186        match self.documents.swap_remove(key) {
187            RemoveResult::Removed { old_value: tf_vec, old_key: _, index: id } => {
188                // 逆Indexから削除
189                let terms_idx = tf_vec.as_ind_slice();
190                terms_idx.iter().for_each(|&idx| {
191                    let doc_keys = self.term_dim_rev_index.get_with_index_mut(idx as usize).expect("unreachable");
192                    doc_keys.iter().position(|k| *k == id as u32).map(|pos| {
193                        doc_keys.swap_remove(pos);
194                    });
195                });
196                // コーパスからも削除
197                let terms = terms_idx.iter()
198                    .filter_map(|&idx| self.term_dim_rev_index.get_key_with_index(idx as usize).cloned())
199                    .collect::<Vec<Box<str>>>();
200                self.corpus_ref.sub_set(&terms);
201            }
202            RemoveResult::None => {}
203        }
204    }
205
206    /// Get TFVector by document ID
207    pub fn get_tf(&self, key: &K) -> Option<&TFVector<N>>
208    where
209        K: Eq + Hash,
210    {
211        self.documents.get(key)
212    }
213
214    /// Get TermFrequency by document ID
215    /// If quantized, there may be some error
216    /// Words not included in the corpus are ignored
217    pub fn get_tf_into_term_freq(&self, key: &K) -> Option<TermFrequency>
218    {
219        if let Some(tf_vec) = self.get_tf(key) {
220            let mut term_freq = TermFrequency::new();
221            tf_vec.raw_iter().for_each(|(idx, val)| {
222                let idx = idx as usize;
223                if let Some(term) = self.term_dim_rev_index.get_key_with_index(idx) {
224                    let term_num = E::tf_denorm(val);
225                    term_freq.set_term_count(term, term_num as u64);
226                } // out of range is ignored
227            });
228            Some(term_freq)
229        } else {
230            None
231        }
232    }
233
234    /// Check if a document with the given ID exists
235    pub fn contains_doc(&self, key: &K) -> bool
236    where
237        K: PartialEq,
238    {
239        self.documents.contains_key(key)
240    }
241
242    /// Check if the term exists in the term dimension sample
243    pub fn contains_term(&self, term: &str) -> bool {
244        self.term_dim_rev_index.contains_key(&Box::<str>::from(term))
245    }
246
247    /// Check if all terms in the given TermFrequency exist in the term dimension sample
248    pub fn contains_terms_from_freq(&self, freq: &TermFrequency) -> bool {
249        freq.term_set_ref_str().iter().all(|tok| self.contains_term(tok))
250    }
251
252    pub fn doc_num(&self) -> usize {
253        self.documents.len()
254    }
255}