tf_idf_vectorizer/vectorizer/
mod.rs

1
2use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
3use token::TokenFrequency;
4
5pub mod index;
6pub mod token;
7pub mod analyzer;
8
9pub struct TFIDFVectorizer {
10    pub corpus: TokenFrequency,
11    doc_num: u64,
12}
13
14impl TFIDFVectorizer {
15    pub fn new() -> Self {
16        Self {
17            corpus: TokenFrequency::new(),
18            doc_num: 0,
19        }
20    }
21
22    pub fn doc_num(&self) -> u64 {
23        self.doc_num
24    }
25
26    pub fn add_corpus(&mut self, tokens: &[&str]) {
27        // TFの計算
28        let mut doc_tf = TokenFrequency::new();
29        doc_tf.add_tokens(tokens);
30
31        // corpus_token_freqに追加
32        self.corpus.add_tokens(tokens);
33 
34        self.doc_num += 1;
35    }
36
37    pub fn tf_idf_vector(&self, tokens: &[&str]) -> Vec<(&str, f64)> {
38        // TFの計算
39        let mut doc_tf = TokenFrequency::new();
40        doc_tf.add_tokens(tokens);
41
42        let mut result: Vec<(&str, f64)> = Vec::new();
43        // corpus_token_freqに追加
44        let idf_vec: Vec<(&str, f64)> = self.corpus.idf_vector_ref_str(self.doc_num as u64);
45        for (added_token, idf) in idf_vec.iter() {
46            let tf: f64 = doc_tf.tf_token(added_token);
47            if tf != 0.0 {
48                let tf_idf = tf * idf;
49                result.push((*added_token, tf_idf));
50            }
51        }
52        result.sort_by(|a, b| b.1.total_cmp(&a.1));
53        result
54    }
55    pub fn tf_idf_vector_parallel(&self, tokens: &[&str], thread_count: usize) -> Vec<(&str, f64)> {
56        // TFの計算
57        let mut doc_tf = TokenFrequency::new();
58        doc_tf.add_tokens(tokens);
59
60        let idf_vec: Vec<(&str, f64)> = self.corpus.idf_vector_ref_str(self.doc_num as u64);
61
62        // カスタムスレッドプールを作成し、スレッド数を指定
63        let pool = rayon::ThreadPoolBuilder::new().num_threads(thread_count).build().unwrap();
64        let mut result: Vec<(&str, f64)> = pool.install(|| {
65            idf_vec
66                .par_iter()
67                .filter_map(|(added_token, idf)| {
68                    let tf: f64 = doc_tf.tf_token(added_token);
69                    if tf != 0.0 {
70                        Some((*added_token, tf * idf))
71                    } else {
72                        None
73                    }
74                })
75                .collect()
76        });
77        result.sort_by(|a, b| b.1.total_cmp(&a.1));
78        result
79    }
80}