tf_idf_vectorizer/vectorizer/index.rs
1
2use std::collections::HashMap;
3use std::str;
4
5use fst::{Map, MapBuilder, Streamer};
6use serde::Serialize;
7use vec_plus::vec::{default_sparse_vec::DefaultSparseVec, vec_trait::Math};
8
9use super::token::TokenFrequency;
10
11
12#[derive(Clone, Debug,)]
13pub struct Index<IdType>
14where
15 IdType: Clone + Eq + std::hash::Hash + Serialize + std::fmt::Debug,
16{
17 // doc_id -> (圧縮ベクトル, 文書の総トークン数)
18 // 圧縮ベクトル: インデックス順にトークンの TF を保持
19 pub index: HashMap<IdType, (DefaultSparseVec<u16>, u64 /* token num */)>,
20 pub avg_tokens_len: u64, // 全文書の平均トークン長
21 pub max_tokens_len: u64, // 全文書の最大トークン長
22 pub idf: Map<Vec<u8>>, // fst::Map 形式の IDF
23 pub total_doc_count: u64, // 文書総数
24}
25
26impl<IdType> Index<IdType>
27where
28 IdType: Clone + Eq + std::hash::Hash + Serialize + std::fmt::Debug,
29{
30 // ---------------------------------------------------------------------------------------------
31 // コンストラクタ
32 // ---------------------------------------------------------------------------------------------
33 pub fn new_with_set(
34 index: HashMap<IdType, (DefaultSparseVec<u16>, u64)>,
35 idf: Map<Vec<u8>>,
36 avg_tokens_len: u64,
37 max_tokens_len: u64,
38 total_doc_count: u64,
39 ) -> Self {
40 Self {
41 index,
42 idf,
43 avg_tokens_len,
44 max_tokens_len,
45 total_doc_count,
46 }
47 }
48
49 pub fn get_index(&self) -> &HashMap<IdType, (DefaultSparseVec<u16>, u64)> {
50 &self.index
51 }
52
53 // ---------------------------------------------------------------------------------------------
54 // 公開メソッド: 検索 (Cosine Similarity)
55 // ---------------------------------------------------------------------------------------------
56
57 /// 単純なコサイン類似度検索
58 pub fn search_cos_similarity(&self, query: &[&str], n: usize) -> Vec<(&IdType, f64)> {
59 // クエリの CsVec を作成
60 let query_csvec = self.build_query_csvec(query);
61
62 // 類似度スコアを計算
63 let mut similarities = self
64 .index
65 .iter()
66 .filter_map(|(id, (doc_vec, _doc_len))| {
67 let similarity = Self::cos_similarity(doc_vec, &query_csvec);
68 (similarity > 0.0).then(|| (id, similarity))
69 })
70 .collect::<Vec<_>>();
71
72 // スコア降順でソートして上位 n 件を返す
73 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
74 similarities.truncate(n);
75
76 similarities
77 }
78
79 /// 文書長を正規化するパラメータを入れたコサイン類似度検索
80 pub fn search_cos_similarity_tuned(&self, query: &[&str], n: usize, b: f64) -> Vec<(&IdType, f64)> {
81 let query_csvec = self.build_query_csvec(query);
82
83 let max_for_len_norm = self.max_tokens_len as f64 / self.avg_tokens_len as f64;
84
85 let mut similarities = self
86 .index
87 .iter()
88 .filter_map(|(id, (doc_vec, doc_len))| {
89 // 0.5 + ( (doc_len / avg_tokens_len) / max_for_len_norm - 0.5 ) * b
90 let len_norm = 0.5
91 + (((*doc_len as f64 / self.avg_tokens_len as f64) / max_for_len_norm) - 0.5) * b;
92
93 let similarity = Self::cos_similarity(doc_vec, &query_csvec) * len_norm;
94 (similarity > 0.0).then(|| (id, similarity))
95 })
96 .collect::<Vec<_>>();
97
98 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
99 similarities.truncate(n);
100
101 similarities
102 }
103
104 // ---------------------------------------------------------------------------------------------
105 // 公開メソッド: BM25 x TF-IDF 検索
106 // ---------------------------------------------------------------------------------------------
107
108 pub fn search_bm25_tfidf(&self, query: &[&str], n: usize, k1: f64, b: f64) -> Vec<(&IdType, f64)> {
109 let query_csvec = self.build_query_csvec(query);
110
111 let mut similarities = self
112 .index
113 .iter()
114 .filter_map(|(id, (doc_vec, doc_len))| {
115 let score = Self::bm25_with_csvec_optimized(
116 &query_csvec,
117 doc_vec,
118 *doc_len,
119 self.avg_tokens_len as f64,
120 k1,
121 b,
122 );
123 (score > 0.0).then(|| (id, score))
124 })
125 .collect::<Vec<_>>();
126
127 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
128 similarities.truncate(n);
129
130 similarities
131 }
132
133 // ---------------------------------------------------------------------------------------------
134 // 公開メソッド: Index の合成
135 // ---------------------------------------------------------------------------------------------
136
137 pub fn synthesize_index(&mut self, mut other: Self /* otherを優先 */) {
138 let new_max_token_len = self.max_tokens_len.max(other.max_tokens_len);
139 let new_total_doc_count = self.total_doc_count + other.total_doc_count;
140 // 加重平均で平均トークン長を再計算
141 let sum_self = self.avg_tokens_len as u128 * self.total_doc_count as u128;
142 let sum_other = other.avg_tokens_len as u128 * other.total_doc_count as u128;
143 let new_avg_token_len =
144 ((sum_self + sum_other) / new_total_doc_count as u128) as u64;
145
146 // その他の初期計算
147 let this_max_idf = (1.0 + self.total_doc_count as f64 / (2.0)).ln() as f32;
148 let other_max_idf = (1.0 + other.total_doc_count as f64 / (2.0)).ln() as f32;
149 let combined_max_idf = (1.0 + new_total_doc_count as f64 / (2.0)).ln() as f32;
150
151 // 値の準備
152 let mut builder = MapBuilder::memory();
153 let mut this_stream = self.idf.stream();
154 let mut other_stream = other.idf.stream();
155 let mut new_index_index: usize = 0;
156 let mut this_new_csvec_index_vec: Vec<usize> = Vec::new();
157 let mut other_new_csvec_index_vec: Vec<usize> = Vec::new();
158
159 let mut next_this = this_stream.next();
160 let mut next_other = other_stream.next();
161 // 両方のidfを合成, csvecのindexを再計算
162 while next_this != None && next_other != None {
163 let (this_token, this_idf ) = next_this.unwrap();
164 let (other_token, other_idf) = next_other.unwrap();
165 if this_token < other_token {
166 builder.insert(this_token, this_idf).unwrap();
167 next_this = this_stream.next();
168 this_new_csvec_index_vec.push(new_index_index);
169 // otherのindexはそのまま
170
171 } else if this_token == other_token {
172 builder.insert(this_token, Self::synthesize_idf(this_idf, other_idf,
173 self.total_doc_count,
174 other.total_doc_count,
175 new_total_doc_count,
176 this_max_idf,
177 other_max_idf,
178 combined_max_idf
179 )).unwrap();
180 next_this = this_stream.next();
181 this_new_csvec_index_vec.push(new_index_index);
182 next_other = other_stream.next();
183 other_new_csvec_index_vec.push(new_index_index);
184 } else {
185 builder.insert(other_token, other_idf).unwrap();
186 // thisのindexはそのまま
187
188 next_other = other_stream.next();
189 other_new_csvec_index_vec.push(new_index_index);
190 }
191 new_index_index += 1;
192 }
193 if next_this != None {
194 loop {
195 let (this_token, this_idf) = next_this.unwrap();
196 builder.insert(this_token, this_idf).unwrap();
197 next_this = this_stream.next();
198 this_new_csvec_index_vec.push(new_index_index);
199 new_index_index += 1;
200 if next_this == None {
201 break;
202 }
203 }
204 } else if next_other != None {
205 loop {
206 let (other_token, other_idf) = next_other.unwrap();
207 builder.insert(other_token, other_idf).unwrap();
208 next_other = other_stream.next();
209 other_new_csvec_index_vec.push(new_index_index);
210 new_index_index += 1;
211 if next_other == None {
212 break;
213 }
214 }
215 }
216 let new_idf = builder.into_map();
217
218 // csvecのindexを合成
219 self.index.iter_mut().for_each(|(_id, (csvec, _))| {
220 let indices = csvec.as_mut_slice_ind();
221 for indice in indices {
222 *indice = this_new_csvec_index_vec[*indice];
223 }
224 });
225
226 other.index.iter_mut().for_each(|(_id, (csvec, _))| {
227 let indices = csvec.as_mut_slice_ind();
228 for indice in indices {
229 *indice = other_new_csvec_index_vec[*indice];
230 }
231 });
232
233 // インデックスの合成
234 self.index.extend(other.index);
235 self.avg_tokens_len = new_avg_token_len;
236 self.max_tokens_len = new_max_token_len;
237 self.idf = new_idf;
238 self.total_doc_count = new_total_doc_count;
239 }
240
241 // ---------------------------------------------------------------------------------------------
242 // プライベート:IDF の合成
243 // ---------------------------------------------------------------------------------------------
244 #[inline(always)]
245 fn synthesize_idf(
246 this_idf: u64,
247 other_idf: u64,
248 this_doc_count: u64,
249 other_doc_count: u64,
250 total_doc_count: u64,
251 this_max_idf: f32,
252 other_max_idf: f32,
253 combined_max_idf: f32,
254 ) -> u64 {
255 const MAX_U16: f32 = 65535.0;
256
257 let a = (this_idf as f32 * this_max_idf / MAX_U16).exp();
258 let b = (other_idf as f32 * other_max_idf / MAX_U16).exp();
259
260 let denominator = (this_doc_count as f32) / a + (other_doc_count as f32) / b - 2.0;
261 let inner = 1.0 + (total_doc_count as f32) / denominator;
262 ((inner.ln() / combined_max_idf) * MAX_U16).round() as u64
263 }
264
265 // ---------------------------------------------------------------------------------------------
266 // BM25 実装 (公開: ほかで呼び出したい場合のみ pub)
267 // ---------------------------------------------------------------------------------------------
268 pub fn bm25_with_csvec_optimized(
269 query_vec: &DefaultSparseVec<u16>, // クエリのTF-IDFベクトル(u16)
270 doc_vec: &DefaultSparseVec<u16>, // 文書のTF-IDFベクトル(u16)
271 doc_len: u64, // 文書のトークン数
272 avg_doc_len: f64, // 平均文書長
273 k1: f64, // BM25のパラメータ
274 b: f64, // 文書長補正のパラメータ
275 ) -> f64 {
276 // 文書長補正を計算
277 let len_norm = 1.0 - b + b * (doc_len as f64 / avg_doc_len);
278
279 // 定数の事前計算
280 const MAX_U16_AS_F64: f64 = 1.0 / (u16::MAX as f64); // 1 / 65535.0
281 let k1_len_norm = k1 * len_norm;
282
283 // クエリと文書のインデックスおよびデータ配列を直接取得
284 let (query_indices, query_data) = (query_vec.as_slice_ind(), query_vec.as_slice_val());
285 let (doc_indices, doc_data) = (doc_vec.as_slice_ind(), doc_vec.as_slice_val());
286
287 let (mut q, mut d) = (0, 0);
288 let (q_len, d_len) = (query_vec.nnz(), doc_vec.nnz());
289
290 let mut score = 0.0;
291 while q < q_len && d < d_len {
292 let q_idx = query_indices[q];
293 let d_idx = doc_indices[d];
294
295 if q_idx == d_idx {
296 let tf_f = (doc_data[d] as f64) * MAX_U16_AS_F64;
297 let idf_f = (query_data[q] as f64) * MAX_U16_AS_F64;
298
299 let numerator = tf_f * (k1 + 1.0);
300 let denominator = tf_f + k1_len_norm;
301 score += idf_f * (numerator / denominator);
302
303 q += 1;
304 d += 1;
305 } else if q_idx < d_idx {
306 q += 1;
307 } else {
308 d += 1;
309 }
310 }
311 score
312 }
313
314 // ---------------------------------------------------------------------------------------------
315 // プライベート: クエリ(&str のスライス)を CsVec<u16> に変換 (IDF を用いた TF-IDF)
316 // ---------------------------------------------------------------------------------------------
317 fn build_query_csvec(&self, query: &[&str]) -> DefaultSparseVec<u16> {
318 // 1) クエリトークン頻度を作成
319 let mut freq = TokenFrequency::new();
320 freq.add_tokens(query);
321
322 // 2) IDF からクエリの TF-IDF (u16) を生成
323 let query_tfidf_map: HashMap<String, u16> = freq.get_tfidf_hashmap_fst_parallel(&self.idf);
324
325 // 3) IDF の順序でソートされた Vec<u16> を作る
326 let mut sorted_tfidf = Vec::new();
327 let mut stream = self.idf.stream();
328 while let Some((token_bytes, _)) = stream.next() {
329 // トークンは bytes -> &str へ変換
330 let token_str = str::from_utf8(token_bytes).unwrap_or("");
331 let tfidf = query_tfidf_map.get(token_str).copied().unwrap_or(0);
332 sorted_tfidf.push(tfidf);
333 }
334
335 // 4) CsVec に変換して返す
336 DefaultSparseVec::from(sorted_tfidf)
337 }
338
339 // ---------------------------------------------------------------------------------------------
340 // プライベート: コサイン類似度
341 // ---------------------------------------------------------------------------------------------
342 fn cos_similarity(vec_a: &DefaultSparseVec<u16>, vec_b: &DefaultSparseVec<u16>) -> f64 {
343 // 内積
344 let dot_product = vec_a.u64_dot(vec_b) as f64;
345
346 // ノルム(ベクトルの長さ)
347 let norm_a = (vec_a.u64_dot(vec_a) as f64).sqrt();
348 let norm_b = (vec_b.u64_dot(vec_b) as f64).sqrt();
349
350 // コサイン類似度を返す
351 if norm_a > 0.0 && norm_b > 0.0 {
352 dot_product / (norm_a * norm_b)
353 } else {
354 0.0
355 }
356 }
357
358 // // ---------------------------------------------------------------------------------------------
359 // // 公開メソッド: エクスポート
360 // // ---------------------------------------------------------------------------------------------
361 // pub fn export(&self, path: &Path) -> io::Result<()> {
362 // let mut file = File::create(path)?;
363
364 // // avg_tokens_len, max_tokens_len, total_doc_count を書き込み
365 // file.write_all(&self.avg_tokens_len.to_le_bytes())?;
366 // file.write_all(&self.max_tokens_len.to_le_bytes())?;
367 // file.write_all(&self.total_doc_count.to_le_bytes())?;
368
369 // // idf をバイト列として書き込み
370 // let idf_bytes = self.idf.as_fst().to_vec();
371 // let idf_len = idf_bytes.len() as u64;
372 // file.write_all(&idf_len.to_le_bytes())?;
373 // file.write_all(&idf_bytes)?;
374
375 // // index のエントリ数を書き込み
376 // let index_len = self.index.len() as u64;
377 // file.write_all(&index_len.to_le_bytes())?;
378
379 // // index の内容を書き込み
380 // for (doc_id, (sparse_vec, token_num)) in &self.index {
381 // // doc_id を文字列として書き込み
382 // let doc_id_str = doc_id.to_string();
383 // let doc_id_bytes = doc_id_str.as_bytes();
384 // let doc_id_len = doc_id_bytes.len() as u64;
385 // file.write_all(&doc_id_len.to_le_bytes())?;
386 // file.write_all(doc_id_bytes)?;
387
388 // // sparse_vec の内容を書き込み
389 // let sparse_vec_len = sparse_vec.len() as u64;
390 // file.write_all(&sparse_vec_len.to_le_bytes())?;
391 // for (index, &value) in sparse_vec.iter() {
392 // file.write_all(&value.to_le_bytes())?;
393 // }
394
395 // // token_num を書き込み
396 // file.write_all(&token_num.to_le_bytes())?;
397 // }
398
399 // Ok(())
400 // }
401
402 // // ---------------------------------------------------------------------------------------------
403 // // 公開メソッド: インポート
404 // // ---------------------------------------------------------------------------------------------
405 // pub fn import(path: &Path) -> io::Result<Self> {
406 // let mut file = File::open(path)?;
407
408 // // avg_tokens_len, max_tokens_len, total_doc_count を読み込み
409 // let avg_tokens_len = Self::read_u64(&mut file)?;
410 // let max_tokens_len = Self::read_u64(&mut file)?;
411 // let total_doc_count = Self::read_u64(&mut file)?;
412
413 // // idf を読み込み
414 // let idf_len = Self::read_u64(&mut file)?;
415 // let mut idf_bytes = vec![0u8; idf_len as usize];
416 // file.read_exact(&mut idf_bytes)?;
417 // let idf = Map::new(idf_bytes)?;
418
419 // // index のエントリ数を読み込み
420 // let index_len = Self::read_u64(&mut file)?;
421
422 // // index の内容を読み込み
423 // let mut index = HashMap::new();
424 // for _ in 0..index_len {
425 // // doc_id を読み込み
426 // let doc_id_len = Self::read_u64(&mut file)?;
427 // let mut doc_id_bytes = vec![0u8; doc_id_len as usize];
428 // file.read_exact(&mut doc_id_bytes)?;
429 // let doc_id = String::from_utf8(doc_id_bytes)
430 // .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
431 // let doc_id: IdType = doc_id.into();
432
433 // // sparse_vec を読み込み
434 // let sparse_vec_len = Self::read_u64(&mut file)?;
435 // let mut sparse_vec = Vec::new();
436 // for _ in 0..sparse_vec_len {
437 // sparse_vec.push(Self::read_u16(&mut file)?);
438 // }
439
440 // // token_num を読み込み
441 // let token_num = Self::read_u64(&mut file)?;
442
443 // index.insert(doc_id, (sparse_vec, token_num));
444 // }
445
446 // Ok(Self {
447 // index,
448 // avg_tokens_len,
449 // max_tokens_len,
450 // idf,
451 // total_doc_count,
452 // })
453 // }
454
455 // /// ユーティリティ関数:u64を読み取る
456 // fn read_u64<R: Read>(reader: &mut R) -> io::Result<u64> {
457 // let mut buffer = [0u8; 8];
458 // reader.read_exact(&mut buffer)?;
459 // Ok(u64::from_le_bytes(buffer))
460 // }
461
462 // /// ユーティリティ関数:u16を読み取る
463 // fn read_u16<R: Read>(reader: &mut R) -> io::Result<u16> {
464 // let mut buffer = [0u8; 2];
465 // reader.read_exact(&mut buffer)?;
466 // Ok(u16::from_le_bytes(buffer))
467 // }
468}