tf_idf_vectorizer/vectorizer/tfidf.rs
1use half::f16;
2use num_traits::{Num, Pow};
3
4use crate::{Corpus, TermFrequency, utils::datastruct::{map::IndexSet, vector::{TFVector, TFVectorTrait}}};
5
6
7
8pub trait TFIDFEngine<N>: Send + Sync
9where
10 N: Num + Copy
11{
12 /// Method to generate the IDF vector
13 /// # Arguments
14 /// * `corpus` - The corpus
15 /// * `term_dim_sample` - term dimension sample
16 /// # Returns
17 /// * `Vec<N>` - The IDF vector
18 /// * `denormalize_num` - Value for denormalization
19 fn idf_vec(corpus: &Corpus, term_dim_sample: &Vec<Box<str>>) -> Vec<f32> {
20 let mut idf_vec = Vec::with_capacity(term_dim_sample.len());
21 let doc_num = corpus.get_doc_num() as f64;
22 for term in term_dim_sample.iter() {
23 let doc_freq = corpus.get_term_count(term);
24 idf_vec.push((doc_num / (doc_freq as f64 + 1.0)) as f32);
25 }
26 idf_vec
27 }
28 /// Method to generate the TF vector
29 /// # Arguments
30 /// * `freq` - term frequency
31 /// * `term_dim_sample` - term dimension sample
32 /// # Returns
33 /// * `(ZeroSpVec<N>, f64)` - TF vector and value for denormalization
34 fn tf_vec(freq: &TermFrequency, term_dim_sample: &IndexSet<Box<str>>) -> TFVector<N>;
35
36 fn tf_denorm(val: N) -> u32;
37}
38
39/// デフォルトのTF-IDFエンジン
40/// `f16`, `f32`, `f64`, `u32`, `u16`, `u8`の型に対応
41#[derive(Debug)]
42pub struct DefaultTFIDFEngine;
43impl DefaultTFIDFEngine {
44 pub fn new() -> Self {
45 DefaultTFIDFEngine
46 }
47}
48
49impl TFIDFEngine<f16> for DefaultTFIDFEngine {
50 // fn idf_vec(corpus: &Corpus, term_dim_sample: &Vec<Box<str>>) -> (Vec<f16>, f64) {
51 // let mut idf_vec = Vec::with_capacity(term_dim_sample.len());
52 // let doc_num = corpus.get_doc_num() as f64;
53 // for term in term_dim_sample.iter() {
54 // let doc_freq = corpus.get_term_count(term);
55 // idf_vec.push(f16::from_f64(doc_num / (doc_freq as f64 + 1.0)));
56 // }
57 // (idf_vec, 1.0)
58 // }
59
60 fn tf_vec(freq: &TermFrequency, term_dim_sample: &IndexSet<Box<str>>) -> TFVector<f16> {
61 // Build sparse TF vector: only non-zero entries are stored
62 let term_sum = freq.term_sum() as u32;
63 let len = freq.term_num();
64 let mut ind_vec: Vec<u32> = Vec::with_capacity(len);
65 let mut val_vec: Vec<f16> = Vec::with_capacity(len);
66 for (term, count) in freq.iter() {
67 let count = (count as f32).sqrt();
68 if let Some(idx) = term_dim_sample.get_index(term) {
69 ind_vec.push(idx as u32);
70 val_vec.push(f16::from_f32(count));
71 }
72 }
73 unsafe { TFVector::from_vec(ind_vec, val_vec, len as u32, term_sum) }
74 }
75
76 fn tf_denorm(val: f16) -> u32 {
77 val.to_f32().pow(2) as u32
78 }
79}
80
81impl TFIDFEngine<f32> for DefaultTFIDFEngine
82{
83 // fn idf_vec(corpus: &Corpus, term_dim_sample: &Vec<Box<str>>) -> (Vec<f32>, f64) {
84 // let mut idf_vec = Vec::with_capacity(term_dim_sample.len());
85 // let doc_num = corpus.get_doc_num() as f64;
86 // for term in term_dim_sample.iter() {
87 // let doc_freq = corpus.get_term_count(term);
88 // idf_vec.push((doc_num / (doc_freq as f64 + 1.0)) as f32);
89 // }
90 // (idf_vec, 1.0)
91 // }
92
93 fn tf_vec(freq: &TermFrequency, term_dim_sample: &IndexSet<Box<str>>) -> TFVector<f32> {
94 // Build sparse TF vector: only non-zero entries are stored
95 let term_sum = freq.term_sum() as u32;
96 let len = freq.term_num();
97 let mut ind_vec: Vec<u32> = Vec::with_capacity(len);
98 let mut val_vec: Vec<f32> = Vec::with_capacity(len);
99 for (term, count) in freq.iter() {
100 if let Some(idx) = term_dim_sample.get_index(term) {
101 ind_vec.push(idx as u32);
102 val_vec.push(count as f32);
103 }
104 }
105 unsafe { TFVector::from_vec(ind_vec, val_vec, len as u32, term_sum) }
106 }
107
108 fn tf_denorm(val: f32) -> u32 {
109 val as u32
110 }
111}
112
113impl TFIDFEngine<u32> for DefaultTFIDFEngine
114{
115 // fn idf_vec(corpus: &Corpus, term_dim_sample: &Vec<Box<str>>) -> (Vec<u32>, f64) {
116 // let mut idf_vec = Vec::with_capacity(term_dim_sample.len());
117 // let doc_num = corpus.get_doc_num() as f64;
118 // for term in term_dim_sample.iter() {
119 // let doc_freq = corpus.get_term_count(term);
120 // idf_vec.push(doc_num / (doc_freq as f64 + 1.0));
121 // }
122 // let max = idf_vec
123 // .iter()
124 // .max_by(|a, b| a.total_cmp(b))
125 // .copied()
126 // .unwrap_or(1.0);
127 // (
128 // idf_vec
129 // .into_iter()
130 // .map(|idf| (idf / max * u32::MAX as f64).ceil() as u32)
131 // .collect(),
132 // max
133 // )
134 // }
135
136 fn tf_vec(freq: &TermFrequency, term_dim_sample: &IndexSet<Box<str>>) -> TFVector<u32> {
137 // Build sparse TF vector: only non-zero entries are stored
138 let term_sum = freq.term_sum() as u32;
139 let len = freq.term_num();
140 let mut ind_vec: Vec<u32> = Vec::with_capacity(len);
141 let mut val_vec: Vec<u32> = Vec::with_capacity(len);
142 for (term, count) in freq.iter() {
143 if let Some(idx) = term_dim_sample.get_index(term) {
144 ind_vec.push(idx as u32);
145 val_vec.push(count as u32);
146 }
147 }
148 unsafe { TFVector::from_vec(ind_vec, val_vec, len as u32, term_sum) }
149 }
150
151 fn tf_denorm(val: u32) -> u32 {
152 val
153 }
154}
155
156impl TFIDFEngine<u16> for DefaultTFIDFEngine
157{
158 // fn idf_vec(corpus: &Corpus, term_dim_sample: &Vec<Box<str>>) -> (Vec<u16>, f64) {
159 // let mut idf_vec = Vec::with_capacity(term_dim_sample.len());
160 // let doc_num = corpus.get_doc_num() as f64;
161 // for term in term_dim_sample.iter() {
162 // let doc_freq = corpus.get_term_count(term);
163 // idf_vec.push(doc_num / (doc_freq as f64 + 1.0));
164 // }
165 // let max = idf_vec
166 // .iter()
167 // .max_by(|a, b| a.total_cmp(b))
168 // .copied()
169 // .unwrap_or(1.0);
170 // (
171 // idf_vec
172 // .into_iter()
173 // .map(|idf| (idf / max * u16::MAX as f64).ceil() as u16)
174 // .collect(),
175 // max
176 // )
177 // }
178
179 fn tf_vec(freq: &TermFrequency, term_dim_sample: &IndexSet<Box<str>>) -> TFVector<u16> {
180 // Build sparse TF vector: only non-zero entries are stored
181 let term_sum = freq.term_sum() as u32;
182 let len = freq.term_num();
183 let mut ind_vec: Vec<u32> = Vec::with_capacity(len);
184 let mut val_vec: Vec<u16> = Vec::with_capacity(len);
185 for (term, count) in freq.iter() {
186 if let Some(idx) = term_dim_sample.get_index(term) {
187 ind_vec.push(idx as u32);
188 val_vec.push(count as u16);
189 }
190 }
191 unsafe { TFVector::from_vec(ind_vec, val_vec, len as u32, term_sum) }
192 }
193
194 fn tf_denorm(val: u16) -> u32 {
195 val as u32
196 }
197}
198
199// impl<K> TFIDFEngine<u8, K> for DefaultTFIDFEngine
200// {
201// // fn idf_vec(corpus: &Corpus, term_dim_sample: &Vec<Box<str>>) -> (Vec<u8>, f64) {
202// // let mut idf_vec = Vec::with_capacity(term_dim_sample.len());
203// // let doc_num = corpus.get_doc_num() as f64;
204// // for term in term_dim_sample.iter() {
205// // let doc_freq = corpus.get_term_count(term);
206// // idf_vec.push(doc_num / (doc_freq as f64 + 1.0));
207// // }
208// // let max = idf_vec
209// // .iter()
210// // .max_by(|a, b| a.total_cmp(b))
211// // .copied()
212// // .unwrap_or(1.0);
213// // (
214// // idf_vec
215// // .into_iter()
216// // .map(|idf| (idf / max * u8::MAX as f64).ceil() as u8)
217// // .collect(),
218// // max
219// // )
220// // }
221
222// fn tf_vec(freq: &TermFrequency, term_dim_sample: &IndexSet<Box<str>>) -> (ZeroSpVec<u8>, f32) {
223// // Build sparse TF vector without allocating dense Vec
224// let total_count_f64 = freq.term_sum() as f64;
225// if total_count_f64 == 0.0 { return (ZeroSpVec::new(), total_count_f64 as f32); }
226// // Use f32 intermediates for u8 to reduce cost and memory
227// let total_count = total_count_f64 as f32;
228// let mut max_val = 0.0f32;
229// let inv_total = 1.0f32 / total_count;
230// let raw = freq.iter().filter_map(|(term, count)| {
231// let idx = term_dim_sample.get_index(term)?;
232// let v = (count as f32) * inv_total;
233// max_val = max_val.max(v);
234// Some((idx, v))
235// }).collect::<Vec<_>>();
236// let len = term_dim_sample.len();
237// if max_val == 0.0 { return (ZeroSpVec::new(), total_count_f64 as f32); }
238// let mul_norm = (u8::MAX as f32) / max_val; // == (1/max_val) * u8::MAX
239// let vec_u8 = raw.into_iter()
240// .map(|(idx, v)| {
241// let q = (v * mul_norm).ceil() as u8;
242// (idx, q)
243// })
244// .collect::<Vec<_>>();
245// (unsafe { ZeroSpVec::from_sparse_iter(vec_u8.into_iter(), len) }, total_count_f64 as f32)
246// }
247// }