1use std::collections::{HashMap, HashSet};
17use std::sync::OnceLock;
18use std::time::{Instant, SystemTime, UNIX_EPOCH};
19#[allow(unused_imports)]
20use zeph_db::sql;
21
22use crate::embedding_store::EmbeddingStore;
23use crate::error::MemoryError;
24use crate::graph::store::GraphStore;
25use crate::graph::types::{Edge, EdgeType, edge_type_weight, evolved_weight};
26
27#[derive(Debug, Clone)]
29pub struct ActivatedNode {
30 pub entity_id: i64,
32 pub activation: f32,
34 pub depth: u32,
36}
37
38#[derive(Debug, Clone)]
40pub struct ActivatedFact {
41 pub edge: Edge,
43 pub activation_score: f32,
45 pub is_implicit_conflict: bool,
47 pub conflict_candidate_id: Option<i64>,
49}
50
51pub use zeph_common::memory::SpreadingActivationParams;
52
53#[derive(Debug, Clone)]
62pub struct HelaFact {
63 pub edge: Edge,
65 pub score: f32,
67 pub depth: u32,
70 pub path_weight: f32,
73 pub cosine: Option<f32>,
77}
78
79#[derive(Debug, Clone)]
89pub struct HelaSpreadParams {
90 pub spread_depth: u32,
92 pub edge_types: Vec<EdgeType>,
94 pub max_visited: usize,
96 pub step_budget: Option<std::time::Duration>,
100 pub embed_timeout: Option<std::time::Duration>,
103}
104
105impl Default for HelaSpreadParams {
106 fn default() -> Self {
107 Self {
108 spread_depth: 2,
109 edge_types: Vec::new(),
110 max_visited: 200,
111 step_budget: Some(std::time::Duration::from_millis(8)),
112 embed_timeout: Some(std::time::Duration::from_secs(5)),
113 }
114 }
115}
116
117static HELA_DIM_MISMATCH: OnceLock<String> = OnceLock::new();
126
127fn cosine(a: &[f32], b: &[f32]) -> f32 {
131 let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
132 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
133 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
134 let denom = (norm_a * norm_b).max(f32::EPSILON);
135 dot / denom
136}
137
138#[tracing::instrument(
165 name = "memory.graph.hela_spread",
166 skip_all,
167 fields(
168 depth = params.spread_depth,
169 limit,
170 anchor_id = tracing::field::Empty,
171 visited = tracing::field::Empty,
172 scored = tracing::field::Empty,
173 fallback = tracing::field::Empty,
174 )
175)]
176#[allow(clippy::too_many_arguments, clippy::too_many_lines)] pub async fn hela_spreading_recall(
178 store: &GraphStore,
179 embeddings: &EmbeddingStore,
180 provider: &zeph_llm::any::AnyProvider,
181 query: &str,
182 limit: usize,
183 params: &HelaSpreadParams,
184 hebbian_enabled: bool,
185 hebbian_lr: f32,
186) -> Result<Vec<HelaFact>, MemoryError> {
187 use zeph_llm::LlmProvider as _;
188
189 const ENTITY_COLLECTION: &str = "zeph_graph_entities";
190
191 if limit == 0 {
192 return Ok(Vec::new());
193 }
194
195 if HELA_DIM_MISMATCH.get().map(String::as_str) == Some(ENTITY_COLLECTION) {
198 tracing::debug!("hela: dim mismatch previously detected for collection, skipping");
199 return Ok(Vec::new());
200 }
201
202 let q_vec = if let Some(timeout) = params.embed_timeout {
204 tokio::time::timeout(timeout, provider.embed(query))
205 .await
206 .map_err(|_| {
207 tracing::warn!(timeout_ms = timeout.as_millis(), "hela: embed timed out");
208 MemoryError::Timeout("hela embed".into())
209 })??
210 } else {
211 provider.embed(query).await?
212 };
213
214 let t_anchor = Instant::now();
216 let anchor_results = match embeddings
217 .search_collection(ENTITY_COLLECTION, &q_vec, 1, None)
218 .await
219 {
220 Ok(r) => r,
221 Err(e) => {
222 let msg = e.to_string();
223 if msg.contains("wrong vector dimension")
224 || msg.contains("InvalidArgument")
225 || msg.contains("dimension")
226 {
227 let _ = HELA_DIM_MISMATCH.set(ENTITY_COLLECTION.to_owned());
228 tracing::warn!(
229 collection = ENTITY_COLLECTION,
230 error = %e,
231 "hela: vector dimension mismatch — HL-F5 disabled for this collection"
232 );
233 return Ok(Vec::new());
234 }
235 return Err(e);
236 }
237 };
238
239 if params.step_budget.is_some_and(|b| t_anchor.elapsed() > b) {
240 tracing::warn!(
241 elapsed_ms = t_anchor.elapsed().as_millis(),
242 "hela: anchor ANN over budget"
243 );
244 return Ok(Vec::new());
245 }
246
247 let Some(anchor_point) = anchor_results.first() else {
248 tracing::debug!("hela: no anchor found, returning empty");
249 return Ok(Vec::new());
250 };
251 let Some(anchor_entity_id) = anchor_point
252 .payload
253 .get("entity_id")
254 .and_then(serde_json::Value::as_i64)
255 else {
256 tracing::warn!("hela: anchor point missing entity_id payload");
257 return Ok(Vec::new());
258 };
259 let anchor_cosine = anchor_point.score;
260
261 tracing::Span::current().record("anchor_id", anchor_entity_id);
262 tracing::debug!(anchor_entity_id, anchor_cosine, "hela: anchor resolved");
263
264 let spread_depth = params.spread_depth.clamp(1, 6);
265
266 let mut visited: HashMap<i64, (u32, f32, Option<i64>)> = HashMap::new();
269 visited.insert(anchor_entity_id, (0, 1.0, None));
270
271 let mut edge_cache: HashMap<i64, Edge> = HashMap::new();
275 let mut frontier: Vec<i64> = vec![anchor_entity_id];
276
277 for hop in 0..spread_depth {
278 if frontier.is_empty() {
279 break;
280 }
281
282 tracing::debug!(hop, frontier_size = frontier.len(), "hela: starting hop");
283
284 let t_step = Instant::now();
285 let edges = store
286 .edges_for_entities(&frontier, ¶ms.edge_types)
287 .await?;
288 if params.step_budget.is_some_and(|b| t_step.elapsed() > b) {
289 tracing::warn!(
290 hop,
291 elapsed_ms = t_step.elapsed().as_millis(),
292 "hela: edge-fetch over budget"
293 );
294 return Ok(Vec::new());
295 }
296
297 let mut next_frontier: Vec<i64> = Vec::new();
298 let mut next_frontier_set: HashSet<i64> = HashSet::new();
299
300 for edge in &edges {
301 edge_cache.entry(edge.id).or_insert_with(|| edge.clone());
303
304 for &src_id in &frontier {
305 let neighbor = if edge.source_entity_id == src_id {
306 edge.target_entity_id
307 } else if edge.target_entity_id == src_id {
308 edge.source_entity_id
309 } else {
310 continue;
311 };
312
313 let parent_pw = visited.get(&src_id).map_or(1.0, |&(_, pw, _)| pw);
314 let new_pw = parent_pw * edge.weight;
315
316 let entry = visited
321 .entry(neighbor)
322 .or_insert((hop + 1, 0.0_f32, Some(edge.id)));
323 if new_pw > entry.1
325 || ((new_pw - entry.1).abs() < f32::EPSILON && hop + 1 < entry.0)
326 {
327 *entry = (hop + 1, new_pw, Some(edge.id));
328 if next_frontier_set.insert(neighbor) {
329 next_frontier.push(neighbor);
330 }
331 }
332
333 if visited.len() >= params.max_visited {
334 break;
335 }
336 }
337
338 if visited.len() >= params.max_visited {
339 break;
340 }
341 }
342
343 tracing::debug!(
344 hop,
345 edges_fetched = edges.len(),
346 visited = visited.len(),
347 next_frontier = next_frontier.len(),
348 "hela: hop complete"
349 );
350
351 frontier = next_frontier;
352 if visited.len() >= params.max_visited {
353 break;
354 }
355 }
356
357 if visited.len() == 1 {
360 tracing::Span::current().record("fallback", true);
361 tracing::debug!(
362 anchor_entity_id,
363 anchor_cosine,
364 "hela: anchor isolated, falling back to pure ANN"
365 );
366 let fact = HelaFact {
367 edge: Edge::synthetic_anchor(anchor_entity_id),
368 score: anchor_cosine,
369 depth: 0,
370 path_weight: 1.0,
371 cosine: Some(anchor_cosine.clamp(0.0, 1.0)),
372 };
373 return Ok(vec![fact]);
374 }
375
376 let entity_ids: Vec<i64> = visited.keys().copied().collect();
378 let point_id_map = store.qdrant_point_ids_for_entities(&entity_ids).await?;
379 let point_ids: Vec<String> = point_id_map.values().cloned().collect();
380
381 let t_vec = Instant::now();
382 let vec_map = embeddings
383 .get_vectors_from_collection(ENTITY_COLLECTION, &point_ids)
384 .await?;
385 if params.step_budget.is_some_and(|b| t_vec.elapsed() > b) {
386 tracing::warn!(
387 elapsed_ms = t_vec.elapsed().as_millis(),
388 "hela: vectors-batch over budget"
389 );
390 return Ok(Vec::new());
391 }
392
393 let mut facts: Vec<HelaFact> = Vec::with_capacity(visited.len().saturating_sub(1));
398 for (&entity_id, &(depth, path_weight, edge_id_opt)) in &visited {
399 if entity_id == anchor_entity_id {
400 continue;
401 }
402 let Some(edge_id) = edge_id_opt else {
403 continue;
404 };
405 let Some(point_id) = point_id_map.get(&entity_id) else {
406 continue;
407 };
408 let Some(node_vec) = vec_map.get(point_id) else {
409 continue;
410 };
411 if node_vec.len() != q_vec.len() {
412 continue;
414 }
415 let cosine_clamped = cosine(&q_vec, node_vec).max(0.0);
416 let fact_score = path_weight * cosine_clamped;
417 let Some(edge) = edge_cache.get(&edge_id).cloned() else {
418 continue;
419 };
420 facts.push(HelaFact {
421 edge,
422 score: fact_score,
423 depth,
424 path_weight,
425 cosine: Some(cosine_clamped),
426 });
427 }
428
429 facts.sort_by(|a, b| b.score.total_cmp(&a.score));
431 facts.truncate(limit);
432
433 if hebbian_enabled {
438 let edge_ids: Vec<i64> = facts
439 .iter()
440 .map(|f| f.edge.id)
441 .filter(|&id| id != 0) .collect();
443 if !edge_ids.is_empty()
444 && let Err(e) = store.apply_hebbian_increment(&edge_ids, hebbian_lr).await
445 {
446 tracing::warn!(error = %e, "hela: hebbian increment failed");
447 }
448 }
449
450 tracing::Span::current().record("visited", visited.len());
451 tracing::Span::current().record("scored", facts.len());
452
453 Ok(facts)
454}
455
456pub struct SpreadingActivation {
460 params: SpreadingActivationParams,
461}
462
463impl SpreadingActivation {
464 #[must_use]
469 pub fn new(params: SpreadingActivationParams) -> Self {
470 Self { params }
471 }
472
473 pub async fn spread(
489 &self,
490 store: &GraphStore,
491 seeds: HashMap<i64, f32>,
492 edge_types: &[EdgeType],
493 ) -> Result<(Vec<ActivatedNode>, Vec<ActivatedFact>), MemoryError> {
494 if seeds.is_empty() {
495 return Ok((Vec::new(), Vec::new()));
496 }
497
498 let now_secs: i64 = SystemTime::now()
501 .duration_since(UNIX_EPOCH)
502 .map_or(0, |d| d.as_secs().cast_signed());
503
504 let mut activation = self.initialize_seeds(&seeds);
505 let mut activated_facts: Vec<ActivatedFact> = Vec::new();
506
507 for hop in 0..self.params.max_hops {
508 let active_nodes: Vec<(i64, f32)> = activation
509 .iter()
510 .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
511 .map(|(&id, &(score, _))| (id, score))
512 .collect();
513
514 if active_nodes.is_empty() {
515 break;
516 }
517
518 let node_ids: Vec<i64> = active_nodes.iter().map(|(id, _)| *id).collect();
519 let edges = store.edges_for_entities(&node_ids, edge_types).await?;
520 let edge_count = edges.len();
521
522 let next_activation =
523 self.propagate_one_hop(hop, &active_nodes, &edges, &activation, now_secs);
524
525 let pruned_count = self.merge_and_prune(&mut activation, next_activation);
526
527 tracing::debug!(
528 hop,
529 active_nodes = active_nodes.len(),
530 edges_fetched = edge_count,
531 after_merge = activation.len(),
532 pruned = pruned_count,
533 "spreading activation: hop complete"
534 );
535
536 self.collect_activated_facts(&edges, &activation, &mut activated_facts);
537 }
538
539 let result = self.finalize(activation);
540
541 tracing::info!(
542 activated = result.len(),
543 facts = activated_facts.len(),
544 "spreading activation: complete"
545 );
546
547 Ok((result, activated_facts))
548 }
549
550 fn initialize_seeds(&self, seeds: &HashMap<i64, f32>) -> HashMap<i64, (f32, u32)> {
552 let mut activation: HashMap<i64, (f32, u32)> = HashMap::new();
553 let mut seed_count = 0usize;
554 for (entity_id, match_score) in seeds {
556 if *match_score < self.params.activation_threshold {
557 tracing::debug!(
558 entity_id,
559 score = match_score,
560 threshold = self.params.activation_threshold,
561 "spreading activation: seed below threshold, skipping"
562 );
563 continue;
564 }
565 activation.insert(*entity_id, (*match_score, 0));
566 seed_count += 1;
567 }
568 tracing::debug!(
569 seeds = seed_count,
570 "spreading activation: initialized seeds"
571 );
572 activation
573 }
574
575 fn propagate_one_hop(
579 &self,
580 hop: u32,
581 active_nodes: &[(i64, f32)],
582 edges: &[Edge],
583 activation: &HashMap<i64, (f32, u32)>,
584 now_secs: i64,
585 ) -> HashMap<i64, (f32, u32)> {
586 let mut next_activation: HashMap<i64, (f32, u32)> = HashMap::new();
587
588 for edge in edges {
589 for &(active_id, node_score) in active_nodes {
590 let neighbor = if edge.source_entity_id == active_id {
591 edge.target_entity_id
592 } else if edge.target_entity_id == active_id {
593 edge.source_entity_id
594 } else {
595 continue;
596 };
597
598 let current_score = activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
603 let next_score = next_activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
604 if current_score >= self.params.inhibition_threshold
605 || next_score >= self.params.inhibition_threshold
606 {
607 continue;
608 }
609
610 let recency = self.recency_weight(&edge.valid_from, now_secs);
611 let blended = self.params.alpha * edge.confidence_fast
613 + (1.0 - self.params.alpha) * edge.confidence_slow;
614 let edge_weight = evolved_weight(edge.retrieval_count, blended);
615 let type_w = edge_type_weight(edge.edge_type);
616 let spread_value =
617 node_score * self.params.decay_lambda * edge_weight * recency * type_w;
618
619 if spread_value < self.params.activation_threshold {
620 continue;
621 }
622
623 let depth_at_max = hop + 1;
626 let entry = next_activation
627 .entry(neighbor)
628 .or_insert((0.0, depth_at_max));
629 let new_score = (entry.0 + spread_value).min(1.0);
630 if new_score > entry.0 {
631 entry.0 = new_score;
632 entry.1 = depth_at_max;
633 }
634 }
635 }
636
637 next_activation
638 }
639
640 fn merge_and_prune(
644 &self,
645 activation: &mut HashMap<i64, (f32, u32)>,
646 next_activation: HashMap<i64, (f32, u32)>,
647 ) -> usize {
648 for (node_id, (new_score, new_depth)) in next_activation {
649 let entry = activation.entry(node_id).or_insert((0.0, new_depth));
650 if new_score > entry.0 {
651 entry.0 = new_score;
652 entry.1 = new_depth;
653 }
654 }
655
656 if activation.len() > self.params.max_activated_nodes {
657 let before = activation.len();
658 let mut entries: Vec<(i64, (f32, u32))> = activation.drain().collect();
659 entries.sort_by(|(_, (a, _)), (_, (b, _))| b.total_cmp(a));
660 entries.truncate(self.params.max_activated_nodes);
661 *activation = entries.into_iter().collect();
662 before - self.params.max_activated_nodes
663 } else {
664 0
665 }
666 }
667
668 fn collect_activated_facts(
670 &self,
671 edges: &[Edge],
672 activation: &HashMap<i64, (f32, u32)>,
673 activated_facts: &mut Vec<ActivatedFact>,
674 ) {
675 for edge in edges {
676 let src_score = activation
677 .get(&edge.source_entity_id)
678 .map_or(0.0, |&(s, _)| s);
679 let tgt_score = activation
680 .get(&edge.target_entity_id)
681 .map_or(0.0, |&(s, _)| s);
682 if src_score >= self.params.activation_threshold
683 && tgt_score >= self.params.activation_threshold
684 {
685 let activation_score = src_score.max(tgt_score);
686 activated_facts.push(ActivatedFact {
687 edge: edge.clone(),
688 activation_score,
689 is_implicit_conflict: false,
690 conflict_candidate_id: None,
691 });
692 }
693 }
694 }
695
696 fn finalize(&self, activation: HashMap<i64, (f32, u32)>) -> Vec<ActivatedNode> {
698 let mut result: Vec<ActivatedNode> = activation
699 .into_iter()
700 .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
701 .map(|(entity_id, (activation, depth))| ActivatedNode {
702 entity_id,
703 activation,
704 depth,
705 })
706 .collect();
707 result.sort_by(|a, b| b.activation.total_cmp(&a.activation));
708 result
709 }
710
711 #[allow(clippy::cast_precision_loss)]
717 fn recency_weight(&self, valid_from: &str, now_secs: i64) -> f32 {
718 if self.params.temporal_decay_rate <= 0.0 {
719 return 1.0;
720 }
721 let Some(valid_from_secs) = parse_sqlite_datetime_to_unix(valid_from) else {
722 return 1.0;
723 };
724 let age_secs = (now_secs - valid_from_secs).max(0);
725 let age_days = age_secs as f64 / 86_400.0;
726 let weight = 1.0_f64 / (1.0 + age_days * self.params.temporal_decay_rate);
727 #[allow(clippy::cast_possible_truncation)]
729 let w = weight as f32;
730 w
731 }
732}
733
734#[must_use]
739fn parse_sqlite_datetime_to_unix(s: &str) -> Option<i64> {
740 if s.len() < 19 {
741 return None;
742 }
743 let year: i64 = s[0..4].parse().ok()?;
744 let month: i64 = s[5..7].parse().ok()?;
745 let day: i64 = s[8..10].parse().ok()?;
746 let hour: i64 = s[11..13].parse().ok()?;
747 let min: i64 = s[14..16].parse().ok()?;
748 let sec: i64 = s[17..19].parse().ok()?;
749
750 let (y, m) = if month <= 2 {
753 (year - 1, month + 9)
754 } else {
755 (year, month - 3)
756 };
757 let era = y.div_euclid(400);
758 let yoe = y - era * 400;
759 let doy = (153 * m + 2) / 5 + day - 1;
760 let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
761 let days = era * 146_097 + doe - 719_468;
762
763 Some(days * 86_400 + hour * 3_600 + min * 60 + sec)
764}
765
766#[cfg(test)]
767mod tests {
768 use super::*;
769 use crate::graph::GraphStore;
770 use crate::graph::types::EntityType;
771 use crate::store::SqliteStore;
772
773 async fn setup_store() -> GraphStore {
774 let store = SqliteStore::new(":memory:").await.unwrap();
775 GraphStore::new(store.pool().clone())
776 }
777
778 fn default_params() -> SpreadingActivationParams {
779 SpreadingActivationParams {
780 decay_lambda: 0.85,
781 max_hops: 3,
782 activation_threshold: 0.1,
783 inhibition_threshold: 0.8,
784 max_activated_nodes: 50,
785 temporal_decay_rate: 0.0,
786 seed_structural_weight: 0.4,
787 seed_community_cap: 3,
788 alpha: 0.3,
789 }
790 }
791
792 #[tokio::test]
795 async fn spread_empty_graph_no_edges_no_facts() {
796 let store = setup_store().await;
797 let sa = SpreadingActivation::new(default_params());
798 let seeds = HashMap::from([(1_i64, 1.0_f32)]);
799 let (nodes, facts) = sa.spread(&store, seeds, &[]).await.unwrap();
800 assert_eq!(nodes.len(), 1, "seed must be in activated nodes");
802 assert_eq!(nodes[0].entity_id, 1);
803 assert!((nodes[0].activation - 1.0).abs() < 1e-6);
804 assert!(
806 facts.is_empty(),
807 "expected no activated facts on empty graph"
808 );
809 }
810
811 #[tokio::test]
813 async fn spread_empty_seeds_returns_empty() {
814 let store = setup_store().await;
815 let sa = SpreadingActivation::new(default_params());
816 let (nodes, facts) = sa.spread(&store, HashMap::new(), &[]).await.unwrap();
817 assert!(nodes.is_empty());
818 assert!(facts.is_empty());
819 }
820
821 #[tokio::test]
823 async fn spread_single_seed_no_edges_returns_seed() {
824 let store = setup_store().await;
825 let alice = store
826 .upsert_entity("Alice", "Alice", EntityType::Person, None)
827 .await
828 .unwrap()
829 .0;
830
831 let sa = SpreadingActivation::new(default_params());
832 let seeds = HashMap::from([(alice, 1.0_f32)]);
833 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
834 assert_eq!(nodes.len(), 1);
835 assert_eq!(nodes[0].entity_id, alice);
836 assert_eq!(nodes[0].depth, 0);
837 assert!((nodes[0].activation - 1.0).abs() < 1e-6);
838 }
839
840 #[tokio::test]
842 async fn spread_linear_chain_all_activated_with_decay() {
843 let store = setup_store().await;
844 let a = store
845 .upsert_entity("A", "A", EntityType::Person, None)
846 .await
847 .unwrap()
848 .0;
849 let b = store
850 .upsert_entity("B", "B", EntityType::Person, None)
851 .await
852 .unwrap()
853 .0;
854 let c = store
855 .upsert_entity("C", "C", EntityType::Person, None)
856 .await
857 .unwrap()
858 .0;
859 store
860 .insert_edge(a, b, "knows", "A knows B", 1.0, None)
861 .await
862 .unwrap();
863 store
864 .insert_edge(b, c, "knows", "B knows C", 1.0, None)
865 .await
866 .unwrap();
867
868 let mut cfg = default_params();
869 cfg.max_hops = 3;
870 cfg.decay_lambda = 0.9;
871 let sa = SpreadingActivation::new(cfg);
872 let seeds = HashMap::from([(a, 1.0_f32)]);
873 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
874
875 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
876 assert!(ids.contains(&a), "A (seed) must be activated");
877 assert!(ids.contains(&b), "B (hop 1) must be activated");
878 assert!(ids.contains(&c), "C (hop 2) must be activated");
879
880 let score_a = nodes.iter().find(|n| n.entity_id == a).unwrap().activation;
882 let score_b = nodes.iter().find(|n| n.entity_id == b).unwrap().activation;
883 let score_c = nodes.iter().find(|n| n.entity_id == c).unwrap().activation;
884 assert!(
885 score_a > score_b,
886 "seed A should have higher activation than hop-1 B"
887 );
888 assert!(
889 score_b > score_c,
890 "hop-1 B should have higher activation than hop-2 C"
891 );
892 }
893
894 #[tokio::test]
896 async fn spread_linear_chain_max_hops_limits_reach() {
897 let store = setup_store().await;
898 let a = store
899 .upsert_entity("A", "A", EntityType::Person, None)
900 .await
901 .unwrap()
902 .0;
903 let b = store
904 .upsert_entity("B", "B", EntityType::Person, None)
905 .await
906 .unwrap()
907 .0;
908 let c = store
909 .upsert_entity("C", "C", EntityType::Person, None)
910 .await
911 .unwrap()
912 .0;
913 store
914 .insert_edge(a, b, "knows", "A knows B", 1.0, None)
915 .await
916 .unwrap();
917 store
918 .insert_edge(b, c, "knows", "B knows C", 1.0, None)
919 .await
920 .unwrap();
921
922 let mut cfg = default_params();
923 cfg.max_hops = 1;
924 let sa = SpreadingActivation::new(cfg);
925 let seeds = HashMap::from([(a, 1.0_f32)]);
926 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
927
928 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
929 assert!(ids.contains(&a), "A must be activated (seed)");
930 assert!(ids.contains(&b), "B must be activated (hop 1)");
931 assert!(!ids.contains(&c), "C must NOT be activated with max_hops=1");
932 }
933
934 #[tokio::test]
938 async fn spread_diamond_graph_convergence() {
939 let store = setup_store().await;
940 let a = store
941 .upsert_entity("A", "A", EntityType::Person, None)
942 .await
943 .unwrap()
944 .0;
945 let b = store
946 .upsert_entity("B", "B", EntityType::Person, None)
947 .await
948 .unwrap()
949 .0;
950 let c = store
951 .upsert_entity("C", "C", EntityType::Person, None)
952 .await
953 .unwrap()
954 .0;
955 let d = store
956 .upsert_entity("D", "D", EntityType::Person, None)
957 .await
958 .unwrap()
959 .0;
960 store
961 .insert_edge(a, b, "rel", "A-B", 1.0, None)
962 .await
963 .unwrap();
964 store
965 .insert_edge(a, c, "rel", "A-C", 1.0, None)
966 .await
967 .unwrap();
968 store
969 .insert_edge(b, d, "rel", "B-D", 1.0, None)
970 .await
971 .unwrap();
972 store
973 .insert_edge(c, d, "rel", "C-D", 1.0, None)
974 .await
975 .unwrap();
976
977 let mut cfg = default_params();
978 cfg.max_hops = 3;
979 cfg.decay_lambda = 0.9;
980 cfg.inhibition_threshold = 0.95; let sa = SpreadingActivation::new(cfg);
982 let seeds = HashMap::from([(a, 1.0_f32)]);
983 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
984
985 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
986 assert!(ids.contains(&d), "D must be activated via diamond paths");
987
988 let node_d = nodes.iter().find(|n| n.entity_id == d).unwrap();
990 assert_eq!(node_d.depth, 2, "D should be at depth 2");
991 }
992
993 #[tokio::test]
995 async fn spread_inhibition_prevents_runaway() {
996 let store = setup_store().await;
997 let hub = store
999 .upsert_entity("Hub", "Hub", EntityType::Concept, None)
1000 .await
1001 .unwrap()
1002 .0;
1003
1004 for i in 0..5 {
1005 let leaf = store
1006 .upsert_entity(
1007 &format!("Leaf{i}"),
1008 &format!("Leaf{i}"),
1009 EntityType::Concept,
1010 None,
1011 )
1012 .await
1013 .unwrap()
1014 .0;
1015 store
1016 .insert_edge(hub, leaf, "has", &format!("Hub has Leaf{i}"), 1.0, None)
1017 .await
1018 .unwrap();
1019 store
1021 .insert_edge(
1022 leaf,
1023 hub,
1024 "part_of",
1025 &format!("Leaf{i} part_of Hub"),
1026 1.0,
1027 None,
1028 )
1029 .await
1030 .unwrap();
1031 }
1032
1033 let mut cfg = default_params();
1035 cfg.inhibition_threshold = 0.8;
1036 cfg.max_hops = 3;
1037 let sa = SpreadingActivation::new(cfg);
1038 let seeds = HashMap::from([(hub, 1.0_f32)]);
1039 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1040
1041 let hub_node = nodes.iter().find(|n| n.entity_id == hub);
1043 assert!(hub_node.is_some(), "hub must be in results");
1044 assert!(
1045 hub_node.unwrap().activation <= 1.0,
1046 "activation must not exceed 1.0"
1047 );
1048 }
1049
1050 #[tokio::test]
1052 async fn spread_max_activated_nodes_cap_enforced() {
1053 let store = setup_store().await;
1054 let root = store
1055 .upsert_entity("Root", "Root", EntityType::Person, None)
1056 .await
1057 .unwrap()
1058 .0;
1059
1060 for i in 0..20 {
1062 let leaf = store
1063 .upsert_entity(
1064 &format!("Node{i}"),
1065 &format!("Node{i}"),
1066 EntityType::Concept,
1067 None,
1068 )
1069 .await
1070 .unwrap()
1071 .0;
1072 store
1073 .insert_edge(root, leaf, "has", &format!("Root has Node{i}"), 0.9, None)
1074 .await
1075 .unwrap();
1076 }
1077
1078 let max_nodes = 5;
1079 let cfg = SpreadingActivationParams {
1080 max_activated_nodes: max_nodes,
1081 max_hops: 2,
1082 ..default_params()
1083 };
1084 let sa = SpreadingActivation::new(cfg);
1085 let seeds = HashMap::from([(root, 1.0_f32)]);
1086 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1087
1088 assert!(
1089 nodes.len() <= max_nodes,
1090 "activation must be capped at {max_nodes} nodes, got {}",
1091 nodes.len()
1092 );
1093 }
1094
1095 #[tokio::test]
1097 async fn spread_temporal_decay_recency_effect() {
1098 let store = setup_store().await;
1099 let src = store
1100 .upsert_entity("Src", "Src", EntityType::Person, None)
1101 .await
1102 .unwrap()
1103 .0;
1104 let recent = store
1105 .upsert_entity("Recent", "Recent", EntityType::Tool, None)
1106 .await
1107 .unwrap()
1108 .0;
1109 let old = store
1110 .upsert_entity("Old", "Old", EntityType::Tool, None)
1111 .await
1112 .unwrap()
1113 .0;
1114
1115 store
1117 .insert_edge(src, recent, "uses", "Src uses Recent", 1.0, None)
1118 .await
1119 .unwrap();
1120
1121 zeph_db::query(
1123 sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
1124 VALUES (?1, ?2, 'uses', 'Src uses Old', 1.0, '1970-01-01 00:00:00')"),
1125 )
1126 .bind(src)
1127 .bind(old)
1128 .execute(store.pool())
1129 .await
1130 .unwrap();
1131
1132 let mut cfg = default_params();
1133 cfg.max_hops = 2;
1134 let sa = SpreadingActivation::new(SpreadingActivationParams {
1136 temporal_decay_rate: 0.5,
1137 ..cfg
1138 });
1139 let seeds = HashMap::from([(src, 1.0_f32)]);
1140 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1141
1142 let score_recent = nodes
1143 .iter()
1144 .find(|n| n.entity_id == recent)
1145 .map_or(0.0, |n| n.activation);
1146 let score_old = nodes
1147 .iter()
1148 .find(|n| n.entity_id == old)
1149 .map_or(0.0, |n| n.activation);
1150
1151 assert!(
1152 score_recent > score_old,
1153 "recent edge ({score_recent}) must produce higher activation than old edge ({score_old})"
1154 );
1155 }
1156
1157 #[tokio::test]
1159 async fn spread_edge_type_filter_excludes_other_types() {
1160 let store = setup_store().await;
1161 let a = store
1162 .upsert_entity("A", "A", EntityType::Person, None)
1163 .await
1164 .unwrap()
1165 .0;
1166 let b_semantic = store
1167 .upsert_entity("BSemantic", "BSemantic", EntityType::Tool, None)
1168 .await
1169 .unwrap()
1170 .0;
1171 let c_causal = store
1172 .upsert_entity("CCausal", "CCausal", EntityType::Concept, None)
1173 .await
1174 .unwrap()
1175 .0;
1176
1177 store
1179 .insert_edge(a, b_semantic, "uses", "A uses BSemantic", 1.0, None)
1180 .await
1181 .unwrap();
1182
1183 zeph_db::query(
1185 sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from, edge_type)
1186 VALUES (?1, ?2, 'caused', 'A caused CCausal', 1.0, datetime('now'), 'causal')"),
1187 )
1188 .bind(a)
1189 .bind(c_causal)
1190 .execute(store.pool())
1191 .await
1192 .unwrap();
1193
1194 let cfg = default_params();
1195 let sa = SpreadingActivation::new(cfg);
1196
1197 let seeds = HashMap::from([(a, 1.0_f32)]);
1199 let (nodes, _) = sa
1200 .spread(&store, seeds, &[EdgeType::Semantic])
1201 .await
1202 .unwrap();
1203
1204 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
1205 assert!(
1206 ids.contains(&b_semantic),
1207 "BSemantic must be activated via semantic edge"
1208 );
1209 assert!(
1210 !ids.contains(&c_causal),
1211 "CCausal must NOT be activated when filtering to semantic only"
1212 );
1213 }
1214
1215 #[tokio::test]
1217 async fn spread_large_seed_list() {
1218 let store = setup_store().await;
1219 let mut seeds = HashMap::new();
1220
1221 for i in 0..100i64 {
1223 let id = store
1224 .upsert_entity(
1225 &format!("Entity{i}"),
1226 &format!("entity{i}"),
1227 EntityType::Concept,
1228 None,
1229 )
1230 .await
1231 .unwrap()
1232 .0;
1233 seeds.insert(id, 1.0_f32);
1234 }
1235
1236 let cfg = default_params();
1237 let sa = SpreadingActivation::new(cfg);
1238 let result = sa.spread(&store, seeds, &[]).await;
1240 assert!(
1241 result.is_ok(),
1242 "large seed list must not error: {:?}",
1243 result.err()
1244 );
1245 }
1246
1247 #[test]
1250 fn hela_cosine_identical_vectors() {
1251 let v = vec![1.0_f32, 0.0, 0.0];
1252 assert!(
1253 (cosine(&v, &v) - 1.0).abs() < 1e-6,
1254 "identical vectors → cosine 1.0"
1255 );
1256 }
1257
1258 #[test]
1259 fn hela_cosine_orthogonal_vectors() {
1260 let a = vec![1.0_f32, 0.0];
1261 let b = vec![0.0_f32, 1.0];
1262 assert!(
1263 cosine(&a, &b).abs() < 1e-6,
1264 "orthogonal vectors → cosine 0.0"
1265 );
1266 }
1267
1268 #[test]
1269 fn hela_cosine_anti_correlated() {
1270 let a = vec![1.0_f32, 0.0];
1271 let b = vec![-1.0_f32, 0.0];
1272 assert!(
1273 cosine(&a, &b) < 0.0,
1274 "anti-correlated vectors → negative cosine"
1275 );
1276 }
1277
1278 #[test]
1279 fn hela_cosine_zero_vector_no_panic() {
1280 let a = vec![0.0_f32, 0.0];
1281 let b = vec![1.0_f32, 0.0];
1282 let result = cosine(&a, &b);
1284 assert!(
1285 result.is_finite(),
1286 "zero-norm vector must yield finite cosine"
1287 );
1288 }
1289
1290 #[test]
1291 fn hela_spread_params_default_depth_is_two() {
1292 let p = HelaSpreadParams::default();
1293 assert_eq!(p.spread_depth, 2);
1294 assert!(p.step_budget.is_some());
1295 assert!(p.edge_types.is_empty());
1296 assert_eq!(p.max_visited, 200);
1297 }
1298
1299 #[test]
1300 fn hela_spread_params_default_embed_timeout_is_some() {
1301 let p = HelaSpreadParams::default();
1302 assert!(
1303 p.embed_timeout.is_some(),
1304 "default embed_timeout must be Some (5 s)"
1305 );
1306 }
1307
1308 #[tokio::test]
1311 async fn hela_spreading_recall_embed_timeout_returns_error() {
1312 use std::time::Duration;
1313 use zeph_llm::any::AnyProvider;
1314 use zeph_llm::mock::MockProvider;
1315
1316 use crate::embedding_store::EmbeddingStore;
1317 use crate::error::MemoryError;
1318 use crate::in_memory_store::InMemoryVectorStore;
1319
1320 let store = setup_store().await;
1321
1322 let mock = MockProvider::default().with_embed_delay(500);
1324 let provider = AnyProvider::Mock(mock);
1325
1326 let sqlite = crate::store::SqliteStore::with_pool_size(":memory:", 1)
1327 .await
1328 .unwrap();
1329 let embeddings =
1330 EmbeddingStore::with_store(Box::new(InMemoryVectorStore::new()), sqlite.pool().clone());
1331
1332 let params = HelaSpreadParams {
1333 embed_timeout: Some(Duration::from_millis(50)),
1334 ..Default::default()
1335 };
1336
1337 let result = hela_spreading_recall(
1338 &store,
1339 &embeddings,
1340 &provider,
1341 "test query",
1342 5,
1343 ¶ms,
1344 false,
1345 0.0,
1346 )
1347 .await;
1348
1349 assert!(
1350 matches!(result, Err(MemoryError::Timeout(_))),
1351 "expected Err(MemoryError::Timeout), got {result:?}"
1352 );
1353 }
1354
1355 #[tokio::test]
1358 async fn hela_spreading_recall_no_timeout_does_not_wrap() {
1359 use zeph_llm::any::AnyProvider;
1360 use zeph_llm::mock::MockProvider;
1361
1362 use crate::embedding_store::EmbeddingStore;
1363 use crate::in_memory_store::InMemoryVectorStore;
1364
1365 let store = setup_store().await;
1366
1367 let mock = MockProvider::default().with_embed_delay(0);
1368 let provider = AnyProvider::Mock(mock);
1369
1370 let sqlite = crate::store::SqliteStore::with_pool_size(":memory:", 1)
1371 .await
1372 .unwrap();
1373 let embeddings =
1374 EmbeddingStore::with_store(Box::new(InMemoryVectorStore::new()), sqlite.pool().clone());
1375
1376 let params = HelaSpreadParams {
1377 embed_timeout: None,
1378 ..Default::default()
1379 };
1380
1381 let result = hela_spreading_recall(
1384 &store,
1385 &embeddings,
1386 &provider,
1387 "test query",
1388 5,
1389 ¶ms,
1390 false,
1391 0.0,
1392 )
1393 .await;
1394
1395 assert!(
1398 !matches!(result, Err(crate::error::MemoryError::Timeout(_))),
1399 "embed_timeout: None must not produce a Timeout error, got {result:?}"
1400 );
1401 }
1402
1403 #[test]
1404 fn hela_synthetic_anchor_edge_id_is_zero() {
1405 let edge = Edge::synthetic_anchor(42);
1406 assert_eq!(
1407 edge.id, 0,
1408 "synthetic anchor must have id = 0 to be excluded from Hebbian"
1409 );
1410 assert_eq!(edge.source_entity_id, 42);
1411 assert_eq!(edge.target_entity_id, 42);
1412 }
1413
1414 #[test]
1415 fn hela_negative_cosine_clamped_to_zero_in_score() {
1416 let anti = vec![-1.0_f32, 0.0];
1418 let query = vec![1.0_f32, 0.0];
1419 let cosine_raw = cosine(&query, &anti);
1420 assert!(cosine_raw < 0.0);
1421 let clamped = cosine_raw.max(0.0);
1422 let fact_score = 0.9_f32 * clamped;
1423 assert!(
1424 fact_score < f32::EPSILON,
1425 "anti-correlated score must be 0.0"
1426 );
1427 }
1428
1429 #[test]
1430 fn hela_path_weight_multiplicative() {
1431 let w1 = 0.8_f32;
1433 let w2 = 0.5_f32;
1434 let expected = w1 * w2;
1435 assert!((expected - 0.4).abs() < 1e-6);
1436 }
1437
1438 #[test]
1439 fn hela_max_path_weight_on_multipath() {
1440 let pw_a = 0.9_f32; let pw_b = 0.3_f32; let kept = pw_a.max(pw_b);
1444 assert!(
1445 (kept - 0.9).abs() < 1e-6,
1446 "multi-path resolution must keep maximum path_weight"
1447 );
1448 }
1449
1450 #[test]
1451 fn hela_fact_score_formula() {
1452 let path_weight = 0.8_f32;
1453 let cosine_clamped = 0.75_f32;
1454 let expected = path_weight * cosine_clamped;
1455 assert!((expected - 0.6).abs() < 1e-5);
1457 }
1458
1459 #[tokio::test]
1464 async fn synapse_blend_uses_alpha_not_raw_confidence() {
1465 let store = SqliteStore::new(":memory:").await.unwrap();
1466 let gs = GraphStore::new(store.pool().clone()).with_benna_rates(0.5, 0.05);
1467
1468 let alpha = 0.3_f32;
1469 let params = SpreadingActivationParams {
1470 alpha,
1471 ..default_params()
1472 };
1473
1474 let src = gs
1475 .upsert_entity("Blend_src", "Blend_src", EntityType::Person, None)
1476 .await
1477 .unwrap()
1478 .0;
1479 let tgt = gs
1480 .upsert_entity("Blend_tgt", "Blend_tgt", EntityType::Concept, None)
1481 .await
1482 .unwrap()
1483 .0;
1484
1485 gs.insert_edge_typed(
1487 src,
1488 tgt,
1489 "knows",
1490 "Blend_src knows Blend_tgt",
1491 0.5,
1492 None,
1493 crate::graph::types::EdgeType::Semantic,
1494 )
1495 .await
1496 .unwrap();
1497 gs.insert_edge_typed(
1499 src,
1500 tgt,
1501 "knows",
1502 "Blend_src knows Blend_tgt",
1503 0.9,
1504 None,
1505 crate::graph::types::EdgeType::Semantic,
1506 )
1507 .await
1508 .unwrap();
1509
1510 let sa = SpreadingActivation::new(params);
1511 let seeds = HashMap::from([(src, 1.0_f32)]);
1512 let (_nodes, facts) = sa.spread(&gs, seeds, &[]).await.unwrap();
1513
1514 assert!(!facts.is_empty(), "spread must return at least one fact");
1519 let raw_score = facts[0].activation_score;
1520 assert!(
1522 (raw_score - 0.9_f32).abs() > 0.05,
1523 "blend score {raw_score} must differ from raw confidence 0.9 — alpha blend not applied"
1524 );
1525 }
1526}