tf_idf_vectorizer/vectorizer/
mod.rs1pub mod corpus;
2pub mod tfidf;
3pub mod token;
4pub mod serde;
5pub mod evaluate;
6
7use std::{rc::Rc, sync::Arc};
8use std::hash::Hash;
9
10use num_traits::Num;
11use ::serde::{Deserialize, Serialize};
12
13use crate::utils::datastruct::map::IndexMap;
14use crate::{utils::{datastruct::{vector::{ZeroSpVec, ZeroSpVecTrait}}, normalizer::DeNormalizer}, vectorizer::{corpus::Corpus, tfidf::{DefaultTFIDFEngine, TFIDFEngine}, token::TokenFrequency}};
15
16pub type KeyRc<K> = Rc<K>;
17
18#[derive(Debug, Clone)]
19pub struct TFIDFVectorizer<N = f32, K = String, E = DefaultTFIDFEngine>
20where
21 N: Num + Copy + Into<f64> + Send + Sync,
22 E: TFIDFEngine<N, K> + Send + Sync,
23 K: Clone + Send + Sync + Eq + std::hash::Hash,
24{
25 pub documents: IndexMap<KeyRc<K>, TFVector<N>>,
27 pub token_dim_rev_index: IndexMap<Box<str>, Vec<KeyRc<K>>>,
29 pub corpus_ref: Arc<Corpus>,
31 pub idf_cache: IDFVector<N>,
33 _marker: std::marker::PhantomData<E>,
34}
35
36#[derive(Debug, Serialize, Deserialize, Clone)]
38#[repr(align(64))]
39pub struct TFVector<N>
40where
41 N: Num + Copy,
42{
43 pub tf_vec: ZeroSpVec<N>,
46 pub token_sum: u64,
48 pub denormalize_num: f64,
51}
52
53#[allow(dead_code)]
58const TF_VECTOR_SIZE: usize = core::mem::size_of::<TFVector<f32>>();
59static_assertions::const_assert!(TF_VECTOR_SIZE == 64);
60#[allow(dead_code)]
61const ZSV_SIZE: usize = core::mem::size_of::<ZeroSpVec<f32>>();
62static_assertions::const_assert!(ZSV_SIZE == 48);
63
64impl<N> TFVector<N>
65where
66 N: Num + Copy,
67{
68 pub fn shrink_to_fit(&mut self) {
69 self.tf_vec.shrink_to_fit();
70 }
71}
72
73#[derive(Debug, Serialize, Deserialize, Clone)]
74pub struct IDFVector<N>
75where
76 N: Num,
77{
78 pub idf_vec: Vec<N>,
80 pub denormalize_num: f64,
82 pub latest_entropy: u64,
84 pub doc_num: u64,
86}
87
88impl <N> IDFVector<N>
89where
90 N: Num,
91{
92 pub fn new() -> Self {
93 Self {
94 idf_vec: Vec::new(),
95 denormalize_num: 1.0,
96 latest_entropy: 0,
97 doc_num: 0,
98 }
99 }
100}
101
102impl <N, K, E> TFIDFVectorizer<N, K, E>
103where
104 N: Num + Copy + Into<f64> + Send + Sync,
105 E: TFIDFEngine<N, K> + Send + Sync,
106 K: Clone + Send + Sync + Eq + Hash,
107{
108 pub fn new(corpus_ref: Arc<Corpus>) -> Self {
110 let mut instance = Self {
111 documents: IndexMap::new(),
112 token_dim_rev_index: IndexMap::new(),
113 corpus_ref,
114 idf_cache: IDFVector::new(),
115 _marker: std::marker::PhantomData,
116 };
117 instance.re_calc_idf();
118 instance
119 }
120
121 pub fn set_corpus_ref(&mut self, corpus_ref: Arc<Corpus>) {
124 self.corpus_ref = corpus_ref;
125 self.re_calc_idf();
126 }
127
128 pub fn update_idf(&mut self) {
130 if self.corpus_ref.get_gen_num() != self.idf_cache.latest_entropy {
131 self.re_calc_idf();
132 }
133 }
135
136 fn re_calc_idf(&mut self) {
138 self.idf_cache.latest_entropy = self.corpus_ref.get_gen_num();
139 self.idf_cache.doc_num = self.corpus_ref.get_doc_num();
140 (self.idf_cache.idf_vec, self.idf_cache.denormalize_num) = E::idf_vec(&self.corpus_ref, self.token_dim_rev_index.keys());
141 }
142}
143
144impl <N, K, E> TFIDFVectorizer<N, K, E>
145where
146 N: Num + Copy + Into<f64> + Send + Sync,
147 E: TFIDFEngine<N, K> + Send + Sync,
148 K: PartialEq + Clone + Send + Sync + Eq + Hash
149{
150 pub fn add_doc(&mut self, key: K, doc: &TokenFrequency) {
153 let key_rc = KeyRc::new(key);
155 if self.documents.contains_key(&key_rc) {
156 self.del_doc(&key_rc);
157 }
158 let token_sum = doc.token_sum();
159 self.add_corpus(doc);
161 for tok in doc.token_set_ref_str() {
163 self.token_dim_rev_index
164 .entry_mut(&Box::from(tok))
165 .or_insert_with(Vec::new)
166 .push(Rc::clone(&key_rc)); }
168
169 let (tf_vec, denormalize_num) = E::tf_vec(doc, self.token_dim_rev_index.keys());
170 let mut doc = TFVector {
171 tf_vec,
172 token_sum,
173 denormalize_num,
174 };
175 doc.shrink_to_fit();
176 self.documents.insert(&key_rc, doc);
177 }
178
179 pub fn del_doc(&mut self, key: &K)
180 where
181 K: PartialEq,
182 {
183 let rc_key = KeyRc::new(key.clone());
184 if let Some(doc) = self.documents.get(&rc_key) {
185 let tokens = doc.tf_vec.raw_iter()
186 .filter_map(|(idx, _)| {
187 let doc_keys = self.token_dim_rev_index.get_with_index_mut(idx);
188 if let Some(doc_keys) = doc_keys {
189 let rc_key = KeyRc::new(key.clone());
191 doc_keys.retain(|k| *k == rc_key);
192 }
193 let token = self.token_dim_rev_index.get_key_with_index(idx).cloned();
194 token
195 }).collect::<Vec<Box<str>>>();
196 self.documents.swap_remove(&rc_key);
198 self.corpus_ref.sub_set(&tokens);
200 }
201 }
202
203 pub fn get_tf(&self, key: &K) -> Option<&TFVector<N>>
205 where
206 K: Eq + Hash,
207 {
208 let rc_key = KeyRc::new(key.clone());
209 self.documents.get(&rc_key)
210 }
211
212 pub fn get_tf_into_token_freq(&self, key: &K) -> Option<TokenFrequency>
216 {
217 if let Some(tf_vec) = self.get_tf(key) {
218 let mut token_freq = TokenFrequency::new();
219 tf_vec.tf_vec.raw_iter().for_each(|(idx, val)| {
220 if let Some(token) = self.token_dim_rev_index.get_key_with_index(idx) {
221 let val_f64: f64 = (*val).into();
222 let token_num: f64 = tf_vec.token_sum.denormalize(tf_vec.denormalize_num) * val_f64;
223 token_freq.set_token_count(token, token_num as u64);
224 } });
226 Some(token_freq)
227 } else {
228 None
229 }
230 }
231
232 pub fn contains_doc(&self, key: &K) -> bool
234 where
235 K: PartialEq,
236 {
237 let rc_key = KeyRc::new(key.clone());
238 self.documents.contains_key(&rc_key)
239 }
240
241 pub fn contains_token(&self, token: &str) -> bool {
243 self.token_dim_rev_index.contains_key(&Box::<str>::from(token))
244 }
245
246 pub fn contains_tokens_from_freq(&self, freq: &TokenFrequency) -> bool {
248 freq.token_set_ref_str().iter().all(|tok| self.contains_token(tok))
249 }
250
251 pub fn doc_num(&self) -> usize {
252 self.documents.len()
253 }
254
255 fn add_corpus(&mut self, doc: &TokenFrequency) {
258 self.corpus_ref.add_set(&doc.token_set_ref_str());
260 }
261}