tf_idf_vectorizer/vectorizer/
tfidf.rs

1use half::f16;
2use num_traits::{Num, Pow};
3
4use crate::{Corpus, TermFrequency, utils::datastruct::{map::IndexSet, vector::{TFVector, TFVectorTrait}}};
5
6
7/// TF-IDF Calculation Engine Trait
8///
9/// Defines the behavior of a TF-IDF calculation engine.
10///
11/// Custom engines can be implemented and plugged into
12/// [`TFIDFVectorizer`].
13///
14/// A default implementation, [`DefaultTFIDFEngine`], is provided.
15///
16/// ### Supported Numeric Types
17/// - `f16`
18/// - `f32`
19/// - `u16`
20/// - `u32`
21pub trait TFIDFEngine<N>: Send + Sync
22where
23    N: Num + Copy
24{
25    /// Method to generate the IDF vector
26    /// # Arguments
27    /// * `corpus` - The corpus
28    /// * `term_dim_sample` - term dimension sample
29    /// # Returns
30    /// * `Vec<N>` - The IDF vector
31    /// * `denormalize_num` - Value for denormalization
32    fn idf_vec(corpus: &Corpus, term_dim_sample: &Vec<Box<str>>) -> Vec<f32> {
33        let mut idf_vec = Vec::with_capacity(term_dim_sample.len());
34        let doc_num = corpus.get_doc_num() as f64;
35        for term in term_dim_sample.iter() {
36            let doc_freq = corpus.get_term_count(term);
37            idf_vec.push((doc_num / (doc_freq as f64 + 1.0)) as f32);
38        }
39        idf_vec
40    }
41    /// Method to generate the TF vector
42    /// # Arguments
43    /// * `freq` - term frequency
44    /// * `term_dim_sample` - term dimension sample
45    /// # Returns
46    /// * `(ZeroSpVec<N>, f64)` - TF vector and value for denormalization
47    fn tf_vec(freq: &TermFrequency, term_dim_sample: &IndexSet<Box<str>>) -> TFVector<N>;
48
49    fn tf_denorm(val: N) -> u32;
50}
51
52/// デフォルトのTF-IDFエンジン
53#[derive(Debug)]
54pub struct DefaultTFIDFEngine;
55impl DefaultTFIDFEngine {
56    pub fn new() -> Self {
57        DefaultTFIDFEngine
58    }
59}
60
61impl TFIDFEngine<f16> for DefaultTFIDFEngine {
62    // fn idf_vec(corpus: &Corpus, term_dim_sample: &Vec<Box<str>>) -> (Vec<f16>, f64) {
63    //     let mut idf_vec = Vec::with_capacity(term_dim_sample.len());
64    //     let doc_num = corpus.get_doc_num() as f64;
65    //     for term in term_dim_sample.iter() {
66    //         let doc_freq = corpus.get_term_count(term);
67    //         idf_vec.push(f16::from_f64(doc_num / (doc_freq as f64 + 1.0)));
68    //     }
69    //     (idf_vec, 1.0)
70    // }
71    #[inline]
72    fn tf_vec(freq: &TermFrequency, term_dim_sample: &IndexSet<Box<str>>) -> TFVector<f16> {
73        // Build sparse TF vector: only non-zero entries are stored
74        let term_sum = freq.term_sum() as u32;
75        let len = freq.term_num();
76        let mut ind_vec: Vec<u32> = Vec::with_capacity(len);
77        let mut val_vec: Vec<f16> = Vec::with_capacity(len);
78        for (term, count) in freq.iter() {
79            let count = (count as f32).sqrt();
80            if let Some(idx) = term_dim_sample.get_index(term) {
81                ind_vec.push(idx as u32);
82                val_vec.push(f16::from_f32(count));
83            }
84        }
85        unsafe { TFVector::from_vec(ind_vec, val_vec, len as u32, term_sum) }
86    }
87
88    #[inline(always)]
89    fn tf_denorm(val: f16) -> u32 {
90        val.to_f32().pow(2) as u32
91    }
92}
93
94impl TFIDFEngine<f32> for DefaultTFIDFEngine
95{
96    // fn idf_vec(corpus: &Corpus, term_dim_sample: &Vec<Box<str>>) -> (Vec<f32>, f64) {
97    //     let mut idf_vec = Vec::with_capacity(term_dim_sample.len());
98    //     let doc_num = corpus.get_doc_num() as f64;
99    //     for term in term_dim_sample.iter() {
100    //         let doc_freq = corpus.get_term_count(term);
101    //         idf_vec.push((doc_num / (doc_freq as f64 + 1.0)) as f32);
102    //     }
103    //     (idf_vec, 1.0)
104    // }
105    #[inline]
106    fn tf_vec(freq: &TermFrequency, term_dim_sample: &IndexSet<Box<str>>) -> TFVector<f32> {
107        // Build sparse TF vector: only non-zero entries are stored
108        let term_sum = freq.term_sum() as u32;
109        let len = freq.term_num();
110        let mut ind_vec: Vec<u32> = Vec::with_capacity(len);
111        let mut val_vec: Vec<f32> = Vec::with_capacity(len);
112        for (term, count) in freq.iter() {
113            if let Some(idx) = term_dim_sample.get_index(term) {
114                ind_vec.push(idx as u32);
115                val_vec.push(count as f32);
116            }
117        }
118        unsafe { TFVector::from_vec(ind_vec, val_vec, len as u32, term_sum) }
119    }
120
121    #[inline(always)]
122    fn tf_denorm(val: f32) -> u32 {
123        val as u32
124    }
125}
126
127impl TFIDFEngine<u32> for DefaultTFIDFEngine
128{
129    // fn idf_vec(corpus: &Corpus, term_dim_sample: &Vec<Box<str>>) -> (Vec<u32>, f64) {
130    //     let mut idf_vec = Vec::with_capacity(term_dim_sample.len());
131    //     let doc_num = corpus.get_doc_num() as f64;
132    //     for term in term_dim_sample.iter() {
133    //         let doc_freq = corpus.get_term_count(term);
134    //         idf_vec.push(doc_num / (doc_freq as f64 + 1.0));
135    //     }
136    //     let max = idf_vec
137    //         .iter()
138    //         .max_by(|a, b| a.total_cmp(b))
139    //         .copied()
140    //         .unwrap_or(1.0);
141    //     (
142    //     idf_vec
143    //         .into_iter()
144    //         .map(|idf| (idf / max * u32::MAX as f64).ceil() as u32)
145    //         .collect(),
146    //     max
147    //     )
148    // }
149    #[inline]
150    fn tf_vec(freq: &TermFrequency, term_dim_sample: &IndexSet<Box<str>>) -> TFVector<u32> {
151        // Build sparse TF vector: only non-zero entries are stored
152        let term_sum = freq.term_sum() as u32;
153        let len = freq.term_num();
154        let mut ind_vec: Vec<u32> = Vec::with_capacity(len);
155        let mut val_vec: Vec<u32> = Vec::with_capacity(len);
156        for (term, count) in freq.iter() {
157            if let Some(idx) = term_dim_sample.get_index(term) {
158                ind_vec.push(idx as u32);
159                val_vec.push(count as u32);
160            }
161        }
162        unsafe { TFVector::from_vec(ind_vec, val_vec, len as u32, term_sum) }
163    }
164
165    #[inline(always)]
166    fn tf_denorm(val: u32) -> u32 {
167        val
168    }
169}
170
171impl TFIDFEngine<u16> for DefaultTFIDFEngine
172{
173    // fn idf_vec(corpus: &Corpus, term_dim_sample: &Vec<Box<str>>) -> (Vec<u16>, f64) {
174    //     let mut idf_vec = Vec::with_capacity(term_dim_sample.len());
175    //     let doc_num = corpus.get_doc_num() as f64;
176    //     for term in term_dim_sample.iter() {
177    //         let doc_freq = corpus.get_term_count(term);
178    //         idf_vec.push(doc_num / (doc_freq as f64 + 1.0));
179    //     }
180    //     let max = idf_vec
181    //         .iter()
182    //         .max_by(|a, b| a.total_cmp(b))
183    //         .copied()
184    //         .unwrap_or(1.0);
185    //     (
186    //     idf_vec
187    //         .into_iter()
188    //         .map(|idf| (idf / max * u16::MAX as f64).ceil() as u16)
189    //         .collect(),
190    //     max
191    //     )
192    // }
193    #[inline]
194    fn tf_vec(freq: &TermFrequency, term_dim_sample: &IndexSet<Box<str>>) -> TFVector<u16> {
195        // Build sparse TF vector: only non-zero entries are stored
196        let term_sum = freq.term_sum() as u32;
197        let len = freq.term_num();
198        let mut ind_vec: Vec<u32> = Vec::with_capacity(len);
199        let mut val_vec: Vec<u16> = Vec::with_capacity(len);
200        for (term, count) in freq.iter() {
201            if let Some(idx) = term_dim_sample.get_index(term) {
202                ind_vec.push(idx as u32);
203                val_vec.push(count as u16);
204            }
205        }
206        unsafe { TFVector::from_vec(ind_vec, val_vec, len as u32, term_sum) }
207    }
208
209    #[inline(always)]
210    fn tf_denorm(val: u16) -> u32 {
211        val as u32
212    }
213}
214
215// impl<K> TFIDFEngine<u8, K> for DefaultTFIDFEngine
216// {
217//     // fn idf_vec(corpus: &Corpus, term_dim_sample: &Vec<Box<str>>) -> (Vec<u8>, f64) {
218//     //     let mut idf_vec = Vec::with_capacity(term_dim_sample.len());
219//     //     let doc_num = corpus.get_doc_num() as f64;
220//     //     for term in term_dim_sample.iter() {
221//     //         let doc_freq = corpus.get_term_count(term);
222//     //         idf_vec.push(doc_num / (doc_freq as f64 + 1.0));
223//     //     }
224//     //     let max = idf_vec
225//     //         .iter()
226//     //         .max_by(|a, b| a.total_cmp(b))
227//     //         .copied()
228//     //         .unwrap_or(1.0);
229//     //     (
230//     //     idf_vec
231//     //         .into_iter()
232//     //         .map(|idf| (idf / max * u8::MAX as f64).ceil() as u8)
233//     //         .collect(),
234//     //     max
235//     //     )
236//     // }
237
238//     fn tf_vec(freq: &TermFrequency, term_dim_sample: &IndexSet<Box<str>>) -> (ZeroSpVec<u8>, f32) {
239//         // Build sparse TF vector without allocating dense Vec
240//         let total_count_f64 = freq.term_sum() as f64;
241//         if total_count_f64 == 0.0 { return (ZeroSpVec::new(), total_count_f64 as f32); }
242//         // Use f32 intermediates for u8 to reduce cost and memory
243//         let total_count = total_count_f64 as f32;
244//         let mut max_val = 0.0f32;
245//         let inv_total = 1.0f32 / total_count;
246//         let raw = freq.iter().filter_map(|(term, count)| {
247//             let idx = term_dim_sample.get_index(term)?;
248//             let v = (count as f32) * inv_total;
249//             max_val = max_val.max(v);
250//             Some((idx, v))
251//         }).collect::<Vec<_>>();
252//         let len = term_dim_sample.len();
253//         if max_val == 0.0 { return (ZeroSpVec::new(), total_count_f64 as f32); }
254//         let mul_norm = (u8::MAX as f32) / max_val; // == (1/max_val) * u8::MAX
255//         let vec_u8 = raw.into_iter()
256//             .map(|(idx, v)| {
257//                 let q = (v * mul_norm).ceil() as u8;
258//                 (idx, q)
259//             })
260//             .collect::<Vec<_>>();
261//         (unsafe { ZeroSpVec::from_sparse_iter(vec_u8.into_iter(), len) }, total_count_f64 as f32)
262//     }
263// }