1use std::collections::HashMap;
51use std::sync::Arc;
52
53use crate::candidate_gate::AllowedSet;
54use crate::filter_ir::{AuthScope, FilterIR};
55use crate::filtered_vector_search::ScoredResult;
56use crate::namespace::NamespaceScope;
57
58#[derive(Debug, Clone, Copy, PartialEq)]
64pub enum FusionMethod {
65 Rrf { k: f32 },
67
68 Linear { vector_weight: f32, bm25_weight: f32 },
70
71 Max,
73
74 Cascade { primary: Modality },
76}
77
78#[derive(Debug, Clone, Copy, PartialEq)]
80pub enum Modality {
81 Vector,
82 Bm25,
83}
84
85impl Default for FusionMethod {
86 fn default() -> Self {
87 Self::Rrf { k: 60.0 }
88 }
89}
90
91#[derive(Debug, Clone)]
93pub struct FusionConfig {
94 pub method: FusionMethod,
96
97 pub candidates_per_modality: usize,
99
100 pub final_k: usize,
102
103 pub min_score: Option<f32>,
105}
106
107impl Default for FusionConfig {
108 fn default() -> Self {
109 Self {
110 method: FusionMethod::default(),
111 candidates_per_modality: 100,
112 final_k: 10,
113 min_score: None,
114 }
115 }
116}
117
118#[derive(Debug, Clone)]
124pub struct UnifiedHybridQuery {
125 pub namespace: NamespaceScope,
127
128 pub vector_query: Option<VectorQuerySpec>,
130
131 pub bm25_query: Option<Bm25QuerySpec>,
133
134 pub filter: FilterIR,
136
137 pub fusion_config: FusionConfig,
139}
140
141#[derive(Debug, Clone)]
143pub struct VectorQuerySpec {
144 pub embedding: Vec<f32>,
146 pub ef_search: usize,
148}
149
150#[derive(Debug, Clone)]
152pub struct Bm25QuerySpec {
153 pub text: String,
155 pub fields: Vec<String>,
157}
158
159impl UnifiedHybridQuery {
160 pub fn new(namespace: NamespaceScope) -> Self {
162 Self {
163 namespace,
164 vector_query: None,
165 bm25_query: None,
166 filter: FilterIR::all(),
167 fusion_config: FusionConfig::default(),
168 }
169 }
170
171 pub fn with_vector(mut self, embedding: Vec<f32>) -> Self {
173 self.vector_query = Some(VectorQuerySpec {
174 embedding,
175 ef_search: 100,
176 });
177 self
178 }
179
180 pub fn with_bm25(mut self, text: impl Into<String>) -> Self {
182 self.bm25_query = Some(Bm25QuerySpec {
183 text: text.into(),
184 fields: vec!["content".to_string()],
185 });
186 self
187 }
188
189 pub fn with_filter(mut self, filter: FilterIR) -> Self {
191 self.filter = filter;
192 self
193 }
194
195 pub fn with_fusion(mut self, config: FusionConfig) -> Self {
197 self.fusion_config = config;
198 self
199 }
200
201 pub fn effective_filter(&self) -> FilterIR {
205 self.namespace.to_filter_ir().and(self.filter.clone())
206 }
207}
208
209#[derive(Debug)]
215pub struct FilteredCandidates {
216 pub modality: Modality,
218 pub results: Vec<ScoredResult>,
220 pub filtered: bool,
222}
223
224impl FilteredCandidates {
225 pub fn from_vector(results: Vec<ScoredResult>) -> Self {
227 Self {
228 modality: Modality::Vector,
229 results,
230 filtered: true,
231 }
232 }
233
234 pub fn from_bm25(results: Vec<ScoredResult>) -> Self {
236 Self {
237 modality: Modality::Bm25,
238 results,
239 filtered: true,
240 }
241 }
242}
243
244pub struct FusionEngine {
250 config: FusionConfig,
251}
252
253impl FusionEngine {
254 pub fn new(config: FusionConfig) -> Self {
256 Self { config }
257 }
258
259 pub fn fuse(
264 &self,
265 vector_candidates: Option<FilteredCandidates>,
266 bm25_candidates: Option<FilteredCandidates>,
267 ) -> FusionResult {
268 if let Some(ref vc) = vector_candidates {
270 debug_assert!(vc.filtered, "Vector candidates must be pre-filtered!");
271 }
272 if let Some(ref bc) = bm25_candidates {
273 debug_assert!(bc.filtered, "BM25 candidates must be pre-filtered!");
274 }
275
276 match self.config.method {
277 FusionMethod::Rrf { k } => self.fuse_rrf(vector_candidates, bm25_candidates, k),
278 FusionMethod::Linear { vector_weight, bm25_weight } => {
279 self.fuse_linear(vector_candidates, bm25_candidates, vector_weight, bm25_weight)
280 }
281 FusionMethod::Max => self.fuse_max(vector_candidates, bm25_candidates),
282 FusionMethod::Cascade { primary } => {
283 self.fuse_cascade(vector_candidates, bm25_candidates, primary)
284 }
285 }
286 }
287
288 fn fuse_rrf(
292 &self,
293 vector: Option<FilteredCandidates>,
294 bm25: Option<FilteredCandidates>,
295 k: f32,
296 ) -> FusionResult {
297 let mut scores: HashMap<u64, f32> = HashMap::new();
298
299 if let Some(vc) = vector {
301 for (rank, result) in vc.results.iter().enumerate() {
302 let rrf_score = 1.0 / (k + rank as f32 + 1.0);
303 *scores.entry(result.doc_id).or_insert(0.0) += rrf_score;
304 }
305 }
306
307 if let Some(bc) = bm25 {
309 for (rank, result) in bc.results.iter().enumerate() {
310 let rrf_score = 1.0 / (k + rank as f32 + 1.0);
311 *scores.entry(result.doc_id).or_insert(0.0) += rrf_score;
312 }
313 }
314
315 self.collect_top_k(scores)
316 }
317
318 fn fuse_linear(
320 &self,
321 vector: Option<FilteredCandidates>,
322 bm25: Option<FilteredCandidates>,
323 vector_weight: f32,
324 bm25_weight: f32,
325 ) -> FusionResult {
326 let mut scores: HashMap<u64, f32> = HashMap::new();
327
328 if let Some(vc) = vector {
330 let normalized = self.normalize_scores(&vc.results);
331 for (doc_id, score) in normalized {
332 *scores.entry(doc_id).or_insert(0.0) += score * vector_weight;
333 }
334 }
335
336 if let Some(bc) = bm25 {
338 let normalized = self.normalize_scores(&bc.results);
339 for (doc_id, score) in normalized {
340 *scores.entry(doc_id).or_insert(0.0) += score * bm25_weight;
341 }
342 }
343
344 self.collect_top_k(scores)
345 }
346
347 fn fuse_max(
349 &self,
350 vector: Option<FilteredCandidates>,
351 bm25: Option<FilteredCandidates>,
352 ) -> FusionResult {
353 let mut scores: HashMap<u64, f32> = HashMap::new();
354
355 if let Some(vc) = vector {
356 let normalized = self.normalize_scores(&vc.results);
357 for (doc_id, score) in normalized {
358 let entry = scores.entry(doc_id).or_insert(0.0);
359 *entry = entry.max(score);
360 }
361 }
362
363 if let Some(bc) = bm25 {
364 let normalized = self.normalize_scores(&bc.results);
365 for (doc_id, score) in normalized {
366 let entry = scores.entry(doc_id).or_insert(0.0);
367 *entry = entry.max(score);
368 }
369 }
370
371 self.collect_top_k(scores)
372 }
373
374 fn fuse_cascade(
376 &self,
377 vector: Option<FilteredCandidates>,
378 bm25: Option<FilteredCandidates>,
379 primary: Modality,
380 ) -> FusionResult {
381 let (primary_candidates, secondary_candidates) = match primary {
382 Modality::Vector => (vector, bm25),
383 Modality::Bm25 => (bm25, vector),
384 };
385
386 let primary_ids: std::collections::HashSet<u64> = primary_candidates
388 .as_ref()
389 .map(|c| c.results.iter().map(|r| r.doc_id).collect())
390 .unwrap_or_default();
391
392 let mut scores: HashMap<u64, f32> = HashMap::new();
394
395 if let Some(sc) = secondary_candidates {
396 for result in &sc.results {
397 if primary_ids.contains(&result.doc_id) {
398 scores.insert(result.doc_id, result.score);
399 }
400 }
401 }
402
403 if let Some(pc) = primary_candidates {
405 for (rank, result) in pc.results.iter().enumerate() {
406 scores.entry(result.doc_id).or_insert(-(rank as f32));
407 }
408 }
409
410 self.collect_top_k(scores)
411 }
412
413 fn normalize_scores(&self, results: &[ScoredResult]) -> Vec<(u64, f32)> {
415 if results.is_empty() {
416 return vec![];
417 }
418
419 let min = results.iter().map(|r| r.score).fold(f32::INFINITY, f32::min);
420 let max = results.iter().map(|r| r.score).fold(f32::NEG_INFINITY, f32::max);
421 let range = max - min;
422
423 if range == 0.0 {
424 return results.iter().map(|r| (r.doc_id, 1.0)).collect();
425 }
426
427 results.iter()
428 .map(|r| (r.doc_id, (r.score - min) / range))
429 .collect()
430 }
431
432 fn collect_top_k(&self, scores: HashMap<u64, f32>) -> FusionResult {
434 let mut results: Vec<ScoredResult> = scores
435 .into_iter()
436 .map(|(doc_id, score)| ScoredResult::new(doc_id, score))
437 .collect();
438
439 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
441
442 if let Some(min) = self.config.min_score {
444 results.retain(|r| r.score >= min);
445 }
446
447 results.truncate(self.config.final_k);
449
450 FusionResult {
451 results,
452 method: self.config.method,
453 }
454 }
455}
456
457#[derive(Debug)]
459pub struct FusionResult {
460 pub results: Vec<ScoredResult>,
462 pub method: FusionMethod,
464}
465
466pub trait VectorExecutor {
472 fn search(&self, query: &[f32], k: usize, allowed: &AllowedSet) -> Vec<ScoredResult>;
473}
474
475pub trait Bm25Executor {
477 fn search(&self, query: &str, k: usize, allowed: &AllowedSet) -> Vec<ScoredResult>;
478}
479
480pub struct UnifiedHybridExecutor<V: VectorExecutor, B: Bm25Executor> {
484 vector_executor: Arc<V>,
485 bm25_executor: Arc<B>,
486 fusion_engine: FusionEngine,
487}
488
489impl<V: VectorExecutor, B: Bm25Executor> UnifiedHybridExecutor<V, B> {
490 pub fn new(
492 vector_executor: Arc<V>,
493 bm25_executor: Arc<B>,
494 fusion_config: FusionConfig,
495 ) -> Self {
496 Self {
497 vector_executor,
498 bm25_executor,
499 fusion_engine: FusionEngine::new(fusion_config),
500 }
501 }
502
503 pub fn execute(
514 &self,
515 query: &UnifiedHybridQuery,
516 _auth_scope: &AuthScope,
517 allowed_set: &AllowedSet, ) -> FusionResult {
519 if allowed_set.is_empty() {
521 return FusionResult {
522 results: vec![],
523 method: self.fusion_engine.config.method,
524 };
525 }
526
527 let k = self.fusion_engine.config.candidates_per_modality;
528
529 let vector_candidates = query.vector_query.as_ref().map(|vq| {
531 let results = self.vector_executor.search(&vq.embedding, k, allowed_set);
532 FilteredCandidates::from_vector(results)
533 });
534
535 let bm25_candidates = query.bm25_query.as_ref().map(|bq| {
537 let results = self.bm25_executor.search(&bq.text, k, allowed_set);
538 FilteredCandidates::from_bm25(results)
539 });
540
541 self.fusion_engine.fuse(vector_candidates, bm25_candidates)
543 }
544}
545
546#[cfg(test)]
551mod tests {
552 use super::*;
553
554 #[test]
555 fn test_rrf_fusion() {
556 let config = FusionConfig {
557 method: FusionMethod::Rrf { k: 60.0 },
558 candidates_per_modality: 10,
559 final_k: 5,
560 min_score: None,
561 };
562
563 let engine = FusionEngine::new(config);
564
565 let vector = FilteredCandidates::from_vector(vec![
566 ScoredResult::new(1, 0.9),
567 ScoredResult::new(2, 0.8),
568 ScoredResult::new(3, 0.7),
569 ]);
570
571 let bm25 = FilteredCandidates::from_bm25(vec![
572 ScoredResult::new(2, 5.0), ScoredResult::new(4, 4.0),
574 ScoredResult::new(1, 3.0), ]);
576
577 let result = engine.fuse(Some(vector), Some(bm25));
578
579 assert!(!result.results.is_empty());
582
583 let top_ids: Vec<u64> = result.results.iter().map(|r| r.doc_id).collect();
585 assert!(top_ids.contains(&1));
586 assert!(top_ids.contains(&2));
587 }
588
589 #[test]
590 fn test_linear_fusion() {
591 let config = FusionConfig {
592 method: FusionMethod::Linear {
593 vector_weight: 0.6,
594 bm25_weight: 0.4
595 },
596 candidates_per_modality: 10,
597 final_k: 5,
598 min_score: None,
599 };
600
601 let engine = FusionEngine::new(config);
602
603 let vector = FilteredCandidates::from_vector(vec![
604 ScoredResult::new(1, 1.0),
605 ScoredResult::new(2, 0.5),
606 ]);
607
608 let bm25 = FilteredCandidates::from_bm25(vec![
609 ScoredResult::new(2, 10.0), ScoredResult::new(3, 5.0),
611 ]);
612
613 let result = engine.fuse(Some(vector), Some(bm25));
614
615 assert!(!result.results.is_empty());
617 }
618
619 #[test]
620 fn test_empty_allowed_set() {
621 let config = FusionConfig::default();
622 let engine = FusionEngine::new(config);
623
624 let result = engine.fuse(None, None);
626 assert!(result.results.is_empty());
627 }
628
629 #[test]
630 fn test_score_normalization() {
631 let config = FusionConfig::default();
632 let engine = FusionEngine::new(config);
633
634 let results = vec![
635 ScoredResult::new(1, 100.0),
636 ScoredResult::new(2, 50.0),
637 ScoredResult::new(3, 0.0),
638 ];
639
640 let normalized = engine.normalize_scores(&results);
641
642 assert_eq!(normalized.len(), 3);
644 let scores: HashMap<u64, f32> = normalized.into_iter().collect();
645 assert!((scores[&1] - 1.0).abs() < 0.001);
646 assert!((scores[&2] - 0.5).abs() < 0.001);
647 assert!((scores[&3] - 0.0).abs() < 0.001);
648 }
649
650 #[test]
651 fn test_no_post_filter_invariant() {
652 let allowed: std::collections::HashSet<u64> = [1, 2, 3, 5, 8].into_iter().collect();
658 let allowed_set = AllowedSet::from_iter(allowed.iter().copied());
659
660 let vector = FilteredCandidates::from_vector(vec![
662 ScoredResult::new(1, 0.9), ScoredResult::new(2, 0.8), ScoredResult::new(5, 0.7), ]);
666
667 let bm25 = FilteredCandidates::from_bm25(vec![
668 ScoredResult::new(2, 5.0), ScoredResult::new(3, 4.0), ScoredResult::new(8, 3.0), ]);
672
673 let config = FusionConfig::default();
674 let engine = FusionEngine::new(config);
675 let result = engine.fuse(Some(vector), Some(bm25));
676
677 for doc in &result.results {
679 assert!(
680 allowed_set.contains(doc.doc_id),
681 "INVARIANT VIOLATION: doc_id {} not in allowed set",
682 doc.doc_id
683 );
684 }
685 }
686}
687
688pub fn verify_no_post_filter_invariant(
703 result: &FusionResult,
704 allowed_set: &AllowedSet,
705) -> InvariantVerification {
706 let mut violations = Vec::new();
707
708 for doc in &result.results {
709 if !allowed_set.contains(doc.doc_id) {
710 violations.push(doc.doc_id);
711 }
712 }
713
714 if violations.is_empty() {
715 InvariantVerification::Valid
716 } else {
717 InvariantVerification::Violated { doc_ids: violations }
718 }
719}
720
721#[derive(Debug, Clone, PartialEq, Eq)]
723pub enum InvariantVerification {
724 Valid,
726 Violated { doc_ids: Vec<u64> },
728}
729
730impl InvariantVerification {
731 pub fn is_valid(&self) -> bool {
733 matches!(self, Self::Valid)
734 }
735
736 pub fn assert_valid(&self) {
738 match self {
739 Self::Valid => {}
740 Self::Violated { doc_ids } => {
741 panic!(
742 "NO-POST-FILTER INVARIANT VIOLATED: {} docs not in allowed set: {:?}",
743 doc_ids.len(),
744 doc_ids
745 );
746 }
747 }
748 }
749}