tf_idf_vectorizer/vectorizer/
index.rs

1use std::ops::{AddAssign, MulAssign};
2
3use num::Num;
4use serde::{Deserialize, Serialize};
5
6use crate::utils::{math::vector::ZeroSpVec, normalizer::{IntoNormalizer, NormalizedBounded, NormalizedMultiply}};
7use rayon::prelude::*;
8
9use super::token::TokenFrequency;
10
11/// インデックス
12/// ドキュメント単位でインデックスを作成、検索するための構造体です
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Index<N>
15where N: Num + Into<f64> + AddAssign + MulAssign + NormalizedMultiply + Copy + NormalizedBounded {
16    matrix: Vec<ZeroSpVec<N>>,
17    doc_id: Vec<String>,
18    corpus_token_freq: TokenFrequency,
19}
20
21impl<N> Index<N>
22where N: Num + Into<f64> + AddAssign + MulAssign + NormalizedMultiply + Copy + NormalizedBounded, f64: IntoNormalizer<N> {
23    /// 新しいインデックスを作成するメソッド
24    pub fn new() -> Self {
25        Self {
26            matrix: Vec::new(),
27            doc_id: Vec::new(),
28            corpus_token_freq: TokenFrequency::new(),
29        }
30    }
31
32    /// インデックスのドキュメント数を取得するメソッド
33    pub fn doc_num(&self) -> usize {
34        self.matrix.len()
35    }
36
37    /// インデックスのトークン数を取得するメソッド
38    /// トークン数はユニークなトークンの数を返す
39    pub fn token_num(&self) -> usize {
40        self.corpus_token_freq.token_num()
41    }
42
43    /// インデックスにドキュメントを追加するメソッド
44    /// 
45    /// # Arguments
46    /// * `doc_id` - ドキュメントのID
47    /// * `tokens` - ドキュメントのトークン
48    pub fn add_doc(&mut self, doc_id: String, tokens: &[&str]) {
49        // TFの計算
50        let mut doc_tf = TokenFrequency::new();
51        doc_tf.add_tokens(tokens);
52
53        // corpus_token_freqに追加
54        let old_corpus_token_num = self.corpus_token_freq.token_num();
55        self.corpus_token_freq.add_tokens(tokens);
56        let added_corpus_token_num = self.corpus_token_freq.token_num() - old_corpus_token_num;
57        self.doc_id.push(doc_id);
58
59        // ZeroSpVecを作成
60        let mut vec: ZeroSpVec<N> = ZeroSpVec::with_capacity(doc_tf.token_num());
61        for token in self.corpus_token_freq.token_set_ref_str().iter() {
62            let tf_val: N = doc_tf.tf_token(token); // ここtf_tokenの内部でいちいちvecのmax計算してるのが最適化されるのか?
63            vec.push(tf_val);
64        }
65
66        if added_corpus_token_num > 0 {
67            // 新しいトークンが追加された場合、matrixを拡張する
68            for other_tf in self.matrix.iter_mut() {
69                other_tf.add_dim(added_corpus_token_num);
70            }
71        }
72        vec.shrink_to_fit();
73        // matrixに追加
74        self.matrix.push(vec);
75    }
76
77    /// query vectorを生成するメソッド
78    /// 重要度を考慮せず、トークンの有無だけでベクトルを生成します。
79    /// 
80    /// # Arguments
81    /// * `tokens` - クエリのトークン
82    /// 
83    /// # Returns
84    /// * `ZeroSpVec<N>` - クエリのベクトル
85    pub fn generate_query_mask(&self, tokens: &[&str]) -> ZeroSpVec<N> {
86        // TFの計算
87        let mut query_tf = TokenFrequency::new();
88        query_tf.add_tokens(tokens);
89
90        // ZeroSpVecを作成
91        let mut vec: ZeroSpVec<N> = ZeroSpVec::with_capacity(query_tf.token_num());
92        for token in self.corpus_token_freq.token_set_ref_str().iter() {
93            let tf_val: N = if query_tf.contains_token(token) { N::max_normalized() } else { N::zero() }; // ここtf_tokenの内部でいちいちvecのmax計算してるのが最適化されるのか?
94            vec.push(tf_val);
95        }
96        vec
97    }
98
99    /// query vectorを生成するメソッド
100    /// 
101    /// # Arguments
102    /// * `tokens` - クエリのトークン
103    /// 
104    /// # Returns
105    /// * `ZeroSpVec<N>` - クエリのベクトル
106    pub fn generate_query(&self, tokens: &[&str]) -> ZeroSpVec<N> {
107        // TFの計算
108        let mut query_tf = TokenFrequency::new();
109        query_tf.add_tokens(tokens);
110
111        // ZeroSpVecを作成
112        let mut vec: ZeroSpVec<N> = ZeroSpVec::with_capacity(query_tf.token_num());
113        for token in self.corpus_token_freq.token_set_ref_str().iter() {
114            let tf_val: N = query_tf.tf_token(token); // ここtf_tokenの内部でいちいちvecのmax計算してるのが最適化されるのか?
115            vec.push(tf_val);
116        }
117        vec
118    }
119
120    /// クエリを検索するメソッド
121    /// コサイン類似度を計算し、類似度の高い順にソートして返します。
122    /// # Arguments
123    /// * `query` - クエリのベクトル
124    /// 
125    /// # Returns
126    /// * `Vec<(String, f64)>` - 検索結果のベクトル
127    pub fn search_cosine_similarity(&self, query: &ZeroSpVec<N>) -> Vec<(String, f64)> {
128        let mut result = Vec::new();
129
130        let idf_vec = 
131            self.corpus_token_freq.
132            idf_vector_ref_str::<N>(self.matrix.len() as u64)
133            .into_iter()
134            .map(|(_, idf)| idf)
135            .collect::<Vec<N>>();
136
137        // IDFとqueryを先に乗算
138        let idf_query: ZeroSpVec<N> = query.hadamard_normalized_vec(&idf_vec);
139
140        // ドキュメントベクトルとIDFを掛け算してコサイン類似度を計算
141        for (i, doc_vec) in self.matrix.iter().enumerate() {
142            let tf_idf_doc_vec = doc_vec.hadamard_normalized_vec(&idf_vec);
143            let similarity = tf_idf_doc_vec.cosine_similarity_normalized::<f64>(&idf_query);
144            if similarity != 0.0 {
145                result.push((self.doc_id[i].clone(), similarity));
146            }
147        }
148
149        // 類似度でソート
150        result.sort_by(|a, b| b.1.total_cmp(&a.1));
151        result
152    }
153
154    /// クエリを検索するメソッド
155    /// コサイン類似度を計算し、類似度の高い順にソートして返します。
156    /// 並列処理を使用して、検索を高速化します。
157    /// 
158    /// # Arguments
159    /// * `query` - クエリのベクトル
160    /// * `thread_count` - スレッド数
161    /// 
162    /// # Returns
163    /// * `Vec<(String, f64)>` - 検索結果のベクトル
164    pub fn search_cosine_similarity_parallel(&self, query: &ZeroSpVec<N>, thread_count: usize) -> Vec<(String, f64)>
165    where
166        N: Send + Sync,
167    {
168        let idf_vec = 
169            self.corpus_token_freq.
170            idf_vector_ref_str::<N>(self.matrix.len() as u64)
171            .into_iter()
172            .map(|(_, idf)| idf)
173            .collect::<Vec<N>>();
174
175        // IDFとqueryを先に乗算
176        let idf_query: ZeroSpVec<N> = query.hadamard_normalized_vec(&idf_vec);
177
178        let pool = rayon::ThreadPoolBuilder::new()
179            .num_threads(thread_count)
180            .build()
181            .expect("Failed to build thread pool");
182
183        let mut result: Vec<(String, f64)> = pool.install(|| {
184            self.matrix
185                .par_iter()
186                .enumerate()
187                .filter_map(|(i, doc_vec)| {
188                    let tf_idf_doc_vec = doc_vec.hadamard_normalized_vec(&idf_vec);
189                    let similarity = tf_idf_doc_vec.cosine_similarity_normalized::<f64>(&idf_query);
190                    if similarity != 0.0 {
191                        Some((self.doc_id[i].clone(), similarity))
192                    } else {
193                        None
194                    }
195                })
196                .collect()
197        });
198
199        // 類似度でソート
200        result.sort_by(|a, b| b.1.total_cmp(&a.1));
201        result
202    }
203}