1use std::collections::{HashMap, HashSet};
5use std::time::{SystemTime, UNIX_EPOCH};
6#[allow(unused_imports)]
7use zeph_db::sql;
8
9use crate::embedding_store::EmbeddingStore;
10use crate::error::MemoryError;
11
12use super::activation::{ActivatedFact, SpreadingActivation, SpreadingActivationParams};
13use super::store::GraphStore;
14use super::types::{EdgeType, GraphFact};
15
16#[allow(clippy::too_many_arguments)]
39pub async fn graph_recall(
40 store: &GraphStore,
41 embeddings: Option<&crate::embedding_store::EmbeddingStore>,
42 provider: &zeph_llm::any::AnyProvider,
43 query: &str,
44 limit: usize,
45 max_hops: u32,
46 at_timestamp: Option<&str>,
47 temporal_decay_rate: f64,
48 edge_types: &[EdgeType],
49) -> Result<Vec<GraphFact>, MemoryError> {
50 const DEFAULT_STRUCTURAL_WEIGHT: f32 = 0.4;
52 const DEFAULT_COMMUNITY_CAP: usize = 3;
53
54 if limit == 0 {
55 return Ok(Vec::new());
56 }
57
58 let entity_scores = find_seed_entities(
60 store,
61 embeddings,
62 provider,
63 query,
64 limit,
65 DEFAULT_STRUCTURAL_WEIGHT,
66 DEFAULT_COMMUNITY_CAP,
67 )
68 .await?;
69
70 if entity_scores.is_empty() {
71 return Ok(Vec::new());
72 }
73
74 let now_secs: i64 = SystemTime::now()
76 .duration_since(UNIX_EPOCH)
77 .map_or(0, |d| d.as_secs().cast_signed());
78
79 let mut all_facts: Vec<GraphFact> = Vec::new();
81
82 for (seed_id, seed_score) in &entity_scores {
83 let (entities, edges, depth_map) = if let Some(ts) = at_timestamp {
84 store.bfs_at_timestamp(*seed_id, max_hops, ts).await?
85 } else if !edge_types.is_empty() {
86 store.bfs_typed(*seed_id, max_hops, edge_types).await?
87 } else {
88 store.bfs_with_depth(*seed_id, max_hops).await?
89 };
90
91 let name_map: HashMap<i64, &str> = entities
94 .iter()
95 .map(|e| (e.id, e.canonical_name.as_str()))
96 .collect();
97
98 let traversed_edge_ids: Vec<i64> = edges.iter().map(|e| e.id).collect();
100
101 for edge in &edges {
102 let Some(&hop_distance) = depth_map
103 .get(&edge.source_entity_id)
104 .or_else(|| depth_map.get(&edge.target_entity_id))
105 else {
106 continue;
107 };
108
109 let entity_name = name_map
110 .get(&edge.source_entity_id)
111 .copied()
112 .unwrap_or_default();
113 let target_name = name_map
114 .get(&edge.target_entity_id)
115 .copied()
116 .unwrap_or_default();
117
118 if entity_name.is_empty() || target_name.is_empty() {
119 continue;
120 }
121
122 all_facts.push(GraphFact {
123 entity_name: entity_name.to_owned(),
124 relation: edge.relation.clone(),
125 target_name: target_name.to_owned(),
126 fact: edge.fact.clone(),
127 entity_match_score: *seed_score,
128 hop_distance,
129 confidence: edge.confidence,
130 valid_from: Some(edge.valid_from.clone()),
131 edge_type: edge.edge_type,
132 retrieval_count: edge.retrieval_count,
133 });
134 }
135
136 if !traversed_edge_ids.is_empty()
138 && let Err(e) = store.record_edge_retrieval(&traversed_edge_ids).await
139 {
140 tracing::warn!(error = %e, "graph_recall: failed to record edge retrieval");
141 }
142 }
143
144 let mut scored: Vec<(f32, GraphFact)> = all_facts
148 .into_iter()
149 .map(|f| {
150 let s = f.score_with_decay(temporal_decay_rate, now_secs);
151 (s, f)
152 })
153 .collect();
154 scored.sort_by(|(sa, _), (sb, _)| sb.total_cmp(sa));
155 let mut all_facts: Vec<GraphFact> = scored.into_iter().map(|(_, f)| f).collect();
156
157 let mut seen: HashSet<(String, String, String, EdgeType)> = HashSet::new();
161 all_facts.retain(|f| {
162 seen.insert((
163 f.entity_name.clone(),
164 f.relation.clone(),
165 f.target_name.clone(),
166 f.edge_type,
167 ))
168 });
169
170 all_facts.truncate(limit);
172
173 Ok(all_facts)
174}
175
176async fn seed_embedding_fallback(
195 store: &GraphStore,
196 emb_store: &EmbeddingStore,
197 provider: &zeph_llm::any::AnyProvider,
198 query: &str,
199 limit: usize,
200 fts_map: &mut HashMap<i64, (super::types::Entity, f32)>,
201) -> bool {
202 use zeph_llm::LlmProvider as _;
203 const ENTITY_COLLECTION: &str = "zeph_graph_entities";
204 let embedding = match provider.embed(query).await {
205 Ok(v) => v,
206 Err(e) => {
207 tracing::warn!(error = %e, "seed fallback: embed() failed, returning empty seeds");
208 return false;
209 }
210 };
211 match emb_store
212 .search_collection(ENTITY_COLLECTION, &embedding, limit, None)
213 .await
214 {
215 Ok(results) => {
216 for result in results {
217 if let Some(entity_id) = result
218 .payload
219 .get("entity_id")
220 .and_then(serde_json::Value::as_i64)
221 && let Ok(Some(entity)) = store.find_entity_by_id(entity_id).await
222 {
223 fts_map.insert(entity_id, (entity, result.score));
224 }
225 }
226 }
227 Err(e) => {
228 tracing::warn!(error = %e, "seed fallback: embedding search failed");
229 }
230 }
231 true
232}
233
234async fn find_seed_entities(
235 store: &GraphStore,
236 embeddings: Option<&EmbeddingStore>,
237 provider: &zeph_llm::any::AnyProvider,
238 query: &str,
239 limit: usize,
240 structural_weight: f32,
241 community_cap: usize,
242) -> Result<HashMap<i64, f32>, MemoryError> {
243 use crate::graph::types::ScoredEntity;
244
245 const MAX_WORDS: usize = 5;
246
247 let filtered: Vec<&str> = query
248 .split_whitespace()
249 .filter(|w| w.len() >= 3)
250 .take(MAX_WORDS)
251 .collect();
252 let words: Vec<&str> = if filtered.is_empty() && !query.is_empty() {
253 vec![query]
254 } else {
255 filtered
256 };
257
258 let mut fts_map: HashMap<i64, (super::types::Entity, f32)> = HashMap::new();
260 for word in &words {
261 let ranked = store.find_entities_ranked(word, limit * 2).await?;
262 for (entity, fts_score) in ranked {
263 fts_map
264 .entry(entity.id)
265 .and_modify(|(_, s)| *s = s.max(fts_score))
266 .or_insert((entity, fts_score));
267 }
268 }
269
270 if fts_map.is_empty()
272 && let Some(emb_store) = embeddings
273 && !seed_embedding_fallback(store, emb_store, provider, query, limit, &mut fts_map).await
274 {
275 return Ok(HashMap::new());
276 }
277
278 if fts_map.is_empty() {
279 return Ok(HashMap::new());
280 }
281
282 let entity_ids: Vec<i64> = fts_map.keys().copied().collect();
283
284 let structural_scores = store.entity_structural_scores(&entity_ids).await?;
286
287 let community_ids = store.entity_community_ids(&entity_ids).await?;
289
290 let fts_weight = 1.0 - structural_weight;
292 let mut scored: Vec<ScoredEntity> = fts_map
293 .into_values()
294 .map(|(entity, fts_score)| {
295 let struct_score = structural_scores.get(&entity.id).copied().unwrap_or(0.0);
296 let community_id = community_ids.get(&entity.id).copied();
297 ScoredEntity {
298 entity,
299 fts_score,
300 structural_score: struct_score,
301 community_id,
302 }
303 })
304 .collect();
305
306 scored.sort_by(|a, b| {
308 let score_a = a.fts_score * fts_weight + a.structural_score * structural_weight;
309 let score_b = b.fts_score * fts_weight + b.structural_score * structural_weight;
310 score_b.total_cmp(&score_a)
311 });
312
313 let capped: Vec<&ScoredEntity> = if community_cap == 0 {
315 scored.iter().collect()
316 } else {
317 let mut community_counts: HashMap<i64, usize> = HashMap::new();
318 let mut result: Vec<&ScoredEntity> = Vec::new();
319 for se in &scored {
320 match se.community_id {
321 Some(cid) => {
322 let count = community_counts.entry(cid).or_insert(0);
323 if *count < community_cap {
324 *count += 1;
325 result.push(se);
326 }
327 }
328 None => {
329 result.push(se);
331 }
332 }
333 }
334 result
335 };
336
337 let selected: Vec<&ScoredEntity> = if capped.is_empty() && !scored.is_empty() {
339 scored.iter().take(limit).collect()
340 } else {
341 capped.into_iter().take(limit).collect()
342 };
343
344 let entity_scores: HashMap<i64, f32> = selected
345 .into_iter()
346 .map(|se| {
347 let hybrid = se.fts_score * fts_weight + se.structural_score * structural_weight;
348 (se.entity.id, hybrid.clamp(0.1, 1.0))
350 })
351 .collect();
352
353 Ok(entity_scores)
354}
355
356pub async fn graph_recall_activated(
371 store: &GraphStore,
372 embeddings: Option<&EmbeddingStore>,
373 provider: &zeph_llm::any::AnyProvider,
374 query: &str,
375 limit: usize,
376 params: SpreadingActivationParams,
377 edge_types: &[EdgeType],
378) -> Result<Vec<ActivatedFact>, MemoryError> {
379 if limit == 0 {
380 return Ok(Vec::new());
381 }
382
383 let entity_scores = find_seed_entities(
384 store,
385 embeddings,
386 provider,
387 query,
388 limit,
389 params.seed_structural_weight,
390 params.seed_community_cap,
391 )
392 .await?;
393
394 if entity_scores.is_empty() {
395 return Ok(Vec::new());
396 }
397
398 tracing::debug!(
399 seeds = entity_scores.len(),
400 "spreading activation: starting recall"
401 );
402
403 let sa = SpreadingActivation::new(params);
404 let (_, mut facts) = sa.spread(store, entity_scores, edge_types).await?;
405
406 let edge_ids: Vec<i64> = facts.iter().map(|f| f.edge.id).collect();
408 if !edge_ids.is_empty()
409 && let Err(e) = store.record_edge_retrieval(&edge_ids).await
410 {
411 tracing::warn!(error = %e, "graph_recall_activated: failed to record edge retrieval");
412 }
413
414 facts.sort_by(|a, b| b.activation_score.total_cmp(&a.activation_score));
416
417 let mut seen: HashSet<(i64, String, i64, EdgeType)> = HashSet::new();
419 facts.retain(|f| {
420 seen.insert((
421 f.edge.source_entity_id,
422 f.edge.relation.clone(),
423 f.edge.target_entity_id,
424 f.edge.edge_type,
425 ))
426 });
427
428 facts.truncate(limit);
429
430 tracing::debug!(
431 result_count = facts.len(),
432 "spreading activation: recall complete"
433 );
434
435 Ok(facts)
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441 use crate::graph::store::GraphStore;
442 use crate::graph::types::EntityType;
443 use crate::store::SqliteStore;
444 use zeph_llm::any::AnyProvider;
445 use zeph_llm::mock::MockProvider;
446
447 async fn setup_store() -> GraphStore {
448 let store = SqliteStore::new(":memory:").await.unwrap();
449 GraphStore::new(store.pool().clone())
450 }
451
452 fn mock_provider() -> AnyProvider {
453 AnyProvider::Mock(MockProvider::default())
454 }
455
456 #[tokio::test]
457 async fn graph_recall_empty_graph_returns_empty() {
458 let store = setup_store().await;
459 let provider = mock_provider();
460 let result = graph_recall(&store, None, &provider, "anything", 10, 2, None, 0.0, &[])
461 .await
462 .unwrap();
463 assert!(result.is_empty());
464 }
465
466 #[tokio::test]
467 async fn graph_recall_zero_limit_returns_empty() {
468 let store = setup_store().await;
469 let provider = mock_provider();
470 let result = graph_recall(&store, None, &provider, "user", 0, 2, None, 0.0, &[])
471 .await
472 .unwrap();
473 assert!(result.is_empty());
474 }
475
476 #[tokio::test]
477 async fn graph_recall_fuzzy_match_returns_facts() {
478 let store = setup_store().await;
479 let user_id = store
480 .upsert_entity("Alice", "Alice", EntityType::Person, None)
481 .await
482 .unwrap();
483 let tool_id = store
484 .upsert_entity("neovim", "neovim", EntityType::Tool, None)
485 .await
486 .unwrap();
487 store
488 .insert_edge(user_id, tool_id, "uses", "Alice uses neovim", 0.9, None)
489 .await
490 .unwrap();
491
492 let provider = mock_provider();
493 let result = graph_recall(&store, None, &provider, "Ali neovim", 10, 2, None, 0.0, &[])
495 .await
496 .unwrap();
497 assert!(!result.is_empty());
498 assert_eq!(result[0].relation, "uses");
499 }
500
501 #[tokio::test]
502 async fn graph_recall_respects_max_hops() {
503 let store = setup_store().await;
504 let a = store
505 .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
506 .await
507 .unwrap();
508 let b = store
509 .upsert_entity("Beta", "Beta", EntityType::Person, None)
510 .await
511 .unwrap();
512 let c = store
513 .upsert_entity("Gamma", "Gamma", EntityType::Person, None)
514 .await
515 .unwrap();
516 store
517 .insert_edge(a, b, "knows", "Alpha knows Beta", 0.8, None)
518 .await
519 .unwrap();
520 store
521 .insert_edge(b, c, "knows", "Beta knows Gamma", 0.8, None)
522 .await
523 .unwrap();
524
525 let provider = mock_provider();
526 let result = graph_recall(&store, None, &provider, "Alp", 10, 1, None, 0.0, &[])
528 .await
529 .unwrap();
530 assert!(result.iter().all(|f| f.hop_distance <= 1));
532 }
533
534 #[tokio::test]
535 async fn graph_recall_deduplicates_facts() {
536 let store = setup_store().await;
537 let alice = store
538 .upsert_entity("Alice", "Alice", EntityType::Person, None)
539 .await
540 .unwrap();
541 let bob = store
542 .upsert_entity("Bob", "Bob", EntityType::Person, None)
543 .await
544 .unwrap();
545 store
546 .insert_edge(alice, bob, "knows", "Alice knows Bob", 0.9, None)
547 .await
548 .unwrap();
549
550 let provider = mock_provider();
551 let result = graph_recall(&store, None, &provider, "Ali Bob", 10, 2, None, 0.0, &[])
553 .await
554 .unwrap();
555
556 let mut seen = std::collections::HashSet::new();
558 for f in &result {
559 let key = (&f.entity_name, &f.relation, &f.target_name);
560 assert!(seen.insert(key), "duplicate fact found: {f:?}");
561 }
562 }
563
564 #[tokio::test]
565 async fn graph_recall_sorts_by_composite_score() {
566 let store = setup_store().await;
567 let a = store
568 .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
569 .await
570 .unwrap();
571 let b = store
572 .upsert_entity("Beta", "Beta", EntityType::Tool, None)
573 .await
574 .unwrap();
575 let c = store
576 .upsert_entity("AlphaGadget", "AlphaGadget", EntityType::Tool, None)
577 .await
578 .unwrap();
579 store
581 .insert_edge(a, b, "uses", "Alpha uses Beta", 1.0, None)
582 .await
583 .unwrap();
584 store
586 .insert_edge(a, c, "mentions", "Alpha mentions AlphaGadget", 0.1, None)
587 .await
588 .unwrap();
589
590 let provider = mock_provider();
591 let result = graph_recall(&store, None, &provider, "Alp", 10, 2, None, 0.0, &[])
592 .await
593 .unwrap();
594
595 assert!(result.len() >= 2);
597 let s0 = result[0].composite_score();
598 let s1 = result[1].composite_score();
599 assert!(s0 >= s1, "expected sorted desc: {s0} >= {s1}");
600 }
601
602 #[tokio::test]
603 async fn graph_recall_limit_truncates() {
604 let store = setup_store().await;
605 let root = store
606 .upsert_entity("Root", "Root", EntityType::Person, None)
607 .await
608 .unwrap();
609 for i in 0..10 {
610 let target = store
611 .upsert_entity(
612 &format!("Target{i}"),
613 &format!("Target{i}"),
614 EntityType::Tool,
615 None,
616 )
617 .await
618 .unwrap();
619 store
620 .insert_edge(
621 root,
622 target,
623 "has",
624 &format!("Root has Target{i}"),
625 0.8,
626 None,
627 )
628 .await
629 .unwrap();
630 }
631
632 let provider = mock_provider();
633 let result = graph_recall(&store, None, &provider, "Roo", 3, 2, None, 0.0, &[])
634 .await
635 .unwrap();
636 assert!(result.len() <= 3);
637 }
638
639 #[tokio::test]
640 async fn graph_recall_at_timestamp_excludes_future_edges() {
641 let store = setup_store().await;
642 let alice = store
643 .upsert_entity("Alice", "Alice", EntityType::Person, None)
644 .await
645 .unwrap();
646 let bob = store
647 .upsert_entity("Bob", "Bob", EntityType::Person, None)
648 .await
649 .unwrap();
650 zeph_db::query(
652 sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
653 VALUES (?1, ?2, 'knows', 'Alice knows Bob', 0.9, '2100-01-01 00:00:00')"),
654 )
655 .bind(alice)
656 .bind(bob)
657 .execute(store.pool())
658 .await
659 .unwrap();
660
661 let provider = mock_provider();
662 let result = graph_recall(
664 &store,
665 None,
666 &provider,
667 "Ali",
668 10,
669 2,
670 Some("2026-01-01 00:00:00"),
671 0.0,
672 &[],
673 )
674 .await
675 .unwrap();
676 assert!(result.is_empty(), "future edge should be excluded");
677 }
678
679 #[tokio::test]
680 async fn graph_recall_at_timestamp_excludes_invalidated_edges() {
681 let store = setup_store().await;
682 let alice = store
683 .upsert_entity("Alice", "Alice", EntityType::Person, None)
684 .await
685 .unwrap();
686 let carol = store
687 .upsert_entity("Carol", "Carol", EntityType::Person, None)
688 .await
689 .unwrap();
690 zeph_db::query(
692 sql!("INSERT INTO graph_edges
693 (source_entity_id, target_entity_id, relation, fact, confidence, valid_from, valid_to, expired_at)
694 VALUES (?1, ?2, 'manages', 'Alice manages Carol', 0.8,
695 '2020-01-01 00:00:00', '2021-01-01 00:00:00', '2021-01-01 00:00:00')"),
696 )
697 .bind(alice)
698 .bind(carol)
699 .execute(store.pool())
700 .await
701 .unwrap();
702
703 let provider = mock_provider();
704
705 let result_current = graph_recall(&store, None, &provider, "Ali", 10, 2, None, 0.0, &[])
707 .await
708 .unwrap();
709 assert!(
710 result_current.is_empty(),
711 "expired edge should be invisible at current time"
712 );
713
714 let result_historical = graph_recall(
716 &store,
717 None,
718 &provider,
719 "Ali",
720 10,
721 2,
722 Some("2020-06-01 00:00:00"),
723 0.0,
724 &[],
725 )
726 .await
727 .unwrap();
728 assert!(
729 !result_historical.is_empty(),
730 "edge should be visible within its validity window"
731 );
732 }
733
734 #[tokio::test]
740 async fn graph_recall_community_cap_guard_non_empty() {
741 let store = setup_store().await;
742 let mut entity_ids = Vec::new();
744 for i in 0..5usize {
745 let id = store
746 .upsert_entity(
747 &format!("Entity{i}"),
748 &format!("entity{i}"),
749 crate::graph::types::EntityType::Concept,
750 None,
751 )
752 .await
753 .unwrap();
754 entity_ids.push(id);
755 }
756
757 let community_id = store
759 .upsert_community("TestComm", "test", &entity_ids, Some("fp"))
760 .await
761 .unwrap();
762 let _ = community_id;
763
764 let hub = store
766 .upsert_entity("Hub", "hub", crate::graph::types::EntityType::Concept, None)
767 .await
768 .unwrap();
769 for &target in &entity_ids {
770 store
771 .insert_edge(hub, target, "has", "Hub has entity", 0.9, None)
772 .await
773 .unwrap();
774 }
775
776 let provider = mock_provider();
777 let result = graph_recall(&store, None, &provider, "entity", 10, 2, None, 0.0, &[])
781 .await
782 .unwrap();
783 assert!(
785 !result.is_empty(),
786 "SA-INV-10: community cap must not zero out all seeds"
787 );
788 }
789
790 #[tokio::test]
793 async fn graph_recall_no_fts_match_no_embeddings_returns_empty() {
794 let store = setup_store().await;
795 let a = store
797 .upsert_entity(
798 "Zephyr",
799 "zephyr",
800 crate::graph::types::EntityType::Concept,
801 None,
802 )
803 .await
804 .unwrap();
805 let b = store
806 .upsert_entity(
807 "Concept",
808 "concept",
809 crate::graph::types::EntityType::Concept,
810 None,
811 )
812 .await
813 .unwrap();
814 store
815 .insert_edge(a, b, "rel", "Zephyr rel Concept", 0.9, None)
816 .await
817 .unwrap();
818
819 let provider = mock_provider();
820 let result = graph_recall(
822 &store,
823 None,
824 &provider,
825 "xyzzyquuxfrob",
826 10,
827 2,
828 None,
829 0.0,
830 &[],
831 )
832 .await
833 .unwrap();
834 assert!(
835 result.is_empty(),
836 "must return empty (not error) when FTS5 returns 0 and no embeddings available"
837 );
838 }
839
840 #[tokio::test]
841 async fn graph_recall_temporal_decay_preserves_order_with_zero_rate() {
842 let store = setup_store().await;
843 let a = store
844 .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
845 .await
846 .unwrap();
847 let b = store
848 .upsert_entity("Beta", "Beta", EntityType::Tool, None)
849 .await
850 .unwrap();
851 let c = store
852 .upsert_entity("AlphaGadget", "AlphaGadget", EntityType::Tool, None)
853 .await
854 .unwrap();
855 store
856 .insert_edge(a, b, "uses", "Alpha uses Beta", 1.0, None)
857 .await
858 .unwrap();
859 store
860 .insert_edge(a, c, "mentions", "Alpha mentions AlphaGadget", 0.1, None)
861 .await
862 .unwrap();
863
864 let provider = mock_provider();
865 let result = graph_recall(&store, None, &provider, "Alp", 10, 2, None, 0.0, &[])
867 .await
868 .unwrap();
869 assert!(result.len() >= 2);
870 let s0 = result[0].composite_score();
871 let s1 = result[1].composite_score();
872 assert!(s0 >= s1, "expected sorted desc: {s0} >= {s1}");
873 }
874}