1use std::collections::{HashMap, HashSet};
5use std::time::{SystemTime, UNIX_EPOCH};
6
7use crate::error::MemoryError;
8
9use super::activation::{ActivatedFact, SpreadingActivation, SpreadingActivationParams};
10use super::store::GraphStore;
11use super::types::{EdgeType, GraphFact};
12
13#[allow(clippy::too_many_arguments)]
36pub async fn graph_recall(
37 store: &GraphStore,
38 _embeddings: Option<&crate::embedding_store::EmbeddingStore>,
39 _provider: &zeph_llm::any::AnyProvider,
40 query: &str,
41 limit: usize,
42 max_hops: u32,
43 at_timestamp: Option<&str>,
44 temporal_decay_rate: f64,
45 edge_types: &[EdgeType],
46) -> Result<Vec<GraphFact>, MemoryError> {
47 if limit == 0 {
48 return Ok(Vec::new());
49 }
50
51 let entity_scores = find_seed_entities(store, query, limit).await?;
53
54 if entity_scores.is_empty() {
55 return Ok(Vec::new());
56 }
57
58 let now_secs: i64 = SystemTime::now()
60 .duration_since(UNIX_EPOCH)
61 .map(|d| d.as_secs().cast_signed())
62 .unwrap_or(0);
63
64 let mut all_facts: Vec<GraphFact> = Vec::new();
66
67 for (seed_id, seed_score) in &entity_scores {
68 let (entities, edges, depth_map) = if let Some(ts) = at_timestamp {
69 store.bfs_at_timestamp(*seed_id, max_hops, ts).await?
70 } else if !edge_types.is_empty() {
71 store.bfs_typed(*seed_id, max_hops, edge_types).await?
72 } else {
73 store.bfs_with_depth(*seed_id, max_hops).await?
74 };
75
76 let name_map: HashMap<i64, &str> = entities
79 .iter()
80 .map(|e| (e.id, e.canonical_name.as_str()))
81 .collect();
82
83 for edge in &edges {
84 let Some(&hop_distance) = depth_map
85 .get(&edge.source_entity_id)
86 .or_else(|| depth_map.get(&edge.target_entity_id))
87 else {
88 continue;
89 };
90
91 let entity_name = name_map
92 .get(&edge.source_entity_id)
93 .copied()
94 .unwrap_or_default();
95 let target_name = name_map
96 .get(&edge.target_entity_id)
97 .copied()
98 .unwrap_or_default();
99
100 if entity_name.is_empty() || target_name.is_empty() {
101 continue;
102 }
103
104 all_facts.push(GraphFact {
105 entity_name: entity_name.to_owned(),
106 relation: edge.relation.clone(),
107 target_name: target_name.to_owned(),
108 fact: edge.fact.clone(),
109 entity_match_score: *seed_score,
110 hop_distance,
111 confidence: edge.confidence,
112 valid_from: Some(edge.valid_from.clone()),
113 edge_type: edge.edge_type,
114 });
115 }
116 }
117
118 let mut scored: Vec<(f32, GraphFact)> = all_facts
122 .into_iter()
123 .map(|f| {
124 let s = f.score_with_decay(temporal_decay_rate, now_secs);
125 (s, f)
126 })
127 .collect();
128 scored.sort_by(|(sa, _), (sb, _)| sb.total_cmp(sa));
129 let mut all_facts: Vec<GraphFact> = scored.into_iter().map(|(_, f)| f).collect();
130
131 let mut seen: HashSet<(String, String, String, EdgeType)> = HashSet::new();
135 all_facts.retain(|f| {
136 seen.insert((
137 f.entity_name.clone(),
138 f.relation.clone(),
139 f.target_name.clone(),
140 f.edge_type,
141 ))
142 });
143
144 all_facts.truncate(limit);
146
147 Ok(all_facts)
148}
149
150async fn find_seed_entities(
161 store: &GraphStore,
162 query: &str,
163 limit: usize,
164) -> Result<HashMap<i64, f32>, MemoryError> {
165 const MAX_WORDS: usize = 5;
166
167 let filtered: Vec<&str> = query
168 .split_whitespace()
169 .filter(|w| w.len() >= 3)
170 .take(MAX_WORDS)
171 .collect();
172 let words: Vec<&str> = if filtered.is_empty() && !query.is_empty() {
173 vec![query]
174 } else {
175 filtered
176 };
177
178 let mut entity_scores: HashMap<i64, f32> = HashMap::new();
179 for word in &words {
180 let matches = store.find_entities_fuzzy(word, limit * 2).await?;
181 for entity in matches {
182 entity_scores
183 .entry(entity.id)
184 .and_modify(|s| *s = s.max(1.0))
185 .or_insert(1.0);
186 }
187 }
188
189 Ok(entity_scores)
190}
191
192pub async fn graph_recall_activated(
207 store: &GraphStore,
208 query: &str,
209 limit: usize,
210 params: SpreadingActivationParams,
211 edge_types: &[EdgeType],
212) -> Result<Vec<ActivatedFact>, MemoryError> {
213 if limit == 0 {
214 return Ok(Vec::new());
215 }
216
217 let entity_scores = find_seed_entities(store, query, limit).await?;
218
219 if entity_scores.is_empty() {
220 return Ok(Vec::new());
221 }
222
223 tracing::debug!(
224 seeds = entity_scores.len(),
225 "spreading activation: starting recall"
226 );
227
228 let sa = SpreadingActivation::new(params);
229 let (_, mut facts) = sa.spread(store, entity_scores, edge_types).await?;
230
231 facts.sort_by(|a, b| b.activation_score.total_cmp(&a.activation_score));
233
234 let mut seen: HashSet<(i64, String, i64, EdgeType)> = HashSet::new();
236 facts.retain(|f| {
237 seen.insert((
238 f.edge.source_entity_id,
239 f.edge.relation.clone(),
240 f.edge.target_entity_id,
241 f.edge.edge_type,
242 ))
243 });
244
245 facts.truncate(limit);
246
247 tracing::debug!(
248 result_count = facts.len(),
249 "spreading activation: recall complete"
250 );
251
252 Ok(facts)
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use crate::graph::store::GraphStore;
259 use crate::graph::types::EntityType;
260 use crate::sqlite::SqliteStore;
261 use zeph_llm::any::AnyProvider;
262 use zeph_llm::mock::MockProvider;
263
264 async fn setup_store() -> GraphStore {
265 let store = SqliteStore::new(":memory:").await.unwrap();
266 GraphStore::new(store.pool().clone())
267 }
268
269 fn mock_provider() -> AnyProvider {
270 AnyProvider::Mock(MockProvider::default())
271 }
272
273 #[tokio::test]
274 async fn graph_recall_empty_graph_returns_empty() {
275 let store = setup_store().await;
276 let provider = mock_provider();
277 let result = graph_recall(&store, None, &provider, "anything", 10, 2, None, 0.0, &[])
278 .await
279 .unwrap();
280 assert!(result.is_empty());
281 }
282
283 #[tokio::test]
284 async fn graph_recall_zero_limit_returns_empty() {
285 let store = setup_store().await;
286 let provider = mock_provider();
287 let result = graph_recall(&store, None, &provider, "user", 0, 2, None, 0.0, &[])
288 .await
289 .unwrap();
290 assert!(result.is_empty());
291 }
292
293 #[tokio::test]
294 async fn graph_recall_fuzzy_match_returns_facts() {
295 let store = setup_store().await;
296 let user_id = store
297 .upsert_entity("Alice", "Alice", EntityType::Person, None)
298 .await
299 .unwrap();
300 let tool_id = store
301 .upsert_entity("neovim", "neovim", EntityType::Tool, None)
302 .await
303 .unwrap();
304 store
305 .insert_edge(user_id, tool_id, "uses", "Alice uses neovim", 0.9, None)
306 .await
307 .unwrap();
308
309 let provider = mock_provider();
310 let result = graph_recall(&store, None, &provider, "Ali neovim", 10, 2, None, 0.0, &[])
312 .await
313 .unwrap();
314 assert!(!result.is_empty());
315 assert_eq!(result[0].relation, "uses");
316 }
317
318 #[tokio::test]
319 async fn graph_recall_respects_max_hops() {
320 let store = setup_store().await;
321 let a = store
322 .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
323 .await
324 .unwrap();
325 let b = store
326 .upsert_entity("Beta", "Beta", EntityType::Person, None)
327 .await
328 .unwrap();
329 let c = store
330 .upsert_entity("Gamma", "Gamma", EntityType::Person, None)
331 .await
332 .unwrap();
333 store
334 .insert_edge(a, b, "knows", "Alpha knows Beta", 0.8, None)
335 .await
336 .unwrap();
337 store
338 .insert_edge(b, c, "knows", "Beta knows Gamma", 0.8, None)
339 .await
340 .unwrap();
341
342 let provider = mock_provider();
343 let result = graph_recall(&store, None, &provider, "Alp", 10, 1, None, 0.0, &[])
345 .await
346 .unwrap();
347 assert!(result.iter().all(|f| f.hop_distance <= 1));
349 }
350
351 #[tokio::test]
352 async fn graph_recall_deduplicates_facts() {
353 let store = setup_store().await;
354 let alice = store
355 .upsert_entity("Alice", "Alice", EntityType::Person, None)
356 .await
357 .unwrap();
358 let bob = store
359 .upsert_entity("Bob", "Bob", EntityType::Person, None)
360 .await
361 .unwrap();
362 store
363 .insert_edge(alice, bob, "knows", "Alice knows Bob", 0.9, None)
364 .await
365 .unwrap();
366
367 let provider = mock_provider();
368 let result = graph_recall(&store, None, &provider, "Ali Bob", 10, 2, None, 0.0, &[])
370 .await
371 .unwrap();
372
373 let mut seen = std::collections::HashSet::new();
375 for f in &result {
376 let key = (&f.entity_name, &f.relation, &f.target_name);
377 assert!(seen.insert(key), "duplicate fact found: {f:?}");
378 }
379 }
380
381 #[tokio::test]
382 async fn graph_recall_sorts_by_composite_score() {
383 let store = setup_store().await;
384 let a = store
385 .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
386 .await
387 .unwrap();
388 let b = store
389 .upsert_entity("Beta", "Beta", EntityType::Tool, None)
390 .await
391 .unwrap();
392 let c = store
393 .upsert_entity("AlphaGadget", "AlphaGadget", EntityType::Tool, None)
394 .await
395 .unwrap();
396 store
398 .insert_edge(a, b, "uses", "Alpha uses Beta", 1.0, None)
399 .await
400 .unwrap();
401 store
403 .insert_edge(a, c, "mentions", "Alpha mentions AlphaGadget", 0.1, None)
404 .await
405 .unwrap();
406
407 let provider = mock_provider();
408 let result = graph_recall(&store, None, &provider, "Alp", 10, 2, None, 0.0, &[])
409 .await
410 .unwrap();
411
412 assert!(result.len() >= 2);
414 let s0 = result[0].composite_score();
415 let s1 = result[1].composite_score();
416 assert!(s0 >= s1, "expected sorted desc: {s0} >= {s1}");
417 }
418
419 #[tokio::test]
420 async fn graph_recall_limit_truncates() {
421 let store = setup_store().await;
422 let root = store
423 .upsert_entity("Root", "Root", EntityType::Person, None)
424 .await
425 .unwrap();
426 for i in 0..10 {
427 let target = store
428 .upsert_entity(
429 &format!("Target{i}"),
430 &format!("Target{i}"),
431 EntityType::Tool,
432 None,
433 )
434 .await
435 .unwrap();
436 store
437 .insert_edge(
438 root,
439 target,
440 "has",
441 &format!("Root has Target{i}"),
442 0.8,
443 None,
444 )
445 .await
446 .unwrap();
447 }
448
449 let provider = mock_provider();
450 let result = graph_recall(&store, None, &provider, "Roo", 3, 2, None, 0.0, &[])
451 .await
452 .unwrap();
453 assert!(result.len() <= 3);
454 }
455
456 #[tokio::test]
457 async fn graph_recall_at_timestamp_excludes_future_edges() {
458 let store = setup_store().await;
459 let alice = store
460 .upsert_entity("Alice", "Alice", EntityType::Person, None)
461 .await
462 .unwrap();
463 let bob = store
464 .upsert_entity("Bob", "Bob", EntityType::Person, None)
465 .await
466 .unwrap();
467 sqlx::query(
469 "INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
470 VALUES (?1, ?2, 'knows', 'Alice knows Bob', 0.9, '2100-01-01 00:00:00')",
471 )
472 .bind(alice)
473 .bind(bob)
474 .execute(store.pool())
475 .await
476 .unwrap();
477
478 let provider = mock_provider();
479 let result = graph_recall(
481 &store,
482 None,
483 &provider,
484 "Ali",
485 10,
486 2,
487 Some("2026-01-01 00:00:00"),
488 0.0,
489 &[],
490 )
491 .await
492 .unwrap();
493 assert!(result.is_empty(), "future edge should be excluded");
494 }
495
496 #[tokio::test]
497 async fn graph_recall_at_timestamp_excludes_invalidated_edges() {
498 let store = setup_store().await;
499 let alice = store
500 .upsert_entity("Alice", "Alice", EntityType::Person, None)
501 .await
502 .unwrap();
503 let carol = store
504 .upsert_entity("Carol", "Carol", EntityType::Person, None)
505 .await
506 .unwrap();
507 sqlx::query(
509 "INSERT INTO graph_edges
510 (source_entity_id, target_entity_id, relation, fact, confidence, valid_from, valid_to, expired_at)
511 VALUES (?1, ?2, 'manages', 'Alice manages Carol', 0.8,
512 '2020-01-01 00:00:00', '2021-01-01 00:00:00', '2021-01-01 00:00:00')",
513 )
514 .bind(alice)
515 .bind(carol)
516 .execute(store.pool())
517 .await
518 .unwrap();
519
520 let provider = mock_provider();
521
522 let result_current = graph_recall(&store, None, &provider, "Ali", 10, 2, None, 0.0, &[])
524 .await
525 .unwrap();
526 assert!(
527 result_current.is_empty(),
528 "expired edge should be invisible at current time"
529 );
530
531 let result_historical = graph_recall(
533 &store,
534 None,
535 &provider,
536 "Ali",
537 10,
538 2,
539 Some("2020-06-01 00:00:00"),
540 0.0,
541 &[],
542 )
543 .await
544 .unwrap();
545 assert!(
546 !result_historical.is_empty(),
547 "edge should be visible within its validity window"
548 );
549 }
550
551 #[tokio::test]
552 async fn graph_recall_temporal_decay_preserves_order_with_zero_rate() {
553 let store = setup_store().await;
554 let a = store
555 .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
556 .await
557 .unwrap();
558 let b = store
559 .upsert_entity("Beta", "Beta", EntityType::Tool, None)
560 .await
561 .unwrap();
562 let c = store
563 .upsert_entity("AlphaGadget", "AlphaGadget", EntityType::Tool, None)
564 .await
565 .unwrap();
566 store
567 .insert_edge(a, b, "uses", "Alpha uses Beta", 1.0, None)
568 .await
569 .unwrap();
570 store
571 .insert_edge(a, c, "mentions", "Alpha mentions AlphaGadget", 0.1, None)
572 .await
573 .unwrap();
574
575 let provider = mock_provider();
576 let result = graph_recall(&store, None, &provider, "Alp", 10, 2, None, 0.0, &[])
578 .await
579 .unwrap();
580 assert!(result.len() >= 2);
581 let s0 = result[0].composite_score();
582 let s1 = result[1].composite_score();
583 assert!(s0 >= s1, "expected sorted desc: {s0} >= {s1}");
584 }
585}