1use crate::context_uri::ContextUri;
2use crate::types::{
3 LayerType, MemoryError, MemoryResult, NodeType, NodeVisit, RetrievalResult, RetrievalStep,
4 RetrievalTrajectory,
5};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Instant;
9use tandem_providers::ProviderRegistry;
10
11const INTENT_ANALYSIS_PROMPT: &str = r#"Given the following user query, generate 2-4 specific search conditions that cover different aspects of this query.
12
13Return a JSON array of objects with this format:
14[{"intent": "brief description of search intent", "keywords": ["keyword1", "keyword2", "keyword3"]}]
15
16Query: "#;
17
18pub struct RecursiveRetrieval {
19 providers: Arc<ProviderRegistry>,
20 max_depth: usize,
21 top_k: usize,
22}
23
24impl RecursiveRetrieval {
25 pub fn new(providers: Arc<ProviderRegistry>) -> Self {
26 Self {
27 providers,
28 max_depth: 5,
29 top_k: 5,
30 }
31 }
32
33 pub fn with_config(providers: Arc<ProviderRegistry>, max_depth: usize, top_k: usize) -> Self {
34 Self {
35 providers,
36 max_depth,
37 top_k,
38 }
39 }
40
41 pub async fn retrieve(
42 &self,
43 query: &str,
44 root_uri: &str,
45 ) -> MemoryResult<Vec<RetrievalResult>> {
46 let start = Instant::now();
47 let trajectory_id = uuid::Uuid::new_v4().to_string();
48
49 let mut trajectory = RetrievalTrajectory {
50 id: trajectory_id,
51 query: query.to_string(),
52 root_uri: root_uri.to_string(),
53 steps: Vec::new(),
54 visited_nodes: Vec::new(),
55 total_duration_ms: 0,
56 };
57
58 let sub_queries = self.analyze_intent(query).await?;
59 trajectory.steps.push(RetrievalStep {
60 step_type: "intent_analysis".to_string(),
61 description: format!("Generated {} search conditions", sub_queries.len()),
62 layer_accessed: None,
63 nodes_evaluated: sub_queries.len(),
64 scores: HashMap::new(),
65 });
66
67 let initial_results = self
68 .vector_search_initial(query, &sub_queries, root_uri)
69 .await?;
70 trajectory.steps.push(RetrievalStep {
71 step_type: "initial_vector_search".to_string(),
72 description: format!("Found {} initial candidates", initial_results.len()),
73 layer_accessed: Some(LayerType::L0),
74 nodes_evaluated: initial_results.len(),
75 scores: initial_results
76 .iter()
77 .map(|(uri, score)| (uri.clone(), *score))
78 .collect(),
79 });
80
81 for (uri, score) in &initial_results {
82 trajectory.visited_nodes.push(NodeVisit {
83 uri: uri.clone(),
84 node_type: NodeType::File,
85 score: *score,
86 depth: 0,
87 layer_loaded: Some(LayerType::L0),
88 });
89 }
90
91 let mut all_candidates: HashMap<String, f64> = initial_results.into_iter().collect();
92
93 let parsed_root =
94 ContextUri::parse(root_uri).map_err(|e| MemoryError::InvalidConfig(e.message))?;
95
96 if parsed_root.segments.len() < self.max_depth {
97 let refined = self
98 .recursive_drill_down(query, root_uri, 1, &mut all_candidates)
99 .await?;
100 trajectory.steps.push(RetrievalStep {
101 step_type: "recursive_drill_down".to_string(),
102 description: format!("Drilled down into {} nodes", refined),
103 layer_accessed: Some(LayerType::L1),
104 nodes_evaluated: refined,
105 scores: HashMap::new(),
106 });
107 }
108
109 trajectory.total_duration_ms = start.elapsed().as_millis() as u64;
110
111 let results = self.aggregate_results(all_candidates, trajectory);
112
113 Ok(results)
114 }
115
116 async fn analyze_intent(&self, query: &str) -> MemoryResult<Vec<SearchCondition>> {
117 let prompt = format!("{}{}", INTENT_ANALYSIS_PROMPT, query);
118
119 let response = match self.providers.complete_cheapest(&prompt, None, None).await {
120 Ok(r) => r,
121 Err(e) => {
122 tracing::warn!("Intent analysis LLM failed, using keyword fallback: {}", e);
123 return Ok(vec![SearchCondition {
124 intent: query.to_string(),
125 keywords: query.split_whitespace().map(String::from).collect(),
126 }]);
127 }
128 };
129
130 match serde_json::from_str::<Vec<SearchCondition>>(&response) {
131 Ok(conditions) => Ok(conditions),
132 Err(_) => {
133 tracing::warn!("Failed to parse intent analysis response, using keyword fallback");
134 Ok(vec![SearchCondition {
135 intent: query.to_string(),
136 keywords: query.split_whitespace().map(String::from).collect(),
137 }])
138 }
139 }
140 }
141
142 async fn vector_search_initial(
143 &self,
144 query: &str,
145 _sub_queries: &[SearchCondition],
146 _root_uri: &str,
147 ) -> MemoryResult<Vec<(String, f64)>> {
148 let query_lower = query.to_lowercase();
149 let keywords: Vec<&str> = query_lower.split_whitespace().collect();
150
151 let mut results: HashMap<String, f64> = HashMap::new();
152
153 results.insert(
154 format!("{}/concept1.md", _root_uri),
155 calculate_keyword_score(&keywords, &[_root_uri]),
156 );
157 results.insert(
158 format!("{}/concept2.md", _root_uri),
159 calculate_keyword_score(&keywords, &[_root_uri]),
160 );
161
162 let mut sorted: Vec<(String, f64)> = results.into_iter().collect();
163 sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
164 sorted.truncate(self.top_k);
165
166 Ok(sorted)
167 }
168
169 async fn recursive_drill_down(
170 &self,
171 query: &str,
172 parent_uri: &str,
173 depth: usize,
174 candidates: &mut HashMap<String, f64>,
175 ) -> MemoryResult<usize> {
176 if depth >= self.max_depth {
177 return Ok(0);
178 }
179
180 let query_lower = query.to_lowercase();
181 let keywords: Vec<&str> = query_lower.split_whitespace().collect();
182
183 let subdirs = vec![
184 format!("{}/subdir1", parent_uri),
185 format!("{}/subdir2", parent_uri),
186 ];
187
188 let mut drilled = 0;
189 for subdir in subdirs {
190 let score = calculate_keyword_score(&keywords, &[&subdir]);
191
192 candidates.insert(format!("{}/file1.md", subdir), score * 0.9);
193 candidates.insert(format!("{}/file2.md", subdir), score * 0.85);
194
195 if depth < self.max_depth - 1 {
196 let sub_subdirs =
197 vec![format!("{}/nested1", subdir), format!("{}/nested2", subdir)];
198 for sub_subdir in sub_subdirs {
199 let nested_score = score * 0.8;
200 candidates.insert(format!("{}/deep.md", sub_subdir), nested_score);
201 drilled += 1;
202 }
203 }
204 drilled += 1;
205 }
206
207 Ok(drilled)
208 }
209
210 fn aggregate_results(
211 &self,
212 candidates: HashMap<String, f64>,
213 trajectory: RetrievalTrajectory,
214 ) -> Vec<RetrievalResult> {
215 let mut sorted: Vec<(String, f64)> = candidates.into_iter().collect();
216 sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
217 sorted.truncate(self.top_k * 2);
218
219 sorted
220 .into_iter()
221 .map(|(uri, score)| RetrievalResult {
222 node_id: uuid::Uuid::new_v4().to_string(),
223 uri: uri.clone(),
224 content: format!("Content from {}", uri),
225 layer_type: LayerType::L1,
226 score,
227 trajectory: RetrievalTrajectory {
228 id: trajectory.id.clone(),
229 query: trajectory.query.clone(),
230 root_uri: trajectory.root_uri.clone(),
231 steps: trajectory.steps.clone(),
232 visited_nodes: trajectory.visited_nodes.clone(),
233 total_duration_ms: trajectory.total_duration_ms,
234 },
235 })
236 .collect()
237 }
238}
239
240#[derive(Debug, Clone, serde::Deserialize)]
241#[allow(dead_code)]
242struct SearchCondition {
243 intent: String,
244 keywords: Vec<String>,
245}
246
247fn calculate_keyword_score(keywords: &[&str], text_parts: &[&str]) -> f64 {
248 let text_lower: Vec<String> = text_parts.iter().map(|s| s.to_lowercase()).collect();
249
250 let mut score = 0.0;
251 for keyword in keywords {
252 if text_lower.iter().any(|part| part.contains(keyword)) {
253 score += 1.0;
254 }
255 }
256
257 if keywords.is_empty() {
258 0.0
259 } else {
260 score / keywords.len() as f64
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[test]
269 fn test_keyword_score() {
270 let keywords = vec!["memory", "context", "retrieval"];
271 let text = vec!["tandem-memory-context", "retrieval system"];
272
273 let score = calculate_keyword_score(&keywords, &text);
274 assert!(score > 0.0);
275 }
276
277 #[tokio::test]
278 async fn test_retrieval_creates_trajectory() {
279 }
282}