1use std::collections::HashMap;
17use std::sync::OnceLock;
18use std::time::{Instant, SystemTime, UNIX_EPOCH};
19#[allow(unused_imports)]
20use zeph_db::sql;
21
22use crate::embedding_store::EmbeddingStore;
23use crate::error::MemoryError;
24use crate::graph::store::GraphStore;
25use crate::graph::types::{Edge, EdgeType, edge_type_weight, evolved_weight};
26
27#[derive(Debug, Clone)]
29pub struct ActivatedNode {
30 pub entity_id: i64,
32 pub activation: f32,
34 pub depth: u32,
36}
37
38#[derive(Debug, Clone)]
40pub struct ActivatedFact {
41 pub edge: Edge,
43 pub activation_score: f32,
45}
46
47#[derive(Debug, Clone)]
50pub struct SpreadingActivationParams {
51 pub decay_lambda: f32,
52 pub max_hops: u32,
53 pub activation_threshold: f32,
54 pub inhibition_threshold: f32,
55 pub max_activated_nodes: usize,
56 pub temporal_decay_rate: f64,
57 pub seed_structural_weight: f32,
59 pub seed_community_cap: usize,
61}
62
63#[derive(Debug, Clone)]
72pub struct HelaFact {
73 pub edge: Edge,
75 pub score: f32,
77 pub depth: u32,
80 pub path_weight: f32,
83 pub cosine: Option<f32>,
87}
88
89#[derive(Debug, Clone)]
99pub struct HelaSpreadParams {
100 pub spread_depth: u32,
102 pub edge_types: Vec<EdgeType>,
104 pub max_visited: usize,
106 pub step_budget: Option<std::time::Duration>,
110}
111
112impl Default for HelaSpreadParams {
113 fn default() -> Self {
114 Self {
115 spread_depth: 2,
116 edge_types: Vec::new(),
117 max_visited: 200,
118 step_budget: Some(std::time::Duration::from_millis(8)),
119 }
120 }
121}
122
123static HELA_DIM_MISMATCH: OnceLock<String> = OnceLock::new();
132
133fn cosine(a: &[f32], b: &[f32]) -> f32 {
137 let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
138 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
139 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
140 let denom = (norm_a * norm_b).max(f32::EPSILON);
141 dot / denom
142}
143
144#[tracing::instrument(
171 name = "memory.graph.hela_spread",
172 skip_all,
173 fields(
174 depth = params.spread_depth,
175 limit,
176 anchor_id = tracing::field::Empty,
177 visited = tracing::field::Empty,
178 scored = tracing::field::Empty,
179 fallback = tracing::field::Empty,
180 )
181)]
182#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
183pub async fn hela_spreading_recall(
184 store: &GraphStore,
185 embeddings: &EmbeddingStore,
186 provider: &zeph_llm::any::AnyProvider,
187 query: &str,
188 limit: usize,
189 params: &HelaSpreadParams,
190 hebbian_enabled: bool,
191 hebbian_lr: f32,
192) -> Result<Vec<HelaFact>, MemoryError> {
193 use zeph_llm::LlmProvider as _;
194
195 const ENTITY_COLLECTION: &str = "zeph_graph_entities";
196
197 if limit == 0 {
198 return Ok(Vec::new());
199 }
200
201 if HELA_DIM_MISMATCH.get().map(String::as_str) == Some(ENTITY_COLLECTION) {
204 tracing::debug!("hela: dim mismatch previously detected for collection, skipping");
205 return Ok(Vec::new());
206 }
207
208 let q_vec = provider.embed(query).await?;
210
211 let t_anchor = Instant::now();
213 let anchor_results = match embeddings
214 .search_collection(ENTITY_COLLECTION, &q_vec, 1, None)
215 .await
216 {
217 Ok(r) => r,
218 Err(e) => {
219 let msg = e.to_string();
220 if msg.contains("wrong vector dimension")
221 || msg.contains("InvalidArgument")
222 || msg.contains("dimension")
223 {
224 let _ = HELA_DIM_MISMATCH.set(ENTITY_COLLECTION.to_owned());
225 tracing::warn!(
226 collection = ENTITY_COLLECTION,
227 error = %e,
228 "hela: vector dimension mismatch — HL-F5 disabled for this collection"
229 );
230 return Ok(Vec::new());
231 }
232 return Err(e);
233 }
234 };
235
236 if params.step_budget.is_some_and(|b| t_anchor.elapsed() > b) {
237 tracing::warn!(
238 elapsed_ms = t_anchor.elapsed().as_millis(),
239 "hela: anchor ANN over budget"
240 );
241 return Ok(Vec::new());
242 }
243
244 let Some(anchor_point) = anchor_results.first() else {
245 tracing::debug!("hela: no anchor found, returning empty");
246 return Ok(Vec::new());
247 };
248 let Some(anchor_entity_id) = anchor_point
249 .payload
250 .get("entity_id")
251 .and_then(serde_json::Value::as_i64)
252 else {
253 tracing::warn!("hela: anchor point missing entity_id payload");
254 return Ok(Vec::new());
255 };
256 let anchor_cosine = anchor_point.score;
257
258 tracing::Span::current().record("anchor_id", anchor_entity_id);
259 tracing::debug!(anchor_entity_id, anchor_cosine, "hela: anchor resolved");
260
261 let spread_depth = params.spread_depth.clamp(1, 6);
262
263 let mut visited: HashMap<i64, (u32, f32, Option<i64>)> = HashMap::new();
266 visited.insert(anchor_entity_id, (0, 1.0, None));
267
268 let mut edge_cache: HashMap<i64, Edge> = HashMap::new();
272 let mut frontier: Vec<i64> = vec![anchor_entity_id];
273
274 for hop in 0..spread_depth {
275 if frontier.is_empty() {
276 break;
277 }
278
279 tracing::debug!(hop, frontier_size = frontier.len(), "hela: starting hop");
280
281 let t_step = Instant::now();
282 let edges = store
283 .edges_for_entities(&frontier, ¶ms.edge_types)
284 .await?;
285 if params.step_budget.is_some_and(|b| t_step.elapsed() > b) {
286 tracing::warn!(
287 hop,
288 elapsed_ms = t_step.elapsed().as_millis(),
289 "hela: edge-fetch over budget"
290 );
291 return Ok(Vec::new());
292 }
293
294 let mut next_frontier: Vec<i64> = Vec::new();
295
296 for edge in &edges {
297 edge_cache.entry(edge.id).or_insert_with(|| edge.clone());
299
300 for &src_id in &frontier {
301 let neighbor = if edge.source_entity_id == src_id {
302 edge.target_entity_id
303 } else if edge.target_entity_id == src_id {
304 edge.source_entity_id
305 } else {
306 continue;
307 };
308
309 let parent_pw = visited.get(&src_id).map_or(1.0, |&(_, pw, _)| pw);
310 let new_pw = parent_pw * edge.weight;
311
312 let entry = visited
317 .entry(neighbor)
318 .or_insert((hop + 1, 0.0_f32, Some(edge.id)));
319 if new_pw > entry.1
321 || ((new_pw - entry.1).abs() < f32::EPSILON && hop + 1 < entry.0)
322 {
323 *entry = (hop + 1, new_pw, Some(edge.id));
324 if !next_frontier.contains(&neighbor) {
325 next_frontier.push(neighbor);
326 }
327 }
328
329 if visited.len() >= params.max_visited {
330 break;
331 }
332 }
333
334 if visited.len() >= params.max_visited {
335 break;
336 }
337 }
338
339 tracing::debug!(
340 hop,
341 edges_fetched = edges.len(),
342 visited = visited.len(),
343 next_frontier = next_frontier.len(),
344 "hela: hop complete"
345 );
346
347 frontier = next_frontier;
348 if visited.len() >= params.max_visited {
349 break;
350 }
351 }
352
353 if visited.len() == 1 {
356 tracing::Span::current().record("fallback", true);
357 tracing::debug!(
358 anchor_entity_id,
359 anchor_cosine,
360 "hela: anchor isolated, falling back to pure ANN"
361 );
362 let fact = HelaFact {
363 edge: Edge::synthetic_anchor(anchor_entity_id),
364 score: anchor_cosine,
365 depth: 0,
366 path_weight: 1.0,
367 cosine: Some(anchor_cosine.clamp(0.0, 1.0)),
368 };
369 return Ok(vec![fact]);
370 }
371
372 let entity_ids: Vec<i64> = visited.keys().copied().collect();
374 let point_id_map = store.qdrant_point_ids_for_entities(&entity_ids).await?;
375 let point_ids: Vec<String> = point_id_map.values().cloned().collect();
376
377 let t_vec = Instant::now();
378 let vec_map = embeddings
379 .get_vectors_from_collection(ENTITY_COLLECTION, &point_ids)
380 .await?;
381 if params.step_budget.is_some_and(|b| t_vec.elapsed() > b) {
382 tracing::warn!(
383 elapsed_ms = t_vec.elapsed().as_millis(),
384 "hela: vectors-batch over budget"
385 );
386 return Ok(Vec::new());
387 }
388
389 let mut facts: Vec<HelaFact> = Vec::with_capacity(visited.len().saturating_sub(1));
394 for (&entity_id, &(depth, path_weight, edge_id_opt)) in &visited {
395 if entity_id == anchor_entity_id {
396 continue;
397 }
398 let Some(edge_id) = edge_id_opt else {
399 continue;
400 };
401 let Some(point_id) = point_id_map.get(&entity_id) else {
402 continue;
403 };
404 let Some(node_vec) = vec_map.get(point_id) else {
405 continue;
406 };
407 if node_vec.len() != q_vec.len() {
408 continue;
410 }
411 let cosine_clamped = cosine(&q_vec, node_vec).max(0.0);
412 let fact_score = path_weight * cosine_clamped;
413 let Some(edge) = edge_cache.get(&edge_id).cloned() else {
414 continue;
415 };
416 facts.push(HelaFact {
417 edge,
418 score: fact_score,
419 depth,
420 path_weight,
421 cosine: Some(cosine_clamped),
422 });
423 }
424
425 facts.sort_by(|a, b| b.score.total_cmp(&a.score));
427 facts.truncate(limit);
428
429 if hebbian_enabled {
434 let edge_ids: Vec<i64> = facts
435 .iter()
436 .map(|f| f.edge.id)
437 .filter(|&id| id != 0) .collect();
439 if !edge_ids.is_empty()
440 && let Err(e) = store.apply_hebbian_increment(&edge_ids, hebbian_lr).await
441 {
442 tracing::warn!(error = %e, "hela: hebbian increment failed");
443 }
444 }
445
446 tracing::Span::current().record("visited", visited.len());
447 tracing::Span::current().record("scored", facts.len());
448
449 Ok(facts)
450}
451
452pub struct SpreadingActivation {
456 params: SpreadingActivationParams,
457}
458
459impl SpreadingActivation {
460 #[must_use]
465 pub fn new(params: SpreadingActivationParams) -> Self {
466 Self { params }
467 }
468
469 #[allow(clippy::too_many_lines)]
485 pub async fn spread(
486 &self,
487 store: &GraphStore,
488 seeds: HashMap<i64, f32>,
489 edge_types: &[EdgeType],
490 ) -> Result<(Vec<ActivatedNode>, Vec<ActivatedFact>), MemoryError> {
491 if seeds.is_empty() {
492 return Ok((Vec::new(), Vec::new()));
493 }
494
495 let now_secs: i64 = SystemTime::now()
498 .duration_since(UNIX_EPOCH)
499 .map_or(0, |d| d.as_secs().cast_signed());
500
501 let mut activation: HashMap<i64, (f32, u32)> = HashMap::new();
503
504 let mut seed_count = 0usize;
507 for (entity_id, match_score) in &seeds {
508 if *match_score < self.params.activation_threshold {
509 tracing::debug!(
510 entity_id,
511 score = match_score,
512 threshold = self.params.activation_threshold,
513 "spreading activation: seed below threshold, skipping"
514 );
515 continue;
516 }
517 activation.insert(*entity_id, (*match_score, 0));
518 seed_count += 1;
519 }
520
521 tracing::debug!(
522 seeds = seed_count,
523 "spreading activation: initialized seeds"
524 );
525
526 let mut activated_facts: Vec<ActivatedFact> = Vec::new();
528
529 for hop in 0..self.params.max_hops {
531 let active_nodes: Vec<(i64, f32)> = activation
533 .iter()
534 .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
535 .map(|(&id, &(score, _))| (id, score))
536 .collect();
537
538 if active_nodes.is_empty() {
539 break;
540 }
541
542 let node_ids: Vec<i64> = active_nodes.iter().map(|(id, _)| *id).collect();
543
544 let edges = store.edges_for_entities(&node_ids, edge_types).await?;
546 let edge_count = edges.len();
547
548 let mut next_activation: HashMap<i64, (f32, u32)> = HashMap::new();
549
550 for edge in &edges {
551 for &(active_id, node_score) in &active_nodes {
554 let neighbor = if edge.source_entity_id == active_id {
555 edge.target_entity_id
556 } else if edge.target_entity_id == active_id {
557 edge.source_entity_id
558 } else {
559 continue;
560 };
561
562 let current_score = activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
567 let next_score = next_activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
568 if current_score >= self.params.inhibition_threshold
569 || next_score >= self.params.inhibition_threshold
570 {
571 continue;
572 }
573
574 let recency = self.recency_weight(&edge.valid_from, now_secs);
575 let edge_weight = evolved_weight(edge.retrieval_count, edge.confidence);
576 let type_w = edge_type_weight(edge.edge_type);
577 let spread_value =
578 node_score * self.params.decay_lambda * edge_weight * recency * type_w;
579
580 if spread_value < self.params.activation_threshold {
581 continue;
582 }
583
584 let depth_at_max = hop + 1;
588 let entry = next_activation
589 .entry(neighbor)
590 .or_insert((0.0, depth_at_max));
591 let new_score = (entry.0 + spread_value).min(1.0);
592 if new_score > entry.0 {
593 entry.0 = new_score;
594 entry.1 = depth_at_max;
595 }
596 }
597 }
598
599 for (node_id, (new_score, new_depth)) in next_activation {
601 let entry = activation.entry(node_id).or_insert((0.0, new_depth));
602 if new_score > entry.0 {
603 entry.0 = new_score;
604 entry.1 = new_depth;
605 }
606 }
607
608 let pruned_count = if activation.len() > self.params.max_activated_nodes {
611 let before = activation.len();
612 let mut entries: Vec<(i64, (f32, u32))> = activation.drain().collect();
613 entries.sort_by(|(_, (a, _)), (_, (b, _))| b.total_cmp(a));
614 entries.truncate(self.params.max_activated_nodes);
615 activation = entries.into_iter().collect();
616 before - self.params.max_activated_nodes
617 } else {
618 0
619 };
620
621 tracing::debug!(
622 hop,
623 active_nodes = active_nodes.len(),
624 edges_fetched = edge_count,
625 after_merge = activation.len(),
626 pruned = pruned_count,
627 "spreading activation: hop complete"
628 );
629
630 for edge in edges {
632 let src_score = activation
634 .get(&edge.source_entity_id)
635 .map_or(0.0, |&(s, _)| s);
636 let tgt_score = activation
637 .get(&edge.target_entity_id)
638 .map_or(0.0, |&(s, _)| s);
639 if src_score >= self.params.activation_threshold
640 && tgt_score >= self.params.activation_threshold
641 {
642 let activation_score = src_score.max(tgt_score);
643 activated_facts.push(ActivatedFact {
644 edge,
645 activation_score,
646 });
647 }
648 }
649 }
650
651 let mut result: Vec<ActivatedNode> = activation
653 .into_iter()
654 .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
655 .map(|(entity_id, (activation, depth))| ActivatedNode {
656 entity_id,
657 activation,
658 depth,
659 })
660 .collect();
661 result.sort_by(|a, b| b.activation.total_cmp(&a.activation));
662
663 tracing::info!(
664 activated = result.len(),
665 facts = activated_facts.len(),
666 "spreading activation: complete"
667 );
668
669 Ok((result, activated_facts))
670 }
671
672 #[allow(clippy::cast_precision_loss)]
678 fn recency_weight(&self, valid_from: &str, now_secs: i64) -> f32 {
679 if self.params.temporal_decay_rate <= 0.0 {
680 return 1.0;
681 }
682 let Some(valid_from_secs) = parse_sqlite_datetime_to_unix(valid_from) else {
683 return 1.0;
684 };
685 let age_secs = (now_secs - valid_from_secs).max(0);
686 let age_days = age_secs as f64 / 86_400.0;
687 let weight = 1.0_f64 / (1.0 + age_days * self.params.temporal_decay_rate);
688 #[allow(clippy::cast_possible_truncation)]
690 let w = weight as f32;
691 w
692 }
693}
694
695#[must_use]
700fn parse_sqlite_datetime_to_unix(s: &str) -> Option<i64> {
701 if s.len() < 19 {
702 return None;
703 }
704 let year: i64 = s[0..4].parse().ok()?;
705 let month: i64 = s[5..7].parse().ok()?;
706 let day: i64 = s[8..10].parse().ok()?;
707 let hour: i64 = s[11..13].parse().ok()?;
708 let min: i64 = s[14..16].parse().ok()?;
709 let sec: i64 = s[17..19].parse().ok()?;
710
711 let (y, m) = if month <= 2 {
714 (year - 1, month + 9)
715 } else {
716 (year, month - 3)
717 };
718 let era = y.div_euclid(400);
719 let yoe = y - era * 400;
720 let doy = (153 * m + 2) / 5 + day - 1;
721 let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
722 let days = era * 146_097 + doe - 719_468;
723
724 Some(days * 86_400 + hour * 3_600 + min * 60 + sec)
725}
726
727#[cfg(test)]
728mod tests {
729 use super::*;
730 use crate::graph::GraphStore;
731 use crate::graph::types::EntityType;
732 use crate::store::SqliteStore;
733
734 async fn setup_store() -> GraphStore {
735 let store = SqliteStore::new(":memory:").await.unwrap();
736 GraphStore::new(store.pool().clone())
737 }
738
739 fn default_params() -> SpreadingActivationParams {
740 SpreadingActivationParams {
741 decay_lambda: 0.85,
742 max_hops: 3,
743 activation_threshold: 0.1,
744 inhibition_threshold: 0.8,
745 max_activated_nodes: 50,
746 temporal_decay_rate: 0.0,
747 seed_structural_weight: 0.4,
748 seed_community_cap: 3,
749 }
750 }
751
752 #[tokio::test]
755 async fn spread_empty_graph_no_edges_no_facts() {
756 let store = setup_store().await;
757 let sa = SpreadingActivation::new(default_params());
758 let seeds = HashMap::from([(1_i64, 1.0_f32)]);
759 let (nodes, facts) = sa.spread(&store, seeds, &[]).await.unwrap();
760 assert_eq!(nodes.len(), 1, "seed must be in activated nodes");
762 assert_eq!(nodes[0].entity_id, 1);
763 assert!((nodes[0].activation - 1.0).abs() < 1e-6);
764 assert!(
766 facts.is_empty(),
767 "expected no activated facts on empty graph"
768 );
769 }
770
771 #[tokio::test]
773 async fn spread_empty_seeds_returns_empty() {
774 let store = setup_store().await;
775 let sa = SpreadingActivation::new(default_params());
776 let (nodes, facts) = sa.spread(&store, HashMap::new(), &[]).await.unwrap();
777 assert!(nodes.is_empty());
778 assert!(facts.is_empty());
779 }
780
781 #[tokio::test]
783 async fn spread_single_seed_no_edges_returns_seed() {
784 let store = setup_store().await;
785 let alice = store
786 .upsert_entity("Alice", "Alice", EntityType::Person, None)
787 .await
788 .unwrap();
789
790 let sa = SpreadingActivation::new(default_params());
791 let seeds = HashMap::from([(alice, 1.0_f32)]);
792 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
793 assert_eq!(nodes.len(), 1);
794 assert_eq!(nodes[0].entity_id, alice);
795 assert_eq!(nodes[0].depth, 0);
796 assert!((nodes[0].activation - 1.0).abs() < 1e-6);
797 }
798
799 #[tokio::test]
801 async fn spread_linear_chain_all_activated_with_decay() {
802 let store = setup_store().await;
803 let a = store
804 .upsert_entity("A", "A", EntityType::Person, None)
805 .await
806 .unwrap();
807 let b = store
808 .upsert_entity("B", "B", EntityType::Person, None)
809 .await
810 .unwrap();
811 let c = store
812 .upsert_entity("C", "C", EntityType::Person, None)
813 .await
814 .unwrap();
815 store
816 .insert_edge(a, b, "knows", "A knows B", 1.0, None)
817 .await
818 .unwrap();
819 store
820 .insert_edge(b, c, "knows", "B knows C", 1.0, None)
821 .await
822 .unwrap();
823
824 let mut cfg = default_params();
825 cfg.max_hops = 3;
826 cfg.decay_lambda = 0.9;
827 let sa = SpreadingActivation::new(cfg);
828 let seeds = HashMap::from([(a, 1.0_f32)]);
829 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
830
831 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
832 assert!(ids.contains(&a), "A (seed) must be activated");
833 assert!(ids.contains(&b), "B (hop 1) must be activated");
834 assert!(ids.contains(&c), "C (hop 2) must be activated");
835
836 let score_a = nodes.iter().find(|n| n.entity_id == a).unwrap().activation;
838 let score_b = nodes.iter().find(|n| n.entity_id == b).unwrap().activation;
839 let score_c = nodes.iter().find(|n| n.entity_id == c).unwrap().activation;
840 assert!(
841 score_a > score_b,
842 "seed A should have higher activation than hop-1 B"
843 );
844 assert!(
845 score_b > score_c,
846 "hop-1 B should have higher activation than hop-2 C"
847 );
848 }
849
850 #[tokio::test]
852 async fn spread_linear_chain_max_hops_limits_reach() {
853 let store = setup_store().await;
854 let a = store
855 .upsert_entity("A", "A", EntityType::Person, None)
856 .await
857 .unwrap();
858 let b = store
859 .upsert_entity("B", "B", EntityType::Person, None)
860 .await
861 .unwrap();
862 let c = store
863 .upsert_entity("C", "C", EntityType::Person, None)
864 .await
865 .unwrap();
866 store
867 .insert_edge(a, b, "knows", "A knows B", 1.0, None)
868 .await
869 .unwrap();
870 store
871 .insert_edge(b, c, "knows", "B knows C", 1.0, None)
872 .await
873 .unwrap();
874
875 let mut cfg = default_params();
876 cfg.max_hops = 1;
877 let sa = SpreadingActivation::new(cfg);
878 let seeds = HashMap::from([(a, 1.0_f32)]);
879 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
880
881 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
882 assert!(ids.contains(&a), "A must be activated (seed)");
883 assert!(ids.contains(&b), "B must be activated (hop 1)");
884 assert!(!ids.contains(&c), "C must NOT be activated with max_hops=1");
885 }
886
887 #[tokio::test]
891 async fn spread_diamond_graph_convergence() {
892 let store = setup_store().await;
893 let a = store
894 .upsert_entity("A", "A", EntityType::Person, None)
895 .await
896 .unwrap();
897 let b = store
898 .upsert_entity("B", "B", EntityType::Person, None)
899 .await
900 .unwrap();
901 let c = store
902 .upsert_entity("C", "C", EntityType::Person, None)
903 .await
904 .unwrap();
905 let d = store
906 .upsert_entity("D", "D", EntityType::Person, None)
907 .await
908 .unwrap();
909 store
910 .insert_edge(a, b, "rel", "A-B", 1.0, None)
911 .await
912 .unwrap();
913 store
914 .insert_edge(a, c, "rel", "A-C", 1.0, None)
915 .await
916 .unwrap();
917 store
918 .insert_edge(b, d, "rel", "B-D", 1.0, None)
919 .await
920 .unwrap();
921 store
922 .insert_edge(c, d, "rel", "C-D", 1.0, None)
923 .await
924 .unwrap();
925
926 let mut cfg = default_params();
927 cfg.max_hops = 3;
928 cfg.decay_lambda = 0.9;
929 cfg.inhibition_threshold = 0.95; let sa = SpreadingActivation::new(cfg);
931 let seeds = HashMap::from([(a, 1.0_f32)]);
932 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
933
934 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
935 assert!(ids.contains(&d), "D must be activated via diamond paths");
936
937 let node_d = nodes.iter().find(|n| n.entity_id == d).unwrap();
939 assert_eq!(node_d.depth, 2, "D should be at depth 2");
940 }
941
942 #[tokio::test]
944 async fn spread_inhibition_prevents_runaway() {
945 let store = setup_store().await;
946 let hub = store
948 .upsert_entity("Hub", "Hub", EntityType::Concept, None)
949 .await
950 .unwrap();
951
952 for i in 0..5 {
953 let leaf = store
954 .upsert_entity(
955 &format!("Leaf{i}"),
956 &format!("Leaf{i}"),
957 EntityType::Concept,
958 None,
959 )
960 .await
961 .unwrap();
962 store
963 .insert_edge(hub, leaf, "has", &format!("Hub has Leaf{i}"), 1.0, None)
964 .await
965 .unwrap();
966 store
968 .insert_edge(
969 leaf,
970 hub,
971 "part_of",
972 &format!("Leaf{i} part_of Hub"),
973 1.0,
974 None,
975 )
976 .await
977 .unwrap();
978 }
979
980 let mut cfg = default_params();
982 cfg.inhibition_threshold = 0.8;
983 cfg.max_hops = 3;
984 let sa = SpreadingActivation::new(cfg);
985 let seeds = HashMap::from([(hub, 1.0_f32)]);
986 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
987
988 let hub_node = nodes.iter().find(|n| n.entity_id == hub);
990 assert!(hub_node.is_some(), "hub must be in results");
991 assert!(
992 hub_node.unwrap().activation <= 1.0,
993 "activation must not exceed 1.0"
994 );
995 }
996
997 #[tokio::test]
999 async fn spread_max_activated_nodes_cap_enforced() {
1000 let store = setup_store().await;
1001 let root = store
1002 .upsert_entity("Root", "Root", EntityType::Person, None)
1003 .await
1004 .unwrap();
1005
1006 for i in 0..20 {
1008 let leaf = store
1009 .upsert_entity(
1010 &format!("Node{i}"),
1011 &format!("Node{i}"),
1012 EntityType::Concept,
1013 None,
1014 )
1015 .await
1016 .unwrap();
1017 store
1018 .insert_edge(root, leaf, "has", &format!("Root has Node{i}"), 0.9, None)
1019 .await
1020 .unwrap();
1021 }
1022
1023 let max_nodes = 5;
1024 let cfg = SpreadingActivationParams {
1025 max_activated_nodes: max_nodes,
1026 max_hops: 2,
1027 ..default_params()
1028 };
1029 let sa = SpreadingActivation::new(cfg);
1030 let seeds = HashMap::from([(root, 1.0_f32)]);
1031 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1032
1033 assert!(
1034 nodes.len() <= max_nodes,
1035 "activation must be capped at {max_nodes} nodes, got {}",
1036 nodes.len()
1037 );
1038 }
1039
1040 #[tokio::test]
1042 async fn spread_temporal_decay_recency_effect() {
1043 let store = setup_store().await;
1044 let src = store
1045 .upsert_entity("Src", "Src", EntityType::Person, None)
1046 .await
1047 .unwrap();
1048 let recent = store
1049 .upsert_entity("Recent", "Recent", EntityType::Tool, None)
1050 .await
1051 .unwrap();
1052 let old = store
1053 .upsert_entity("Old", "Old", EntityType::Tool, None)
1054 .await
1055 .unwrap();
1056
1057 store
1059 .insert_edge(src, recent, "uses", "Src uses Recent", 1.0, None)
1060 .await
1061 .unwrap();
1062
1063 zeph_db::query(
1065 sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
1066 VALUES (?1, ?2, 'uses', 'Src uses Old', 1.0, '1970-01-01 00:00:00')"),
1067 )
1068 .bind(src)
1069 .bind(old)
1070 .execute(store.pool())
1071 .await
1072 .unwrap();
1073
1074 let mut cfg = default_params();
1075 cfg.max_hops = 2;
1076 let sa = SpreadingActivation::new(SpreadingActivationParams {
1078 temporal_decay_rate: 0.5,
1079 ..cfg
1080 });
1081 let seeds = HashMap::from([(src, 1.0_f32)]);
1082 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1083
1084 let score_recent = nodes
1085 .iter()
1086 .find(|n| n.entity_id == recent)
1087 .map_or(0.0, |n| n.activation);
1088 let score_old = nodes
1089 .iter()
1090 .find(|n| n.entity_id == old)
1091 .map_or(0.0, |n| n.activation);
1092
1093 assert!(
1094 score_recent > score_old,
1095 "recent edge ({score_recent}) must produce higher activation than old edge ({score_old})"
1096 );
1097 }
1098
1099 #[tokio::test]
1101 async fn spread_edge_type_filter_excludes_other_types() {
1102 let store = setup_store().await;
1103 let a = store
1104 .upsert_entity("A", "A", EntityType::Person, None)
1105 .await
1106 .unwrap();
1107 let b_semantic = store
1108 .upsert_entity("BSemantic", "BSemantic", EntityType::Tool, None)
1109 .await
1110 .unwrap();
1111 let c_causal = store
1112 .upsert_entity("CCausal", "CCausal", EntityType::Concept, None)
1113 .await
1114 .unwrap();
1115
1116 store
1118 .insert_edge(a, b_semantic, "uses", "A uses BSemantic", 1.0, None)
1119 .await
1120 .unwrap();
1121
1122 zeph_db::query(
1124 sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from, edge_type)
1125 VALUES (?1, ?2, 'caused', 'A caused CCausal', 1.0, datetime('now'), 'causal')"),
1126 )
1127 .bind(a)
1128 .bind(c_causal)
1129 .execute(store.pool())
1130 .await
1131 .unwrap();
1132
1133 let cfg = default_params();
1134 let sa = SpreadingActivation::new(cfg);
1135
1136 let seeds = HashMap::from([(a, 1.0_f32)]);
1138 let (nodes, _) = sa
1139 .spread(&store, seeds, &[EdgeType::Semantic])
1140 .await
1141 .unwrap();
1142
1143 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
1144 assert!(
1145 ids.contains(&b_semantic),
1146 "BSemantic must be activated via semantic edge"
1147 );
1148 assert!(
1149 !ids.contains(&c_causal),
1150 "CCausal must NOT be activated when filtering to semantic only"
1151 );
1152 }
1153
1154 #[tokio::test]
1156 async fn spread_large_seed_list() {
1157 let store = setup_store().await;
1158 let mut seeds = HashMap::new();
1159
1160 for i in 0..100i64 {
1162 let id = store
1163 .upsert_entity(
1164 &format!("Entity{i}"),
1165 &format!("entity{i}"),
1166 EntityType::Concept,
1167 None,
1168 )
1169 .await
1170 .unwrap();
1171 seeds.insert(id, 1.0_f32);
1172 }
1173
1174 let cfg = default_params();
1175 let sa = SpreadingActivation::new(cfg);
1176 let result = sa.spread(&store, seeds, &[]).await;
1178 assert!(
1179 result.is_ok(),
1180 "large seed list must not error: {:?}",
1181 result.err()
1182 );
1183 }
1184
1185 #[test]
1188 fn hela_cosine_identical_vectors() {
1189 let v = vec![1.0_f32, 0.0, 0.0];
1190 assert!(
1191 (cosine(&v, &v) - 1.0).abs() < 1e-6,
1192 "identical vectors → cosine 1.0"
1193 );
1194 }
1195
1196 #[test]
1197 fn hela_cosine_orthogonal_vectors() {
1198 let a = vec![1.0_f32, 0.0];
1199 let b = vec![0.0_f32, 1.0];
1200 assert!(
1201 cosine(&a, &b).abs() < 1e-6,
1202 "orthogonal vectors → cosine 0.0"
1203 );
1204 }
1205
1206 #[test]
1207 fn hela_cosine_anti_correlated() {
1208 let a = vec![1.0_f32, 0.0];
1209 let b = vec![-1.0_f32, 0.0];
1210 assert!(
1211 cosine(&a, &b) < 0.0,
1212 "anti-correlated vectors → negative cosine"
1213 );
1214 }
1215
1216 #[test]
1217 fn hela_cosine_zero_vector_no_panic() {
1218 let a = vec![0.0_f32, 0.0];
1219 let b = vec![1.0_f32, 0.0];
1220 let result = cosine(&a, &b);
1222 assert!(
1223 result.is_finite(),
1224 "zero-norm vector must yield finite cosine"
1225 );
1226 }
1227
1228 #[test]
1229 fn hela_spread_params_default_depth_is_two() {
1230 let p = HelaSpreadParams::default();
1231 assert_eq!(p.spread_depth, 2);
1232 assert!(p.step_budget.is_some());
1233 assert!(p.edge_types.is_empty());
1234 assert_eq!(p.max_visited, 200);
1235 }
1236
1237 #[test]
1238 fn hela_synthetic_anchor_edge_id_is_zero() {
1239 let edge = Edge::synthetic_anchor(42);
1240 assert_eq!(
1241 edge.id, 0,
1242 "synthetic anchor must have id = 0 to be excluded from Hebbian"
1243 );
1244 assert_eq!(edge.source_entity_id, 42);
1245 assert_eq!(edge.target_entity_id, 42);
1246 }
1247
1248 #[test]
1249 fn hela_negative_cosine_clamped_to_zero_in_score() {
1250 let anti = vec![-1.0_f32, 0.0];
1252 let query = vec![1.0_f32, 0.0];
1253 let cosine_raw = cosine(&query, &anti);
1254 assert!(cosine_raw < 0.0);
1255 let clamped = cosine_raw.max(0.0);
1256 let fact_score = 0.9_f32 * clamped;
1257 assert!(
1258 fact_score < f32::EPSILON,
1259 "anti-correlated score must be 0.0"
1260 );
1261 }
1262
1263 #[test]
1264 fn hela_path_weight_multiplicative() {
1265 let w1 = 0.8_f32;
1267 let w2 = 0.5_f32;
1268 let expected = w1 * w2;
1269 assert!((expected - 0.4).abs() < 1e-6);
1270 }
1271
1272 #[test]
1273 fn hela_max_path_weight_on_multipath() {
1274 let pw_a = 0.9_f32; let pw_b = 0.3_f32; let kept = pw_a.max(pw_b);
1278 assert!(
1279 (kept - 0.9).abs() < 1e-6,
1280 "multi-path resolution must keep maximum path_weight"
1281 );
1282 }
1283
1284 #[test]
1285 fn hela_fact_score_formula() {
1286 let path_weight = 0.8_f32;
1287 let cosine_clamped = 0.75_f32;
1288 let expected = path_weight * cosine_clamped;
1289 assert!((expected - 0.6).abs() < 1e-5);
1291 }
1292}