Skip to main content

sochdb_query/
hybrid_retrieval.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Hybrid Retrieval Pipeline (Task 3)
19//!
20//! This module implements a unified hybrid query planner combining:
21//! - Vector similarity search (ANN)
22//! - Lexical search (BM25)
23//! - Metadata filtering (PRE-FILTER ONLY)
24//! - Score fusion (RRF)
25//! - Cross-encoder reranking
26//!
27//! ## CRITICAL INVARIANT: No Post-Filtering
28//!
29//! This module enforces a hard security invariant:
30//! 
31//! > **All filtering MUST occur during candidate generation, never after.**
32//!
33//! The `ExecutionStep::PostFilter` variant has been intentionally removed.
34//! This guarantees:
35//! 1. **Security by construction**: No leakage of filtered documents
36//! 2. **No wasted compute**: We never score disallowed documents
37//! 3. **Monotone property**: `result-set ⊆ allowed-set` (verifiable)
38//!
39//! ## Correct Pattern
40//!
41//! ```text
42//! FilterIR + AuthScope → AllowedSet (computed once)
43//!     ↓
44//! vector_search(query, AllowedSet) → filtered candidates
45//! bm25_search(query, AllowedSet)   → filtered candidates
46//!     ↓
47//! fusion(filtered_v, filtered_b)   → already correct!
48//!     ↓
49//! rerank, limit                    → final results
50//! ```
51//!
52//! ## Anti-Pattern (What We Prevent)
53//!
54//! ```text
55//! BAD: vector_search() → candidates → filter → too few/leaky
56//!      bm25_search()   → candidates → filter → inconsistent
57//!      fusion()        → filter at end → SECURITY RISK!
58//! ```
59//!
60//! ## Execution Plan
61//!
62//! ```text
63//! HybridQuery
64//!     │
65//!     ▼
66//! ┌─────────────────────────────────────────┐
67//! │              ExecutionPlan              │
68//! │  ┌─────────┐ ┌─────────┐ ┌──────────┐  │
69//! │  │ Vector  │ │  BM25   │ │  Filter  │  │
70//! │  │ Search  │ │ Search  │ │ (PRE-ONLY)│  │
71//! │  └────┬────┘ └────┬────┘ └────┬─────┘  │
72//! │       │           │           │        │
73//! │       └─────┬─────┘           │        │
74//! │             ▼                 │        │
75//! │       ┌─────────┐             │        │
76//! │       │  Fusion │◄────────────┘        │
77//! │       │  (RRF)  │                      │
78//! │       └────┬────┘                      │
79//! │            ▼                           │
80//! │       ┌─────────┐                      │
81//! │       │ Rerank  │                      │
82//! │       └────┬────┘                      │
83//! │            ▼                           │
84//! │       ┌─────────┐                      │
85//! │       │  Limit  │                      │
86//! │       └─────────┘                      │
87//! └─────────────────────────────────────────┘
88//! ```
89//!
90//! ## Scoring
91//!
92//! RRF fusion: `score(d) = Σ w_i / (k + rank_i(d))`
93//! where k is typically 60 (robust default)
94
95use std::collections::{HashMap, HashSet};
96use std::cmp::Ordering;
97use std::sync::Arc;
98
99use crate::context_query::VectorIndex;
100use crate::soch_ql::SochValue;
101
102// ============================================================================
103// Hybrid Query Builder
104// ============================================================================
105
106/// Builder for hybrid retrieval queries
107#[derive(Debug, Clone)]
108pub struct HybridQuery {
109    /// Collection to search
110    pub collection: String,
111    
112    /// Vector search component
113    pub vector: Option<VectorQueryComponent>,
114    
115    /// Lexical (BM25) search component
116    pub lexical: Option<LexicalQueryComponent>,
117    
118    /// Metadata filters
119    pub filters: Vec<MetadataFilter>,
120    
121    /// Fusion configuration
122    pub fusion: FusionConfig,
123    
124    /// Reranking configuration
125    pub rerank: Option<RerankConfig>,
126    
127    /// Result limit
128    pub limit: usize,
129    
130    /// Minimum score threshold
131    pub min_score: Option<f32>,
132}
133
134impl HybridQuery {
135    /// Create a new hybrid query builder
136    pub fn new(collection: &str) -> Self {
137        Self {
138            collection: collection.to_string(),
139            vector: None,
140            lexical: None,
141            filters: Vec::new(),
142            fusion: FusionConfig::default(),
143            rerank: None,
144            limit: 10,
145            min_score: None,
146        }
147    }
148    
149    /// Add vector search component
150    pub fn with_vector(mut self, embedding: Vec<f32>, weight: f32) -> Self {
151        self.vector = Some(VectorQueryComponent {
152            embedding,
153            weight,
154            ef_search: 100,
155        });
156        self
157    }
158    
159    /// Add vector search from text (requires embedding provider)
160    pub fn with_vector_text(mut self, text: String, weight: f32) -> Self {
161        self.vector = Some(VectorQueryComponent {
162            embedding: Vec::new(), // Will be resolved at execution time
163            weight,
164            ef_search: 100,
165        });
166        // Store text for later resolution
167        self.lexical = self.lexical.or(Some(LexicalQueryComponent {
168            query: text,
169            weight: 0.0, // Text stored but not used for lexical
170            fields: vec!["content".to_string()],
171        }));
172        self
173    }
174    
175    /// Add lexical (BM25) search component
176    pub fn with_lexical(mut self, query: &str, weight: f32) -> Self {
177        self.lexical = Some(LexicalQueryComponent {
178            query: query.to_string(),
179            weight,
180            fields: vec!["content".to_string()],
181        });
182        self
183    }
184    
185    /// Add lexical search with specific fields
186    pub fn with_lexical_fields(mut self, query: &str, weight: f32, fields: Vec<String>) -> Self {
187        self.lexical = Some(LexicalQueryComponent {
188            query: query.to_string(),
189            weight,
190            fields,
191        });
192        self
193    }
194    
195    /// Add metadata filter
196    pub fn filter(mut self, field: &str, op: FilterOp, value: SochValue) -> Self {
197        self.filters.push(MetadataFilter {
198            field: field.to_string(),
199            op,
200            value,
201        });
202        self
203    }
204    
205    /// Add equality filter
206    pub fn filter_eq(self, field: &str, value: impl Into<SochValue>) -> Self {
207        self.filter(field, FilterOp::Eq, value.into())
208    }
209    
210    /// Add range filter
211    pub fn filter_range(mut self, field: &str, min: Option<SochValue>, max: Option<SochValue>) -> Self {
212        if let Some(min_val) = min {
213            self.filters.push(MetadataFilter {
214                field: field.to_string(),
215                op: FilterOp::Gte,
216                value: min_val,
217            });
218        }
219        if let Some(max_val) = max {
220            self.filters.push(MetadataFilter {
221                field: field.to_string(),
222                op: FilterOp::Lte,
223                value: max_val,
224            });
225        }
226        self
227    }
228    
229    /// Set fusion method
230    pub fn with_fusion(mut self, method: FusionMethod) -> Self {
231        self.fusion.method = method;
232        self
233    }
234    
235    /// Set RRF k parameter
236    pub fn with_rrf_k(mut self, k: f32) -> Self {
237        self.fusion.rrf_k = k;
238        self
239    }
240    
241    /// Enable reranking
242    pub fn with_rerank(mut self, model: &str, top_n: usize) -> Self {
243        self.rerank = Some(RerankConfig {
244            model: model.to_string(),
245            top_n,
246            batch_size: 32,
247        });
248        self
249    }
250    
251    /// Set result limit
252    pub fn limit(mut self, limit: usize) -> Self {
253        self.limit = limit;
254        self
255    }
256    
257    /// Set minimum score threshold
258    pub fn min_score(mut self, score: f32) -> Self {
259        self.min_score = Some(score);
260        self
261    }
262}
263
264/// Vector search component
265#[derive(Debug, Clone)]
266pub struct VectorQueryComponent {
267    /// Query embedding
268    pub embedding: Vec<f32>,
269    /// Weight for fusion
270    pub weight: f32,
271    /// HNSW ef_search parameter
272    pub ef_search: usize,
273}
274
275/// Lexical search component
276#[derive(Debug, Clone)]
277pub struct LexicalQueryComponent {
278    /// Query text
279    pub query: String,
280    /// Weight for fusion
281    pub weight: f32,
282    /// Fields to search
283    pub fields: Vec<String>,
284}
285
286/// Metadata filter
287#[derive(Debug, Clone)]
288pub struct MetadataFilter {
289    /// Field name
290    pub field: String,
291    /// Comparison operator
292    pub op: FilterOp,
293    /// Value to compare
294    pub value: SochValue,
295}
296
297/// Filter comparison operators
298#[derive(Debug, Clone, Copy, PartialEq, Eq)]
299pub enum FilterOp {
300    /// Equal
301    Eq,
302    /// Not equal
303    Ne,
304    /// Greater than
305    Gt,
306    /// Greater than or equal
307    Gte,
308    /// Less than
309    Lt,
310    /// Less than or equal
311    Lte,
312    /// Contains (for arrays/strings)
313    Contains,
314    /// In set
315    In,
316}
317
318/// Fusion configuration
319#[derive(Debug, Clone)]
320pub struct FusionConfig {
321    /// Fusion method
322    pub method: FusionMethod,
323    /// RRF k parameter (default: 60)
324    pub rrf_k: f32,
325    /// Normalize scores before fusion
326    pub normalize: bool,
327}
328
329impl Default for FusionConfig {
330    fn default() -> Self {
331        Self {
332            method: FusionMethod::Rrf,
333            rrf_k: 60.0,
334            normalize: true,
335        }
336    }
337}
338
339/// Score fusion methods
340#[derive(Debug, Clone, Copy, PartialEq, Eq)]
341pub enum FusionMethod {
342    /// Reciprocal Rank Fusion
343    Rrf,
344    /// Weighted sum of normalized scores
345    WeightedSum,
346    /// Max score from any source
347    Max,
348    /// Relative score fusion
349    Rsf,
350}
351
352/// Reranking configuration
353#[derive(Debug, Clone)]
354pub struct RerankConfig {
355    /// Reranker model
356    pub model: String,
357    /// Number of top candidates to rerank
358    pub top_n: usize,
359    /// Batch size for reranking
360    pub batch_size: usize,
361}
362
363// ============================================================================
364// Execution Plan
365// ============================================================================
366
367/// Execution plan for hybrid query
368#[derive(Debug, Clone)]
369pub struct HybridExecutionPlan {
370    /// Query being executed
371    pub query: HybridQuery,
372    
373    /// Execution steps
374    pub steps: Vec<ExecutionStep>,
375    
376    /// Estimated cost
377    pub estimated_cost: f64,
378}
379
380/// Individual execution step
381#[derive(Debug, Clone)]
382pub enum ExecutionStep {
383    /// Vector similarity search
384    VectorSearch {
385        collection: String,
386        ef_search: usize,
387        weight: f32,
388    },
389    
390    /// Lexical (BM25) search
391    LexicalSearch {
392        collection: String,
393        query: String,
394        fields: Vec<String>,
395        weight: f32,
396    },
397    
398    /// Pre-filter (before retrieval) - REQUIRED for security
399    /// 
400    /// This is the ONLY allowed filter step. Filters are always applied
401    /// during candidate generation via AllowedSet, never after.
402    PreFilter {
403        filters: Vec<MetadataFilter>,
404    },
405    
406    // NOTE: PostFilter has been REMOVED by design.
407    // The "no post-filtering" invariant is a hard security requirement.
408    // All filtering must happen via PreFilter -> AllowedSet -> candidate generation.
409    // See unified_fusion.rs for the correct pattern.
410    
411    /// Score fusion
412    Fusion {
413        method: FusionMethod,
414        rrf_k: f32,
415    },
416    
417    /// Reranking (does NOT filter, only re-orders)
418    Rerank {
419        model: String,
420        top_n: usize,
421    },
422    
423    /// Limit results (applied AFTER all filtering is complete)
424    Limit {
425        count: usize,
426        min_score: Option<f32>,
427    },
428    
429    /// Redaction transform (post-retrieval modification, NOT filtering)
430    /// 
431    /// Unlike filtering (which removes candidates), redaction transforms
432    /// the content of already-allowed documents. This preserves the
433    /// invariant: result-set ⊆ allowed-set.
434    Redact {
435        /// Fields to redact
436        fields: Vec<String>,
437        /// Redaction method
438        method: RedactionMethod,
439    },
440}
441
442/// Redaction methods for post-retrieval content transformation
443#[derive(Debug, Clone)]
444pub enum RedactionMethod {
445    /// Replace with a fixed string
446    Replace(String),
447    /// Mask with asterisks
448    Mask,
449    /// Remove the field entirely
450    Remove,
451    /// Hash the value
452    Hash,
453}
454
455// ============================================================================
456// Hybrid Query Executor
457// ============================================================================
458
459/// Executor for hybrid queries
460pub struct HybridQueryExecutor<V: VectorIndex> {
461    /// Vector index
462    vector_index: Arc<V>,
463    
464    /// Lexical index (BM25)
465    lexical_index: Arc<LexicalIndex>,
466}
467
468impl<V: VectorIndex> HybridQueryExecutor<V> {
469    /// Create a new executor
470    pub fn new(vector_index: Arc<V>, lexical_index: Arc<LexicalIndex>) -> Self {
471        Self {
472            vector_index,
473            lexical_index,
474        }
475    }
476    
477    /// Execute a hybrid query
478    pub fn execute(&self, query: &HybridQuery) -> Result<HybridQueryResult, HybridQueryError> {
479        let mut candidates: HashMap<String, CandidateDoc> = HashMap::new();
480        
481        // Over-fetch factor for fusion
482        let overfetch = (query.limit * 3).max(100);
483        
484        // Execute vector search
485        if let Some(vector) = &query.vector {
486            if !vector.embedding.is_empty() {
487                let results = self.vector_index
488                    .search_by_embedding(&query.collection, &vector.embedding, overfetch, None)
489                    .map_err(HybridQueryError::VectorSearchError)?;
490                
491                for (rank, result) in results.iter().enumerate() {
492                    let entry = candidates.entry(result.id.clone()).or_insert_with(|| {
493                        CandidateDoc {
494                            id: result.id.clone(),
495                            content: result.content.clone(),
496                            metadata: result.metadata.clone(),
497                            vector_rank: None,
498                            vector_score: None,
499                            lexical_rank: None,
500                            lexical_score: None,
501                            fused_score: 0.0,
502                        }
503                    });
504                    entry.vector_rank = Some(rank);
505                    entry.vector_score = Some(result.score);
506                }
507            }
508        }
509        
510        // Execute lexical search
511        if let Some(lexical) = &query.lexical {
512            if lexical.weight > 0.0 {
513                let results = self.lexical_index.search(
514                    &query.collection,
515                    &lexical.query,
516                    &lexical.fields,
517                    overfetch,
518                )?;
519                
520                for (rank, result) in results.iter().enumerate() {
521                    let entry = candidates.entry(result.id.clone()).or_insert_with(|| {
522                        CandidateDoc {
523                            id: result.id.clone(),
524                            content: result.content.clone(),
525                            metadata: HashMap::new(),
526                            vector_rank: None,
527                            vector_score: None,
528                            lexical_rank: None,
529                            lexical_score: None,
530                            fused_score: 0.0,
531                        }
532                    });
533                    entry.lexical_rank = Some(rank);
534                    entry.lexical_score = Some(result.score);
535                }
536            }
537        }
538        
539        // Apply filters
540        let filtered: Vec<CandidateDoc> = candidates
541            .into_values()
542            .filter(|doc| self.matches_filters(doc, &query.filters))
543            .collect();
544        
545        // Fuse scores
546        let mut fused = self.fuse_scores(filtered, query)?;
547        
548        // Sort by fused score (descending)
549        fused.sort_by(|a, b| b.fused_score.partial_cmp(&a.fused_score).unwrap_or(Ordering::Equal));
550        
551        // Apply reranking (stub - would call reranker model)
552        if let Some(rerank) = &query.rerank {
553            fused = self.rerank(&fused, &query.lexical.as_ref().map(|l| l.query.clone()).unwrap_or_default(), rerank)?;
554        }
555        
556        // Apply min_score filter
557        if let Some(min) = query.min_score {
558            fused.retain(|doc| doc.fused_score >= min);
559        }
560        
561        // Limit results
562        fused.truncate(query.limit);
563        
564        // Convert to results
565        let results: Vec<HybridSearchResult> = fused
566            .into_iter()
567            .map(|doc| HybridSearchResult {
568                id: doc.id,
569                score: doc.fused_score,
570                content: doc.content,
571                metadata: doc.metadata,
572                vector_score: doc.vector_score,
573                lexical_score: doc.lexical_score,
574            })
575            .collect();
576        
577        Ok(HybridQueryResult {
578            results,
579            query: query.clone(),
580            stats: HybridQueryStats {
581                vector_candidates: 0, // Would be populated in real impl
582                lexical_candidates: 0,
583                filtered_candidates: 0,
584                fusion_time_us: 0,
585                rerank_time_us: 0,
586            },
587        })
588    }
589    
590    /// Check if document matches all filters
591    fn matches_filters(&self, doc: &CandidateDoc, filters: &[MetadataFilter]) -> bool {
592        for filter in filters {
593            if let Some(value) = doc.metadata.get(&filter.field) {
594                if !self.match_filter(value, &filter.op, &filter.value) {
595                    return false;
596                }
597            } else {
598                // Field not present - filter fails
599                return false;
600            }
601        }
602        true
603    }
604    
605    /// Match a single filter
606    fn match_filter(&self, doc_value: &SochValue, op: &FilterOp, filter_value: &SochValue) -> bool {
607        match op {
608            FilterOp::Eq => doc_value == filter_value,
609            FilterOp::Ne => doc_value != filter_value,
610            FilterOp::Gt => self.compare_values(doc_value, filter_value) == Some(Ordering::Greater),
611            FilterOp::Gte => matches!(self.compare_values(doc_value, filter_value), Some(Ordering::Greater | Ordering::Equal)),
612            FilterOp::Lt => self.compare_values(doc_value, filter_value) == Some(Ordering::Less),
613            FilterOp::Lte => matches!(self.compare_values(doc_value, filter_value), Some(Ordering::Less | Ordering::Equal)),
614            FilterOp::Contains => self.value_contains(doc_value, filter_value),
615            FilterOp::In => self.value_in_set(doc_value, filter_value),
616        }
617    }
618    
619    /// Compare two SochValues
620    fn compare_values(&self, a: &SochValue, b: &SochValue) -> Option<Ordering> {
621        match (a, b) {
622            (SochValue::Int(a), SochValue::Int(b)) => Some(a.cmp(b)),
623            (SochValue::UInt(a), SochValue::UInt(b)) => Some(a.cmp(b)),
624            (SochValue::Float(a), SochValue::Float(b)) => a.partial_cmp(b),
625            (SochValue::Text(a), SochValue::Text(b)) => Some(a.cmp(b)),
626            _ => None,
627        }
628    }
629    
630    /// Check if value contains another
631    fn value_contains(&self, doc_value: &SochValue, search_value: &SochValue) -> bool {
632        match (doc_value, search_value) {
633            (SochValue::Text(text), SochValue::Text(search)) => text.contains(search.as_str()),
634            (SochValue::Array(arr), _) => arr.contains(search_value),
635            _ => false,
636        }
637    }
638    
639    /// Check if value is in set
640    fn value_in_set(&self, doc_value: &SochValue, set_value: &SochValue) -> bool {
641        if let SochValue::Array(arr) = set_value {
642            arr.contains(doc_value)
643        } else {
644            false
645        }
646    }
647    
648    /// Fuse scores from multiple sources
649    fn fuse_scores(
650        &self,
651        candidates: Vec<CandidateDoc>,
652        query: &HybridQuery,
653    ) -> Result<Vec<CandidateDoc>, HybridQueryError> {
654        let vector_weight = query.vector.as_ref().map(|v| v.weight).unwrap_or(0.0);
655        let lexical_weight = query.lexical.as_ref().map(|l| l.weight).unwrap_or(0.0);
656        
657        let mut fused = candidates;
658        
659        match query.fusion.method {
660            FusionMethod::Rrf => {
661                // Reciprocal Rank Fusion
662                // score(d) = Σ w_i / (k + rank_i(d))
663                for doc in &mut fused {
664                    let mut score = 0.0;
665                    
666                    if let Some(rank) = doc.vector_rank {
667                        score += vector_weight / (query.fusion.rrf_k + rank as f32);
668                    }
669                    
670                    if let Some(rank) = doc.lexical_rank {
671                        score += lexical_weight / (query.fusion.rrf_k + rank as f32);
672                    }
673                    
674                    doc.fused_score = score;
675                }
676            }
677            
678            FusionMethod::WeightedSum => {
679                // Weighted sum of normalized scores
680                for doc in &mut fused {
681                    let mut score = 0.0;
682                    
683                    if let Some(s) = doc.vector_score {
684                        score += vector_weight * s;
685                    }
686                    
687                    if let Some(s) = doc.lexical_score {
688                        score += lexical_weight * s;
689                    }
690                    
691                    doc.fused_score = score;
692                }
693            }
694            
695            FusionMethod::Max => {
696                // Maximum score from any source
697                for doc in &mut fused {
698                    let v_score = doc.vector_score.map(|s| vector_weight * s).unwrap_or(0.0);
699                    let l_score = doc.lexical_score.map(|s| lexical_weight * s).unwrap_or(0.0);
700                    doc.fused_score = v_score.max(l_score);
701                }
702            }
703            
704            FusionMethod::Rsf => {
705                // Relative Score Fusion (simplified)
706                for doc in &mut fused {
707                    let mut score = 0.0;
708                    let mut count = 0;
709                    
710                    if let Some(s) = doc.vector_score {
711                        score += s;
712                        count += 1;
713                    }
714                    
715                    if let Some(s) = doc.lexical_score {
716                        score += s;
717                        count += 1;
718                    }
719                    
720                    doc.fused_score = if count > 0 { score / count as f32 } else { 0.0 };
721                }
722            }
723        }
724        
725        Ok(fused)
726    }
727    
728    /// Rerank candidates using cross-encoder (stub)
729    fn rerank(
730        &self,
731        candidates: &[CandidateDoc],
732        query: &str,
733        config: &RerankConfig,
734    ) -> Result<Vec<CandidateDoc>, HybridQueryError> {
735        // Take top_n candidates for reranking
736        let to_rerank: Vec<_> = candidates.iter().take(config.top_n).cloned().collect();
737        
738        // Stub: In production, would call cross-encoder model
739        // For now, just apply a small boost based on query term overlap
740        let mut reranked = to_rerank;
741        let query_terms: HashSet<&str> = query.split_whitespace().collect();
742        
743        for doc in &mut reranked {
744            let content_terms: HashSet<&str> = doc.content.split_whitespace().collect();
745            let overlap = query_terms.intersection(&content_terms).count();
746            
747            // Small boost for term overlap
748            doc.fused_score += (overlap as f32) * 0.01;
749        }
750        
751        // Add remaining candidates unchanged
752        reranked.extend(candidates.iter().skip(config.top_n).cloned());
753        
754        Ok(reranked)
755    }
756}
757
758/// Internal candidate document during processing
759#[derive(Debug, Clone)]
760struct CandidateDoc {
761    id: String,
762    content: String,
763    metadata: HashMap<String, SochValue>,
764    vector_rank: Option<usize>,
765    vector_score: Option<f32>,
766    lexical_rank: Option<usize>,
767    lexical_score: Option<f32>,
768    fused_score: f32,
769}
770
771// ============================================================================
772// Lexical Index (BM25)
773// ============================================================================
774
775/// Simple lexical (BM25) index
776pub struct LexicalIndex {
777    /// Collections: name -> inverted index
778    collections: std::sync::RwLock<HashMap<String, InvertedIndex>>,
779}
780
781/// Inverted index for a collection
782struct InvertedIndex {
783    /// Term -> posting list (doc_id, term_freq)
784    postings: HashMap<String, Vec<(String, u32)>>,
785    
786    /// Document lengths
787    doc_lengths: HashMap<String, u32>,
788    
789    /// Document contents
790    documents: HashMap<String, String>,
791    
792    /// Average document length
793    avg_doc_len: f32,
794    
795    /// BM25 parameters
796    k1: f32,
797    b: f32,
798}
799
800/// Lexical search result
801#[derive(Debug, Clone)]
802pub struct LexicalSearchResult {
803    pub id: String,
804    pub score: f32,
805    pub content: String,
806}
807
808impl LexicalIndex {
809    /// Create a new lexical index
810    pub fn new() -> Self {
811        Self {
812            collections: std::sync::RwLock::new(HashMap::new()),
813        }
814    }
815    
816    /// Create collection
817    pub fn create_collection(&self, name: &str) {
818        let mut collections = self.collections.write().unwrap();
819        collections.insert(name.to_string(), InvertedIndex {
820            postings: HashMap::new(),
821            doc_lengths: HashMap::new(),
822            documents: HashMap::new(),
823            avg_doc_len: 0.0,
824            k1: 1.2,
825            b: 0.75,
826        });
827    }
828    
829    /// Index a document
830    pub fn index_document(&self, collection: &str, id: &str, content: &str) -> Result<(), HybridQueryError> {
831        let mut collections = self.collections.write().unwrap();
832        let index = collections.get_mut(collection)
833            .ok_or_else(|| HybridQueryError::CollectionNotFound(collection.to_string()))?;
834        
835        // Tokenize
836        let tokens: Vec<String> = content
837            .split_whitespace()
838            .map(|t| t.to_lowercase())
839            .collect();
840        
841        let doc_len = tokens.len() as u32;
842        
843        // Update document length
844        index.doc_lengths.insert(id.to_string(), doc_len);
845        index.documents.insert(id.to_string(), content.to_string());
846        
847        // Update average doc length
848        let total_len: u32 = index.doc_lengths.values().sum();
849        index.avg_doc_len = total_len as f32 / index.doc_lengths.len() as f32;
850        
851        // Count term frequencies
852        let mut term_freqs: HashMap<String, u32> = HashMap::new();
853        for token in &tokens {
854            *term_freqs.entry(token.clone()).or_insert(0) += 1;
855        }
856        
857        // Update postings
858        for (term, freq) in term_freqs {
859            index.postings
860                .entry(term)
861                .or_insert_with(Vec::new)
862                .push((id.to_string(), freq));
863        }
864        
865        Ok(())
866    }
867    
868    /// Search using BM25
869    pub fn search(
870        &self,
871        collection: &str,
872        query: &str,
873        _fields: &[String],
874        limit: usize,
875    ) -> Result<Vec<LexicalSearchResult>, HybridQueryError> {
876        let collections = self.collections.read().unwrap();
877        let index = collections.get(collection)
878            .ok_or_else(|| HybridQueryError::CollectionNotFound(collection.to_string()))?;
879        
880        // Tokenize query
881        let query_terms: Vec<String> = query
882            .split_whitespace()
883            .map(|t| t.to_lowercase())
884            .collect();
885        
886        let n = index.doc_lengths.len() as f32;
887        let mut scores: HashMap<String, f32> = HashMap::new();
888        
889        // Calculate BM25 scores
890        for term in &query_terms {
891            if let Some(postings) = index.postings.get(term) {
892                let df = postings.len() as f32;
893                let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
894                
895                for (doc_id, tf) in postings {
896                    let doc_len = *index.doc_lengths.get(doc_id).unwrap_or(&1) as f32;
897                    let tf = *tf as f32;
898                    
899                    // BM25 formula
900                    let score = idf * (tf * (index.k1 + 1.0)) / 
901                        (tf + index.k1 * (1.0 - index.b + index.b * doc_len / index.avg_doc_len));
902                    
903                    *scores.entry(doc_id.clone()).or_insert(0.0) += score;
904                }
905            }
906        }
907        
908        // Sort by score
909        let mut results: Vec<_> = scores.into_iter().collect();
910        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
911        
912        // Convert to results
913        let results: Vec<LexicalSearchResult> = results
914            .into_iter()
915            .take(limit)
916            .map(|(id, score)| {
917                let content = index.documents.get(&id).cloned().unwrap_or_default();
918                LexicalSearchResult { id, score, content }
919            })
920            .collect();
921        
922        Ok(results)
923    }
924}
925
926impl Default for LexicalIndex {
927    fn default() -> Self {
928        Self::new()
929    }
930}
931
932// ============================================================================
933// Results
934// ============================================================================
935
936/// Hybrid search result
937#[derive(Debug, Clone)]
938pub struct HybridSearchResult {
939    /// Document ID
940    pub id: String,
941    /// Fused score
942    pub score: f32,
943    /// Document content
944    pub content: String,
945    /// Document metadata
946    pub metadata: HashMap<String, SochValue>,
947    /// Score from vector search (if any)
948    pub vector_score: Option<f32>,
949    /// Score from lexical search (if any)
950    pub lexical_score: Option<f32>,
951}
952
953/// Result of hybrid query execution
954#[derive(Debug, Clone)]
955pub struct HybridQueryResult {
956    /// Search results
957    pub results: Vec<HybridSearchResult>,
958    /// Original query
959    pub query: HybridQuery,
960    /// Execution statistics
961    pub stats: HybridQueryStats,
962}
963
964/// Execution statistics
965#[derive(Debug, Clone, Default)]
966pub struct HybridQueryStats {
967    /// Candidates from vector search
968    pub vector_candidates: usize,
969    /// Candidates from lexical search
970    pub lexical_candidates: usize,
971    /// Candidates after filtering
972    pub filtered_candidates: usize,
973    /// Fusion time in microseconds
974    pub fusion_time_us: u64,
975    /// Rerank time in microseconds
976    pub rerank_time_us: u64,
977}
978
979/// Hybrid query error
980#[derive(Debug, Clone)]
981pub enum HybridQueryError {
982    /// Collection not found
983    CollectionNotFound(String),
984    /// Vector search error
985    VectorSearchError(String),
986    /// Lexical search error
987    LexicalSearchError(String),
988    /// Filter error
989    FilterError(String),
990    /// Rerank error
991    RerankError(String),
992}
993
994impl std::fmt::Display for HybridQueryError {
995    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
996        match self {
997            Self::CollectionNotFound(name) => write!(f, "Collection not found: {}", name),
998            Self::VectorSearchError(msg) => write!(f, "Vector search error: {}", msg),
999            Self::LexicalSearchError(msg) => write!(f, "Lexical search error: {}", msg),
1000            Self::FilterError(msg) => write!(f, "Filter error: {}", msg),
1001            Self::RerankError(msg) => write!(f, "Rerank error: {}", msg),
1002        }
1003    }
1004}
1005
1006impl std::error::Error for HybridQueryError {}
1007
1008// ============================================================================
1009// Tests
1010// ============================================================================
1011
1012#[cfg(test)]
1013mod tests {
1014    use super::*;
1015    
1016    #[test]
1017    fn test_hybrid_query_builder() {
1018        let query = HybridQuery::new("documents")
1019            .with_vector(vec![0.1, 0.2, 0.3], 0.7)
1020            .with_lexical("search query", 0.3)
1021            .filter_eq("category", SochValue::Text("tech".to_string()))
1022            .with_fusion(FusionMethod::Rrf)
1023            .with_rerank("cross-encoder", 20)
1024            .limit(10);
1025        
1026        assert_eq!(query.collection, "documents");
1027        assert!(query.vector.is_some());
1028        assert!(query.lexical.is_some());
1029        assert_eq!(query.filters.len(), 1);
1030        assert_eq!(query.limit, 10);
1031    }
1032    
1033    #[test]
1034    fn test_lexical_index_bm25() {
1035        let index = LexicalIndex::new();
1036        index.create_collection("test");
1037        
1038        index.index_document("test", "doc1", "the quick brown fox").unwrap();
1039        index.index_document("test", "doc2", "the lazy dog sleeps").unwrap();
1040        index.index_document("test", "doc3", "quick fox jumps over the lazy dog").unwrap();
1041        
1042        let results = index.search("test", "quick fox", &[], 10).unwrap();
1043        
1044        assert!(!results.is_empty());
1045        // doc1 and doc3 should both appear in results (they both have "quick" and/or "fox")
1046        let ids: Vec<&str> = results.iter().map(|r| r.id.as_str()).collect();
1047        assert!(ids.contains(&"doc1") || ids.contains(&"doc3"));
1048        // doc2 should not appear (no "quick" or "fox")
1049        assert!(!ids.contains(&"doc2"));
1050    }
1051    
1052    #[test]
1053    fn test_rrf_fusion() {
1054        // RRF formula: score = Σ w / (k + rank)
1055        let k = 60.0;
1056        
1057        // Doc appears at rank 0 in vector, rank 5 in lexical
1058        let vector_weight = 0.7;
1059        let lexical_weight = 0.3;
1060        
1061        let score = vector_weight / (k + 0.0) + lexical_weight / (k + 5.0);
1062        
1063        // Should be approximately 0.0116 + 0.0046 = 0.0162
1064        assert!(score > 0.01 && score < 0.02);
1065    }
1066    
1067    #[test]
1068    fn test_filter_matching() {
1069        let filters = vec![
1070            MetadataFilter {
1071                field: "status".to_string(),
1072                op: FilterOp::Eq,
1073                value: SochValue::Text("active".to_string()),
1074            },
1075            MetadataFilter {
1076                field: "count".to_string(),
1077                op: FilterOp::Gte,
1078                value: SochValue::Int(10),
1079            },
1080        ];
1081        
1082        let mut metadata = HashMap::new();
1083        metadata.insert("status".to_string(), SochValue::Text("active".to_string()));
1084        metadata.insert("count".to_string(), SochValue::Int(15));
1085        
1086        // Create a mock candidate
1087        let doc = CandidateDoc {
1088            id: "test".to_string(),
1089            content: "test content".to_string(),
1090            metadata,
1091            vector_rank: None,
1092            vector_score: None,
1093            lexical_rank: None,
1094            lexical_score: None,
1095            fused_score: 0.0,
1096        };
1097        
1098        // Would pass filters
1099        assert!(doc.metadata.get("status") == Some(&SochValue::Text("active".to_string())));
1100        if let Some(SochValue::Int(count)) = doc.metadata.get("count") {
1101            assert!(*count >= 10);
1102        }
1103    }
1104}