tf_idf_vectorizer/vectorizer/
mod.rs

1pub mod corpus;
2pub mod tfidf;
3pub mod term;
4pub mod serde;
5pub mod evaluate;
6
7use std::{rc::Rc, sync::Arc};
8use std::hash::Hash;
9
10use half::f16;
11use num_traits::Num;
12use ::serde::{Deserialize, Serialize};
13
14use crate::utils::datastruct::vector::{TFVector, TFVectorTrait};
15use crate::{DefaultTFIDFEngine, TFIDFEngine, TermFrequency};
16use crate::utils::datastruct::map::IndexMap;
17use crate::Corpus;
18
19pub type KeyRc<K> = Rc<K>;
20
21/// TF-IDF Vectorizer
22///
23/// The top-level struct of this crate, providing the main TF-IDF vectorizer features.
24///
25/// It converts a document collection into TF-IDF vectors and supports similarity
26/// computation and search functionality.
27///
28/// ### Internals
29/// - Corpus vocabulary
30/// - Sparse TF vectors per document
31/// - term index mapping
32/// - Cached IDF vector
33/// - Pluggable TF-IDF engine
34/// - Inverted document index
35///
36/// ### Type Parameters
37/// - `N`: Vector parameter type (e.g., `f32`, `f64`, `u16`)
38/// - `K`: Document key type (e.g., `String`, `usize`)
39/// - `E`: TF-IDF calculation engine
40///
41/// ### Notes
42/// - Requires an `Arc<Corpus>` on construction
43/// - `Corpus` can be shared across multiple vectorizers
44///
45/// ### Serialization
46/// Supported.  
47/// Serialized data includes the `Corpus` reference.
48///
49/// For corpus-independent storage, use [`TFIDFData`].
50#[derive(Debug, Clone)]
51pub struct TFIDFVectorizer<N = f16, K = String, E = DefaultTFIDFEngine>
52where
53    N: Num + Copy + Into<f64> + Send + Sync,
54    E: TFIDFEngine<N> + Send + Sync,
55    K: Clone + Send + Sync + Eq + std::hash::Hash,
56{
57    /// Document's TF Vector
58    pub documents: IndexMap<KeyRc<K>, TFVector<N>>,
59    /// TF Vector's term dimension sample and reverse index
60    pub term_dim_rev_index: IndexMap<Box<str>, Vec<KeyRc<K>>>,
61    /// Corpus reference
62    pub corpus_ref: Arc<Corpus>,
63    /// IDF Vector
64    pub idf_cache: IDFVector,
65    _marker: std::marker::PhantomData<E>,
66}
67
68#[derive(Debug, Serialize, Deserialize, Clone)]
69pub struct IDFVector
70{
71    /// IDF Vector it is not sparse because it is mostly filled
72    pub idf_vec: Vec<f32>,
73    /// latest entropy
74    pub latest_entropy: u64,
75    /// document count
76    pub doc_num: u64,
77}
78
79impl IDFVector
80{
81    pub fn new() -> Self {
82        Self {
83            idf_vec: Vec::new(),
84            latest_entropy: 0,
85            doc_num: 0,
86        }
87    }
88}
89
90impl <N, K, E> TFIDFVectorizer<N, K, E>
91where
92    N: Num + Copy + Into<f64> + Send + Sync,
93    E: TFIDFEngine<N> + Send + Sync,
94    K: Clone + Send + Sync + Eq + Hash,
95{
96    /// Create a new TFIDFVectorizer instance
97    pub fn new(corpus_ref: Arc<Corpus>) -> Self {
98        let mut instance = Self {
99            documents: IndexMap::new(),
100            term_dim_rev_index: IndexMap::new(),
101            corpus_ref,
102            idf_cache: IDFVector::new(),
103            _marker: std::marker::PhantomData,
104        };
105        instance.re_calc_idf();
106        instance
107    }
108
109    /// set corpus reference
110    /// and recalculate idf
111    pub fn set_corpus_ref(&mut self, corpus_ref: Arc<Corpus>) {
112        self.corpus_ref = corpus_ref;
113        self.re_calc_idf();
114    }
115
116    /// Corpusに変更があればIDFを再計算する
117    pub fn update_idf(&mut self) {
118        if self.corpus_ref.get_gen_num() != self.idf_cache.latest_entropy {
119            self.re_calc_idf();
120        }
121        // 更新がなければ何もしない
122    }
123
124    /// CorpusからIDFを再計算する
125    fn re_calc_idf(&mut self) {
126        self.idf_cache.latest_entropy = self.corpus_ref.get_gen_num();
127        self.idf_cache.doc_num = self.corpus_ref.get_doc_num();
128        self.idf_cache.idf_vec = E::idf_vec(&self.corpus_ref, self.term_dim_rev_index.keys());
129    }
130}
131
132impl <N, K, E> TFIDFVectorizer<N, K, E>
133where
134    N: Num + Copy + Into<f64> + Send + Sync,
135    E: TFIDFEngine<N> + Send + Sync,
136    K: PartialEq + Clone + Send + Sync + Eq + Hash
137{
138    /// Add a document
139    /// The immediately referenced Corpus is also updated
140    pub fn add_doc(&mut self, key: K, doc: &TermFrequency) {
141        // key_rcを作成
142        let key_rc = KeyRc::new(key);
143        if self.documents.contains_key(&key_rc) {
144            self.del_doc(&key_rc);
145        }
146        // ドキュメントのトークンをコーパスに追加
147        self.add_corpus(doc);
148        // 新語彙を差分追加 (O(|doc_vocab|))
149        for tok in doc.term_set(){
150            self.term_dim_rev_index
151                .entry_mut(tok.into_boxed_str())
152                .or_insert_with(Vec::new)
153                .push(Rc::clone(&key_rc)); // 逆Indexに追加
154        }
155
156        let tf_vec= E::tf_vec(doc, self.term_dim_rev_index.as_index_set());
157        self.documents.insert(key_rc, tf_vec);
158    }
159
160    pub fn del_doc(&mut self, key: &K)
161    where
162        K: PartialEq,
163    {
164        let rc_key = KeyRc::new(key.clone());
165        if let Some(tf_vec) = self.documents.get(&rc_key) {
166            let terms = tf_vec.raw_iter()
167                .filter_map(|(idx, _)| {
168                    let idx = idx as usize;
169                    let doc_keys = self.term_dim_rev_index.get_with_index_mut(idx);
170                    if let Some(doc_keys) = doc_keys {
171                        // 逆Indexから削除
172                        let rc_key = KeyRc::new(key.clone());
173                        doc_keys.retain(|k| *k != rc_key);
174                    }
175                    let term = self.term_dim_rev_index.get_key_with_index(idx).cloned();
176                    term
177                }).collect::<Vec<Box<str>>>();
178            // ドキュメントを削除
179            self.documents.swap_remove(&rc_key);
180            // コーパスからも削除
181            self.corpus_ref.sub_set(&terms);
182        }
183    }
184
185    /// Get TFVector by document ID
186    pub fn get_tf(&self, key: &K) -> Option<&TFVector<N>>
187    where
188        K: Eq + Hash,
189    {
190        let rc_key = KeyRc::new(key.clone());
191        self.documents.get(&rc_key)
192    }
193
194    /// Get TermFrequency by document ID
195    /// If quantized, there may be some error
196    /// Words not included in the corpus are ignored
197    pub fn get_tf_into_term_freq(&self, key: &K) -> Option<TermFrequency>
198    {
199        if let Some(tf_vec) = self.get_tf(key) {
200            let mut term_freq = TermFrequency::new();
201            tf_vec.raw_iter().for_each(|(idx, val)| {
202                let idx = idx as usize;
203                if let Some(term) = self.term_dim_rev_index.get_key_with_index(idx) {
204                    let term_num = E::tf_denorm(val);
205                    term_freq.set_term_count(term, term_num as u64);
206                } // out of range is ignored
207            });
208            Some(term_freq)
209        } else {
210            None
211        }
212    }
213
214    /// Check if a document with the given ID exists
215    pub fn contains_doc(&self, key: &K) -> bool
216    where
217        K: PartialEq,
218    {
219        let rc_key = KeyRc::new(key.clone());
220        self.documents.contains_key(&rc_key)
221    }
222
223    /// Check if the term exists in the term dimension sample
224    pub fn contains_term(&self, term: &str) -> bool {
225        self.term_dim_rev_index.contains_key(&Box::<str>::from(term))
226    }
227
228    /// Check if all terms in the given TermFrequency exist in the term dimension sample
229    pub fn contains_terms_from_freq(&self, freq: &TermFrequency) -> bool {
230        freq.term_set_ref_str().iter().all(|tok| self.contains_term(tok))
231    }
232
233    pub fn doc_num(&self) -> usize {
234        self.documents.len()
235    }
236
237    /// add document to corpus
238    /// update the referenced corpus
239    fn add_corpus(&self, doc: &TermFrequency) {
240        // add document to corpus
241        self.corpus_ref.add_set(&doc.term_set_ref_str());
242    }
243}