Skip to main content

synaptic_retrieval/
multi_query.rs

1use std::collections::HashSet;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use synaptic_core::{ChatModel, ChatRequest, Message, SynapseError};
6
7use crate::{Document, Retriever};
8
9/// A retriever that generates multiple query variants using a ChatModel,
10/// runs each through a base retriever, and deduplicates results by document id.
11pub struct MultiQueryRetriever {
12    base: Arc<dyn Retriever>,
13    model: Arc<dyn ChatModel>,
14    num_queries: usize,
15}
16
17impl MultiQueryRetriever {
18    /// Create a new MultiQueryRetriever with default num_queries (3).
19    pub fn new(base: Arc<dyn Retriever>, model: Arc<dyn ChatModel>) -> Self {
20        Self {
21            base,
22            model,
23            num_queries: 3,
24        }
25    }
26
27    /// Create a new MultiQueryRetriever with a custom number of query variants.
28    pub fn with_num_queries(
29        base: Arc<dyn Retriever>,
30        model: Arc<dyn ChatModel>,
31        num_queries: usize,
32    ) -> Self {
33        Self {
34            base,
35            model,
36            num_queries,
37        }
38    }
39
40    /// Generate alternative query variants using the ChatModel.
41    async fn generate_queries(&self, query: &str) -> Result<Vec<String>, SynapseError> {
42        let prompt = format!(
43            "You are an AI language model assistant. Your task is to generate {} \
44             different versions of the given user question to retrieve relevant documents \
45             from a vector database. By generating multiple perspectives on the user question, \
46             your goal is to help the user overcome some of the limitations of distance-based \
47             similarity search. Provide these alternative questions separated by newlines. \
48             Only output the questions, nothing else.\n\nOriginal question: {}",
49            self.num_queries, query
50        );
51
52        let request = ChatRequest::new(vec![Message::human(prompt)]);
53        let response = self.model.chat(request).await?;
54        let content = response.message.content().to_string();
55
56        let queries: Vec<String> = content
57            .lines()
58            .map(|line| line.trim().to_string())
59            .filter(|line| !line.is_empty())
60            .collect();
61
62        Ok(queries)
63    }
64}
65
66#[async_trait]
67impl Retriever for MultiQueryRetriever {
68    async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapseError> {
69        // Generate alternative queries
70        let alternative_queries = self.generate_queries(query).await?;
71
72        // Collect all queries: original + alternatives
73        let mut all_queries = vec![query.to_string()];
74        all_queries.extend(alternative_queries);
75
76        // Run each query through the base retriever and deduplicate
77        let mut seen_ids = HashSet::new();
78        let mut results = Vec::new();
79
80        for q in &all_queries {
81            let docs = self.base.retrieve(q, top_k).await?;
82            for doc in docs {
83                if seen_ids.insert(doc.id.clone()) {
84                    results.push(doc);
85                }
86            }
87        }
88
89        // Return up to top_k deduplicated results
90        results.truncate(top_k);
91        Ok(results)
92    }
93}