tf_idf_vectorizer/vectorizer/
mod.rs1pub mod corpus;
2pub mod tfidf;
3pub mod term;
4pub mod serde;
5pub mod evaluate;
6
7use std::{rc::Rc, sync::Arc};
8use std::hash::Hash;
9
10use half::f16;
11use num_traits::Num;
12use ::serde::{Deserialize, Serialize};
13
14use crate::utils::datastruct::vector::{TFVector, TFVectorTrait};
15use crate::{DefaultTFIDFEngine, TFIDFEngine, TermFrequency};
16use crate::utils::datastruct::map::IndexMap;
17use crate::Corpus;
18
19pub type KeyRc<K> = Rc<K>;
20
21#[derive(Debug, Clone)]
22pub struct TFIDFVectorizer<N = f16, K = String, E = DefaultTFIDFEngine>
23where
24 N: Num + Copy + Into<f64> + Send + Sync,
25 E: TFIDFEngine<N> + Send + Sync,
26 K: Clone + Send + Sync + Eq + std::hash::Hash,
27{
28 pub documents: IndexMap<KeyRc<K>, TFVector<N>>,
30 pub term_dim_rev_index: IndexMap<Box<str>, Vec<KeyRc<K>>>,
32 pub corpus_ref: Arc<Corpus>,
34 pub idf_cache: IDFVector,
36 _marker: std::marker::PhantomData<E>,
37}
38
39#[derive(Debug, Serialize, Deserialize, Clone)]
40pub struct IDFVector
41{
42 pub idf_vec: Vec<f32>,
44 pub latest_entropy: u64,
46 pub doc_num: u64,
48}
49
50impl IDFVector
51{
52 pub fn new() -> Self {
53 Self {
54 idf_vec: Vec::new(),
55 latest_entropy: 0,
56 doc_num: 0,
57 }
58 }
59}
60
61impl <N, K, E> TFIDFVectorizer<N, K, E>
62where
63 N: Num + Copy + Into<f64> + Send + Sync,
64 E: TFIDFEngine<N> + Send + Sync,
65 K: Clone + Send + Sync + Eq + Hash,
66{
67 pub fn new(corpus_ref: Arc<Corpus>) -> Self {
69 let mut instance = Self {
70 documents: IndexMap::new(),
71 term_dim_rev_index: IndexMap::new(),
72 corpus_ref,
73 idf_cache: IDFVector::new(),
74 _marker: std::marker::PhantomData,
75 };
76 instance.re_calc_idf();
77 instance
78 }
79
80 pub fn set_corpus_ref(&mut self, corpus_ref: Arc<Corpus>) {
83 self.corpus_ref = corpus_ref;
84 self.re_calc_idf();
85 }
86
87 pub fn update_idf(&mut self) {
89 if self.corpus_ref.get_gen_num() != self.idf_cache.latest_entropy {
90 self.re_calc_idf();
91 }
92 }
94
95 fn re_calc_idf(&mut self) {
97 self.idf_cache.latest_entropy = self.corpus_ref.get_gen_num();
98 self.idf_cache.doc_num = self.corpus_ref.get_doc_num();
99 self.idf_cache.idf_vec = E::idf_vec(&self.corpus_ref, self.term_dim_rev_index.keys());
100 }
101}
102
103impl <N, K, E> TFIDFVectorizer<N, K, E>
104where
105 N: Num + Copy + Into<f64> + Send + Sync,
106 E: TFIDFEngine<N> + Send + Sync,
107 K: PartialEq + Clone + Send + Sync + Eq + Hash
108{
109 pub fn add_doc(&mut self, key: K, doc: &TermFrequency) {
112 let key_rc = KeyRc::new(key);
114 if self.documents.contains_key(&key_rc) {
115 self.del_doc(&key_rc);
116 }
117 self.add_corpus(doc);
119 for tok in doc.term_set(){
121 self.term_dim_rev_index
122 .entry_mut(tok.into_boxed_str())
123 .or_insert_with(Vec::new)
124 .push(Rc::clone(&key_rc)); }
126
127 let tf_vec= E::tf_vec(doc, self.term_dim_rev_index.as_index_set());
128 self.documents.insert(key_rc, tf_vec);
129 }
130
131 pub fn del_doc(&mut self, key: &K)
132 where
133 K: PartialEq,
134 {
135 let rc_key = KeyRc::new(key.clone());
136 if let Some(tf_vec) = self.documents.get(&rc_key) {
137 let terms = tf_vec.raw_iter()
138 .filter_map(|(idx, _)| {
139 let idx = idx as usize;
140 let doc_keys = self.term_dim_rev_index.get_with_index_mut(idx);
141 if let Some(doc_keys) = doc_keys {
142 let rc_key = KeyRc::new(key.clone());
144 doc_keys.retain(|k| *k != rc_key);
145 }
146 let term = self.term_dim_rev_index.get_key_with_index(idx).cloned();
147 term
148 }).collect::<Vec<Box<str>>>();
149 self.documents.swap_remove(&rc_key);
151 self.corpus_ref.sub_set(&terms);
153 }
154 }
155
156 pub fn get_tf(&self, key: &K) -> Option<&TFVector<N>>
158 where
159 K: Eq + Hash,
160 {
161 let rc_key = KeyRc::new(key.clone());
162 self.documents.get(&rc_key)
163 }
164
165 pub fn get_tf_into_term_freq(&self, key: &K) -> Option<TermFrequency>
169 {
170 if let Some(tf_vec) = self.get_tf(key) {
171 let mut term_freq = TermFrequency::new();
172 tf_vec.raw_iter().for_each(|(idx, val)| {
173 let idx = idx as usize;
174 if let Some(term) = self.term_dim_rev_index.get_key_with_index(idx) {
175 let term_num = E::tf_denorm(val);
176 term_freq.set_term_count(term, term_num as u64);
177 } });
179 Some(term_freq)
180 } else {
181 None
182 }
183 }
184
185 pub fn contains_doc(&self, key: &K) -> bool
187 where
188 K: PartialEq,
189 {
190 let rc_key = KeyRc::new(key.clone());
191 self.documents.contains_key(&rc_key)
192 }
193
194 pub fn contains_term(&self, term: &str) -> bool {
196 self.term_dim_rev_index.contains_key(&Box::<str>::from(term))
197 }
198
199 pub fn contains_terms_from_freq(&self, freq: &TermFrequency) -> bool {
201 freq.term_set_ref_str().iter().all(|tok| self.contains_term(tok))
202 }
203
204 pub fn doc_num(&self) -> usize {
205 self.documents.len()
206 }
207
208 fn add_corpus(&self, doc: &TermFrequency) {
211 self.corpus_ref.add_set(&doc.term_set_ref_str());
213 }
214}