tf_idf_vectorizer/vectorizer/
index.rs

1
2use std::collections::HashMap;
3use std::str;
4
5use fst::{Map, MapBuilder, Streamer};
6use serde::Serialize;
7use vec_plus::vec::{default_sparse_vec::DefaultSparseVec, vec_trait::Math};
8
9use super::token::TokenFrequency;
10
11
12#[derive(Clone, Debug,)]
13pub struct Index<IdType>
14where
15    IdType: Clone + Eq + std::hash::Hash + Serialize + std::fmt::Debug,
16{
17    // doc_id -> (圧縮ベクトル, 文書の総トークン数)
18    // 圧縮ベクトル: インデックス順にトークンの TF を保持
19    pub index: HashMap<IdType, (DefaultSparseVec<u16>, u64 /* token num */)>,
20    pub avg_tokens_len: u64,  // 全文書の平均トークン長
21    pub max_tokens_len: u64,  // 全文書の最大トークン長
22    pub idf: Map<Vec<u8>>,    // fst::Map 形式の IDF
23    pub total_doc_count: u64, // 文書総数
24}
25
26impl<IdType> Index<IdType>
27where
28    IdType: Clone + Eq + std::hash::Hash + Serialize + std::fmt::Debug,
29{
30    // ---------------------------------------------------------------------------------------------
31    // コンストラクタ
32    // ---------------------------------------------------------------------------------------------
33    pub fn new_with_set(
34        index: HashMap<IdType, (DefaultSparseVec<u16>, u64)>,
35        idf: Map<Vec<u8>>,
36        avg_tokens_len: u64,
37        max_tokens_len: u64,
38        total_doc_count: u64,
39    ) -> Self {
40        Self {
41            index,
42            idf,
43            avg_tokens_len,
44            max_tokens_len,
45            total_doc_count,
46        }
47    }
48
49    pub fn get_index(&self) -> &HashMap<IdType, (DefaultSparseVec<u16>, u64)> {
50        &self.index
51    }
52
53    // ---------------------------------------------------------------------------------------------
54    // 公開メソッド: 検索 (Cosine Similarity)
55    // ---------------------------------------------------------------------------------------------
56
57    /// 単純なコサイン類似度検索
58    pub fn search_cos_similarity(&self, query: &[&str], n: usize) -> Vec<(&IdType, f64)> {
59        // クエリの CsVec を作成
60        let query_csvec = self.build_query_csvec(query);
61
62        // 類似度スコアを計算
63        let mut similarities = self
64            .index
65            .iter()
66            .filter_map(|(id, (doc_vec, _doc_len))| {
67                let similarity = Self::cos_similarity(doc_vec, &query_csvec);
68                (similarity > 0.0).then(|| (id, similarity))
69            })
70            .collect::<Vec<_>>();
71
72        // スコア降順でソートして上位 n 件を返す
73        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
74        similarities.truncate(n);
75
76        similarities
77    }
78
79    /// 文書長を正規化するパラメータを入れたコサイン類似度検索
80    pub fn search_cos_similarity_tuned(&self, query: &[&str], n: usize, b: f64) -> Vec<(&IdType, f64)> {
81        let query_csvec = self.build_query_csvec(query);
82
83        let max_for_len_norm = self.max_tokens_len as f64 / self.avg_tokens_len as f64;
84
85        let mut similarities = self
86            .index
87            .iter()
88            .filter_map(|(id, (doc_vec, doc_len))| {
89                // 0.5 + ( (doc_len / avg_tokens_len) / max_for_len_norm - 0.5 ) * b
90                let len_norm = 0.5
91                    + (((*doc_len as f64 / self.avg_tokens_len as f64) / max_for_len_norm) - 0.5) * b;
92
93                let similarity = Self::cos_similarity(doc_vec, &query_csvec) * len_norm;
94                (similarity > 0.0).then(|| (id, similarity))
95            })
96            .collect::<Vec<_>>();
97
98        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
99        similarities.truncate(n);
100
101        similarities
102    }
103
104    // ---------------------------------------------------------------------------------------------
105    // 公開メソッド: BM25 x TF-IDF 検索
106    // ---------------------------------------------------------------------------------------------
107
108    pub fn search_bm25_tfidf(&self, query: &[&str], n: usize, k1: f64, b: f64) -> Vec<(&IdType, f64)> {
109        let query_csvec = self.build_query_csvec(query);
110
111        let mut similarities = self
112            .index
113            .iter()
114            .filter_map(|(id, (doc_vec, doc_len))| {
115                let score = Self::bm25_with_csvec_optimized(
116                    &query_csvec,
117                    doc_vec,
118                    *doc_len,
119                    self.avg_tokens_len as f64,
120                    k1,
121                    b,
122                );
123                (score > 0.0).then(|| (id, score))
124            })
125            .collect::<Vec<_>>();
126
127        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
128        similarities.truncate(n);
129
130        similarities
131    }
132
133    // ---------------------------------------------------------------------------------------------
134    // 公開メソッド: Index の合成 
135    // ---------------------------------------------------------------------------------------------
136
137    pub fn synthesize_index(&mut self, mut other: Self /* otherを優先 */) {
138        let new_max_token_len = self.max_tokens_len.max(other.max_tokens_len);
139        let new_total_doc_count = self.total_doc_count + other.total_doc_count;
140        // 加重平均で平均トークン長を再計算
141        let sum_self = self.avg_tokens_len as u128 * self.total_doc_count as u128;
142        let sum_other = other.avg_tokens_len as u128 * other.total_doc_count as u128;
143        let new_avg_token_len =
144            ((sum_self + sum_other) / new_total_doc_count as u128) as u64;
145
146        // その他の初期計算
147        let this_max_idf = (1.0 + self.total_doc_count as f64 / (2.0)).ln() as f32;
148        let other_max_idf = (1.0 + other.total_doc_count as f64 / (2.0)).ln() as f32;
149        let combined_max_idf = (1.0 + new_total_doc_count as f64 / (2.0)).ln() as f32;
150
151        //  値の準備
152        let mut builder = MapBuilder::memory();
153        let mut this_stream = self.idf.stream();
154        let mut other_stream = other.idf.stream();
155        let mut new_index_index: usize = 0;
156        let mut this_new_csvec_index_vec: Vec<usize> = Vec::new();
157        let mut other_new_csvec_index_vec: Vec<usize> = Vec::new();
158
159        let mut next_this = this_stream.next();
160        let mut next_other = other_stream.next();
161        //  両方のidfを合成, csvecのindexを再計算
162        while next_this != None && next_other != None {
163            let (this_token, this_idf ) = next_this.unwrap();
164            let (other_token, other_idf) = next_other.unwrap();
165            if this_token < other_token {
166                builder.insert(this_token, this_idf).unwrap();
167                next_this = this_stream.next();
168                    this_new_csvec_index_vec.push(new_index_index);
169                // otherのindexはそのまま
170
171            } else if this_token == other_token {
172                builder.insert(this_token, Self::synthesize_idf(this_idf, other_idf, 
173                    self.total_doc_count, 
174                    other.total_doc_count, 
175                    new_total_doc_count, 
176                    this_max_idf, 
177                    other_max_idf, 
178                    combined_max_idf
179                )).unwrap();
180                next_this = this_stream.next();
181                    this_new_csvec_index_vec.push(new_index_index);
182                next_other = other_stream.next();
183                    other_new_csvec_index_vec.push(new_index_index);
184            } else {
185                builder.insert(other_token, other_idf).unwrap();
186                // thisのindexはそのまま
187
188                next_other = other_stream.next();
189                    other_new_csvec_index_vec.push(new_index_index);
190            }
191            new_index_index += 1;
192        }
193        if next_this != None {
194            loop {
195                let (this_token, this_idf) = next_this.unwrap();
196                builder.insert(this_token, this_idf).unwrap();
197                next_this = this_stream.next();
198                    this_new_csvec_index_vec.push(new_index_index);
199                new_index_index += 1;
200                if next_this == None {
201                    break;
202                }
203            }
204        } else if next_other != None {
205            loop {
206                let (other_token, other_idf) = next_other.unwrap();
207                builder.insert(other_token, other_idf).unwrap();
208                next_other = other_stream.next();
209                    other_new_csvec_index_vec.push(new_index_index);
210                new_index_index += 1;
211                if next_other == None {
212                    break;
213                }
214            }
215        }
216        let new_idf = builder.into_map();
217
218        //  csvecのindexを合成
219        self.index.iter_mut().for_each(|(_id, (csvec, _))| {
220            let indices = csvec.as_mut_slice_ind();
221            for indice in indices {
222                *indice = this_new_csvec_index_vec[*indice];
223            }
224        });
225
226        other.index.iter_mut().for_each(|(_id, (csvec, _))| {
227            let indices = csvec.as_mut_slice_ind();
228            for indice in indices {
229                *indice = other_new_csvec_index_vec[*indice];
230            }
231        });
232
233        //  インデックスの合成
234        self.index.extend(other.index);
235        self.avg_tokens_len = new_avg_token_len;
236        self.max_tokens_len = new_max_token_len;
237        self.idf = new_idf;
238        self.total_doc_count = new_total_doc_count;
239    }
240
241    // ---------------------------------------------------------------------------------------------
242    //  プライベート:IDF の合成
243    // ---------------------------------------------------------------------------------------------
244    #[inline(always)]
245    fn synthesize_idf(
246        this_idf: u64,
247        other_idf: u64,
248        this_doc_count: u64,
249        other_doc_count: u64,
250        total_doc_count: u64,
251        this_max_idf: f32,
252        other_max_idf: f32,
253        combined_max_idf: f32,
254    ) -> u64 {
255        const MAX_U16: f32 = 65535.0;
256
257        let a = (this_idf as f32 * this_max_idf / MAX_U16).exp();
258        let b = (other_idf as f32 * other_max_idf / MAX_U16).exp();
259
260        let denominator = (this_doc_count as f32) / a + (other_doc_count as f32) / b - 2.0;
261        let inner = 1.0 + (total_doc_count as f32) / denominator;
262        ((inner.ln() / combined_max_idf) * MAX_U16).round() as u64
263    }
264
265    // ---------------------------------------------------------------------------------------------
266    // BM25 実装 (公開: ほかで呼び出したい場合のみ pub)
267    // ---------------------------------------------------------------------------------------------
268    pub fn bm25_with_csvec_optimized(
269        query_vec: &DefaultSparseVec<u16>, // クエリのTF-IDFベクトル(u16)
270        doc_vec: &DefaultSparseVec<u16>,   // 文書のTF-IDFベクトル(u16)
271        doc_len: u64,           // 文書のトークン数
272        avg_doc_len: f64,       // 平均文書長
273        k1: f64,                // BM25のパラメータ
274        b: f64,                 // 文書長補正のパラメータ
275    ) -> f64 {
276        // 文書長補正を計算
277        let len_norm = 1.0 - b + b * (doc_len as f64 / avg_doc_len);
278
279        // 定数の事前計算
280        const MAX_U16_AS_F64: f64 = 1.0 / (u16::MAX as f64); // 1 / 65535.0
281        let k1_len_norm = k1 * len_norm;
282
283        // クエリと文書のインデックスおよびデータ配列を直接取得
284        let (query_indices, query_data) = (query_vec.as_slice_ind(), query_vec.as_slice_val());
285        let (doc_indices, doc_data) = (doc_vec.as_slice_ind(), doc_vec.as_slice_val());
286
287        let (mut q, mut d) = (0, 0);
288        let (q_len, d_len) = (query_vec.nnz(), doc_vec.nnz());
289
290        let mut score = 0.0;
291        while q < q_len && d < d_len {
292            let q_idx = query_indices[q];
293            let d_idx = doc_indices[d];
294
295            if q_idx == d_idx {
296                let tf_f = (doc_data[d] as f64) * MAX_U16_AS_F64;
297                let idf_f = (query_data[q] as f64) * MAX_U16_AS_F64;
298
299                let numerator = tf_f * (k1 + 1.0);
300                let denominator = tf_f + k1_len_norm;
301                score += idf_f * (numerator / denominator);
302
303                q += 1;
304                d += 1;
305            } else if q_idx < d_idx {
306                q += 1;
307            } else {
308                d += 1;
309            }
310        }
311        score
312    }
313
314    // ---------------------------------------------------------------------------------------------
315    // プライベート: クエリ(&str のスライス)を CsVec<u16> に変換 (IDF を用いた TF-IDF)
316    // ---------------------------------------------------------------------------------------------
317    fn build_query_csvec(&self, query: &[&str]) -> DefaultSparseVec<u16> {
318        // 1) クエリトークン頻度を作成
319        let mut freq = TokenFrequency::new();
320        freq.add_tokens(query);
321
322        // 2) IDF からクエリの TF-IDF (u16) を生成
323        let query_tfidf_map: HashMap<String, u16> = freq.get_tfidf_hashmap_fst_parallel(&self.idf);
324
325        // 3) IDF の順序でソートされた Vec<u16> を作る
326        let mut sorted_tfidf = Vec::new();
327        let mut stream = self.idf.stream();
328        while let Some((token_bytes, _)) = stream.next() {
329            // トークンは bytes -> &str へ変換
330            let token_str = str::from_utf8(token_bytes).unwrap_or("");
331            let tfidf = query_tfidf_map.get(token_str).copied().unwrap_or(0);
332            sorted_tfidf.push(tfidf);
333        }
334
335        // 4) CsVec に変換して返す
336        DefaultSparseVec::from(sorted_tfidf)
337    }
338
339    // ---------------------------------------------------------------------------------------------
340    // プライベート: コサイン類似度
341    // ---------------------------------------------------------------------------------------------
342    fn cos_similarity(vec_a: &DefaultSparseVec<u16>, vec_b: &DefaultSparseVec<u16>) -> f64 {
343        // 内積
344        let dot_product = vec_a.u64_dot(vec_b) as f64;
345
346        // ノルム(ベクトルの長さ)
347        let norm_a = (vec_a.u64_dot(vec_a) as f64).sqrt();
348        let norm_b = (vec_b.u64_dot(vec_b) as f64).sqrt();
349
350        // コサイン類似度を返す
351        if norm_a > 0.0 && norm_b > 0.0 {
352            dot_product / (norm_a * norm_b)
353        } else {
354            0.0
355        }
356    }
357
358    // // ---------------------------------------------------------------------------------------------
359    // // 公開メソッド: エクスポート
360    // // ---------------------------------------------------------------------------------------------
361    // pub fn export(&self, path: &Path) -> io::Result<()> {
362    //     let mut file = File::create(path)?;
363
364    //     // avg_tokens_len, max_tokens_len, total_doc_count を書き込み
365    //     file.write_all(&self.avg_tokens_len.to_le_bytes())?;
366    //     file.write_all(&self.max_tokens_len.to_le_bytes())?;
367    //     file.write_all(&self.total_doc_count.to_le_bytes())?;
368
369    //     // idf をバイト列として書き込み
370    //     let idf_bytes = self.idf.as_fst().to_vec();
371    //     let idf_len = idf_bytes.len() as u64;
372    //     file.write_all(&idf_len.to_le_bytes())?;
373    //     file.write_all(&idf_bytes)?;
374
375    //     // index のエントリ数を書き込み
376    //     let index_len = self.index.len() as u64;
377    //     file.write_all(&index_len.to_le_bytes())?;
378
379    //     // index の内容を書き込み
380    //     for (doc_id, (sparse_vec, token_num)) in &self.index {
381    //         // doc_id を文字列として書き込み
382    //         let doc_id_str = doc_id.to_string();
383    //         let doc_id_bytes = doc_id_str.as_bytes();
384    //         let doc_id_len = doc_id_bytes.len() as u64;
385    //         file.write_all(&doc_id_len.to_le_bytes())?;
386    //         file.write_all(doc_id_bytes)?;
387
388    //         // sparse_vec の内容を書き込み
389    //         let sparse_vec_len = sparse_vec.len() as u64;
390    //         file.write_all(&sparse_vec_len.to_le_bytes())?;
391    //         for (index, &value) in sparse_vec.iter() {
392    //             file.write_all(&value.to_le_bytes())?;
393    //         }
394
395    //         // token_num を書き込み
396    //         file.write_all(&token_num.to_le_bytes())?;
397    //     }
398
399    //     Ok(())
400    // }
401
402    // // ---------------------------------------------------------------------------------------------
403    // // 公開メソッド: インポート
404    // // ---------------------------------------------------------------------------------------------
405    // pub fn import(path: &Path) -> io::Result<Self> {
406    //     let mut file = File::open(path)?;
407
408    //     // avg_tokens_len, max_tokens_len, total_doc_count を読み込み
409    //     let avg_tokens_len = Self::read_u64(&mut file)?;
410    //     let max_tokens_len = Self::read_u64(&mut file)?;
411    //     let total_doc_count = Self::read_u64(&mut file)?;
412
413    //     // idf を読み込み
414    //     let idf_len = Self::read_u64(&mut file)?;
415    //     let mut idf_bytes = vec![0u8; idf_len as usize];
416    //     file.read_exact(&mut idf_bytes)?;
417    //     let idf = Map::new(idf_bytes)?;
418
419    //     // index のエントリ数を読み込み
420    //     let index_len = Self::read_u64(&mut file)?;
421
422    //     // index の内容を読み込み
423    //     let mut index = HashMap::new();
424    //     for _ in 0..index_len {
425    //         // doc_id を読み込み
426    //         let doc_id_len = Self::read_u64(&mut file)?;
427    //         let mut doc_id_bytes = vec![0u8; doc_id_len as usize];
428    //         file.read_exact(&mut doc_id_bytes)?;
429    //         let doc_id = String::from_utf8(doc_id_bytes)
430    //             .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
431    //         let doc_id: IdType = doc_id.into();
432
433    //         // sparse_vec を読み込み
434    //         let sparse_vec_len = Self::read_u64(&mut file)?;
435    //         let mut sparse_vec = Vec::new();
436    //         for _ in 0..sparse_vec_len {
437    //             sparse_vec.push(Self::read_u16(&mut file)?);
438    //         }
439
440    //         // token_num を読み込み
441    //         let token_num = Self::read_u64(&mut file)?;
442
443    //         index.insert(doc_id, (sparse_vec, token_num));
444    //     }
445
446    //     Ok(Self {
447    //         index,
448    //         avg_tokens_len,
449    //         max_tokens_len,
450    //         idf,
451    //         total_doc_count,
452    //     })
453    // }
454
455    // /// ユーティリティ関数:u64を読み取る
456    // fn read_u64<R: Read>(reader: &mut R) -> io::Result<u64> {
457    //     let mut buffer = [0u8; 8];
458    //     reader.read_exact(&mut buffer)?;
459    //     Ok(u64::from_le_bytes(buffer))
460    // }
461
462    // /// ユーティリティ関数:u16を読み取る
463    // fn read_u16<R: Read>(reader: &mut R) -> io::Result<u16> {
464    //     let mut buffer = [0u8; 2];
465    //     reader.read_exact(&mut buffer)?;
466    //     Ok(u16::from_le_bytes(buffer))
467    // }
468}