tf_idf_vectorizer/vectorizer/
mod.rs1pub mod corpus;
2pub mod tfidf;
3pub mod token;
4pub mod serde;
5pub mod evaluate;
6
7use std::{rc::Rc, sync::Arc};
8use std::hash::Hash;
9
10use num_traits::Num;
11use ::serde::{Deserialize, Serialize};
12
13use crate::utils::datastruct::map::IndexMap;
14use crate::{utils::{datastruct::{vector::{ZeroSpVec, ZeroSpVecTrait}}, normalizer::DeNormalizer}, vectorizer::{corpus::Corpus, tfidf::{DefaultTFIDFEngine, TFIDFEngine}, token::TokenFrequency}};
15
16pub type KeyRc<K> = Rc<K>;
17
18#[derive(Debug, Clone)]
19pub struct TFIDFVectorizer<N = f32, K = String, E = DefaultTFIDFEngine>
20where
21 N: Num + Copy + Into<f64> + Send + Sync,
22 E: TFIDFEngine<N, K> + Send + Sync,
23 K: Clone + Send + Sync + Eq + std::hash::Hash,
24{
25 pub documents: IndexMap<KeyRc<K>, TFVector<N>>,
27 pub token_dim_rev_index: IndexMap<Box<str>, Vec<KeyRc<K>>>,
29 pub corpus_ref: Arc<Corpus>,
31 pub idf_cache: IDFVector,
33 _marker: std::marker::PhantomData<E>,
34}
35
36#[derive(Debug, Serialize, Deserialize, Clone)]
38#[repr(align(64))]
39pub struct TFVector<N>
40where
41 N: Num + Copy,
42{
43 pub tf_vec: ZeroSpVec<N>,
46 pub token_sum: u64,
48 pub denormalize_num: f64,
51}
52
53#[allow(dead_code)]
58const TF_VECTOR_SIZE: usize = core::mem::size_of::<TFVector<f32>>();
59static_assertions::const_assert!(TF_VECTOR_SIZE == 64);
60#[allow(dead_code)]
61const ZSV_SIZE: usize = core::mem::size_of::<ZeroSpVec<f32>>();
62static_assertions::const_assert!(ZSV_SIZE == 48);
63
64impl<N> TFVector<N>
65where
66 N: Num + Copy,
67{
68 pub fn shrink_to_fit(&mut self) {
69 self.tf_vec.shrink_to_fit();
70 }
71}
72
73#[derive(Debug, Serialize, Deserialize, Clone)]
74pub struct IDFVector
75{
76 pub idf_vec: Vec<f32>,
78 pub denormalize_num: f64,
80 pub latest_entropy: u64,
82 pub doc_num: u64,
84}
85
86impl IDFVector
87{
88 pub fn new() -> Self {
89 Self {
90 idf_vec: Vec::new(),
91 denormalize_num: 1.0,
92 latest_entropy: 0,
93 doc_num: 0,
94 }
95 }
96}
97
98impl <N, K, E> TFIDFVectorizer<N, K, E>
99where
100 N: Num + Copy + Into<f64> + Send + Sync,
101 E: TFIDFEngine<N, K> + Send + Sync,
102 K: Clone + Send + Sync + Eq + Hash,
103{
104 pub fn new(corpus_ref: Arc<Corpus>) -> Self {
106 let mut instance = Self {
107 documents: IndexMap::new(),
108 token_dim_rev_index: IndexMap::new(),
109 corpus_ref,
110 idf_cache: IDFVector::new(),
111 _marker: std::marker::PhantomData,
112 };
113 instance.re_calc_idf();
114 instance
115 }
116
117 pub fn set_corpus_ref(&mut self, corpus_ref: Arc<Corpus>) {
120 self.corpus_ref = corpus_ref;
121 self.re_calc_idf();
122 }
123
124 pub fn update_idf(&mut self) {
126 if self.corpus_ref.get_gen_num() != self.idf_cache.latest_entropy {
127 self.re_calc_idf();
128 }
129 }
131
132 fn re_calc_idf(&mut self) {
134 self.idf_cache.latest_entropy = self.corpus_ref.get_gen_num();
135 self.idf_cache.doc_num = self.corpus_ref.get_doc_num();
136 (self.idf_cache.idf_vec, self.idf_cache.denormalize_num) = E::idf_vec(&self.corpus_ref, self.token_dim_rev_index.keys());
137 }
138}
139
140impl <N, K, E> TFIDFVectorizer<N, K, E>
141where
142 N: Num + Copy + Into<f64> + Send + Sync,
143 E: TFIDFEngine<N, K> + Send + Sync,
144 K: PartialEq + Clone + Send + Sync + Eq + Hash
145{
146 pub fn add_doc(&mut self, key: K, doc: &TokenFrequency) {
149 let key_rc = KeyRc::new(key);
151 if self.documents.contains_key(&key_rc) {
152 self.del_doc(&key_rc);
153 }
154 let token_sum = doc.token_sum();
155 self.add_corpus(doc);
157 for tok in doc.token_set(){
159 self.token_dim_rev_index
160 .entry_mut(tok.into_boxed_str())
161 .or_insert_with(Vec::new)
162 .push(Rc::clone(&key_rc)); }
164
165 let (mut tf_vec, denormalize_num) = E::tf_vec(doc, self.token_dim_rev_index.as_index_set());
166 tf_vec.shrink_to_fit();
167 let mut doc = TFVector {
168 tf_vec,
169 token_sum,
170 denormalize_num,
171 };
172 doc.shrink_to_fit();
173 self.documents.insert(key_rc, doc);
174 }
175
176 pub fn del_doc(&mut self, key: &K)
177 where
178 K: PartialEq,
179 {
180 let rc_key = KeyRc::new(key.clone());
181 if let Some(doc) = self.documents.get(&rc_key) {
182 let tokens = doc.tf_vec.raw_iter()
183 .filter_map(|(idx, _)| {
184 let doc_keys = self.token_dim_rev_index.get_with_index_mut(idx);
185 if let Some(doc_keys) = doc_keys {
186 let rc_key = KeyRc::new(key.clone());
188 doc_keys.retain(|k| *k != rc_key);
189 }
190 let token = self.token_dim_rev_index.get_key_with_index(idx).cloned();
191 token
192 }).collect::<Vec<Box<str>>>();
193 self.documents.swap_remove(&rc_key);
195 self.corpus_ref.sub_set(&tokens);
197 }
198 }
199
200 pub fn get_tf(&self, key: &K) -> Option<&TFVector<N>>
202 where
203 K: Eq + Hash,
204 {
205 let rc_key = KeyRc::new(key.clone());
206 self.documents.get(&rc_key)
207 }
208
209 pub fn get_tf_into_token_freq(&self, key: &K) -> Option<TokenFrequency>
213 {
214 if let Some(tf_vec) = self.get_tf(key) {
215 let mut token_freq = TokenFrequency::new();
216 tf_vec.tf_vec.raw_iter().for_each(|(idx, val)| {
217 if let Some(token) = self.token_dim_rev_index.get_key_with_index(idx) {
218 let val_f64 = (*val).into();
219 let token_num = tf_vec.token_sum.denormalize(tf_vec.denormalize_num) as f64 * val_f64;
220 token_freq.set_token_count(token, token_num as u64);
221 } });
223 Some(token_freq)
224 } else {
225 None
226 }
227 }
228
229 pub fn contains_doc(&self, key: &K) -> bool
231 where
232 K: PartialEq,
233 {
234 let rc_key = KeyRc::new(key.clone());
235 self.documents.contains_key(&rc_key)
236 }
237
238 pub fn contains_token(&self, token: &str) -> bool {
240 self.token_dim_rev_index.contains_key(&Box::<str>::from(token))
241 }
242
243 pub fn contains_tokens_from_freq(&self, freq: &TokenFrequency) -> bool {
245 freq.token_set_ref_str().iter().all(|tok| self.contains_token(tok))
246 }
247
248 pub fn doc_num(&self) -> usize {
249 self.documents.len()
250 }
251
252 fn add_corpus(&self, doc: &TokenFrequency) {
255 self.corpus_ref.add_set(&doc.token_set_ref_str());
257 }
258}