synaptic_retrieval/
ensemble.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use synaptic_core::SynapseError;
6
7use crate::{Document, Retriever};
8
9const RRF_K: f64 = 60.0;
11
12pub struct EnsembleRetriever {
15 retrievers: Vec<(Arc<dyn Retriever>, f64)>,
16}
17
18impl EnsembleRetriever {
19 pub fn new(retrievers: Vec<(Arc<dyn Retriever>, f64)>) -> Self {
24 Self { retrievers }
25 }
26}
27
28#[async_trait]
29impl Retriever for EnsembleRetriever {
30 async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapseError> {
31 let mut scores: HashMap<String, (f64, Document)> = HashMap::new();
33
34 for (retriever, weight) in &self.retrievers {
35 let docs = retriever.retrieve(query, top_k).await?;
36
37 for (rank, doc) in docs.iter().enumerate() {
38 let rrf_score = weight / (RRF_K + (rank + 1) as f64);
41
42 scores
43 .entry(doc.id.clone())
44 .and_modify(|(existing_score, _)| {
45 *existing_score += rrf_score;
46 })
47 .or_insert_with(|| (rrf_score, doc.clone()));
48 }
49 }
50
51 let mut sorted: Vec<(f64, Document)> = scores.into_values().collect();
53 sorted.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
54
55 Ok(sorted.into_iter().take(top_k).map(|(_, doc)| doc).collect())
56 }
57}