tf_idf_vectorizer/vectorizer/
tfidf.rs

1use ahash::RandomState;
2use indexmap::IndexSet;
3
4use crate::{utils::math::vector::{ZeroSpVec, ZeroSpVecTrait}, vectorizer::{corpus::Corpus, token::TokenFrequency}};
5
6pub trait TFIDFEngine<N>: Send + Sync
7where
8    N: num::Num + Copy,
9{
10    /// Method to generate the IDF vector
11    /// # Arguments
12    /// * `corpus` - The corpus
13    /// * `token_dim_sample` - Token dimension sample
14    /// # Returns
15    /// * `Vec<N>` - The IDF vector
16    /// * `denormalize_num` - Value for denormalization
17    fn idf_vec(corpus: &Corpus, token_dim_sample: &IndexSet<String, RandomState>) -> (Vec<N>, f64);
18    /// Method to generate the TF vector
19    /// # Arguments
20    /// * `freq` - Token frequency
21    /// * `token_dim_sample` - Token dimension sample
22    /// # Returns
23    /// * `(ZeroSpVec<N>, f64)` - TF vector and value for denormalization
24    fn tf_vec(freq: &TokenFrequency, token_dim_sample: &IndexSet<String, RandomState>) -> (ZeroSpVec<N>, f64);
25}
26
27/// デフォルトのTF-IDFエンジン
28/// `f32`、`f64`、`u32`、`u16`、`u8`の型に対応
29#[derive(Debug)]
30pub struct DefaultTFIDFEngine;
31impl DefaultTFIDFEngine {
32    pub fn new() -> Self {
33        DefaultTFIDFEngine
34    }
35}
36
37impl TFIDFEngine<f32> for DefaultTFIDFEngine
38{
39    fn idf_vec(corpus: &Corpus, token_dim_sample: &IndexSet<String, RandomState>) -> (Vec<f32>, f64) {
40        let mut idf_vec = Vec::with_capacity(token_dim_sample.len());
41        let doc_num = corpus.get_doc_num() as f64;
42        for token in token_dim_sample {
43            let doc_freq = corpus.get_token_count(token);
44            idf_vec.push((doc_num / (doc_freq as f64 + 1.0)) as f32);
45        }
46        (idf_vec, 1.0)
47    }
48
49    fn tf_vec(freq: &TokenFrequency, token_dim_sample: &IndexSet<String, RandomState>) -> (ZeroSpVec<f32>, f64) {
50        // Build sparse TF vector: only non-zero entries are stored
51        let total_count = freq.token_sum() as f32;
52        if total_count == 0.0 { return (ZeroSpVec::new(), total_count.into()); }
53        let mut raw: Vec<(usize, f32)> = Vec::with_capacity(freq.token_num());
54        let len = token_dim_sample.len();
55        let inv_total = 1.0f32 / total_count;
56        for (idx, token) in token_dim_sample.iter().enumerate() {
57            let count = freq.token_count(token) as f32;
58            if count == 0.0 { continue; }
59            raw.push((idx, count * inv_total));
60        }
61        (unsafe { ZeroSpVec::from_raw_iter(raw.into_iter(), len) }, total_count.into())
62    }
63}
64
65impl TFIDFEngine<f64> for DefaultTFIDFEngine
66{
67    fn idf_vec(corpus: &Corpus, token_dim_sample: &IndexSet<String, RandomState>) -> (Vec<f64>, f64) {
68        let mut idf_vec = Vec::with_capacity(token_dim_sample.len());
69        let doc_num = corpus.get_doc_num() as f64;
70        for token in token_dim_sample {
71            let doc_freq = corpus.get_token_count(token);
72            idf_vec.push(doc_num / (doc_freq as f64 + 1.0));
73        }
74        (idf_vec, 1.0)
75    }
76
77    fn tf_vec(freq: &TokenFrequency, token_dim_sample: &IndexSet<String, RandomState>) -> (ZeroSpVec<f64>, f64) {
78        // Build sparse TF vector: only non-zero entries are stored
79        let total_count = freq.token_sum() as f64;
80        if total_count == 0.0 { return (ZeroSpVec::new(), total_count.into()); }
81        let mut raw: Vec<(usize, f64)> = Vec::with_capacity(freq.token_num());
82        let len = token_dim_sample.len();
83        let inv_total = 1.0f64 / total_count;
84        for (idx, token) in token_dim_sample.iter().enumerate() {
85            let count = freq.token_count(token) as f64;
86            if count == 0.0 { continue; }
87            raw.push((idx, count * inv_total));
88        }
89        (unsafe { ZeroSpVec::from_raw_iter(raw.into_iter(), len) }, total_count.into())
90    }
91}
92
93impl TFIDFEngine<u32> for DefaultTFIDFEngine
94{
95    fn idf_vec(corpus: &Corpus, token_dim_sample: &IndexSet<String, RandomState>) -> (Vec<u32>, f64) {
96        let mut idf_vec = Vec::with_capacity(token_dim_sample.len());
97        let doc_num = corpus.get_doc_num() as f64;
98        for token in token_dim_sample {
99            let doc_freq = corpus.get_token_count(token);
100            idf_vec.push((doc_num / (doc_freq as f64 + 1.0)) as u32);
101        }
102        (idf_vec, 1.0)
103    }
104
105    fn tf_vec(freq: &TokenFrequency, token_dim_sample: &IndexSet<String, RandomState>) -> (ZeroSpVec<u32>, f64) {
106        // Build sparse TF vector without allocating dense Vec
107        let total_count = freq.token_sum() as f64;
108        if total_count == 0.0 { return (ZeroSpVec::new(), total_count); }
109        let mut raw: Vec<(usize, f64)> = Vec::with_capacity(freq.token_num());
110        let mut max_val = 0.0f64;
111        let inv_total = 1.0f64 / total_count;
112        for (idx, token) in token_dim_sample.iter().enumerate() {
113            let count = freq.token_count(token) as f64;
114            if count == 0.0 { continue; }
115            let v = count * inv_total;
116            if v > max_val { max_val = v; }
117            raw.push((idx, v));
118        }
119        let len = token_dim_sample.len();
120        if max_val == 0.0 { return (ZeroSpVec::new(), total_count); }
121        let mut vec_u32: Vec<(usize, u32)> = Vec::with_capacity(raw.len());
122        let mul_norm = (u32::MAX as f64) / max_val; // == (1/max_val) * u32::MAX
123        for (idx, v) in raw.into_iter() {
124            let q = (v * mul_norm).ceil() as u32;
125            vec_u32.push((idx, q));
126        }
127        (unsafe { ZeroSpVec::from_raw_iter(vec_u32.into_iter(), len) }, total_count)
128    }
129}
130
131impl TFIDFEngine<u16> for DefaultTFIDFEngine
132{
133    fn idf_vec(corpus: &Corpus, token_dim_sample: &IndexSet<String, RandomState>) -> (Vec<u16>, f64) {
134        let mut idf_vec = Vec::with_capacity(token_dim_sample.len());
135        let doc_num = corpus.get_doc_num() as f64;
136        for token in token_dim_sample {
137            let doc_freq = corpus.get_token_count(token);
138            idf_vec.push(doc_num / (doc_freq as f64 + 1.0));
139        }
140        let max = idf_vec
141            .iter()
142            .max_by(|a, b| a.total_cmp(b))
143            .copied()
144            .unwrap_or(1.0);
145        (
146        idf_vec
147            .into_iter()
148            .map(|idf| (idf / max * u16::MAX as f64).ceil() as u16)
149            .collect(),
150        max
151        )
152    }
153
154    fn tf_vec(freq: &TokenFrequency, token_dim_sample: &IndexSet<String, RandomState>) -> (ZeroSpVec<u16>, f64) {
155        // Build sparse TF vector without allocating a dense Vec<f64>
156        let total_count = freq.token_sum() as f64;
157        // First pass: compute raw tf values and track max
158        let mut raw: Vec<(usize, f32)> = Vec::with_capacity(freq.token_num());
159        let mut max_val = 0.0f32;
160    let div_total = (1.0 / total_count) as f32;
161        for (idx, token) in token_dim_sample.iter().enumerate() {
162            let count = freq.token_count(token);
163            if count == 0 { continue; }
164            let v = count as f32 * div_total;
165            if v > max_val { max_val = v; }
166            raw.push((idx, v));
167        }
168        let len = token_dim_sample.len();
169        if max_val == 0.0 { return (ZeroSpVec::new(), total_count); }
170        // Second pass: normalize into quantized u16 and build sparse vector
171        let mut vec_u16: Vec<(usize, u16)> = Vec::with_capacity(raw.len());
172    let norm_div_max = (u16::MAX as f32) / max_val; // == (1/max_val) * u16::MAX
173        for (idx, v) in raw.into_iter() {
174            let q = (v * norm_div_max).ceil() as u16;
175            vec_u16.push((idx, q));
176        }
177        (unsafe { ZeroSpVec::from_raw_iter(vec_u16.into_iter(), len) }, total_count)
178    }
179}
180
181impl TFIDFEngine<u8> for DefaultTFIDFEngine
182{
183    fn idf_vec(corpus: &Corpus, token_dim_sample: &IndexSet<String, RandomState>) -> (Vec<u8>, f64) {
184        let mut idf_vec = Vec::with_capacity(token_dim_sample.len());
185        let doc_num = corpus.get_doc_num() as f64;
186        for token in token_dim_sample {
187            let doc_freq = corpus.get_token_count(token);
188            idf_vec.push(doc_num / (doc_freq as f64 + 1.0));
189        }
190        let max = idf_vec
191            .iter()
192            .max_by(|a, b| a.total_cmp(b))
193            .copied()
194            .unwrap_or(1.0);
195        (
196        idf_vec
197            .into_iter()
198            .map(|idf| (idf / max * u8::MAX as f64).ceil() as u8)
199            .collect(),
200        max
201        )
202    }
203
204    fn tf_vec(freq: &TokenFrequency, token_dim_sample: &IndexSet<String, RandomState>) -> (ZeroSpVec<u8>, f64) {
205        // Build sparse TF vector without allocating dense Vec
206        let total_count_f64 = freq.token_sum() as f64;
207        if total_count_f64 == 0.0 { return (ZeroSpVec::new(), total_count_f64); }
208        // Use f32 intermediates for u8 to reduce cost and memory
209        let total_count = total_count_f64 as f32;
210        let mut raw: Vec<(usize, f32)> = Vec::with_capacity(freq.token_num());
211        let mut max_val = 0.0f32;
212        let inv_total = 1.0f32 / total_count;
213        for (idx, token) in token_dim_sample.iter().enumerate() {
214            let count = freq.token_count(token) as f32;
215            if count == 0.0 { continue; }
216            let v = count * inv_total;
217            if v > max_val { max_val = v; }
218            raw.push((idx, v));
219        }
220        let len = token_dim_sample.len();
221        if max_val == 0.0 { return (ZeroSpVec::new(), total_count_f64); }
222        let mut vec_u8: Vec<(usize, u8)> = Vec::with_capacity(raw.len());
223        let mul_norm = (u8::MAX as f32) / max_val; // == (1/max_val) * u8::MAX
224        for (idx, v) in raw.into_iter() {
225            let q = (v * mul_norm).ceil() as u8;
226            vec_u8.push((idx, q));
227        }
228        (unsafe { ZeroSpVec::from_raw_iter(vec_u8.into_iter(), len) }, total_count_f64)
229    }
230}