Skip to main content

sochdb_vector/
hybrid.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Hybrid Search with RRF Fusion (Task 4)
19//!
20//! This module combines vector similarity search (ANN) with lexical search (BM25)
21//! using Reciprocal Rank Fusion (RRF) for score combination.
22//!
23//! ## RRF Algorithm
24//!
25//! ```text
26//! RRF_score(d) = Σ weight_i / (k + rank_i(d))
27//! ```
28//!
29//! Where:
30//! - `k` is typically 60 (robust default)
31//! - `rank_i(d)` is the rank of document d in result list i (1-indexed)
32//! - `weight_i` is the weight for result list i
33//!
34//! ## Pipeline
35//!
36//! ```text
37//!                    Query
38//!                      │
39//!           ┌──────────┴──────────┐
40//!           │                     │
41//!           ▼                     ▼
42//!    ┌─────────────┐       ┌─────────────┐
43//!    │   Vector    │       │   Lexical   │
44//!    │   Search    │       │   Search    │
45//!    │   (HNSW)    │       │   (BM25)    │
46//!    └──────┬──────┘       └──────┬──────┘
47//!           │                     │
48//!           │  [(id, score), ...]│  [(id, score), ...]
49//!           │                     │
50//!           └──────────┬──────────┘
51//!                      │
52//!                      ▼
53//!               ┌─────────────┐
54//!               │  RRF Fusion │
55//!               └──────┬──────┘
56//!                      │
57//!                      ▼
58//!               ┌─────────────┐
59//!               │   Filter    │
60//!               │  (optional) │
61//!               └──────┬──────┘
62//!                      │
63//!                      ▼
64//!               ┌─────────────┐
65//!               │   Top-K     │
66//!               └─────────────┘
67//! ```
68
69use std::collections::HashMap;
70
71// ============================================================================
72// Types
73// ============================================================================
74
75/// Document ID type
76pub type DocId = u64;
77
78/// Search result with score
79#[derive(Debug, Clone)]
80pub struct SearchResult {
81    /// Document ID
82    pub doc_id: DocId,
83
84    /// Combined score (from fusion)
85    pub score: f32,
86
87    /// Component scores for debugging
88    pub component_scores: Option<ComponentScores>,
89}
90
91/// Individual component scores
92#[derive(Debug, Clone)]
93pub struct ComponentScores {
94    /// Vector similarity score
95    pub vector_score: Option<f32>,
96
97    /// Vector rank (1-indexed)
98    pub vector_rank: Option<usize>,
99
100    /// Lexical (BM25) score
101    pub lexical_score: Option<f32>,
102
103    /// Lexical rank (1-indexed)
104    pub lexical_rank: Option<usize>,
105}
106
107// ============================================================================
108// RRF Configuration
109// ============================================================================
110
111/// Configuration for Reciprocal Rank Fusion
112#[derive(Debug, Clone, Copy)]
113pub struct RRFConfig {
114    /// RRF k parameter (typically 60)
115    /// Higher values give more weight to lower-ranked results
116    pub k: f32,
117
118    /// Weight for vector search results
119    pub vector_weight: f32,
120
121    /// Weight for lexical search results
122    pub lexical_weight: f32,
123}
124
125impl Default for RRFConfig {
126    fn default() -> Self {
127        Self {
128            k: 60.0,
129            vector_weight: 1.0,
130            lexical_weight: 1.0,
131        }
132    }
133}
134
135impl RRFConfig {
136    /// Create with custom weights
137    pub fn with_weights(vector_weight: f32, lexical_weight: f32) -> Self {
138        Self {
139            k: 60.0,
140            vector_weight,
141            lexical_weight,
142        }
143    }
144
145    /// Emphasize vector search (semantic)
146    pub fn semantic_focused() -> Self {
147        Self {
148            k: 60.0,
149            vector_weight: 0.7,
150            lexical_weight: 0.3,
151        }
152    }
153
154    /// Emphasize lexical search (keyword)
155    pub fn keyword_focused() -> Self {
156        Self {
157            k: 60.0,
158            vector_weight: 0.3,
159            lexical_weight: 0.7,
160        }
161    }
162
163    /// Balanced (equal weights)
164    pub fn balanced() -> Self {
165        Self::default()
166    }
167}
168
169// ============================================================================
170// RRF Fusion
171// ============================================================================
172
173/// Reciprocal Rank Fusion combiner
174pub struct RRFFusion {
175    config: RRFConfig,
176}
177
178impl RRFFusion {
179    /// Create a new RRF fusion combiner
180    pub fn new(config: RRFConfig) -> Self {
181        Self { config }
182    }
183
184    /// Fuse vector and lexical search results
185    ///
186    /// # Arguments
187    /// * `vector_results` - Results from vector search, sorted by score descending
188    /// * `lexical_results` - Results from lexical search, sorted by score descending
189    /// * `limit` - Maximum number of results to return
190    /// * `keep_details` - Whether to include component scores
191    ///
192    /// # Returns
193    /// Fused results sorted by combined score descending
194    pub fn fuse(
195        &self,
196        vector_results: &[(DocId, f32)],
197        lexical_results: &[(DocId, f32)],
198        limit: usize,
199        keep_details: bool,
200    ) -> Vec<SearchResult> {
201        let k = self.config.k;
202
203        // Build rank maps (1-indexed ranks)
204        let mut doc_scores: HashMap<DocId, FusionState> = HashMap::new();
205
206        // Add vector results
207        for (rank, &(doc_id, score)) in vector_results.iter().enumerate() {
208            let rrf_score = self.config.vector_weight / (k + (rank + 1) as f32);
209
210            let state = doc_scores.entry(doc_id).or_default();
211            state.rrf_score += rrf_score;
212            state.vector_score = Some(score);
213            state.vector_rank = Some(rank + 1);
214        }
215
216        // Add lexical results
217        for (rank, &(doc_id, score)) in lexical_results.iter().enumerate() {
218            let rrf_score = self.config.lexical_weight / (k + (rank + 1) as f32);
219
220            let state = doc_scores.entry(doc_id).or_default();
221            state.rrf_score += rrf_score;
222            state.lexical_score = Some(score);
223            state.lexical_rank = Some(rank + 1);
224        }
225
226        // Convert to results and sort
227        let mut results: Vec<SearchResult> = doc_scores
228            .into_iter()
229            .map(|(doc_id, state)| SearchResult {
230                doc_id,
231                score: state.rrf_score,
232                component_scores: if keep_details {
233                    Some(ComponentScores {
234                        vector_score: state.vector_score,
235                        vector_rank: state.vector_rank,
236                        lexical_score: state.lexical_score,
237                        lexical_rank: state.lexical_rank,
238                    })
239                } else {
240                    None
241                },
242            })
243            .collect();
244
245        // Sort by RRF score descending
246        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
247
248        // Limit
249        results.truncate(limit);
250
251        results
252    }
253
254    /// Fuse multiple result lists with custom weights
255    ///
256    /// CROSS-CRATE INVARIANT: the RRF kernel here — `weight / (k + rank)` with
257    /// `rank` **1-indexed** — must stay numerically identical to
258    /// `sochdb_query::fuse_rrf_weighted`. The two crates are siblings (neither
259    /// depends on the other), so the formula is pinned independently in each by
260    /// a golden test (`test_fuse_multi_rrf_formula_golden` here and
261    /// `test_fuse_rrf_weighted_is_1_indexed_and_weighted` there); divergence on
262    /// either side fails that crate's test.
263    pub fn fuse_multi(
264        &self,
265        result_lists: &[(&[(DocId, f32)], f32)], // (results, weight)
266        limit: usize,
267    ) -> Vec<SearchResult> {
268        let k = self.config.k;
269        let mut doc_scores: HashMap<DocId, f32> = HashMap::new();
270
271        for (results, weight) in result_lists {
272            for (rank, &(doc_id, _score)) in results.iter().enumerate() {
273                let rrf_score = *weight / (k + (rank + 1) as f32);
274                *doc_scores.entry(doc_id).or_default() += rrf_score;
275            }
276        }
277
278        let mut results: Vec<SearchResult> = doc_scores
279            .into_iter()
280            .map(|(doc_id, score)| SearchResult {
281                doc_id,
282                score,
283                component_scores: None,
284            })
285            .collect();
286
287        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
288        results.truncate(limit);
289
290        results
291    }
292}
293
294/// Internal state for fusion
295#[derive(Default)]
296struct FusionState {
297    rrf_score: f32,
298    vector_score: Option<f32>,
299    vector_rank: Option<usize>,
300    lexical_score: Option<f32>,
301    lexical_rank: Option<usize>,
302}
303
304impl Default for RRFFusion {
305    fn default() -> Self {
306        Self::new(RRFConfig::default())
307    }
308}
309
310// ============================================================================
311// Hybrid Search Engine
312// ============================================================================
313
314/// A combined vector + lexical search engine
315pub struct HybridSearchEngine<V, L> {
316    /// Vector search backend
317    vector_search: V,
318
319    /// Lexical search backend
320    lexical_search: L,
321
322    /// RRF fusion config
323    fusion_config: RRFConfig,
324
325    /// Over-fetch factor for better fusion results
326    overfetch_factor: f32,
327}
328
329/// Trait for vector search backends
330pub trait VectorSearchBackend {
331    /// Search for similar vectors
332    fn search(&self, query: &[f32], k: usize) -> Vec<(DocId, f32)>;
333}
334
335/// Trait for lexical search backends
336pub trait LexicalSearchBackend {
337    /// Search by text query
338    fn search(&self, query: &str, k: usize) -> Vec<(DocId, f32)>;
339}
340
341impl<V, L> HybridSearchEngine<V, L>
342where
343    V: VectorSearchBackend,
344    L: LexicalSearchBackend,
345{
346    /// Create a new hybrid search engine
347    pub fn new(vector_search: V, lexical_search: L) -> Self {
348        Self {
349            vector_search,
350            lexical_search,
351            fusion_config: RRFConfig::default(),
352            overfetch_factor: 2.0,
353        }
354    }
355
356    /// Set fusion configuration
357    pub fn with_fusion_config(mut self, config: RRFConfig) -> Self {
358        self.fusion_config = config;
359        self
360    }
361
362    /// Set over-fetch factor
363    pub fn with_overfetch(mut self, factor: f32) -> Self {
364        self.overfetch_factor = factor.max(1.0);
365        self
366    }
367
368    /// Perform hybrid search
369    pub fn search(
370        &self,
371        vector_query: Option<&[f32]>,
372        text_query: Option<&str>,
373        limit: usize,
374    ) -> Vec<SearchResult> {
375        let fetch_k = (limit as f32 * self.overfetch_factor) as usize;
376
377        // Get vector results
378        let vector_results = match vector_query {
379            Some(q) => self.vector_search.search(q, fetch_k),
380            None => Vec::new(),
381        };
382
383        // Get lexical results
384        let lexical_results = match text_query {
385            Some(q) => self.lexical_search.search(q, fetch_k),
386            None => Vec::new(),
387        };
388
389        // If only one type of search, return directly
390        if vector_results.is_empty() {
391            return lexical_results
392                .into_iter()
393                .take(limit)
394                .map(|(doc_id, score)| SearchResult {
395                    doc_id,
396                    score,
397                    component_scores: None,
398                })
399                .collect();
400        }
401
402        if lexical_results.is_empty() {
403            return vector_results
404                .into_iter()
405                .take(limit)
406                .map(|(doc_id, score)| SearchResult {
407                    doc_id,
408                    score,
409                    component_scores: None,
410                })
411                .collect();
412        }
413
414        // Fuse results
415        let fusion = RRFFusion::new(self.fusion_config);
416        fusion.fuse(&vector_results, &lexical_results, limit, false)
417    }
418
419    /// Perform hybrid search with detailed scores
420    pub fn search_detailed(
421        &self,
422        vector_query: Option<&[f32]>,
423        text_query: Option<&str>,
424        limit: usize,
425    ) -> Vec<SearchResult> {
426        let fetch_k = (limit as f32 * self.overfetch_factor) as usize;
427
428        let vector_results = vector_query
429            .map(|q| self.vector_search.search(q, fetch_k))
430            .unwrap_or_default();
431
432        let lexical_results = text_query
433            .map(|q| self.lexical_search.search(q, fetch_k))
434            .unwrap_or_default();
435
436        let fusion = RRFFusion::new(self.fusion_config);
437        fusion.fuse(&vector_results, &lexical_results, limit, true)
438    }
439}
440
441// ============================================================================
442// Filter Integration
443// ============================================================================
444
445/// Post-filter results by metadata predicate
446pub fn filter_results<F>(
447    results: Vec<SearchResult>,
448    predicate: F,
449    limit: usize,
450) -> Vec<SearchResult>
451where
452    F: Fn(DocId) -> bool,
453{
454    results
455        .into_iter()
456        .filter(|r| predicate(r.doc_id))
457        .take(limit)
458        .collect()
459}
460
461// ============================================================================
462// Tests
463// ============================================================================
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468
469    #[test]
470    fn test_rrf_fusion_basic() {
471        let fusion = RRFFusion::default();
472
473        let vector_results = vec![(1, 0.95), (2, 0.90), (3, 0.85)];
474
475        let lexical_results = vec![
476            (2, 5.0), // Shared with vector
477            (4, 4.5),
478            (3, 4.0), // Shared with vector
479        ];
480
481        let results = fusion.fuse(&vector_results, &lexical_results, 10, false);
482
483        // Doc 2 appears in both lists - should rank high
484        assert!(!results.is_empty());
485
486        // Check that scores are computed
487        for r in &results {
488            assert!(r.score > 0.0);
489        }
490    }
491
492    #[test]
493    fn test_rrf_fusion_with_details() {
494        let fusion = RRFFusion::default();
495
496        let vector_results = vec![(1, 0.9), (2, 0.8)];
497        let lexical_results = vec![(2, 5.0), (3, 4.0)];
498
499        let results = fusion.fuse(&vector_results, &lexical_results, 10, true);
500
501        // Find doc 2 (appears in both)
502        let doc2 = results.iter().find(|r| r.doc_id == 2).unwrap();
503        let scores = doc2.component_scores.as_ref().unwrap();
504
505        assert_eq!(scores.vector_rank, Some(2)); // Rank 2 in vector results
506        assert_eq!(scores.lexical_rank, Some(1)); // Rank 1 in lexical results
507        assert_eq!(scores.vector_score, Some(0.8));
508        assert_eq!(scores.lexical_score, Some(5.0));
509    }
510
511    #[test]
512    fn test_rrf_ranking() {
513        let fusion = RRFFusion::default();
514
515        // Doc 1: rank 1 in vector, not in lexical
516        // Doc 2: rank 2 in vector, rank 1 in lexical
517        // Doc 2 should win because it appears in both
518        let vector_results = vec![(1, 0.95), (2, 0.90)];
519        let lexical_results = vec![(2, 5.0)];
520
521        let results = fusion.fuse(&vector_results, &lexical_results, 10, false);
522
523        assert_eq!(results[0].doc_id, 2); // Doc 2 should be first
524    }
525
526    #[test]
527    fn test_rrf_weights() {
528        // Heavy lexical weight
529        let config = RRFConfig::keyword_focused();
530        let fusion = RRFFusion::new(config);
531
532        // Doc 1: rank 1 in vector only
533        // Doc 2: rank 1 in lexical only
534        let vector_results = vec![(1, 0.95)];
535        let lexical_results = vec![(2, 5.0)];
536
537        let results = fusion.fuse(&vector_results, &lexical_results, 10, false);
538
539        // Doc 2 should win with keyword-focused config
540        assert_eq!(results[0].doc_id, 2);
541    }
542
543    #[test]
544    fn test_fuse_multi() {
545        let fusion = RRFFusion::default();
546
547        let list1: Vec<(DocId, f32)> = vec![(1, 0.9), (2, 0.8)];
548        let list2: Vec<(DocId, f32)> = vec![(2, 0.9), (3, 0.8)];
549        let list3: Vec<(DocId, f32)> = vec![(3, 0.9), (1, 0.8)];
550
551        let results = fusion.fuse_multi(&[(&list1, 1.0), (&list2, 1.0), (&list3, 1.0)], 10);
552
553        // All docs should appear
554        let doc_ids: Vec<_> = results.iter().map(|r| r.doc_id).collect();
555        assert!(doc_ids.contains(&1));
556        assert!(doc_ids.contains(&2));
557        assert!(doc_ids.contains(&3));
558    }
559
560    #[test]
561    fn test_fuse_multi_rrf_formula_golden() {
562        // CROSS-CRATE INVARIANT: pin the exact RRF kernel — weight / (k + rank),
563        // rank 1-indexed — to the SAME golden numbers asserted by sochdb-query's
564        // `fuse_rrf_weighted` test. The crates are siblings (no dependency
565        // edge), so each pins the shared formula independently; if either kernel
566        // drifts, its own golden test fails.
567        let k = 60.0_f32;
568        let fusion = RRFFusion::new(RRFConfig {
569            k,
570            vector_weight: 1.0,
571            lexical_weight: 1.0,
572        });
573
574        // Single weighted list: rank-1 => weight/(k+1), rank-2 => weight/(k+2).
575        let docs: Vec<(DocId, f32)> = vec![(7, 0.9), (8, 0.5)];
576        let single = fusion.fuse_multi(&[(&docs, 2.0)], 10);
577        let s7 = single.iter().find(|r| r.doc_id == 7).unwrap().score;
578        let s8 = single.iter().find(|r| r.doc_id == 8).unwrap().score;
579        assert!(
580            (s7 - 2.0 / (k + 1.0)).abs() < 1e-6,
581            "rank-1 must be 1-indexed weighted"
582        );
583        assert!(
584            (s8 - 2.0 / (k + 2.0)).abs() < 1e-6,
585            "rank-2 must be 1-indexed weighted"
586        );
587        assert!(s7 > s8, "earlier rank must score higher");
588
589        // A doc present in two weighted lists accumulates both contributions.
590        let la: Vec<(DocId, f32)> = vec![(1, 0.0)];
591        let lb: Vec<(DocId, f32)> = vec![(1, 0.0)];
592        let merged = fusion.fuse_multi(&[(&la, 1.0), (&lb, 3.0)], 10);
593        let s1 = merged.iter().find(|r| r.doc_id == 1).unwrap().score;
594        let expected = 1.0 / (k + 1.0) + 3.0 / (k + 1.0);
595        assert!(
596            (s1 - expected).abs() < 1e-6,
597            "weights must sum across lists"
598        );
599    }
600
601    #[test]
602    fn test_filter_results() {
603        let results = vec![
604            SearchResult {
605                doc_id: 1,
606                score: 0.9,
607                component_scores: None,
608            },
609            SearchResult {
610                doc_id: 2,
611                score: 0.8,
612                component_scores: None,
613            },
614            SearchResult {
615                doc_id: 3,
616                score: 0.7,
617                component_scores: None,
618            },
619            SearchResult {
620                doc_id: 4,
621                score: 0.6,
622                component_scores: None,
623            },
624        ];
625
626        // Filter to only even doc IDs
627        let filtered = filter_results(results, |id| id % 2 == 0, 10);
628
629        assert_eq!(filtered.len(), 2);
630        assert_eq!(filtered[0].doc_id, 2);
631        assert_eq!(filtered[1].doc_id, 4);
632    }
633}