tf_idf_vectorizer/vectorizer/
index.rs1use std::ops::{AddAssign, MulAssign};
2
3use num::Num;
4use serde::{Deserialize, Serialize};
5
6use crate::utils::{math::vector::ZeroSpVec, normalizer::{IntoNormalizer, NormalizedBounded, NormalizedMultiply}};
7use rayon::prelude::*;
8
9use super::token::TokenFrequency;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Index<N>
15where N: Num + Into<f64> + AddAssign + MulAssign + NormalizedMultiply + Copy + NormalizedBounded {
16 matrix: Vec<ZeroSpVec<N>>,
17 doc_id: Vec<String>,
18 corpus_token_freq: TokenFrequency,
19}
20
21impl<N> Index<N>
22where N: Num + Into<f64> + AddAssign + MulAssign + NormalizedMultiply + Copy + NormalizedBounded, f64: IntoNormalizer<N> {
23 pub fn new() -> Self {
25 Self {
26 matrix: Vec::new(),
27 doc_id: Vec::new(),
28 corpus_token_freq: TokenFrequency::new(),
29 }
30 }
31
32 pub fn doc_num(&self) -> usize {
34 self.matrix.len()
35 }
36
37 pub fn token_num(&self) -> usize {
40 self.corpus_token_freq.token_num()
41 }
42
43 pub fn add_doc(&mut self, doc_id: String, tokens: &[&str]) {
49 let mut doc_tf = TokenFrequency::new();
51 doc_tf.add_tokens(tokens);
52
53 let old_corpus_token_num = self.corpus_token_freq.token_num();
55 self.corpus_token_freq.add_tokens(tokens);
56 let added_corpus_token_num = self.corpus_token_freq.token_num() - old_corpus_token_num;
57 self.doc_id.push(doc_id);
58
59 let mut vec: ZeroSpVec<N> = ZeroSpVec::with_capacity(doc_tf.token_num());
61 for token in self.corpus_token_freq.token_set_ref_str().iter() {
62 let tf_val: N = doc_tf.tf_token(token); vec.push(tf_val);
64 }
65
66 if added_corpus_token_num > 0 {
67 for other_tf in self.matrix.iter_mut() {
69 other_tf.add_dim(added_corpus_token_num);
70 }
71 }
72 vec.shrink_to_fit();
73 self.matrix.push(vec);
75 }
76
77 pub fn generate_query_mask(&self, tokens: &[&str]) -> ZeroSpVec<N> {
86 let mut query_tf = TokenFrequency::new();
88 query_tf.add_tokens(tokens);
89
90 let mut vec: ZeroSpVec<N> = ZeroSpVec::with_capacity(query_tf.token_num());
92 for token in self.corpus_token_freq.token_set_ref_str().iter() {
93 let tf_val: N = if query_tf.contains_token(token) { N::max_normalized() } else { N::zero() }; vec.push(tf_val);
95 }
96 vec
97 }
98
99 pub fn generate_query(&self, tokens: &[&str]) -> ZeroSpVec<N> {
107 let mut query_tf = TokenFrequency::new();
109 query_tf.add_tokens(tokens);
110
111 let mut vec: ZeroSpVec<N> = ZeroSpVec::with_capacity(query_tf.token_num());
113 for token in self.corpus_token_freq.token_set_ref_str().iter() {
114 let tf_val: N = query_tf.tf_token(token); vec.push(tf_val);
116 }
117 vec
118 }
119
120 pub fn search_cosine_similarity(&self, query: &ZeroSpVec<N>) -> Vec<(String, f64)> {
128 let mut result = Vec::new();
129
130 let idf_vec =
131 self.corpus_token_freq.
132 idf_vector_ref_str::<N>(self.matrix.len() as u64)
133 .into_iter()
134 .map(|(_, idf)| idf)
135 .collect::<Vec<N>>();
136
137 let idf_query: ZeroSpVec<N> = query.hadamard_normalized_vec(&idf_vec);
139
140 for (i, doc_vec) in self.matrix.iter().enumerate() {
142 let tf_idf_doc_vec = doc_vec.hadamard_normalized_vec(&idf_vec);
143 let similarity = tf_idf_doc_vec.cosine_similarity_normalized::<f64>(&idf_query);
144 if similarity != 0.0 {
145 result.push((self.doc_id[i].clone(), similarity));
146 }
147 }
148
149 result.sort_by(|a, b| b.1.total_cmp(&a.1));
151 result
152 }
153
154 pub fn search_cosine_similarity_parallel(&self, query: &ZeroSpVec<N>, thread_count: usize) -> Vec<(String, f64)>
165 where
166 N: Send + Sync,
167 {
168 let idf_vec =
169 self.corpus_token_freq.
170 idf_vector_ref_str::<N>(self.matrix.len() as u64)
171 .into_iter()
172 .map(|(_, idf)| idf)
173 .collect::<Vec<N>>();
174
175 let idf_query: ZeroSpVec<N> = query.hadamard_normalized_vec(&idf_vec);
177
178 let pool = rayon::ThreadPoolBuilder::new()
179 .num_threads(thread_count)
180 .build()
181 .expect("Failed to build thread pool");
182
183 let mut result: Vec<(String, f64)> = pool.install(|| {
184 self.matrix
185 .par_iter()
186 .enumerate()
187 .filter_map(|(i, doc_vec)| {
188 let tf_idf_doc_vec = doc_vec.hadamard_normalized_vec(&idf_vec);
189 let similarity = tf_idf_doc_vec.cosine_similarity_normalized::<f64>(&idf_query);
190 if similarity != 0.0 {
191 Some((self.doc_id[i].clone(), similarity))
192 } else {
193 None
194 }
195 })
196 .collect()
197 });
198
199 result.sort_by(|a, b| b.1.total_cmp(&a.1));
201 result
202 }
203}