tf_idf_vectorizer/vectorizer/
mod.rs1pub mod corpus;
2pub mod tfidf;
3pub mod term;
4pub mod serde;
5pub mod evaluate;
6
7use std::cmp::Ordering;
8use std::sync::Arc;
9use std::hash::Hash;
10
11use half::f16;
12use num_traits::Num;
13
14use crate::utils::datastruct::map::index_map::{EntryMut, InsertResult, RemoveResult};
15use crate::utils::datastruct::vector::{TFVector, TFVectorTrait, IDFVector};
16use crate::{DefaultTFIDFEngine, TFIDFEngine, TermFrequency};
17use crate::utils::datastruct::map::IndexMap;
18use crate::Corpus;
19
20#[derive(Debug, Clone)]
50pub struct TFIDFVectorizer<N = f16, K = String, E = DefaultTFIDFEngine>
51where
52 N: Num + Copy + Into<f64> + Send + Sync,
53 E: TFIDFEngine<N> + Send + Sync,
54 K: Clone + Send + Sync + Eq + std::hash::Hash,
55{
56 pub documents: IndexMap<K, TFVector<N>>,
58 pub term_dim_rev_index: IndexMap<Box<str>, Vec<u32>>,
61 pub corpus_ref: Arc<Corpus>,
63 pub idf_cache: IDFVector,
65 _marker: std::marker::PhantomData<E>,
66}
67
68impl <N, K, E> TFIDFVectorizer<N, K, E>
69where
70 N: Num + Copy + Into<f64> + Send + Sync,
71 E: TFIDFEngine<N> + Send + Sync,
72 K: Clone + Send + Sync + Eq + Hash,
73{
74 pub fn new(corpus_ref: Arc<Corpus>) -> Self {
76 let mut instance = Self {
77 documents: IndexMap::new(),
78 term_dim_rev_index: IndexMap::new(),
79 corpus_ref,
80 idf_cache: IDFVector::new(),
81 _marker: std::marker::PhantomData,
82 };
83 instance.re_calc_idf();
84 instance
85 }
86
87 pub fn set_corpus_ref(&mut self, corpus_ref: Arc<Corpus>) {
90 self.corpus_ref = corpus_ref;
91 self.re_calc_idf();
92 }
93
94 pub fn update_idf(&mut self) {
96 if self.corpus_ref.get_gen_num() != self.idf_cache.latest_entropy {
97 self.re_calc_idf();
98 }
99 }
101
102 fn re_calc_idf(&mut self) {
104 self.idf_cache.latest_entropy = self.corpus_ref.get_gen_num();
105 self.idf_cache.doc_num = self.corpus_ref.get_doc_num();
106 self.idf_cache.idf_vec = E::idf_vec(&self.corpus_ref, self.term_dim_rev_index.keys());
107 }
108}
109
110impl <N, K, E> TFIDFVectorizer<N, K, E>
111where
112 N: Num + Copy + Into<f64> + Send + Sync,
113 E: TFIDFEngine<N> + Send + Sync,
114 K: PartialEq + Clone + Send + Sync + Eq + Hash
115{
116 pub fn add_doc(&mut self, key: K, doc: &TermFrequency) {
119 for tok in doc.term_set(){
121 self.term_dim_rev_index
122 .entry_mut(tok.into_boxed_str())
123 .or_insert_with(Vec::new);
124 }
125 let tf_vec= E::tf_vec(doc, self.term_dim_rev_index.as_index_set());
126 let (new_terms, old_terms) = self.add_tf_vec(key, tf_vec);
127 if old_terms.is_empty() && new_terms.is_empty() {
129 return;
131 }
132 if old_terms.is_empty() {
133 let add_terms: Vec<&Box<str>> = new_terms.iter()
135 .filter_map(|&idx| self.term_dim_rev_index.get_key_with_index(idx as usize))
136 .collect();
137 self.corpus_ref.add_set(&add_terms);
138 return;
139 }
140 let mut new_terms_iter = new_terms.into_iter().fuse();
141 let mut old_terms_iter = old_terms.into_iter().fuse();
142 let mut new_term_next = new_terms_iter.next();
143 let mut old_term_next = old_terms_iter.next();
144 let mut add_terms = Vec::new();
145 let mut del_terms = Vec::new();
146 while let (Some(new_idx), Some(old_idx)) = (new_term_next, old_term_next) {
147 match new_idx.cmp(&old_idx) {
148 Ordering::Less => {
149 let term = self.term_dim_rev_index.get_key_with_index(new_idx as usize).expect("unreachable");
151 add_terms.push(term);
152 new_term_next = new_terms_iter.next();
153 }
154 Ordering::Greater => {
155 let term = self.term_dim_rev_index.get_key_with_index(old_idx as usize).expect("unreachable");
157 del_terms.push(term);
158 old_term_next = old_terms_iter.next();
159 }
160 Ordering::Equal => {
161 new_term_next = new_terms_iter.next();
163 old_term_next = old_terms_iter.next();
164 }
165 }
166 }
167 while let Some(new_idx) = new_term_next {
168 let term = self.term_dim_rev_index.get_key_with_index(new_idx as usize).expect("unreachable");
170 add_terms.push(term);
171 new_term_next = new_terms_iter.next();
172 }
173 while let Some(old_idx) = old_term_next {
174 let term = self.term_dim_rev_index.get_key_with_index(old_idx as usize).expect("unreachable");
176 del_terms.push(term);
177 old_term_next = old_terms_iter.next();
178 }
179 self.corpus_ref.add_set(&add_terms);
180 self.corpus_ref.sub_set(&del_terms);
181 }
182
183 fn add_tf_vec(&mut self, key: K, tf_vec: TFVector<N>) -> (Vec<u32>, Vec<u32>) {
184 let new_tf_terms_ind: Vec<u32> = tf_vec.as_ind_slice().to_vec();
185 match self.documents.insert(key, tf_vec) {
186 InsertResult::New { index: id } => {
187 self.documents.get_with_index(id).expect("unreachable").as_ind_slice().iter().for_each(|&idx| {
188 self.term_dim_rev_index.get_with_index_mut(idx as usize).expect("unreachable").push(id as u32);
189 });
190 (new_tf_terms_ind, Vec::new())
191 }
192 InsertResult::Override { old_value: old_tf, old_key: _, index: id } => {
193 let old_tf_ind_iter = old_tf.as_ind_slice().iter();
194 let new_tf_ind_iter = self.documents.get_with_index(id).expect("unreachable").as_ind_slice().iter();
195 let mut old_it = old_tf_ind_iter.fuse();
196 let mut new_it = new_tf_ind_iter.fuse();
197 let mut old_next = old_it.next();
198 let mut new_next = new_it.next();
199 while let (Some(old_idx), Some(new_idx)) = (old_next, new_next) {
200 match old_idx.cmp(new_idx) {
201 Ordering::Equal => {
202 old_next = old_it.next();
204 new_next = new_it.next();
205 }
206 Ordering::Less => {
207 let doc_keys = self.term_dim_rev_index.get_with_index_mut(*old_idx as usize).expect("unreachable");
209 doc_keys.iter().position(|k| *k == id as u32).map(|pos| {
210 doc_keys.swap_remove(pos);
211 });
212 old_next = old_it.next();
213 }
214 Ordering::Greater => {
215 let doc_keys = self.term_dim_rev_index.get_with_index_mut(*new_idx as usize).expect("unreachable");
217 doc_keys.push(id as u32);
218 new_next = new_it.next();
219 }
220 }
221 }
222 (new_tf_terms_ind, old_tf.as_ind_slice().to_vec())
223 }
224 }
225 }
226
227 pub fn del_doc(&mut self, key: &K)
228 where
229 K: PartialEq,
230 {
231 match self.documents.swap_remove(key) {
232 RemoveResult::Removed { old_value: tf_vec, old_key: _, index: id } => {
233 let terms_idx = tf_vec.as_ind_slice();
235 terms_idx.iter().for_each(|&idx| {
236 let doc_keys = self.term_dim_rev_index.get_with_index_mut(idx as usize).expect("unreachable");
237 doc_keys.iter().position(|k| *k == id as u32).map(|pos| {
238 doc_keys.swap_remove(pos);
239 });
240 });
241 let swap_doc_id = self.documents.len() as u32;
243 if swap_doc_id != id as u32 {
244 self.documents.get_with_index(id).expect("unreachable").as_ind_slice().iter().for_each(|&idx| {
245 let doc_keys = self.term_dim_rev_index.get_with_index_mut(idx as usize).expect("unreachable");
246 doc_keys.iter().position(|k| *k == swap_doc_id).map(|pos| {
247 doc_keys[pos] = id as u32;
248 });
249 });
250 }
251 let terms = terms_idx.iter()
253 .filter_map(|&idx| self.term_dim_rev_index.get_key_with_index(idx as usize))
254 .collect::<Vec<&Box<str>>>();
255 self.corpus_ref.sub_set(&terms);
256 }
257 RemoveResult::None => {}
258 }
259 }
260
261 pub fn get_tf(&self, key: &K) -> Option<&TFVector<N>>
263 where
264 K: Eq + Hash,
265 {
266 self.documents.get(key)
267 }
268
269 pub fn get_tf_into_term_freq(&self, key: &K) -> Option<TermFrequency>
273 {
274 if let Some(tf_vec) = self.get_tf(key) {
275 let mut term_freq = TermFrequency::new();
276 tf_vec.raw_iter().for_each(|(idx, val)| {
277 let idx = idx as usize;
278 if let Some(term) = self.term_dim_rev_index.get_key_with_index(idx) {
279 let term_num = E::tf_denorm(val);
280 term_freq.set_term_count(term, term_num as u64);
281 } });
283 Some(term_freq)
284 } else {
285 None
286 }
287 }
288
289 pub fn contains_doc(&self, key: &K) -> bool
291 where
292 K: PartialEq,
293 {
294 self.documents.contains_key(key)
295 }
296
297 pub fn contains_term(&self, term: &str) -> bool {
299 self.term_dim_rev_index.contains_key(&Box::<str>::from(term))
300 }
301
302 pub fn contains_terms_from_freq(&self, freq: &TermFrequency) -> bool {
304 freq.term_set_ref_str().iter().all(|tok| self.contains_term(tok))
305 }
306
307 pub fn doc_num(&self) -> usize {
308 self.documents.len()
309 }
310
311 pub fn merge(&mut self, other: Self)
313 where
314 K: Eq + Hash,
315 {
316 let perm_idxs: Vec<u32> = other.term_dim_rev_index.into_iter().map(|(term, _)| {
318 match self.term_dim_rev_index.entry_mut(term) {
319 EntryMut::Occupied { index, ..} => index as u32,
320 EntryMut::Vacant { key, map } => {
321 match map.insert(key, Vec::new()) {
322 InsertResult::New { index } => index as u32,
323 InsertResult::Override { .. } => unreachable!(),
324 }
325 },
326 }
327 }).collect();
328 other.documents.into_iter().for_each(|(key, mut tf_vec)| {
330 tf_vec.perm(&perm_idxs);
331 let (_, old_tf_terms_ind) = self.add_tf_vec(key, tf_vec);
332 let del_terms = old_tf_terms_ind.into_iter().map(|old_idx| {
334 self.term_dim_rev_index.get_key_with_index(old_idx as usize).expect("unreachable")
335 }).collect::<Vec<&Box<str>>>();
336 self.corpus_ref.sub_set(&del_terms);
337 });
338 }
339}