1use std::collections::{HashMap, HashSet};
5use std::time::{SystemTime, UNIX_EPOCH};
6#[allow(unused_imports)]
7use zeph_db::sql;
8
9use crate::embedding_store::EmbeddingStore;
10use crate::error::MemoryError;
11
12use super::activation::{ActivatedFact, SpreadingActivation, SpreadingActivationParams};
13use super::store::GraphStore;
14use super::types::{EdgeType, GraphFact};
15
16#[allow(clippy::too_many_arguments, clippy::too_many_lines)] pub async fn graph_recall(
40 store: &GraphStore,
41 embeddings: Option<&crate::embedding_store::EmbeddingStore>,
42 provider: &zeph_llm::any::AnyProvider,
43 query: &str,
44 limit: usize,
45 max_hops: u32,
46 at_timestamp: Option<&str>,
47 temporal_decay_rate: f64,
48 edge_types: &[EdgeType],
49 hebbian_enabled: bool,
50 hebbian_lr: f32,
51) -> Result<Vec<GraphFact>, MemoryError> {
52 const DEFAULT_STRUCTURAL_WEIGHT: f32 = 0.4;
54 const DEFAULT_COMMUNITY_CAP: usize = 3;
55
56 if limit == 0 {
57 return Ok(Vec::new());
58 }
59
60 let entity_scores = find_seed_entities(
62 store,
63 embeddings,
64 provider,
65 query,
66 limit,
67 DEFAULT_STRUCTURAL_WEIGHT,
68 DEFAULT_COMMUNITY_CAP,
69 )
70 .await?;
71
72 if entity_scores.is_empty() {
73 return Ok(Vec::new());
74 }
75
76 let now_secs: i64 = SystemTime::now()
78 .duration_since(UNIX_EPOCH)
79 .map_or(0, |d| d.as_secs().cast_signed());
80
81 let mut all_facts: Vec<GraphFact> = Vec::new();
83
84 for (seed_id, seed_score) in &entity_scores {
85 let (entities, edges, depth_map) = if let Some(ts) = at_timestamp {
86 store.bfs_at_timestamp(*seed_id, max_hops, ts).await?
87 } else if !edge_types.is_empty() {
88 store.bfs_typed(*seed_id, max_hops, edge_types).await?
89 } else {
90 store.bfs_with_depth(*seed_id, max_hops).await?
91 };
92
93 let name_map: HashMap<i64, &str> = entities
96 .iter()
97 .map(|e| (e.id, e.canonical_name.as_str()))
98 .collect();
99
100 let traversed_edge_ids: Vec<i64> = edges.iter().map(|e| e.id).collect();
102
103 for edge in &edges {
104 let Some(&hop_distance) = depth_map
105 .get(&edge.source_entity_id)
106 .or_else(|| depth_map.get(&edge.target_entity_id))
107 else {
108 continue;
109 };
110
111 let entity_name = name_map
112 .get(&edge.source_entity_id)
113 .copied()
114 .unwrap_or_default();
115 let target_name = name_map
116 .get(&edge.target_entity_id)
117 .copied()
118 .unwrap_or_default();
119
120 if entity_name.is_empty() || target_name.is_empty() {
121 continue;
122 }
123
124 all_facts.push(GraphFact {
125 entity_name: entity_name.to_owned(),
126 relation: edge.relation.clone(),
127 target_name: target_name.to_owned(),
128 fact: edge.fact.clone(),
129 entity_match_score: *seed_score,
130 hop_distance,
131 confidence: edge.confidence,
132 valid_from: Some(edge.valid_from.clone()),
133 edge_type: edge.edge_type,
134 retrieval_count: edge.retrieval_count,
135 edge_id: Some(edge.id),
136 });
137 }
138
139 if !traversed_edge_ids.is_empty()
141 && let Err(e) = store.record_edge_retrieval(&traversed_edge_ids).await
142 {
143 tracing::warn!(error = %e, "graph_recall: failed to record edge retrieval");
144 }
145 if hebbian_enabled
147 && !traversed_edge_ids.is_empty()
148 && let Err(e) = store
149 .apply_hebbian_increment(&traversed_edge_ids, hebbian_lr)
150 .await
151 {
152 tracing::warn!(error = %e, "graph_recall: hebbian increment failed");
153 }
154 }
155
156 let mut scored: Vec<(f32, GraphFact)> = all_facts
160 .into_iter()
161 .map(|f| {
162 let s = f.score_with_decay(temporal_decay_rate, now_secs);
163 (s, f)
164 })
165 .collect();
166 scored.sort_by(|(sa, _), (sb, _)| sb.total_cmp(sa));
167 let mut all_facts: Vec<GraphFact> = scored.into_iter().map(|(_, f)| f).collect();
168
169 let mut seen: HashSet<(String, String, String, EdgeType)> = HashSet::new();
173 all_facts.retain(|f| {
174 seen.insert((
175 f.entity_name.clone(),
176 f.relation.clone(),
177 f.target_name.clone(),
178 f.edge_type,
179 ))
180 });
181
182 all_facts.truncate(limit);
184
185 Ok(all_facts)
186}
187
188async fn seed_embedding_fallback(
207 store: &GraphStore,
208 emb_store: &EmbeddingStore,
209 provider: &zeph_llm::any::AnyProvider,
210 query: &str,
211 limit: usize,
212 fts_map: &mut HashMap<i64, (super::types::Entity, f32)>,
213) -> bool {
214 use zeph_llm::LlmProvider as _;
215 const ENTITY_COLLECTION: &str = "zeph_graph_entities";
216 let embedding = match provider.embed(query).await {
217 Ok(v) => v,
218 Err(e) => {
219 tracing::warn!(error = %e, "seed fallback: embed() failed, returning empty seeds");
220 return false;
221 }
222 };
223 match emb_store
224 .search_collection(ENTITY_COLLECTION, &embedding, limit, None)
225 .await
226 {
227 Ok(results) => {
228 for result in results {
229 if let Some(entity_id) = result
230 .payload
231 .get("entity_id")
232 .and_then(serde_json::Value::as_i64)
233 && let Ok(Some(entity)) = store.find_entity_by_id(entity_id).await
234 {
235 fts_map.insert(entity_id, (entity, result.score));
236 }
237 }
238 }
239 Err(e) => {
240 tracing::warn!(error = %e, "seed fallback: embedding search failed");
241 }
242 }
243 true
244}
245
246pub(crate) async fn find_seed_entities(
247 store: &GraphStore,
248 embeddings: Option<&EmbeddingStore>,
249 provider: &zeph_llm::any::AnyProvider,
250 query: &str,
251 limit: usize,
252 structural_weight: f32,
253 community_cap: usize,
254) -> Result<HashMap<i64, f32>, MemoryError> {
255 use crate::graph::types::ScoredEntity;
256
257 const MAX_WORDS: usize = 5;
258
259 let filtered: Vec<&str> = query
260 .split_whitespace()
261 .filter(|w| w.len() >= 3)
262 .take(MAX_WORDS)
263 .collect();
264 let words: Vec<&str> = if filtered.is_empty() && !query.is_empty() {
265 vec![query]
266 } else {
267 filtered
268 };
269
270 let mut fts_map: HashMap<i64, (super::types::Entity, f32)> = HashMap::new();
272 for word in &words {
273 let ranked = store.find_entities_ranked(word, limit * 2).await?;
274 for (entity, fts_score) in ranked {
275 fts_map
276 .entry(entity.id)
277 .and_modify(|(_, s)| *s = s.max(fts_score))
278 .or_insert((entity, fts_score));
279 }
280 }
281
282 if fts_map.is_empty()
284 && let Some(emb_store) = embeddings
285 && !seed_embedding_fallback(store, emb_store, provider, query, limit, &mut fts_map).await
286 {
287 return Ok(HashMap::new());
288 }
289
290 if fts_map.is_empty() {
291 return Ok(HashMap::new());
292 }
293
294 let entity_ids: Vec<i64> = fts_map.keys().copied().collect();
295
296 let structural_scores = store.entity_structural_scores(&entity_ids).await?;
298
299 let community_ids = store.entity_community_ids(&entity_ids).await?;
301
302 let fts_weight = 1.0 - structural_weight;
304 let mut scored: Vec<ScoredEntity> = fts_map
305 .into_values()
306 .map(|(entity, fts_score)| {
307 let struct_score = structural_scores.get(&entity.id).copied().unwrap_or(0.0);
308 let community_id = community_ids.get(&entity.id).copied();
309 ScoredEntity {
310 entity,
311 fts_score,
312 structural_score: struct_score,
313 community_id,
314 }
315 })
316 .collect();
317
318 scored.sort_by(|a, b| {
320 let score_a = a.fts_score * fts_weight + a.structural_score * structural_weight;
321 let score_b = b.fts_score * fts_weight + b.structural_score * structural_weight;
322 score_b.total_cmp(&score_a)
323 });
324
325 let capped: Vec<&ScoredEntity> = if community_cap == 0 {
327 scored.iter().collect()
328 } else {
329 let mut community_counts: HashMap<i64, usize> = HashMap::new();
330 let mut result: Vec<&ScoredEntity> = Vec::new();
331 for se in &scored {
332 match se.community_id {
333 Some(cid) => {
334 let count = community_counts.entry(cid).or_insert(0);
335 if *count < community_cap {
336 *count += 1;
337 result.push(se);
338 }
339 }
340 None => {
341 result.push(se);
343 }
344 }
345 }
346 result
347 };
348
349 let selected: Vec<&ScoredEntity> = if capped.is_empty() && !scored.is_empty() {
351 scored.iter().take(limit).collect()
352 } else {
353 capped.into_iter().take(limit).collect()
354 };
355
356 let entity_scores: HashMap<i64, f32> = selected
357 .into_iter()
358 .map(|se| {
359 let hybrid = se.fts_score * fts_weight + se.structural_score * structural_weight;
360 (se.entity.id, hybrid.clamp(0.1, 1.0))
362 })
363 .collect();
364
365 Ok(entity_scores)
366}
367
368#[allow(clippy::too_many_arguments)] pub async fn graph_recall_activated(
384 store: &GraphStore,
385 embeddings: Option<&EmbeddingStore>,
386 provider: &zeph_llm::any::AnyProvider,
387 query: &str,
388 limit: usize,
389 params: SpreadingActivationParams,
390 edge_types: &[EdgeType],
391 hebbian_enabled: bool,
392 hebbian_lr: f32,
393) -> Result<Vec<ActivatedFact>, MemoryError> {
394 if limit == 0 {
395 return Ok(Vec::new());
396 }
397
398 let entity_scores = find_seed_entities(
399 store,
400 embeddings,
401 provider,
402 query,
403 limit,
404 params.seed_structural_weight,
405 params.seed_community_cap,
406 )
407 .await?;
408
409 if entity_scores.is_empty() {
410 return Ok(Vec::new());
411 }
412
413 tracing::debug!(
414 seeds = entity_scores.len(),
415 "spreading activation: starting recall"
416 );
417
418 let sa = SpreadingActivation::new(params);
419 let (_, mut facts) = sa.spread(store, entity_scores, edge_types).await?;
420
421 let edge_ids: Vec<i64> = facts.iter().map(|f| f.edge.id).collect();
423 if !edge_ids.is_empty()
424 && let Err(e) = store.record_edge_retrieval(&edge_ids).await
425 {
426 tracing::warn!(error = %e, "graph_recall_activated: failed to record edge retrieval");
427 }
428 if hebbian_enabled
430 && !edge_ids.is_empty()
431 && let Err(e) = store.apply_hebbian_increment(&edge_ids, hebbian_lr).await
432 {
433 tracing::warn!(error = %e, "graph_recall_activated: hebbian increment failed");
434 }
435
436 facts.sort_by(|a, b| b.activation_score.total_cmp(&a.activation_score));
438
439 let mut seen: HashSet<(i64, String, i64, EdgeType)> = HashSet::new();
441 facts.retain(|f| {
442 seen.insert((
443 f.edge.source_entity_id,
444 f.edge.relation.clone(),
445 f.edge.target_entity_id,
446 f.edge.edge_type,
447 ))
448 });
449
450 facts.truncate(limit);
451
452 tracing::debug!(
453 result_count = facts.len(),
454 "spreading activation: recall complete"
455 );
456
457 Ok(facts)
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463 use crate::graph::store::GraphStore;
464 use crate::graph::types::EntityType;
465 use crate::store::SqliteStore;
466 use zeph_llm::any::AnyProvider;
467 use zeph_llm::mock::MockProvider;
468
469 async fn setup_store() -> GraphStore {
470 let store = SqliteStore::new(":memory:").await.unwrap();
471 GraphStore::new(store.pool().clone())
472 }
473
474 fn mock_provider() -> AnyProvider {
475 AnyProvider::Mock(MockProvider::default())
476 }
477
478 #[tokio::test]
479 async fn graph_recall_empty_graph_returns_empty() {
480 let store = setup_store().await;
481 let provider = mock_provider();
482 let result = graph_recall(
483 &store,
484 None,
485 &provider,
486 "anything",
487 10,
488 2,
489 None,
490 0.0,
491 &[],
492 false,
493 0.0,
494 )
495 .await
496 .unwrap();
497 assert!(result.is_empty());
498 }
499
500 #[tokio::test]
501 async fn graph_recall_zero_limit_returns_empty() {
502 let store = setup_store().await;
503 let provider = mock_provider();
504 let result = graph_recall(
505 &store,
506 None,
507 &provider,
508 "user",
509 0,
510 2,
511 None,
512 0.0,
513 &[],
514 false,
515 0.0,
516 )
517 .await
518 .unwrap();
519 assert!(result.is_empty());
520 }
521
522 #[tokio::test]
523 async fn graph_recall_fuzzy_match_returns_facts() {
524 let store = setup_store().await;
525 let user_id = store
526 .upsert_entity("Alice", "Alice", EntityType::Person, None)
527 .await
528 .unwrap();
529 let tool_id = store
530 .upsert_entity("neovim", "neovim", EntityType::Tool, None)
531 .await
532 .unwrap();
533 store
534 .insert_edge(user_id, tool_id, "uses", "Alice uses neovim", 0.9, None)
535 .await
536 .unwrap();
537
538 let provider = mock_provider();
539 let result = graph_recall(
541 &store,
542 None,
543 &provider,
544 "Ali neovim",
545 10,
546 2,
547 None,
548 0.0,
549 &[],
550 false,
551 0.0,
552 )
553 .await
554 .unwrap();
555 assert!(!result.is_empty());
556 assert_eq!(result[0].relation, "uses");
557 }
558
559 #[tokio::test]
560 async fn graph_recall_respects_max_hops() {
561 let store = setup_store().await;
562 let a = store
563 .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
564 .await
565 .unwrap();
566 let b = store
567 .upsert_entity("Beta", "Beta", EntityType::Person, None)
568 .await
569 .unwrap();
570 let c = store
571 .upsert_entity("Gamma", "Gamma", EntityType::Person, None)
572 .await
573 .unwrap();
574 store
575 .insert_edge(a, b, "knows", "Alpha knows Beta", 0.8, None)
576 .await
577 .unwrap();
578 store
579 .insert_edge(b, c, "knows", "Beta knows Gamma", 0.8, None)
580 .await
581 .unwrap();
582
583 let provider = mock_provider();
584 let result = graph_recall(
586 &store,
587 None,
588 &provider,
589 "Alp",
590 10,
591 1,
592 None,
593 0.0,
594 &[],
595 false,
596 0.0,
597 )
598 .await
599 .unwrap();
600 assert!(result.iter().all(|f| f.hop_distance <= 1));
602 }
603
604 #[tokio::test]
605 async fn graph_recall_deduplicates_facts() {
606 let store = setup_store().await;
607 let alice = store
608 .upsert_entity("Alice", "Alice", EntityType::Person, None)
609 .await
610 .unwrap();
611 let bob = store
612 .upsert_entity("Bob", "Bob", EntityType::Person, None)
613 .await
614 .unwrap();
615 store
616 .insert_edge(alice, bob, "knows", "Alice knows Bob", 0.9, None)
617 .await
618 .unwrap();
619
620 let provider = mock_provider();
621 let result = graph_recall(
623 &store,
624 None,
625 &provider,
626 "Ali Bob",
627 10,
628 2,
629 None,
630 0.0,
631 &[],
632 false,
633 0.0,
634 )
635 .await
636 .unwrap();
637
638 let mut seen = std::collections::HashSet::new();
640 for f in &result {
641 let key = (&f.entity_name, &f.relation, &f.target_name);
642 assert!(seen.insert(key), "duplicate fact found: {f:?}");
643 }
644 }
645
646 #[tokio::test]
647 async fn graph_recall_sorts_by_composite_score() {
648 let store = setup_store().await;
649 let a = store
650 .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
651 .await
652 .unwrap();
653 let b = store
654 .upsert_entity("Beta", "Beta", EntityType::Tool, None)
655 .await
656 .unwrap();
657 let c = store
658 .upsert_entity("AlphaGadget", "AlphaGadget", EntityType::Tool, None)
659 .await
660 .unwrap();
661 store
663 .insert_edge(a, b, "uses", "Alpha uses Beta", 1.0, None)
664 .await
665 .unwrap();
666 store
668 .insert_edge(a, c, "mentions", "Alpha mentions AlphaGadget", 0.1, None)
669 .await
670 .unwrap();
671
672 let provider = mock_provider();
673 let result = graph_recall(
674 &store,
675 None,
676 &provider,
677 "Alp",
678 10,
679 2,
680 None,
681 0.0,
682 &[],
683 false,
684 0.0,
685 )
686 .await
687 .unwrap();
688
689 assert!(result.len() >= 2);
691 let s0 = result[0].composite_score();
692 let s1 = result[1].composite_score();
693 assert!(s0 >= s1, "expected sorted desc: {s0} >= {s1}");
694 }
695
696 #[tokio::test]
697 async fn graph_recall_limit_truncates() {
698 let store = setup_store().await;
699 let root = store
700 .upsert_entity("Root", "Root", EntityType::Person, None)
701 .await
702 .unwrap();
703 for i in 0..10 {
704 let target = store
705 .upsert_entity(
706 &format!("Target{i}"),
707 &format!("Target{i}"),
708 EntityType::Tool,
709 None,
710 )
711 .await
712 .unwrap();
713 store
714 .insert_edge(
715 root,
716 target,
717 "has",
718 &format!("Root has Target{i}"),
719 0.8,
720 None,
721 )
722 .await
723 .unwrap();
724 }
725
726 let provider = mock_provider();
727 let result = graph_recall(
728 &store,
729 None,
730 &provider,
731 "Roo",
732 3,
733 2,
734 None,
735 0.0,
736 &[],
737 false,
738 0.0,
739 )
740 .await
741 .unwrap();
742 assert!(result.len() <= 3);
743 }
744
745 #[tokio::test]
746 async fn graph_recall_at_timestamp_excludes_future_edges() {
747 let store = setup_store().await;
748 let alice = store
749 .upsert_entity("Alice", "Alice", EntityType::Person, None)
750 .await
751 .unwrap();
752 let bob = store
753 .upsert_entity("Bob", "Bob", EntityType::Person, None)
754 .await
755 .unwrap();
756 zeph_db::query(
758 sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
759 VALUES (?1, ?2, 'knows', 'Alice knows Bob', 0.9, '2100-01-01 00:00:00')"),
760 )
761 .bind(alice)
762 .bind(bob)
763 .execute(store.pool())
764 .await
765 .unwrap();
766
767 let provider = mock_provider();
768 let result = graph_recall(
770 &store,
771 None,
772 &provider,
773 "Ali",
774 10,
775 2,
776 Some("2026-01-01 00:00:00"),
777 0.0,
778 &[],
779 false,
780 0.0,
781 )
782 .await
783 .unwrap();
784 assert!(result.is_empty(), "future edge should be excluded");
785 }
786
787 #[tokio::test]
788 async fn graph_recall_at_timestamp_excludes_invalidated_edges() {
789 let store = setup_store().await;
790 let alice = store
791 .upsert_entity("Alice", "Alice", EntityType::Person, None)
792 .await
793 .unwrap();
794 let carol = store
795 .upsert_entity("Carol", "Carol", EntityType::Person, None)
796 .await
797 .unwrap();
798 zeph_db::query(
800 sql!("INSERT INTO graph_edges
801 (source_entity_id, target_entity_id, relation, fact, confidence, valid_from, valid_to, expired_at)
802 VALUES (?1, ?2, 'manages', 'Alice manages Carol', 0.8,
803 '2020-01-01 00:00:00', '2021-01-01 00:00:00', '2021-01-01 00:00:00')"),
804 )
805 .bind(alice)
806 .bind(carol)
807 .execute(store.pool())
808 .await
809 .unwrap();
810
811 let provider = mock_provider();
812
813 let result_current = graph_recall(
815 &store,
816 None,
817 &provider,
818 "Ali",
819 10,
820 2,
821 None,
822 0.0,
823 &[],
824 false,
825 0.0,
826 )
827 .await
828 .unwrap();
829 assert!(
830 result_current.is_empty(),
831 "expired edge should be invisible at current time"
832 );
833
834 let result_historical = graph_recall(
836 &store,
837 None,
838 &provider,
839 "Ali",
840 10,
841 2,
842 Some("2020-06-01 00:00:00"),
843 0.0,
844 &[],
845 false,
846 0.0,
847 )
848 .await
849 .unwrap();
850 assert!(
851 !result_historical.is_empty(),
852 "edge should be visible within its validity window"
853 );
854 }
855
856 #[tokio::test]
862 async fn graph_recall_community_cap_guard_non_empty() {
863 let store = setup_store().await;
864 let mut entity_ids = Vec::new();
866 for i in 0..5usize {
867 let id = store
868 .upsert_entity(
869 &format!("Entity{i}"),
870 &format!("entity{i}"),
871 crate::graph::types::EntityType::Concept,
872 None,
873 )
874 .await
875 .unwrap();
876 entity_ids.push(id);
877 }
878
879 let community_id = store
881 .upsert_community("TestComm", "test", &entity_ids, Some("fp"))
882 .await
883 .unwrap();
884 let _ = community_id;
885
886 let hub = store
888 .upsert_entity("Hub", "hub", crate::graph::types::EntityType::Concept, None)
889 .await
890 .unwrap();
891 for &target in &entity_ids {
892 store
893 .insert_edge(hub, target, "has", "Hub has entity", 0.9, None)
894 .await
895 .unwrap();
896 }
897
898 let provider = mock_provider();
899 let result = graph_recall(
903 &store,
904 None,
905 &provider,
906 "entity",
907 10,
908 2,
909 None,
910 0.0,
911 &[],
912 false,
913 0.0,
914 )
915 .await
916 .unwrap();
917 assert!(
919 !result.is_empty(),
920 "SA-INV-10: community cap must not zero out all seeds"
921 );
922 }
923
924 #[tokio::test]
927 async fn graph_recall_no_fts_match_no_embeddings_returns_empty() {
928 let store = setup_store().await;
929 let a = store
931 .upsert_entity(
932 "Zephyr",
933 "zephyr",
934 crate::graph::types::EntityType::Concept,
935 None,
936 )
937 .await
938 .unwrap();
939 let b = store
940 .upsert_entity(
941 "Concept",
942 "concept",
943 crate::graph::types::EntityType::Concept,
944 None,
945 )
946 .await
947 .unwrap();
948 store
949 .insert_edge(a, b, "rel", "Zephyr rel Concept", 0.9, None)
950 .await
951 .unwrap();
952
953 let provider = mock_provider();
954 let result = graph_recall(
956 &store,
957 None,
958 &provider,
959 "xyzzyquuxfrob",
960 10,
961 2,
962 None,
963 0.0,
964 &[],
965 false,
966 0.0,
967 )
968 .await
969 .unwrap();
970 assert!(
971 result.is_empty(),
972 "must return empty (not error) when FTS5 returns 0 and no embeddings available"
973 );
974 }
975
976 #[tokio::test]
977 async fn graph_recall_temporal_decay_preserves_order_with_zero_rate() {
978 let store = setup_store().await;
979 let a = store
980 .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
981 .await
982 .unwrap();
983 let b = store
984 .upsert_entity("Beta", "Beta", EntityType::Tool, None)
985 .await
986 .unwrap();
987 let c = store
988 .upsert_entity("AlphaGadget", "AlphaGadget", EntityType::Tool, None)
989 .await
990 .unwrap();
991 store
992 .insert_edge(a, b, "uses", "Alpha uses Beta", 1.0, None)
993 .await
994 .unwrap();
995 store
996 .insert_edge(a, c, "mentions", "Alpha mentions AlphaGadget", 0.1, None)
997 .await
998 .unwrap();
999
1000 let provider = mock_provider();
1001 let result = graph_recall(
1003 &store,
1004 None,
1005 &provider,
1006 "Alp",
1007 10,
1008 2,
1009 None,
1010 0.0,
1011 &[],
1012 false,
1013 0.0,
1014 )
1015 .await
1016 .unwrap();
1017 assert!(result.len() >= 2);
1018 let s0 = result[0].composite_score();
1019 let s1 = result[1].composite_score();
1020 assert!(s0 >= s1, "expected sorted desc: {s0} >= {s1}");
1021 }
1022
1023 #[tokio::test]
1026 async fn test_graph_recall_hebbian_enabled_increments_weight() {
1027 let store = setup_store().await;
1028 let provider = mock_provider();
1029
1030 let user = store
1031 .upsert_entity("Alice", "Alice", EntityType::Person, None)
1032 .await
1033 .unwrap();
1034 let tool = store
1035 .upsert_entity("Vim", "Vim", EntityType::Tool, None)
1036 .await
1037 .unwrap();
1038 let eid = store
1039 .insert_edge(user, tool, "uses", "Alice uses Vim", 0.9, None)
1040 .await
1041 .unwrap();
1042
1043 let weight_before: f64 = sqlx::query_scalar("SELECT weight FROM graph_edges WHERE id = ?")
1045 .bind(eid)
1046 .fetch_one(store.pool())
1047 .await
1048 .unwrap();
1049 assert!((weight_before - 1.0).abs() < 1e-6);
1050
1051 let _ = graph_recall(
1053 &store,
1054 None,
1055 &provider,
1056 "Alice Vim",
1057 10,
1058 2,
1059 None,
1060 0.0,
1061 &[],
1062 true,
1063 0.5,
1064 )
1065 .await
1066 .unwrap();
1067
1068 let weight_after: f64 = sqlx::query_scalar("SELECT weight FROM graph_edges WHERE id = ?")
1069 .bind(eid)
1070 .fetch_one(store.pool())
1071 .await
1072 .unwrap();
1073 assert!(
1074 weight_after > weight_before,
1075 "weight must increase after hebbian recall, before={weight_before} after={weight_after}"
1076 );
1077 }
1078
1079 #[tokio::test]
1080 async fn test_graph_recall_hebbian_disabled_no_weight_change() {
1081 let store = setup_store().await;
1082 let provider = mock_provider();
1083
1084 let user = store
1085 .upsert_entity("Bob", "Bob", EntityType::Person, None)
1086 .await
1087 .unwrap();
1088 let tool = store
1089 .upsert_entity("Emacs", "Emacs", EntityType::Tool, None)
1090 .await
1091 .unwrap();
1092 let eid = store
1093 .insert_edge(user, tool, "uses", "Bob uses Emacs", 0.9, None)
1094 .await
1095 .unwrap();
1096
1097 let _ = graph_recall(
1098 &store,
1099 None,
1100 &provider,
1101 "Bob Emacs",
1102 10,
1103 2,
1104 None,
1105 0.0,
1106 &[],
1107 false,
1108 0.5,
1109 )
1110 .await
1111 .unwrap();
1112
1113 let weight_after: f64 = sqlx::query_scalar("SELECT weight FROM graph_edges WHERE id = ?")
1114 .bind(eid)
1115 .fetch_one(store.pool())
1116 .await
1117 .unwrap();
1118 assert!(
1119 (weight_after - 1.0).abs() < 1e-6,
1120 "weight must remain 1.0 when hebbian is disabled, got {weight_after}"
1121 );
1122 }
1123}