1use std::collections::HashMap;
17use std::sync::OnceLock;
18use std::time::{Instant, SystemTime, UNIX_EPOCH};
19#[allow(unused_imports)]
20use zeph_db::sql;
21
22use crate::embedding_store::EmbeddingStore;
23use crate::error::MemoryError;
24use crate::graph::store::GraphStore;
25use crate::graph::types::{Edge, EdgeType, edge_type_weight, evolved_weight};
26
27#[derive(Debug, Clone)]
29pub struct ActivatedNode {
30 pub entity_id: i64,
32 pub activation: f32,
34 pub depth: u32,
36}
37
38#[derive(Debug, Clone)]
40pub struct ActivatedFact {
41 pub edge: Edge,
43 pub activation_score: f32,
45}
46
47pub use zeph_common::memory::SpreadingActivationParams;
48
49#[derive(Debug, Clone)]
58pub struct HelaFact {
59 pub edge: Edge,
61 pub score: f32,
63 pub depth: u32,
66 pub path_weight: f32,
69 pub cosine: Option<f32>,
73}
74
75#[derive(Debug, Clone)]
85pub struct HelaSpreadParams {
86 pub spread_depth: u32,
88 pub edge_types: Vec<EdgeType>,
90 pub max_visited: usize,
92 pub step_budget: Option<std::time::Duration>,
96}
97
98impl Default for HelaSpreadParams {
99 fn default() -> Self {
100 Self {
101 spread_depth: 2,
102 edge_types: Vec::new(),
103 max_visited: 200,
104 step_budget: Some(std::time::Duration::from_millis(8)),
105 }
106 }
107}
108
109static HELA_DIM_MISMATCH: OnceLock<String> = OnceLock::new();
118
119fn cosine(a: &[f32], b: &[f32]) -> f32 {
123 let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
124 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
125 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
126 let denom = (norm_a * norm_b).max(f32::EPSILON);
127 dot / denom
128}
129
130#[tracing::instrument(
157 name = "memory.graph.hela_spread",
158 skip_all,
159 fields(
160 depth = params.spread_depth,
161 limit,
162 anchor_id = tracing::field::Empty,
163 visited = tracing::field::Empty,
164 scored = tracing::field::Empty,
165 fallback = tracing::field::Empty,
166 )
167)]
168#[allow(clippy::too_many_arguments, clippy::too_many_lines)] pub async fn hela_spreading_recall(
170 store: &GraphStore,
171 embeddings: &EmbeddingStore,
172 provider: &zeph_llm::any::AnyProvider,
173 query: &str,
174 limit: usize,
175 params: &HelaSpreadParams,
176 hebbian_enabled: bool,
177 hebbian_lr: f32,
178) -> Result<Vec<HelaFact>, MemoryError> {
179 use zeph_llm::LlmProvider as _;
180
181 const ENTITY_COLLECTION: &str = "zeph_graph_entities";
182
183 if limit == 0 {
184 return Ok(Vec::new());
185 }
186
187 if HELA_DIM_MISMATCH.get().map(String::as_str) == Some(ENTITY_COLLECTION) {
190 tracing::debug!("hela: dim mismatch previously detected for collection, skipping");
191 return Ok(Vec::new());
192 }
193
194 let q_vec = provider.embed(query).await?;
196
197 let t_anchor = Instant::now();
199 let anchor_results = match embeddings
200 .search_collection(ENTITY_COLLECTION, &q_vec, 1, None)
201 .await
202 {
203 Ok(r) => r,
204 Err(e) => {
205 let msg = e.to_string();
206 if msg.contains("wrong vector dimension")
207 || msg.contains("InvalidArgument")
208 || msg.contains("dimension")
209 {
210 let _ = HELA_DIM_MISMATCH.set(ENTITY_COLLECTION.to_owned());
211 tracing::warn!(
212 collection = ENTITY_COLLECTION,
213 error = %e,
214 "hela: vector dimension mismatch — HL-F5 disabled for this collection"
215 );
216 return Ok(Vec::new());
217 }
218 return Err(e);
219 }
220 };
221
222 if params.step_budget.is_some_and(|b| t_anchor.elapsed() > b) {
223 tracing::warn!(
224 elapsed_ms = t_anchor.elapsed().as_millis(),
225 "hela: anchor ANN over budget"
226 );
227 return Ok(Vec::new());
228 }
229
230 let Some(anchor_point) = anchor_results.first() else {
231 tracing::debug!("hela: no anchor found, returning empty");
232 return Ok(Vec::new());
233 };
234 let Some(anchor_entity_id) = anchor_point
235 .payload
236 .get("entity_id")
237 .and_then(serde_json::Value::as_i64)
238 else {
239 tracing::warn!("hela: anchor point missing entity_id payload");
240 return Ok(Vec::new());
241 };
242 let anchor_cosine = anchor_point.score;
243
244 tracing::Span::current().record("anchor_id", anchor_entity_id);
245 tracing::debug!(anchor_entity_id, anchor_cosine, "hela: anchor resolved");
246
247 let spread_depth = params.spread_depth.clamp(1, 6);
248
249 let mut visited: HashMap<i64, (u32, f32, Option<i64>)> = HashMap::new();
252 visited.insert(anchor_entity_id, (0, 1.0, None));
253
254 let mut edge_cache: HashMap<i64, Edge> = HashMap::new();
258 let mut frontier: Vec<i64> = vec![anchor_entity_id];
259
260 for hop in 0..spread_depth {
261 if frontier.is_empty() {
262 break;
263 }
264
265 tracing::debug!(hop, frontier_size = frontier.len(), "hela: starting hop");
266
267 let t_step = Instant::now();
268 let edges = store
269 .edges_for_entities(&frontier, ¶ms.edge_types)
270 .await?;
271 if params.step_budget.is_some_and(|b| t_step.elapsed() > b) {
272 tracing::warn!(
273 hop,
274 elapsed_ms = t_step.elapsed().as_millis(),
275 "hela: edge-fetch over budget"
276 );
277 return Ok(Vec::new());
278 }
279
280 let mut next_frontier: Vec<i64> = Vec::new();
281
282 for edge in &edges {
283 edge_cache.entry(edge.id).or_insert_with(|| edge.clone());
285
286 for &src_id in &frontier {
287 let neighbor = if edge.source_entity_id == src_id {
288 edge.target_entity_id
289 } else if edge.target_entity_id == src_id {
290 edge.source_entity_id
291 } else {
292 continue;
293 };
294
295 let parent_pw = visited.get(&src_id).map_or(1.0, |&(_, pw, _)| pw);
296 let new_pw = parent_pw * edge.weight;
297
298 let entry = visited
303 .entry(neighbor)
304 .or_insert((hop + 1, 0.0_f32, Some(edge.id)));
305 if new_pw > entry.1
307 || ((new_pw - entry.1).abs() < f32::EPSILON && hop + 1 < entry.0)
308 {
309 *entry = (hop + 1, new_pw, Some(edge.id));
310 if !next_frontier.contains(&neighbor) {
311 next_frontier.push(neighbor);
312 }
313 }
314
315 if visited.len() >= params.max_visited {
316 break;
317 }
318 }
319
320 if visited.len() >= params.max_visited {
321 break;
322 }
323 }
324
325 tracing::debug!(
326 hop,
327 edges_fetched = edges.len(),
328 visited = visited.len(),
329 next_frontier = next_frontier.len(),
330 "hela: hop complete"
331 );
332
333 frontier = next_frontier;
334 if visited.len() >= params.max_visited {
335 break;
336 }
337 }
338
339 if visited.len() == 1 {
342 tracing::Span::current().record("fallback", true);
343 tracing::debug!(
344 anchor_entity_id,
345 anchor_cosine,
346 "hela: anchor isolated, falling back to pure ANN"
347 );
348 let fact = HelaFact {
349 edge: Edge::synthetic_anchor(anchor_entity_id),
350 score: anchor_cosine,
351 depth: 0,
352 path_weight: 1.0,
353 cosine: Some(anchor_cosine.clamp(0.0, 1.0)),
354 };
355 return Ok(vec![fact]);
356 }
357
358 let entity_ids: Vec<i64> = visited.keys().copied().collect();
360 let point_id_map = store.qdrant_point_ids_for_entities(&entity_ids).await?;
361 let point_ids: Vec<String> = point_id_map.values().cloned().collect();
362
363 let t_vec = Instant::now();
364 let vec_map = embeddings
365 .get_vectors_from_collection(ENTITY_COLLECTION, &point_ids)
366 .await?;
367 if params.step_budget.is_some_and(|b| t_vec.elapsed() > b) {
368 tracing::warn!(
369 elapsed_ms = t_vec.elapsed().as_millis(),
370 "hela: vectors-batch over budget"
371 );
372 return Ok(Vec::new());
373 }
374
375 let mut facts: Vec<HelaFact> = Vec::with_capacity(visited.len().saturating_sub(1));
380 for (&entity_id, &(depth, path_weight, edge_id_opt)) in &visited {
381 if entity_id == anchor_entity_id {
382 continue;
383 }
384 let Some(edge_id) = edge_id_opt else {
385 continue;
386 };
387 let Some(point_id) = point_id_map.get(&entity_id) else {
388 continue;
389 };
390 let Some(node_vec) = vec_map.get(point_id) else {
391 continue;
392 };
393 if node_vec.len() != q_vec.len() {
394 continue;
396 }
397 let cosine_clamped = cosine(&q_vec, node_vec).max(0.0);
398 let fact_score = path_weight * cosine_clamped;
399 let Some(edge) = edge_cache.get(&edge_id).cloned() else {
400 continue;
401 };
402 facts.push(HelaFact {
403 edge,
404 score: fact_score,
405 depth,
406 path_weight,
407 cosine: Some(cosine_clamped),
408 });
409 }
410
411 facts.sort_by(|a, b| b.score.total_cmp(&a.score));
413 facts.truncate(limit);
414
415 if hebbian_enabled {
420 let edge_ids: Vec<i64> = facts
421 .iter()
422 .map(|f| f.edge.id)
423 .filter(|&id| id != 0) .collect();
425 if !edge_ids.is_empty()
426 && let Err(e) = store.apply_hebbian_increment(&edge_ids, hebbian_lr).await
427 {
428 tracing::warn!(error = %e, "hela: hebbian increment failed");
429 }
430 }
431
432 tracing::Span::current().record("visited", visited.len());
433 tracing::Span::current().record("scored", facts.len());
434
435 Ok(facts)
436}
437
438pub struct SpreadingActivation {
442 params: SpreadingActivationParams,
443}
444
445impl SpreadingActivation {
446 #[must_use]
451 pub fn new(params: SpreadingActivationParams) -> Self {
452 Self { params }
453 }
454
455 pub async fn spread(
471 &self,
472 store: &GraphStore,
473 seeds: HashMap<i64, f32>,
474 edge_types: &[EdgeType],
475 ) -> Result<(Vec<ActivatedNode>, Vec<ActivatedFact>), MemoryError> {
476 if seeds.is_empty() {
477 return Ok((Vec::new(), Vec::new()));
478 }
479
480 let now_secs: i64 = SystemTime::now()
483 .duration_since(UNIX_EPOCH)
484 .map_or(0, |d| d.as_secs().cast_signed());
485
486 let mut activation = self.initialize_seeds(&seeds);
487 let mut activated_facts: Vec<ActivatedFact> = Vec::new();
488
489 for hop in 0..self.params.max_hops {
490 let active_nodes: Vec<(i64, f32)> = activation
491 .iter()
492 .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
493 .map(|(&id, &(score, _))| (id, score))
494 .collect();
495
496 if active_nodes.is_empty() {
497 break;
498 }
499
500 let node_ids: Vec<i64> = active_nodes.iter().map(|(id, _)| *id).collect();
501 let edges = store.edges_for_entities(&node_ids, edge_types).await?;
502 let edge_count = edges.len();
503
504 let next_activation =
505 self.propagate_one_hop(hop, &active_nodes, &edges, &activation, now_secs);
506
507 let pruned_count = self.merge_and_prune(&mut activation, next_activation);
508
509 tracing::debug!(
510 hop,
511 active_nodes = active_nodes.len(),
512 edges_fetched = edge_count,
513 after_merge = activation.len(),
514 pruned = pruned_count,
515 "spreading activation: hop complete"
516 );
517
518 self.collect_activated_facts(&edges, &activation, &mut activated_facts);
519 }
520
521 let result = self.finalize(activation);
522
523 tracing::info!(
524 activated = result.len(),
525 facts = activated_facts.len(),
526 "spreading activation: complete"
527 );
528
529 Ok((result, activated_facts))
530 }
531
532 fn initialize_seeds(&self, seeds: &HashMap<i64, f32>) -> HashMap<i64, (f32, u32)> {
534 let mut activation: HashMap<i64, (f32, u32)> = HashMap::new();
535 let mut seed_count = 0usize;
536 for (entity_id, match_score) in seeds {
538 if *match_score < self.params.activation_threshold {
539 tracing::debug!(
540 entity_id,
541 score = match_score,
542 threshold = self.params.activation_threshold,
543 "spreading activation: seed below threshold, skipping"
544 );
545 continue;
546 }
547 activation.insert(*entity_id, (*match_score, 0));
548 seed_count += 1;
549 }
550 tracing::debug!(
551 seeds = seed_count,
552 "spreading activation: initialized seeds"
553 );
554 activation
555 }
556
557 fn propagate_one_hop(
561 &self,
562 hop: u32,
563 active_nodes: &[(i64, f32)],
564 edges: &[Edge],
565 activation: &HashMap<i64, (f32, u32)>,
566 now_secs: i64,
567 ) -> HashMap<i64, (f32, u32)> {
568 let mut next_activation: HashMap<i64, (f32, u32)> = HashMap::new();
569
570 for edge in edges {
571 for &(active_id, node_score) in active_nodes {
572 let neighbor = if edge.source_entity_id == active_id {
573 edge.target_entity_id
574 } else if edge.target_entity_id == active_id {
575 edge.source_entity_id
576 } else {
577 continue;
578 };
579
580 let current_score = activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
585 let next_score = next_activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
586 if current_score >= self.params.inhibition_threshold
587 || next_score >= self.params.inhibition_threshold
588 {
589 continue;
590 }
591
592 let recency = self.recency_weight(&edge.valid_from, now_secs);
593 let edge_weight = evolved_weight(edge.retrieval_count, edge.confidence);
594 let type_w = edge_type_weight(edge.edge_type);
595 let spread_value =
596 node_score * self.params.decay_lambda * edge_weight * recency * type_w;
597
598 if spread_value < self.params.activation_threshold {
599 continue;
600 }
601
602 let depth_at_max = hop + 1;
605 let entry = next_activation
606 .entry(neighbor)
607 .or_insert((0.0, depth_at_max));
608 let new_score = (entry.0 + spread_value).min(1.0);
609 if new_score > entry.0 {
610 entry.0 = new_score;
611 entry.1 = depth_at_max;
612 }
613 }
614 }
615
616 next_activation
617 }
618
619 fn merge_and_prune(
623 &self,
624 activation: &mut HashMap<i64, (f32, u32)>,
625 next_activation: HashMap<i64, (f32, u32)>,
626 ) -> usize {
627 for (node_id, (new_score, new_depth)) in next_activation {
628 let entry = activation.entry(node_id).or_insert((0.0, new_depth));
629 if new_score > entry.0 {
630 entry.0 = new_score;
631 entry.1 = new_depth;
632 }
633 }
634
635 if activation.len() > self.params.max_activated_nodes {
636 let before = activation.len();
637 let mut entries: Vec<(i64, (f32, u32))> = activation.drain().collect();
638 entries.sort_by(|(_, (a, _)), (_, (b, _))| b.total_cmp(a));
639 entries.truncate(self.params.max_activated_nodes);
640 *activation = entries.into_iter().collect();
641 before - self.params.max_activated_nodes
642 } else {
643 0
644 }
645 }
646
647 fn collect_activated_facts(
649 &self,
650 edges: &[Edge],
651 activation: &HashMap<i64, (f32, u32)>,
652 activated_facts: &mut Vec<ActivatedFact>,
653 ) {
654 for edge in edges {
655 let src_score = activation
656 .get(&edge.source_entity_id)
657 .map_or(0.0, |&(s, _)| s);
658 let tgt_score = activation
659 .get(&edge.target_entity_id)
660 .map_or(0.0, |&(s, _)| s);
661 if src_score >= self.params.activation_threshold
662 && tgt_score >= self.params.activation_threshold
663 {
664 let activation_score = src_score.max(tgt_score);
665 activated_facts.push(ActivatedFact {
666 edge: edge.clone(),
667 activation_score,
668 });
669 }
670 }
671 }
672
673 fn finalize(&self, activation: HashMap<i64, (f32, u32)>) -> Vec<ActivatedNode> {
675 let mut result: Vec<ActivatedNode> = activation
676 .into_iter()
677 .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
678 .map(|(entity_id, (activation, depth))| ActivatedNode {
679 entity_id,
680 activation,
681 depth,
682 })
683 .collect();
684 result.sort_by(|a, b| b.activation.total_cmp(&a.activation));
685 result
686 }
687
688 #[allow(clippy::cast_precision_loss)]
694 fn recency_weight(&self, valid_from: &str, now_secs: i64) -> f32 {
695 if self.params.temporal_decay_rate <= 0.0 {
696 return 1.0;
697 }
698 let Some(valid_from_secs) = parse_sqlite_datetime_to_unix(valid_from) else {
699 return 1.0;
700 };
701 let age_secs = (now_secs - valid_from_secs).max(0);
702 let age_days = age_secs as f64 / 86_400.0;
703 let weight = 1.0_f64 / (1.0 + age_days * self.params.temporal_decay_rate);
704 #[allow(clippy::cast_possible_truncation)]
706 let w = weight as f32;
707 w
708 }
709}
710
711#[must_use]
716fn parse_sqlite_datetime_to_unix(s: &str) -> Option<i64> {
717 if s.len() < 19 {
718 return None;
719 }
720 let year: i64 = s[0..4].parse().ok()?;
721 let month: i64 = s[5..7].parse().ok()?;
722 let day: i64 = s[8..10].parse().ok()?;
723 let hour: i64 = s[11..13].parse().ok()?;
724 let min: i64 = s[14..16].parse().ok()?;
725 let sec: i64 = s[17..19].parse().ok()?;
726
727 let (y, m) = if month <= 2 {
730 (year - 1, month + 9)
731 } else {
732 (year, month - 3)
733 };
734 let era = y.div_euclid(400);
735 let yoe = y - era * 400;
736 let doy = (153 * m + 2) / 5 + day - 1;
737 let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
738 let days = era * 146_097 + doe - 719_468;
739
740 Some(days * 86_400 + hour * 3_600 + min * 60 + sec)
741}
742
743#[cfg(test)]
744mod tests {
745 use super::*;
746 use crate::graph::GraphStore;
747 use crate::graph::types::EntityType;
748 use crate::store::SqliteStore;
749
750 async fn setup_store() -> GraphStore {
751 let store = SqliteStore::new(":memory:").await.unwrap();
752 GraphStore::new(store.pool().clone())
753 }
754
755 fn default_params() -> SpreadingActivationParams {
756 SpreadingActivationParams {
757 decay_lambda: 0.85,
758 max_hops: 3,
759 activation_threshold: 0.1,
760 inhibition_threshold: 0.8,
761 max_activated_nodes: 50,
762 temporal_decay_rate: 0.0,
763 seed_structural_weight: 0.4,
764 seed_community_cap: 3,
765 }
766 }
767
768 #[tokio::test]
771 async fn spread_empty_graph_no_edges_no_facts() {
772 let store = setup_store().await;
773 let sa = SpreadingActivation::new(default_params());
774 let seeds = HashMap::from([(1_i64, 1.0_f32)]);
775 let (nodes, facts) = sa.spread(&store, seeds, &[]).await.unwrap();
776 assert_eq!(nodes.len(), 1, "seed must be in activated nodes");
778 assert_eq!(nodes[0].entity_id, 1);
779 assert!((nodes[0].activation - 1.0).abs() < 1e-6);
780 assert!(
782 facts.is_empty(),
783 "expected no activated facts on empty graph"
784 );
785 }
786
787 #[tokio::test]
789 async fn spread_empty_seeds_returns_empty() {
790 let store = setup_store().await;
791 let sa = SpreadingActivation::new(default_params());
792 let (nodes, facts) = sa.spread(&store, HashMap::new(), &[]).await.unwrap();
793 assert!(nodes.is_empty());
794 assert!(facts.is_empty());
795 }
796
797 #[tokio::test]
799 async fn spread_single_seed_no_edges_returns_seed() {
800 let store = setup_store().await;
801 let alice = store
802 .upsert_entity("Alice", "Alice", EntityType::Person, None)
803 .await
804 .unwrap();
805
806 let sa = SpreadingActivation::new(default_params());
807 let seeds = HashMap::from([(alice, 1.0_f32)]);
808 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
809 assert_eq!(nodes.len(), 1);
810 assert_eq!(nodes[0].entity_id, alice);
811 assert_eq!(nodes[0].depth, 0);
812 assert!((nodes[0].activation - 1.0).abs() < 1e-6);
813 }
814
815 #[tokio::test]
817 async fn spread_linear_chain_all_activated_with_decay() {
818 let store = setup_store().await;
819 let a = store
820 .upsert_entity("A", "A", EntityType::Person, None)
821 .await
822 .unwrap();
823 let b = store
824 .upsert_entity("B", "B", EntityType::Person, None)
825 .await
826 .unwrap();
827 let c = store
828 .upsert_entity("C", "C", EntityType::Person, None)
829 .await
830 .unwrap();
831 store
832 .insert_edge(a, b, "knows", "A knows B", 1.0, None)
833 .await
834 .unwrap();
835 store
836 .insert_edge(b, c, "knows", "B knows C", 1.0, None)
837 .await
838 .unwrap();
839
840 let mut cfg = default_params();
841 cfg.max_hops = 3;
842 cfg.decay_lambda = 0.9;
843 let sa = SpreadingActivation::new(cfg);
844 let seeds = HashMap::from([(a, 1.0_f32)]);
845 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
846
847 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
848 assert!(ids.contains(&a), "A (seed) must be activated");
849 assert!(ids.contains(&b), "B (hop 1) must be activated");
850 assert!(ids.contains(&c), "C (hop 2) must be activated");
851
852 let score_a = nodes.iter().find(|n| n.entity_id == a).unwrap().activation;
854 let score_b = nodes.iter().find(|n| n.entity_id == b).unwrap().activation;
855 let score_c = nodes.iter().find(|n| n.entity_id == c).unwrap().activation;
856 assert!(
857 score_a > score_b,
858 "seed A should have higher activation than hop-1 B"
859 );
860 assert!(
861 score_b > score_c,
862 "hop-1 B should have higher activation than hop-2 C"
863 );
864 }
865
866 #[tokio::test]
868 async fn spread_linear_chain_max_hops_limits_reach() {
869 let store = setup_store().await;
870 let a = store
871 .upsert_entity("A", "A", EntityType::Person, None)
872 .await
873 .unwrap();
874 let b = store
875 .upsert_entity("B", "B", EntityType::Person, None)
876 .await
877 .unwrap();
878 let c = store
879 .upsert_entity("C", "C", EntityType::Person, None)
880 .await
881 .unwrap();
882 store
883 .insert_edge(a, b, "knows", "A knows B", 1.0, None)
884 .await
885 .unwrap();
886 store
887 .insert_edge(b, c, "knows", "B knows C", 1.0, None)
888 .await
889 .unwrap();
890
891 let mut cfg = default_params();
892 cfg.max_hops = 1;
893 let sa = SpreadingActivation::new(cfg);
894 let seeds = HashMap::from([(a, 1.0_f32)]);
895 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
896
897 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
898 assert!(ids.contains(&a), "A must be activated (seed)");
899 assert!(ids.contains(&b), "B must be activated (hop 1)");
900 assert!(!ids.contains(&c), "C must NOT be activated with max_hops=1");
901 }
902
903 #[tokio::test]
907 async fn spread_diamond_graph_convergence() {
908 let store = setup_store().await;
909 let a = store
910 .upsert_entity("A", "A", EntityType::Person, None)
911 .await
912 .unwrap();
913 let b = store
914 .upsert_entity("B", "B", EntityType::Person, None)
915 .await
916 .unwrap();
917 let c = store
918 .upsert_entity("C", "C", EntityType::Person, None)
919 .await
920 .unwrap();
921 let d = store
922 .upsert_entity("D", "D", EntityType::Person, None)
923 .await
924 .unwrap();
925 store
926 .insert_edge(a, b, "rel", "A-B", 1.0, None)
927 .await
928 .unwrap();
929 store
930 .insert_edge(a, c, "rel", "A-C", 1.0, None)
931 .await
932 .unwrap();
933 store
934 .insert_edge(b, d, "rel", "B-D", 1.0, None)
935 .await
936 .unwrap();
937 store
938 .insert_edge(c, d, "rel", "C-D", 1.0, None)
939 .await
940 .unwrap();
941
942 let mut cfg = default_params();
943 cfg.max_hops = 3;
944 cfg.decay_lambda = 0.9;
945 cfg.inhibition_threshold = 0.95; let sa = SpreadingActivation::new(cfg);
947 let seeds = HashMap::from([(a, 1.0_f32)]);
948 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
949
950 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
951 assert!(ids.contains(&d), "D must be activated via diamond paths");
952
953 let node_d = nodes.iter().find(|n| n.entity_id == d).unwrap();
955 assert_eq!(node_d.depth, 2, "D should be at depth 2");
956 }
957
958 #[tokio::test]
960 async fn spread_inhibition_prevents_runaway() {
961 let store = setup_store().await;
962 let hub = store
964 .upsert_entity("Hub", "Hub", EntityType::Concept, None)
965 .await
966 .unwrap();
967
968 for i in 0..5 {
969 let leaf = store
970 .upsert_entity(
971 &format!("Leaf{i}"),
972 &format!("Leaf{i}"),
973 EntityType::Concept,
974 None,
975 )
976 .await
977 .unwrap();
978 store
979 .insert_edge(hub, leaf, "has", &format!("Hub has Leaf{i}"), 1.0, None)
980 .await
981 .unwrap();
982 store
984 .insert_edge(
985 leaf,
986 hub,
987 "part_of",
988 &format!("Leaf{i} part_of Hub"),
989 1.0,
990 None,
991 )
992 .await
993 .unwrap();
994 }
995
996 let mut cfg = default_params();
998 cfg.inhibition_threshold = 0.8;
999 cfg.max_hops = 3;
1000 let sa = SpreadingActivation::new(cfg);
1001 let seeds = HashMap::from([(hub, 1.0_f32)]);
1002 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1003
1004 let hub_node = nodes.iter().find(|n| n.entity_id == hub);
1006 assert!(hub_node.is_some(), "hub must be in results");
1007 assert!(
1008 hub_node.unwrap().activation <= 1.0,
1009 "activation must not exceed 1.0"
1010 );
1011 }
1012
1013 #[tokio::test]
1015 async fn spread_max_activated_nodes_cap_enforced() {
1016 let store = setup_store().await;
1017 let root = store
1018 .upsert_entity("Root", "Root", EntityType::Person, None)
1019 .await
1020 .unwrap();
1021
1022 for i in 0..20 {
1024 let leaf = store
1025 .upsert_entity(
1026 &format!("Node{i}"),
1027 &format!("Node{i}"),
1028 EntityType::Concept,
1029 None,
1030 )
1031 .await
1032 .unwrap();
1033 store
1034 .insert_edge(root, leaf, "has", &format!("Root has Node{i}"), 0.9, None)
1035 .await
1036 .unwrap();
1037 }
1038
1039 let max_nodes = 5;
1040 let cfg = SpreadingActivationParams {
1041 max_activated_nodes: max_nodes,
1042 max_hops: 2,
1043 ..default_params()
1044 };
1045 let sa = SpreadingActivation::new(cfg);
1046 let seeds = HashMap::from([(root, 1.0_f32)]);
1047 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1048
1049 assert!(
1050 nodes.len() <= max_nodes,
1051 "activation must be capped at {max_nodes} nodes, got {}",
1052 nodes.len()
1053 );
1054 }
1055
1056 #[tokio::test]
1058 async fn spread_temporal_decay_recency_effect() {
1059 let store = setup_store().await;
1060 let src = store
1061 .upsert_entity("Src", "Src", EntityType::Person, None)
1062 .await
1063 .unwrap();
1064 let recent = store
1065 .upsert_entity("Recent", "Recent", EntityType::Tool, None)
1066 .await
1067 .unwrap();
1068 let old = store
1069 .upsert_entity("Old", "Old", EntityType::Tool, None)
1070 .await
1071 .unwrap();
1072
1073 store
1075 .insert_edge(src, recent, "uses", "Src uses Recent", 1.0, None)
1076 .await
1077 .unwrap();
1078
1079 zeph_db::query(
1081 sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
1082 VALUES (?1, ?2, 'uses', 'Src uses Old', 1.0, '1970-01-01 00:00:00')"),
1083 )
1084 .bind(src)
1085 .bind(old)
1086 .execute(store.pool())
1087 .await
1088 .unwrap();
1089
1090 let mut cfg = default_params();
1091 cfg.max_hops = 2;
1092 let sa = SpreadingActivation::new(SpreadingActivationParams {
1094 temporal_decay_rate: 0.5,
1095 ..cfg
1096 });
1097 let seeds = HashMap::from([(src, 1.0_f32)]);
1098 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1099
1100 let score_recent = nodes
1101 .iter()
1102 .find(|n| n.entity_id == recent)
1103 .map_or(0.0, |n| n.activation);
1104 let score_old = nodes
1105 .iter()
1106 .find(|n| n.entity_id == old)
1107 .map_or(0.0, |n| n.activation);
1108
1109 assert!(
1110 score_recent > score_old,
1111 "recent edge ({score_recent}) must produce higher activation than old edge ({score_old})"
1112 );
1113 }
1114
1115 #[tokio::test]
1117 async fn spread_edge_type_filter_excludes_other_types() {
1118 let store = setup_store().await;
1119 let a = store
1120 .upsert_entity("A", "A", EntityType::Person, None)
1121 .await
1122 .unwrap();
1123 let b_semantic = store
1124 .upsert_entity("BSemantic", "BSemantic", EntityType::Tool, None)
1125 .await
1126 .unwrap();
1127 let c_causal = store
1128 .upsert_entity("CCausal", "CCausal", EntityType::Concept, None)
1129 .await
1130 .unwrap();
1131
1132 store
1134 .insert_edge(a, b_semantic, "uses", "A uses BSemantic", 1.0, None)
1135 .await
1136 .unwrap();
1137
1138 zeph_db::query(
1140 sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from, edge_type)
1141 VALUES (?1, ?2, 'caused', 'A caused CCausal', 1.0, datetime('now'), 'causal')"),
1142 )
1143 .bind(a)
1144 .bind(c_causal)
1145 .execute(store.pool())
1146 .await
1147 .unwrap();
1148
1149 let cfg = default_params();
1150 let sa = SpreadingActivation::new(cfg);
1151
1152 let seeds = HashMap::from([(a, 1.0_f32)]);
1154 let (nodes, _) = sa
1155 .spread(&store, seeds, &[EdgeType::Semantic])
1156 .await
1157 .unwrap();
1158
1159 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
1160 assert!(
1161 ids.contains(&b_semantic),
1162 "BSemantic must be activated via semantic edge"
1163 );
1164 assert!(
1165 !ids.contains(&c_causal),
1166 "CCausal must NOT be activated when filtering to semantic only"
1167 );
1168 }
1169
1170 #[tokio::test]
1172 async fn spread_large_seed_list() {
1173 let store = setup_store().await;
1174 let mut seeds = HashMap::new();
1175
1176 for i in 0..100i64 {
1178 let id = store
1179 .upsert_entity(
1180 &format!("Entity{i}"),
1181 &format!("entity{i}"),
1182 EntityType::Concept,
1183 None,
1184 )
1185 .await
1186 .unwrap();
1187 seeds.insert(id, 1.0_f32);
1188 }
1189
1190 let cfg = default_params();
1191 let sa = SpreadingActivation::new(cfg);
1192 let result = sa.spread(&store, seeds, &[]).await;
1194 assert!(
1195 result.is_ok(),
1196 "large seed list must not error: {:?}",
1197 result.err()
1198 );
1199 }
1200
1201 #[test]
1204 fn hela_cosine_identical_vectors() {
1205 let v = vec![1.0_f32, 0.0, 0.0];
1206 assert!(
1207 (cosine(&v, &v) - 1.0).abs() < 1e-6,
1208 "identical vectors → cosine 1.0"
1209 );
1210 }
1211
1212 #[test]
1213 fn hela_cosine_orthogonal_vectors() {
1214 let a = vec![1.0_f32, 0.0];
1215 let b = vec![0.0_f32, 1.0];
1216 assert!(
1217 cosine(&a, &b).abs() < 1e-6,
1218 "orthogonal vectors → cosine 0.0"
1219 );
1220 }
1221
1222 #[test]
1223 fn hela_cosine_anti_correlated() {
1224 let a = vec![1.0_f32, 0.0];
1225 let b = vec![-1.0_f32, 0.0];
1226 assert!(
1227 cosine(&a, &b) < 0.0,
1228 "anti-correlated vectors → negative cosine"
1229 );
1230 }
1231
1232 #[test]
1233 fn hela_cosine_zero_vector_no_panic() {
1234 let a = vec![0.0_f32, 0.0];
1235 let b = vec![1.0_f32, 0.0];
1236 let result = cosine(&a, &b);
1238 assert!(
1239 result.is_finite(),
1240 "zero-norm vector must yield finite cosine"
1241 );
1242 }
1243
1244 #[test]
1245 fn hela_spread_params_default_depth_is_two() {
1246 let p = HelaSpreadParams::default();
1247 assert_eq!(p.spread_depth, 2);
1248 assert!(p.step_budget.is_some());
1249 assert!(p.edge_types.is_empty());
1250 assert_eq!(p.max_visited, 200);
1251 }
1252
1253 #[test]
1254 fn hela_synthetic_anchor_edge_id_is_zero() {
1255 let edge = Edge::synthetic_anchor(42);
1256 assert_eq!(
1257 edge.id, 0,
1258 "synthetic anchor must have id = 0 to be excluded from Hebbian"
1259 );
1260 assert_eq!(edge.source_entity_id, 42);
1261 assert_eq!(edge.target_entity_id, 42);
1262 }
1263
1264 #[test]
1265 fn hela_negative_cosine_clamped_to_zero_in_score() {
1266 let anti = vec![-1.0_f32, 0.0];
1268 let query = vec![1.0_f32, 0.0];
1269 let cosine_raw = cosine(&query, &anti);
1270 assert!(cosine_raw < 0.0);
1271 let clamped = cosine_raw.max(0.0);
1272 let fact_score = 0.9_f32 * clamped;
1273 assert!(
1274 fact_score < f32::EPSILON,
1275 "anti-correlated score must be 0.0"
1276 );
1277 }
1278
1279 #[test]
1280 fn hela_path_weight_multiplicative() {
1281 let w1 = 0.8_f32;
1283 let w2 = 0.5_f32;
1284 let expected = w1 * w2;
1285 assert!((expected - 0.4).abs() < 1e-6);
1286 }
1287
1288 #[test]
1289 fn hela_max_path_weight_on_multipath() {
1290 let pw_a = 0.9_f32; let pw_b = 0.3_f32; let kept = pw_a.max(pw_b);
1294 assert!(
1295 (kept - 0.9).abs() < 1e-6,
1296 "multi-path resolution must keep maximum path_weight"
1297 );
1298 }
1299
1300 #[test]
1301 fn hela_fact_score_formula() {
1302 let path_weight = 0.8_f32;
1303 let cosine_clamped = 0.75_f32;
1304 let expected = path_weight * cosine_clamped;
1305 assert!((expected - 0.6).abs() < 1e-5);
1307 }
1308}