tf_idf_vectorizer/vectorizer/
mod.rs1pub mod corpus;
2pub mod tfidf;
3pub mod term;
4pub mod serde;
5pub mod evaluate;
6
7use std::{rc::Rc, sync::Arc};
8use std::hash::Hash;
9
10use half::f16;
11use num_traits::Num;
12use ::serde::{Deserialize, Serialize};
13
14use crate::utils::datastruct::vector::{TFVector, TFVectorTrait};
15use crate::{DefaultTFIDFEngine, TFIDFEngine, TermFrequency};
16use crate::utils::datastruct::map::IndexMap;
17use crate::Corpus;
18
19pub type KeyRc<K> = Rc<K>;
20
21#[derive(Debug, Clone)]
51pub struct TFIDFVectorizer<N = f16, K = String, E = DefaultTFIDFEngine>
52where
53 N: Num + Copy + Into<f64> + Send + Sync,
54 E: TFIDFEngine<N> + Send + Sync,
55 K: Clone + Send + Sync + Eq + std::hash::Hash,
56{
57 pub documents: IndexMap<KeyRc<K>, TFVector<N>>,
59 pub term_dim_rev_index: IndexMap<Box<str>, Vec<KeyRc<K>>>,
61 pub corpus_ref: Arc<Corpus>,
63 pub idf_cache: IDFVector,
65 _marker: std::marker::PhantomData<E>,
66}
67
68#[derive(Debug, Serialize, Deserialize, Clone)]
69pub struct IDFVector
70{
71 pub idf_vec: Vec<f32>,
73 pub latest_entropy: u64,
75 pub doc_num: u64,
77}
78
79impl IDFVector
80{
81 pub fn new() -> Self {
82 Self {
83 idf_vec: Vec::new(),
84 latest_entropy: 0,
85 doc_num: 0,
86 }
87 }
88}
89
90impl <N, K, E> TFIDFVectorizer<N, K, E>
91where
92 N: Num + Copy + Into<f64> + Send + Sync,
93 E: TFIDFEngine<N> + Send + Sync,
94 K: Clone + Send + Sync + Eq + Hash,
95{
96 pub fn new(corpus_ref: Arc<Corpus>) -> Self {
98 let mut instance = Self {
99 documents: IndexMap::new(),
100 term_dim_rev_index: IndexMap::new(),
101 corpus_ref,
102 idf_cache: IDFVector::new(),
103 _marker: std::marker::PhantomData,
104 };
105 instance.re_calc_idf();
106 instance
107 }
108
109 pub fn set_corpus_ref(&mut self, corpus_ref: Arc<Corpus>) {
112 self.corpus_ref = corpus_ref;
113 self.re_calc_idf();
114 }
115
116 pub fn update_idf(&mut self) {
118 if self.corpus_ref.get_gen_num() != self.idf_cache.latest_entropy {
119 self.re_calc_idf();
120 }
121 }
123
124 fn re_calc_idf(&mut self) {
126 self.idf_cache.latest_entropy = self.corpus_ref.get_gen_num();
127 self.idf_cache.doc_num = self.corpus_ref.get_doc_num();
128 self.idf_cache.idf_vec = E::idf_vec(&self.corpus_ref, self.term_dim_rev_index.keys());
129 }
130}
131
132impl <N, K, E> TFIDFVectorizer<N, K, E>
133where
134 N: Num + Copy + Into<f64> + Send + Sync,
135 E: TFIDFEngine<N> + Send + Sync,
136 K: PartialEq + Clone + Send + Sync + Eq + Hash
137{
138 pub fn add_doc(&mut self, key: K, doc: &TermFrequency) {
141 let key_rc = KeyRc::new(key);
143 if self.documents.contains_key(&key_rc) {
144 self.del_doc(&key_rc);
145 }
146 self.add_corpus(doc);
148 for tok in doc.term_set(){
150 self.term_dim_rev_index
151 .entry_mut(tok.into_boxed_str())
152 .or_insert_with(Vec::new)
153 .push(Rc::clone(&key_rc)); }
155
156 let tf_vec= E::tf_vec(doc, self.term_dim_rev_index.as_index_set());
157 self.documents.insert(key_rc, tf_vec);
158 }
159
160 pub fn del_doc(&mut self, key: &K)
161 where
162 K: PartialEq,
163 {
164 let rc_key = KeyRc::new(key.clone());
165 if let Some(tf_vec) = self.documents.get(&rc_key) {
166 let terms = tf_vec.raw_iter()
167 .filter_map(|(idx, _)| {
168 let idx = idx as usize;
169 let doc_keys = self.term_dim_rev_index.get_with_index_mut(idx);
170 if let Some(doc_keys) = doc_keys {
171 let rc_key = KeyRc::new(key.clone());
173 doc_keys.retain(|k| *k != rc_key);
174 }
175 let term = self.term_dim_rev_index.get_key_with_index(idx).cloned();
176 term
177 }).collect::<Vec<Box<str>>>();
178 self.documents.swap_remove(&rc_key);
180 self.corpus_ref.sub_set(&terms);
182 }
183 }
184
185 pub fn get_tf(&self, key: &K) -> Option<&TFVector<N>>
187 where
188 K: Eq + Hash,
189 {
190 let rc_key = KeyRc::new(key.clone());
191 self.documents.get(&rc_key)
192 }
193
194 pub fn get_tf_into_term_freq(&self, key: &K) -> Option<TermFrequency>
198 {
199 if let Some(tf_vec) = self.get_tf(key) {
200 let mut term_freq = TermFrequency::new();
201 tf_vec.raw_iter().for_each(|(idx, val)| {
202 let idx = idx as usize;
203 if let Some(term) = self.term_dim_rev_index.get_key_with_index(idx) {
204 let term_num = E::tf_denorm(val);
205 term_freq.set_term_count(term, term_num as u64);
206 } });
208 Some(term_freq)
209 } else {
210 None
211 }
212 }
213
214 pub fn contains_doc(&self, key: &K) -> bool
216 where
217 K: PartialEq,
218 {
219 let rc_key = KeyRc::new(key.clone());
220 self.documents.contains_key(&rc_key)
221 }
222
223 pub fn contains_term(&self, term: &str) -> bool {
225 self.term_dim_rev_index.contains_key(&Box::<str>::from(term))
226 }
227
228 pub fn contains_terms_from_freq(&self, freq: &TermFrequency) -> bool {
230 freq.term_set_ref_str().iter().all(|tok| self.contains_term(tok))
231 }
232
233 pub fn doc_num(&self) -> usize {
234 self.documents.len()
235 }
236
237 fn add_corpus(&self, doc: &TermFrequency) {
240 self.corpus_ref.add_set(&doc.term_set_ref_str());
242 }
243}