1use std::collections::{HashMap, HashSet};
5use std::time::{SystemTime, UNIX_EPOCH};
6
7use crate::error::MemoryError;
8
9use super::store::GraphStore;
10use super::types::GraphFact;
11
12#[allow(clippy::too_many_arguments)]
33pub async fn graph_recall(
34 store: &GraphStore,
35 _embeddings: Option<&crate::embedding_store::EmbeddingStore>,
36 _provider: &zeph_llm::any::AnyProvider,
37 query: &str,
38 limit: usize,
39 max_hops: u32,
40 at_timestamp: Option<&str>,
41 temporal_decay_rate: f64,
42) -> Result<Vec<GraphFact>, MemoryError> {
43 const MAX_WORDS: usize = 5;
45
46 if limit == 0 {
47 return Ok(Vec::new());
48 }
49
50 let filtered: Vec<&str> = query
53 .split_whitespace()
54 .filter(|w| w.len() >= 3)
55 .take(MAX_WORDS)
56 .collect();
57 let words: Vec<&str> = if filtered.is_empty() && !query.is_empty() {
58 vec![query]
59 } else {
60 filtered
61 };
62
63 let mut entity_scores: HashMap<i64, f32> = HashMap::new();
64
65 for word in &words {
66 let matches = store.find_entities_fuzzy(word, limit * 2).await?;
67 for entity in matches {
68 entity_scores
69 .entry(entity.id)
70 .and_modify(|s| *s = s.max(1.0))
71 .or_insert(1.0);
72 }
73 }
74
75 if entity_scores.is_empty() {
76 return Ok(Vec::new());
77 }
78
79 let now_secs: i64 = SystemTime::now()
81 .duration_since(UNIX_EPOCH)
82 .map(|d| d.as_secs().cast_signed())
83 .unwrap_or(0);
84
85 let mut all_facts: Vec<GraphFact> = Vec::new();
87
88 for (seed_id, seed_score) in &entity_scores {
89 let (entities, edges, depth_map) = if let Some(ts) = at_timestamp {
90 store.bfs_at_timestamp(*seed_id, max_hops, ts).await?
91 } else {
92 store.bfs_with_depth(*seed_id, max_hops).await?
93 };
94
95 let name_map: HashMap<i64, &str> = entities
98 .iter()
99 .map(|e| (e.id, e.canonical_name.as_str()))
100 .collect();
101
102 for edge in &edges {
103 let Some(&hop_distance) = depth_map
104 .get(&edge.source_entity_id)
105 .or_else(|| depth_map.get(&edge.target_entity_id))
106 else {
107 continue;
108 };
109
110 let entity_name = name_map
111 .get(&edge.source_entity_id)
112 .copied()
113 .unwrap_or_default();
114 let target_name = name_map
115 .get(&edge.target_entity_id)
116 .copied()
117 .unwrap_or_default();
118
119 if entity_name.is_empty() || target_name.is_empty() {
120 continue;
121 }
122
123 all_facts.push(GraphFact {
124 entity_name: entity_name.to_owned(),
125 relation: edge.relation.clone(),
126 target_name: target_name.to_owned(),
127 fact: edge.fact.clone(),
128 entity_match_score: *seed_score,
129 hop_distance,
130 confidence: edge.confidence,
131 valid_from: Some(edge.valid_from.clone()),
132 });
133 }
134 }
135
136 let mut scored: Vec<(f32, GraphFact)> = all_facts
140 .into_iter()
141 .map(|f| {
142 let s = f.score_with_decay(temporal_decay_rate, now_secs);
143 (s, f)
144 })
145 .collect();
146 scored.sort_by(|(sa, _), (sb, _)| sb.total_cmp(sa));
147 let mut all_facts: Vec<GraphFact> = scored.into_iter().map(|(_, f)| f).collect();
148
149 let mut seen: HashSet<(String, String, String)> = HashSet::new();
150 all_facts.retain(|f| {
151 seen.insert((
152 f.entity_name.clone(),
153 f.relation.clone(),
154 f.target_name.clone(),
155 ))
156 });
157
158 all_facts.truncate(limit);
160
161 Ok(all_facts)
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use crate::graph::store::GraphStore;
168 use crate::graph::types::EntityType;
169 use crate::sqlite::SqliteStore;
170 use zeph_llm::any::AnyProvider;
171 use zeph_llm::mock::MockProvider;
172
173 async fn setup_store() -> GraphStore {
174 let store = SqliteStore::new(":memory:").await.unwrap();
175 GraphStore::new(store.pool().clone())
176 }
177
178 fn mock_provider() -> AnyProvider {
179 AnyProvider::Mock(MockProvider::default())
180 }
181
182 #[tokio::test]
183 async fn graph_recall_empty_graph_returns_empty() {
184 let store = setup_store().await;
185 let provider = mock_provider();
186 let result = graph_recall(&store, None, &provider, "anything", 10, 2, None, 0.0)
187 .await
188 .unwrap();
189 assert!(result.is_empty());
190 }
191
192 #[tokio::test]
193 async fn graph_recall_zero_limit_returns_empty() {
194 let store = setup_store().await;
195 let provider = mock_provider();
196 let result = graph_recall(&store, None, &provider, "user", 0, 2, None, 0.0)
197 .await
198 .unwrap();
199 assert!(result.is_empty());
200 }
201
202 #[tokio::test]
203 async fn graph_recall_fuzzy_match_returns_facts() {
204 let store = setup_store().await;
205 let user_id = store
206 .upsert_entity("Alice", "Alice", EntityType::Person, None)
207 .await
208 .unwrap();
209 let tool_id = store
210 .upsert_entity("neovim", "neovim", EntityType::Tool, None)
211 .await
212 .unwrap();
213 store
214 .insert_edge(user_id, tool_id, "uses", "Alice uses neovim", 0.9, None)
215 .await
216 .unwrap();
217
218 let provider = mock_provider();
219 let result = graph_recall(&store, None, &provider, "Ali neovim", 10, 2, None, 0.0)
221 .await
222 .unwrap();
223 assert!(!result.is_empty());
224 assert_eq!(result[0].relation, "uses");
225 }
226
227 #[tokio::test]
228 async fn graph_recall_respects_max_hops() {
229 let store = setup_store().await;
230 let a = store
231 .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
232 .await
233 .unwrap();
234 let b = store
235 .upsert_entity("Beta", "Beta", EntityType::Person, None)
236 .await
237 .unwrap();
238 let c = store
239 .upsert_entity("Gamma", "Gamma", EntityType::Person, None)
240 .await
241 .unwrap();
242 store
243 .insert_edge(a, b, "knows", "Alpha knows Beta", 0.8, None)
244 .await
245 .unwrap();
246 store
247 .insert_edge(b, c, "knows", "Beta knows Gamma", 0.8, None)
248 .await
249 .unwrap();
250
251 let provider = mock_provider();
252 let result = graph_recall(&store, None, &provider, "Alp", 10, 1, None, 0.0)
254 .await
255 .unwrap();
256 assert!(result.iter().all(|f| f.hop_distance <= 1));
258 }
259
260 #[tokio::test]
261 async fn graph_recall_deduplicates_facts() {
262 let store = setup_store().await;
263 let alice = store
264 .upsert_entity("Alice", "Alice", EntityType::Person, None)
265 .await
266 .unwrap();
267 let bob = store
268 .upsert_entity("Bob", "Bob", EntityType::Person, None)
269 .await
270 .unwrap();
271 store
272 .insert_edge(alice, bob, "knows", "Alice knows Bob", 0.9, None)
273 .await
274 .unwrap();
275
276 let provider = mock_provider();
277 let result = graph_recall(&store, None, &provider, "Ali Bob", 10, 2, None, 0.0)
279 .await
280 .unwrap();
281
282 let mut seen = std::collections::HashSet::new();
284 for f in &result {
285 let key = (&f.entity_name, &f.relation, &f.target_name);
286 assert!(seen.insert(key), "duplicate fact found: {f:?}");
287 }
288 }
289
290 #[tokio::test]
291 async fn graph_recall_sorts_by_composite_score() {
292 let store = setup_store().await;
293 let a = store
294 .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
295 .await
296 .unwrap();
297 let b = store
298 .upsert_entity("Beta", "Beta", EntityType::Tool, None)
299 .await
300 .unwrap();
301 let c = store
302 .upsert_entity("AlphaGadget", "AlphaGadget", EntityType::Tool, None)
303 .await
304 .unwrap();
305 store
307 .insert_edge(a, b, "uses", "Alpha uses Beta", 1.0, None)
308 .await
309 .unwrap();
310 store
312 .insert_edge(a, c, "mentions", "Alpha mentions AlphaGadget", 0.1, None)
313 .await
314 .unwrap();
315
316 let provider = mock_provider();
317 let result = graph_recall(&store, None, &provider, "Alp", 10, 2, None, 0.0)
318 .await
319 .unwrap();
320
321 assert!(result.len() >= 2);
323 let s0 = result[0].composite_score();
324 let s1 = result[1].composite_score();
325 assert!(s0 >= s1, "expected sorted desc: {s0} >= {s1}");
326 }
327
328 #[tokio::test]
329 async fn graph_recall_limit_truncates() {
330 let store = setup_store().await;
331 let root = store
332 .upsert_entity("Root", "Root", EntityType::Person, None)
333 .await
334 .unwrap();
335 for i in 0..10 {
336 let target = store
337 .upsert_entity(
338 &format!("Target{i}"),
339 &format!("Target{i}"),
340 EntityType::Tool,
341 None,
342 )
343 .await
344 .unwrap();
345 store
346 .insert_edge(
347 root,
348 target,
349 "has",
350 &format!("Root has Target{i}"),
351 0.8,
352 None,
353 )
354 .await
355 .unwrap();
356 }
357
358 let provider = mock_provider();
359 let result = graph_recall(&store, None, &provider, "Roo", 3, 2, None, 0.0)
360 .await
361 .unwrap();
362 assert!(result.len() <= 3);
363 }
364
365 #[tokio::test]
366 async fn graph_recall_at_timestamp_excludes_future_edges() {
367 let store = setup_store().await;
368 let alice = store
369 .upsert_entity("Alice", "Alice", EntityType::Person, None)
370 .await
371 .unwrap();
372 let bob = store
373 .upsert_entity("Bob", "Bob", EntityType::Person, None)
374 .await
375 .unwrap();
376 sqlx::query(
378 "INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
379 VALUES (?1, ?2, 'knows', 'Alice knows Bob', 0.9, '2100-01-01 00:00:00')",
380 )
381 .bind(alice)
382 .bind(bob)
383 .execute(store.pool())
384 .await
385 .unwrap();
386
387 let provider = mock_provider();
388 let result = graph_recall(
390 &store,
391 None,
392 &provider,
393 "Ali",
394 10,
395 2,
396 Some("2026-01-01 00:00:00"),
397 0.0,
398 )
399 .await
400 .unwrap();
401 assert!(result.is_empty(), "future edge should be excluded");
402 }
403
404 #[tokio::test]
405 async fn graph_recall_at_timestamp_excludes_invalidated_edges() {
406 let store = setup_store().await;
407 let alice = store
408 .upsert_entity("Alice", "Alice", EntityType::Person, None)
409 .await
410 .unwrap();
411 let carol = store
412 .upsert_entity("Carol", "Carol", EntityType::Person, None)
413 .await
414 .unwrap();
415 sqlx::query(
417 "INSERT INTO graph_edges
418 (source_entity_id, target_entity_id, relation, fact, confidence, valid_from, valid_to, expired_at)
419 VALUES (?1, ?2, 'manages', 'Alice manages Carol', 0.8,
420 '2020-01-01 00:00:00', '2021-01-01 00:00:00', '2021-01-01 00:00:00')",
421 )
422 .bind(alice)
423 .bind(carol)
424 .execute(store.pool())
425 .await
426 .unwrap();
427
428 let provider = mock_provider();
429
430 let result_current = graph_recall(&store, None, &provider, "Ali", 10, 2, None, 0.0)
432 .await
433 .unwrap();
434 assert!(
435 result_current.is_empty(),
436 "expired edge should be invisible at current time"
437 );
438
439 let result_historical = graph_recall(
441 &store,
442 None,
443 &provider,
444 "Ali",
445 10,
446 2,
447 Some("2020-06-01 00:00:00"),
448 0.0,
449 )
450 .await
451 .unwrap();
452 assert!(
453 !result_historical.is_empty(),
454 "edge should be visible within its validity window"
455 );
456 }
457
458 #[tokio::test]
459 async fn graph_recall_temporal_decay_preserves_order_with_zero_rate() {
460 let store = setup_store().await;
461 let a = store
462 .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
463 .await
464 .unwrap();
465 let b = store
466 .upsert_entity("Beta", "Beta", EntityType::Tool, None)
467 .await
468 .unwrap();
469 let c = store
470 .upsert_entity("AlphaGadget", "AlphaGadget", EntityType::Tool, None)
471 .await
472 .unwrap();
473 store
474 .insert_edge(a, b, "uses", "Alpha uses Beta", 1.0, None)
475 .await
476 .unwrap();
477 store
478 .insert_edge(a, c, "mentions", "Alpha mentions AlphaGadget", 0.1, None)
479 .await
480 .unwrap();
481
482 let provider = mock_provider();
483 let result = graph_recall(&store, None, &provider, "Alp", 10, 2, None, 0.0)
485 .await
486 .unwrap();
487 assert!(result.len() >= 2);
488 let s0 = result[0].composite_score();
489 let s1 = result[1].composite_score();
490 assert!(s0 >= s1, "expected sorted desc: {s0} >= {s1}");
491 }
492}