1use std::collections::{HashMap, HashSet};
5use std::time::{SystemTime, UNIX_EPOCH};
6#[allow(unused_imports)]
7use zeph_db::sql;
8
9use crate::embedding_store::EmbeddingStore;
10use crate::error::MemoryError;
11
12use super::activation::{ActivatedFact, SpreadingActivation, SpreadingActivationParams};
13use super::store::GraphStore;
14use super::types::{EdgeType, GraphFact};
15
16#[allow(clippy::too_many_arguments)]
39pub async fn graph_recall(
40 store: &GraphStore,
41 embeddings: Option<&crate::embedding_store::EmbeddingStore>,
42 provider: &zeph_llm::any::AnyProvider,
43 query: &str,
44 limit: usize,
45 max_hops: u32,
46 at_timestamp: Option<&str>,
47 temporal_decay_rate: f64,
48 edge_types: &[EdgeType],
49) -> Result<Vec<GraphFact>, MemoryError> {
50 const DEFAULT_STRUCTURAL_WEIGHT: f32 = 0.4;
52 const DEFAULT_COMMUNITY_CAP: usize = 3;
53
54 if limit == 0 {
55 return Ok(Vec::new());
56 }
57
58 let entity_scores = find_seed_entities(
60 store,
61 embeddings,
62 provider,
63 query,
64 limit,
65 DEFAULT_STRUCTURAL_WEIGHT,
66 DEFAULT_COMMUNITY_CAP,
67 )
68 .await?;
69
70 if entity_scores.is_empty() {
71 return Ok(Vec::new());
72 }
73
74 let now_secs: i64 = SystemTime::now()
76 .duration_since(UNIX_EPOCH)
77 .map(|d| d.as_secs().cast_signed())
78 .unwrap_or(0);
79
80 let mut all_facts: Vec<GraphFact> = Vec::new();
82
83 for (seed_id, seed_score) in &entity_scores {
84 let (entities, edges, depth_map) = if let Some(ts) = at_timestamp {
85 store.bfs_at_timestamp(*seed_id, max_hops, ts).await?
86 } else if !edge_types.is_empty() {
87 store.bfs_typed(*seed_id, max_hops, edge_types).await?
88 } else {
89 store.bfs_with_depth(*seed_id, max_hops).await?
90 };
91
92 let name_map: HashMap<i64, &str> = entities
95 .iter()
96 .map(|e| (e.id, e.canonical_name.as_str()))
97 .collect();
98
99 let traversed_edge_ids: Vec<i64> = edges.iter().map(|e| e.id).collect();
101
102 for edge in &edges {
103 let Some(&hop_distance) = depth_map
104 .get(&edge.source_entity_id)
105 .or_else(|| depth_map.get(&edge.target_entity_id))
106 else {
107 continue;
108 };
109
110 let entity_name = name_map
111 .get(&edge.source_entity_id)
112 .copied()
113 .unwrap_or_default();
114 let target_name = name_map
115 .get(&edge.target_entity_id)
116 .copied()
117 .unwrap_or_default();
118
119 if entity_name.is_empty() || target_name.is_empty() {
120 continue;
121 }
122
123 all_facts.push(GraphFact {
124 entity_name: entity_name.to_owned(),
125 relation: edge.relation.clone(),
126 target_name: target_name.to_owned(),
127 fact: edge.fact.clone(),
128 entity_match_score: *seed_score,
129 hop_distance,
130 confidence: edge.confidence,
131 valid_from: Some(edge.valid_from.clone()),
132 edge_type: edge.edge_type,
133 retrieval_count: edge.retrieval_count,
134 });
135 }
136
137 if !traversed_edge_ids.is_empty()
139 && let Err(e) = store.record_edge_retrieval(&traversed_edge_ids).await
140 {
141 tracing::warn!(error = %e, "graph_recall: failed to record edge retrieval");
142 }
143 }
144
145 let mut scored: Vec<(f32, GraphFact)> = all_facts
149 .into_iter()
150 .map(|f| {
151 let s = f.score_with_decay(temporal_decay_rate, now_secs);
152 (s, f)
153 })
154 .collect();
155 scored.sort_by(|(sa, _), (sb, _)| sb.total_cmp(sa));
156 let mut all_facts: Vec<GraphFact> = scored.into_iter().map(|(_, f)| f).collect();
157
158 let mut seen: HashSet<(String, String, String, EdgeType)> = HashSet::new();
162 all_facts.retain(|f| {
163 seen.insert((
164 f.entity_name.clone(),
165 f.relation.clone(),
166 f.target_name.clone(),
167 f.edge_type,
168 ))
169 });
170
171 all_facts.truncate(limit);
173
174 Ok(all_facts)
175}
176
177async fn seed_embedding_fallback(
196 store: &GraphStore,
197 emb_store: &EmbeddingStore,
198 provider: &zeph_llm::any::AnyProvider,
199 query: &str,
200 limit: usize,
201 fts_map: &mut HashMap<i64, (super::types::Entity, f32)>,
202) -> bool {
203 use zeph_llm::LlmProvider as _;
204 const ENTITY_COLLECTION: &str = "zeph_graph_entities";
205 let embedding = match provider.embed(query).await {
206 Ok(v) => v,
207 Err(e) => {
208 tracing::warn!(error = %e, "seed fallback: embed() failed, returning empty seeds");
209 return false;
210 }
211 };
212 match emb_store
213 .search_collection(ENTITY_COLLECTION, &embedding, limit, None)
214 .await
215 {
216 Ok(results) => {
217 for result in results {
218 if let Some(entity_id) = result
219 .payload
220 .get("entity_id")
221 .and_then(serde_json::Value::as_i64)
222 && let Ok(Some(entity)) = store.find_entity_by_id(entity_id).await
223 {
224 fts_map.insert(entity_id, (entity, result.score));
225 }
226 }
227 }
228 Err(e) => {
229 tracing::warn!(error = %e, "seed fallback: embedding search failed");
230 }
231 }
232 true
233}
234
235async fn find_seed_entities(
236 store: &GraphStore,
237 embeddings: Option<&EmbeddingStore>,
238 provider: &zeph_llm::any::AnyProvider,
239 query: &str,
240 limit: usize,
241 structural_weight: f32,
242 community_cap: usize,
243) -> Result<HashMap<i64, f32>, MemoryError> {
244 use crate::graph::types::ScoredEntity;
245
246 const MAX_WORDS: usize = 5;
247
248 let filtered: Vec<&str> = query
249 .split_whitespace()
250 .filter(|w| w.len() >= 3)
251 .take(MAX_WORDS)
252 .collect();
253 let words: Vec<&str> = if filtered.is_empty() && !query.is_empty() {
254 vec![query]
255 } else {
256 filtered
257 };
258
259 let mut fts_map: HashMap<i64, (super::types::Entity, f32)> = HashMap::new();
261 for word in &words {
262 let ranked = store.find_entities_ranked(word, limit * 2).await?;
263 for (entity, fts_score) in ranked {
264 fts_map
265 .entry(entity.id)
266 .and_modify(|(_, s)| *s = s.max(fts_score))
267 .or_insert((entity, fts_score));
268 }
269 }
270
271 if fts_map.is_empty()
273 && let Some(emb_store) = embeddings
274 && !seed_embedding_fallback(store, emb_store, provider, query, limit, &mut fts_map).await
275 {
276 return Ok(HashMap::new());
277 }
278
279 if fts_map.is_empty() {
280 return Ok(HashMap::new());
281 }
282
283 let entity_ids: Vec<i64> = fts_map.keys().copied().collect();
284
285 let structural_scores = store.entity_structural_scores(&entity_ids).await?;
287
288 let community_ids = store.entity_community_ids(&entity_ids).await?;
290
291 let fts_weight = 1.0 - structural_weight;
293 let mut scored: Vec<ScoredEntity> = fts_map
294 .into_values()
295 .map(|(entity, fts_score)| {
296 let struct_score = structural_scores.get(&entity.id).copied().unwrap_or(0.0);
297 let community_id = community_ids.get(&entity.id).copied();
298 ScoredEntity {
299 entity,
300 fts_score,
301 structural_score: struct_score,
302 community_id,
303 }
304 })
305 .collect();
306
307 scored.sort_by(|a, b| {
309 let score_a = a.fts_score * fts_weight + a.structural_score * structural_weight;
310 let score_b = b.fts_score * fts_weight + b.structural_score * structural_weight;
311 score_b.total_cmp(&score_a)
312 });
313
314 let capped: Vec<&ScoredEntity> = if community_cap == 0 {
316 scored.iter().collect()
317 } else {
318 let mut community_counts: HashMap<i64, usize> = HashMap::new();
319 let mut result: Vec<&ScoredEntity> = Vec::new();
320 for se in &scored {
321 match se.community_id {
322 Some(cid) => {
323 let count = community_counts.entry(cid).or_insert(0);
324 if *count < community_cap {
325 *count += 1;
326 result.push(se);
327 }
328 }
329 None => {
330 result.push(se);
332 }
333 }
334 }
335 result
336 };
337
338 let selected: Vec<&ScoredEntity> = if capped.is_empty() && !scored.is_empty() {
340 scored.iter().take(limit).collect()
341 } else {
342 capped.into_iter().take(limit).collect()
343 };
344
345 let entity_scores: HashMap<i64, f32> = selected
346 .into_iter()
347 .map(|se| {
348 let hybrid = se.fts_score * fts_weight + se.structural_score * structural_weight;
349 (se.entity.id, hybrid.clamp(0.1, 1.0))
351 })
352 .collect();
353
354 Ok(entity_scores)
355}
356
357pub async fn graph_recall_activated(
372 store: &GraphStore,
373 embeddings: Option<&EmbeddingStore>,
374 provider: &zeph_llm::any::AnyProvider,
375 query: &str,
376 limit: usize,
377 params: SpreadingActivationParams,
378 edge_types: &[EdgeType],
379) -> Result<Vec<ActivatedFact>, MemoryError> {
380 if limit == 0 {
381 return Ok(Vec::new());
382 }
383
384 let entity_scores = find_seed_entities(
385 store,
386 embeddings,
387 provider,
388 query,
389 limit,
390 params.seed_structural_weight,
391 params.seed_community_cap,
392 )
393 .await?;
394
395 if entity_scores.is_empty() {
396 return Ok(Vec::new());
397 }
398
399 tracing::debug!(
400 seeds = entity_scores.len(),
401 "spreading activation: starting recall"
402 );
403
404 let sa = SpreadingActivation::new(params);
405 let (_, mut facts) = sa.spread(store, entity_scores, edge_types).await?;
406
407 let edge_ids: Vec<i64> = facts.iter().map(|f| f.edge.id).collect();
409 if !edge_ids.is_empty()
410 && let Err(e) = store.record_edge_retrieval(&edge_ids).await
411 {
412 tracing::warn!(error = %e, "graph_recall_activated: failed to record edge retrieval");
413 }
414
415 facts.sort_by(|a, b| b.activation_score.total_cmp(&a.activation_score));
417
418 let mut seen: HashSet<(i64, String, i64, EdgeType)> = HashSet::new();
420 facts.retain(|f| {
421 seen.insert((
422 f.edge.source_entity_id,
423 f.edge.relation.clone(),
424 f.edge.target_entity_id,
425 f.edge.edge_type,
426 ))
427 });
428
429 facts.truncate(limit);
430
431 tracing::debug!(
432 result_count = facts.len(),
433 "spreading activation: recall complete"
434 );
435
436 Ok(facts)
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442 use crate::graph::store::GraphStore;
443 use crate::graph::types::EntityType;
444 use crate::store::SqliteStore;
445 use zeph_llm::any::AnyProvider;
446 use zeph_llm::mock::MockProvider;
447
448 async fn setup_store() -> GraphStore {
449 let store = SqliteStore::new(":memory:").await.unwrap();
450 GraphStore::new(store.pool().clone())
451 }
452
453 fn mock_provider() -> AnyProvider {
454 AnyProvider::Mock(MockProvider::default())
455 }
456
457 #[tokio::test]
458 async fn graph_recall_empty_graph_returns_empty() {
459 let store = setup_store().await;
460 let provider = mock_provider();
461 let result = graph_recall(&store, None, &provider, "anything", 10, 2, None, 0.0, &[])
462 .await
463 .unwrap();
464 assert!(result.is_empty());
465 }
466
467 #[tokio::test]
468 async fn graph_recall_zero_limit_returns_empty() {
469 let store = setup_store().await;
470 let provider = mock_provider();
471 let result = graph_recall(&store, None, &provider, "user", 0, 2, None, 0.0, &[])
472 .await
473 .unwrap();
474 assert!(result.is_empty());
475 }
476
477 #[tokio::test]
478 async fn graph_recall_fuzzy_match_returns_facts() {
479 let store = setup_store().await;
480 let user_id = store
481 .upsert_entity("Alice", "Alice", EntityType::Person, None)
482 .await
483 .unwrap();
484 let tool_id = store
485 .upsert_entity("neovim", "neovim", EntityType::Tool, None)
486 .await
487 .unwrap();
488 store
489 .insert_edge(user_id, tool_id, "uses", "Alice uses neovim", 0.9, None)
490 .await
491 .unwrap();
492
493 let provider = mock_provider();
494 let result = graph_recall(&store, None, &provider, "Ali neovim", 10, 2, None, 0.0, &[])
496 .await
497 .unwrap();
498 assert!(!result.is_empty());
499 assert_eq!(result[0].relation, "uses");
500 }
501
502 #[tokio::test]
503 async fn graph_recall_respects_max_hops() {
504 let store = setup_store().await;
505 let a = store
506 .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
507 .await
508 .unwrap();
509 let b = store
510 .upsert_entity("Beta", "Beta", EntityType::Person, None)
511 .await
512 .unwrap();
513 let c = store
514 .upsert_entity("Gamma", "Gamma", EntityType::Person, None)
515 .await
516 .unwrap();
517 store
518 .insert_edge(a, b, "knows", "Alpha knows Beta", 0.8, None)
519 .await
520 .unwrap();
521 store
522 .insert_edge(b, c, "knows", "Beta knows Gamma", 0.8, None)
523 .await
524 .unwrap();
525
526 let provider = mock_provider();
527 let result = graph_recall(&store, None, &provider, "Alp", 10, 1, None, 0.0, &[])
529 .await
530 .unwrap();
531 assert!(result.iter().all(|f| f.hop_distance <= 1));
533 }
534
535 #[tokio::test]
536 async fn graph_recall_deduplicates_facts() {
537 let store = setup_store().await;
538 let alice = store
539 .upsert_entity("Alice", "Alice", EntityType::Person, None)
540 .await
541 .unwrap();
542 let bob = store
543 .upsert_entity("Bob", "Bob", EntityType::Person, None)
544 .await
545 .unwrap();
546 store
547 .insert_edge(alice, bob, "knows", "Alice knows Bob", 0.9, None)
548 .await
549 .unwrap();
550
551 let provider = mock_provider();
552 let result = graph_recall(&store, None, &provider, "Ali Bob", 10, 2, None, 0.0, &[])
554 .await
555 .unwrap();
556
557 let mut seen = std::collections::HashSet::new();
559 for f in &result {
560 let key = (&f.entity_name, &f.relation, &f.target_name);
561 assert!(seen.insert(key), "duplicate fact found: {f:?}");
562 }
563 }
564
565 #[tokio::test]
566 async fn graph_recall_sorts_by_composite_score() {
567 let store = setup_store().await;
568 let a = store
569 .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
570 .await
571 .unwrap();
572 let b = store
573 .upsert_entity("Beta", "Beta", EntityType::Tool, None)
574 .await
575 .unwrap();
576 let c = store
577 .upsert_entity("AlphaGadget", "AlphaGadget", EntityType::Tool, None)
578 .await
579 .unwrap();
580 store
582 .insert_edge(a, b, "uses", "Alpha uses Beta", 1.0, None)
583 .await
584 .unwrap();
585 store
587 .insert_edge(a, c, "mentions", "Alpha mentions AlphaGadget", 0.1, None)
588 .await
589 .unwrap();
590
591 let provider = mock_provider();
592 let result = graph_recall(&store, None, &provider, "Alp", 10, 2, None, 0.0, &[])
593 .await
594 .unwrap();
595
596 assert!(result.len() >= 2);
598 let s0 = result[0].composite_score();
599 let s1 = result[1].composite_score();
600 assert!(s0 >= s1, "expected sorted desc: {s0} >= {s1}");
601 }
602
603 #[tokio::test]
604 async fn graph_recall_limit_truncates() {
605 let store = setup_store().await;
606 let root = store
607 .upsert_entity("Root", "Root", EntityType::Person, None)
608 .await
609 .unwrap();
610 for i in 0..10 {
611 let target = store
612 .upsert_entity(
613 &format!("Target{i}"),
614 &format!("Target{i}"),
615 EntityType::Tool,
616 None,
617 )
618 .await
619 .unwrap();
620 store
621 .insert_edge(
622 root,
623 target,
624 "has",
625 &format!("Root has Target{i}"),
626 0.8,
627 None,
628 )
629 .await
630 .unwrap();
631 }
632
633 let provider = mock_provider();
634 let result = graph_recall(&store, None, &provider, "Roo", 3, 2, None, 0.0, &[])
635 .await
636 .unwrap();
637 assert!(result.len() <= 3);
638 }
639
640 #[tokio::test]
641 async fn graph_recall_at_timestamp_excludes_future_edges() {
642 let store = setup_store().await;
643 let alice = store
644 .upsert_entity("Alice", "Alice", EntityType::Person, None)
645 .await
646 .unwrap();
647 let bob = store
648 .upsert_entity("Bob", "Bob", EntityType::Person, None)
649 .await
650 .unwrap();
651 zeph_db::query(
653 sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
654 VALUES (?1, ?2, 'knows', 'Alice knows Bob', 0.9, '2100-01-01 00:00:00')"),
655 )
656 .bind(alice)
657 .bind(bob)
658 .execute(store.pool())
659 .await
660 .unwrap();
661
662 let provider = mock_provider();
663 let result = graph_recall(
665 &store,
666 None,
667 &provider,
668 "Ali",
669 10,
670 2,
671 Some("2026-01-01 00:00:00"),
672 0.0,
673 &[],
674 )
675 .await
676 .unwrap();
677 assert!(result.is_empty(), "future edge should be excluded");
678 }
679
680 #[tokio::test]
681 async fn graph_recall_at_timestamp_excludes_invalidated_edges() {
682 let store = setup_store().await;
683 let alice = store
684 .upsert_entity("Alice", "Alice", EntityType::Person, None)
685 .await
686 .unwrap();
687 let carol = store
688 .upsert_entity("Carol", "Carol", EntityType::Person, None)
689 .await
690 .unwrap();
691 zeph_db::query(
693 sql!("INSERT INTO graph_edges
694 (source_entity_id, target_entity_id, relation, fact, confidence, valid_from, valid_to, expired_at)
695 VALUES (?1, ?2, 'manages', 'Alice manages Carol', 0.8,
696 '2020-01-01 00:00:00', '2021-01-01 00:00:00', '2021-01-01 00:00:00')"),
697 )
698 .bind(alice)
699 .bind(carol)
700 .execute(store.pool())
701 .await
702 .unwrap();
703
704 let provider = mock_provider();
705
706 let result_current = graph_recall(&store, None, &provider, "Ali", 10, 2, None, 0.0, &[])
708 .await
709 .unwrap();
710 assert!(
711 result_current.is_empty(),
712 "expired edge should be invisible at current time"
713 );
714
715 let result_historical = graph_recall(
717 &store,
718 None,
719 &provider,
720 "Ali",
721 10,
722 2,
723 Some("2020-06-01 00:00:00"),
724 0.0,
725 &[],
726 )
727 .await
728 .unwrap();
729 assert!(
730 !result_historical.is_empty(),
731 "edge should be visible within its validity window"
732 );
733 }
734
735 #[tokio::test]
741 async fn graph_recall_community_cap_guard_non_empty() {
742 let store = setup_store().await;
743 let mut entity_ids = Vec::new();
745 for i in 0..5usize {
746 let id = store
747 .upsert_entity(
748 &format!("Entity{i}"),
749 &format!("entity{i}"),
750 crate::graph::types::EntityType::Concept,
751 None,
752 )
753 .await
754 .unwrap();
755 entity_ids.push(id);
756 }
757
758 let community_id = store
760 .upsert_community("TestComm", "test", &entity_ids, Some("fp"))
761 .await
762 .unwrap();
763 let _ = community_id;
764
765 let hub = store
767 .upsert_entity("Hub", "hub", crate::graph::types::EntityType::Concept, None)
768 .await
769 .unwrap();
770 for &target in &entity_ids {
771 store
772 .insert_edge(hub, target, "has", "Hub has entity", 0.9, None)
773 .await
774 .unwrap();
775 }
776
777 let provider = mock_provider();
778 let result = graph_recall(&store, None, &provider, "entity", 10, 2, None, 0.0, &[])
782 .await
783 .unwrap();
784 assert!(
786 !result.is_empty(),
787 "SA-INV-10: community cap must not zero out all seeds"
788 );
789 }
790
791 #[tokio::test]
794 async fn graph_recall_no_fts_match_no_embeddings_returns_empty() {
795 let store = setup_store().await;
796 let a = store
798 .upsert_entity(
799 "Zephyr",
800 "zephyr",
801 crate::graph::types::EntityType::Concept,
802 None,
803 )
804 .await
805 .unwrap();
806 let b = store
807 .upsert_entity(
808 "Concept",
809 "concept",
810 crate::graph::types::EntityType::Concept,
811 None,
812 )
813 .await
814 .unwrap();
815 store
816 .insert_edge(a, b, "rel", "Zephyr rel Concept", 0.9, None)
817 .await
818 .unwrap();
819
820 let provider = mock_provider();
821 let result = graph_recall(
823 &store,
824 None,
825 &provider,
826 "xyzzyquuxfrob",
827 10,
828 2,
829 None,
830 0.0,
831 &[],
832 )
833 .await
834 .unwrap();
835 assert!(
836 result.is_empty(),
837 "must return empty (not error) when FTS5 returns 0 and no embeddings available"
838 );
839 }
840
841 #[tokio::test]
842 async fn graph_recall_temporal_decay_preserves_order_with_zero_rate() {
843 let store = setup_store().await;
844 let a = store
845 .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
846 .await
847 .unwrap();
848 let b = store
849 .upsert_entity("Beta", "Beta", EntityType::Tool, None)
850 .await
851 .unwrap();
852 let c = store
853 .upsert_entity("AlphaGadget", "AlphaGadget", EntityType::Tool, None)
854 .await
855 .unwrap();
856 store
857 .insert_edge(a, b, "uses", "Alpha uses Beta", 1.0, None)
858 .await
859 .unwrap();
860 store
861 .insert_edge(a, c, "mentions", "Alpha mentions AlphaGadget", 0.1, None)
862 .await
863 .unwrap();
864
865 let provider = mock_provider();
866 let result = graph_recall(&store, None, &provider, "Alp", 10, 2, None, 0.0, &[])
868 .await
869 .unwrap();
870 assert!(result.len() >= 2);
871 let s0 = result[0].composite_score();
872 let s1 = result[1].composite_score();
873 assert!(s0 >= s1, "expected sorted desc: {s0} >= {s1}");
874 }
875}