1use std::collections::HashMap;
48use std::sync::Arc;
49
50use crate::candidate_gate::AllowedSet;
51use crate::filter_ir::{AuthScope, FilterIR};
52use crate::filtered_vector_search::ScoredResult;
53use crate::namespace::NamespaceScope;
54
55#[derive(Debug, Clone, Copy, PartialEq)]
61pub enum FusionMethod {
62 Rrf { k: f32 },
64
65 Linear { vector_weight: f32, bm25_weight: f32 },
67
68 Max,
70
71 Cascade { primary: Modality },
73}
74
75#[derive(Debug, Clone, Copy, PartialEq)]
77pub enum Modality {
78 Vector,
79 Bm25,
80}
81
82impl Default for FusionMethod {
83 fn default() -> Self {
84 Self::Rrf { k: 60.0 }
85 }
86}
87
88#[derive(Debug, Clone)]
90pub struct FusionConfig {
91 pub method: FusionMethod,
93
94 pub candidates_per_modality: usize,
96
97 pub final_k: usize,
99
100 pub min_score: Option<f32>,
102}
103
104impl Default for FusionConfig {
105 fn default() -> Self {
106 Self {
107 method: FusionMethod::default(),
108 candidates_per_modality: 100,
109 final_k: 10,
110 min_score: None,
111 }
112 }
113}
114
115#[derive(Debug, Clone)]
121pub struct UnifiedHybridQuery {
122 pub namespace: NamespaceScope,
124
125 pub vector_query: Option<VectorQuerySpec>,
127
128 pub bm25_query: Option<Bm25QuerySpec>,
130
131 pub filter: FilterIR,
133
134 pub fusion_config: FusionConfig,
136}
137
138#[derive(Debug, Clone)]
140pub struct VectorQuerySpec {
141 pub embedding: Vec<f32>,
143 pub ef_search: usize,
145}
146
147#[derive(Debug, Clone)]
149pub struct Bm25QuerySpec {
150 pub text: String,
152 pub fields: Vec<String>,
154}
155
156impl UnifiedHybridQuery {
157 pub fn new(namespace: NamespaceScope) -> Self {
159 Self {
160 namespace,
161 vector_query: None,
162 bm25_query: None,
163 filter: FilterIR::all(),
164 fusion_config: FusionConfig::default(),
165 }
166 }
167
168 pub fn with_vector(mut self, embedding: Vec<f32>) -> Self {
170 self.vector_query = Some(VectorQuerySpec {
171 embedding,
172 ef_search: 100,
173 });
174 self
175 }
176
177 pub fn with_bm25(mut self, text: impl Into<String>) -> Self {
179 self.bm25_query = Some(Bm25QuerySpec {
180 text: text.into(),
181 fields: vec!["content".to_string()],
182 });
183 self
184 }
185
186 pub fn with_filter(mut self, filter: FilterIR) -> Self {
188 self.filter = filter;
189 self
190 }
191
192 pub fn with_fusion(mut self, config: FusionConfig) -> Self {
194 self.fusion_config = config;
195 self
196 }
197
198 pub fn effective_filter(&self) -> FilterIR {
202 self.namespace.to_filter_ir().and(self.filter.clone())
203 }
204}
205
206#[derive(Debug)]
212pub struct FilteredCandidates {
213 pub modality: Modality,
215 pub results: Vec<ScoredResult>,
217 pub filtered: bool,
219}
220
221impl FilteredCandidates {
222 pub fn from_vector(results: Vec<ScoredResult>) -> Self {
224 Self {
225 modality: Modality::Vector,
226 results,
227 filtered: true,
228 }
229 }
230
231 pub fn from_bm25(results: Vec<ScoredResult>) -> Self {
233 Self {
234 modality: Modality::Bm25,
235 results,
236 filtered: true,
237 }
238 }
239}
240
241pub struct FusionEngine {
247 config: FusionConfig,
248}
249
250impl FusionEngine {
251 pub fn new(config: FusionConfig) -> Self {
253 Self { config }
254 }
255
256 pub fn fuse(
261 &self,
262 vector_candidates: Option<FilteredCandidates>,
263 bm25_candidates: Option<FilteredCandidates>,
264 ) -> FusionResult {
265 if let Some(ref vc) = vector_candidates {
267 debug_assert!(vc.filtered, "Vector candidates must be pre-filtered!");
268 }
269 if let Some(ref bc) = bm25_candidates {
270 debug_assert!(bc.filtered, "BM25 candidates must be pre-filtered!");
271 }
272
273 match self.config.method {
274 FusionMethod::Rrf { k } => self.fuse_rrf(vector_candidates, bm25_candidates, k),
275 FusionMethod::Linear { vector_weight, bm25_weight } => {
276 self.fuse_linear(vector_candidates, bm25_candidates, vector_weight, bm25_weight)
277 }
278 FusionMethod::Max => self.fuse_max(vector_candidates, bm25_candidates),
279 FusionMethod::Cascade { primary } => {
280 self.fuse_cascade(vector_candidates, bm25_candidates, primary)
281 }
282 }
283 }
284
285 fn fuse_rrf(
289 &self,
290 vector: Option<FilteredCandidates>,
291 bm25: Option<FilteredCandidates>,
292 k: f32,
293 ) -> FusionResult {
294 let mut scores: HashMap<u64, f32> = HashMap::new();
295
296 if let Some(vc) = vector {
298 for (rank, result) in vc.results.iter().enumerate() {
299 let rrf_score = 1.0 / (k + rank as f32 + 1.0);
300 *scores.entry(result.doc_id).or_insert(0.0) += rrf_score;
301 }
302 }
303
304 if let Some(bc) = bm25 {
306 for (rank, result) in bc.results.iter().enumerate() {
307 let rrf_score = 1.0 / (k + rank as f32 + 1.0);
308 *scores.entry(result.doc_id).or_insert(0.0) += rrf_score;
309 }
310 }
311
312 self.collect_top_k(scores)
313 }
314
315 fn fuse_linear(
317 &self,
318 vector: Option<FilteredCandidates>,
319 bm25: Option<FilteredCandidates>,
320 vector_weight: f32,
321 bm25_weight: f32,
322 ) -> FusionResult {
323 let mut scores: HashMap<u64, f32> = HashMap::new();
324
325 if let Some(vc) = vector {
327 let normalized = self.normalize_scores(&vc.results);
328 for (doc_id, score) in normalized {
329 *scores.entry(doc_id).or_insert(0.0) += score * vector_weight;
330 }
331 }
332
333 if let Some(bc) = bm25 {
335 let normalized = self.normalize_scores(&bc.results);
336 for (doc_id, score) in normalized {
337 *scores.entry(doc_id).or_insert(0.0) += score * bm25_weight;
338 }
339 }
340
341 self.collect_top_k(scores)
342 }
343
344 fn fuse_max(
346 &self,
347 vector: Option<FilteredCandidates>,
348 bm25: Option<FilteredCandidates>,
349 ) -> FusionResult {
350 let mut scores: HashMap<u64, f32> = HashMap::new();
351
352 if let Some(vc) = vector {
353 let normalized = self.normalize_scores(&vc.results);
354 for (doc_id, score) in normalized {
355 let entry = scores.entry(doc_id).or_insert(0.0);
356 *entry = entry.max(score);
357 }
358 }
359
360 if let Some(bc) = bm25 {
361 let normalized = self.normalize_scores(&bc.results);
362 for (doc_id, score) in normalized {
363 let entry = scores.entry(doc_id).or_insert(0.0);
364 *entry = entry.max(score);
365 }
366 }
367
368 self.collect_top_k(scores)
369 }
370
371 fn fuse_cascade(
373 &self,
374 vector: Option<FilteredCandidates>,
375 bm25: Option<FilteredCandidates>,
376 primary: Modality,
377 ) -> FusionResult {
378 let (primary_candidates, secondary_candidates) = match primary {
379 Modality::Vector => (vector, bm25),
380 Modality::Bm25 => (bm25, vector),
381 };
382
383 let primary_ids: std::collections::HashSet<u64> = primary_candidates
385 .as_ref()
386 .map(|c| c.results.iter().map(|r| r.doc_id).collect())
387 .unwrap_or_default();
388
389 let mut scores: HashMap<u64, f32> = HashMap::new();
391
392 if let Some(sc) = secondary_candidates {
393 for result in &sc.results {
394 if primary_ids.contains(&result.doc_id) {
395 scores.insert(result.doc_id, result.score);
396 }
397 }
398 }
399
400 if let Some(pc) = primary_candidates {
402 for (rank, result) in pc.results.iter().enumerate() {
403 scores.entry(result.doc_id).or_insert(-(rank as f32));
404 }
405 }
406
407 self.collect_top_k(scores)
408 }
409
410 fn normalize_scores(&self, results: &[ScoredResult]) -> Vec<(u64, f32)> {
412 if results.is_empty() {
413 return vec![];
414 }
415
416 let min = results.iter().map(|r| r.score).fold(f32::INFINITY, f32::min);
417 let max = results.iter().map(|r| r.score).fold(f32::NEG_INFINITY, f32::max);
418 let range = max - min;
419
420 if range == 0.0 {
421 return results.iter().map(|r| (r.doc_id, 1.0)).collect();
422 }
423
424 results.iter()
425 .map(|r| (r.doc_id, (r.score - min) / range))
426 .collect()
427 }
428
429 fn collect_top_k(&self, scores: HashMap<u64, f32>) -> FusionResult {
431 let mut results: Vec<ScoredResult> = scores
432 .into_iter()
433 .map(|(doc_id, score)| ScoredResult::new(doc_id, score))
434 .collect();
435
436 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
438
439 if let Some(min) = self.config.min_score {
441 results.retain(|r| r.score >= min);
442 }
443
444 results.truncate(self.config.final_k);
446
447 FusionResult {
448 results,
449 method: self.config.method,
450 }
451 }
452}
453
454#[derive(Debug)]
456pub struct FusionResult {
457 pub results: Vec<ScoredResult>,
459 pub method: FusionMethod,
461}
462
463pub trait VectorExecutor {
469 fn search(&self, query: &[f32], k: usize, allowed: &AllowedSet) -> Vec<ScoredResult>;
470}
471
472pub trait Bm25Executor {
474 fn search(&self, query: &str, k: usize, allowed: &AllowedSet) -> Vec<ScoredResult>;
475}
476
477pub struct UnifiedHybridExecutor<V: VectorExecutor, B: Bm25Executor> {
481 vector_executor: Arc<V>,
482 bm25_executor: Arc<B>,
483 fusion_engine: FusionEngine,
484}
485
486impl<V: VectorExecutor, B: Bm25Executor> UnifiedHybridExecutor<V, B> {
487 pub fn new(
489 vector_executor: Arc<V>,
490 bm25_executor: Arc<B>,
491 fusion_config: FusionConfig,
492 ) -> Self {
493 Self {
494 vector_executor,
495 bm25_executor,
496 fusion_engine: FusionEngine::new(fusion_config),
497 }
498 }
499
500 pub fn execute(
511 &self,
512 query: &UnifiedHybridQuery,
513 _auth_scope: &AuthScope,
514 allowed_set: &AllowedSet, ) -> FusionResult {
516 if allowed_set.is_empty() {
518 return FusionResult {
519 results: vec![],
520 method: self.fusion_engine.config.method,
521 };
522 }
523
524 let k = self.fusion_engine.config.candidates_per_modality;
525
526 let vector_candidates = query.vector_query.as_ref().map(|vq| {
528 let results = self.vector_executor.search(&vq.embedding, k, allowed_set);
529 FilteredCandidates::from_vector(results)
530 });
531
532 let bm25_candidates = query.bm25_query.as_ref().map(|bq| {
534 let results = self.bm25_executor.search(&bq.text, k, allowed_set);
535 FilteredCandidates::from_bm25(results)
536 });
537
538 self.fusion_engine.fuse(vector_candidates, bm25_candidates)
540 }
541}
542
543#[cfg(test)]
548mod tests {
549 use super::*;
550
551 #[test]
552 fn test_rrf_fusion() {
553 let config = FusionConfig {
554 method: FusionMethod::Rrf { k: 60.0 },
555 candidates_per_modality: 10,
556 final_k: 5,
557 min_score: None,
558 };
559
560 let engine = FusionEngine::new(config);
561
562 let vector = FilteredCandidates::from_vector(vec![
563 ScoredResult::new(1, 0.9),
564 ScoredResult::new(2, 0.8),
565 ScoredResult::new(3, 0.7),
566 ]);
567
568 let bm25 = FilteredCandidates::from_bm25(vec![
569 ScoredResult::new(2, 5.0), ScoredResult::new(4, 4.0),
571 ScoredResult::new(1, 3.0), ]);
573
574 let result = engine.fuse(Some(vector), Some(bm25));
575
576 assert!(!result.results.is_empty());
579
580 let top_ids: Vec<u64> = result.results.iter().map(|r| r.doc_id).collect();
582 assert!(top_ids.contains(&1));
583 assert!(top_ids.contains(&2));
584 }
585
586 #[test]
587 fn test_linear_fusion() {
588 let config = FusionConfig {
589 method: FusionMethod::Linear {
590 vector_weight: 0.6,
591 bm25_weight: 0.4
592 },
593 candidates_per_modality: 10,
594 final_k: 5,
595 min_score: None,
596 };
597
598 let engine = FusionEngine::new(config);
599
600 let vector = FilteredCandidates::from_vector(vec![
601 ScoredResult::new(1, 1.0),
602 ScoredResult::new(2, 0.5),
603 ]);
604
605 let bm25 = FilteredCandidates::from_bm25(vec![
606 ScoredResult::new(2, 10.0), ScoredResult::new(3, 5.0),
608 ]);
609
610 let result = engine.fuse(Some(vector), Some(bm25));
611
612 assert!(!result.results.is_empty());
614 }
615
616 #[test]
617 fn test_empty_allowed_set() {
618 let config = FusionConfig::default();
619 let engine = FusionEngine::new(config);
620
621 let result = engine.fuse(None, None);
623 assert!(result.results.is_empty());
624 }
625
626 #[test]
627 fn test_score_normalization() {
628 let config = FusionConfig::default();
629 let engine = FusionEngine::new(config);
630
631 let results = vec![
632 ScoredResult::new(1, 100.0),
633 ScoredResult::new(2, 50.0),
634 ScoredResult::new(3, 0.0),
635 ];
636
637 let normalized = engine.normalize_scores(&results);
638
639 assert_eq!(normalized.len(), 3);
641 let scores: HashMap<u64, f32> = normalized.into_iter().collect();
642 assert!((scores[&1] - 1.0).abs() < 0.001);
643 assert!((scores[&2] - 0.5).abs() < 0.001);
644 assert!((scores[&3] - 0.0).abs() < 0.001);
645 }
646
647 #[test]
648 fn test_no_post_filter_invariant() {
649 let allowed: std::collections::HashSet<u64> = [1, 2, 3, 5, 8].into_iter().collect();
655 let allowed_set = AllowedSet::from_iter(allowed.iter().copied());
656
657 let vector = FilteredCandidates::from_vector(vec![
659 ScoredResult::new(1, 0.9), ScoredResult::new(2, 0.8), ScoredResult::new(5, 0.7), ]);
663
664 let bm25 = FilteredCandidates::from_bm25(vec![
665 ScoredResult::new(2, 5.0), ScoredResult::new(3, 4.0), ScoredResult::new(8, 3.0), ]);
669
670 let config = FusionConfig::default();
671 let engine = FusionEngine::new(config);
672 let result = engine.fuse(Some(vector), Some(bm25));
673
674 for doc in &result.results {
676 assert!(
677 allowed_set.contains(doc.doc_id),
678 "INVARIANT VIOLATION: doc_id {} not in allowed set",
679 doc.doc_id
680 );
681 }
682 }
683}
684
685pub fn verify_no_post_filter_invariant(
700 result: &FusionResult,
701 allowed_set: &AllowedSet,
702) -> InvariantVerification {
703 let mut violations = Vec::new();
704
705 for doc in &result.results {
706 if !allowed_set.contains(doc.doc_id) {
707 violations.push(doc.doc_id);
708 }
709 }
710
711 if violations.is_empty() {
712 InvariantVerification::Valid
713 } else {
714 InvariantVerification::Violated { doc_ids: violations }
715 }
716}
717
718#[derive(Debug, Clone, PartialEq, Eq)]
720pub enum InvariantVerification {
721 Valid,
723 Violated { doc_ids: Vec<u64> },
725}
726
727impl InvariantVerification {
728 pub fn is_valid(&self) -> bool {
730 matches!(self, Self::Valid)
731 }
732
733 pub fn assert_valid(&self) {
735 match self {
736 Self::Valid => {}
737 Self::Violated { doc_ids } => {
738 panic!(
739 "NO-POST-FILTER INVARIANT VIOLATED: {} docs not in allowed set: {:?}",
740 doc_ids.len(),
741 doc_ids
742 );
743 }
744 }
745 }
746}