Skip to main content

sochdb_query/
unified_fusion.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Unified Hybrid Fusion with Mandatory Pre-Filtering (Task 7)
19//!
20//! This module implements hybrid retrieval (vector + BM25) that **never**
21//! post-filters. The key insight is:
22//!
23//! > Both vector and BM25 executors receive the **same** AllowedSet,
24//! > produce candidates **guaranteed** within it, then fusion merges by doc_id.
25//!
26//! ## Anti-Pattern (What We Avoid)
27//!
28//! ```text
29//! BAD: vector_search() → candidates → filter → too few
30//!      bm25_search() → candidates → filter → inconsistent
31//!      fusion(unfiltered_v, unfiltered_b) → filter at end → broken!
32//! ```
33//!
34//! ## Correct Pattern
35//!
36//! ```text
37//! GOOD: compute AllowedSet from FilterIR
38//!       vector_search(query, allowed_set) → filtered_v
39//!       bm25_search(query, allowed_set) → filtered_b
40//!       fusion(filtered_v, filtered_b) → already correct!
41//! ```
42//!
43//! ## Fusion Cost
44//!
45//! With pre-filtered candidates:
46//! - Fusion is O(k_v + k_b) with hash-join or two-pointer merge
47//! - Total work is proportional to constrained candidate sizes
48//! - No wasted scoring on disallowed documents
49
50use std::collections::HashMap;
51use std::sync::Arc;
52
53use crate::candidate_gate::AllowedSet;
54use crate::filter_ir::{AuthScope, FilterIR};
55use crate::filtered_vector_search::ScoredResult;
56use crate::namespace::NamespaceScope;
57
58// ============================================================================
59// Fusion Configuration
60// ============================================================================
61
62/// Fusion method
63#[derive(Debug, Clone, Copy, PartialEq)]
64pub enum FusionMethod {
65    /// Reciprocal Rank Fusion: score = Σ w_i / (k + rank_i)
66    Rrf { k: f32 },
67    
68    /// Linear combination of normalized scores
69    Linear { vector_weight: f32, bm25_weight: f32 },
70    
71    /// Take max score across modalities
72    Max,
73    
74    /// Cascade: use one modality to filter, other to rank
75    Cascade { primary: Modality },
76}
77
78/// Search modality
79#[derive(Debug, Clone, Copy, PartialEq)]
80pub enum Modality {
81    Vector,
82    Bm25,
83}
84
85impl Default for FusionMethod {
86    fn default() -> Self {
87        Self::Rrf { k: 60.0 }
88    }
89}
90
91/// Configuration for hybrid fusion
92#[derive(Debug, Clone)]
93pub struct FusionConfig {
94    /// Fusion method
95    pub method: FusionMethod,
96    
97    /// Number of candidates to retrieve from each modality
98    pub candidates_per_modality: usize,
99    
100    /// Final result limit
101    pub final_k: usize,
102    
103    /// Minimum score threshold (after fusion)
104    pub min_score: Option<f32>,
105}
106
107impl Default for FusionConfig {
108    fn default() -> Self {
109        Self {
110            method: FusionMethod::default(),
111            candidates_per_modality: 100,
112            final_k: 10,
113            min_score: None,
114        }
115    }
116}
117
118// ============================================================================
119// Unified Hybrid Query
120// ============================================================================
121
122/// A hybrid query that enforces pre-filtering
123#[derive(Debug, Clone)]
124pub struct UnifiedHybridQuery {
125    /// Namespace scope (mandatory)
126    pub namespace: NamespaceScope,
127    
128    /// Vector query (optional)
129    pub vector_query: Option<VectorQuerySpec>,
130    
131    /// BM25 query (optional)
132    pub bm25_query: Option<Bm25QuerySpec>,
133    
134    /// User-provided filter
135    pub filter: FilterIR,
136    
137    /// Fusion configuration
138    pub fusion_config: FusionConfig,
139}
140
141/// Vector query specification
142#[derive(Debug, Clone)]
143pub struct VectorQuerySpec {
144    /// Query embedding
145    pub embedding: Vec<f32>,
146    /// ef_search for HNSW
147    pub ef_search: usize,
148}
149
150/// BM25 query specification
151#[derive(Debug, Clone)]
152pub struct Bm25QuerySpec {
153    /// Query text (will be tokenized)
154    pub text: String,
155    /// Fields to search
156    pub fields: Vec<String>,
157}
158
159impl UnifiedHybridQuery {
160    /// Create a new hybrid query (namespace is mandatory)
161    pub fn new(namespace: NamespaceScope) -> Self {
162        Self {
163            namespace,
164            vector_query: None,
165            bm25_query: None,
166            filter: FilterIR::all(),
167            fusion_config: FusionConfig::default(),
168        }
169    }
170    
171    /// Add vector search
172    pub fn with_vector(mut self, embedding: Vec<f32>) -> Self {
173        self.vector_query = Some(VectorQuerySpec {
174            embedding,
175            ef_search: 100,
176        });
177        self
178    }
179    
180    /// Add BM25 search
181    pub fn with_bm25(mut self, text: impl Into<String>) -> Self {
182        self.bm25_query = Some(Bm25QuerySpec {
183            text: text.into(),
184            fields: vec!["content".to_string()],
185        });
186        self
187    }
188    
189    /// Add filter
190    pub fn with_filter(mut self, filter: FilterIR) -> Self {
191        self.filter = filter;
192        self
193    }
194    
195    /// Set fusion config
196    pub fn with_fusion(mut self, config: FusionConfig) -> Self {
197        self.fusion_config = config;
198        self
199    }
200    
201    /// Compute the complete effective filter
202    ///
203    /// This combines namespace scope + user filter. Auth scope is added later.
204    pub fn effective_filter(&self) -> FilterIR {
205        self.namespace.to_filter_ir().and(self.filter.clone())
206    }
207}
208
209// ============================================================================
210// Filtered Candidates
211// ============================================================================
212
213/// Candidates from a single modality (already filtered)
214#[derive(Debug)]
215pub struct FilteredCandidates {
216    /// Modality source
217    pub modality: Modality,
218    /// Scored results (doc_id, score)
219    pub results: Vec<ScoredResult>,
220    /// Whether the allowed set was applied
221    pub filtered: bool,
222}
223
224impl FilteredCandidates {
225    /// Create from vector search results
226    pub fn from_vector(results: Vec<ScoredResult>) -> Self {
227        Self {
228            modality: Modality::Vector,
229            results,
230            filtered: true,
231        }
232    }
233    
234    /// Create from BM25 results
235    pub fn from_bm25(results: Vec<ScoredResult>) -> Self {
236        Self {
237            modality: Modality::Bm25,
238            results,
239            filtered: true,
240        }
241    }
242}
243
244// ============================================================================
245// Fusion Engine
246// ============================================================================
247
248/// The fusion engine that combines candidates from multiple modalities
249pub struct FusionEngine {
250    config: FusionConfig,
251}
252
253impl FusionEngine {
254    /// Create a new fusion engine
255    pub fn new(config: FusionConfig) -> Self {
256        Self { config }
257    }
258    
259    /// Fuse candidates from vector and BM25 search
260    ///
261    /// INVARIANT: Both candidate sets are already filtered to AllowedSet.
262    /// This function does NOT apply any additional filtering.
263    pub fn fuse(
264        &self,
265        vector_candidates: Option<FilteredCandidates>,
266        bm25_candidates: Option<FilteredCandidates>,
267    ) -> FusionResult {
268        // Validate that candidates are pre-filtered
269        if let Some(ref vc) = vector_candidates {
270            debug_assert!(vc.filtered, "Vector candidates must be pre-filtered!");
271        }
272        if let Some(ref bc) = bm25_candidates {
273            debug_assert!(bc.filtered, "BM25 candidates must be pre-filtered!");
274        }
275        
276        match self.config.method {
277            FusionMethod::Rrf { k } => self.fuse_rrf(vector_candidates, bm25_candidates, k),
278            FusionMethod::Linear { vector_weight, bm25_weight } => {
279                self.fuse_linear(vector_candidates, bm25_candidates, vector_weight, bm25_weight)
280            }
281            FusionMethod::Max => self.fuse_max(vector_candidates, bm25_candidates),
282            FusionMethod::Cascade { primary } => {
283                self.fuse_cascade(vector_candidates, bm25_candidates, primary)
284            }
285        }
286    }
287    
288    /// Reciprocal Rank Fusion
289    ///
290    /// score(d) = Σ w_i / (k + rank_i(d))
291    fn fuse_rrf(
292        &self,
293        vector: Option<FilteredCandidates>,
294        bm25: Option<FilteredCandidates>,
295        k: f32,
296    ) -> FusionResult {
297        let mut scores: HashMap<u64, f32> = HashMap::new();
298        
299        // Add vector ranks
300        if let Some(vc) = vector {
301            for (rank, result) in vc.results.iter().enumerate() {
302                let rrf_score = 1.0 / (k + rank as f32 + 1.0);
303                *scores.entry(result.doc_id).or_insert(0.0) += rrf_score;
304            }
305        }
306        
307        // Add BM25 ranks
308        if let Some(bc) = bm25 {
309            for (rank, result) in bc.results.iter().enumerate() {
310                let rrf_score = 1.0 / (k + rank as f32 + 1.0);
311                *scores.entry(result.doc_id).or_insert(0.0) += rrf_score;
312            }
313        }
314        
315        self.collect_top_k(scores)
316    }
317    
318    /// Linear combination fusion
319    fn fuse_linear(
320        &self,
321        vector: Option<FilteredCandidates>,
322        bm25: Option<FilteredCandidates>,
323        vector_weight: f32,
324        bm25_weight: f32,
325    ) -> FusionResult {
326        let mut scores: HashMap<u64, f32> = HashMap::new();
327        
328        // Normalize and add vector scores
329        if let Some(vc) = vector {
330            let normalized = self.normalize_scores(&vc.results);
331            for (doc_id, score) in normalized {
332                *scores.entry(doc_id).or_insert(0.0) += score * vector_weight;
333            }
334        }
335        
336        // Normalize and add BM25 scores
337        if let Some(bc) = bm25 {
338            let normalized = self.normalize_scores(&bc.results);
339            for (doc_id, score) in normalized {
340                *scores.entry(doc_id).or_insert(0.0) += score * bm25_weight;
341            }
342        }
343        
344        self.collect_top_k(scores)
345    }
346    
347    /// Max-score fusion
348    fn fuse_max(
349        &self,
350        vector: Option<FilteredCandidates>,
351        bm25: Option<FilteredCandidates>,
352    ) -> FusionResult {
353        let mut scores: HashMap<u64, f32> = HashMap::new();
354        
355        if let Some(vc) = vector {
356            let normalized = self.normalize_scores(&vc.results);
357            for (doc_id, score) in normalized {
358                let entry = scores.entry(doc_id).or_insert(0.0);
359                *entry = entry.max(score);
360            }
361        }
362        
363        if let Some(bc) = bm25 {
364            let normalized = self.normalize_scores(&bc.results);
365            for (doc_id, score) in normalized {
366                let entry = scores.entry(doc_id).or_insert(0.0);
367                *entry = entry.max(score);
368            }
369        }
370        
371        self.collect_top_k(scores)
372    }
373    
374    /// Cascade fusion: use primary modality to filter, secondary to rank
375    fn fuse_cascade(
376        &self,
377        vector: Option<FilteredCandidates>,
378        bm25: Option<FilteredCandidates>,
379        primary: Modality,
380    ) -> FusionResult {
381        let (primary_candidates, secondary_candidates) = match primary {
382            Modality::Vector => (vector, bm25),
383            Modality::Bm25 => (bm25, vector),
384        };
385        
386        // Get primary doc IDs
387        let primary_ids: std::collections::HashSet<u64> = primary_candidates
388            .as_ref()
389            .map(|c| c.results.iter().map(|r| r.doc_id).collect())
390            .unwrap_or_default();
391        
392        // Score by secondary, but only docs in primary
393        let mut scores: HashMap<u64, f32> = HashMap::new();
394        
395        if let Some(sc) = secondary_candidates {
396            for result in &sc.results {
397                if primary_ids.contains(&result.doc_id) {
398                    scores.insert(result.doc_id, result.score);
399                }
400            }
401        }
402        
403        // If secondary doesn't score some docs, use primary order
404        if let Some(pc) = primary_candidates {
405            for (rank, result) in pc.results.iter().enumerate() {
406                scores.entry(result.doc_id).or_insert(-(rank as f32));
407            }
408        }
409        
410        self.collect_top_k(scores)
411    }
412    
413    /// Normalize scores to [0, 1] using min-max normalization
414    fn normalize_scores(&self, results: &[ScoredResult]) -> Vec<(u64, f32)> {
415        if results.is_empty() {
416            return vec![];
417        }
418        
419        let min = results.iter().map(|r| r.score).fold(f32::INFINITY, f32::min);
420        let max = results.iter().map(|r| r.score).fold(f32::NEG_INFINITY, f32::max);
421        let range = max - min;
422        
423        if range == 0.0 {
424            return results.iter().map(|r| (r.doc_id, 1.0)).collect();
425        }
426        
427        results.iter()
428            .map(|r| (r.doc_id, (r.score - min) / range))
429            .collect()
430    }
431    
432    /// Collect top-k results from score map
433    fn collect_top_k(&self, scores: HashMap<u64, f32>) -> FusionResult {
434        let mut results: Vec<ScoredResult> = scores
435            .into_iter()
436            .map(|(doc_id, score)| ScoredResult::new(doc_id, score))
437            .collect();
438        
439        // Sort by score descending
440        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
441        
442        // Apply min_score filter
443        if let Some(min) = self.config.min_score {
444            results.retain(|r| r.score >= min);
445        }
446        
447        // Truncate to k
448        results.truncate(self.config.final_k);
449        
450        FusionResult {
451            results,
452            method: self.config.method,
453        }
454    }
455}
456
457/// Result of fusion
458#[derive(Debug)]
459pub struct FusionResult {
460    /// Final ranked results
461    pub results: Vec<ScoredResult>,
462    /// Method used
463    pub method: FusionMethod,
464}
465
466// ============================================================================
467// Unified Hybrid Executor
468// ============================================================================
469
470/// Trait for vector search executor
471pub trait VectorExecutor {
472    fn search(&self, query: &[f32], k: usize, allowed: &AllowedSet) -> Vec<ScoredResult>;
473}
474
475/// Trait for BM25 executor
476pub trait Bm25Executor {
477    fn search(&self, query: &str, k: usize, allowed: &AllowedSet) -> Vec<ScoredResult>;
478}
479
480/// The unified hybrid executor
481///
482/// This is the main entry point that enforces the "no post-filtering" contract.
483pub struct UnifiedHybridExecutor<V: VectorExecutor, B: Bm25Executor> {
484    vector_executor: Arc<V>,
485    bm25_executor: Arc<B>,
486    fusion_engine: FusionEngine,
487}
488
489impl<V: VectorExecutor, B: Bm25Executor> UnifiedHybridExecutor<V, B> {
490    /// Create a new executor
491    pub fn new(
492        vector_executor: Arc<V>,
493        bm25_executor: Arc<B>,
494        fusion_config: FusionConfig,
495    ) -> Self {
496        Self {
497            vector_executor,
498            bm25_executor,
499            fusion_engine: FusionEngine::new(fusion_config),
500        }
501    }
502    
503    /// Execute a hybrid query with mandatory pre-filtering
504    ///
505    /// # Contract
506    ///
507    /// 1. Computes `effective_filter = auth_scope ∧ query_filter`
508    /// 2. Converts to `AllowedSet` (via metadata index)
509    /// 3. Passes SAME `AllowedSet` to BOTH vector and BM25 executors
510    /// 4. Fuses already-filtered results
511    ///
512    /// NO POST-FILTERING occurs in this function.
513    pub fn execute(
514        &self,
515        query: &UnifiedHybridQuery,
516        _auth_scope: &AuthScope,
517        allowed_set: &AllowedSet, // Pre-computed from FilterIR + AuthScope
518    ) -> FusionResult {
519        // Short-circuit if empty
520        if allowed_set.is_empty() {
521            return FusionResult {
522                results: vec![],
523                method: self.fusion_engine.config.method,
524            };
525        }
526        
527        let k = self.fusion_engine.config.candidates_per_modality;
528        
529        // Vector search (with AllowedSet)
530        let vector_candidates = query.vector_query.as_ref().map(|vq| {
531            let results = self.vector_executor.search(&vq.embedding, k, allowed_set);
532            FilteredCandidates::from_vector(results)
533        });
534        
535        // BM25 search (with SAME AllowedSet)
536        let bm25_candidates = query.bm25_query.as_ref().map(|bq| {
537            let results = self.bm25_executor.search(&bq.text, k, allowed_set);
538            FilteredCandidates::from_bm25(results)
539        });
540        
541        // Fuse (both are already filtered - no post-filtering!)
542        self.fusion_engine.fuse(vector_candidates, bm25_candidates)
543    }
544}
545
546// ============================================================================
547// Tests
548// ============================================================================
549
550#[cfg(test)]
551mod tests {
552    use super::*;
553    
554    #[test]
555    fn test_rrf_fusion() {
556        let config = FusionConfig {
557            method: FusionMethod::Rrf { k: 60.0 },
558            candidates_per_modality: 10,
559            final_k: 5,
560            min_score: None,
561        };
562        
563        let engine = FusionEngine::new(config);
564        
565        let vector = FilteredCandidates::from_vector(vec![
566            ScoredResult::new(1, 0.9),
567            ScoredResult::new(2, 0.8),
568            ScoredResult::new(3, 0.7),
569        ]);
570        
571        let bm25 = FilteredCandidates::from_bm25(vec![
572            ScoredResult::new(2, 5.0), // doc 2 is in both
573            ScoredResult::new(4, 4.0),
574            ScoredResult::new(1, 3.0), // doc 1 is in both
575        ]);
576        
577        let result = engine.fuse(Some(vector), Some(bm25));
578        
579        // Doc 2 should score highest (rank 2 in vector, rank 1 in BM25)
580        // Doc 1 should also score well (rank 1 in vector, rank 3 in BM25)
581        assert!(!result.results.is_empty());
582        
583        // Docs 1 and 2 should be near the top
584        let top_ids: Vec<u64> = result.results.iter().map(|r| r.doc_id).collect();
585        assert!(top_ids.contains(&1));
586        assert!(top_ids.contains(&2));
587    }
588    
589    #[test]
590    fn test_linear_fusion() {
591        let config = FusionConfig {
592            method: FusionMethod::Linear { 
593                vector_weight: 0.6, 
594                bm25_weight: 0.4 
595            },
596            candidates_per_modality: 10,
597            final_k: 5,
598            min_score: None,
599        };
600        
601        let engine = FusionEngine::new(config);
602        
603        let vector = FilteredCandidates::from_vector(vec![
604            ScoredResult::new(1, 1.0),
605            ScoredResult::new(2, 0.5),
606        ]);
607        
608        let bm25 = FilteredCandidates::from_bm25(vec![
609            ScoredResult::new(2, 10.0), // Different scale
610            ScoredResult::new(3, 5.0),
611        ]);
612        
613        let result = engine.fuse(Some(vector), Some(bm25));
614        
615        // After normalization, doc 2 should benefit from both
616        assert!(!result.results.is_empty());
617    }
618    
619    #[test]
620    fn test_empty_allowed_set() {
621        let config = FusionConfig::default();
622        let engine = FusionEngine::new(config);
623        
624        // No candidates = empty result
625        let result = engine.fuse(None, None);
626        assert!(result.results.is_empty());
627    }
628    
629    #[test]
630    fn test_score_normalization() {
631        let config = FusionConfig::default();
632        let engine = FusionEngine::new(config);
633        
634        let results = vec![
635            ScoredResult::new(1, 100.0),
636            ScoredResult::new(2, 50.0),
637            ScoredResult::new(3, 0.0),
638        ];
639        
640        let normalized = engine.normalize_scores(&results);
641        
642        // Should be normalized to [0, 1]
643        assert_eq!(normalized.len(), 3);
644        let scores: HashMap<u64, f32> = normalized.into_iter().collect();
645        assert!((scores[&1] - 1.0).abs() < 0.001);
646        assert!((scores[&2] - 0.5).abs() < 0.001);
647        assert!((scores[&3] - 0.0).abs() < 0.001);
648    }
649    
650    #[test]
651    fn test_no_post_filter_invariant() {
652        // This test verifies the core invariant:
653        // result-set ⊆ allowed-set
654        //
655        // If this invariant is violated, it indicates a security issue.
656        
657        let allowed: std::collections::HashSet<u64> = [1, 2, 3, 5, 8].into_iter().collect();
658        let allowed_set = AllowedSet::from_iter(allowed.iter().copied());
659        
660        // Simulate filtered candidates (these should already respect AllowedSet)
661        let vector = FilteredCandidates::from_vector(vec![
662            ScoredResult::new(1, 0.9),  // in allowed set
663            ScoredResult::new(2, 0.8),  // in allowed set
664            ScoredResult::new(5, 0.7),  // in allowed set
665        ]);
666        
667        let bm25 = FilteredCandidates::from_bm25(vec![
668            ScoredResult::new(2, 5.0),  // in allowed set
669            ScoredResult::new(3, 4.0),  // in allowed set
670            ScoredResult::new(8, 3.0),  // in allowed set
671        ]);
672        
673        let config = FusionConfig::default();
674        let engine = FusionEngine::new(config);
675        let result = engine.fuse(Some(vector), Some(bm25));
676        
677        // INVARIANT: Every result doc_id must be in the allowed set
678        for doc in &result.results {
679            assert!(
680                allowed_set.contains(doc.doc_id),
681                "INVARIANT VIOLATION: doc_id {} not in allowed set",
682                doc.doc_id
683            );
684        }
685    }
686}
687
688// ============================================================================
689// Invariant Verification
690// ============================================================================
691
692/// Verify that a fusion result respects the no-post-filtering invariant
693/// 
694/// This function should be used in tests and optionally in debug builds
695/// to verify that the security invariant holds.
696///
697/// # Invariant
698///
699/// `∀ doc ∈ result: doc.id ∈ allowed_set`
700///
701/// This is the "monotone property" from the architecture document.
702pub fn verify_no_post_filter_invariant(
703    result: &FusionResult,
704    allowed_set: &AllowedSet,
705) -> InvariantVerification {
706    let mut violations = Vec::new();
707    
708    for doc in &result.results {
709        if !allowed_set.contains(doc.doc_id) {
710            violations.push(doc.doc_id);
711        }
712    }
713    
714    if violations.is_empty() {
715        InvariantVerification::Valid
716    } else {
717        InvariantVerification::Violated { doc_ids: violations }
718    }
719}
720
721/// Result of invariant verification
722#[derive(Debug, Clone, PartialEq, Eq)]
723pub enum InvariantVerification {
724    /// Invariant holds
725    Valid,
726    /// Invariant violated - these doc IDs should not be in results
727    Violated { doc_ids: Vec<u64> },
728}
729
730impl InvariantVerification {
731    /// Check if the invariant holds
732    pub fn is_valid(&self) -> bool {
733        matches!(self, Self::Valid)
734    }
735    
736    /// Panic if the invariant is violated (for testing)
737    pub fn assert_valid(&self) {
738        match self {
739            Self::Valid => {}
740            Self::Violated { doc_ids } => {
741                panic!(
742                    "NO-POST-FILTER INVARIANT VIOLATED: {} docs not in allowed set: {:?}",
743                    doc_ids.len(),
744                    doc_ids
745                );
746            }
747        }
748    }
749}