Skip to main content

sochdb_query/
unified_fusion.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Unified Hybrid Fusion with Mandatory Pre-Filtering (Task 7)
19//!
20//! This module implements hybrid retrieval (vector + BM25) that **never**
21//! post-filters. The key insight is:
22//!
23//! > Both vector and BM25 executors receive the **same** AllowedSet,
24//! > produce candidates **guaranteed** within it, then fusion merges by doc_id.
25//!
26//! ## Anti-Pattern (What We Avoid)
27//!
28//! ```text
29//! BAD: vector_search() → candidates → filter → too few
30//!      bm25_search() → candidates → filter → inconsistent
31//!      fusion(unfiltered_v, unfiltered_b) → filter at end → broken!
32//! ```
33//!
34//! ## Correct Pattern
35//!
36//! ```text
37//! GOOD: compute AllowedSet from FilterIR
38//!       vector_search(query, allowed_set) → filtered_v
39//!       bm25_search(query, allowed_set) → filtered_b
40//!       fusion(filtered_v, filtered_b) → already correct!
41//! ```
42//!
43//! ## Fusion Cost
44//!
45//! With pre-filtered candidates:
46//! - Fusion is O(k_v + k_b) with hash-join or two-pointer merge
47//! - Total work is proportional to constrained candidate sizes
48//! - No wasted scoring on disallowed documents
49
50use std::collections::HashMap;
51use std::sync::Arc;
52
53use crate::candidate_gate::AllowedSet;
54use crate::filter_ir::{AuthScope, FilterIR};
55use crate::filtered_vector_search::ScoredResult;
56use crate::grep_executor::GrepMode;
57use crate::namespace::NamespaceScope;
58
59// ============================================================================
60// Fusion Configuration
61// ============================================================================
62
63/// Fusion method
64#[derive(Debug, Clone, Copy, PartialEq)]
65pub enum FusionMethod {
66    /// Reciprocal Rank Fusion: score = Σ wᵢ / (k + rankᵢ), rank 1-indexed.
67    Rrf {
68        k: f32,
69        vector_weight: f32,
70        bm25_weight: f32,
71    },
72
73    /// Linear combination of normalized scores
74    Linear {
75        vector_weight: f32,
76        bm25_weight: f32,
77    },
78
79    /// Take max score across modalities
80    Max,
81
82    /// Cascade: use one modality to filter, other to rank
83    Cascade { primary: Modality },
84}
85
86/// Search modality
87#[derive(Debug, Clone, Copy, PartialEq)]
88pub enum Modality {
89    Vector,
90    Bm25,
91    /// Trigram-accelerated regex (grep) lane.
92    Grep,
93}
94
95impl Default for FusionMethod {
96    fn default() -> Self {
97        Self::Rrf {
98            k: 60.0,
99            vector_weight: 1.0,
100            bm25_weight: 1.0,
101        }
102    }
103}
104
105/// Configuration for hybrid fusion
106#[derive(Debug, Clone)]
107pub struct FusionConfig {
108    /// Fusion method
109    pub method: FusionMethod,
110
111    /// Number of candidates to retrieve from each modality
112    pub candidates_per_modality: usize,
113
114    /// Final result limit
115    pub final_k: usize,
116
117    /// Minimum score threshold (after fusion)
118    pub min_score: Option<f32>,
119}
120
121impl Default for FusionConfig {
122    fn default() -> Self {
123        Self {
124            method: FusionMethod::default(),
125            candidates_per_modality: 100,
126            final_k: 10,
127            min_score: None,
128        }
129    }
130}
131
132// ============================================================================
133// Unified Hybrid Query
134// ============================================================================
135
136/// A hybrid query that enforces pre-filtering
137#[derive(Debug, Clone)]
138pub struct UnifiedHybridQuery {
139    /// Namespace scope (mandatory)
140    pub namespace: NamespaceScope,
141
142    /// Vector query (optional)
143    pub vector_query: Option<VectorQuerySpec>,
144
145    /// BM25 query (optional)
146    pub bm25_query: Option<Bm25QuerySpec>,
147
148    /// Grep (regex) query (optional)
149    pub grep_query: Option<GrepQuerySpec>,
150
151    /// User-provided filter
152    pub filter: FilterIR,
153
154    /// Fusion configuration
155    pub fusion_config: FusionConfig,
156}
157
158/// Vector query specification
159#[derive(Debug, Clone)]
160pub struct VectorQuerySpec {
161    /// Query embedding
162    pub embedding: Vec<f32>,
163    /// ef_search for HNSW
164    pub ef_search: usize,
165}
166
167/// BM25 query specification
168#[derive(Debug, Clone)]
169pub struct Bm25QuerySpec {
170    /// Query text (will be tokenized)
171    pub text: String,
172    /// Fields to search
173    pub fields: Vec<String>,
174}
175
176/// Grep (regex) query specification for the third lane.
177///
178/// The [`GrepMode`] determines how the lane participates in fusion:
179/// - [`GrepMode::Rank`] contributes a ranked list weighted by `weight`.
180/// - [`GrepMode::Gate`] narrows the `AllowedSet` *before* the vector and BM25
181///   lanes run (a cascade); `weight` is unused in that case.
182#[derive(Debug, Clone)]
183pub struct GrepQuerySpec {
184    /// Regular expression pattern.
185    pub pattern: String,
186    /// How the lane is consumed by fusion.
187    pub mode: GrepMode,
188    /// Fusion weight (used only for [`GrepMode::Rank`]).
189    pub weight: f32,
190}
191
192impl UnifiedHybridQuery {
193    /// Create a new hybrid query (namespace is mandatory)
194    pub fn new(namespace: NamespaceScope) -> Self {
195        Self {
196            namespace,
197            vector_query: None,
198            bm25_query: None,
199            grep_query: None,
200            filter: FilterIR::all(),
201            fusion_config: FusionConfig::default(),
202        }
203    }
204
205    /// Add vector search
206    pub fn with_vector(mut self, embedding: Vec<f32>) -> Self {
207        self.vector_query = Some(VectorQuerySpec {
208            embedding,
209            ef_search: 100,
210        });
211        self
212    }
213
214    /// Add BM25 search
215    pub fn with_bm25(mut self, text: impl Into<String>) -> Self {
216        self.bm25_query = Some(Bm25QuerySpec {
217            text: text.into(),
218            fields: vec!["content".to_string()],
219        });
220        self
221    }
222
223    /// Add a grep (regex) lane with the default weight of `1.0`.
224    pub fn with_grep(mut self, pattern: impl Into<String>, mode: GrepMode) -> Self {
225        self.grep_query = Some(GrepQuerySpec {
226            pattern: pattern.into(),
227            mode,
228            weight: 1.0,
229        });
230        self
231    }
232
233    /// Add a grep (regex) lane with an explicit fusion weight.
234    pub fn with_grep_weighted(
235        mut self,
236        pattern: impl Into<String>,
237        mode: GrepMode,
238        weight: f32,
239    ) -> Self {
240        self.grep_query = Some(GrepQuerySpec {
241            pattern: pattern.into(),
242            mode,
243            weight,
244        });
245        self
246    }
247
248    /// Add filter
249    pub fn with_filter(mut self, filter: FilterIR) -> Self {
250        self.filter = filter;
251        self
252    }
253
254    /// Set fusion config
255    pub fn with_fusion(mut self, config: FusionConfig) -> Self {
256        self.fusion_config = config;
257        self
258    }
259
260    /// Compute the complete effective filter
261    ///
262    /// This combines namespace scope + user filter. Auth scope is added later.
263    pub fn effective_filter(&self) -> FilterIR {
264        self.namespace.to_filter_ir().and(self.filter.clone())
265    }
266}
267
268// ============================================================================
269// Filtered Candidates
270// ============================================================================
271
272/// Candidates from a single modality (already filtered)
273#[derive(Debug)]
274pub struct FilteredCandidates {
275    /// Modality source
276    pub modality: Modality,
277    /// Scored results (doc_id, score)
278    pub results: Vec<ScoredResult>,
279    /// Whether the allowed set was applied
280    pub filtered: bool,
281}
282
283impl FilteredCandidates {
284    /// Create from vector search results
285    pub fn from_vector(results: Vec<ScoredResult>) -> Self {
286        Self {
287            modality: Modality::Vector,
288            results,
289            filtered: true,
290        }
291    }
292
293    /// Create from BM25 results
294    pub fn from_bm25(results: Vec<ScoredResult>) -> Self {
295        Self {
296            modality: Modality::Bm25,
297            results,
298            filtered: true,
299        }
300    }
301
302    /// Create from grep (regex) results
303    pub fn from_grep(results: Vec<ScoredResult>) -> Self {
304        Self {
305            modality: Modality::Grep,
306            results,
307            filtered: true,
308        }
309    }
310}
311
312// ============================================================================
313// Canonical Document Identity + RRF Kernel
314// ============================================================================
315
316/// Canonical document identity consumed by the fusion kernel.
317///
318/// A newtype (rather than a bare `u64`) so a retrieval-space document id can
319/// never be silently confused with a raw vector offset, record id, or rank.
320/// The fusion kernel keys exclusively on this type; executors convert their
321/// native ids into a `DocId` at the boundary.
322///
323/// ## Identity-space contract (Task 1)
324///
325/// `DocId` is the **retrieval-universe** identity, a dense `u64` shared by
326/// every lane that participates in pre-filtered fusion:
327///
328/// | Lane | Native key | Relationship to `DocId` |
329/// |------|-----------|--------------------------|
330/// | BM25 / inverted index | `u64` doc id | identical (`DocId(id)`) |
331/// | Grep / trigram lane | `DocId = u64` alias | identical |
332/// | `AllowedSet` membership | `u64` | identical (`AllowedSet::contains(d.get())`) |
333/// | HNSW vector graph | `u128` storage id | **mapped** at the boundary |
334///
335/// The first three are the *same* key space, so an `AllowedSet` produced by one
336/// lane gates every other lane with an O(1) membership test on `DocId.0`. The
337/// HNSW graph keys on a wider `u128` *storage* id; the vector executor narrows
338/// that to a retrieval `DocId` when it emits candidates (and accepts an
339/// `allowed(u128)` predicate over the same mapping — see
340/// `HnswIndex::search_allowed`). Threading the newtype down into the `u128`
341/// storage layer is deliberately **out of scope** here: it touches the durable
342/// id space and yields no behavioral change while this kernel keys on `DocId`.
343#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
344pub struct DocId(pub u64);
345
346impl DocId {
347    /// The underlying retrieval-universe id.
348    #[inline]
349    pub const fn get(self) -> u64 {
350        self.0
351    }
352}
353
354impl From<u64> for DocId {
355    fn from(id: u64) -> Self {
356        DocId(id)
357    }
358}
359
360impl From<DocId> for u64 {
361    fn from(d: DocId) -> Self {
362        d.0
363    }
364}
365
366/// A ranked candidate list paired with the weight it contributes to fusion.
367///
368/// Results must be ordered best-first; the element at index 0 is treated as
369/// rank 1.
370pub struct RankedList<'a> {
371    /// Candidates ordered best-first.
372    pub results: &'a [ScoredResult],
373    /// Fusion weight for this list.
374    pub weight: f32,
375}
376
377/// A modality's pre-filtered candidates paired with the weight it contributes
378/// to N-ary fusion. This is the unit consumed by [`FusionEngine::fuse_multi`],
379/// allowing an arbitrary number of lanes (vector, BM25, grep, …) to be fused by
380/// a single call.
381pub struct WeightedLane {
382    /// Pre-filtered candidates for one modality.
383    pub candidates: FilteredCandidates,
384    /// Fusion weight for this lane.
385    pub weight: f32,
386}
387
388/// The single, canonical Reciprocal Rank Fusion kernel: **weighted, 1-indexed,
389/// N-ary**.
390///
391/// ```text
392/// score(d) = Σᵢ  weightᵢ / (k + rankᵢ(d))
393/// ```
394///
395/// where `rankᵢ(d)` is the **1-indexed** position of document `d` in list `i`
396/// (the top result has rank 1) and `weightᵢ` is that list's weight. Every
397/// higher-level fusion path funnels through this function so the weighting and
398/// the rank offset can never diverge across the codebase again.
399pub fn fuse_rrf_weighted(lists: &[RankedList<'_>], k: f32) -> HashMap<DocId, f32> {
400    let mut scores: HashMap<DocId, f32> = HashMap::new();
401    for list in lists {
402        for (rank, result) in list.results.iter().enumerate() {
403            let contribution = list.weight / (k + (rank as f32 + 1.0));
404            *scores.entry(DocId(result.doc_id)).or_insert(0.0) += contribution;
405        }
406    }
407    scores
408}
409
410// ============================================================================
411// Fusion Engine
412// ============================================================================
413
414/// The fusion engine that combines candidates from multiple modalities
415pub struct FusionEngine {
416    config: FusionConfig,
417}
418
419impl FusionEngine {
420    /// Create a new fusion engine
421    pub fn new(config: FusionConfig) -> Self {
422        Self { config }
423    }
424
425    /// Fuse candidates from vector and BM25 search
426    ///
427    /// INVARIANT: Both candidate sets are already filtered to AllowedSet.
428    /// This function does NOT apply any additional filtering.
429    ///
430    /// This is the two-lane convenience over [`FusionEngine::fuse_multi`]: it
431    /// builds weighted lanes from the configured method weights and delegates,
432    /// so the two-lane and N-ary paths share exactly one scoring implementation.
433    pub fn fuse(
434        &self,
435        vector_candidates: Option<FilteredCandidates>,
436        bm25_candidates: Option<FilteredCandidates>,
437    ) -> FusionResult {
438        // Validate that candidates are pre-filtered
439        if let Some(ref vc) = vector_candidates {
440            debug_assert!(vc.filtered, "Vector candidates must be pre-filtered!");
441        }
442        if let Some(ref bc) = bm25_candidates {
443            debug_assert!(bc.filtered, "BM25 candidates must be pre-filtered!");
444        }
445
446        // Cascade is intrinsically two-modality (primary filters, secondary
447        // ranks) and keeps its dedicated path.
448        if let FusionMethod::Cascade { primary } = self.config.method {
449            return self.fuse_cascade(vector_candidates, bm25_candidates, primary);
450        }
451
452        let (vector_weight, bm25_weight) = self.method_weights();
453        let mut lanes: Vec<WeightedLane> = Vec::with_capacity(2);
454        if let Some(vc) = vector_candidates {
455            lanes.push(WeightedLane {
456                candidates: vc,
457                weight: vector_weight,
458            });
459        }
460        if let Some(bc) = bm25_candidates {
461            lanes.push(WeightedLane {
462                candidates: bc,
463                weight: bm25_weight,
464            });
465        }
466        self.fuse_multi(lanes)
467    }
468
469    /// The per-modality weights implied by the configured fusion method.
470    ///
471    /// RRF and Linear carry explicit vector/BM25 weights; Max and Cascade do
472    /// not weight their inputs, so they report a neutral `1.0`.
473    pub(crate) fn method_weights(&self) -> (f32, f32) {
474        match self.config.method {
475            FusionMethod::Rrf {
476                vector_weight,
477                bm25_weight,
478                ..
479            } => (vector_weight, bm25_weight),
480            FusionMethod::Linear {
481                vector_weight,
482                bm25_weight,
483            } => (vector_weight, bm25_weight),
484            FusionMethod::Max | FusionMethod::Cascade { .. } => (1.0, 1.0),
485        }
486    }
487
488    /// N-ary fusion across any number of pre-filtered modality lanes.
489    ///
490    /// This is the canonical multi-lane path: vector, BM25, and grep (or any
491    /// future lane) are fused by a single call. RRF funnels through
492    /// [`fuse_rrf_weighted`]; Linear and Max combine per-lane normalized scores
493    /// weighted by each lane's weight.
494    ///
495    /// INVARIANT: every lane is already filtered to the AllowedSet. No
496    /// additional filtering happens here.
497    pub fn fuse_multi(&self, lanes: Vec<WeightedLane>) -> FusionResult {
498        for lane in &lanes {
499            debug_assert!(
500                lane.candidates.filtered,
501                "Fusion lanes must be pre-filtered!"
502            );
503        }
504
505        match self.config.method {
506            FusionMethod::Rrf { k, .. } => {
507                let ranked: Vec<RankedList<'_>> = lanes
508                    .iter()
509                    .map(|lane| RankedList {
510                        results: &lane.candidates.results,
511                        weight: lane.weight,
512                    })
513                    .collect();
514                let scores = fuse_rrf_weighted(&ranked, k)
515                    .into_iter()
516                    .map(|(doc, score)| (doc.0, score))
517                    .collect();
518                self.collect_top_k(scores)
519            }
520            FusionMethod::Linear { .. } => {
521                let mut scores: HashMap<u64, f32> = HashMap::new();
522                for lane in &lanes {
523                    for (doc_id, score) in self.normalize_scores(&lane.candidates.results) {
524                        *scores.entry(doc_id).or_insert(0.0) += score * lane.weight;
525                    }
526                }
527                self.collect_top_k(scores)
528            }
529            FusionMethod::Max => {
530                let mut scores: HashMap<u64, f32> = HashMap::new();
531                for lane in &lanes {
532                    for (doc_id, score) in self.normalize_scores(&lane.candidates.results) {
533                        let entry = scores.entry(doc_id).or_insert(0.0);
534                        *entry = entry.max(score);
535                    }
536                }
537                self.collect_top_k(scores)
538            }
539            FusionMethod::Cascade { primary } => {
540                // Cascade is two-modality: reconstruct the vector and BM25 lanes
541                // by modality and apply the primary/secondary logic. A grep Rank
542                // lane is not part of a cascade and is ignored here (grep's
543                // cascade shape is Gate, applied before fusion in `execute`).
544                let mut vector = None;
545                let mut bm25 = None;
546                for lane in lanes {
547                    match lane.candidates.modality {
548                        Modality::Vector => vector = Some(lane.candidates),
549                        Modality::Bm25 => bm25 = Some(lane.candidates),
550                        Modality::Grep => {}
551                    }
552                }
553                self.fuse_cascade(vector, bm25, primary)
554            }
555        }
556    }
557
558    /// Cascade fusion: use primary modality to filter, secondary to rank
559    fn fuse_cascade(
560        &self,
561        vector: Option<FilteredCandidates>,
562        bm25: Option<FilteredCandidates>,
563        primary: Modality,
564    ) -> FusionResult {
565        let (primary_candidates, secondary_candidates) = match primary {
566            Modality::Vector => (vector, bm25),
567            Modality::Bm25 => (bm25, vector),
568            // Grep is not a cascade ranking modality (its cascade shape is the
569            // Gate, applied to the AllowedSet before fusion). Fall back to a
570            // vector-primary cascade so the method stays total.
571            Modality::Grep => (vector, bm25),
572        };
573
574        // Get primary doc IDs
575        let primary_ids: std::collections::HashSet<u64> = primary_candidates
576            .as_ref()
577            .map(|c| c.results.iter().map(|r| r.doc_id).collect())
578            .unwrap_or_default();
579
580        // Score by secondary, but only docs in primary
581        let mut scores: HashMap<u64, f32> = HashMap::new();
582
583        if let Some(sc) = secondary_candidates {
584            for result in &sc.results {
585                if primary_ids.contains(&result.doc_id) {
586                    scores.insert(result.doc_id, result.score);
587                }
588            }
589        }
590
591        // If secondary doesn't score some docs, use primary order
592        if let Some(pc) = primary_candidates {
593            for (rank, result) in pc.results.iter().enumerate() {
594                scores.entry(result.doc_id).or_insert(-(rank as f32));
595            }
596        }
597
598        self.collect_top_k(scores)
599    }
600
601    /// Normalize scores to [0, 1] using min-max normalization
602    fn normalize_scores(&self, results: &[ScoredResult]) -> Vec<(u64, f32)> {
603        if results.is_empty() {
604            return vec![];
605        }
606
607        let min = results
608            .iter()
609            .map(|r| r.score)
610            .fold(f32::INFINITY, f32::min);
611        let max = results
612            .iter()
613            .map(|r| r.score)
614            .fold(f32::NEG_INFINITY, f32::max);
615        let range = max - min;
616
617        if range == 0.0 {
618            return results.iter().map(|r| (r.doc_id, 1.0)).collect();
619        }
620
621        results
622            .iter()
623            .map(|r| (r.doc_id, (r.score - min) / range))
624            .collect()
625    }
626
627    /// Collect top-k results from score map
628    fn collect_top_k(&self, scores: HashMap<u64, f32>) -> FusionResult {
629        let mut results: Vec<ScoredResult> = scores
630            .into_iter()
631            .map(|(doc_id, score)| ScoredResult::new(doc_id, score))
632            .collect();
633
634        // Sort by score descending
635        results.sort_by(|a, b| {
636            b.score
637                .partial_cmp(&a.score)
638                .unwrap_or(std::cmp::Ordering::Equal)
639        });
640
641        // Apply min_score filter
642        if let Some(min) = self.config.min_score {
643            results.retain(|r| r.score >= min);
644        }
645
646        // Truncate to k
647        results.truncate(self.config.final_k);
648
649        FusionResult {
650            results,
651            method: self.config.method,
652        }
653    }
654}
655
656/// Result of fusion
657#[derive(Debug)]
658pub struct FusionResult {
659    /// Final ranked results
660    pub results: Vec<ScoredResult>,
661    /// Method used
662    pub method: FusionMethod,
663}
664
665// ============================================================================
666// Unified Hybrid Executor
667// ============================================================================
668
669/// Trait for vector search executor
670pub trait VectorExecutor {
671    fn search(&self, query: &[f32], k: usize, allowed: &AllowedSet) -> Vec<ScoredResult>;
672}
673
674/// Trait for BM25 executor
675pub trait Bm25Executor {
676    fn search(&self, query: &str, k: usize, allowed: &AllowedSet) -> Vec<ScoredResult>;
677}
678
679/// Trait for the grep (trigram-accelerated regex) lane.
680///
681/// Implementations MUST honor `allowed` — every returned id must be a member
682/// (the same `result ⊆ allowed` contract the other lanes enforce). For
683/// [`GrepMode::Rank`] the scores rank documents by match density; for
684/// [`GrepMode::Gate`] only the returned doc-ids matter (they form the cascade
685/// gate) and the scores are not meaningful.
686pub trait GrepLaneExecutor {
687    fn grep(
688        &self,
689        pattern: &str,
690        k: usize,
691        allowed: &AllowedSet,
692        mode: GrepMode,
693    ) -> Vec<ScoredResult>;
694}
695
696/// The unified hybrid executor
697///
698/// This is the main entry point that enforces the "no post-filtering" contract.
699pub struct UnifiedHybridExecutor<V: VectorExecutor, B: Bm25Executor> {
700    vector_executor: Arc<V>,
701    bm25_executor: Arc<B>,
702    grep_executor: Option<Arc<dyn GrepLaneExecutor>>,
703    fusion_engine: FusionEngine,
704}
705
706impl<V: VectorExecutor, B: Bm25Executor> UnifiedHybridExecutor<V, B> {
707    /// Create a new executor
708    pub fn new(
709        vector_executor: Arc<V>,
710        bm25_executor: Arc<B>,
711        fusion_config: FusionConfig,
712    ) -> Self {
713        Self {
714            vector_executor,
715            bm25_executor,
716            grep_executor: None,
717            fusion_engine: FusionEngine::new(fusion_config),
718        }
719    }
720
721    /// Attach a grep lane executor, enabling three-lane fusion.
722    ///
723    /// Without one, any `grep_query` on a [`UnifiedHybridQuery`] is ignored and
724    /// the executor behaves exactly as the two-lane vector+BM25 path.
725    pub fn with_grep_executor(mut self, grep_executor: Arc<dyn GrepLaneExecutor>) -> Self {
726        self.grep_executor = Some(grep_executor);
727        self
728    }
729
730    /// Execute a hybrid query with mandatory pre-filtering
731    ///
732    /// # Contract
733    ///
734    /// 1. Computes `effective_filter = auth_scope ∧ query_filter`
735    /// 2. Converts to `AllowedSet` (via metadata index)
736    /// 3. A grep `Gate` lane (if present) narrows that `AllowedSet` *first*
737    /// 4. Passes the SAME `AllowedSet` to the vector, BM25 and grep-`Rank` lanes
738    /// 5. Fuses all already-filtered lanes with one N-ary `fuse_multi` call
739    ///
740    /// NO POST-FILTERING occurs in this function.
741    pub fn execute(
742        &self,
743        query: &UnifiedHybridQuery,
744        _auth_scope: &AuthScope,
745        allowed_set: &AllowedSet, // Pre-computed from FilterIR + AuthScope
746    ) -> FusionResult {
747        // Short-circuit if empty
748        if allowed_set.is_empty() {
749            return FusionResult {
750                results: vec![],
751                method: self.fusion_engine.config.method,
752            };
753        }
754
755        let k = self.fusion_engine.config.candidates_per_modality;
756
757        // ---- Lane 3 (grep) planning -------------------------------------
758        // Gate: run grep first and intersect its matches into the AllowedSet so
759        //       the other lanes only ever see grep-approved documents.
760        // Rank: run grep as an additional ranked lane alongside vector/BM25.
761        let mut grep_rank: Option<FilteredCandidates> = None;
762        let mut grep_weight = 1.0_f32;
763        let mut gated: Option<AllowedSet> = None;
764        if let (Some(gq), Some(grep)) = (query.grep_query.as_ref(), self.grep_executor.as_ref()) {
765            match gq.mode {
766                GrepMode::Gate => {
767                    // `k = 0` = unlimited: the gate must be the full match set.
768                    let hits = grep.grep(&gq.pattern, 0, allowed_set, GrepMode::Gate);
769                    gated = Some(AllowedSet::from_iter(hits.into_iter().map(|r| r.doc_id)));
770                }
771                GrepMode::Rank => {
772                    let hits = grep.grep(&gq.pattern, k, allowed_set, GrepMode::Rank);
773                    grep_rank = Some(FilteredCandidates::from_grep(hits));
774                    grep_weight = gq.weight;
775                }
776            }
777        }
778
779        // The effective AllowedSet every other lane is gated by.
780        let effective_allowed: &AllowedSet = gated.as_ref().unwrap_or(allowed_set);
781        if effective_allowed.is_empty() {
782            return FusionResult {
783                results: vec![],
784                method: self.fusion_engine.config.method,
785            };
786        }
787
788        // Vector search (with the effective AllowedSet)
789        let vector_candidates = query.vector_query.as_ref().map(|vq| {
790            let results = self
791                .vector_executor
792                .search(&vq.embedding, k, effective_allowed);
793            FilteredCandidates::from_vector(results)
794        });
795
796        // BM25 search (with the SAME effective AllowedSet)
797        let bm25_candidates = query.bm25_query.as_ref().map(|bq| {
798            let results = self.bm25_executor.search(&bq.text, k, effective_allowed);
799            FilteredCandidates::from_bm25(results)
800        });
801
802        // ---- Fuse all lanes with a single N-ary call --------------------
803        let (vector_weight, bm25_weight) = self.fusion_engine.method_weights();
804        let mut lanes: Vec<WeightedLane> = Vec::with_capacity(3);
805        if let Some(vc) = vector_candidates {
806            lanes.push(WeightedLane {
807                candidates: vc,
808                weight: vector_weight,
809            });
810        }
811        if let Some(bc) = bm25_candidates {
812            lanes.push(WeightedLane {
813                candidates: bc,
814                weight: bm25_weight,
815            });
816        }
817        if let Some(gc) = grep_rank {
818            lanes.push(WeightedLane {
819                candidates: gc,
820                weight: grep_weight,
821            });
822        }
823
824        self.fusion_engine.fuse_multi(lanes)
825    }
826}
827
828// ============================================================================
829// Tests
830// ============================================================================
831
832#[cfg(test)]
833mod tests {
834    use super::*;
835
836    #[test]
837    fn test_rrf_fusion() {
838        let config = FusionConfig {
839            method: FusionMethod::Rrf {
840                k: 60.0,
841                vector_weight: 1.0,
842                bm25_weight: 1.0,
843            },
844            candidates_per_modality: 10,
845            final_k: 5,
846            min_score: None,
847        };
848
849        let engine = FusionEngine::new(config);
850
851        let vector = FilteredCandidates::from_vector(vec![
852            ScoredResult::new(1, 0.9),
853            ScoredResult::new(2, 0.8),
854            ScoredResult::new(3, 0.7),
855        ]);
856
857        let bm25 = FilteredCandidates::from_bm25(vec![
858            ScoredResult::new(2, 5.0), // doc 2 is in both
859            ScoredResult::new(4, 4.0),
860            ScoredResult::new(1, 3.0), // doc 1 is in both
861        ]);
862
863        let result = engine.fuse(Some(vector), Some(bm25));
864
865        // Doc 2 should score highest (rank 2 in vector, rank 1 in BM25)
866        // Doc 1 should also score well (rank 1 in vector, rank 3 in BM25)
867        assert!(!result.results.is_empty());
868
869        // Docs 1 and 2 should be near the top
870        let top_ids: Vec<u64> = result.results.iter().map(|r| r.doc_id).collect();
871        assert!(top_ids.contains(&1));
872        assert!(top_ids.contains(&2));
873    }
874
875    #[test]
876    fn test_fuse_rrf_weighted_is_1_indexed_and_weighted() {
877        // Single list: the rank-1 document must score weight / (k + 1), proving
878        // the kernel is 1-indexed (top result is rank 1, not 0) and honors the
879        // per-list weight.
880        let k = 60.0_f32;
881        let docs = [ScoredResult::new(7, 0.9), ScoredResult::new(8, 0.5)];
882        let scores = fuse_rrf_weighted(
883            &[RankedList {
884                results: &docs,
885                weight: 2.0,
886            }],
887            k,
888        );
889
890        let s7 = scores[&DocId(7)];
891        let s8 = scores[&DocId(8)];
892        assert!(
893            (s7 - 2.0 / (k + 1.0)).abs() < 1e-6,
894            "rank-1 must use 1-indexed weighted score"
895        );
896        assert!(
897            (s8 - 2.0 / (k + 2.0)).abs() < 1e-6,
898            "rank-2 must use 1-indexed weighted score"
899        );
900        assert!(s7 > s8, "earlier rank must score higher");
901
902        // A document present in two weighted lists accumulates both contributions.
903        let list_a = [ScoredResult::new(1, 0.0)];
904        let list_b = [ScoredResult::new(1, 0.0)];
905        let merged = fuse_rrf_weighted(
906            &[
907                RankedList {
908                    results: &list_a,
909                    weight: 1.0,
910                },
911                RankedList {
912                    results: &list_b,
913                    weight: 3.0,
914                },
915            ],
916            k,
917        );
918        let expected = 1.0 / (k + 1.0) + 3.0 / (k + 1.0);
919        assert!(
920            (merged[&DocId(1)] - expected).abs() < 1e-6,
921            "weights must sum across lists"
922        );
923    }
924
925    #[test]
926    fn test_linear_fusion() {
927        let config = FusionConfig {
928            method: FusionMethod::Linear {
929                vector_weight: 0.6,
930                bm25_weight: 0.4,
931            },
932            candidates_per_modality: 10,
933            final_k: 5,
934            min_score: None,
935        };
936
937        let engine = FusionEngine::new(config);
938
939        let vector = FilteredCandidates::from_vector(vec![
940            ScoredResult::new(1, 1.0),
941            ScoredResult::new(2, 0.5),
942        ]);
943
944        let bm25 = FilteredCandidates::from_bm25(vec![
945            ScoredResult::new(2, 10.0), // Different scale
946            ScoredResult::new(3, 5.0),
947        ]);
948
949        let result = engine.fuse(Some(vector), Some(bm25));
950
951        // After normalization, doc 2 should benefit from both
952        assert!(!result.results.is_empty());
953    }
954
955    #[test]
956    fn test_empty_allowed_set() {
957        let config = FusionConfig::default();
958        let engine = FusionEngine::new(config);
959
960        // No candidates = empty result
961        let result = engine.fuse(None, None);
962        assert!(result.results.is_empty());
963    }
964
965    #[test]
966    fn test_score_normalization() {
967        let config = FusionConfig::default();
968        let engine = FusionEngine::new(config);
969
970        let results = vec![
971            ScoredResult::new(1, 100.0),
972            ScoredResult::new(2, 50.0),
973            ScoredResult::new(3, 0.0),
974        ];
975
976        let normalized = engine.normalize_scores(&results);
977
978        // Should be normalized to [0, 1]
979        assert_eq!(normalized.len(), 3);
980        let scores: HashMap<u64, f32> = normalized.into_iter().collect();
981        assert!((scores[&1] - 1.0).abs() < 0.001);
982        assert!((scores[&2] - 0.5).abs() < 0.001);
983        assert!((scores[&3] - 0.0).abs() < 0.001);
984    }
985
986    #[test]
987    fn test_no_post_filter_invariant() {
988        // This test verifies the core invariant:
989        // result-set ⊆ allowed-set
990        //
991        // If this invariant is violated, it indicates a security issue.
992
993        let allowed: std::collections::HashSet<u64> = [1, 2, 3, 5, 8].into_iter().collect();
994        let allowed_set = AllowedSet::from_iter(allowed.iter().copied());
995
996        // Simulate filtered candidates (these should already respect AllowedSet)
997        let vector = FilteredCandidates::from_vector(vec![
998            ScoredResult::new(1, 0.9), // in allowed set
999            ScoredResult::new(2, 0.8), // in allowed set
1000            ScoredResult::new(5, 0.7), // in allowed set
1001        ]);
1002
1003        let bm25 = FilteredCandidates::from_bm25(vec![
1004            ScoredResult::new(2, 5.0), // in allowed set
1005            ScoredResult::new(3, 4.0), // in allowed set
1006            ScoredResult::new(8, 3.0), // in allowed set
1007        ]);
1008
1009        let config = FusionConfig::default();
1010        let engine = FusionEngine::new(config);
1011        let result = engine.fuse(Some(vector), Some(bm25));
1012
1013        // INVARIANT: Every result doc_id must be in the allowed set
1014        for doc in &result.results {
1015            assert!(
1016                allowed_set.contains(doc.doc_id),
1017                "INVARIANT VIOLATION: doc_id {} not in allowed set",
1018                doc.doc_id
1019            );
1020        }
1021    }
1022
1023    // ---- Three-lane fusion: grep (Task 5) wired into hybrid (Task 7) -------
1024
1025    use crate::grep_executor::GrepMode;
1026    use crate::namespace::Namespace;
1027    use crate::trigram_index::TrigramIndex;
1028
1029    struct MockVector(Vec<ScoredResult>);
1030    impl VectorExecutor for MockVector {
1031        fn search(&self, _q: &[f32], k: usize, allowed: &AllowedSet) -> Vec<ScoredResult> {
1032            self.0
1033                .iter()
1034                .filter(|r| allowed.contains(r.doc_id))
1035                .take(k)
1036                .cloned()
1037                .collect()
1038        }
1039    }
1040
1041    struct MockBm25(Vec<ScoredResult>);
1042    impl Bm25Executor for MockBm25 {
1043        fn search(&self, _q: &str, k: usize, allowed: &AllowedSet) -> Vec<ScoredResult> {
1044            self.0
1045                .iter()
1046                .filter(|r| allowed.contains(r.doc_id))
1047                .take(k)
1048                .cloned()
1049                .collect()
1050        }
1051    }
1052
1053    /// Grep lane backed by the real trigram index + grep executor — this proves
1054    /// the wiring drives the actual Task 5 machinery, not a stub.
1055    struct RealGrep {
1056        index: TrigramIndex,
1057    }
1058    impl GrepLaneExecutor for RealGrep {
1059        fn grep(
1060            &self,
1061            pattern: &str,
1062            k: usize,
1063            allowed: &AllowedSet,
1064            mode: GrepMode,
1065        ) -> Vec<ScoredResult> {
1066            let exec = crate::grep_executor::GrepExecutor::new(&self.index);
1067            match exec.search(pattern, allowed, k, mode) {
1068                Ok(results) => results
1069                    .hits
1070                    .into_iter()
1071                    .map(|h| ScoredResult::new(h.doc_id, h.score))
1072                    .collect(),
1073                Err(_) => Vec::new(),
1074            }
1075        }
1076    }
1077
1078    fn test_query() -> UnifiedHybridQuery {
1079        UnifiedHybridQuery::new(NamespaceScope::single(Namespace::new("test").unwrap()))
1080    }
1081
1082    fn grep_index() -> TrigramIndex {
1083        let mut idx = TrigramIndex::new();
1084        idx.insert(1, "fn alpha() { compute_idf() }");
1085        idx.insert(2, "fn beta() { unrelated helper }");
1086        idx.insert(3, "fn gamma() { compute_idf() twice compute_idf() }");
1087        idx.insert(4, "struct Config { compute_idf: bool }");
1088        idx
1089    }
1090
1091    #[test]
1092    fn test_three_lane_rank_fusion_respects_allowed_set() {
1093        // Vector + BM25 favor docs {1,2,4}; grep(Rank) for "compute_idf" finds
1094        // {1,3,4}. Fusing all three must (a) stay within the allowed set and
1095        // (b) surface doc 3, which only the grep lane contributes.
1096        let vector = MockVector(vec![
1097            ScoredResult::new(2, 0.9),
1098            ScoredResult::new(1, 0.8),
1099            ScoredResult::new(4, 0.2),
1100        ]);
1101        let bm25 = MockBm25(vec![ScoredResult::new(2, 5.0), ScoredResult::new(1, 3.0)]);
1102        let grep = RealGrep {
1103            index: grep_index(),
1104        };
1105
1106        let allowed = AllowedSet::from_iter([1, 2, 3, 4]);
1107        let executor =
1108            UnifiedHybridExecutor::new(Arc::new(vector), Arc::new(bm25), FusionConfig::default())
1109                .with_grep_executor(Arc::new(grep));
1110
1111        let query = test_query()
1112            .with_vector(vec![0.0; 4])
1113            .with_bm25("anything")
1114            .with_grep("compute_idf", GrepMode::Rank);
1115
1116        let result = executor.execute(&query, &AuthScope::for_namespace("test"), &allowed);
1117
1118        assert!(!result.results.is_empty());
1119        for r in &result.results {
1120            assert!(
1121                allowed.contains(r.doc_id),
1122                "result {} escaped allowed set",
1123                r.doc_id
1124            );
1125        }
1126        let ids: Vec<u64> = result.results.iter().map(|r| r.doc_id).collect();
1127        assert!(
1128            ids.contains(&3),
1129            "grep-only doc 3 should appear via the third lane, got {ids:?}"
1130        );
1131    }
1132
1133    #[test]
1134    fn test_grep_gate_narrows_before_other_lanes() {
1135        // grep(Gate) for "compute_idf" matches {1,3,4}. Even though the vector
1136        // lane ranks doc 2 first, the gate must exclude it entirely (cascade).
1137        let vector = MockVector(vec![
1138            ScoredResult::new(2, 0.9),
1139            ScoredResult::new(1, 0.8),
1140            ScoredResult::new(4, 0.7),
1141            ScoredResult::new(3, 0.6),
1142        ]);
1143        let bm25 = MockBm25(vec![ScoredResult::new(2, 5.0)]);
1144        let grep = RealGrep {
1145            index: grep_index(),
1146        };
1147
1148        let allowed = AllowedSet::from_iter([1, 2, 3, 4]);
1149        let executor =
1150            UnifiedHybridExecutor::new(Arc::new(vector), Arc::new(bm25), FusionConfig::default())
1151                .with_grep_executor(Arc::new(grep));
1152
1153        let query = test_query()
1154            .with_vector(vec![0.0; 4])
1155            .with_bm25("anything")
1156            .with_grep("compute_idf", GrepMode::Gate);
1157
1158        let result = executor.execute(&query, &AuthScope::for_namespace("test"), &allowed);
1159
1160        assert!(!result.results.is_empty());
1161        let gate: std::collections::HashSet<u64> = [1, 3, 4].into_iter().collect();
1162        for r in &result.results {
1163            assert!(
1164                gate.contains(&r.doc_id),
1165                "doc {} not in grep gate {{1,3,4}}",
1166                r.doc_id
1167            );
1168        }
1169        assert!(
1170            !result.results.iter().any(|r| r.doc_id == 2),
1171            "doc 2 (no compute_idf) must be gated out"
1172        );
1173    }
1174
1175    #[test]
1176    fn test_grep_query_ignored_without_grep_executor() {
1177        // A grep_query with no configured grep executor degrades cleanly to the
1178        // two-lane vector+BM25 path (no panic, grep simply absent).
1179        let vector = MockVector(vec![ScoredResult::new(1, 0.9)]);
1180        let bm25 = MockBm25(vec![ScoredResult::new(2, 5.0)]);
1181        let allowed = AllowedSet::from_iter([1, 2, 3, 4]);
1182        let executor =
1183            UnifiedHybridExecutor::new(Arc::new(vector), Arc::new(bm25), FusionConfig::default());
1184
1185        let query = test_query()
1186            .with_vector(vec![0.0; 4])
1187            .with_bm25("anything")
1188            .with_grep("compute_idf", GrepMode::Gate);
1189
1190        let result = executor.execute(&query, &AuthScope::for_namespace("test"), &allowed);
1191        let ids: Vec<u64> = result.results.iter().map(|r| r.doc_id).collect();
1192        assert!(
1193            ids.contains(&1) && ids.contains(&2),
1194            "without a grep executor both lanes survive, got {ids:?}"
1195        );
1196    }
1197}
1198
1199// ============================================================================
1200// Invariant Verification
1201// ============================================================================
1202
1203/// Verify that a fusion result respects the no-post-filtering invariant
1204///
1205/// This function should be used in tests and optionally in debug builds
1206/// to verify that the security invariant holds.
1207///
1208/// # Invariant
1209///
1210/// `∀ doc ∈ result: doc.id ∈ allowed_set`
1211///
1212/// This is the "monotone property" from the architecture document.
1213pub fn verify_no_post_filter_invariant(
1214    result: &FusionResult,
1215    allowed_set: &AllowedSet,
1216) -> InvariantVerification {
1217    let mut violations = Vec::new();
1218
1219    for doc in &result.results {
1220        if !allowed_set.contains(doc.doc_id) {
1221            violations.push(doc.doc_id);
1222        }
1223    }
1224
1225    if violations.is_empty() {
1226        InvariantVerification::Valid
1227    } else {
1228        InvariantVerification::Violated {
1229            doc_ids: violations,
1230        }
1231    }
1232}
1233
1234/// Result of invariant verification
1235#[derive(Debug, Clone, PartialEq, Eq)]
1236pub enum InvariantVerification {
1237    /// Invariant holds
1238    Valid,
1239    /// Invariant violated - these doc IDs should not be in results
1240    Violated { doc_ids: Vec<u64> },
1241}
1242
1243impl InvariantVerification {
1244    /// Check if the invariant holds
1245    pub fn is_valid(&self) -> bool {
1246        matches!(self, Self::Valid)
1247    }
1248
1249    /// Panic if the invariant is violated (for testing)
1250    pub fn assert_valid(&self) {
1251        match self {
1252            Self::Valid => {}
1253            Self::Violated { doc_ids } => {
1254                panic!(
1255                    "NO-POST-FILTER INVARIANT VIOLATED: {} docs not in allowed set: {:?}",
1256                    doc_ids.len(),
1257                    doc_ids
1258                );
1259            }
1260        }
1261    }
1262}