tf_idf_vectorizer/vectorizer/
mod.rs1pub mod corpus;
2pub mod tfidf;
3pub mod token;
4pub mod serde;
5pub mod evaluate;
6
7use std::{rc::Rc, sync::Arc};
8use std::hash::Hash;
9
10use half::f16;
11use num_traits::Num;
12use ::serde::{Deserialize, Serialize};
13
14use crate::utils::datastruct::map::IndexMap;
15use crate::{utils::{datastruct::{vector::{ZeroSpVec, ZeroSpVecTrait}}, normalizer::DeNormalizer}, vectorizer::{corpus::Corpus, tfidf::{DefaultTFIDFEngine, TFIDFEngine}, token::TokenFrequency}};
16
17pub type KeyRc<K> = Rc<K>;
18
19#[derive(Debug, Clone)]
20pub struct TFIDFVectorizer<N = f16, K = String, E = DefaultTFIDFEngine>
21where
22 N: Num + Copy + Into<f64> + Send + Sync,
23 E: TFIDFEngine<N, K> + Send + Sync,
24 K: Clone + Send + Sync + Eq + std::hash::Hash,
25{
26 pub documents: IndexMap<KeyRc<K>, TFVector<N>>,
28 pub token_dim_rev_index: IndexMap<Box<str>, Vec<KeyRc<K>>>,
30 pub corpus_ref: Arc<Corpus>,
32 pub idf_cache: IDFVector,
34 _marker: std::marker::PhantomData<E>,
35}
36
37#[derive(Debug, Serialize, Deserialize, Clone)]
39#[repr(align(64))]
40pub struct TFVector<N>
41where
42 N: Num + Copy,
43{
44 pub tf_vec: ZeroSpVec<N>,
47 pub token_sum: u64,
49 pub denormalize_num: f64,
52}
53
54#[allow(dead_code)]
59const TF_VECTOR_SIZE: usize = core::mem::size_of::<TFVector<f32>>();
60static_assertions::const_assert!(TF_VECTOR_SIZE == 64);
61#[allow(dead_code)]
62const ZSV_SIZE: usize = core::mem::size_of::<ZeroSpVec<f32>>();
63static_assertions::const_assert!(ZSV_SIZE == 48);
64
65impl<N> TFVector<N>
66where
67 N: Num + Copy,
68{
69 pub fn shrink_to_fit(&mut self) {
70 self.tf_vec.shrink_to_fit();
71 }
72}
73
74#[derive(Debug, Serialize, Deserialize, Clone)]
75pub struct IDFVector
76{
77 pub idf_vec: Vec<f32>,
79 pub denormalize_num: f64,
81 pub latest_entropy: u64,
83 pub doc_num: u64,
85}
86
87impl IDFVector
88{
89 pub fn new() -> Self {
90 Self {
91 idf_vec: Vec::new(),
92 denormalize_num: 1.0,
93 latest_entropy: 0,
94 doc_num: 0,
95 }
96 }
97}
98
99impl <N, K, E> TFIDFVectorizer<N, K, E>
100where
101 N: Num + Copy + Into<f64> + Send + Sync,
102 E: TFIDFEngine<N, K> + Send + Sync,
103 K: Clone + Send + Sync + Eq + Hash,
104{
105 pub fn new(corpus_ref: Arc<Corpus>) -> Self {
107 let mut instance = Self {
108 documents: IndexMap::new(),
109 token_dim_rev_index: IndexMap::new(),
110 corpus_ref,
111 idf_cache: IDFVector::new(),
112 _marker: std::marker::PhantomData,
113 };
114 instance.re_calc_idf();
115 instance
116 }
117
118 pub fn set_corpus_ref(&mut self, corpus_ref: Arc<Corpus>) {
121 self.corpus_ref = corpus_ref;
122 self.re_calc_idf();
123 }
124
125 pub fn update_idf(&mut self) {
127 if self.corpus_ref.get_gen_num() != self.idf_cache.latest_entropy {
128 self.re_calc_idf();
129 }
130 }
132
133 fn re_calc_idf(&mut self) {
135 self.idf_cache.latest_entropy = self.corpus_ref.get_gen_num();
136 self.idf_cache.doc_num = self.corpus_ref.get_doc_num();
137 (self.idf_cache.idf_vec, self.idf_cache.denormalize_num) = E::idf_vec(&self.corpus_ref, self.token_dim_rev_index.keys());
138 }
139}
140
141impl <N, K, E> TFIDFVectorizer<N, K, E>
142where
143 N: Num + Copy + Into<f64> + Send + Sync,
144 E: TFIDFEngine<N, K> + Send + Sync,
145 K: PartialEq + Clone + Send + Sync + Eq + Hash
146{
147 pub fn add_doc(&mut self, key: K, doc: &TokenFrequency) {
150 let key_rc = KeyRc::new(key);
152 if self.documents.contains_key(&key_rc) {
153 self.del_doc(&key_rc);
154 }
155 let token_sum = doc.token_sum();
156 self.add_corpus(doc);
158 for tok in doc.token_set(){
160 self.token_dim_rev_index
161 .entry_mut(tok.into_boxed_str())
162 .or_insert_with(Vec::new)
163 .push(Rc::clone(&key_rc)); }
165
166 let (mut tf_vec, denormalize_num) = E::tf_vec(doc, self.token_dim_rev_index.as_index_set());
167 tf_vec.shrink_to_fit();
168 let mut doc = TFVector {
169 tf_vec,
170 token_sum,
171 denormalize_num,
172 };
173 doc.shrink_to_fit();
174 self.documents.insert(key_rc, doc);
175 }
176
177 pub fn del_doc(&mut self, key: &K)
178 where
179 K: PartialEq,
180 {
181 let rc_key = KeyRc::new(key.clone());
182 if let Some(doc) = self.documents.get(&rc_key) {
183 let tokens = doc.tf_vec.raw_iter()
184 .filter_map(|(idx, _)| {
185 let doc_keys = self.token_dim_rev_index.get_with_index_mut(idx);
186 if let Some(doc_keys) = doc_keys {
187 let rc_key = KeyRc::new(key.clone());
189 doc_keys.retain(|k| *k != rc_key);
190 }
191 let token = self.token_dim_rev_index.get_key_with_index(idx).cloned();
192 token
193 }).collect::<Vec<Box<str>>>();
194 self.documents.swap_remove(&rc_key);
196 self.corpus_ref.sub_set(&tokens);
198 }
199 }
200
201 pub fn get_tf(&self, key: &K) -> Option<&TFVector<N>>
203 where
204 K: Eq + Hash,
205 {
206 let rc_key = KeyRc::new(key.clone());
207 self.documents.get(&rc_key)
208 }
209
210 pub fn get_tf_into_token_freq(&self, key: &K) -> Option<TokenFrequency>
214 {
215 if let Some(tf_vec) = self.get_tf(key) {
216 let mut token_freq = TokenFrequency::new();
217 tf_vec.tf_vec.raw_iter().for_each(|(idx, val)| {
218 if let Some(token) = self.token_dim_rev_index.get_key_with_index(idx) {
219 let val_f64 = (*val).into();
220 let token_num = tf_vec.token_sum.denormalize(tf_vec.denormalize_num) as f64 * val_f64;
221 token_freq.set_token_count(token, token_num as u64);
222 } });
224 Some(token_freq)
225 } else {
226 None
227 }
228 }
229
230 pub fn contains_doc(&self, key: &K) -> bool
232 where
233 K: PartialEq,
234 {
235 let rc_key = KeyRc::new(key.clone());
236 self.documents.contains_key(&rc_key)
237 }
238
239 pub fn contains_token(&self, token: &str) -> bool {
241 self.token_dim_rev_index.contains_key(&Box::<str>::from(token))
242 }
243
244 pub fn contains_tokens_from_freq(&self, freq: &TokenFrequency) -> bool {
246 freq.token_set_ref_str().iter().all(|tok| self.contains_token(tok))
247 }
248
249 pub fn doc_num(&self) -> usize {
250 self.documents.len()
251 }
252
253 fn add_corpus(&self, doc: &TokenFrequency) {
256 self.corpus_ref.add_set(&doc.token_set_ref_str());
258 }
259}