tf_idf_vectorizer/vectorizer/
mod.rs1pub mod corpus;
2pub mod tfidf;
3pub mod term;
4pub mod serde;
5pub mod evaluate;
6
7use std::cmp::Ordering;
8use std::sync::Arc;
9use std::hash::Hash;
10
11use half::f16;
12use num_traits::Num;
13
14use crate::utils::datastruct::map::index_map::{InsertResult, RemoveResult};
15use crate::utils::datastruct::vector::{TFVector, TFVectorTrait, IDFVector};
16use crate::{DefaultTFIDFEngine, TFIDFEngine, TermFrequency};
17use crate::utils::datastruct::map::IndexMap;
18use crate::Corpus;
19
20#[derive(Debug, Clone)]
50pub struct TFIDFVectorizer<N = f16, K = String, E = DefaultTFIDFEngine>
51where
52 N: Num + Copy + Into<f64> + Send + Sync,
53 E: TFIDFEngine<N> + Send + Sync,
54 K: Clone + Send + Sync + Eq + std::hash::Hash,
55{
56 pub documents: IndexMap<K, TFVector<N>>,
58 pub term_dim_rev_index: IndexMap<Box<str>, Vec<u32>>,
61 pub corpus_ref: Arc<Corpus>,
63 pub idf_cache: IDFVector,
65 _marker: std::marker::PhantomData<E>,
66}
67
68impl <N, K, E> TFIDFVectorizer<N, K, E>
69where
70 N: Num + Copy + Into<f64> + Send + Sync,
71 E: TFIDFEngine<N> + Send + Sync,
72 K: Clone + Send + Sync + Eq + Hash,
73{
74 pub fn new(corpus_ref: Arc<Corpus>) -> Self {
76 let mut instance = Self {
77 documents: IndexMap::new(),
78 term_dim_rev_index: IndexMap::new(),
79 corpus_ref,
80 idf_cache: IDFVector::new(),
81 _marker: std::marker::PhantomData,
82 };
83 instance.re_calc_idf();
84 instance
85 }
86
87 pub fn set_corpus_ref(&mut self, corpus_ref: Arc<Corpus>) {
90 self.corpus_ref = corpus_ref;
91 self.re_calc_idf();
92 }
93
94 pub fn update_idf(&mut self) {
96 if self.corpus_ref.get_gen_num() != self.idf_cache.latest_entropy {
97 self.re_calc_idf();
98 }
99 }
101
102 fn re_calc_idf(&mut self) {
104 self.idf_cache.latest_entropy = self.corpus_ref.get_gen_num();
105 self.idf_cache.doc_num = self.corpus_ref.get_doc_num();
106 self.idf_cache.idf_vec = E::idf_vec(&self.corpus_ref, self.term_dim_rev_index.keys());
107 }
108}
109
110impl <N, K, E> TFIDFVectorizer<N, K, E>
111where
112 N: Num + Copy + Into<f64> + Send + Sync,
113 E: TFIDFEngine<N> + Send + Sync,
114 K: PartialEq + Clone + Send + Sync + Eq + Hash
115{
116 pub fn add_doc(&mut self, key: K, doc: &TermFrequency) {
119 for tok in doc.term_set(){
121 self.term_dim_rev_index
122 .entry_mut(tok.into_boxed_str())
123 .or_insert_with(Vec::new);
124 }
125 let tf_vec= E::tf_vec(doc, self.term_dim_rev_index.as_index_set());
126 let mut added_terms = Vec::new();
128 let mut removed_terms = Vec::new();
130 match self.documents.insert(key, tf_vec) {
131 InsertResult::New { index: id } => {
132 self.documents.get_with_index(id).expect("unreachable").as_ind_slice().iter().for_each(|&idx| {
133 self.term_dim_rev_index.get_with_index_mut(idx as usize).expect("unreachable").push(id as u32);
134 added_terms.push(idx);
135 });
136 }
137 InsertResult::Override { old_value: old_tf, old_key: _, index: id } => {
138 let old_tf_ind_iter = old_tf.as_ind_slice().iter();
139 let new_tf_ind_iter = self.documents.get_with_index(id).expect("unreachable").as_ind_slice().iter();
140 let mut old_it = old_tf_ind_iter.fuse();
141 let mut new_it = new_tf_ind_iter.fuse();
142 let mut old_next = old_it.next();
143 let mut new_next = new_it.next();
144 while let (Some(old_idx), Some(new_idx)) = (old_next, new_next) {
145 match old_idx.cmp(new_idx) {
146 Ordering::Equal => {
147 old_next = old_it.next();
149 new_next = new_it.next();
150 }
151 Ordering::Less => {
152 let doc_keys = self.term_dim_rev_index.get_with_index_mut(*old_idx as usize).expect("unreachable");
154 doc_keys.iter().position(|k| *k == id as u32).map(|pos| {
155 doc_keys.swap_remove(pos);
156 });
157 removed_terms.push(*old_idx);
158 old_next = old_it.next();
159 }
160 Ordering::Greater => {
161 let doc_keys = self.term_dim_rev_index.get_with_index_mut(*new_idx as usize).expect("unreachable");
163 doc_keys.push(id as u32);
164 added_terms.push(*new_idx);
165 new_next = new_it.next();
166 }
167 }
168 }
169 }
170 }
171 let added_terms_str = added_terms.iter()
173 .filter_map(|&idx| self.term_dim_rev_index.get_key_with_index(idx as usize).cloned())
174 .collect::<Vec<Box<str>>>();
175 self.corpus_ref.add_set(&added_terms_str);
176 let removed_terms_str = removed_terms.iter()
177 .filter_map(|&idx| self.term_dim_rev_index.get_key_with_index(idx as usize).cloned())
178 .collect::<Vec<Box<str>>>();
179 self.corpus_ref.sub_set(&removed_terms_str);
180 }
181
182 pub fn del_doc(&mut self, key: &K)
183 where
184 K: PartialEq,
185 {
186 match self.documents.swap_remove(key) {
187 RemoveResult::Removed { old_value: tf_vec, old_key: _, index: id } => {
188 let terms_idx = tf_vec.as_ind_slice();
190 terms_idx.iter().for_each(|&idx| {
191 let doc_keys = self.term_dim_rev_index.get_with_index_mut(idx as usize).expect("unreachable");
192 doc_keys.iter().position(|k| *k == id as u32).map(|pos| {
193 doc_keys.swap_remove(pos);
194 });
195 });
196 let terms = terms_idx.iter()
198 .filter_map(|&idx| self.term_dim_rev_index.get_key_with_index(idx as usize).cloned())
199 .collect::<Vec<Box<str>>>();
200 self.corpus_ref.sub_set(&terms);
201 }
202 RemoveResult::None => {}
203 }
204 }
205
206 pub fn get_tf(&self, key: &K) -> Option<&TFVector<N>>
208 where
209 K: Eq + Hash,
210 {
211 self.documents.get(key)
212 }
213
214 pub fn get_tf_into_term_freq(&self, key: &K) -> Option<TermFrequency>
218 {
219 if let Some(tf_vec) = self.get_tf(key) {
220 let mut term_freq = TermFrequency::new();
221 tf_vec.raw_iter().for_each(|(idx, val)| {
222 let idx = idx as usize;
223 if let Some(term) = self.term_dim_rev_index.get_key_with_index(idx) {
224 let term_num = E::tf_denorm(val);
225 term_freq.set_term_count(term, term_num as u64);
226 } });
228 Some(term_freq)
229 } else {
230 None
231 }
232 }
233
234 pub fn contains_doc(&self, key: &K) -> bool
236 where
237 K: PartialEq,
238 {
239 self.documents.contains_key(key)
240 }
241
242 pub fn contains_term(&self, term: &str) -> bool {
244 self.term_dim_rev_index.contains_key(&Box::<str>::from(term))
245 }
246
247 pub fn contains_terms_from_freq(&self, freq: &TermFrequency) -> bool {
249 freq.term_set_ref_str().iter().all(|tok| self.contains_term(tok))
250 }
251
252 pub fn doc_num(&self) -> usize {
253 self.documents.len()
254 }
255}