tf_idf_vectorizer/vectorizer/
index.rs1use std::ops::{AddAssign, MulAssign};
2
3use num::Num;
4use serde::{Deserialize, Serialize};
5
6use crate::utils::{math::vector::ZeroSpVec, normalizer::{IntoNormalizer, NormalizedBounded, NormalizedMultiply}};
7use rayon::prelude::*;
8
9use super::token::TokenFrequency;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Index<N>
15where N: Num + Into<f64> + AddAssign + MulAssign + NormalizedMultiply + Copy + NormalizedBounded {
16 matrix: Vec<ZeroSpVec<N>>,
17 doc_id: Vec<String>,
18 corpus_token_freq: TokenFrequency,
19}
20
21impl<N> Index<N>
22where N: Num + Into<f64> + AddAssign + MulAssign + NormalizedMultiply + Copy + NormalizedBounded, f64: IntoNormalizer<N> {
23 pub fn new() -> Self {
24 Self {
25 matrix: Vec::new(),
26 doc_id: Vec::new(),
27 corpus_token_freq: TokenFrequency::new(),
28 }
29 }
30
31 pub fn doc_num(&self) -> usize {
33 self.matrix.len()
34 }
35
36 pub fn token_num(&self) -> usize {
39 self.corpus_token_freq.token_num()
40 }
41
42 pub fn add_doc(&mut self, doc_id: String, tokens: &[&str]) {
44 let mut doc_tf = TokenFrequency::new();
46 doc_tf.add_tokens(tokens);
47
48 let old_corpus_token_num = self.corpus_token_freq.token_num();
50 self.corpus_token_freq.add_tokens(tokens);
51 let added_corpus_token_num = self.corpus_token_freq.token_num() - old_corpus_token_num;
52 self.doc_id.push(doc_id);
53
54 let mut vec: ZeroSpVec<N> = ZeroSpVec::with_capacity(doc_tf.token_num());
56 for token in self.corpus_token_freq.token_set_ref_str().iter() {
57 let tf_val: N = doc_tf.tf_token(token); vec.push(tf_val);
59 }
60
61 if added_corpus_token_num > 0 {
62 for other_tf in self.matrix.iter_mut() {
64 other_tf.add_dim(added_corpus_token_num);
65 }
66 }
67 vec.shrink_to_fit();
68 self.matrix.push(vec);
70 }
71
72 pub fn generate_query_mask(&self, tokens: &[&str]) -> ZeroSpVec<N> {
73 let mut query_tf = TokenFrequency::new();
75 query_tf.add_tokens(tokens);
76
77 let mut vec: ZeroSpVec<N> = ZeroSpVec::with_capacity(query_tf.token_num());
79 for token in self.corpus_token_freq.token_set_ref_str().iter() {
80 let tf_val: N = if query_tf.contains_token(token) { N::max_normalized() } else { N::zero() }; vec.push(tf_val);
82 }
83 vec
84 }
85
86 pub fn generate_query(&self, tokens: &[&str]) -> ZeroSpVec<N> {
87 let mut query_tf = TokenFrequency::new();
89 query_tf.add_tokens(tokens);
90
91 let mut vec: ZeroSpVec<N> = ZeroSpVec::with_capacity(query_tf.token_num());
93 for token in self.corpus_token_freq.token_set_ref_str().iter() {
94 let tf_val: N = query_tf.tf_token(token); vec.push(tf_val);
96 }
97 vec
98 }
99
100 pub fn search_cosine_similarity(&self, query: &ZeroSpVec<N>) -> Vec<(String, f64)> {
101 let mut result = Vec::new();
102
103 let idf_query: ZeroSpVec<N> = query.hadamard_normalized_vec(
105 &self.corpus_token_freq.
106 idf_vector_ref_str::<N>(self.matrix.len() as u64).into_iter().map(|(_, idf)| idf).collect::<Vec<N>>()
107 );
108
109 for (i, doc_vec) in self.matrix.iter().enumerate() {
111 let similarity = doc_vec.cosine_similarity_normalized::<f64>(&idf_query);
112 if similarity != 0.0 {
113 result.push((self.doc_id[i].clone(), similarity));
114 }
115 }
116
117 result.sort_by(|a, b| b.1.total_cmp(&a.1));
119 result
120 }
121
122 pub fn search_cosine_similarity_parallel(&self, query: &ZeroSpVec<N>, thread_count: usize) -> Vec<(String, f64)>
123 where
124 N: Send + Sync,
125 {
126 let idf_query: ZeroSpVec<N> = query.hadamard_normalized_vec(
128 &self
129 .corpus_token_freq
130 .idf_vector_ref_str::<N>(self.matrix.len() as u64)
131 .into_iter()
132 .map(|(_, idf)| idf)
133 .collect::<Vec<N>>(),
134 );
135
136 let pool = rayon::ThreadPoolBuilder::new()
137 .num_threads(thread_count)
138 .build()
139 .expect("Failed to build thread pool");
140
141 let mut result: Vec<(String, f64)> = pool.install(|| {
142 self.matrix
143 .par_iter()
144 .enumerate()
145 .filter_map(|(i, doc_vec)| {
146 let similarity = doc_vec.cosine_similarity_normalized::<f64>(&idf_query);
147 if similarity != 0.0 {
148 Some((self.doc_id[i].clone(), similarity))
149 } else {
150 None
151 }
152 })
153 .collect()
154 });
155
156 result.sort_by(|a, b| b.1.total_cmp(&a.1));
158 result
159 }
160}