1use std::collections::{HashMap, HashSet};
5use std::time::{SystemTime, UNIX_EPOCH};
6#[allow(unused_imports)]
7use zeph_db::sql;
8
9use crate::embedding_store::EmbeddingStore;
10use crate::error::MemoryError;
11
12use super::activation::{ActivatedFact, SpreadingActivation, SpreadingActivationParams};
13use super::store::GraphStore;
14use super::types::{EdgeType, GraphFact};
15
16#[allow(clippy::too_many_arguments, clippy::too_many_lines)] pub async fn graph_recall(
40 store: &GraphStore,
41 embeddings: Option<&crate::embedding_store::EmbeddingStore>,
42 provider: &zeph_llm::any::AnyProvider,
43 query: &str,
44 limit: usize,
45 max_hops: u32,
46 at_timestamp: Option<&str>,
47 temporal_decay_rate: f64,
48 edge_types: &[EdgeType],
49 hebbian_enabled: bool,
50 hebbian_lr: f32,
51 embed_timeout: std::time::Duration,
52) -> Result<Vec<GraphFact>, MemoryError> {
53 const DEFAULT_STRUCTURAL_WEIGHT: f32 = 0.4;
55 const DEFAULT_COMMUNITY_CAP: usize = 3;
56
57 if limit == 0 {
58 return Ok(Vec::new());
59 }
60
61 let entity_scores = find_seed_entities(
63 store,
64 embeddings,
65 provider,
66 query,
67 limit,
68 DEFAULT_STRUCTURAL_WEIGHT,
69 DEFAULT_COMMUNITY_CAP,
70 embed_timeout,
71 )
72 .await?;
73
74 if entity_scores.is_empty() {
75 return Ok(Vec::new());
76 }
77
78 let now_secs: i64 = SystemTime::now()
80 .duration_since(UNIX_EPOCH)
81 .map_or(0, |d| d.as_secs().cast_signed());
82
83 let mut all_facts: Vec<GraphFact> = Vec::new();
85
86 for (seed_id, seed_score) in &entity_scores {
87 let (entities, edges, depth_map) = if let Some(ts) = at_timestamp {
88 store.bfs_at_timestamp(*seed_id, max_hops, ts).await?
89 } else if !edge_types.is_empty() {
90 store.bfs_typed(*seed_id, max_hops, edge_types).await?
91 } else {
92 store.bfs_with_depth(*seed_id, max_hops).await?
93 };
94
95 let name_map: HashMap<i64, &str> = entities
98 .iter()
99 .map(|e| (e.id.0, e.canonical_name.as_str()))
100 .collect();
101
102 let traversed_edge_ids: Vec<i64> = edges.iter().map(|e| e.id).collect();
104
105 for edge in &edges {
106 let Some(&hop_distance) = depth_map
107 .get(&edge.source_entity_id)
108 .or_else(|| depth_map.get(&edge.target_entity_id))
109 else {
110 continue;
111 };
112
113 let entity_name = name_map
114 .get(&edge.source_entity_id)
115 .copied()
116 .unwrap_or_default();
117 let target_name = name_map
118 .get(&edge.target_entity_id)
119 .copied()
120 .unwrap_or_default();
121
122 if entity_name.is_empty() || target_name.is_empty() {
123 continue;
124 }
125
126 all_facts.push(GraphFact {
127 entity_name: entity_name.to_owned(),
128 relation: edge.relation.clone(),
129 target_name: target_name.to_owned(),
130 fact: edge.fact.clone(),
131 entity_match_score: *seed_score,
132 hop_distance,
133 confidence: edge.confidence,
134 valid_from: Some(edge.valid_from.clone()),
135 edge_type: edge.edge_type,
136 retrieval_count: edge.retrieval_count,
137 edge_id: Some(edge.id),
138 });
139 }
140
141 if !traversed_edge_ids.is_empty()
143 && let Err(e) = store.record_edge_retrieval(&traversed_edge_ids).await
144 {
145 tracing::warn!(error = %e, "graph_recall: failed to record edge retrieval");
146 }
147 if hebbian_enabled
149 && !traversed_edge_ids.is_empty()
150 && let Err(e) = store
151 .apply_hebbian_increment(&traversed_edge_ids, hebbian_lr)
152 .await
153 {
154 tracing::warn!(error = %e, "graph_recall: hebbian increment failed");
155 }
156 }
157
158 let mut scored: Vec<(f32, GraphFact)> = all_facts
162 .into_iter()
163 .map(|f| {
164 let s = f.score_with_decay(temporal_decay_rate, now_secs);
165 (s, f)
166 })
167 .collect();
168 scored.sort_by(|(sa, _), (sb, _)| sb.total_cmp(sa));
169 let mut all_facts: Vec<GraphFact> = scored.into_iter().map(|(_, f)| f).collect();
170
171 let mut seen: HashSet<(String, String, String, EdgeType)> = HashSet::new();
175 all_facts.retain(|f| {
176 seen.insert((
177 f.entity_name.clone(),
178 f.relation.clone(),
179 f.target_name.clone(),
180 f.edge_type,
181 ))
182 });
183
184 all_facts.truncate(limit);
186
187 Ok(all_facts)
188}
189
190async fn seed_embedding_fallback(
209 store: &GraphStore,
210 emb_store: &EmbeddingStore,
211 provider: &zeph_llm::any::AnyProvider,
212 query: &str,
213 limit: usize,
214 fts_map: &mut HashMap<i64, (super::types::Entity, f32)>,
215 embed_timeout: std::time::Duration,
216) -> bool {
217 use zeph_llm::LlmProvider as _;
218 const ENTITY_COLLECTION: &str = "zeph_graph_entities";
219 let embedding = match tokio::time::timeout(embed_timeout, provider.embed(query)).await {
220 Ok(Ok(v)) => v,
221 Ok(Err(e)) => {
222 tracing::warn!(error = %e, "seed fallback: embed() failed, returning empty seeds");
223 return false;
224 }
225 Err(_) => {
226 tracing::warn!("seed fallback: embed() timed out, returning empty seeds");
227 return false;
228 }
229 };
230 match emb_store
231 .search_collection(ENTITY_COLLECTION, &embedding, limit, None)
232 .await
233 {
234 Ok(results) => {
235 for result in results {
236 if let Some(entity_id) = result
237 .payload
238 .get("entity_id")
239 .and_then(serde_json::Value::as_i64)
240 && let Ok(Some(entity)) = store.find_entity_by_id(entity_id).await
241 {
242 fts_map.insert(entity_id, (entity, result.score));
243 }
244 }
245 }
246 Err(e) => {
247 tracing::warn!(error = %e, "seed fallback: embedding search failed");
248 }
249 }
250 true
251}
252
253#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
254pub(crate) async fn find_seed_entities(
255 store: &GraphStore,
256 embeddings: Option<&EmbeddingStore>,
257 provider: &zeph_llm::any::AnyProvider,
258 query: &str,
259 limit: usize,
260 structural_weight: f32,
261 community_cap: usize,
262 embed_timeout: std::time::Duration,
263) -> Result<HashMap<i64, f32>, MemoryError> {
264 use crate::graph::types::ScoredEntity;
265
266 const MAX_WORDS: usize = 5;
267
268 let filtered: Vec<&str> = query
269 .split_whitespace()
270 .filter(|w| w.len() >= 3)
271 .take(MAX_WORDS)
272 .collect();
273 let words: Vec<&str> = if filtered.is_empty() && !query.is_empty() {
274 vec![query]
275 } else {
276 filtered
277 };
278
279 let mut fts_map: HashMap<i64, (super::types::Entity, f32)> = HashMap::new();
281 for word in &words {
282 let ranked = store.find_entities_ranked(word, limit * 2).await?;
283 for (entity, fts_score) in ranked {
284 fts_map
285 .entry(entity.id.0)
286 .and_modify(|(_, s)| *s = s.max(fts_score))
287 .or_insert((entity, fts_score));
288 }
289 }
290
291 if fts_map.is_empty()
293 && let Some(emb_store) = embeddings
294 && !seed_embedding_fallback(
295 store,
296 emb_store,
297 provider,
298 query,
299 limit,
300 &mut fts_map,
301 embed_timeout,
302 )
303 .await
304 {
305 return Ok(HashMap::new());
306 }
307
308 if fts_map.is_empty() {
309 return Ok(HashMap::new());
310 }
311
312 let entity_ids: Vec<i64> = fts_map.keys().copied().collect();
313
314 let structural_scores = store.entity_structural_scores(&entity_ids).await?;
316
317 #[cfg(any(feature = "sqlite", feature = "postgres"))]
319 let community_ids = store.entity_community_ids(&entity_ids).await?;
320 #[cfg(not(any(feature = "sqlite", feature = "postgres")))]
321 let community_ids: HashMap<i64, i64> = HashMap::new();
322
323 let fts_weight = 1.0 - structural_weight;
325 let mut scored: Vec<ScoredEntity> = fts_map
326 .into_values()
327 .map(|(entity, fts_score)| {
328 let struct_score = structural_scores.get(&entity.id.0).copied().unwrap_or(0.0);
329 let community_id = community_ids.get(&entity.id.0).copied();
330 ScoredEntity {
331 entity,
332 fts_score,
333 structural_score: struct_score,
334 community_id,
335 }
336 })
337 .collect();
338
339 scored.sort_by(|a, b| {
341 let score_a = a.fts_score * fts_weight + a.structural_score * structural_weight;
342 let score_b = b.fts_score * fts_weight + b.structural_score * structural_weight;
343 score_b.total_cmp(&score_a)
344 });
345
346 let capped: Vec<&ScoredEntity> = if community_cap == 0 {
348 scored.iter().collect()
349 } else {
350 let mut community_counts: HashMap<i64, usize> = HashMap::new();
351 let mut result: Vec<&ScoredEntity> = Vec::new();
352 for se in &scored {
353 match se.community_id {
354 Some(cid) => {
355 let count = community_counts.entry(cid).or_insert(0);
356 if *count < community_cap {
357 *count += 1;
358 result.push(se);
359 }
360 }
361 None => {
362 result.push(se);
364 }
365 }
366 }
367 result
368 };
369
370 let selected: Vec<&ScoredEntity> = if capped.is_empty() && !scored.is_empty() {
372 scored.iter().take(limit).collect()
373 } else {
374 capped.into_iter().take(limit).collect()
375 };
376
377 let entity_scores: HashMap<i64, f32> = selected
378 .into_iter()
379 .map(|se| {
380 let hybrid = se.fts_score * fts_weight + se.structural_score * structural_weight;
381 (se.entity.id.0, hybrid.clamp(0.1, 1.0))
383 })
384 .collect();
385
386 Ok(entity_scores)
387}
388
389#[allow(clippy::too_many_arguments)] pub async fn graph_recall_activated(
405 store: &GraphStore,
406 embeddings: Option<&EmbeddingStore>,
407 provider: &zeph_llm::any::AnyProvider,
408 query: &str,
409 limit: usize,
410 params: SpreadingActivationParams,
411 edge_types: &[EdgeType],
412 hebbian_enabled: bool,
413 hebbian_lr: f32,
414 embed_timeout: std::time::Duration,
415) -> Result<Vec<ActivatedFact>, MemoryError> {
416 if limit == 0 {
417 return Ok(Vec::new());
418 }
419
420 let entity_scores = find_seed_entities(
421 store,
422 embeddings,
423 provider,
424 query,
425 limit,
426 params.seed_structural_weight,
427 params.seed_community_cap,
428 embed_timeout,
429 )
430 .await?;
431
432 if entity_scores.is_empty() {
433 return Ok(Vec::new());
434 }
435
436 tracing::debug!(
437 seeds = entity_scores.len(),
438 "spreading activation: starting recall"
439 );
440
441 let sa = SpreadingActivation::new(params);
442 let (_, mut facts) = sa.spread(store, entity_scores, edge_types).await?;
443
444 let edge_ids: Vec<i64> = facts.iter().map(|f| f.edge.id).collect();
446 if !edge_ids.is_empty()
447 && let Err(e) = store.record_edge_retrieval(&edge_ids).await
448 {
449 tracing::warn!(error = %e, "graph_recall_activated: failed to record edge retrieval");
450 }
451 if hebbian_enabled
453 && !edge_ids.is_empty()
454 && let Err(e) = store.apply_hebbian_increment(&edge_ids, hebbian_lr).await
455 {
456 tracing::warn!(error = %e, "graph_recall_activated: hebbian increment failed");
457 }
458
459 facts.sort_by(|a, b| b.activation_score.total_cmp(&a.activation_score));
461
462 let mut seen: HashSet<(i64, String, i64, EdgeType)> = HashSet::new();
464 facts.retain(|f| {
465 seen.insert((
466 f.edge.source_entity_id,
467 f.edge.relation.clone(),
468 f.edge.target_entity_id,
469 f.edge.edge_type,
470 ))
471 });
472
473 facts.truncate(limit);
474
475 tracing::debug!(
476 result_count = facts.len(),
477 "spreading activation: recall complete"
478 );
479
480 Ok(facts)
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486 use crate::graph::store::GraphStore;
487 use crate::graph::types::EntityType;
488 use crate::store::SqliteStore;
489 use zeph_llm::any::AnyProvider;
490 use zeph_llm::mock::MockProvider;
491
492 async fn setup_store() -> GraphStore {
493 let store = SqliteStore::new(":memory:").await.unwrap();
494 GraphStore::new(store.pool().clone())
495 }
496
497 fn mock_provider() -> AnyProvider {
498 AnyProvider::Mock(MockProvider::default())
499 }
500
501 #[tokio::test]
502 async fn graph_recall_empty_graph_returns_empty() {
503 let store = setup_store().await;
504 let provider = mock_provider();
505 let result = graph_recall(
506 &store,
507 None,
508 &provider,
509 "anything",
510 10,
511 2,
512 None,
513 0.0,
514 &[],
515 false,
516 0.0,
517 std::time::Duration::from_secs(5),
518 )
519 .await
520 .unwrap();
521 assert!(result.is_empty());
522 }
523
524 #[tokio::test]
525 async fn graph_recall_zero_limit_returns_empty() {
526 let store = setup_store().await;
527 let provider = mock_provider();
528 let result = graph_recall(
529 &store,
530 None,
531 &provider,
532 "user",
533 0,
534 2,
535 None,
536 0.0,
537 &[],
538 false,
539 0.0,
540 std::time::Duration::from_secs(5),
541 )
542 .await
543 .unwrap();
544 assert!(result.is_empty());
545 }
546
547 #[tokio::test]
548 async fn graph_recall_fuzzy_match_returns_facts() {
549 let store = setup_store().await;
550 let user_id = store
551 .upsert_entity("Alice", "Alice", EntityType::Person, None)
552 .await
553 .unwrap()
554 .0;
555 let tool_id = store
556 .upsert_entity("neovim", "neovim", EntityType::Tool, None)
557 .await
558 .unwrap()
559 .0;
560 store
561 .insert_edge(user_id, tool_id, "uses", "Alice uses neovim", 0.9, None)
562 .await
563 .unwrap();
564
565 let provider = mock_provider();
566 let result = graph_recall(
568 &store,
569 None,
570 &provider,
571 "Ali neovim",
572 10,
573 2,
574 None,
575 0.0,
576 &[],
577 false,
578 0.0,
579 std::time::Duration::from_secs(5),
580 )
581 .await
582 .unwrap();
583 assert!(!result.is_empty());
584 assert_eq!(result[0].relation, "uses");
585 }
586
587 #[tokio::test]
588 async fn graph_recall_respects_max_hops() {
589 let store = setup_store().await;
590 let a = store
591 .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
592 .await
593 .unwrap()
594 .0;
595 let b = store
596 .upsert_entity("Beta", "Beta", EntityType::Person, None)
597 .await
598 .unwrap()
599 .0;
600 let c = store
601 .upsert_entity("Gamma", "Gamma", EntityType::Person, None)
602 .await
603 .unwrap()
604 .0;
605 store
606 .insert_edge(a, b, "knows", "Alpha knows Beta", 0.8, None)
607 .await
608 .unwrap();
609 store
610 .insert_edge(b, c, "knows", "Beta knows Gamma", 0.8, None)
611 .await
612 .unwrap();
613
614 let provider = mock_provider();
615 let result = graph_recall(
617 &store,
618 None,
619 &provider,
620 "Alp",
621 10,
622 1,
623 None,
624 0.0,
625 &[],
626 false,
627 0.0,
628 std::time::Duration::from_secs(5),
629 )
630 .await
631 .unwrap();
632 assert!(result.iter().all(|f| f.hop_distance <= 1));
634 }
635
636 #[tokio::test]
637 async fn graph_recall_deduplicates_facts() {
638 let store = setup_store().await;
639 let alice = store
640 .upsert_entity("Alice", "Alice", EntityType::Person, None)
641 .await
642 .unwrap()
643 .0;
644 let bob = store
645 .upsert_entity("Bob", "Bob", EntityType::Person, None)
646 .await
647 .unwrap()
648 .0;
649 store
650 .insert_edge(alice, bob, "knows", "Alice knows Bob", 0.9, None)
651 .await
652 .unwrap();
653
654 let provider = mock_provider();
655 let result = graph_recall(
657 &store,
658 None,
659 &provider,
660 "Ali Bob",
661 10,
662 2,
663 None,
664 0.0,
665 &[],
666 false,
667 0.0,
668 std::time::Duration::from_secs(5),
669 )
670 .await
671 .unwrap();
672
673 let mut seen = std::collections::HashSet::new();
675 for f in &result {
676 let key = (&f.entity_name, &f.relation, &f.target_name);
677 assert!(seen.insert(key), "duplicate fact found: {f:?}");
678 }
679 }
680
681 #[tokio::test]
682 async fn graph_recall_sorts_by_composite_score() {
683 let store = setup_store().await;
684 let a = store
685 .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
686 .await
687 .unwrap()
688 .0;
689 let b = store
690 .upsert_entity("Beta", "Beta", EntityType::Tool, None)
691 .await
692 .unwrap()
693 .0;
694 let c = store
695 .upsert_entity("AlphaGadget", "AlphaGadget", EntityType::Tool, None)
696 .await
697 .unwrap()
698 .0;
699 store
701 .insert_edge(a, b, "uses", "Alpha uses Beta", 1.0, None)
702 .await
703 .unwrap();
704 store
706 .insert_edge(a, c, "mentions", "Alpha mentions AlphaGadget", 0.1, None)
707 .await
708 .unwrap();
709
710 let provider = mock_provider();
711 let result = graph_recall(
712 &store,
713 None,
714 &provider,
715 "Alp",
716 10,
717 2,
718 None,
719 0.0,
720 &[],
721 false,
722 0.0,
723 std::time::Duration::from_secs(5),
724 )
725 .await
726 .unwrap();
727
728 assert!(result.len() >= 2);
730 let s0 = result[0].composite_score();
731 let s1 = result[1].composite_score();
732 assert!(s0 >= s1, "expected sorted desc: {s0} >= {s1}");
733 }
734
735 #[tokio::test]
736 async fn graph_recall_limit_truncates() {
737 let store = setup_store().await;
738 let root = store
739 .upsert_entity("Root", "Root", EntityType::Person, None)
740 .await
741 .unwrap()
742 .0;
743 for i in 0..10 {
744 let target = store
745 .upsert_entity(
746 &format!("Target{i}"),
747 &format!("Target{i}"),
748 EntityType::Tool,
749 None,
750 )
751 .await
752 .unwrap()
753 .0;
754 store
755 .insert_edge(
756 root,
757 target,
758 "has",
759 &format!("Root has Target{i}"),
760 0.8,
761 None,
762 )
763 .await
764 .unwrap();
765 }
766
767 let provider = mock_provider();
768 let result = graph_recall(
769 &store,
770 None,
771 &provider,
772 "Roo",
773 3,
774 2,
775 None,
776 0.0,
777 &[],
778 false,
779 0.0,
780 std::time::Duration::from_secs(5),
781 )
782 .await
783 .unwrap();
784 assert!(result.len() <= 3);
785 }
786
787 #[tokio::test]
788 async fn graph_recall_at_timestamp_excludes_future_edges() {
789 let store = setup_store().await;
790 let alice = store
791 .upsert_entity("Alice", "Alice", EntityType::Person, None)
792 .await
793 .unwrap()
794 .0;
795 let bob = store
796 .upsert_entity("Bob", "Bob", EntityType::Person, None)
797 .await
798 .unwrap()
799 .0;
800 zeph_db::query(
802 sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
803 VALUES (?1, ?2, 'knows', 'Alice knows Bob', 0.9, '2100-01-01 00:00:00')"),
804 )
805 .bind(alice)
806 .bind(bob)
807 .execute(store.pool())
808 .await
809 .unwrap();
810
811 let provider = mock_provider();
812 let result = graph_recall(
814 &store,
815 None,
816 &provider,
817 "Ali",
818 10,
819 2,
820 Some("2026-01-01 00:00:00"),
821 0.0,
822 &[],
823 false,
824 0.0,
825 std::time::Duration::from_secs(5),
826 )
827 .await
828 .unwrap();
829 assert!(result.is_empty(), "future edge should be excluded");
830 }
831
832 #[tokio::test]
833 async fn graph_recall_at_timestamp_excludes_invalidated_edges() {
834 let store = setup_store().await;
835 let alice = store
836 .upsert_entity("Alice", "Alice", EntityType::Person, None)
837 .await
838 .unwrap()
839 .0;
840 let carol = store
841 .upsert_entity("Carol", "Carol", EntityType::Person, None)
842 .await
843 .unwrap()
844 .0;
845 zeph_db::query(
847 sql!("INSERT INTO graph_edges
848 (source_entity_id, target_entity_id, relation, fact, confidence, valid_from, valid_to, expired_at)
849 VALUES (?1, ?2, 'manages', 'Alice manages Carol', 0.8,
850 '2020-01-01 00:00:00', '2021-01-01 00:00:00', '2021-01-01 00:00:00')"),
851 )
852 .bind(alice)
853 .bind(carol)
854 .execute(store.pool())
855 .await
856 .unwrap();
857
858 let provider = mock_provider();
859
860 let result_current = graph_recall(
862 &store,
863 None,
864 &provider,
865 "Ali",
866 10,
867 2,
868 None,
869 0.0,
870 &[],
871 false,
872 0.0,
873 std::time::Duration::from_secs(5),
874 )
875 .await
876 .unwrap();
877 assert!(
878 result_current.is_empty(),
879 "expired edge should be invisible at current time"
880 );
881
882 let result_historical = graph_recall(
884 &store,
885 None,
886 &provider,
887 "Ali",
888 10,
889 2,
890 Some("2020-06-01 00:00:00"),
891 0.0,
892 &[],
893 false,
894 0.0,
895 std::time::Duration::from_secs(5),
896 )
897 .await
898 .unwrap();
899 assert!(
900 !result_historical.is_empty(),
901 "edge should be visible within its validity window"
902 );
903 }
904
905 #[tokio::test]
911 async fn graph_recall_community_cap_guard_non_empty() {
912 let store = setup_store().await;
913 let mut entity_ids = Vec::new();
915 for i in 0..5usize {
916 let id = store
917 .upsert_entity(
918 &format!("Entity{i}"),
919 &format!("entity{i}"),
920 crate::graph::types::EntityType::Concept,
921 None,
922 )
923 .await
924 .unwrap()
925 .0;
926 entity_ids.push(id);
927 }
928
929 let community_id = store
931 .upsert_community("TestComm", "test", &entity_ids, Some("fp"))
932 .await
933 .unwrap();
934 let _ = community_id;
935
936 let hub = store
938 .upsert_entity("Hub", "hub", crate::graph::types::EntityType::Concept, None)
939 .await
940 .unwrap()
941 .0;
942 for &target in &entity_ids {
943 store
944 .insert_edge(hub, target, "has", "Hub has entity", 0.9, None)
945 .await
946 .unwrap();
947 }
948
949 let provider = mock_provider();
950 let result = graph_recall(
954 &store,
955 None,
956 &provider,
957 "entity",
958 10,
959 2,
960 None,
961 0.0,
962 &[],
963 false,
964 0.0,
965 std::time::Duration::from_secs(5),
966 )
967 .await
968 .unwrap();
969 assert!(
971 !result.is_empty(),
972 "SA-INV-10: community cap must not zero out all seeds"
973 );
974 }
975
976 #[tokio::test]
979 async fn graph_recall_no_fts_match_no_embeddings_returns_empty() {
980 let store = setup_store().await;
981 let a = store
983 .upsert_entity(
984 "Zephyr",
985 "zephyr",
986 crate::graph::types::EntityType::Concept,
987 None,
988 )
989 .await
990 .unwrap()
991 .0;
992 let b = store
993 .upsert_entity(
994 "Concept",
995 "concept",
996 crate::graph::types::EntityType::Concept,
997 None,
998 )
999 .await
1000 .unwrap()
1001 .0;
1002 store
1003 .insert_edge(a, b, "rel", "Zephyr rel Concept", 0.9, None)
1004 .await
1005 .unwrap();
1006
1007 let provider = mock_provider();
1008 let result = graph_recall(
1010 &store,
1011 None,
1012 &provider,
1013 "xyzzyquuxfrob",
1014 10,
1015 2,
1016 None,
1017 0.0,
1018 &[],
1019 false,
1020 0.0,
1021 std::time::Duration::from_secs(5),
1022 )
1023 .await
1024 .unwrap();
1025 assert!(
1026 result.is_empty(),
1027 "must return empty (not error) when FTS5 returns 0 and no embeddings available"
1028 );
1029 }
1030
1031 #[tokio::test]
1032 async fn graph_recall_temporal_decay_preserves_order_with_zero_rate() {
1033 let store = setup_store().await;
1034 let a = store
1035 .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
1036 .await
1037 .unwrap()
1038 .0;
1039 let b = store
1040 .upsert_entity("Beta", "Beta", EntityType::Tool, None)
1041 .await
1042 .unwrap()
1043 .0;
1044 let c = store
1045 .upsert_entity("AlphaGadget", "AlphaGadget", EntityType::Tool, None)
1046 .await
1047 .unwrap()
1048 .0;
1049 store
1050 .insert_edge(a, b, "uses", "Alpha uses Beta", 1.0, None)
1051 .await
1052 .unwrap();
1053 store
1054 .insert_edge(a, c, "mentions", "Alpha mentions AlphaGadget", 0.1, None)
1055 .await
1056 .unwrap();
1057
1058 let provider = mock_provider();
1059 let result = graph_recall(
1061 &store,
1062 None,
1063 &provider,
1064 "Alp",
1065 10,
1066 2,
1067 None,
1068 0.0,
1069 &[],
1070 false,
1071 0.0,
1072 std::time::Duration::from_secs(5),
1073 )
1074 .await
1075 .unwrap();
1076 assert!(result.len() >= 2);
1077 let s0 = result[0].composite_score();
1078 let s1 = result[1].composite_score();
1079 assert!(s0 >= s1, "expected sorted desc: {s0} >= {s1}");
1080 }
1081
1082 #[tokio::test]
1085 async fn test_graph_recall_hebbian_enabled_increments_weight() {
1086 let store = setup_store().await;
1087 let provider = mock_provider();
1088
1089 let user = store
1090 .upsert_entity("Alice", "Alice", EntityType::Person, None)
1091 .await
1092 .unwrap()
1093 .0;
1094 let tool = store
1095 .upsert_entity("Vim", "Vim", EntityType::Tool, None)
1096 .await
1097 .unwrap()
1098 .0;
1099 let eid = store
1100 .insert_edge(user, tool, "uses", "Alice uses Vim", 0.9, None)
1101 .await
1102 .unwrap();
1103
1104 let weight_before: f64 = sqlx::query_scalar("SELECT weight FROM graph_edges WHERE id = ?")
1106 .bind(eid)
1107 .fetch_one(store.pool())
1108 .await
1109 .unwrap();
1110 assert!((weight_before - 1.0).abs() < 1e-6);
1111
1112 let _ = graph_recall(
1114 &store,
1115 None,
1116 &provider,
1117 "Alice Vim",
1118 10,
1119 2,
1120 None,
1121 0.0,
1122 &[],
1123 true,
1124 0.5,
1125 std::time::Duration::from_secs(5),
1126 )
1127 .await
1128 .unwrap();
1129
1130 let weight_after: f64 = sqlx::query_scalar("SELECT weight FROM graph_edges WHERE id = ?")
1131 .bind(eid)
1132 .fetch_one(store.pool())
1133 .await
1134 .unwrap();
1135 assert!(
1136 weight_after > weight_before,
1137 "weight must increase after hebbian recall, before={weight_before} after={weight_after}"
1138 );
1139 }
1140
1141 #[tokio::test]
1142 async fn test_graph_recall_hebbian_disabled_no_weight_change() {
1143 let store = setup_store().await;
1144 let provider = mock_provider();
1145
1146 let user = store
1147 .upsert_entity("Bob", "Bob", EntityType::Person, None)
1148 .await
1149 .unwrap()
1150 .0;
1151 let tool = store
1152 .upsert_entity("Emacs", "Emacs", EntityType::Tool, None)
1153 .await
1154 .unwrap()
1155 .0;
1156 let eid = store
1157 .insert_edge(user, tool, "uses", "Bob uses Emacs", 0.9, None)
1158 .await
1159 .unwrap();
1160
1161 let _ = graph_recall(
1162 &store,
1163 None,
1164 &provider,
1165 "Bob Emacs",
1166 10,
1167 2,
1168 None,
1169 0.0,
1170 &[],
1171 false,
1172 0.5,
1173 std::time::Duration::from_secs(5),
1174 )
1175 .await
1176 .unwrap();
1177
1178 let weight_after: f64 = sqlx::query_scalar("SELECT weight FROM graph_edges WHERE id = ?")
1179 .bind(eid)
1180 .fetch_one(store.pool())
1181 .await
1182 .unwrap();
1183 assert!(
1184 (weight_after - 1.0).abs() < 1e-6,
1185 "weight must remain 1.0 when hebbian is disabled, got {weight_after}"
1186 );
1187 }
1188
1189 #[tokio::test]
1191 async fn seed_embedding_fallback_embed_timeout_returns_false() {
1192 let store = setup_store().await;
1194 let emb_store_pool = store.pool().clone();
1195
1196 tokio::time::pause();
1197 let slow = zeph_llm::any::AnyProvider::Mock(
1200 zeph_llm::mock::MockProvider::default().with_embed_delay(10_000),
1201 );
1202
1203 let mut fts_map = std::collections::HashMap::new();
1204
1205 let emb_store = EmbeddingStore::with_store(
1208 Box::new(crate::in_memory_store::InMemoryVectorStore::new()),
1209 emb_store_pool,
1210 );
1211
1212 let fut = seed_embedding_fallback(
1213 &store,
1214 &emb_store,
1215 &slow,
1216 "query",
1217 5,
1218 &mut fts_map,
1219 std::time::Duration::from_secs(5),
1220 );
1221 let (result, ()) = tokio::join!(fut, async {
1222 tokio::time::advance(std::time::Duration::from_secs(6)).await;
1223 });
1224
1225 assert!(
1226 !result,
1227 "seed_embedding_fallback must return false on embed timeout"
1228 );
1229 assert!(
1230 fts_map.is_empty(),
1231 "fts_map must remain empty when embed timed out"
1232 );
1233 }
1234}