tf_idf_vectorizer/vectorizer/
mod.rs1pub mod corpus;
2pub mod tfidf;
3pub mod token;
4pub mod serde;
5pub mod evaluate;
6
7use std::sync::Arc;
8
9use indexmap::IndexSet;
10use num::Num;
11use ::serde::{Deserialize, Serialize};
12
13use crate::{utils::{math::vector::{ZeroSpVec, ZeroSpVecTrait}, normalizer::DeNormalizer}, vectorizer::{corpus::Corpus, tfidf::{DefaultTFIDFEngine, TFIDFEngine}, token::TokenFrequency}};
14use ahash::RandomState;
15
16#[derive(Debug, Clone)]
17pub struct TFIDFVectorizer<N = f32, K = String, E = DefaultTFIDFEngine>
18where
19 N: Num + Copy,
20 E: TFIDFEngine<N>,
21{
22 pub documents: Vec<TFVector<N, K>>,
24 pub token_dim_sample: IndexSet<String, RandomState>,
26 pub corpus_ref: Arc<Corpus>,
28 pub idf: IDFVector<N>,
30 _marker: std::marker::PhantomData<E>,
31}
32
33#[derive(Debug, Serialize, Deserialize, Clone)]
34pub struct TFVector<N, K>
35where
36 N: Num + Copy,
37{
38 pub tf_vec: ZeroSpVec<N>,
41 pub token_sum: u64,
43 pub denormalize_num: f64,
46 pub key: K,
48}
49
50impl<N, K> TFVector<N, K>
51where
52 N: Num + Copy,
53{
54 pub fn shrink_to_fit(&mut self) {
55 self.tf_vec.shrink_to_fit();
56 }
57}
58
59#[derive(Debug, Serialize, Deserialize, Clone)]
60pub struct IDFVector<N>
61where
62 N: Num,
63{
64 pub idf_vec: Vec<N>,
66 pub denormalize_num: f64,
68 pub latest_entropy: u64,
70 pub doc_num: u64,
72}
73
74impl <N> IDFVector<N>
75where
76 N: Num,
77{
78 pub fn new() -> Self {
79 Self {
80 idf_vec: Vec::new(),
81 denormalize_num: 1.0,
82 latest_entropy: 0,
83 doc_num: 0,
84 }
85 }
86}
87
88impl <N, K, E> TFIDFVectorizer<N, K, E>
89where
90 N: Num + Copy,
91 E: TFIDFEngine<N>,
92{
93 pub fn new(corpus_ref: Arc<Corpus>) -> Self {
95 let mut instance = Self {
96 documents: Vec::new(),
97 token_dim_sample: IndexSet::with_hasher(RandomState::new()),
98 corpus_ref,
99 idf: IDFVector::new(),
100 _marker: std::marker::PhantomData,
101 };
102 instance.re_calc_idf();
103 instance
104 }
105
106 pub fn set_corpus_ref(&mut self, corpus_ref: Arc<Corpus>) {
109 self.corpus_ref = corpus_ref;
110 self.re_calc_idf();
111 }
112
113 pub fn update_idf(&mut self) {
115 if self.corpus_ref.get_gen_num() != self.idf.latest_entropy {
116 self.re_calc_idf();
117 }
118 }
120
121 fn re_calc_idf(&mut self) {
123 self.idf.latest_entropy = self.corpus_ref.get_gen_num();
124 self.idf.doc_num = self.corpus_ref.get_doc_num();
125 (self.idf.idf_vec, self.idf.denormalize_num) = E::idf_vec(&self.corpus_ref, &self.token_dim_sample)
126 }
127}
128
129impl <N, K, E> TFIDFVectorizer<N, K, E>
130where
131 N: Num + Copy + Into<f64>,
132 E: TFIDFEngine<N>,
133 K: PartialEq
134{
135 pub fn add_doc(&mut self, doc_id: K, doc: &TokenFrequency) {
138 let token_sum = doc.token_sum();
139 self.add_corpus(doc);
141 for tok in doc.token_set_ref_str() {
143 if !self.token_dim_sample.contains(tok) {
144 self.token_dim_sample.insert(tok.to_string());
145 }
146 }
147
148 let (tf_vec, denormalize_num) = E::tf_vec(doc, &self.token_dim_sample);
149 let mut doc = TFVector {
150 tf_vec,
151 token_sum,
152 denormalize_num,
153 key: doc_id,
154 };
155 doc.shrink_to_fit();
156 self.documents.push(doc);
157 }
158
159 pub fn del_doc(&mut self, doc_id: &K)
160 where
161 K: PartialEq,
162 {
163 if let Some(pos) = self.documents.iter().position(|doc| &doc.key == doc_id) {
164 let doc = &self.documents[pos];
165 let token_set = doc.tf_vec.raw_iter()
166 .filter_map(|(idx, _)| self.token_dim_sample.get_index(idx).map(|s| s.as_str()))
167 .collect::<Vec<&str>>();
168 self.corpus_ref.sub_set(&token_set);
170 self.documents.remove(pos);
172 }
173 }
174
175 pub fn get_tf(&self, key: &K) -> Option<&TFVector<N, K>>
177 where
178 K: PartialEq,
179 {
180 self.documents.iter().find(|doc| &doc.key == key)
181 }
182
183 pub fn get_tf_into_token_freq(&self, key: &K) -> Option<TokenFrequency>
187 {
188 if let Some(tf_vec) = self.get_tf(key) {
189 let mut token_freq = TokenFrequency::new();
190 tf_vec.tf_vec.raw_iter().for_each(|(idx, val)| {
191 if let Some(token) = self.token_dim_sample.get_index(idx) {
192 let val_f64: f64 = (*val).into();
193 let token_num: f64 = tf_vec.token_sum.denormalize(tf_vec.denormalize_num) * val_f64;
194 token_freq.set_token_count(token, token_num as u64);
195 } });
197 Some(token_freq)
198 } else {
199 None
200 }
201 }
202
203 pub fn contains_doc(&self, key: &K) -> bool
205 where
206 K: PartialEq,
207 {
208 self.documents.iter().any(|doc| &doc.key == key)
209 }
210
211 pub fn contains_token(&self, token: &str) -> bool {
213 self.token_dim_sample.contains(token)
214 }
215
216 pub fn contains_tokens_from_freq(&self, freq: &TokenFrequency) -> bool {
218 freq.token_set_ref_str().iter().all(|tok| self.token_dim_sample.contains(*tok))
219 }
220
221 pub fn doc_num(&self) -> usize {
222 self.documents.len()
223 }
224
225 fn add_corpus(&mut self, doc: &TokenFrequency) {
228 self.corpus_ref.add_set(&doc.token_set_ref_str());
230 }
231}