1use std::collections::HashMap;
17use std::sync::OnceLock;
18use std::time::{Instant, SystemTime, UNIX_EPOCH};
19#[allow(unused_imports)]
20use zeph_db::sql;
21
22use crate::embedding_store::EmbeddingStore;
23use crate::error::MemoryError;
24use crate::graph::store::GraphStore;
25use crate::graph::types::{Edge, EdgeType, edge_type_weight, evolved_weight};
26
27#[derive(Debug, Clone)]
29pub struct ActivatedNode {
30 pub entity_id: i64,
32 pub activation: f32,
34 pub depth: u32,
36}
37
38#[derive(Debug, Clone)]
40pub struct ActivatedFact {
41 pub edge: Edge,
43 pub activation_score: f32,
45}
46
47#[derive(Debug, Clone)]
50pub struct SpreadingActivationParams {
51 pub decay_lambda: f32,
52 pub max_hops: u32,
53 pub activation_threshold: f32,
54 pub inhibition_threshold: f32,
55 pub max_activated_nodes: usize,
56 pub temporal_decay_rate: f64,
57 pub seed_structural_weight: f32,
59 pub seed_community_cap: usize,
61}
62
63#[derive(Debug, Clone)]
72pub struct HelaFact {
73 pub edge: Edge,
75 pub score: f32,
77 pub depth: u32,
80 pub path_weight: f32,
83 pub cosine: Option<f32>,
87}
88
89#[derive(Debug, Clone)]
99pub struct HelaSpreadParams {
100 pub spread_depth: u32,
102 pub edge_types: Vec<EdgeType>,
104 pub max_visited: usize,
106 pub step_budget: Option<std::time::Duration>,
110}
111
112impl Default for HelaSpreadParams {
113 fn default() -> Self {
114 Self {
115 spread_depth: 2,
116 edge_types: Vec::new(),
117 max_visited: 200,
118 step_budget: Some(std::time::Duration::from_millis(8)),
119 }
120 }
121}
122
123static HELA_DIM_MISMATCH: OnceLock<String> = OnceLock::new();
132
133fn cosine(a: &[f32], b: &[f32]) -> f32 {
137 let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
138 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
139 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
140 let denom = (norm_a * norm_b).max(f32::EPSILON);
141 dot / denom
142}
143
144#[tracing::instrument(
171 name = "memory.graph.hela_spread",
172 skip_all,
173 fields(
174 depth = params.spread_depth,
175 limit,
176 anchor_id = tracing::field::Empty,
177 visited = tracing::field::Empty,
178 scored = tracing::field::Empty,
179 fallback = tracing::field::Empty,
180 )
181)]
182#[allow(clippy::too_many_arguments, clippy::too_many_lines)] pub async fn hela_spreading_recall(
184 store: &GraphStore,
185 embeddings: &EmbeddingStore,
186 provider: &zeph_llm::any::AnyProvider,
187 query: &str,
188 limit: usize,
189 params: &HelaSpreadParams,
190 hebbian_enabled: bool,
191 hebbian_lr: f32,
192) -> Result<Vec<HelaFact>, MemoryError> {
193 use zeph_llm::LlmProvider as _;
194
195 const ENTITY_COLLECTION: &str = "zeph_graph_entities";
196
197 if limit == 0 {
198 return Ok(Vec::new());
199 }
200
201 if HELA_DIM_MISMATCH.get().map(String::as_str) == Some(ENTITY_COLLECTION) {
204 tracing::debug!("hela: dim mismatch previously detected for collection, skipping");
205 return Ok(Vec::new());
206 }
207
208 let q_vec = provider.embed(query).await?;
210
211 let t_anchor = Instant::now();
213 let anchor_results = match embeddings
214 .search_collection(ENTITY_COLLECTION, &q_vec, 1, None)
215 .await
216 {
217 Ok(r) => r,
218 Err(e) => {
219 let msg = e.to_string();
220 if msg.contains("wrong vector dimension")
221 || msg.contains("InvalidArgument")
222 || msg.contains("dimension")
223 {
224 let _ = HELA_DIM_MISMATCH.set(ENTITY_COLLECTION.to_owned());
225 tracing::warn!(
226 collection = ENTITY_COLLECTION,
227 error = %e,
228 "hela: vector dimension mismatch — HL-F5 disabled for this collection"
229 );
230 return Ok(Vec::new());
231 }
232 return Err(e);
233 }
234 };
235
236 if params.step_budget.is_some_and(|b| t_anchor.elapsed() > b) {
237 tracing::warn!(
238 elapsed_ms = t_anchor.elapsed().as_millis(),
239 "hela: anchor ANN over budget"
240 );
241 return Ok(Vec::new());
242 }
243
244 let Some(anchor_point) = anchor_results.first() else {
245 tracing::debug!("hela: no anchor found, returning empty");
246 return Ok(Vec::new());
247 };
248 let Some(anchor_entity_id) = anchor_point
249 .payload
250 .get("entity_id")
251 .and_then(serde_json::Value::as_i64)
252 else {
253 tracing::warn!("hela: anchor point missing entity_id payload");
254 return Ok(Vec::new());
255 };
256 let anchor_cosine = anchor_point.score;
257
258 tracing::Span::current().record("anchor_id", anchor_entity_id);
259 tracing::debug!(anchor_entity_id, anchor_cosine, "hela: anchor resolved");
260
261 let spread_depth = params.spread_depth.clamp(1, 6);
262
263 let mut visited: HashMap<i64, (u32, f32, Option<i64>)> = HashMap::new();
266 visited.insert(anchor_entity_id, (0, 1.0, None));
267
268 let mut edge_cache: HashMap<i64, Edge> = HashMap::new();
272 let mut frontier: Vec<i64> = vec![anchor_entity_id];
273
274 for hop in 0..spread_depth {
275 if frontier.is_empty() {
276 break;
277 }
278
279 tracing::debug!(hop, frontier_size = frontier.len(), "hela: starting hop");
280
281 let t_step = Instant::now();
282 let edges = store
283 .edges_for_entities(&frontier, ¶ms.edge_types)
284 .await?;
285 if params.step_budget.is_some_and(|b| t_step.elapsed() > b) {
286 tracing::warn!(
287 hop,
288 elapsed_ms = t_step.elapsed().as_millis(),
289 "hela: edge-fetch over budget"
290 );
291 return Ok(Vec::new());
292 }
293
294 let mut next_frontier: Vec<i64> = Vec::new();
295
296 for edge in &edges {
297 edge_cache.entry(edge.id).or_insert_with(|| edge.clone());
299
300 for &src_id in &frontier {
301 let neighbor = if edge.source_entity_id == src_id {
302 edge.target_entity_id
303 } else if edge.target_entity_id == src_id {
304 edge.source_entity_id
305 } else {
306 continue;
307 };
308
309 let parent_pw = visited.get(&src_id).map_or(1.0, |&(_, pw, _)| pw);
310 let new_pw = parent_pw * edge.weight;
311
312 let entry = visited
317 .entry(neighbor)
318 .or_insert((hop + 1, 0.0_f32, Some(edge.id)));
319 if new_pw > entry.1
321 || ((new_pw - entry.1).abs() < f32::EPSILON && hop + 1 < entry.0)
322 {
323 *entry = (hop + 1, new_pw, Some(edge.id));
324 if !next_frontier.contains(&neighbor) {
325 next_frontier.push(neighbor);
326 }
327 }
328
329 if visited.len() >= params.max_visited {
330 break;
331 }
332 }
333
334 if visited.len() >= params.max_visited {
335 break;
336 }
337 }
338
339 tracing::debug!(
340 hop,
341 edges_fetched = edges.len(),
342 visited = visited.len(),
343 next_frontier = next_frontier.len(),
344 "hela: hop complete"
345 );
346
347 frontier = next_frontier;
348 if visited.len() >= params.max_visited {
349 break;
350 }
351 }
352
353 if visited.len() == 1 {
356 tracing::Span::current().record("fallback", true);
357 tracing::debug!(
358 anchor_entity_id,
359 anchor_cosine,
360 "hela: anchor isolated, falling back to pure ANN"
361 );
362 let fact = HelaFact {
363 edge: Edge::synthetic_anchor(anchor_entity_id),
364 score: anchor_cosine,
365 depth: 0,
366 path_weight: 1.0,
367 cosine: Some(anchor_cosine.clamp(0.0, 1.0)),
368 };
369 return Ok(vec![fact]);
370 }
371
372 let entity_ids: Vec<i64> = visited.keys().copied().collect();
374 let point_id_map = store.qdrant_point_ids_for_entities(&entity_ids).await?;
375 let point_ids: Vec<String> = point_id_map.values().cloned().collect();
376
377 let t_vec = Instant::now();
378 let vec_map = embeddings
379 .get_vectors_from_collection(ENTITY_COLLECTION, &point_ids)
380 .await?;
381 if params.step_budget.is_some_and(|b| t_vec.elapsed() > b) {
382 tracing::warn!(
383 elapsed_ms = t_vec.elapsed().as_millis(),
384 "hela: vectors-batch over budget"
385 );
386 return Ok(Vec::new());
387 }
388
389 let mut facts: Vec<HelaFact> = Vec::with_capacity(visited.len().saturating_sub(1));
394 for (&entity_id, &(depth, path_weight, edge_id_opt)) in &visited {
395 if entity_id == anchor_entity_id {
396 continue;
397 }
398 let Some(edge_id) = edge_id_opt else {
399 continue;
400 };
401 let Some(point_id) = point_id_map.get(&entity_id) else {
402 continue;
403 };
404 let Some(node_vec) = vec_map.get(point_id) else {
405 continue;
406 };
407 if node_vec.len() != q_vec.len() {
408 continue;
410 }
411 let cosine_clamped = cosine(&q_vec, node_vec).max(0.0);
412 let fact_score = path_weight * cosine_clamped;
413 let Some(edge) = edge_cache.get(&edge_id).cloned() else {
414 continue;
415 };
416 facts.push(HelaFact {
417 edge,
418 score: fact_score,
419 depth,
420 path_weight,
421 cosine: Some(cosine_clamped),
422 });
423 }
424
425 facts.sort_by(|a, b| b.score.total_cmp(&a.score));
427 facts.truncate(limit);
428
429 if hebbian_enabled {
434 let edge_ids: Vec<i64> = facts
435 .iter()
436 .map(|f| f.edge.id)
437 .filter(|&id| id != 0) .collect();
439 if !edge_ids.is_empty()
440 && let Err(e) = store.apply_hebbian_increment(&edge_ids, hebbian_lr).await
441 {
442 tracing::warn!(error = %e, "hela: hebbian increment failed");
443 }
444 }
445
446 tracing::Span::current().record("visited", visited.len());
447 tracing::Span::current().record("scored", facts.len());
448
449 Ok(facts)
450}
451
452pub struct SpreadingActivation {
456 params: SpreadingActivationParams,
457}
458
459impl SpreadingActivation {
460 #[must_use]
465 pub fn new(params: SpreadingActivationParams) -> Self {
466 Self { params }
467 }
468
469 pub async fn spread(
485 &self,
486 store: &GraphStore,
487 seeds: HashMap<i64, f32>,
488 edge_types: &[EdgeType],
489 ) -> Result<(Vec<ActivatedNode>, Vec<ActivatedFact>), MemoryError> {
490 if seeds.is_empty() {
491 return Ok((Vec::new(), Vec::new()));
492 }
493
494 let now_secs: i64 = SystemTime::now()
497 .duration_since(UNIX_EPOCH)
498 .map_or(0, |d| d.as_secs().cast_signed());
499
500 let mut activation = self.initialize_seeds(&seeds);
501 let mut activated_facts: Vec<ActivatedFact> = Vec::new();
502
503 for hop in 0..self.params.max_hops {
504 let active_nodes: Vec<(i64, f32)> = activation
505 .iter()
506 .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
507 .map(|(&id, &(score, _))| (id, score))
508 .collect();
509
510 if active_nodes.is_empty() {
511 break;
512 }
513
514 let node_ids: Vec<i64> = active_nodes.iter().map(|(id, _)| *id).collect();
515 let edges = store.edges_for_entities(&node_ids, edge_types).await?;
516 let edge_count = edges.len();
517
518 let next_activation =
519 self.propagate_one_hop(hop, &active_nodes, &edges, &activation, now_secs);
520
521 let pruned_count = self.merge_and_prune(&mut activation, next_activation);
522
523 tracing::debug!(
524 hop,
525 active_nodes = active_nodes.len(),
526 edges_fetched = edge_count,
527 after_merge = activation.len(),
528 pruned = pruned_count,
529 "spreading activation: hop complete"
530 );
531
532 self.collect_activated_facts(&edges, &activation, &mut activated_facts);
533 }
534
535 let result = self.finalize(activation);
536
537 tracing::info!(
538 activated = result.len(),
539 facts = activated_facts.len(),
540 "spreading activation: complete"
541 );
542
543 Ok((result, activated_facts))
544 }
545
546 fn initialize_seeds(&self, seeds: &HashMap<i64, f32>) -> HashMap<i64, (f32, u32)> {
548 let mut activation: HashMap<i64, (f32, u32)> = HashMap::new();
549 let mut seed_count = 0usize;
550 for (entity_id, match_score) in seeds {
552 if *match_score < self.params.activation_threshold {
553 tracing::debug!(
554 entity_id,
555 score = match_score,
556 threshold = self.params.activation_threshold,
557 "spreading activation: seed below threshold, skipping"
558 );
559 continue;
560 }
561 activation.insert(*entity_id, (*match_score, 0));
562 seed_count += 1;
563 }
564 tracing::debug!(
565 seeds = seed_count,
566 "spreading activation: initialized seeds"
567 );
568 activation
569 }
570
571 fn propagate_one_hop(
575 &self,
576 hop: u32,
577 active_nodes: &[(i64, f32)],
578 edges: &[Edge],
579 activation: &HashMap<i64, (f32, u32)>,
580 now_secs: i64,
581 ) -> HashMap<i64, (f32, u32)> {
582 let mut next_activation: HashMap<i64, (f32, u32)> = HashMap::new();
583
584 for edge in edges {
585 for &(active_id, node_score) in active_nodes {
586 let neighbor = if edge.source_entity_id == active_id {
587 edge.target_entity_id
588 } else if edge.target_entity_id == active_id {
589 edge.source_entity_id
590 } else {
591 continue;
592 };
593
594 let current_score = activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
599 let next_score = next_activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
600 if current_score >= self.params.inhibition_threshold
601 || next_score >= self.params.inhibition_threshold
602 {
603 continue;
604 }
605
606 let recency = self.recency_weight(&edge.valid_from, now_secs);
607 let edge_weight = evolved_weight(edge.retrieval_count, edge.confidence);
608 let type_w = edge_type_weight(edge.edge_type);
609 let spread_value =
610 node_score * self.params.decay_lambda * edge_weight * recency * type_w;
611
612 if spread_value < self.params.activation_threshold {
613 continue;
614 }
615
616 let depth_at_max = hop + 1;
619 let entry = next_activation
620 .entry(neighbor)
621 .or_insert((0.0, depth_at_max));
622 let new_score = (entry.0 + spread_value).min(1.0);
623 if new_score > entry.0 {
624 entry.0 = new_score;
625 entry.1 = depth_at_max;
626 }
627 }
628 }
629
630 next_activation
631 }
632
633 fn merge_and_prune(
637 &self,
638 activation: &mut HashMap<i64, (f32, u32)>,
639 next_activation: HashMap<i64, (f32, u32)>,
640 ) -> usize {
641 for (node_id, (new_score, new_depth)) in next_activation {
642 let entry = activation.entry(node_id).or_insert((0.0, new_depth));
643 if new_score > entry.0 {
644 entry.0 = new_score;
645 entry.1 = new_depth;
646 }
647 }
648
649 if activation.len() > self.params.max_activated_nodes {
650 let before = activation.len();
651 let mut entries: Vec<(i64, (f32, u32))> = activation.drain().collect();
652 entries.sort_by(|(_, (a, _)), (_, (b, _))| b.total_cmp(a));
653 entries.truncate(self.params.max_activated_nodes);
654 *activation = entries.into_iter().collect();
655 before - self.params.max_activated_nodes
656 } else {
657 0
658 }
659 }
660
661 fn collect_activated_facts(
663 &self,
664 edges: &[Edge],
665 activation: &HashMap<i64, (f32, u32)>,
666 activated_facts: &mut Vec<ActivatedFact>,
667 ) {
668 for edge in edges {
669 let src_score = activation
670 .get(&edge.source_entity_id)
671 .map_or(0.0, |&(s, _)| s);
672 let tgt_score = activation
673 .get(&edge.target_entity_id)
674 .map_or(0.0, |&(s, _)| s);
675 if src_score >= self.params.activation_threshold
676 && tgt_score >= self.params.activation_threshold
677 {
678 let activation_score = src_score.max(tgt_score);
679 activated_facts.push(ActivatedFact {
680 edge: edge.clone(),
681 activation_score,
682 });
683 }
684 }
685 }
686
687 fn finalize(&self, activation: HashMap<i64, (f32, u32)>) -> Vec<ActivatedNode> {
689 let mut result: Vec<ActivatedNode> = activation
690 .into_iter()
691 .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
692 .map(|(entity_id, (activation, depth))| ActivatedNode {
693 entity_id,
694 activation,
695 depth,
696 })
697 .collect();
698 result.sort_by(|a, b| b.activation.total_cmp(&a.activation));
699 result
700 }
701
702 #[allow(clippy::cast_precision_loss)]
708 fn recency_weight(&self, valid_from: &str, now_secs: i64) -> f32 {
709 if self.params.temporal_decay_rate <= 0.0 {
710 return 1.0;
711 }
712 let Some(valid_from_secs) = parse_sqlite_datetime_to_unix(valid_from) else {
713 return 1.0;
714 };
715 let age_secs = (now_secs - valid_from_secs).max(0);
716 let age_days = age_secs as f64 / 86_400.0;
717 let weight = 1.0_f64 / (1.0 + age_days * self.params.temporal_decay_rate);
718 #[allow(clippy::cast_possible_truncation)]
720 let w = weight as f32;
721 w
722 }
723}
724
725#[must_use]
730fn parse_sqlite_datetime_to_unix(s: &str) -> Option<i64> {
731 if s.len() < 19 {
732 return None;
733 }
734 let year: i64 = s[0..4].parse().ok()?;
735 let month: i64 = s[5..7].parse().ok()?;
736 let day: i64 = s[8..10].parse().ok()?;
737 let hour: i64 = s[11..13].parse().ok()?;
738 let min: i64 = s[14..16].parse().ok()?;
739 let sec: i64 = s[17..19].parse().ok()?;
740
741 let (y, m) = if month <= 2 {
744 (year - 1, month + 9)
745 } else {
746 (year, month - 3)
747 };
748 let era = y.div_euclid(400);
749 let yoe = y - era * 400;
750 let doy = (153 * m + 2) / 5 + day - 1;
751 let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
752 let days = era * 146_097 + doe - 719_468;
753
754 Some(days * 86_400 + hour * 3_600 + min * 60 + sec)
755}
756
757#[cfg(test)]
758mod tests {
759 use super::*;
760 use crate::graph::GraphStore;
761 use crate::graph::types::EntityType;
762 use crate::store::SqliteStore;
763
764 async fn setup_store() -> GraphStore {
765 let store = SqliteStore::new(":memory:").await.unwrap();
766 GraphStore::new(store.pool().clone())
767 }
768
769 fn default_params() -> SpreadingActivationParams {
770 SpreadingActivationParams {
771 decay_lambda: 0.85,
772 max_hops: 3,
773 activation_threshold: 0.1,
774 inhibition_threshold: 0.8,
775 max_activated_nodes: 50,
776 temporal_decay_rate: 0.0,
777 seed_structural_weight: 0.4,
778 seed_community_cap: 3,
779 }
780 }
781
782 #[tokio::test]
785 async fn spread_empty_graph_no_edges_no_facts() {
786 let store = setup_store().await;
787 let sa = SpreadingActivation::new(default_params());
788 let seeds = HashMap::from([(1_i64, 1.0_f32)]);
789 let (nodes, facts) = sa.spread(&store, seeds, &[]).await.unwrap();
790 assert_eq!(nodes.len(), 1, "seed must be in activated nodes");
792 assert_eq!(nodes[0].entity_id, 1);
793 assert!((nodes[0].activation - 1.0).abs() < 1e-6);
794 assert!(
796 facts.is_empty(),
797 "expected no activated facts on empty graph"
798 );
799 }
800
801 #[tokio::test]
803 async fn spread_empty_seeds_returns_empty() {
804 let store = setup_store().await;
805 let sa = SpreadingActivation::new(default_params());
806 let (nodes, facts) = sa.spread(&store, HashMap::new(), &[]).await.unwrap();
807 assert!(nodes.is_empty());
808 assert!(facts.is_empty());
809 }
810
811 #[tokio::test]
813 async fn spread_single_seed_no_edges_returns_seed() {
814 let store = setup_store().await;
815 let alice = store
816 .upsert_entity("Alice", "Alice", EntityType::Person, None)
817 .await
818 .unwrap();
819
820 let sa = SpreadingActivation::new(default_params());
821 let seeds = HashMap::from([(alice, 1.0_f32)]);
822 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
823 assert_eq!(nodes.len(), 1);
824 assert_eq!(nodes[0].entity_id, alice);
825 assert_eq!(nodes[0].depth, 0);
826 assert!((nodes[0].activation - 1.0).abs() < 1e-6);
827 }
828
829 #[tokio::test]
831 async fn spread_linear_chain_all_activated_with_decay() {
832 let store = setup_store().await;
833 let a = store
834 .upsert_entity("A", "A", EntityType::Person, None)
835 .await
836 .unwrap();
837 let b = store
838 .upsert_entity("B", "B", EntityType::Person, None)
839 .await
840 .unwrap();
841 let c = store
842 .upsert_entity("C", "C", EntityType::Person, None)
843 .await
844 .unwrap();
845 store
846 .insert_edge(a, b, "knows", "A knows B", 1.0, None)
847 .await
848 .unwrap();
849 store
850 .insert_edge(b, c, "knows", "B knows C", 1.0, None)
851 .await
852 .unwrap();
853
854 let mut cfg = default_params();
855 cfg.max_hops = 3;
856 cfg.decay_lambda = 0.9;
857 let sa = SpreadingActivation::new(cfg);
858 let seeds = HashMap::from([(a, 1.0_f32)]);
859 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
860
861 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
862 assert!(ids.contains(&a), "A (seed) must be activated");
863 assert!(ids.contains(&b), "B (hop 1) must be activated");
864 assert!(ids.contains(&c), "C (hop 2) must be activated");
865
866 let score_a = nodes.iter().find(|n| n.entity_id == a).unwrap().activation;
868 let score_b = nodes.iter().find(|n| n.entity_id == b).unwrap().activation;
869 let score_c = nodes.iter().find(|n| n.entity_id == c).unwrap().activation;
870 assert!(
871 score_a > score_b,
872 "seed A should have higher activation than hop-1 B"
873 );
874 assert!(
875 score_b > score_c,
876 "hop-1 B should have higher activation than hop-2 C"
877 );
878 }
879
880 #[tokio::test]
882 async fn spread_linear_chain_max_hops_limits_reach() {
883 let store = setup_store().await;
884 let a = store
885 .upsert_entity("A", "A", EntityType::Person, None)
886 .await
887 .unwrap();
888 let b = store
889 .upsert_entity("B", "B", EntityType::Person, None)
890 .await
891 .unwrap();
892 let c = store
893 .upsert_entity("C", "C", EntityType::Person, None)
894 .await
895 .unwrap();
896 store
897 .insert_edge(a, b, "knows", "A knows B", 1.0, None)
898 .await
899 .unwrap();
900 store
901 .insert_edge(b, c, "knows", "B knows C", 1.0, None)
902 .await
903 .unwrap();
904
905 let mut cfg = default_params();
906 cfg.max_hops = 1;
907 let sa = SpreadingActivation::new(cfg);
908 let seeds = HashMap::from([(a, 1.0_f32)]);
909 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
910
911 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
912 assert!(ids.contains(&a), "A must be activated (seed)");
913 assert!(ids.contains(&b), "B must be activated (hop 1)");
914 assert!(!ids.contains(&c), "C must NOT be activated with max_hops=1");
915 }
916
917 #[tokio::test]
921 async fn spread_diamond_graph_convergence() {
922 let store = setup_store().await;
923 let a = store
924 .upsert_entity("A", "A", EntityType::Person, None)
925 .await
926 .unwrap();
927 let b = store
928 .upsert_entity("B", "B", EntityType::Person, None)
929 .await
930 .unwrap();
931 let c = store
932 .upsert_entity("C", "C", EntityType::Person, None)
933 .await
934 .unwrap();
935 let d = store
936 .upsert_entity("D", "D", EntityType::Person, None)
937 .await
938 .unwrap();
939 store
940 .insert_edge(a, b, "rel", "A-B", 1.0, None)
941 .await
942 .unwrap();
943 store
944 .insert_edge(a, c, "rel", "A-C", 1.0, None)
945 .await
946 .unwrap();
947 store
948 .insert_edge(b, d, "rel", "B-D", 1.0, None)
949 .await
950 .unwrap();
951 store
952 .insert_edge(c, d, "rel", "C-D", 1.0, None)
953 .await
954 .unwrap();
955
956 let mut cfg = default_params();
957 cfg.max_hops = 3;
958 cfg.decay_lambda = 0.9;
959 cfg.inhibition_threshold = 0.95; let sa = SpreadingActivation::new(cfg);
961 let seeds = HashMap::from([(a, 1.0_f32)]);
962 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
963
964 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
965 assert!(ids.contains(&d), "D must be activated via diamond paths");
966
967 let node_d = nodes.iter().find(|n| n.entity_id == d).unwrap();
969 assert_eq!(node_d.depth, 2, "D should be at depth 2");
970 }
971
972 #[tokio::test]
974 async fn spread_inhibition_prevents_runaway() {
975 let store = setup_store().await;
976 let hub = store
978 .upsert_entity("Hub", "Hub", EntityType::Concept, None)
979 .await
980 .unwrap();
981
982 for i in 0..5 {
983 let leaf = store
984 .upsert_entity(
985 &format!("Leaf{i}"),
986 &format!("Leaf{i}"),
987 EntityType::Concept,
988 None,
989 )
990 .await
991 .unwrap();
992 store
993 .insert_edge(hub, leaf, "has", &format!("Hub has Leaf{i}"), 1.0, None)
994 .await
995 .unwrap();
996 store
998 .insert_edge(
999 leaf,
1000 hub,
1001 "part_of",
1002 &format!("Leaf{i} part_of Hub"),
1003 1.0,
1004 None,
1005 )
1006 .await
1007 .unwrap();
1008 }
1009
1010 let mut cfg = default_params();
1012 cfg.inhibition_threshold = 0.8;
1013 cfg.max_hops = 3;
1014 let sa = SpreadingActivation::new(cfg);
1015 let seeds = HashMap::from([(hub, 1.0_f32)]);
1016 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1017
1018 let hub_node = nodes.iter().find(|n| n.entity_id == hub);
1020 assert!(hub_node.is_some(), "hub must be in results");
1021 assert!(
1022 hub_node.unwrap().activation <= 1.0,
1023 "activation must not exceed 1.0"
1024 );
1025 }
1026
1027 #[tokio::test]
1029 async fn spread_max_activated_nodes_cap_enforced() {
1030 let store = setup_store().await;
1031 let root = store
1032 .upsert_entity("Root", "Root", EntityType::Person, None)
1033 .await
1034 .unwrap();
1035
1036 for i in 0..20 {
1038 let leaf = store
1039 .upsert_entity(
1040 &format!("Node{i}"),
1041 &format!("Node{i}"),
1042 EntityType::Concept,
1043 None,
1044 )
1045 .await
1046 .unwrap();
1047 store
1048 .insert_edge(root, leaf, "has", &format!("Root has Node{i}"), 0.9, None)
1049 .await
1050 .unwrap();
1051 }
1052
1053 let max_nodes = 5;
1054 let cfg = SpreadingActivationParams {
1055 max_activated_nodes: max_nodes,
1056 max_hops: 2,
1057 ..default_params()
1058 };
1059 let sa = SpreadingActivation::new(cfg);
1060 let seeds = HashMap::from([(root, 1.0_f32)]);
1061 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1062
1063 assert!(
1064 nodes.len() <= max_nodes,
1065 "activation must be capped at {max_nodes} nodes, got {}",
1066 nodes.len()
1067 );
1068 }
1069
1070 #[tokio::test]
1072 async fn spread_temporal_decay_recency_effect() {
1073 let store = setup_store().await;
1074 let src = store
1075 .upsert_entity("Src", "Src", EntityType::Person, None)
1076 .await
1077 .unwrap();
1078 let recent = store
1079 .upsert_entity("Recent", "Recent", EntityType::Tool, None)
1080 .await
1081 .unwrap();
1082 let old = store
1083 .upsert_entity("Old", "Old", EntityType::Tool, None)
1084 .await
1085 .unwrap();
1086
1087 store
1089 .insert_edge(src, recent, "uses", "Src uses Recent", 1.0, None)
1090 .await
1091 .unwrap();
1092
1093 zeph_db::query(
1095 sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
1096 VALUES (?1, ?2, 'uses', 'Src uses Old', 1.0, '1970-01-01 00:00:00')"),
1097 )
1098 .bind(src)
1099 .bind(old)
1100 .execute(store.pool())
1101 .await
1102 .unwrap();
1103
1104 let mut cfg = default_params();
1105 cfg.max_hops = 2;
1106 let sa = SpreadingActivation::new(SpreadingActivationParams {
1108 temporal_decay_rate: 0.5,
1109 ..cfg
1110 });
1111 let seeds = HashMap::from([(src, 1.0_f32)]);
1112 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1113
1114 let score_recent = nodes
1115 .iter()
1116 .find(|n| n.entity_id == recent)
1117 .map_or(0.0, |n| n.activation);
1118 let score_old = nodes
1119 .iter()
1120 .find(|n| n.entity_id == old)
1121 .map_or(0.0, |n| n.activation);
1122
1123 assert!(
1124 score_recent > score_old,
1125 "recent edge ({score_recent}) must produce higher activation than old edge ({score_old})"
1126 );
1127 }
1128
1129 #[tokio::test]
1131 async fn spread_edge_type_filter_excludes_other_types() {
1132 let store = setup_store().await;
1133 let a = store
1134 .upsert_entity("A", "A", EntityType::Person, None)
1135 .await
1136 .unwrap();
1137 let b_semantic = store
1138 .upsert_entity("BSemantic", "BSemantic", EntityType::Tool, None)
1139 .await
1140 .unwrap();
1141 let c_causal = store
1142 .upsert_entity("CCausal", "CCausal", EntityType::Concept, None)
1143 .await
1144 .unwrap();
1145
1146 store
1148 .insert_edge(a, b_semantic, "uses", "A uses BSemantic", 1.0, None)
1149 .await
1150 .unwrap();
1151
1152 zeph_db::query(
1154 sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from, edge_type)
1155 VALUES (?1, ?2, 'caused', 'A caused CCausal', 1.0, datetime('now'), 'causal')"),
1156 )
1157 .bind(a)
1158 .bind(c_causal)
1159 .execute(store.pool())
1160 .await
1161 .unwrap();
1162
1163 let cfg = default_params();
1164 let sa = SpreadingActivation::new(cfg);
1165
1166 let seeds = HashMap::from([(a, 1.0_f32)]);
1168 let (nodes, _) = sa
1169 .spread(&store, seeds, &[EdgeType::Semantic])
1170 .await
1171 .unwrap();
1172
1173 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
1174 assert!(
1175 ids.contains(&b_semantic),
1176 "BSemantic must be activated via semantic edge"
1177 );
1178 assert!(
1179 !ids.contains(&c_causal),
1180 "CCausal must NOT be activated when filtering to semantic only"
1181 );
1182 }
1183
1184 #[tokio::test]
1186 async fn spread_large_seed_list() {
1187 let store = setup_store().await;
1188 let mut seeds = HashMap::new();
1189
1190 for i in 0..100i64 {
1192 let id = store
1193 .upsert_entity(
1194 &format!("Entity{i}"),
1195 &format!("entity{i}"),
1196 EntityType::Concept,
1197 None,
1198 )
1199 .await
1200 .unwrap();
1201 seeds.insert(id, 1.0_f32);
1202 }
1203
1204 let cfg = default_params();
1205 let sa = SpreadingActivation::new(cfg);
1206 let result = sa.spread(&store, seeds, &[]).await;
1208 assert!(
1209 result.is_ok(),
1210 "large seed list must not error: {:?}",
1211 result.err()
1212 );
1213 }
1214
1215 #[test]
1218 fn hela_cosine_identical_vectors() {
1219 let v = vec![1.0_f32, 0.0, 0.0];
1220 assert!(
1221 (cosine(&v, &v) - 1.0).abs() < 1e-6,
1222 "identical vectors → cosine 1.0"
1223 );
1224 }
1225
1226 #[test]
1227 fn hela_cosine_orthogonal_vectors() {
1228 let a = vec![1.0_f32, 0.0];
1229 let b = vec![0.0_f32, 1.0];
1230 assert!(
1231 cosine(&a, &b).abs() < 1e-6,
1232 "orthogonal vectors → cosine 0.0"
1233 );
1234 }
1235
1236 #[test]
1237 fn hela_cosine_anti_correlated() {
1238 let a = vec![1.0_f32, 0.0];
1239 let b = vec![-1.0_f32, 0.0];
1240 assert!(
1241 cosine(&a, &b) < 0.0,
1242 "anti-correlated vectors → negative cosine"
1243 );
1244 }
1245
1246 #[test]
1247 fn hela_cosine_zero_vector_no_panic() {
1248 let a = vec![0.0_f32, 0.0];
1249 let b = vec![1.0_f32, 0.0];
1250 let result = cosine(&a, &b);
1252 assert!(
1253 result.is_finite(),
1254 "zero-norm vector must yield finite cosine"
1255 );
1256 }
1257
1258 #[test]
1259 fn hela_spread_params_default_depth_is_two() {
1260 let p = HelaSpreadParams::default();
1261 assert_eq!(p.spread_depth, 2);
1262 assert!(p.step_budget.is_some());
1263 assert!(p.edge_types.is_empty());
1264 assert_eq!(p.max_visited, 200);
1265 }
1266
1267 #[test]
1268 fn hela_synthetic_anchor_edge_id_is_zero() {
1269 let edge = Edge::synthetic_anchor(42);
1270 assert_eq!(
1271 edge.id, 0,
1272 "synthetic anchor must have id = 0 to be excluded from Hebbian"
1273 );
1274 assert_eq!(edge.source_entity_id, 42);
1275 assert_eq!(edge.target_entity_id, 42);
1276 }
1277
1278 #[test]
1279 fn hela_negative_cosine_clamped_to_zero_in_score() {
1280 let anti = vec![-1.0_f32, 0.0];
1282 let query = vec![1.0_f32, 0.0];
1283 let cosine_raw = cosine(&query, &anti);
1284 assert!(cosine_raw < 0.0);
1285 let clamped = cosine_raw.max(0.0);
1286 let fact_score = 0.9_f32 * clamped;
1287 assert!(
1288 fact_score < f32::EPSILON,
1289 "anti-correlated score must be 0.0"
1290 );
1291 }
1292
1293 #[test]
1294 fn hela_path_weight_multiplicative() {
1295 let w1 = 0.8_f32;
1297 let w2 = 0.5_f32;
1298 let expected = w1 * w2;
1299 assert!((expected - 0.4).abs() < 1e-6);
1300 }
1301
1302 #[test]
1303 fn hela_max_path_weight_on_multipath() {
1304 let pw_a = 0.9_f32; let pw_b = 0.3_f32; let kept = pw_a.max(pw_b);
1308 assert!(
1309 (kept - 0.9).abs() < 1e-6,
1310 "multi-path resolution must keep maximum path_weight"
1311 );
1312 }
1313
1314 #[test]
1315 fn hela_fact_score_formula() {
1316 let path_weight = 0.8_f32;
1317 let cosine_clamped = 0.75_f32;
1318 let expected = path_weight * cosine_clamped;
1319 assert!((expected - 0.6).abs() < 1e-5);
1321 }
1322}