1use std::collections::HashMap;
17use std::sync::OnceLock;
18use std::time::{Instant, SystemTime, UNIX_EPOCH};
19#[allow(unused_imports)]
20use zeph_db::sql;
21
22use crate::embedding_store::EmbeddingStore;
23use crate::error::MemoryError;
24use crate::graph::store::GraphStore;
25use crate::graph::types::{Edge, EdgeType, edge_type_weight, evolved_weight};
26
27#[derive(Debug, Clone)]
29pub struct ActivatedNode {
30 pub entity_id: i64,
32 pub activation: f32,
34 pub depth: u32,
36}
37
38#[derive(Debug, Clone)]
40pub struct ActivatedFact {
41 pub edge: Edge,
43 pub activation_score: f32,
45 pub is_implicit_conflict: bool,
47 pub conflict_candidate_id: Option<i64>,
49}
50
51pub use zeph_common::memory::SpreadingActivationParams;
52
53#[derive(Debug, Clone)]
62pub struct HelaFact {
63 pub edge: Edge,
65 pub score: f32,
67 pub depth: u32,
70 pub path_weight: f32,
73 pub cosine: Option<f32>,
77}
78
79#[derive(Debug, Clone)]
89pub struct HelaSpreadParams {
90 pub spread_depth: u32,
92 pub edge_types: Vec<EdgeType>,
94 pub max_visited: usize,
96 pub step_budget: Option<std::time::Duration>,
100 pub embed_timeout: Option<std::time::Duration>,
103}
104
105impl Default for HelaSpreadParams {
106 fn default() -> Self {
107 Self {
108 spread_depth: 2,
109 edge_types: Vec::new(),
110 max_visited: 200,
111 step_budget: Some(std::time::Duration::from_millis(8)),
112 embed_timeout: Some(std::time::Duration::from_secs(5)),
113 }
114 }
115}
116
117static HELA_DIM_MISMATCH: OnceLock<String> = OnceLock::new();
126
127fn cosine(a: &[f32], b: &[f32]) -> f32 {
131 let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
132 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
133 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
134 let denom = (norm_a * norm_b).max(f32::EPSILON);
135 dot / denom
136}
137
138#[tracing::instrument(
165 name = "memory.graph.hela_spread",
166 skip_all,
167 fields(
168 depth = params.spread_depth,
169 limit,
170 anchor_id = tracing::field::Empty,
171 visited = tracing::field::Empty,
172 scored = tracing::field::Empty,
173 fallback = tracing::field::Empty,
174 )
175)]
176#[allow(clippy::too_many_arguments, clippy::too_many_lines)] pub async fn hela_spreading_recall(
178 store: &GraphStore,
179 embeddings: &EmbeddingStore,
180 provider: &zeph_llm::any::AnyProvider,
181 query: &str,
182 limit: usize,
183 params: &HelaSpreadParams,
184 hebbian_enabled: bool,
185 hebbian_lr: f32,
186) -> Result<Vec<HelaFact>, MemoryError> {
187 use zeph_llm::LlmProvider as _;
188
189 const ENTITY_COLLECTION: &str = "zeph_graph_entities";
190
191 if limit == 0 {
192 return Ok(Vec::new());
193 }
194
195 if HELA_DIM_MISMATCH.get().map(String::as_str) == Some(ENTITY_COLLECTION) {
198 tracing::debug!("hela: dim mismatch previously detected for collection, skipping");
199 return Ok(Vec::new());
200 }
201
202 let q_vec = if let Some(timeout) = params.embed_timeout {
204 tokio::time::timeout(timeout, provider.embed(query))
205 .await
206 .map_err(|_| {
207 tracing::warn!(timeout_ms = timeout.as_millis(), "hela: embed timed out");
208 MemoryError::Timeout("hela embed".into())
209 })??
210 } else {
211 provider.embed(query).await?
212 };
213
214 let t_anchor = Instant::now();
216 let anchor_results = match embeddings
217 .search_collection(ENTITY_COLLECTION, &q_vec, 1, None)
218 .await
219 {
220 Ok(r) => r,
221 Err(e) => {
222 let msg = e.to_string();
223 if msg.contains("wrong vector dimension")
224 || msg.contains("InvalidArgument")
225 || msg.contains("dimension")
226 {
227 let _ = HELA_DIM_MISMATCH.set(ENTITY_COLLECTION.to_owned());
228 tracing::warn!(
229 collection = ENTITY_COLLECTION,
230 error = %e,
231 "hela: vector dimension mismatch — HL-F5 disabled for this collection"
232 );
233 return Ok(Vec::new());
234 }
235 return Err(e);
236 }
237 };
238
239 if params.step_budget.is_some_and(|b| t_anchor.elapsed() > b) {
240 tracing::warn!(
241 elapsed_ms = t_anchor.elapsed().as_millis(),
242 "hela: anchor ANN over budget"
243 );
244 return Ok(Vec::new());
245 }
246
247 let Some(anchor_point) = anchor_results.first() else {
248 tracing::debug!("hela: no anchor found, returning empty");
249 return Ok(Vec::new());
250 };
251 let Some(anchor_entity_id) = anchor_point
252 .payload
253 .get("entity_id")
254 .and_then(serde_json::Value::as_i64)
255 else {
256 tracing::warn!("hela: anchor point missing entity_id payload");
257 return Ok(Vec::new());
258 };
259 let anchor_cosine = anchor_point.score;
260
261 tracing::Span::current().record("anchor_id", anchor_entity_id);
262 tracing::debug!(anchor_entity_id, anchor_cosine, "hela: anchor resolved");
263
264 let spread_depth = params.spread_depth.clamp(1, 6);
265
266 let mut visited: HashMap<i64, (u32, f32, Option<i64>)> = HashMap::new();
269 visited.insert(anchor_entity_id, (0, 1.0, None));
270
271 let mut edge_cache: HashMap<i64, Edge> = HashMap::new();
275 let mut frontier: Vec<i64> = vec![anchor_entity_id];
276
277 for hop in 0..spread_depth {
278 if frontier.is_empty() {
279 break;
280 }
281
282 tracing::debug!(hop, frontier_size = frontier.len(), "hela: starting hop");
283
284 let t_step = Instant::now();
285 let edges = store
286 .edges_for_entities(&frontier, ¶ms.edge_types)
287 .await?;
288 if params.step_budget.is_some_and(|b| t_step.elapsed() > b) {
289 tracing::warn!(
290 hop,
291 elapsed_ms = t_step.elapsed().as_millis(),
292 "hela: edge-fetch over budget"
293 );
294 return Ok(Vec::new());
295 }
296
297 let mut next_frontier: Vec<i64> = Vec::new();
298
299 for edge in &edges {
300 edge_cache.entry(edge.id).or_insert_with(|| edge.clone());
302
303 for &src_id in &frontier {
304 let neighbor = if edge.source_entity_id == src_id {
305 edge.target_entity_id
306 } else if edge.target_entity_id == src_id {
307 edge.source_entity_id
308 } else {
309 continue;
310 };
311
312 let parent_pw = visited.get(&src_id).map_or(1.0, |&(_, pw, _)| pw);
313 let new_pw = parent_pw * edge.weight;
314
315 let entry = visited
320 .entry(neighbor)
321 .or_insert((hop + 1, 0.0_f32, Some(edge.id)));
322 if new_pw > entry.1
324 || ((new_pw - entry.1).abs() < f32::EPSILON && hop + 1 < entry.0)
325 {
326 *entry = (hop + 1, new_pw, Some(edge.id));
327 if !next_frontier.contains(&neighbor) {
328 next_frontier.push(neighbor);
329 }
330 }
331
332 if visited.len() >= params.max_visited {
333 break;
334 }
335 }
336
337 if visited.len() >= params.max_visited {
338 break;
339 }
340 }
341
342 tracing::debug!(
343 hop,
344 edges_fetched = edges.len(),
345 visited = visited.len(),
346 next_frontier = next_frontier.len(),
347 "hela: hop complete"
348 );
349
350 frontier = next_frontier;
351 if visited.len() >= params.max_visited {
352 break;
353 }
354 }
355
356 if visited.len() == 1 {
359 tracing::Span::current().record("fallback", true);
360 tracing::debug!(
361 anchor_entity_id,
362 anchor_cosine,
363 "hela: anchor isolated, falling back to pure ANN"
364 );
365 let fact = HelaFact {
366 edge: Edge::synthetic_anchor(anchor_entity_id),
367 score: anchor_cosine,
368 depth: 0,
369 path_weight: 1.0,
370 cosine: Some(anchor_cosine.clamp(0.0, 1.0)),
371 };
372 return Ok(vec![fact]);
373 }
374
375 let entity_ids: Vec<i64> = visited.keys().copied().collect();
377 let point_id_map = store.qdrant_point_ids_for_entities(&entity_ids).await?;
378 let point_ids: Vec<String> = point_id_map.values().cloned().collect();
379
380 let t_vec = Instant::now();
381 let vec_map = embeddings
382 .get_vectors_from_collection(ENTITY_COLLECTION, &point_ids)
383 .await?;
384 if params.step_budget.is_some_and(|b| t_vec.elapsed() > b) {
385 tracing::warn!(
386 elapsed_ms = t_vec.elapsed().as_millis(),
387 "hela: vectors-batch over budget"
388 );
389 return Ok(Vec::new());
390 }
391
392 let mut facts: Vec<HelaFact> = Vec::with_capacity(visited.len().saturating_sub(1));
397 for (&entity_id, &(depth, path_weight, edge_id_opt)) in &visited {
398 if entity_id == anchor_entity_id {
399 continue;
400 }
401 let Some(edge_id) = edge_id_opt else {
402 continue;
403 };
404 let Some(point_id) = point_id_map.get(&entity_id) else {
405 continue;
406 };
407 let Some(node_vec) = vec_map.get(point_id) else {
408 continue;
409 };
410 if node_vec.len() != q_vec.len() {
411 continue;
413 }
414 let cosine_clamped = cosine(&q_vec, node_vec).max(0.0);
415 let fact_score = path_weight * cosine_clamped;
416 let Some(edge) = edge_cache.get(&edge_id).cloned() else {
417 continue;
418 };
419 facts.push(HelaFact {
420 edge,
421 score: fact_score,
422 depth,
423 path_weight,
424 cosine: Some(cosine_clamped),
425 });
426 }
427
428 facts.sort_by(|a, b| b.score.total_cmp(&a.score));
430 facts.truncate(limit);
431
432 if hebbian_enabled {
437 let edge_ids: Vec<i64> = facts
438 .iter()
439 .map(|f| f.edge.id)
440 .filter(|&id| id != 0) .collect();
442 if !edge_ids.is_empty()
443 && let Err(e) = store.apply_hebbian_increment(&edge_ids, hebbian_lr).await
444 {
445 tracing::warn!(error = %e, "hela: hebbian increment failed");
446 }
447 }
448
449 tracing::Span::current().record("visited", visited.len());
450 tracing::Span::current().record("scored", facts.len());
451
452 Ok(facts)
453}
454
455pub struct SpreadingActivation {
459 params: SpreadingActivationParams,
460}
461
462impl SpreadingActivation {
463 #[must_use]
468 pub fn new(params: SpreadingActivationParams) -> Self {
469 Self { params }
470 }
471
472 pub async fn spread(
488 &self,
489 store: &GraphStore,
490 seeds: HashMap<i64, f32>,
491 edge_types: &[EdgeType],
492 ) -> Result<(Vec<ActivatedNode>, Vec<ActivatedFact>), MemoryError> {
493 if seeds.is_empty() {
494 return Ok((Vec::new(), Vec::new()));
495 }
496
497 let now_secs: i64 = SystemTime::now()
500 .duration_since(UNIX_EPOCH)
501 .map_or(0, |d| d.as_secs().cast_signed());
502
503 let mut activation = self.initialize_seeds(&seeds);
504 let mut activated_facts: Vec<ActivatedFact> = Vec::new();
505
506 for hop in 0..self.params.max_hops {
507 let active_nodes: Vec<(i64, f32)> = activation
508 .iter()
509 .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
510 .map(|(&id, &(score, _))| (id, score))
511 .collect();
512
513 if active_nodes.is_empty() {
514 break;
515 }
516
517 let node_ids: Vec<i64> = active_nodes.iter().map(|(id, _)| *id).collect();
518 let edges = store.edges_for_entities(&node_ids, edge_types).await?;
519 let edge_count = edges.len();
520
521 let next_activation =
522 self.propagate_one_hop(hop, &active_nodes, &edges, &activation, now_secs);
523
524 let pruned_count = self.merge_and_prune(&mut activation, next_activation);
525
526 tracing::debug!(
527 hop,
528 active_nodes = active_nodes.len(),
529 edges_fetched = edge_count,
530 after_merge = activation.len(),
531 pruned = pruned_count,
532 "spreading activation: hop complete"
533 );
534
535 self.collect_activated_facts(&edges, &activation, &mut activated_facts);
536 }
537
538 let result = self.finalize(activation);
539
540 tracing::info!(
541 activated = result.len(),
542 facts = activated_facts.len(),
543 "spreading activation: complete"
544 );
545
546 Ok((result, activated_facts))
547 }
548
549 fn initialize_seeds(&self, seeds: &HashMap<i64, f32>) -> HashMap<i64, (f32, u32)> {
551 let mut activation: HashMap<i64, (f32, u32)> = HashMap::new();
552 let mut seed_count = 0usize;
553 for (entity_id, match_score) in seeds {
555 if *match_score < self.params.activation_threshold {
556 tracing::debug!(
557 entity_id,
558 score = match_score,
559 threshold = self.params.activation_threshold,
560 "spreading activation: seed below threshold, skipping"
561 );
562 continue;
563 }
564 activation.insert(*entity_id, (*match_score, 0));
565 seed_count += 1;
566 }
567 tracing::debug!(
568 seeds = seed_count,
569 "spreading activation: initialized seeds"
570 );
571 activation
572 }
573
574 fn propagate_one_hop(
578 &self,
579 hop: u32,
580 active_nodes: &[(i64, f32)],
581 edges: &[Edge],
582 activation: &HashMap<i64, (f32, u32)>,
583 now_secs: i64,
584 ) -> HashMap<i64, (f32, u32)> {
585 let mut next_activation: HashMap<i64, (f32, u32)> = HashMap::new();
586
587 for edge in edges {
588 for &(active_id, node_score) in active_nodes {
589 let neighbor = if edge.source_entity_id == active_id {
590 edge.target_entity_id
591 } else if edge.target_entity_id == active_id {
592 edge.source_entity_id
593 } else {
594 continue;
595 };
596
597 let current_score = activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
602 let next_score = next_activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
603 if current_score >= self.params.inhibition_threshold
604 || next_score >= self.params.inhibition_threshold
605 {
606 continue;
607 }
608
609 let recency = self.recency_weight(&edge.valid_from, now_secs);
610 let edge_weight = evolved_weight(edge.retrieval_count, edge.confidence);
611 let type_w = edge_type_weight(edge.edge_type);
612 let spread_value =
613 node_score * self.params.decay_lambda * edge_weight * recency * type_w;
614
615 if spread_value < self.params.activation_threshold {
616 continue;
617 }
618
619 let depth_at_max = hop + 1;
622 let entry = next_activation
623 .entry(neighbor)
624 .or_insert((0.0, depth_at_max));
625 let new_score = (entry.0 + spread_value).min(1.0);
626 if new_score > entry.0 {
627 entry.0 = new_score;
628 entry.1 = depth_at_max;
629 }
630 }
631 }
632
633 next_activation
634 }
635
636 fn merge_and_prune(
640 &self,
641 activation: &mut HashMap<i64, (f32, u32)>,
642 next_activation: HashMap<i64, (f32, u32)>,
643 ) -> usize {
644 for (node_id, (new_score, new_depth)) in next_activation {
645 let entry = activation.entry(node_id).or_insert((0.0, new_depth));
646 if new_score > entry.0 {
647 entry.0 = new_score;
648 entry.1 = new_depth;
649 }
650 }
651
652 if activation.len() > self.params.max_activated_nodes {
653 let before = activation.len();
654 let mut entries: Vec<(i64, (f32, u32))> = activation.drain().collect();
655 entries.sort_by(|(_, (a, _)), (_, (b, _))| b.total_cmp(a));
656 entries.truncate(self.params.max_activated_nodes);
657 *activation = entries.into_iter().collect();
658 before - self.params.max_activated_nodes
659 } else {
660 0
661 }
662 }
663
664 fn collect_activated_facts(
666 &self,
667 edges: &[Edge],
668 activation: &HashMap<i64, (f32, u32)>,
669 activated_facts: &mut Vec<ActivatedFact>,
670 ) {
671 for edge in edges {
672 let src_score = activation
673 .get(&edge.source_entity_id)
674 .map_or(0.0, |&(s, _)| s);
675 let tgt_score = activation
676 .get(&edge.target_entity_id)
677 .map_or(0.0, |&(s, _)| s);
678 if src_score >= self.params.activation_threshold
679 && tgt_score >= self.params.activation_threshold
680 {
681 let activation_score = src_score.max(tgt_score);
682 activated_facts.push(ActivatedFact {
683 edge: edge.clone(),
684 activation_score,
685 is_implicit_conflict: false,
686 conflict_candidate_id: None,
687 });
688 }
689 }
690 }
691
692 fn finalize(&self, activation: HashMap<i64, (f32, u32)>) -> Vec<ActivatedNode> {
694 let mut result: Vec<ActivatedNode> = activation
695 .into_iter()
696 .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
697 .map(|(entity_id, (activation, depth))| ActivatedNode {
698 entity_id,
699 activation,
700 depth,
701 })
702 .collect();
703 result.sort_by(|a, b| b.activation.total_cmp(&a.activation));
704 result
705 }
706
707 #[allow(clippy::cast_precision_loss)]
713 fn recency_weight(&self, valid_from: &str, now_secs: i64) -> f32 {
714 if self.params.temporal_decay_rate <= 0.0 {
715 return 1.0;
716 }
717 let Some(valid_from_secs) = parse_sqlite_datetime_to_unix(valid_from) else {
718 return 1.0;
719 };
720 let age_secs = (now_secs - valid_from_secs).max(0);
721 let age_days = age_secs as f64 / 86_400.0;
722 let weight = 1.0_f64 / (1.0 + age_days * self.params.temporal_decay_rate);
723 #[allow(clippy::cast_possible_truncation)]
725 let w = weight as f32;
726 w
727 }
728}
729
730#[must_use]
735fn parse_sqlite_datetime_to_unix(s: &str) -> Option<i64> {
736 if s.len() < 19 {
737 return None;
738 }
739 let year: i64 = s[0..4].parse().ok()?;
740 let month: i64 = s[5..7].parse().ok()?;
741 let day: i64 = s[8..10].parse().ok()?;
742 let hour: i64 = s[11..13].parse().ok()?;
743 let min: i64 = s[14..16].parse().ok()?;
744 let sec: i64 = s[17..19].parse().ok()?;
745
746 let (y, m) = if month <= 2 {
749 (year - 1, month + 9)
750 } else {
751 (year, month - 3)
752 };
753 let era = y.div_euclid(400);
754 let yoe = y - era * 400;
755 let doy = (153 * m + 2) / 5 + day - 1;
756 let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
757 let days = era * 146_097 + doe - 719_468;
758
759 Some(days * 86_400 + hour * 3_600 + min * 60 + sec)
760}
761
762#[cfg(test)]
763mod tests {
764 use super::*;
765 use crate::graph::GraphStore;
766 use crate::graph::types::EntityType;
767 use crate::store::SqliteStore;
768
769 async fn setup_store() -> GraphStore {
770 let store = SqliteStore::new(":memory:").await.unwrap();
771 GraphStore::new(store.pool().clone())
772 }
773
774 fn default_params() -> SpreadingActivationParams {
775 SpreadingActivationParams {
776 decay_lambda: 0.85,
777 max_hops: 3,
778 activation_threshold: 0.1,
779 inhibition_threshold: 0.8,
780 max_activated_nodes: 50,
781 temporal_decay_rate: 0.0,
782 seed_structural_weight: 0.4,
783 seed_community_cap: 3,
784 }
785 }
786
787 #[tokio::test]
790 async fn spread_empty_graph_no_edges_no_facts() {
791 let store = setup_store().await;
792 let sa = SpreadingActivation::new(default_params());
793 let seeds = HashMap::from([(1_i64, 1.0_f32)]);
794 let (nodes, facts) = sa.spread(&store, seeds, &[]).await.unwrap();
795 assert_eq!(nodes.len(), 1, "seed must be in activated nodes");
797 assert_eq!(nodes[0].entity_id, 1);
798 assert!((nodes[0].activation - 1.0).abs() < 1e-6);
799 assert!(
801 facts.is_empty(),
802 "expected no activated facts on empty graph"
803 );
804 }
805
806 #[tokio::test]
808 async fn spread_empty_seeds_returns_empty() {
809 let store = setup_store().await;
810 let sa = SpreadingActivation::new(default_params());
811 let (nodes, facts) = sa.spread(&store, HashMap::new(), &[]).await.unwrap();
812 assert!(nodes.is_empty());
813 assert!(facts.is_empty());
814 }
815
816 #[tokio::test]
818 async fn spread_single_seed_no_edges_returns_seed() {
819 let store = setup_store().await;
820 let alice = store
821 .upsert_entity("Alice", "Alice", EntityType::Person, None)
822 .await
823 .unwrap()
824 .0;
825
826 let sa = SpreadingActivation::new(default_params());
827 let seeds = HashMap::from([(alice, 1.0_f32)]);
828 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
829 assert_eq!(nodes.len(), 1);
830 assert_eq!(nodes[0].entity_id, alice);
831 assert_eq!(nodes[0].depth, 0);
832 assert!((nodes[0].activation - 1.0).abs() < 1e-6);
833 }
834
835 #[tokio::test]
837 async fn spread_linear_chain_all_activated_with_decay() {
838 let store = setup_store().await;
839 let a = store
840 .upsert_entity("A", "A", EntityType::Person, None)
841 .await
842 .unwrap()
843 .0;
844 let b = store
845 .upsert_entity("B", "B", EntityType::Person, None)
846 .await
847 .unwrap()
848 .0;
849 let c = store
850 .upsert_entity("C", "C", EntityType::Person, None)
851 .await
852 .unwrap()
853 .0;
854 store
855 .insert_edge(a, b, "knows", "A knows B", 1.0, None)
856 .await
857 .unwrap();
858 store
859 .insert_edge(b, c, "knows", "B knows C", 1.0, None)
860 .await
861 .unwrap();
862
863 let mut cfg = default_params();
864 cfg.max_hops = 3;
865 cfg.decay_lambda = 0.9;
866 let sa = SpreadingActivation::new(cfg);
867 let seeds = HashMap::from([(a, 1.0_f32)]);
868 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
869
870 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
871 assert!(ids.contains(&a), "A (seed) must be activated");
872 assert!(ids.contains(&b), "B (hop 1) must be activated");
873 assert!(ids.contains(&c), "C (hop 2) must be activated");
874
875 let score_a = nodes.iter().find(|n| n.entity_id == a).unwrap().activation;
877 let score_b = nodes.iter().find(|n| n.entity_id == b).unwrap().activation;
878 let score_c = nodes.iter().find(|n| n.entity_id == c).unwrap().activation;
879 assert!(
880 score_a > score_b,
881 "seed A should have higher activation than hop-1 B"
882 );
883 assert!(
884 score_b > score_c,
885 "hop-1 B should have higher activation than hop-2 C"
886 );
887 }
888
889 #[tokio::test]
891 async fn spread_linear_chain_max_hops_limits_reach() {
892 let store = setup_store().await;
893 let a = store
894 .upsert_entity("A", "A", EntityType::Person, None)
895 .await
896 .unwrap()
897 .0;
898 let b = store
899 .upsert_entity("B", "B", EntityType::Person, None)
900 .await
901 .unwrap()
902 .0;
903 let c = store
904 .upsert_entity("C", "C", EntityType::Person, None)
905 .await
906 .unwrap()
907 .0;
908 store
909 .insert_edge(a, b, "knows", "A knows B", 1.0, None)
910 .await
911 .unwrap();
912 store
913 .insert_edge(b, c, "knows", "B knows C", 1.0, None)
914 .await
915 .unwrap();
916
917 let mut cfg = default_params();
918 cfg.max_hops = 1;
919 let sa = SpreadingActivation::new(cfg);
920 let seeds = HashMap::from([(a, 1.0_f32)]);
921 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
922
923 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
924 assert!(ids.contains(&a), "A must be activated (seed)");
925 assert!(ids.contains(&b), "B must be activated (hop 1)");
926 assert!(!ids.contains(&c), "C must NOT be activated with max_hops=1");
927 }
928
929 #[tokio::test]
933 async fn spread_diamond_graph_convergence() {
934 let store = setup_store().await;
935 let a = store
936 .upsert_entity("A", "A", EntityType::Person, None)
937 .await
938 .unwrap()
939 .0;
940 let b = store
941 .upsert_entity("B", "B", EntityType::Person, None)
942 .await
943 .unwrap()
944 .0;
945 let c = store
946 .upsert_entity("C", "C", EntityType::Person, None)
947 .await
948 .unwrap()
949 .0;
950 let d = store
951 .upsert_entity("D", "D", EntityType::Person, None)
952 .await
953 .unwrap()
954 .0;
955 store
956 .insert_edge(a, b, "rel", "A-B", 1.0, None)
957 .await
958 .unwrap();
959 store
960 .insert_edge(a, c, "rel", "A-C", 1.0, None)
961 .await
962 .unwrap();
963 store
964 .insert_edge(b, d, "rel", "B-D", 1.0, None)
965 .await
966 .unwrap();
967 store
968 .insert_edge(c, d, "rel", "C-D", 1.0, None)
969 .await
970 .unwrap();
971
972 let mut cfg = default_params();
973 cfg.max_hops = 3;
974 cfg.decay_lambda = 0.9;
975 cfg.inhibition_threshold = 0.95; let sa = SpreadingActivation::new(cfg);
977 let seeds = HashMap::from([(a, 1.0_f32)]);
978 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
979
980 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
981 assert!(ids.contains(&d), "D must be activated via diamond paths");
982
983 let node_d = nodes.iter().find(|n| n.entity_id == d).unwrap();
985 assert_eq!(node_d.depth, 2, "D should be at depth 2");
986 }
987
988 #[tokio::test]
990 async fn spread_inhibition_prevents_runaway() {
991 let store = setup_store().await;
992 let hub = store
994 .upsert_entity("Hub", "Hub", EntityType::Concept, None)
995 .await
996 .unwrap()
997 .0;
998
999 for i in 0..5 {
1000 let leaf = store
1001 .upsert_entity(
1002 &format!("Leaf{i}"),
1003 &format!("Leaf{i}"),
1004 EntityType::Concept,
1005 None,
1006 )
1007 .await
1008 .unwrap()
1009 .0;
1010 store
1011 .insert_edge(hub, leaf, "has", &format!("Hub has Leaf{i}"), 1.0, None)
1012 .await
1013 .unwrap();
1014 store
1016 .insert_edge(
1017 leaf,
1018 hub,
1019 "part_of",
1020 &format!("Leaf{i} part_of Hub"),
1021 1.0,
1022 None,
1023 )
1024 .await
1025 .unwrap();
1026 }
1027
1028 let mut cfg = default_params();
1030 cfg.inhibition_threshold = 0.8;
1031 cfg.max_hops = 3;
1032 let sa = SpreadingActivation::new(cfg);
1033 let seeds = HashMap::from([(hub, 1.0_f32)]);
1034 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1035
1036 let hub_node = nodes.iter().find(|n| n.entity_id == hub);
1038 assert!(hub_node.is_some(), "hub must be in results");
1039 assert!(
1040 hub_node.unwrap().activation <= 1.0,
1041 "activation must not exceed 1.0"
1042 );
1043 }
1044
1045 #[tokio::test]
1047 async fn spread_max_activated_nodes_cap_enforced() {
1048 let store = setup_store().await;
1049 let root = store
1050 .upsert_entity("Root", "Root", EntityType::Person, None)
1051 .await
1052 .unwrap()
1053 .0;
1054
1055 for i in 0..20 {
1057 let leaf = store
1058 .upsert_entity(
1059 &format!("Node{i}"),
1060 &format!("Node{i}"),
1061 EntityType::Concept,
1062 None,
1063 )
1064 .await
1065 .unwrap()
1066 .0;
1067 store
1068 .insert_edge(root, leaf, "has", &format!("Root has Node{i}"), 0.9, None)
1069 .await
1070 .unwrap();
1071 }
1072
1073 let max_nodes = 5;
1074 let cfg = SpreadingActivationParams {
1075 max_activated_nodes: max_nodes,
1076 max_hops: 2,
1077 ..default_params()
1078 };
1079 let sa = SpreadingActivation::new(cfg);
1080 let seeds = HashMap::from([(root, 1.0_f32)]);
1081 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1082
1083 assert!(
1084 nodes.len() <= max_nodes,
1085 "activation must be capped at {max_nodes} nodes, got {}",
1086 nodes.len()
1087 );
1088 }
1089
1090 #[tokio::test]
1092 async fn spread_temporal_decay_recency_effect() {
1093 let store = setup_store().await;
1094 let src = store
1095 .upsert_entity("Src", "Src", EntityType::Person, None)
1096 .await
1097 .unwrap()
1098 .0;
1099 let recent = store
1100 .upsert_entity("Recent", "Recent", EntityType::Tool, None)
1101 .await
1102 .unwrap()
1103 .0;
1104 let old = store
1105 .upsert_entity("Old", "Old", EntityType::Tool, None)
1106 .await
1107 .unwrap()
1108 .0;
1109
1110 store
1112 .insert_edge(src, recent, "uses", "Src uses Recent", 1.0, None)
1113 .await
1114 .unwrap();
1115
1116 zeph_db::query(
1118 sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
1119 VALUES (?1, ?2, 'uses', 'Src uses Old', 1.0, '1970-01-01 00:00:00')"),
1120 )
1121 .bind(src)
1122 .bind(old)
1123 .execute(store.pool())
1124 .await
1125 .unwrap();
1126
1127 let mut cfg = default_params();
1128 cfg.max_hops = 2;
1129 let sa = SpreadingActivation::new(SpreadingActivationParams {
1131 temporal_decay_rate: 0.5,
1132 ..cfg
1133 });
1134 let seeds = HashMap::from([(src, 1.0_f32)]);
1135 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1136
1137 let score_recent = nodes
1138 .iter()
1139 .find(|n| n.entity_id == recent)
1140 .map_or(0.0, |n| n.activation);
1141 let score_old = nodes
1142 .iter()
1143 .find(|n| n.entity_id == old)
1144 .map_or(0.0, |n| n.activation);
1145
1146 assert!(
1147 score_recent > score_old,
1148 "recent edge ({score_recent}) must produce higher activation than old edge ({score_old})"
1149 );
1150 }
1151
1152 #[tokio::test]
1154 async fn spread_edge_type_filter_excludes_other_types() {
1155 let store = setup_store().await;
1156 let a = store
1157 .upsert_entity("A", "A", EntityType::Person, None)
1158 .await
1159 .unwrap()
1160 .0;
1161 let b_semantic = store
1162 .upsert_entity("BSemantic", "BSemantic", EntityType::Tool, None)
1163 .await
1164 .unwrap()
1165 .0;
1166 let c_causal = store
1167 .upsert_entity("CCausal", "CCausal", EntityType::Concept, None)
1168 .await
1169 .unwrap()
1170 .0;
1171
1172 store
1174 .insert_edge(a, b_semantic, "uses", "A uses BSemantic", 1.0, None)
1175 .await
1176 .unwrap();
1177
1178 zeph_db::query(
1180 sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from, edge_type)
1181 VALUES (?1, ?2, 'caused', 'A caused CCausal', 1.0, datetime('now'), 'causal')"),
1182 )
1183 .bind(a)
1184 .bind(c_causal)
1185 .execute(store.pool())
1186 .await
1187 .unwrap();
1188
1189 let cfg = default_params();
1190 let sa = SpreadingActivation::new(cfg);
1191
1192 let seeds = HashMap::from([(a, 1.0_f32)]);
1194 let (nodes, _) = sa
1195 .spread(&store, seeds, &[EdgeType::Semantic])
1196 .await
1197 .unwrap();
1198
1199 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
1200 assert!(
1201 ids.contains(&b_semantic),
1202 "BSemantic must be activated via semantic edge"
1203 );
1204 assert!(
1205 !ids.contains(&c_causal),
1206 "CCausal must NOT be activated when filtering to semantic only"
1207 );
1208 }
1209
1210 #[tokio::test]
1212 async fn spread_large_seed_list() {
1213 let store = setup_store().await;
1214 let mut seeds = HashMap::new();
1215
1216 for i in 0..100i64 {
1218 let id = store
1219 .upsert_entity(
1220 &format!("Entity{i}"),
1221 &format!("entity{i}"),
1222 EntityType::Concept,
1223 None,
1224 )
1225 .await
1226 .unwrap()
1227 .0;
1228 seeds.insert(id, 1.0_f32);
1229 }
1230
1231 let cfg = default_params();
1232 let sa = SpreadingActivation::new(cfg);
1233 let result = sa.spread(&store, seeds, &[]).await;
1235 assert!(
1236 result.is_ok(),
1237 "large seed list must not error: {:?}",
1238 result.err()
1239 );
1240 }
1241
1242 #[test]
1245 fn hela_cosine_identical_vectors() {
1246 let v = vec![1.0_f32, 0.0, 0.0];
1247 assert!(
1248 (cosine(&v, &v) - 1.0).abs() < 1e-6,
1249 "identical vectors → cosine 1.0"
1250 );
1251 }
1252
1253 #[test]
1254 fn hela_cosine_orthogonal_vectors() {
1255 let a = vec![1.0_f32, 0.0];
1256 let b = vec![0.0_f32, 1.0];
1257 assert!(
1258 cosine(&a, &b).abs() < 1e-6,
1259 "orthogonal vectors → cosine 0.0"
1260 );
1261 }
1262
1263 #[test]
1264 fn hela_cosine_anti_correlated() {
1265 let a = vec![1.0_f32, 0.0];
1266 let b = vec![-1.0_f32, 0.0];
1267 assert!(
1268 cosine(&a, &b) < 0.0,
1269 "anti-correlated vectors → negative cosine"
1270 );
1271 }
1272
1273 #[test]
1274 fn hela_cosine_zero_vector_no_panic() {
1275 let a = vec![0.0_f32, 0.0];
1276 let b = vec![1.0_f32, 0.0];
1277 let result = cosine(&a, &b);
1279 assert!(
1280 result.is_finite(),
1281 "zero-norm vector must yield finite cosine"
1282 );
1283 }
1284
1285 #[test]
1286 fn hela_spread_params_default_depth_is_two() {
1287 let p = HelaSpreadParams::default();
1288 assert_eq!(p.spread_depth, 2);
1289 assert!(p.step_budget.is_some());
1290 assert!(p.edge_types.is_empty());
1291 assert_eq!(p.max_visited, 200);
1292 }
1293
1294 #[test]
1295 fn hela_spread_params_default_embed_timeout_is_some() {
1296 let p = HelaSpreadParams::default();
1297 assert!(
1298 p.embed_timeout.is_some(),
1299 "default embed_timeout must be Some (5 s)"
1300 );
1301 }
1302
1303 #[tokio::test]
1306 async fn hela_spreading_recall_embed_timeout_returns_error() {
1307 use std::time::Duration;
1308 use zeph_llm::any::AnyProvider;
1309 use zeph_llm::mock::MockProvider;
1310
1311 use crate::embedding_store::EmbeddingStore;
1312 use crate::error::MemoryError;
1313 use crate::in_memory_store::InMemoryVectorStore;
1314
1315 let store = setup_store().await;
1316
1317 let mock = MockProvider::default().with_embed_delay(500);
1319 let provider = AnyProvider::Mock(mock);
1320
1321 let sqlite = crate::store::SqliteStore::with_pool_size(":memory:", 1)
1322 .await
1323 .unwrap();
1324 let embeddings =
1325 EmbeddingStore::with_store(Box::new(InMemoryVectorStore::new()), sqlite.pool().clone());
1326
1327 let params = HelaSpreadParams {
1328 embed_timeout: Some(Duration::from_millis(50)),
1329 ..Default::default()
1330 };
1331
1332 let result = hela_spreading_recall(
1333 &store,
1334 &embeddings,
1335 &provider,
1336 "test query",
1337 5,
1338 ¶ms,
1339 false,
1340 0.0,
1341 )
1342 .await;
1343
1344 assert!(
1345 matches!(result, Err(MemoryError::Timeout(_))),
1346 "expected Err(MemoryError::Timeout), got {result:?}"
1347 );
1348 }
1349
1350 #[tokio::test]
1353 async fn hela_spreading_recall_no_timeout_does_not_wrap() {
1354 use zeph_llm::any::AnyProvider;
1355 use zeph_llm::mock::MockProvider;
1356
1357 use crate::embedding_store::EmbeddingStore;
1358 use crate::in_memory_store::InMemoryVectorStore;
1359
1360 let store = setup_store().await;
1361
1362 let mock = MockProvider::default().with_embed_delay(0);
1363 let provider = AnyProvider::Mock(mock);
1364
1365 let sqlite = crate::store::SqliteStore::with_pool_size(":memory:", 1)
1366 .await
1367 .unwrap();
1368 let embeddings =
1369 EmbeddingStore::with_store(Box::new(InMemoryVectorStore::new()), sqlite.pool().clone());
1370
1371 let params = HelaSpreadParams {
1372 embed_timeout: None,
1373 ..Default::default()
1374 };
1375
1376 let result = hela_spreading_recall(
1379 &store,
1380 &embeddings,
1381 &provider,
1382 "test query",
1383 5,
1384 ¶ms,
1385 false,
1386 0.0,
1387 )
1388 .await;
1389
1390 assert!(
1393 !matches!(result, Err(crate::error::MemoryError::Timeout(_))),
1394 "embed_timeout: None must not produce a Timeout error, got {result:?}"
1395 );
1396 }
1397
1398 #[test]
1399 fn hela_synthetic_anchor_edge_id_is_zero() {
1400 let edge = Edge::synthetic_anchor(42);
1401 assert_eq!(
1402 edge.id, 0,
1403 "synthetic anchor must have id = 0 to be excluded from Hebbian"
1404 );
1405 assert_eq!(edge.source_entity_id, 42);
1406 assert_eq!(edge.target_entity_id, 42);
1407 }
1408
1409 #[test]
1410 fn hela_negative_cosine_clamped_to_zero_in_score() {
1411 let anti = vec![-1.0_f32, 0.0];
1413 let query = vec![1.0_f32, 0.0];
1414 let cosine_raw = cosine(&query, &anti);
1415 assert!(cosine_raw < 0.0);
1416 let clamped = cosine_raw.max(0.0);
1417 let fact_score = 0.9_f32 * clamped;
1418 assert!(
1419 fact_score < f32::EPSILON,
1420 "anti-correlated score must be 0.0"
1421 );
1422 }
1423
1424 #[test]
1425 fn hela_path_weight_multiplicative() {
1426 let w1 = 0.8_f32;
1428 let w2 = 0.5_f32;
1429 let expected = w1 * w2;
1430 assert!((expected - 0.4).abs() < 1e-6);
1431 }
1432
1433 #[test]
1434 fn hela_max_path_weight_on_multipath() {
1435 let pw_a = 0.9_f32; let pw_b = 0.3_f32; let kept = pw_a.max(pw_b);
1439 assert!(
1440 (kept - 0.9).abs() < 1e-6,
1441 "multi-path resolution must keep maximum path_weight"
1442 );
1443 }
1444
1445 #[test]
1446 fn hela_fact_score_formula() {
1447 let path_weight = 0.8_f32;
1448 let cosine_clamped = 0.75_f32;
1449 let expected = path_weight * cosine_clamped;
1450 assert!((expected - 0.6).abs() < 1e-5);
1452 }
1453}