1use std::collections::HashMap;
60use std::sync::Arc;
61
62use crate::candidate_gate::AllowedSet;
63use crate::filtered_vector_search::ScoredResult;
64
65#[derive(Debug, Clone)]
71pub struct Bm25Params {
72 pub k1: f32,
74 pub b: f32,
76 pub avgdl: f32,
78 pub total_docs: u64,
80}
81
82impl Default for Bm25Params {
83 fn default() -> Self {
84 Self {
85 k1: 1.2,
86 b: 0.75,
87 avgdl: 100.0,
88 total_docs: 1_000_000,
89 }
90 }
91}
92
93impl Bm25Params {
94 pub fn idf(&self, doc_freq: u64) -> f32 {
96 let n = self.total_docs as f32;
97 let df = doc_freq as f32;
98 ((n - df + 0.5) / (df + 0.5) + 1.0).ln()
99 }
100
101 pub fn term_score(&self, tf: f32, doc_len: f32, idf: f32) -> f32 {
103 let numerator = tf * (self.k1 + 1.0);
104 let denominator = tf + self.k1 * (1.0 - self.b + self.b * doc_len / self.avgdl);
105 idf * numerator / denominator
106 }
107}
108
109#[derive(Debug, Clone)]
115pub struct PostingList {
116 pub term: String,
118 pub doc_ids: Vec<u64>,
120 pub term_freqs: Vec<u32>,
122 pub doc_freq: u64,
124}
125
126impl PostingList {
127 pub fn new(term: impl Into<String>, entries: Vec<(u64, u32)>) -> Self {
129 let term = term.into();
130 let doc_freq = entries.len() as u64;
131 let mut doc_ids = Vec::with_capacity(entries.len());
132 let mut term_freqs = Vec::with_capacity(entries.len());
133
134 for (doc_id, tf) in entries {
135 doc_ids.push(doc_id);
136 term_freqs.push(tf);
137 }
138
139 Self {
140 term,
141 doc_ids,
142 term_freqs,
143 doc_freq,
144 }
145 }
146
147 pub fn intersect_with_allowed(&self, allowed: &AllowedSet) -> Vec<(u64, u32)> {
149 match allowed {
150 AllowedSet::All => self
151 .doc_ids
152 .iter()
153 .zip(self.term_freqs.iter())
154 .map(|(&id, &tf)| (id, tf))
155 .collect(),
156 AllowedSet::None => vec![],
157 _ => self
158 .doc_ids
159 .iter()
160 .zip(self.term_freqs.iter())
161 .filter(|&(&id, _)| allowed.contains(id))
162 .map(|(&id, &tf)| (id, tf))
163 .collect(),
164 }
165 }
166}
167
168pub trait InvertedIndex: Send + Sync {
174 fn get_posting_list(&self, term: &str) -> Option<PostingList>;
176
177 fn get_doc_length(&self, doc_id: u64) -> Option<u32>;
179
180 fn get_params(&self) -> &Bm25Params;
182}
183
184pub struct FilteredBm25Executor<I: InvertedIndex> {
190 index: Arc<I>,
191}
192
193impl<I: InvertedIndex> FilteredBm25Executor<I> {
194 pub fn new(index: Arc<I>) -> Self {
196 Self { index }
197 }
198
199 pub fn search(&self, query: &str, k: usize, allowed: &AllowedSet) -> Vec<ScoredResult> {
210 if allowed.is_empty() {
212 return vec![];
213 }
214
215 let terms: Vec<&str> = query
217 .split_whitespace()
218 .filter(|t| t.len() >= 2) .collect();
220
221 if terms.is_empty() {
222 return vec![];
223 }
224
225 let mut posting_lists: Vec<PostingList> = terms
227 .iter()
228 .filter_map(|t| self.index.get_posting_list(t))
229 .collect();
230
231 posting_lists.sort_by_key(|pl| pl.doc_freq);
233
234 let candidates = self.progressive_intersection(&posting_lists, allowed);
236
237 if candidates.is_empty() {
238 return vec![];
239 }
240
241 let params = self.index.get_params();
243 let scores = self.score_candidates(&candidates, &posting_lists, params);
244
245 self.top_k(scores, k)
247 }
248
249 fn progressive_intersection(
253 &self,
254 posting_lists: &[PostingList],
255 allowed: &AllowedSet,
256 ) -> HashMap<u64, Vec<u32>> {
257 if posting_lists.is_empty() {
258 return HashMap::new();
259 }
260
261 let first = &posting_lists[0];
263 let mut candidates: HashMap<u64, Vec<u32>> = first
264 .intersect_with_allowed(allowed)
265 .into_iter()
266 .map(|(id, tf)| (id, vec![tf]))
267 .collect();
268
269 for (_term_idx, posting_list) in posting_lists.iter().enumerate().skip(1) {
271 let term_postings: HashMap<u64, u32> = posting_list
273 .doc_ids
274 .iter()
275 .zip(posting_list.term_freqs.iter())
276 .map(|(&id, &tf)| (id, tf))
277 .collect();
278
279 candidates.retain(|doc_id, tfs| {
281 if let Some(&tf) = term_postings.get(doc_id) {
282 tfs.push(tf);
283 true
284 } else {
285 false
286 }
287 });
288
289 if candidates.is_empty() {
291 break;
292 }
293 }
294
295 candidates
296 }
297
298 fn score_candidates(
300 &self,
301 candidates: &HashMap<u64, Vec<u32>>,
302 posting_lists: &[PostingList],
303 params: &Bm25Params,
304 ) -> Vec<ScoredResult> {
305 let idfs: Vec<f32> = posting_lists
307 .iter()
308 .map(|pl| params.idf(pl.doc_freq))
309 .collect();
310
311 candidates
312 .iter()
313 .filter_map(|(&doc_id, tfs)| {
314 let doc_len = self.index.get_doc_length(doc_id)? as f32;
315
316 let score: f32 = tfs
317 .iter()
318 .zip(idfs.iter())
319 .map(|(&tf, &idf)| params.term_score(tf as f32, doc_len, idf))
320 .sum();
321
322 Some(ScoredResult::new(doc_id, score))
323 })
324 .collect()
325 }
326
327 fn top_k(&self, mut scores: Vec<ScoredResult>, k: usize) -> Vec<ScoredResult> {
329 scores.sort_by(|a, b| {
331 b.score
332 .partial_cmp(&a.score)
333 .unwrap_or(std::cmp::Ordering::Equal)
334 });
335 scores.truncate(k);
336 scores
337 }
338}
339
340pub struct DisjunctiveBm25Executor<I: InvertedIndex> {
346 index: Arc<I>,
347}
348
349impl<I: InvertedIndex> DisjunctiveBm25Executor<I> {
350 pub fn new(index: Arc<I>) -> Self {
352 Self { index }
353 }
354
355 pub fn search(&self, query: &str, k: usize, allowed: &AllowedSet) -> Vec<ScoredResult> {
357 if allowed.is_empty() {
358 return vec![];
359 }
360
361 let terms: Vec<&str> = query.split_whitespace().collect();
362 if terms.is_empty() {
363 return vec![];
364 }
365
366 let posting_lists: Vec<PostingList> = terms
368 .iter()
369 .filter_map(|t| self.index.get_posting_list(t))
370 .collect();
371
372 let params = self.index.get_params();
373
374 let mut scores: HashMap<u64, f32> = HashMap::new();
376
377 for posting_list in &posting_lists {
378 let idf = params.idf(posting_list.doc_freq);
379
380 for (&doc_id, &tf) in posting_list
382 .doc_ids
383 .iter()
384 .zip(posting_list.term_freqs.iter())
385 {
386 if !allowed.contains(doc_id) {
387 continue;
388 }
389
390 if let Some(doc_len) = self.index.get_doc_length(doc_id) {
391 let term_score = params.term_score(tf as f32, doc_len as f32, idf);
392 *scores.entry(doc_id).or_insert(0.0) += term_score;
393 }
394 }
395 }
396
397 let mut results: Vec<ScoredResult> = scores
399 .into_iter()
400 .map(|(id, score)| ScoredResult::new(id, score))
401 .collect();
402
403 results.sort_by(|a, b| {
404 b.score
405 .partial_cmp(&a.score)
406 .unwrap_or(std::cmp::Ordering::Equal)
407 });
408 results.truncate(k);
409 results
410 }
411}
412
413#[derive(Debug, Clone)]
419pub struct PositionalPosting {
420 pub doc_id: u64,
422 pub positions: Vec<u32>,
424}
425
426pub trait PositionalIndex: InvertedIndex {
428 fn get_positional_posting(&self, term: &str) -> Option<Vec<PositionalPosting>>;
430}
431
432pub struct FilteredPhraseExecutor<I: PositionalIndex> {
434 index: Arc<I>,
435}
436
437impl<I: PositionalIndex> FilteredPhraseExecutor<I> {
438 pub fn new(index: Arc<I>) -> Self {
440 Self { index }
441 }
442
443 pub fn search(&self, phrase: &[&str], k: usize, allowed: &AllowedSet) -> Vec<ScoredResult> {
447 if phrase.is_empty() || allowed.is_empty() {
448 return vec![];
449 }
450
451 let mut positional_postings: Vec<Vec<PositionalPosting>> = vec![];
453 for term in phrase {
454 match self.index.get_positional_posting(term) {
455 Some(postings) => positional_postings.push(postings),
456 None => return vec![], }
458 }
459
460 let candidates = self.find_phrase_matches(&positional_postings, allowed);
462
463 let params = self.index.get_params();
465 let results: Vec<ScoredResult> = candidates
466 .into_iter()
467 .filter_map(|(doc_id, phrase_freq)| {
468 let doc_len = self.index.get_doc_length(doc_id)? as f32;
469 let min_df = positional_postings
471 .iter()
472 .map(|pp| pp.len() as u64)
473 .min()
474 .unwrap_or(1);
475 let idf = params.idf(min_df);
476 let score = params.term_score(phrase_freq as f32, doc_len, idf);
477 Some(ScoredResult::new(doc_id, score))
478 })
479 .collect();
480
481 let mut results = results;
482 results.sort_by(|a, b| {
483 b.score
484 .partial_cmp(&a.score)
485 .unwrap_or(std::cmp::Ordering::Equal)
486 });
487 results.truncate(k);
488 results
489 }
490
491 fn find_phrase_matches(
493 &self,
494 positional_postings: &[Vec<PositionalPosting>],
495 allowed: &AllowedSet,
496 ) -> Vec<(u64, u32)> {
497 if positional_postings.is_empty() {
498 return vec![];
499 }
500
501 let indexed: Vec<HashMap<u64, &Vec<u32>>> = positional_postings
503 .iter()
504 .map(|postings| {
505 postings
506 .iter()
507 .filter(|p| allowed.contains(p.doc_id))
508 .map(|p| (p.doc_id, &p.positions))
509 .collect()
510 })
511 .collect();
512
513 let first_docs: std::collections::HashSet<u64> = indexed[0].keys().copied().collect();
515
516 let candidate_docs: Vec<u64> = first_docs
518 .into_iter()
519 .filter(|doc_id| indexed.iter().all(|idx| idx.contains_key(doc_id)))
520 .collect();
521
522 let mut matches = vec![];
524
525 for doc_id in candidate_docs {
526 let mut phrase_count = 0u32;
527
528 let first_positions = indexed[0].get(&doc_id).unwrap();
530
531 'outer: for &start_pos in first_positions.iter() {
532 for (term_idx, term_positions) in indexed.iter().enumerate().skip(1) {
534 let expected_pos = start_pos + term_idx as u32;
535 let positions = term_positions.get(&doc_id).unwrap();
536
537 if positions.binary_search(&expected_pos).is_err() {
539 continue 'outer;
540 }
541 }
542
543 phrase_count += 1;
545 }
546
547 if phrase_count > 0 {
548 matches.push((doc_id, phrase_count));
549 }
550 }
551
552 matches
553 }
554}
555
556#[cfg(test)]
561mod tests {
562 use super::*;
563 use crate::candidate_gate::AllowedSet;
564
565 struct MockIndex {
567 postings: HashMap<String, PostingList>,
568 doc_lengths: HashMap<u64, u32>,
569 params: Bm25Params,
570 }
571
572 impl MockIndex {
573 fn new() -> Self {
574 let mut postings = HashMap::new();
575 let mut doc_lengths = HashMap::new();
576
577 postings.insert(
579 "rust".to_string(),
580 PostingList::new("rust", vec![(1, 3), (2, 1), (3, 2), (5, 1)]),
581 );
582 postings.insert(
583 "database".to_string(),
584 PostingList::new("database", vec![(1, 1), (3, 4), (4, 1)]),
585 );
586 postings.insert(
587 "vector".to_string(),
588 PostingList::new("vector", vec![(1, 2), (2, 3), (4, 1), (5, 2)]),
589 );
590
591 for i in 1..=5 {
593 doc_lengths.insert(i, 100);
594 }
595
596 Self {
597 postings,
598 doc_lengths,
599 params: Bm25Params {
600 k1: 1.2,
601 b: 0.75,
602 avgdl: 100.0,
603 total_docs: 1000,
604 },
605 }
606 }
607 }
608
609 impl InvertedIndex for MockIndex {
610 fn get_posting_list(&self, term: &str) -> Option<PostingList> {
611 self.postings.get(term).cloned()
612 }
613
614 fn get_doc_length(&self, doc_id: u64) -> Option<u32> {
615 self.doc_lengths.get(&doc_id).copied()
616 }
617
618 fn get_params(&self) -> &Bm25Params {
619 &self.params
620 }
621 }
622
623 #[test]
624 fn test_conjunctive_search() {
625 let index = Arc::new(MockIndex::new());
626 let executor = FilteredBm25Executor::new(index);
627
628 let results = executor.search("rust database", 10, &AllowedSet::All);
631
632 assert_eq!(results.len(), 2);
633 let doc_ids: Vec<u64> = results.iter().map(|r| r.doc_id).collect();
634 assert!(doc_ids.contains(&1));
635 assert!(doc_ids.contains(&3));
636 }
637
638 #[test]
639 fn test_filter_pushdown() {
640 let index = Arc::new(MockIndex::new());
641 let executor = FilteredBm25Executor::new(index);
642
643 let allowed = AllowedSet::SortedVec(Arc::new(vec![1]));
645
646 let results = executor.search("rust database", 10, &allowed);
647
648 assert_eq!(results.len(), 1);
649 assert_eq!(results[0].doc_id, 1);
650 }
651
652 #[test]
653 fn test_empty_allowed_set() {
654 let index = Arc::new(MockIndex::new());
655 let executor = FilteredBm25Executor::new(index);
656
657 let results = executor.search("rust", 10, &AllowedSet::None);
658 assert!(results.is_empty());
659 }
660
661 #[test]
662 fn test_disjunctive_search() {
663 let index = Arc::new(MockIndex::new());
664 let executor = DisjunctiveBm25Executor::new(index);
665
666 let results = executor.search("rust database", 10, &AllowedSet::All);
669
670 assert!(results.len() >= 4);
672 }
673
674 #[test]
675 fn test_term_ordering_by_df() {
676 let mut pl1 = PostingList::new("rare", vec![(1, 1), (2, 1)]); let mut pl2 = PostingList::new("common", vec![(1, 1), (2, 1), (3, 1), (4, 1), (5, 1)]); let mut lists = vec![pl2.clone(), pl1.clone()];
681 lists.sort_by_key(|pl| pl.doc_freq);
682
683 assert_eq!(lists[0].term, "rare");
685 assert_eq!(lists[1].term, "common");
686 }
687
688 #[test]
689 fn test_bm25_scoring() {
690 let params = Bm25Params::default();
691
692 let idf_rare = params.idf(10);
694 let idf_common = params.idf(100_000);
695
696 assert!(idf_rare > idf_common);
697
698 let score_tf_1 = params.term_score(1.0, 100.0, idf_rare);
700 let score_tf_5 = params.term_score(5.0, 100.0, idf_rare);
701
702 assert!(score_tf_5 > score_tf_1);
703 }
704}