1use std::collections::HashMap;
57use std::sync::Arc;
58
59use crate::candidate_gate::AllowedSet;
60use crate::filtered_vector_search::ScoredResult;
61
62#[derive(Debug, Clone)]
68pub struct Bm25Params {
69 pub k1: f32,
71 pub b: f32,
73 pub avgdl: f32,
75 pub total_docs: u64,
77}
78
79impl Default for Bm25Params {
80 fn default() -> Self {
81 Self {
82 k1: 1.2,
83 b: 0.75,
84 avgdl: 100.0,
85 total_docs: 1_000_000,
86 }
87 }
88}
89
90impl Bm25Params {
91 pub fn idf(&self, doc_freq: u64) -> f32 {
93 let n = self.total_docs as f32;
94 let df = doc_freq as f32;
95 ((n - df + 0.5) / (df + 0.5) + 1.0).ln()
96 }
97
98 pub fn term_score(&self, tf: f32, doc_len: f32, idf: f32) -> f32 {
100 let numerator = tf * (self.k1 + 1.0);
101 let denominator = tf + self.k1 * (1.0 - self.b + self.b * doc_len / self.avgdl);
102 idf * numerator / denominator
103 }
104}
105
106#[derive(Debug, Clone)]
112pub struct PostingList {
113 pub term: String,
115 pub doc_ids: Vec<u64>,
117 pub term_freqs: Vec<u32>,
119 pub doc_freq: u64,
121}
122
123impl PostingList {
124 pub fn new(term: impl Into<String>, entries: Vec<(u64, u32)>) -> Self {
126 let term = term.into();
127 let doc_freq = entries.len() as u64;
128 let mut doc_ids = Vec::with_capacity(entries.len());
129 let mut term_freqs = Vec::with_capacity(entries.len());
130
131 for (doc_id, tf) in entries {
132 doc_ids.push(doc_id);
133 term_freqs.push(tf);
134 }
135
136 Self {
137 term,
138 doc_ids,
139 term_freqs,
140 doc_freq,
141 }
142 }
143
144 pub fn intersect_with_allowed(&self, allowed: &AllowedSet) -> Vec<(u64, u32)> {
146 match allowed {
147 AllowedSet::All => {
148 self.doc_ids.iter()
149 .zip(self.term_freqs.iter())
150 .map(|(&id, &tf)| (id, tf))
151 .collect()
152 }
153 AllowedSet::None => vec![],
154 _ => {
155 self.doc_ids.iter()
156 .zip(self.term_freqs.iter())
157 .filter(|&(&id, _)| allowed.contains(id))
158 .map(|(&id, &tf)| (id, tf))
159 .collect()
160 }
161 }
162 }
163}
164
165pub trait InvertedIndex: Send + Sync {
171 fn get_posting_list(&self, term: &str) -> Option<PostingList>;
173
174 fn get_doc_length(&self, doc_id: u64) -> Option<u32>;
176
177 fn get_params(&self) -> &Bm25Params;
179}
180
181pub struct FilteredBm25Executor<I: InvertedIndex> {
187 index: Arc<I>,
188}
189
190impl<I: InvertedIndex> FilteredBm25Executor<I> {
191 pub fn new(index: Arc<I>) -> Self {
193 Self { index }
194 }
195
196 pub fn search(
207 &self,
208 query: &str,
209 k: usize,
210 allowed: &AllowedSet,
211 ) -> Vec<ScoredResult> {
212 if allowed.is_empty() {
214 return vec![];
215 }
216
217 let terms: Vec<&str> = query
219 .split_whitespace()
220 .filter(|t| t.len() >= 2) .collect();
222
223 if terms.is_empty() {
224 return vec![];
225 }
226
227 let mut posting_lists: Vec<PostingList> = terms
229 .iter()
230 .filter_map(|t| self.index.get_posting_list(t))
231 .collect();
232
233 posting_lists.sort_by_key(|pl| pl.doc_freq);
235
236 let candidates = self.progressive_intersection(&posting_lists, allowed);
238
239 if candidates.is_empty() {
240 return vec![];
241 }
242
243 let params = self.index.get_params();
245 let scores = self.score_candidates(&candidates, &posting_lists, params);
246
247 self.top_k(scores, k)
249 }
250
251 fn progressive_intersection(
255 &self,
256 posting_lists: &[PostingList],
257 allowed: &AllowedSet,
258 ) -> HashMap<u64, Vec<u32>> {
259 if posting_lists.is_empty() {
260 return HashMap::new();
261 }
262
263 let first = &posting_lists[0];
265 let mut candidates: HashMap<u64, Vec<u32>> = first
266 .intersect_with_allowed(allowed)
267 .into_iter()
268 .map(|(id, tf)| (id, vec![tf]))
269 .collect();
270
271 for (_term_idx, posting_list) in posting_lists.iter().enumerate().skip(1) {
273 let term_postings: HashMap<u64, u32> = posting_list
275 .doc_ids.iter()
276 .zip(posting_list.term_freqs.iter())
277 .map(|(&id, &tf)| (id, tf))
278 .collect();
279
280 candidates.retain(|doc_id, tfs| {
282 if let Some(&tf) = term_postings.get(doc_id) {
283 tfs.push(tf);
284 true
285 } else {
286 false
287 }
288 });
289
290 if candidates.is_empty() {
292 break;
293 }
294 }
295
296 candidates
297 }
298
299 fn score_candidates(
301 &self,
302 candidates: &HashMap<u64, Vec<u32>>,
303 posting_lists: &[PostingList],
304 params: &Bm25Params,
305 ) -> Vec<ScoredResult> {
306 let idfs: Vec<f32> = posting_lists
308 .iter()
309 .map(|pl| params.idf(pl.doc_freq))
310 .collect();
311
312 candidates
313 .iter()
314 .filter_map(|(&doc_id, tfs)| {
315 let doc_len = self.index.get_doc_length(doc_id)? as f32;
316
317 let score: f32 = tfs.iter()
318 .zip(idfs.iter())
319 .map(|(&tf, &idf)| params.term_score(tf as f32, doc_len, idf))
320 .sum();
321
322 Some(ScoredResult::new(doc_id, score))
323 })
324 .collect()
325 }
326
327 fn top_k(&self, mut scores: Vec<ScoredResult>, k: usize) -> Vec<ScoredResult> {
329 scores.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
331 scores.truncate(k);
332 scores
333 }
334}
335
336pub struct DisjunctiveBm25Executor<I: InvertedIndex> {
342 index: Arc<I>,
343}
344
345impl<I: InvertedIndex> DisjunctiveBm25Executor<I> {
346 pub fn new(index: Arc<I>) -> Self {
348 Self { index }
349 }
350
351 pub fn search(
353 &self,
354 query: &str,
355 k: usize,
356 allowed: &AllowedSet,
357 ) -> Vec<ScoredResult> {
358 if allowed.is_empty() {
359 return vec![];
360 }
361
362 let terms: Vec<&str> = query.split_whitespace().collect();
363 if terms.is_empty() {
364 return vec![];
365 }
366
367 let posting_lists: Vec<PostingList> = terms
369 .iter()
370 .filter_map(|t| self.index.get_posting_list(t))
371 .collect();
372
373 let params = self.index.get_params();
374
375 let mut scores: HashMap<u64, f32> = HashMap::new();
377
378 for posting_list in &posting_lists {
379 let idf = params.idf(posting_list.doc_freq);
380
381 for (&doc_id, &tf) in posting_list.doc_ids.iter().zip(posting_list.term_freqs.iter()) {
383 if !allowed.contains(doc_id) {
384 continue;
385 }
386
387 if let Some(doc_len) = self.index.get_doc_length(doc_id) {
388 let term_score = params.term_score(tf as f32, doc_len as f32, idf);
389 *scores.entry(doc_id).or_insert(0.0) += term_score;
390 }
391 }
392 }
393
394 let mut results: Vec<ScoredResult> = scores
396 .into_iter()
397 .map(|(id, score)| ScoredResult::new(id, score))
398 .collect();
399
400 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
401 results.truncate(k);
402 results
403 }
404}
405
406#[derive(Debug, Clone)]
412pub struct PositionalPosting {
413 pub doc_id: u64,
415 pub positions: Vec<u32>,
417}
418
419pub trait PositionalIndex: InvertedIndex {
421 fn get_positional_posting(&self, term: &str) -> Option<Vec<PositionalPosting>>;
423}
424
425pub struct FilteredPhraseExecutor<I: PositionalIndex> {
427 index: Arc<I>,
428}
429
430impl<I: PositionalIndex> FilteredPhraseExecutor<I> {
431 pub fn new(index: Arc<I>) -> Self {
433 Self { index }
434 }
435
436 pub fn search(
440 &self,
441 phrase: &[&str],
442 k: usize,
443 allowed: &AllowedSet,
444 ) -> Vec<ScoredResult> {
445 if phrase.is_empty() || allowed.is_empty() {
446 return vec![];
447 }
448
449 let mut positional_postings: Vec<Vec<PositionalPosting>> = vec![];
451 for term in phrase {
452 match self.index.get_positional_posting(term) {
453 Some(postings) => positional_postings.push(postings),
454 None => return vec![], }
456 }
457
458 let candidates = self.find_phrase_matches(&positional_postings, allowed);
460
461 let params = self.index.get_params();
463 let results: Vec<ScoredResult> = candidates
464 .into_iter()
465 .filter_map(|(doc_id, phrase_freq)| {
466 let doc_len = self.index.get_doc_length(doc_id)? as f32;
467 let min_df = positional_postings.iter()
469 .map(|pp| pp.len() as u64)
470 .min()
471 .unwrap_or(1);
472 let idf = params.idf(min_df);
473 let score = params.term_score(phrase_freq as f32, doc_len, idf);
474 Some(ScoredResult::new(doc_id, score))
475 })
476 .collect();
477
478 let mut results = results;
479 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
480 results.truncate(k);
481 results
482 }
483
484 fn find_phrase_matches(
486 &self,
487 positional_postings: &[Vec<PositionalPosting>],
488 allowed: &AllowedSet,
489 ) -> Vec<(u64, u32)> {
490 if positional_postings.is_empty() {
491 return vec![];
492 }
493
494 let indexed: Vec<HashMap<u64, &Vec<u32>>> = positional_postings
496 .iter()
497 .map(|postings| {
498 postings.iter()
499 .filter(|p| allowed.contains(p.doc_id))
500 .map(|p| (p.doc_id, &p.positions))
501 .collect()
502 })
503 .collect();
504
505 let first_docs: std::collections::HashSet<u64> = indexed[0].keys().copied().collect();
507
508 let candidate_docs: Vec<u64> = first_docs
510 .into_iter()
511 .filter(|doc_id| indexed.iter().all(|idx| idx.contains_key(doc_id)))
512 .collect();
513
514 let mut matches = vec![];
516
517 for doc_id in candidate_docs {
518 let mut phrase_count = 0u32;
519
520 let first_positions = indexed[0].get(&doc_id).unwrap();
522
523 'outer: for &start_pos in first_positions.iter() {
524 for (term_idx, term_positions) in indexed.iter().enumerate().skip(1) {
526 let expected_pos = start_pos + term_idx as u32;
527 let positions = term_positions.get(&doc_id).unwrap();
528
529 if positions.binary_search(&expected_pos).is_err() {
531 continue 'outer;
532 }
533 }
534
535 phrase_count += 1;
537 }
538
539 if phrase_count > 0 {
540 matches.push((doc_id, phrase_count));
541 }
542 }
543
544 matches
545 }
546}
547
548#[cfg(test)]
553mod tests {
554 use super::*;
555 use crate::candidate_gate::AllowedSet;
556
557 struct MockIndex {
559 postings: HashMap<String, PostingList>,
560 doc_lengths: HashMap<u64, u32>,
561 params: Bm25Params,
562 }
563
564 impl MockIndex {
565 fn new() -> Self {
566 let mut postings = HashMap::new();
567 let mut doc_lengths = HashMap::new();
568
569 postings.insert("rust".to_string(), PostingList::new("rust", vec![
571 (1, 3), (2, 1), (3, 2), (5, 1),
572 ]));
573 postings.insert("database".to_string(), PostingList::new("database", vec![
574 (1, 1), (3, 4), (4, 1),
575 ]));
576 postings.insert("vector".to_string(), PostingList::new("vector", vec![
577 (1, 2), (2, 3), (4, 1), (5, 2),
578 ]));
579
580 for i in 1..=5 {
582 doc_lengths.insert(i, 100);
583 }
584
585 Self {
586 postings,
587 doc_lengths,
588 params: Bm25Params {
589 k1: 1.2,
590 b: 0.75,
591 avgdl: 100.0,
592 total_docs: 1000,
593 },
594 }
595 }
596 }
597
598 impl InvertedIndex for MockIndex {
599 fn get_posting_list(&self, term: &str) -> Option<PostingList> {
600 self.postings.get(term).cloned()
601 }
602
603 fn get_doc_length(&self, doc_id: u64) -> Option<u32> {
604 self.doc_lengths.get(&doc_id).copied()
605 }
606
607 fn get_params(&self) -> &Bm25Params {
608 &self.params
609 }
610 }
611
612 #[test]
613 fn test_conjunctive_search() {
614 let index = Arc::new(MockIndex::new());
615 let executor = FilteredBm25Executor::new(index);
616
617 let results = executor.search("rust database", 10, &AllowedSet::All);
620
621 assert_eq!(results.len(), 2);
622 let doc_ids: Vec<u64> = results.iter().map(|r| r.doc_id).collect();
623 assert!(doc_ids.contains(&1));
624 assert!(doc_ids.contains(&3));
625 }
626
627 #[test]
628 fn test_filter_pushdown() {
629 let index = Arc::new(MockIndex::new());
630 let executor = FilteredBm25Executor::new(index);
631
632 let allowed = AllowedSet::SortedVec(Arc::new(vec![1]));
634
635 let results = executor.search("rust database", 10, &allowed);
636
637 assert_eq!(results.len(), 1);
638 assert_eq!(results[0].doc_id, 1);
639 }
640
641 #[test]
642 fn test_empty_allowed_set() {
643 let index = Arc::new(MockIndex::new());
644 let executor = FilteredBm25Executor::new(index);
645
646 let results = executor.search("rust", 10, &AllowedSet::None);
647 assert!(results.is_empty());
648 }
649
650 #[test]
651 fn test_disjunctive_search() {
652 let index = Arc::new(MockIndex::new());
653 let executor = DisjunctiveBm25Executor::new(index);
654
655 let results = executor.search("rust database", 10, &AllowedSet::All);
658
659 assert!(results.len() >= 4);
661 }
662
663 #[test]
664 fn test_term_ordering_by_df() {
665 let mut pl1 = PostingList::new("rare", vec![(1, 1), (2, 1)]); let mut pl2 = PostingList::new("common", vec![(1, 1), (2, 1), (3, 1), (4, 1), (5, 1)]); let mut lists = vec![pl2.clone(), pl1.clone()];
670 lists.sort_by_key(|pl| pl.doc_freq);
671
672 assert_eq!(lists[0].term, "rare");
674 assert_eq!(lists[1].term, "common");
675 }
676
677 #[test]
678 fn test_bm25_scoring() {
679 let params = Bm25Params::default();
680
681 let idf_rare = params.idf(10);
683 let idf_common = params.idf(100_000);
684
685 assert!(idf_rare > idf_common);
686
687 let score_tf_1 = params.term_score(1.0, 100.0, idf_rare);
689 let score_tf_5 = params.term_score(5.0, 100.0, idf_rare);
690
691 assert!(score_tf_5 > score_tf_1);
692 }
693}