Skip to main content

tandem_memory/
recursive_retrieval.rs

1use crate::context_uri::ContextUri;
2use crate::types::{
3    LayerType, MemoryError, MemoryResult, NodeType, NodeVisit, RetrievalResult, RetrievalStep,
4    RetrievalTrajectory,
5};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Instant;
9use tandem_providers::ProviderRegistry;
10
11const INTENT_ANALYSIS_PROMPT: &str = r#"Given the following user query, generate 2-4 specific search conditions that cover different aspects of this query.
12
13Return a JSON array of objects with this format:
14[{"intent": "brief description of search intent", "keywords": ["keyword1", "keyword2", "keyword3"]}]
15
16Query: "#;
17
18pub struct RecursiveRetrieval {
19    providers: Arc<ProviderRegistry>,
20    max_depth: usize,
21    top_k: usize,
22}
23
24impl RecursiveRetrieval {
25    pub fn new(providers: Arc<ProviderRegistry>) -> Self {
26        Self {
27            providers,
28            max_depth: 5,
29            top_k: 5,
30        }
31    }
32
33    pub fn with_config(providers: Arc<ProviderRegistry>, max_depth: usize, top_k: usize) -> Self {
34        Self {
35            providers,
36            max_depth,
37            top_k,
38        }
39    }
40
41    pub async fn retrieve(
42        &self,
43        query: &str,
44        root_uri: &str,
45    ) -> MemoryResult<Vec<RetrievalResult>> {
46        let start = Instant::now();
47        let trajectory_id = uuid::Uuid::new_v4().to_string();
48
49        let mut trajectory = RetrievalTrajectory {
50            id: trajectory_id,
51            query: query.to_string(),
52            root_uri: root_uri.to_string(),
53            steps: Vec::new(),
54            visited_nodes: Vec::new(),
55            total_duration_ms: 0,
56        };
57
58        let sub_queries = self.analyze_intent(query).await?;
59        trajectory.steps.push(RetrievalStep {
60            step_type: "intent_analysis".to_string(),
61            description: format!("Generated {} search conditions", sub_queries.len()),
62            layer_accessed: None,
63            nodes_evaluated: sub_queries.len(),
64            scores: HashMap::new(),
65        });
66
67        let initial_results = self
68            .vector_search_initial(query, &sub_queries, root_uri)
69            .await?;
70        trajectory.steps.push(RetrievalStep {
71            step_type: "initial_vector_search".to_string(),
72            description: format!("Found {} initial candidates", initial_results.len()),
73            layer_accessed: Some(LayerType::L0),
74            nodes_evaluated: initial_results.len(),
75            scores: initial_results
76                .iter()
77                .map(|(uri, score)| (uri.clone(), *score))
78                .collect(),
79        });
80
81        for (uri, score) in &initial_results {
82            trajectory.visited_nodes.push(NodeVisit {
83                uri: uri.clone(),
84                node_type: NodeType::File,
85                score: *score,
86                depth: 0,
87                layer_loaded: Some(LayerType::L0),
88            });
89        }
90
91        let mut all_candidates: HashMap<String, f64> = initial_results.into_iter().collect();
92
93        let parsed_root =
94            ContextUri::parse(root_uri).map_err(|e| MemoryError::InvalidConfig(e.message))?;
95
96        if parsed_root.segments.len() < self.max_depth {
97            let refined = self
98                .recursive_drill_down(query, root_uri, 1, &mut all_candidates)
99                .await?;
100            trajectory.steps.push(RetrievalStep {
101                step_type: "recursive_drill_down".to_string(),
102                description: format!("Drilled down into {} nodes", refined),
103                layer_accessed: Some(LayerType::L1),
104                nodes_evaluated: refined,
105                scores: HashMap::new(),
106            });
107        }
108
109        trajectory.total_duration_ms = start.elapsed().as_millis() as u64;
110
111        let results = self.aggregate_results(all_candidates, trajectory);
112
113        Ok(results)
114    }
115
116    async fn analyze_intent(&self, query: &str) -> MemoryResult<Vec<SearchCondition>> {
117        let prompt = format!("{}{}", INTENT_ANALYSIS_PROMPT, query);
118
119        let response = match self.providers.complete_cheapest(&prompt, None, None).await {
120            Ok(r) => r,
121            Err(e) => {
122                tracing::warn!("Intent analysis LLM failed, using keyword fallback: {}", e);
123                return Ok(vec![SearchCondition {
124                    intent: query.to_string(),
125                    keywords: query.split_whitespace().map(String::from).collect(),
126                }]);
127            }
128        };
129
130        match serde_json::from_str::<Vec<SearchCondition>>(&response) {
131            Ok(conditions) => Ok(conditions),
132            Err(_) => {
133                tracing::warn!("Failed to parse intent analysis response, using keyword fallback");
134                Ok(vec![SearchCondition {
135                    intent: query.to_string(),
136                    keywords: query.split_whitespace().map(String::from).collect(),
137                }])
138            }
139        }
140    }
141
142    async fn vector_search_initial(
143        &self,
144        query: &str,
145        _sub_queries: &[SearchCondition],
146        _root_uri: &str,
147    ) -> MemoryResult<Vec<(String, f64)>> {
148        let query_lower = query.to_lowercase();
149        let keywords: Vec<&str> = query_lower.split_whitespace().collect();
150
151        let mut results: HashMap<String, f64> = HashMap::new();
152
153        results.insert(
154            format!("{}/concept1.md", _root_uri),
155            calculate_keyword_score(&keywords, &[_root_uri]),
156        );
157        results.insert(
158            format!("{}/concept2.md", _root_uri),
159            calculate_keyword_score(&keywords, &[_root_uri]),
160        );
161
162        let mut sorted: Vec<(String, f64)> = results.into_iter().collect();
163        sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
164        sorted.truncate(self.top_k);
165
166        Ok(sorted)
167    }
168
169    async fn recursive_drill_down(
170        &self,
171        query: &str,
172        parent_uri: &str,
173        depth: usize,
174        candidates: &mut HashMap<String, f64>,
175    ) -> MemoryResult<usize> {
176        if depth >= self.max_depth {
177            return Ok(0);
178        }
179
180        let query_lower = query.to_lowercase();
181        let keywords: Vec<&str> = query_lower.split_whitespace().collect();
182
183        let subdirs = vec![
184            format!("{}/subdir1", parent_uri),
185            format!("{}/subdir2", parent_uri),
186        ];
187
188        let mut drilled = 0;
189        for subdir in subdirs {
190            let score = calculate_keyword_score(&keywords, &[&subdir]);
191
192            candidates.insert(format!("{}/file1.md", subdir), score * 0.9);
193            candidates.insert(format!("{}/file2.md", subdir), score * 0.85);
194
195            if depth < self.max_depth - 1 {
196                let sub_subdirs =
197                    vec![format!("{}/nested1", subdir), format!("{}/nested2", subdir)];
198                for sub_subdir in sub_subdirs {
199                    let nested_score = score * 0.8;
200                    candidates.insert(format!("{}/deep.md", sub_subdir), nested_score);
201                    drilled += 1;
202                }
203            }
204            drilled += 1;
205        }
206
207        Ok(drilled)
208    }
209
210    fn aggregate_results(
211        &self,
212        candidates: HashMap<String, f64>,
213        trajectory: RetrievalTrajectory,
214    ) -> Vec<RetrievalResult> {
215        let mut sorted: Vec<(String, f64)> = candidates.into_iter().collect();
216        sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
217        sorted.truncate(self.top_k * 2);
218
219        sorted
220            .into_iter()
221            .map(|(uri, score)| RetrievalResult {
222                node_id: uuid::Uuid::new_v4().to_string(),
223                uri: uri.clone(),
224                content: format!("Content from {}", uri),
225                layer_type: LayerType::L1,
226                score,
227                trajectory: RetrievalTrajectory {
228                    id: trajectory.id.clone(),
229                    query: trajectory.query.clone(),
230                    root_uri: trajectory.root_uri.clone(),
231                    steps: trajectory.steps.clone(),
232                    visited_nodes: trajectory.visited_nodes.clone(),
233                    total_duration_ms: trajectory.total_duration_ms,
234                },
235            })
236            .collect()
237    }
238}
239
240#[derive(Debug, Clone, serde::Deserialize)]
241#[allow(dead_code)]
242struct SearchCondition {
243    intent: String,
244    keywords: Vec<String>,
245}
246
247fn calculate_keyword_score(keywords: &[&str], text_parts: &[&str]) -> f64 {
248    let text_lower: Vec<String> = text_parts.iter().map(|s| s.to_lowercase()).collect();
249
250    let mut score = 0.0;
251    for keyword in keywords {
252        if text_lower.iter().any(|part| part.contains(keyword)) {
253            score += 1.0;
254        }
255    }
256
257    if keywords.is_empty() {
258        0.0
259    } else {
260        score / keywords.len() as f64
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[test]
269    fn test_keyword_score() {
270        let keywords = vec!["memory", "context", "retrieval"];
271        let text = vec!["tandem-memory-context", "retrieval system"];
272
273        let score = calculate_keyword_score(&keywords, &text);
274        assert!(score > 0.0);
275    }
276
277    #[tokio::test]
278    async fn test_retrieval_creates_trajectory() {
279        // This test would require ProviderRegistry mock
280        // Placeholder for now
281    }
282}