Skip to main content

synaptic_retrieval/
ensemble.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use synaptic_core::SynapseError;
6
7use crate::{Document, Retriever};
8
9/// Standard RRF constant (k parameter in the RRF formula).
10const RRF_K: f64 = 60.0;
11
12/// A retriever that combines results from multiple retrievers using
13/// Reciprocal Rank Fusion (RRF) with configurable weights.
14pub struct EnsembleRetriever {
15    retrievers: Vec<(Arc<dyn Retriever>, f64)>,
16}
17
18impl EnsembleRetriever {
19    /// Create a new EnsembleRetriever with weighted retrievers.
20    ///
21    /// Each tuple is `(retriever, weight)`. The weight scales the RRF score
22    /// contribution of that retriever.
23    pub fn new(retrievers: Vec<(Arc<dyn Retriever>, f64)>) -> Self {
24        Self { retrievers }
25    }
26}
27
28#[async_trait]
29impl Retriever for EnsembleRetriever {
30    async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapseError> {
31        // Map from doc.id -> (rrf_score, Document)
32        let mut scores: HashMap<String, (f64, Document)> = HashMap::new();
33
34        for (retriever, weight) in &self.retrievers {
35            let docs = retriever.retrieve(query, top_k).await?;
36
37            for (rank, doc) in docs.iter().enumerate() {
38                // RRF score contribution: weight / (k + rank)
39                // rank is 0-based, so rank 0 = position 1
40                let rrf_score = weight / (RRF_K + (rank + 1) as f64);
41
42                scores
43                    .entry(doc.id.clone())
44                    .and_modify(|(existing_score, _)| {
45                        *existing_score += rrf_score;
46                    })
47                    .or_insert_with(|| (rrf_score, doc.clone()));
48            }
49        }
50
51        // Sort by RRF score descending
52        let mut sorted: Vec<(f64, Document)> = scores.into_values().collect();
53        sorted.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
54
55        Ok(sorted.into_iter().take(top_k).map(|(_, doc)| doc).collect())
56    }
57}