1use std::collections::HashMap;
60use std::sync::Arc;
61
62use crate::candidate_gate::AllowedSet;
63use crate::filtered_vector_search::ScoredResult;
64
65#[derive(Debug, Clone)]
71pub struct Bm25Params {
72 pub k1: f32,
74 pub b: f32,
76 pub avgdl: f32,
78 pub total_docs: u64,
80}
81
82impl Default for Bm25Params {
83 fn default() -> Self {
84 Self {
85 k1: 1.2,
86 b: 0.75,
87 avgdl: 100.0,
88 total_docs: 1_000_000,
89 }
90 }
91}
92
93impl Bm25Params {
94 pub fn idf(&self, doc_freq: u64) -> f32 {
96 let n = self.total_docs as f32;
97 let df = doc_freq as f32;
98 ((n - df + 0.5) / (df + 0.5) + 1.0).ln()
99 }
100
101 pub fn term_score(&self, tf: f32, doc_len: f32, idf: f32) -> f32 {
103 let numerator = tf * (self.k1 + 1.0);
104 let denominator = tf + self.k1 * (1.0 - self.b + self.b * doc_len / self.avgdl);
105 idf * numerator / denominator
106 }
107}
108
109#[derive(Debug, Clone)]
115pub struct PostingList {
116 pub term: String,
118 pub doc_ids: Vec<u64>,
120 pub term_freqs: Vec<u32>,
122 pub doc_freq: u64,
124}
125
126impl PostingList {
127 pub fn new(term: impl Into<String>, entries: Vec<(u64, u32)>) -> Self {
129 let term = term.into();
130 let doc_freq = entries.len() as u64;
131 let mut doc_ids = Vec::with_capacity(entries.len());
132 let mut term_freqs = Vec::with_capacity(entries.len());
133
134 for (doc_id, tf) in entries {
135 doc_ids.push(doc_id);
136 term_freqs.push(tf);
137 }
138
139 Self {
140 term,
141 doc_ids,
142 term_freqs,
143 doc_freq,
144 }
145 }
146
147 pub fn intersect_with_allowed(&self, allowed: &AllowedSet) -> Vec<(u64, u32)> {
149 match allowed {
150 AllowedSet::All => {
151 self.doc_ids.iter()
152 .zip(self.term_freqs.iter())
153 .map(|(&id, &tf)| (id, tf))
154 .collect()
155 }
156 AllowedSet::None => vec![],
157 _ => {
158 self.doc_ids.iter()
159 .zip(self.term_freqs.iter())
160 .filter(|&(&id, _)| allowed.contains(id))
161 .map(|(&id, &tf)| (id, tf))
162 .collect()
163 }
164 }
165 }
166}
167
168pub trait InvertedIndex: Send + Sync {
174 fn get_posting_list(&self, term: &str) -> Option<PostingList>;
176
177 fn get_doc_length(&self, doc_id: u64) -> Option<u32>;
179
180 fn get_params(&self) -> &Bm25Params;
182}
183
184pub struct FilteredBm25Executor<I: InvertedIndex> {
190 index: Arc<I>,
191}
192
193impl<I: InvertedIndex> FilteredBm25Executor<I> {
194 pub fn new(index: Arc<I>) -> Self {
196 Self { index }
197 }
198
199 pub fn search(
210 &self,
211 query: &str,
212 k: usize,
213 allowed: &AllowedSet,
214 ) -> Vec<ScoredResult> {
215 if allowed.is_empty() {
217 return vec![];
218 }
219
220 let terms: Vec<&str> = query
222 .split_whitespace()
223 .filter(|t| t.len() >= 2) .collect();
225
226 if terms.is_empty() {
227 return vec![];
228 }
229
230 let mut posting_lists: Vec<PostingList> = terms
232 .iter()
233 .filter_map(|t| self.index.get_posting_list(t))
234 .collect();
235
236 posting_lists.sort_by_key(|pl| pl.doc_freq);
238
239 let candidates = self.progressive_intersection(&posting_lists, allowed);
241
242 if candidates.is_empty() {
243 return vec![];
244 }
245
246 let params = self.index.get_params();
248 let scores = self.score_candidates(&candidates, &posting_lists, params);
249
250 self.top_k(scores, k)
252 }
253
254 fn progressive_intersection(
258 &self,
259 posting_lists: &[PostingList],
260 allowed: &AllowedSet,
261 ) -> HashMap<u64, Vec<u32>> {
262 if posting_lists.is_empty() {
263 return HashMap::new();
264 }
265
266 let first = &posting_lists[0];
268 let mut candidates: HashMap<u64, Vec<u32>> = first
269 .intersect_with_allowed(allowed)
270 .into_iter()
271 .map(|(id, tf)| (id, vec![tf]))
272 .collect();
273
274 for (_term_idx, posting_list) in posting_lists.iter().enumerate().skip(1) {
276 let term_postings: HashMap<u64, u32> = posting_list
278 .doc_ids.iter()
279 .zip(posting_list.term_freqs.iter())
280 .map(|(&id, &tf)| (id, tf))
281 .collect();
282
283 candidates.retain(|doc_id, tfs| {
285 if let Some(&tf) = term_postings.get(doc_id) {
286 tfs.push(tf);
287 true
288 } else {
289 false
290 }
291 });
292
293 if candidates.is_empty() {
295 break;
296 }
297 }
298
299 candidates
300 }
301
302 fn score_candidates(
304 &self,
305 candidates: &HashMap<u64, Vec<u32>>,
306 posting_lists: &[PostingList],
307 params: &Bm25Params,
308 ) -> Vec<ScoredResult> {
309 let idfs: Vec<f32> = posting_lists
311 .iter()
312 .map(|pl| params.idf(pl.doc_freq))
313 .collect();
314
315 candidates
316 .iter()
317 .filter_map(|(&doc_id, tfs)| {
318 let doc_len = self.index.get_doc_length(doc_id)? as f32;
319
320 let score: f32 = tfs.iter()
321 .zip(idfs.iter())
322 .map(|(&tf, &idf)| params.term_score(tf as f32, doc_len, idf))
323 .sum();
324
325 Some(ScoredResult::new(doc_id, score))
326 })
327 .collect()
328 }
329
330 fn top_k(&self, mut scores: Vec<ScoredResult>, k: usize) -> Vec<ScoredResult> {
332 scores.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
334 scores.truncate(k);
335 scores
336 }
337}
338
339pub struct DisjunctiveBm25Executor<I: InvertedIndex> {
345 index: Arc<I>,
346}
347
348impl<I: InvertedIndex> DisjunctiveBm25Executor<I> {
349 pub fn new(index: Arc<I>) -> Self {
351 Self { index }
352 }
353
354 pub fn search(
356 &self,
357 query: &str,
358 k: usize,
359 allowed: &AllowedSet,
360 ) -> Vec<ScoredResult> {
361 if allowed.is_empty() {
362 return vec![];
363 }
364
365 let terms: Vec<&str> = query.split_whitespace().collect();
366 if terms.is_empty() {
367 return vec![];
368 }
369
370 let posting_lists: Vec<PostingList> = terms
372 .iter()
373 .filter_map(|t| self.index.get_posting_list(t))
374 .collect();
375
376 let params = self.index.get_params();
377
378 let mut scores: HashMap<u64, f32> = HashMap::new();
380
381 for posting_list in &posting_lists {
382 let idf = params.idf(posting_list.doc_freq);
383
384 for (&doc_id, &tf) in posting_list.doc_ids.iter().zip(posting_list.term_freqs.iter()) {
386 if !allowed.contains(doc_id) {
387 continue;
388 }
389
390 if let Some(doc_len) = self.index.get_doc_length(doc_id) {
391 let term_score = params.term_score(tf as f32, doc_len as f32, idf);
392 *scores.entry(doc_id).or_insert(0.0) += term_score;
393 }
394 }
395 }
396
397 let mut results: Vec<ScoredResult> = scores
399 .into_iter()
400 .map(|(id, score)| ScoredResult::new(id, score))
401 .collect();
402
403 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
404 results.truncate(k);
405 results
406 }
407}
408
409#[derive(Debug, Clone)]
415pub struct PositionalPosting {
416 pub doc_id: u64,
418 pub positions: Vec<u32>,
420}
421
422pub trait PositionalIndex: InvertedIndex {
424 fn get_positional_posting(&self, term: &str) -> Option<Vec<PositionalPosting>>;
426}
427
428pub struct FilteredPhraseExecutor<I: PositionalIndex> {
430 index: Arc<I>,
431}
432
433impl<I: PositionalIndex> FilteredPhraseExecutor<I> {
434 pub fn new(index: Arc<I>) -> Self {
436 Self { index }
437 }
438
439 pub fn search(
443 &self,
444 phrase: &[&str],
445 k: usize,
446 allowed: &AllowedSet,
447 ) -> Vec<ScoredResult> {
448 if phrase.is_empty() || allowed.is_empty() {
449 return vec![];
450 }
451
452 let mut positional_postings: Vec<Vec<PositionalPosting>> = vec![];
454 for term in phrase {
455 match self.index.get_positional_posting(term) {
456 Some(postings) => positional_postings.push(postings),
457 None => return vec![], }
459 }
460
461 let candidates = self.find_phrase_matches(&positional_postings, allowed);
463
464 let params = self.index.get_params();
466 let results: Vec<ScoredResult> = candidates
467 .into_iter()
468 .filter_map(|(doc_id, phrase_freq)| {
469 let doc_len = self.index.get_doc_length(doc_id)? as f32;
470 let min_df = positional_postings.iter()
472 .map(|pp| pp.len() as u64)
473 .min()
474 .unwrap_or(1);
475 let idf = params.idf(min_df);
476 let score = params.term_score(phrase_freq as f32, doc_len, idf);
477 Some(ScoredResult::new(doc_id, score))
478 })
479 .collect();
480
481 let mut results = results;
482 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
483 results.truncate(k);
484 results
485 }
486
487 fn find_phrase_matches(
489 &self,
490 positional_postings: &[Vec<PositionalPosting>],
491 allowed: &AllowedSet,
492 ) -> Vec<(u64, u32)> {
493 if positional_postings.is_empty() {
494 return vec![];
495 }
496
497 let indexed: Vec<HashMap<u64, &Vec<u32>>> = positional_postings
499 .iter()
500 .map(|postings| {
501 postings.iter()
502 .filter(|p| allowed.contains(p.doc_id))
503 .map(|p| (p.doc_id, &p.positions))
504 .collect()
505 })
506 .collect();
507
508 let first_docs: std::collections::HashSet<u64> = indexed[0].keys().copied().collect();
510
511 let candidate_docs: Vec<u64> = first_docs
513 .into_iter()
514 .filter(|doc_id| indexed.iter().all(|idx| idx.contains_key(doc_id)))
515 .collect();
516
517 let mut matches = vec![];
519
520 for doc_id in candidate_docs {
521 let mut phrase_count = 0u32;
522
523 let first_positions = indexed[0].get(&doc_id).unwrap();
525
526 'outer: for &start_pos in first_positions.iter() {
527 for (term_idx, term_positions) in indexed.iter().enumerate().skip(1) {
529 let expected_pos = start_pos + term_idx as u32;
530 let positions = term_positions.get(&doc_id).unwrap();
531
532 if positions.binary_search(&expected_pos).is_err() {
534 continue 'outer;
535 }
536 }
537
538 phrase_count += 1;
540 }
541
542 if phrase_count > 0 {
543 matches.push((doc_id, phrase_count));
544 }
545 }
546
547 matches
548 }
549}
550
551#[cfg(test)]
556mod tests {
557 use super::*;
558 use crate::candidate_gate::AllowedSet;
559
560 struct MockIndex {
562 postings: HashMap<String, PostingList>,
563 doc_lengths: HashMap<u64, u32>,
564 params: Bm25Params,
565 }
566
567 impl MockIndex {
568 fn new() -> Self {
569 let mut postings = HashMap::new();
570 let mut doc_lengths = HashMap::new();
571
572 postings.insert("rust".to_string(), PostingList::new("rust", vec![
574 (1, 3), (2, 1), (3, 2), (5, 1),
575 ]));
576 postings.insert("database".to_string(), PostingList::new("database", vec![
577 (1, 1), (3, 4), (4, 1),
578 ]));
579 postings.insert("vector".to_string(), PostingList::new("vector", vec![
580 (1, 2), (2, 3), (4, 1), (5, 2),
581 ]));
582
583 for i in 1..=5 {
585 doc_lengths.insert(i, 100);
586 }
587
588 Self {
589 postings,
590 doc_lengths,
591 params: Bm25Params {
592 k1: 1.2,
593 b: 0.75,
594 avgdl: 100.0,
595 total_docs: 1000,
596 },
597 }
598 }
599 }
600
601 impl InvertedIndex for MockIndex {
602 fn get_posting_list(&self, term: &str) -> Option<PostingList> {
603 self.postings.get(term).cloned()
604 }
605
606 fn get_doc_length(&self, doc_id: u64) -> Option<u32> {
607 self.doc_lengths.get(&doc_id).copied()
608 }
609
610 fn get_params(&self) -> &Bm25Params {
611 &self.params
612 }
613 }
614
615 #[test]
616 fn test_conjunctive_search() {
617 let index = Arc::new(MockIndex::new());
618 let executor = FilteredBm25Executor::new(index);
619
620 let results = executor.search("rust database", 10, &AllowedSet::All);
623
624 assert_eq!(results.len(), 2);
625 let doc_ids: Vec<u64> = results.iter().map(|r| r.doc_id).collect();
626 assert!(doc_ids.contains(&1));
627 assert!(doc_ids.contains(&3));
628 }
629
630 #[test]
631 fn test_filter_pushdown() {
632 let index = Arc::new(MockIndex::new());
633 let executor = FilteredBm25Executor::new(index);
634
635 let allowed = AllowedSet::SortedVec(Arc::new(vec![1]));
637
638 let results = executor.search("rust database", 10, &allowed);
639
640 assert_eq!(results.len(), 1);
641 assert_eq!(results[0].doc_id, 1);
642 }
643
644 #[test]
645 fn test_empty_allowed_set() {
646 let index = Arc::new(MockIndex::new());
647 let executor = FilteredBm25Executor::new(index);
648
649 let results = executor.search("rust", 10, &AllowedSet::None);
650 assert!(results.is_empty());
651 }
652
653 #[test]
654 fn test_disjunctive_search() {
655 let index = Arc::new(MockIndex::new());
656 let executor = DisjunctiveBm25Executor::new(index);
657
658 let results = executor.search("rust database", 10, &AllowedSet::All);
661
662 assert!(results.len() >= 4);
664 }
665
666 #[test]
667 fn test_term_ordering_by_df() {
668 let mut pl1 = PostingList::new("rare", vec![(1, 1), (2, 1)]); let mut pl2 = PostingList::new("common", vec![(1, 1), (2, 1), (3, 1), (4, 1), (5, 1)]); let mut lists = vec![pl2.clone(), pl1.clone()];
673 lists.sort_by_key(|pl| pl.doc_freq);
674
675 assert_eq!(lists[0].term, "rare");
677 assert_eq!(lists[1].term, "common");
678 }
679
680 #[test]
681 fn test_bm25_scoring() {
682 let params = Bm25Params::default();
683
684 let idf_rare = params.idf(10);
686 let idf_common = params.idf(100_000);
687
688 assert!(idf_rare > idf_common);
689
690 let score_tf_1 = params.term_score(1.0, 100.0, idf_rare);
692 let score_tf_5 = params.term_score(5.0, 100.0, idf_rare);
693
694 assert!(score_tf_5 > score_tf_1);
695 }
696}