tf_idf_vectorizer/vectorizer/
mod.rs1pub mod corpus;
2pub mod tfidf;
3pub mod token;
4pub mod serde;
5pub mod evaluate;
6
7use std::sync::Arc;
8
9use indexmap::IndexSet;
10use num::Num;
11use ::serde::{Deserialize, Serialize};
12
13use crate::{utils::{math::vector::{ZeroSpVec, ZeroSpVecTrait}, normalizer::DeNormalizer}, vectorizer::{corpus::Corpus, tfidf::{DefaultTFIDFEngine, TFIDFEngine}, token::TokenFrequency}};
14use ahash::RandomState;
15
16#[derive(Debug, Clone)]
17pub struct TFIDFVectorizer<N = f32, K = String, E = DefaultTFIDFEngine>
18where
19 N: Num + Copy,
20 E: TFIDFEngine<N>,
21{
22 pub documents: Vec<TFVector<N, K>>,
24 pub token_dim_sample: IndexSet<String, RandomState>,
26 pub corpus_ref: Arc<Corpus>,
28 pub idf: IDFVector<N>,
30 _marker: std::marker::PhantomData<E>,
31}
32
33#[derive(Debug, Serialize, Deserialize, Clone)]
34pub struct TFVector<N, K>
35where
36 N: Num + Copy,
37{
38 pub tf_vec: ZeroSpVec<N>,
41 pub token_sum: u64,
43 pub denormalize_num: f64,
46 pub key: K,
48}
49
50impl<N, K> TFVector<N, K>
51where
52 N: Num + Copy,
53{
54 pub fn shrink_to_fit(&mut self) {
55 self.tf_vec.shrink_to_fit();
56 }
57}
58
59#[derive(Debug, Serialize, Deserialize, Clone)]
60pub struct IDFVector<N>
61where
62 N: Num,
63{
64 pub idf_vec: Vec<N>,
66 pub denormalize_num: f64,
68 pub latest_entropy: u64,
70 pub doc_num: u64,
72}
73
74impl <N> IDFVector<N>
75where
76 N: Num,
77{
78 pub fn new() -> Self {
79 Self {
80 idf_vec: Vec::new(),
81 denormalize_num: 1.0,
82 latest_entropy: 0,
83 doc_num: 0,
84 }
85 }
86}
87
88impl <N, K, E> TFIDFVectorizer<N, K, E>
89where
90 N: Num + Copy,
91 E: TFIDFEngine<N>,
92{
93 pub fn new(corpus_ref: Arc<Corpus>) -> Self {
95 let mut instance = Self {
96 documents: Vec::new(),
97 token_dim_sample: IndexSet::with_hasher(RandomState::new()),
98 corpus_ref,
99 idf: IDFVector::new(),
100 _marker: std::marker::PhantomData,
101 };
102 instance.re_calc_idf();
103 instance
104 }
105
106 pub fn set_corpus_ref(&mut self, corpus_ref: Arc<Corpus>) {
109 self.corpus_ref = corpus_ref;
110 self.re_calc_idf();
111 }
112
113 pub fn update_idf(&mut self) {
115 if self.corpus_ref.get_gen_num() != self.idf.latest_entropy {
116 self.re_calc_idf();
117 }
118 }
120
121 fn re_calc_idf(&mut self) {
123 self.idf.latest_entropy = self.corpus_ref.get_gen_num();
124 self.idf.doc_num = self.corpus_ref.get_doc_num();
125 (self.idf.idf_vec, self.idf.denormalize_num) = E::idf_vec(&self.corpus_ref, &self.token_dim_sample)
126 }
127}
128
129impl <N, K, E> TFIDFVectorizer<N, K, E>
130where
131 N: Num + Copy,
132 E: TFIDFEngine<N>,
133{
134 pub fn add_doc(&mut self, doc_id: K, doc: &TokenFrequency) {
137 let token_sum = doc.token_sum();
138 self.add_corpus(doc);
140 for tok in doc.token_set_ref_str() {
142 if !self.token_dim_sample.contains(tok) {
143 self.token_dim_sample.insert(tok.to_string());
144 }
145 }
146
147 let (tf_vec, denormalize_num) = E::tf_vec(doc, &self.token_dim_sample);
148 let mut doc = TFVector {
149 tf_vec,
150 token_sum,
151 denormalize_num,
152 key: doc_id,
153 };
154 doc.shrink_to_fit();
155 self.documents.push(doc);
156 }
157
158 pub fn get_tf(&self, key: &K) -> Option<&TFVector<N, K>>
160 where
161 K: PartialEq,
162 {
163 self.documents.iter().find(|doc| &doc.key == key)
164 }
165
166 pub fn get_tf_into_token_freq(&self, key: &K) -> Option<TokenFrequency>
170 where
171 K: PartialEq,
172 N: Into<f64>,
173 {
174 if let Some(tf_vec) = self.get_tf(key) {
175 let mut token_freq = TokenFrequency::new();
176 tf_vec.tf_vec.raw_iter().for_each(|(idx, val)| {
177 if let Some(token) = self.token_dim_sample.get_index(idx) {
178 let val_f64: f64 = (*val).into();
179 let token_num: f64 = tf_vec.token_sum.denormalize(tf_vec.denormalize_num) * val_f64;
180 token_freq.set_token_count(token, token_num as u64);
181 } });
183 Some(token_freq)
184 } else {
185 None
186 }
187 }
188
189 pub fn contains_doc(&self, key: &K) -> bool
191 where
192 K: PartialEq,
193 {
194 self.documents.iter().any(|doc| &doc.key == key)
195 }
196
197 pub fn contains_token(&self, token: &str) -> bool {
199 self.token_dim_sample.contains(token)
200 }
201
202 pub fn contains_tokens_from_freq(&self, freq: &TokenFrequency) -> bool {
204 freq.token_set_ref_str().iter().all(|tok| self.token_dim_sample.contains(*tok))
205 }
206
207 fn add_corpus(&mut self, doc: &TokenFrequency) {
210 self.corpus_ref.add_set(&doc.token_set_ref_str());
212 }
213}