synaptic_retrieval/
multi_query.rs1use std::collections::HashSet;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use synaptic_core::{ChatModel, ChatRequest, Message, SynapseError};
6
7use crate::{Document, Retriever};
8
9pub struct MultiQueryRetriever {
12 base: Arc<dyn Retriever>,
13 model: Arc<dyn ChatModel>,
14 num_queries: usize,
15}
16
17impl MultiQueryRetriever {
18 pub fn new(base: Arc<dyn Retriever>, model: Arc<dyn ChatModel>) -> Self {
20 Self {
21 base,
22 model,
23 num_queries: 3,
24 }
25 }
26
27 pub fn with_num_queries(
29 base: Arc<dyn Retriever>,
30 model: Arc<dyn ChatModel>,
31 num_queries: usize,
32 ) -> Self {
33 Self {
34 base,
35 model,
36 num_queries,
37 }
38 }
39
40 async fn generate_queries(&self, query: &str) -> Result<Vec<String>, SynapseError> {
42 let prompt = format!(
43 "You are an AI language model assistant. Your task is to generate {} \
44 different versions of the given user question to retrieve relevant documents \
45 from a vector database. By generating multiple perspectives on the user question, \
46 your goal is to help the user overcome some of the limitations of distance-based \
47 similarity search. Provide these alternative questions separated by newlines. \
48 Only output the questions, nothing else.\n\nOriginal question: {}",
49 self.num_queries, query
50 );
51
52 let request = ChatRequest::new(vec![Message::human(prompt)]);
53 let response = self.model.chat(request).await?;
54 let content = response.message.content().to_string();
55
56 let queries: Vec<String> = content
57 .lines()
58 .map(|line| line.trim().to_string())
59 .filter(|line| !line.is_empty())
60 .collect();
61
62 Ok(queries)
63 }
64}
65
66#[async_trait]
67impl Retriever for MultiQueryRetriever {
68 async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapseError> {
69 let alternative_queries = self.generate_queries(query).await?;
71
72 let mut all_queries = vec![query.to_string()];
74 all_queries.extend(alternative_queries);
75
76 let mut seen_ids = HashSet::new();
78 let mut results = Vec::new();
79
80 for q in &all_queries {
81 let docs = self.base.retrieve(q, top_k).await?;
82 for doc in docs {
83 if seen_ids.insert(doc.id.clone()) {
84 results.push(doc);
85 }
86 }
87 }
88
89 results.truncate(top_k);
91 Ok(results)
92 }
93}