1use std::collections::{HashMap, HashSet};
5
6use crate::error::MemoryError;
7
8use super::store::GraphStore;
9use super::types::GraphFact;
10
11pub async fn graph_recall(
24 store: &GraphStore,
25 _embeddings: Option<&crate::embedding_store::EmbeddingStore>,
26 _provider: &zeph_llm::any::AnyProvider,
27 query: &str,
28 limit: usize,
29 max_hops: u32,
30) -> Result<Vec<GraphFact>, MemoryError> {
31 const MAX_WORDS: usize = 5;
33
34 if limit == 0 {
35 return Ok(Vec::new());
36 }
37
38 let filtered: Vec<&str> = query
41 .split_whitespace()
42 .filter(|w| w.len() >= 3)
43 .take(MAX_WORDS)
44 .collect();
45 let words: Vec<&str> = if filtered.is_empty() && !query.is_empty() {
46 vec![query]
47 } else {
48 filtered
49 };
50
51 let mut entity_scores: HashMap<i64, f32> = HashMap::new();
52
53 for word in &words {
54 let matches = store.find_entities_fuzzy(word, limit * 2).await?;
55 for entity in matches {
56 entity_scores
57 .entry(entity.id)
58 .and_modify(|s| *s = s.max(1.0))
59 .or_insert(1.0);
60 }
61 }
62
63 if entity_scores.is_empty() {
64 return Ok(Vec::new());
65 }
66
67 let mut all_facts: Vec<GraphFact> = Vec::new();
69
70 for (seed_id, seed_score) in &entity_scores {
71 let (entities, edges, depth_map) = store.bfs_with_depth(*seed_id, max_hops).await?;
72
73 let name_map: HashMap<i64, &str> = entities
76 .iter()
77 .map(|e| (e.id, e.canonical_name.as_str()))
78 .collect();
79
80 for edge in &edges {
81 let Some(&hop_distance) = depth_map
82 .get(&edge.source_entity_id)
83 .or_else(|| depth_map.get(&edge.target_entity_id))
84 else {
85 continue;
86 };
87
88 let entity_name = name_map
89 .get(&edge.source_entity_id)
90 .copied()
91 .unwrap_or_default();
92 let target_name = name_map
93 .get(&edge.target_entity_id)
94 .copied()
95 .unwrap_or_default();
96
97 if entity_name.is_empty() || target_name.is_empty() {
98 continue;
99 }
100
101 all_facts.push(GraphFact {
102 entity_name: entity_name.to_owned(),
103 relation: edge.relation.clone(),
104 target_name: target_name.to_owned(),
105 fact: edge.fact.clone(),
106 entity_match_score: *seed_score,
107 hop_distance,
108 confidence: edge.confidence,
109 });
110 }
111 }
112
113 all_facts.sort_by(|a, b| b.composite_score().total_cmp(&a.composite_score()));
116
117 let mut seen: HashSet<(String, String, String)> = HashSet::new();
118 all_facts.retain(|f| {
119 seen.insert((
120 f.entity_name.clone(),
121 f.relation.clone(),
122 f.target_name.clone(),
123 ))
124 });
125
126 all_facts.truncate(limit);
128
129 Ok(all_facts)
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use crate::graph::store::GraphStore;
136 use crate::graph::types::EntityType;
137 use crate::sqlite::SqliteStore;
138 use zeph_llm::any::AnyProvider;
139 use zeph_llm::mock::MockProvider;
140
141 async fn setup_store() -> GraphStore {
142 let store = SqliteStore::new(":memory:").await.unwrap();
143 GraphStore::new(store.pool().clone())
144 }
145
146 fn mock_provider() -> AnyProvider {
147 AnyProvider::Mock(MockProvider::default())
148 }
149
150 #[tokio::test]
151 async fn graph_recall_empty_graph_returns_empty() {
152 let store = setup_store().await;
153 let provider = mock_provider();
154 let result = graph_recall(&store, None, &provider, "anything", 10, 2)
155 .await
156 .unwrap();
157 assert!(result.is_empty());
158 }
159
160 #[tokio::test]
161 async fn graph_recall_zero_limit_returns_empty() {
162 let store = setup_store().await;
163 let provider = mock_provider();
164 let result = graph_recall(&store, None, &provider, "user", 0, 2)
165 .await
166 .unwrap();
167 assert!(result.is_empty());
168 }
169
170 #[tokio::test]
171 async fn graph_recall_fuzzy_match_returns_facts() {
172 let store = setup_store().await;
173 let user_id = store
174 .upsert_entity("Alice", "Alice", EntityType::Person, None)
175 .await
176 .unwrap();
177 let tool_id = store
178 .upsert_entity("neovim", "neovim", EntityType::Tool, None)
179 .await
180 .unwrap();
181 store
182 .insert_edge(user_id, tool_id, "uses", "Alice uses neovim", 0.9, None)
183 .await
184 .unwrap();
185
186 let provider = mock_provider();
187 let result = graph_recall(&store, None, &provider, "Ali neovim", 10, 2)
189 .await
190 .unwrap();
191 assert!(!result.is_empty());
192 assert_eq!(result[0].relation, "uses");
193 }
194
195 #[tokio::test]
196 async fn graph_recall_respects_max_hops() {
197 let store = setup_store().await;
198 let a = store
199 .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
200 .await
201 .unwrap();
202 let b = store
203 .upsert_entity("Beta", "Beta", EntityType::Person, None)
204 .await
205 .unwrap();
206 let c = store
207 .upsert_entity("Gamma", "Gamma", EntityType::Person, None)
208 .await
209 .unwrap();
210 store
211 .insert_edge(a, b, "knows", "Alpha knows Beta", 0.8, None)
212 .await
213 .unwrap();
214 store
215 .insert_edge(b, c, "knows", "Beta knows Gamma", 0.8, None)
216 .await
217 .unwrap();
218
219 let provider = mock_provider();
220 let result = graph_recall(&store, None, &provider, "Alp", 10, 1)
222 .await
223 .unwrap();
224 assert!(result.iter().all(|f| f.hop_distance <= 1));
226 }
227
228 #[tokio::test]
229 async fn graph_recall_deduplicates_facts() {
230 let store = setup_store().await;
231 let alice = store
232 .upsert_entity("Alice", "Alice", EntityType::Person, None)
233 .await
234 .unwrap();
235 let bob = store
236 .upsert_entity("Bob", "Bob", EntityType::Person, None)
237 .await
238 .unwrap();
239 store
240 .insert_edge(alice, bob, "knows", "Alice knows Bob", 0.9, None)
241 .await
242 .unwrap();
243
244 let provider = mock_provider();
245 let result = graph_recall(&store, None, &provider, "Ali Bob", 10, 2)
247 .await
248 .unwrap();
249
250 let mut seen = std::collections::HashSet::new();
252 for f in &result {
253 let key = (&f.entity_name, &f.relation, &f.target_name);
254 assert!(seen.insert(key), "duplicate fact found: {f:?}");
255 }
256 }
257
258 #[tokio::test]
259 async fn graph_recall_sorts_by_composite_score() {
260 let store = setup_store().await;
261 let a = store
262 .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
263 .await
264 .unwrap();
265 let b = store
266 .upsert_entity("Beta", "Beta", EntityType::Tool, None)
267 .await
268 .unwrap();
269 let c = store
270 .upsert_entity("AlphaGadget", "AlphaGadget", EntityType::Tool, None)
271 .await
272 .unwrap();
273 store
275 .insert_edge(a, b, "uses", "Alpha uses Beta", 1.0, None)
276 .await
277 .unwrap();
278 store
280 .insert_edge(a, c, "mentions", "Alpha mentions AlphaGadget", 0.1, None)
281 .await
282 .unwrap();
283
284 let provider = mock_provider();
285 let result = graph_recall(&store, None, &provider, "Alp", 10, 2)
286 .await
287 .unwrap();
288
289 assert!(result.len() >= 2);
291 let s0 = result[0].composite_score();
292 let s1 = result[1].composite_score();
293 assert!(s0 >= s1, "expected sorted desc: {s0} >= {s1}");
294 }
295
296 #[tokio::test]
297 async fn graph_recall_limit_truncates() {
298 let store = setup_store().await;
299 let root = store
300 .upsert_entity("Root", "Root", EntityType::Person, None)
301 .await
302 .unwrap();
303 for i in 0..10 {
304 let target = store
305 .upsert_entity(
306 &format!("Target{i}"),
307 &format!("Target{i}"),
308 EntityType::Tool,
309 None,
310 )
311 .await
312 .unwrap();
313 store
314 .insert_edge(
315 root,
316 target,
317 "has",
318 &format!("Root has Target{i}"),
319 0.8,
320 None,
321 )
322 .await
323 .unwrap();
324 }
325
326 let provider = mock_provider();
327 let result = graph_recall(&store, None, &provider, "Roo", 3, 2)
328 .await
329 .unwrap();
330 assert!(result.len() <= 3);
331 }
332}