1use std::collections::HashMap;
17use std::sync::OnceLock;
18use std::time::{Instant, SystemTime, UNIX_EPOCH};
19#[allow(unused_imports)]
20use zeph_db::sql;
21
22use crate::embedding_store::EmbeddingStore;
23use crate::error::MemoryError;
24use crate::graph::store::GraphStore;
25use crate::graph::types::{Edge, EdgeType, edge_type_weight, evolved_weight};
26
27#[derive(Debug, Clone)]
29pub struct ActivatedNode {
30 pub entity_id: i64,
32 pub activation: f32,
34 pub depth: u32,
36}
37
38#[derive(Debug, Clone)]
40pub struct ActivatedFact {
41 pub edge: Edge,
43 pub activation_score: f32,
45}
46
47pub use zeph_common::memory::SpreadingActivationParams;
48
49#[derive(Debug, Clone)]
58pub struct HelaFact {
59 pub edge: Edge,
61 pub score: f32,
63 pub depth: u32,
66 pub path_weight: f32,
69 pub cosine: Option<f32>,
73}
74
75#[derive(Debug, Clone)]
85pub struct HelaSpreadParams {
86 pub spread_depth: u32,
88 pub edge_types: Vec<EdgeType>,
90 pub max_visited: usize,
92 pub step_budget: Option<std::time::Duration>,
96 pub embed_timeout: Option<std::time::Duration>,
99}
100
101impl Default for HelaSpreadParams {
102 fn default() -> Self {
103 Self {
104 spread_depth: 2,
105 edge_types: Vec::new(),
106 max_visited: 200,
107 step_budget: Some(std::time::Duration::from_millis(8)),
108 embed_timeout: Some(std::time::Duration::from_secs(5)),
109 }
110 }
111}
112
113static HELA_DIM_MISMATCH: OnceLock<String> = OnceLock::new();
122
123fn cosine(a: &[f32], b: &[f32]) -> f32 {
127 let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
128 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
129 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
130 let denom = (norm_a * norm_b).max(f32::EPSILON);
131 dot / denom
132}
133
134#[tracing::instrument(
161 name = "memory.graph.hela_spread",
162 skip_all,
163 fields(
164 depth = params.spread_depth,
165 limit,
166 anchor_id = tracing::field::Empty,
167 visited = tracing::field::Empty,
168 scored = tracing::field::Empty,
169 fallback = tracing::field::Empty,
170 )
171)]
172#[allow(clippy::too_many_arguments, clippy::too_many_lines)] pub async fn hela_spreading_recall(
174 store: &GraphStore,
175 embeddings: &EmbeddingStore,
176 provider: &zeph_llm::any::AnyProvider,
177 query: &str,
178 limit: usize,
179 params: &HelaSpreadParams,
180 hebbian_enabled: bool,
181 hebbian_lr: f32,
182) -> Result<Vec<HelaFact>, MemoryError> {
183 use zeph_llm::LlmProvider as _;
184
185 const ENTITY_COLLECTION: &str = "zeph_graph_entities";
186
187 if limit == 0 {
188 return Ok(Vec::new());
189 }
190
191 if HELA_DIM_MISMATCH.get().map(String::as_str) == Some(ENTITY_COLLECTION) {
194 tracing::debug!("hela: dim mismatch previously detected for collection, skipping");
195 return Ok(Vec::new());
196 }
197
198 let q_vec = if let Some(timeout) = params.embed_timeout {
200 tokio::time::timeout(timeout, provider.embed(query))
201 .await
202 .map_err(|_| {
203 tracing::warn!(timeout_ms = timeout.as_millis(), "hela: embed timed out");
204 MemoryError::Timeout("hela embed".into())
205 })??
206 } else {
207 provider.embed(query).await?
208 };
209
210 let t_anchor = Instant::now();
212 let anchor_results = match embeddings
213 .search_collection(ENTITY_COLLECTION, &q_vec, 1, None)
214 .await
215 {
216 Ok(r) => r,
217 Err(e) => {
218 let msg = e.to_string();
219 if msg.contains("wrong vector dimension")
220 || msg.contains("InvalidArgument")
221 || msg.contains("dimension")
222 {
223 let _ = HELA_DIM_MISMATCH.set(ENTITY_COLLECTION.to_owned());
224 tracing::warn!(
225 collection = ENTITY_COLLECTION,
226 error = %e,
227 "hela: vector dimension mismatch — HL-F5 disabled for this collection"
228 );
229 return Ok(Vec::new());
230 }
231 return Err(e);
232 }
233 };
234
235 if params.step_budget.is_some_and(|b| t_anchor.elapsed() > b) {
236 tracing::warn!(
237 elapsed_ms = t_anchor.elapsed().as_millis(),
238 "hela: anchor ANN over budget"
239 );
240 return Ok(Vec::new());
241 }
242
243 let Some(anchor_point) = anchor_results.first() else {
244 tracing::debug!("hela: no anchor found, returning empty");
245 return Ok(Vec::new());
246 };
247 let Some(anchor_entity_id) = anchor_point
248 .payload
249 .get("entity_id")
250 .and_then(serde_json::Value::as_i64)
251 else {
252 tracing::warn!("hela: anchor point missing entity_id payload");
253 return Ok(Vec::new());
254 };
255 let anchor_cosine = anchor_point.score;
256
257 tracing::Span::current().record("anchor_id", anchor_entity_id);
258 tracing::debug!(anchor_entity_id, anchor_cosine, "hela: anchor resolved");
259
260 let spread_depth = params.spread_depth.clamp(1, 6);
261
262 let mut visited: HashMap<i64, (u32, f32, Option<i64>)> = HashMap::new();
265 visited.insert(anchor_entity_id, (0, 1.0, None));
266
267 let mut edge_cache: HashMap<i64, Edge> = HashMap::new();
271 let mut frontier: Vec<i64> = vec![anchor_entity_id];
272
273 for hop in 0..spread_depth {
274 if frontier.is_empty() {
275 break;
276 }
277
278 tracing::debug!(hop, frontier_size = frontier.len(), "hela: starting hop");
279
280 let t_step = Instant::now();
281 let edges = store
282 .edges_for_entities(&frontier, ¶ms.edge_types)
283 .await?;
284 if params.step_budget.is_some_and(|b| t_step.elapsed() > b) {
285 tracing::warn!(
286 hop,
287 elapsed_ms = t_step.elapsed().as_millis(),
288 "hela: edge-fetch over budget"
289 );
290 return Ok(Vec::new());
291 }
292
293 let mut next_frontier: Vec<i64> = Vec::new();
294
295 for edge in &edges {
296 edge_cache.entry(edge.id).or_insert_with(|| edge.clone());
298
299 for &src_id in &frontier {
300 let neighbor = if edge.source_entity_id == src_id {
301 edge.target_entity_id
302 } else if edge.target_entity_id == src_id {
303 edge.source_entity_id
304 } else {
305 continue;
306 };
307
308 let parent_pw = visited.get(&src_id).map_or(1.0, |&(_, pw, _)| pw);
309 let new_pw = parent_pw * edge.weight;
310
311 let entry = visited
316 .entry(neighbor)
317 .or_insert((hop + 1, 0.0_f32, Some(edge.id)));
318 if new_pw > entry.1
320 || ((new_pw - entry.1).abs() < f32::EPSILON && hop + 1 < entry.0)
321 {
322 *entry = (hop + 1, new_pw, Some(edge.id));
323 if !next_frontier.contains(&neighbor) {
324 next_frontier.push(neighbor);
325 }
326 }
327
328 if visited.len() >= params.max_visited {
329 break;
330 }
331 }
332
333 if visited.len() >= params.max_visited {
334 break;
335 }
336 }
337
338 tracing::debug!(
339 hop,
340 edges_fetched = edges.len(),
341 visited = visited.len(),
342 next_frontier = next_frontier.len(),
343 "hela: hop complete"
344 );
345
346 frontier = next_frontier;
347 if visited.len() >= params.max_visited {
348 break;
349 }
350 }
351
352 if visited.len() == 1 {
355 tracing::Span::current().record("fallback", true);
356 tracing::debug!(
357 anchor_entity_id,
358 anchor_cosine,
359 "hela: anchor isolated, falling back to pure ANN"
360 );
361 let fact = HelaFact {
362 edge: Edge::synthetic_anchor(anchor_entity_id),
363 score: anchor_cosine,
364 depth: 0,
365 path_weight: 1.0,
366 cosine: Some(anchor_cosine.clamp(0.0, 1.0)),
367 };
368 return Ok(vec![fact]);
369 }
370
371 let entity_ids: Vec<i64> = visited.keys().copied().collect();
373 let point_id_map = store.qdrant_point_ids_for_entities(&entity_ids).await?;
374 let point_ids: Vec<String> = point_id_map.values().cloned().collect();
375
376 let t_vec = Instant::now();
377 let vec_map = embeddings
378 .get_vectors_from_collection(ENTITY_COLLECTION, &point_ids)
379 .await?;
380 if params.step_budget.is_some_and(|b| t_vec.elapsed() > b) {
381 tracing::warn!(
382 elapsed_ms = t_vec.elapsed().as_millis(),
383 "hela: vectors-batch over budget"
384 );
385 return Ok(Vec::new());
386 }
387
388 let mut facts: Vec<HelaFact> = Vec::with_capacity(visited.len().saturating_sub(1));
393 for (&entity_id, &(depth, path_weight, edge_id_opt)) in &visited {
394 if entity_id == anchor_entity_id {
395 continue;
396 }
397 let Some(edge_id) = edge_id_opt else {
398 continue;
399 };
400 let Some(point_id) = point_id_map.get(&entity_id) else {
401 continue;
402 };
403 let Some(node_vec) = vec_map.get(point_id) else {
404 continue;
405 };
406 if node_vec.len() != q_vec.len() {
407 continue;
409 }
410 let cosine_clamped = cosine(&q_vec, node_vec).max(0.0);
411 let fact_score = path_weight * cosine_clamped;
412 let Some(edge) = edge_cache.get(&edge_id).cloned() else {
413 continue;
414 };
415 facts.push(HelaFact {
416 edge,
417 score: fact_score,
418 depth,
419 path_weight,
420 cosine: Some(cosine_clamped),
421 });
422 }
423
424 facts.sort_by(|a, b| b.score.total_cmp(&a.score));
426 facts.truncate(limit);
427
428 if hebbian_enabled {
433 let edge_ids: Vec<i64> = facts
434 .iter()
435 .map(|f| f.edge.id)
436 .filter(|&id| id != 0) .collect();
438 if !edge_ids.is_empty()
439 && let Err(e) = store.apply_hebbian_increment(&edge_ids, hebbian_lr).await
440 {
441 tracing::warn!(error = %e, "hela: hebbian increment failed");
442 }
443 }
444
445 tracing::Span::current().record("visited", visited.len());
446 tracing::Span::current().record("scored", facts.len());
447
448 Ok(facts)
449}
450
451pub struct SpreadingActivation {
455 params: SpreadingActivationParams,
456}
457
458impl SpreadingActivation {
459 #[must_use]
464 pub fn new(params: SpreadingActivationParams) -> Self {
465 Self { params }
466 }
467
468 pub async fn spread(
484 &self,
485 store: &GraphStore,
486 seeds: HashMap<i64, f32>,
487 edge_types: &[EdgeType],
488 ) -> Result<(Vec<ActivatedNode>, Vec<ActivatedFact>), MemoryError> {
489 if seeds.is_empty() {
490 return Ok((Vec::new(), Vec::new()));
491 }
492
493 let now_secs: i64 = SystemTime::now()
496 .duration_since(UNIX_EPOCH)
497 .map_or(0, |d| d.as_secs().cast_signed());
498
499 let mut activation = self.initialize_seeds(&seeds);
500 let mut activated_facts: Vec<ActivatedFact> = Vec::new();
501
502 for hop in 0..self.params.max_hops {
503 let active_nodes: Vec<(i64, f32)> = activation
504 .iter()
505 .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
506 .map(|(&id, &(score, _))| (id, score))
507 .collect();
508
509 if active_nodes.is_empty() {
510 break;
511 }
512
513 let node_ids: Vec<i64> = active_nodes.iter().map(|(id, _)| *id).collect();
514 let edges = store.edges_for_entities(&node_ids, edge_types).await?;
515 let edge_count = edges.len();
516
517 let next_activation =
518 self.propagate_one_hop(hop, &active_nodes, &edges, &activation, now_secs);
519
520 let pruned_count = self.merge_and_prune(&mut activation, next_activation);
521
522 tracing::debug!(
523 hop,
524 active_nodes = active_nodes.len(),
525 edges_fetched = edge_count,
526 after_merge = activation.len(),
527 pruned = pruned_count,
528 "spreading activation: hop complete"
529 );
530
531 self.collect_activated_facts(&edges, &activation, &mut activated_facts);
532 }
533
534 let result = self.finalize(activation);
535
536 tracing::info!(
537 activated = result.len(),
538 facts = activated_facts.len(),
539 "spreading activation: complete"
540 );
541
542 Ok((result, activated_facts))
543 }
544
545 fn initialize_seeds(&self, seeds: &HashMap<i64, f32>) -> HashMap<i64, (f32, u32)> {
547 let mut activation: HashMap<i64, (f32, u32)> = HashMap::new();
548 let mut seed_count = 0usize;
549 for (entity_id, match_score) in seeds {
551 if *match_score < self.params.activation_threshold {
552 tracing::debug!(
553 entity_id,
554 score = match_score,
555 threshold = self.params.activation_threshold,
556 "spreading activation: seed below threshold, skipping"
557 );
558 continue;
559 }
560 activation.insert(*entity_id, (*match_score, 0));
561 seed_count += 1;
562 }
563 tracing::debug!(
564 seeds = seed_count,
565 "spreading activation: initialized seeds"
566 );
567 activation
568 }
569
570 fn propagate_one_hop(
574 &self,
575 hop: u32,
576 active_nodes: &[(i64, f32)],
577 edges: &[Edge],
578 activation: &HashMap<i64, (f32, u32)>,
579 now_secs: i64,
580 ) -> HashMap<i64, (f32, u32)> {
581 let mut next_activation: HashMap<i64, (f32, u32)> = HashMap::new();
582
583 for edge in edges {
584 for &(active_id, node_score) in active_nodes {
585 let neighbor = if edge.source_entity_id == active_id {
586 edge.target_entity_id
587 } else if edge.target_entity_id == active_id {
588 edge.source_entity_id
589 } else {
590 continue;
591 };
592
593 let current_score = activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
598 let next_score = next_activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
599 if current_score >= self.params.inhibition_threshold
600 || next_score >= self.params.inhibition_threshold
601 {
602 continue;
603 }
604
605 let recency = self.recency_weight(&edge.valid_from, now_secs);
606 let edge_weight = evolved_weight(edge.retrieval_count, edge.confidence);
607 let type_w = edge_type_weight(edge.edge_type);
608 let spread_value =
609 node_score * self.params.decay_lambda * edge_weight * recency * type_w;
610
611 if spread_value < self.params.activation_threshold {
612 continue;
613 }
614
615 let depth_at_max = hop + 1;
618 let entry = next_activation
619 .entry(neighbor)
620 .or_insert((0.0, depth_at_max));
621 let new_score = (entry.0 + spread_value).min(1.0);
622 if new_score > entry.0 {
623 entry.0 = new_score;
624 entry.1 = depth_at_max;
625 }
626 }
627 }
628
629 next_activation
630 }
631
632 fn merge_and_prune(
636 &self,
637 activation: &mut HashMap<i64, (f32, u32)>,
638 next_activation: HashMap<i64, (f32, u32)>,
639 ) -> usize {
640 for (node_id, (new_score, new_depth)) in next_activation {
641 let entry = activation.entry(node_id).or_insert((0.0, new_depth));
642 if new_score > entry.0 {
643 entry.0 = new_score;
644 entry.1 = new_depth;
645 }
646 }
647
648 if activation.len() > self.params.max_activated_nodes {
649 let before = activation.len();
650 let mut entries: Vec<(i64, (f32, u32))> = activation.drain().collect();
651 entries.sort_by(|(_, (a, _)), (_, (b, _))| b.total_cmp(a));
652 entries.truncate(self.params.max_activated_nodes);
653 *activation = entries.into_iter().collect();
654 before - self.params.max_activated_nodes
655 } else {
656 0
657 }
658 }
659
660 fn collect_activated_facts(
662 &self,
663 edges: &[Edge],
664 activation: &HashMap<i64, (f32, u32)>,
665 activated_facts: &mut Vec<ActivatedFact>,
666 ) {
667 for edge in edges {
668 let src_score = activation
669 .get(&edge.source_entity_id)
670 .map_or(0.0, |&(s, _)| s);
671 let tgt_score = activation
672 .get(&edge.target_entity_id)
673 .map_or(0.0, |&(s, _)| s);
674 if src_score >= self.params.activation_threshold
675 && tgt_score >= self.params.activation_threshold
676 {
677 let activation_score = src_score.max(tgt_score);
678 activated_facts.push(ActivatedFact {
679 edge: edge.clone(),
680 activation_score,
681 });
682 }
683 }
684 }
685
686 fn finalize(&self, activation: HashMap<i64, (f32, u32)>) -> Vec<ActivatedNode> {
688 let mut result: Vec<ActivatedNode> = activation
689 .into_iter()
690 .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
691 .map(|(entity_id, (activation, depth))| ActivatedNode {
692 entity_id,
693 activation,
694 depth,
695 })
696 .collect();
697 result.sort_by(|a, b| b.activation.total_cmp(&a.activation));
698 result
699 }
700
701 #[allow(clippy::cast_precision_loss)]
707 fn recency_weight(&self, valid_from: &str, now_secs: i64) -> f32 {
708 if self.params.temporal_decay_rate <= 0.0 {
709 return 1.0;
710 }
711 let Some(valid_from_secs) = parse_sqlite_datetime_to_unix(valid_from) else {
712 return 1.0;
713 };
714 let age_secs = (now_secs - valid_from_secs).max(0);
715 let age_days = age_secs as f64 / 86_400.0;
716 let weight = 1.0_f64 / (1.0 + age_days * self.params.temporal_decay_rate);
717 #[allow(clippy::cast_possible_truncation)]
719 let w = weight as f32;
720 w
721 }
722}
723
724#[must_use]
729fn parse_sqlite_datetime_to_unix(s: &str) -> Option<i64> {
730 if s.len() < 19 {
731 return None;
732 }
733 let year: i64 = s[0..4].parse().ok()?;
734 let month: i64 = s[5..7].parse().ok()?;
735 let day: i64 = s[8..10].parse().ok()?;
736 let hour: i64 = s[11..13].parse().ok()?;
737 let min: i64 = s[14..16].parse().ok()?;
738 let sec: i64 = s[17..19].parse().ok()?;
739
740 let (y, m) = if month <= 2 {
743 (year - 1, month + 9)
744 } else {
745 (year, month - 3)
746 };
747 let era = y.div_euclid(400);
748 let yoe = y - era * 400;
749 let doy = (153 * m + 2) / 5 + day - 1;
750 let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
751 let days = era * 146_097 + doe - 719_468;
752
753 Some(days * 86_400 + hour * 3_600 + min * 60 + sec)
754}
755
756#[cfg(test)]
757mod tests {
758 use super::*;
759 use crate::graph::GraphStore;
760 use crate::graph::types::EntityType;
761 use crate::store::SqliteStore;
762
763 async fn setup_store() -> GraphStore {
764 let store = SqliteStore::new(":memory:").await.unwrap();
765 GraphStore::new(store.pool().clone())
766 }
767
768 fn default_params() -> SpreadingActivationParams {
769 SpreadingActivationParams {
770 decay_lambda: 0.85,
771 max_hops: 3,
772 activation_threshold: 0.1,
773 inhibition_threshold: 0.8,
774 max_activated_nodes: 50,
775 temporal_decay_rate: 0.0,
776 seed_structural_weight: 0.4,
777 seed_community_cap: 3,
778 }
779 }
780
781 #[tokio::test]
784 async fn spread_empty_graph_no_edges_no_facts() {
785 let store = setup_store().await;
786 let sa = SpreadingActivation::new(default_params());
787 let seeds = HashMap::from([(1_i64, 1.0_f32)]);
788 let (nodes, facts) = sa.spread(&store, seeds, &[]).await.unwrap();
789 assert_eq!(nodes.len(), 1, "seed must be in activated nodes");
791 assert_eq!(nodes[0].entity_id, 1);
792 assert!((nodes[0].activation - 1.0).abs() < 1e-6);
793 assert!(
795 facts.is_empty(),
796 "expected no activated facts on empty graph"
797 );
798 }
799
800 #[tokio::test]
802 async fn spread_empty_seeds_returns_empty() {
803 let store = setup_store().await;
804 let sa = SpreadingActivation::new(default_params());
805 let (nodes, facts) = sa.spread(&store, HashMap::new(), &[]).await.unwrap();
806 assert!(nodes.is_empty());
807 assert!(facts.is_empty());
808 }
809
810 #[tokio::test]
812 async fn spread_single_seed_no_edges_returns_seed() {
813 let store = setup_store().await;
814 let alice = store
815 .upsert_entity("Alice", "Alice", EntityType::Person, None)
816 .await
817 .unwrap()
818 .0;
819
820 let sa = SpreadingActivation::new(default_params());
821 let seeds = HashMap::from([(alice, 1.0_f32)]);
822 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
823 assert_eq!(nodes.len(), 1);
824 assert_eq!(nodes[0].entity_id, alice);
825 assert_eq!(nodes[0].depth, 0);
826 assert!((nodes[0].activation - 1.0).abs() < 1e-6);
827 }
828
829 #[tokio::test]
831 async fn spread_linear_chain_all_activated_with_decay() {
832 let store = setup_store().await;
833 let a = store
834 .upsert_entity("A", "A", EntityType::Person, None)
835 .await
836 .unwrap()
837 .0;
838 let b = store
839 .upsert_entity("B", "B", EntityType::Person, None)
840 .await
841 .unwrap()
842 .0;
843 let c = store
844 .upsert_entity("C", "C", EntityType::Person, None)
845 .await
846 .unwrap()
847 .0;
848 store
849 .insert_edge(a, b, "knows", "A knows B", 1.0, None)
850 .await
851 .unwrap();
852 store
853 .insert_edge(b, c, "knows", "B knows C", 1.0, None)
854 .await
855 .unwrap();
856
857 let mut cfg = default_params();
858 cfg.max_hops = 3;
859 cfg.decay_lambda = 0.9;
860 let sa = SpreadingActivation::new(cfg);
861 let seeds = HashMap::from([(a, 1.0_f32)]);
862 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
863
864 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
865 assert!(ids.contains(&a), "A (seed) must be activated");
866 assert!(ids.contains(&b), "B (hop 1) must be activated");
867 assert!(ids.contains(&c), "C (hop 2) must be activated");
868
869 let score_a = nodes.iter().find(|n| n.entity_id == a).unwrap().activation;
871 let score_b = nodes.iter().find(|n| n.entity_id == b).unwrap().activation;
872 let score_c = nodes.iter().find(|n| n.entity_id == c).unwrap().activation;
873 assert!(
874 score_a > score_b,
875 "seed A should have higher activation than hop-1 B"
876 );
877 assert!(
878 score_b > score_c,
879 "hop-1 B should have higher activation than hop-2 C"
880 );
881 }
882
883 #[tokio::test]
885 async fn spread_linear_chain_max_hops_limits_reach() {
886 let store = setup_store().await;
887 let a = store
888 .upsert_entity("A", "A", EntityType::Person, None)
889 .await
890 .unwrap()
891 .0;
892 let b = store
893 .upsert_entity("B", "B", EntityType::Person, None)
894 .await
895 .unwrap()
896 .0;
897 let c = store
898 .upsert_entity("C", "C", EntityType::Person, None)
899 .await
900 .unwrap()
901 .0;
902 store
903 .insert_edge(a, b, "knows", "A knows B", 1.0, None)
904 .await
905 .unwrap();
906 store
907 .insert_edge(b, c, "knows", "B knows C", 1.0, None)
908 .await
909 .unwrap();
910
911 let mut cfg = default_params();
912 cfg.max_hops = 1;
913 let sa = SpreadingActivation::new(cfg);
914 let seeds = HashMap::from([(a, 1.0_f32)]);
915 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
916
917 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
918 assert!(ids.contains(&a), "A must be activated (seed)");
919 assert!(ids.contains(&b), "B must be activated (hop 1)");
920 assert!(!ids.contains(&c), "C must NOT be activated with max_hops=1");
921 }
922
923 #[tokio::test]
927 async fn spread_diamond_graph_convergence() {
928 let store = setup_store().await;
929 let a = store
930 .upsert_entity("A", "A", EntityType::Person, None)
931 .await
932 .unwrap()
933 .0;
934 let b = store
935 .upsert_entity("B", "B", EntityType::Person, None)
936 .await
937 .unwrap()
938 .0;
939 let c = store
940 .upsert_entity("C", "C", EntityType::Person, None)
941 .await
942 .unwrap()
943 .0;
944 let d = store
945 .upsert_entity("D", "D", EntityType::Person, None)
946 .await
947 .unwrap()
948 .0;
949 store
950 .insert_edge(a, b, "rel", "A-B", 1.0, None)
951 .await
952 .unwrap();
953 store
954 .insert_edge(a, c, "rel", "A-C", 1.0, None)
955 .await
956 .unwrap();
957 store
958 .insert_edge(b, d, "rel", "B-D", 1.0, None)
959 .await
960 .unwrap();
961 store
962 .insert_edge(c, d, "rel", "C-D", 1.0, None)
963 .await
964 .unwrap();
965
966 let mut cfg = default_params();
967 cfg.max_hops = 3;
968 cfg.decay_lambda = 0.9;
969 cfg.inhibition_threshold = 0.95; let sa = SpreadingActivation::new(cfg);
971 let seeds = HashMap::from([(a, 1.0_f32)]);
972 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
973
974 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
975 assert!(ids.contains(&d), "D must be activated via diamond paths");
976
977 let node_d = nodes.iter().find(|n| n.entity_id == d).unwrap();
979 assert_eq!(node_d.depth, 2, "D should be at depth 2");
980 }
981
982 #[tokio::test]
984 async fn spread_inhibition_prevents_runaway() {
985 let store = setup_store().await;
986 let hub = store
988 .upsert_entity("Hub", "Hub", EntityType::Concept, None)
989 .await
990 .unwrap()
991 .0;
992
993 for i in 0..5 {
994 let leaf = store
995 .upsert_entity(
996 &format!("Leaf{i}"),
997 &format!("Leaf{i}"),
998 EntityType::Concept,
999 None,
1000 )
1001 .await
1002 .unwrap()
1003 .0;
1004 store
1005 .insert_edge(hub, leaf, "has", &format!("Hub has Leaf{i}"), 1.0, None)
1006 .await
1007 .unwrap();
1008 store
1010 .insert_edge(
1011 leaf,
1012 hub,
1013 "part_of",
1014 &format!("Leaf{i} part_of Hub"),
1015 1.0,
1016 None,
1017 )
1018 .await
1019 .unwrap();
1020 }
1021
1022 let mut cfg = default_params();
1024 cfg.inhibition_threshold = 0.8;
1025 cfg.max_hops = 3;
1026 let sa = SpreadingActivation::new(cfg);
1027 let seeds = HashMap::from([(hub, 1.0_f32)]);
1028 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1029
1030 let hub_node = nodes.iter().find(|n| n.entity_id == hub);
1032 assert!(hub_node.is_some(), "hub must be in results");
1033 assert!(
1034 hub_node.unwrap().activation <= 1.0,
1035 "activation must not exceed 1.0"
1036 );
1037 }
1038
1039 #[tokio::test]
1041 async fn spread_max_activated_nodes_cap_enforced() {
1042 let store = setup_store().await;
1043 let root = store
1044 .upsert_entity("Root", "Root", EntityType::Person, None)
1045 .await
1046 .unwrap()
1047 .0;
1048
1049 for i in 0..20 {
1051 let leaf = store
1052 .upsert_entity(
1053 &format!("Node{i}"),
1054 &format!("Node{i}"),
1055 EntityType::Concept,
1056 None,
1057 )
1058 .await
1059 .unwrap()
1060 .0;
1061 store
1062 .insert_edge(root, leaf, "has", &format!("Root has Node{i}"), 0.9, None)
1063 .await
1064 .unwrap();
1065 }
1066
1067 let max_nodes = 5;
1068 let cfg = SpreadingActivationParams {
1069 max_activated_nodes: max_nodes,
1070 max_hops: 2,
1071 ..default_params()
1072 };
1073 let sa = SpreadingActivation::new(cfg);
1074 let seeds = HashMap::from([(root, 1.0_f32)]);
1075 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1076
1077 assert!(
1078 nodes.len() <= max_nodes,
1079 "activation must be capped at {max_nodes} nodes, got {}",
1080 nodes.len()
1081 );
1082 }
1083
1084 #[tokio::test]
1086 async fn spread_temporal_decay_recency_effect() {
1087 let store = setup_store().await;
1088 let src = store
1089 .upsert_entity("Src", "Src", EntityType::Person, None)
1090 .await
1091 .unwrap()
1092 .0;
1093 let recent = store
1094 .upsert_entity("Recent", "Recent", EntityType::Tool, None)
1095 .await
1096 .unwrap()
1097 .0;
1098 let old = store
1099 .upsert_entity("Old", "Old", EntityType::Tool, None)
1100 .await
1101 .unwrap()
1102 .0;
1103
1104 store
1106 .insert_edge(src, recent, "uses", "Src uses Recent", 1.0, None)
1107 .await
1108 .unwrap();
1109
1110 zeph_db::query(
1112 sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
1113 VALUES (?1, ?2, 'uses', 'Src uses Old', 1.0, '1970-01-01 00:00:00')"),
1114 )
1115 .bind(src)
1116 .bind(old)
1117 .execute(store.pool())
1118 .await
1119 .unwrap();
1120
1121 let mut cfg = default_params();
1122 cfg.max_hops = 2;
1123 let sa = SpreadingActivation::new(SpreadingActivationParams {
1125 temporal_decay_rate: 0.5,
1126 ..cfg
1127 });
1128 let seeds = HashMap::from([(src, 1.0_f32)]);
1129 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
1130
1131 let score_recent = nodes
1132 .iter()
1133 .find(|n| n.entity_id == recent)
1134 .map_or(0.0, |n| n.activation);
1135 let score_old = nodes
1136 .iter()
1137 .find(|n| n.entity_id == old)
1138 .map_or(0.0, |n| n.activation);
1139
1140 assert!(
1141 score_recent > score_old,
1142 "recent edge ({score_recent}) must produce higher activation than old edge ({score_old})"
1143 );
1144 }
1145
1146 #[tokio::test]
1148 async fn spread_edge_type_filter_excludes_other_types() {
1149 let store = setup_store().await;
1150 let a = store
1151 .upsert_entity("A", "A", EntityType::Person, None)
1152 .await
1153 .unwrap()
1154 .0;
1155 let b_semantic = store
1156 .upsert_entity("BSemantic", "BSemantic", EntityType::Tool, None)
1157 .await
1158 .unwrap()
1159 .0;
1160 let c_causal = store
1161 .upsert_entity("CCausal", "CCausal", EntityType::Concept, None)
1162 .await
1163 .unwrap()
1164 .0;
1165
1166 store
1168 .insert_edge(a, b_semantic, "uses", "A uses BSemantic", 1.0, None)
1169 .await
1170 .unwrap();
1171
1172 zeph_db::query(
1174 sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from, edge_type)
1175 VALUES (?1, ?2, 'caused', 'A caused CCausal', 1.0, datetime('now'), 'causal')"),
1176 )
1177 .bind(a)
1178 .bind(c_causal)
1179 .execute(store.pool())
1180 .await
1181 .unwrap();
1182
1183 let cfg = default_params();
1184 let sa = SpreadingActivation::new(cfg);
1185
1186 let seeds = HashMap::from([(a, 1.0_f32)]);
1188 let (nodes, _) = sa
1189 .spread(&store, seeds, &[EdgeType::Semantic])
1190 .await
1191 .unwrap();
1192
1193 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
1194 assert!(
1195 ids.contains(&b_semantic),
1196 "BSemantic must be activated via semantic edge"
1197 );
1198 assert!(
1199 !ids.contains(&c_causal),
1200 "CCausal must NOT be activated when filtering to semantic only"
1201 );
1202 }
1203
1204 #[tokio::test]
1206 async fn spread_large_seed_list() {
1207 let store = setup_store().await;
1208 let mut seeds = HashMap::new();
1209
1210 for i in 0..100i64 {
1212 let id = store
1213 .upsert_entity(
1214 &format!("Entity{i}"),
1215 &format!("entity{i}"),
1216 EntityType::Concept,
1217 None,
1218 )
1219 .await
1220 .unwrap()
1221 .0;
1222 seeds.insert(id, 1.0_f32);
1223 }
1224
1225 let cfg = default_params();
1226 let sa = SpreadingActivation::new(cfg);
1227 let result = sa.spread(&store, seeds, &[]).await;
1229 assert!(
1230 result.is_ok(),
1231 "large seed list must not error: {:?}",
1232 result.err()
1233 );
1234 }
1235
1236 #[test]
1239 fn hela_cosine_identical_vectors() {
1240 let v = vec![1.0_f32, 0.0, 0.0];
1241 assert!(
1242 (cosine(&v, &v) - 1.0).abs() < 1e-6,
1243 "identical vectors → cosine 1.0"
1244 );
1245 }
1246
1247 #[test]
1248 fn hela_cosine_orthogonal_vectors() {
1249 let a = vec![1.0_f32, 0.0];
1250 let b = vec![0.0_f32, 1.0];
1251 assert!(
1252 cosine(&a, &b).abs() < 1e-6,
1253 "orthogonal vectors → cosine 0.0"
1254 );
1255 }
1256
1257 #[test]
1258 fn hela_cosine_anti_correlated() {
1259 let a = vec![1.0_f32, 0.0];
1260 let b = vec![-1.0_f32, 0.0];
1261 assert!(
1262 cosine(&a, &b) < 0.0,
1263 "anti-correlated vectors → negative cosine"
1264 );
1265 }
1266
1267 #[test]
1268 fn hela_cosine_zero_vector_no_panic() {
1269 let a = vec![0.0_f32, 0.0];
1270 let b = vec![1.0_f32, 0.0];
1271 let result = cosine(&a, &b);
1273 assert!(
1274 result.is_finite(),
1275 "zero-norm vector must yield finite cosine"
1276 );
1277 }
1278
1279 #[test]
1280 fn hela_spread_params_default_depth_is_two() {
1281 let p = HelaSpreadParams::default();
1282 assert_eq!(p.spread_depth, 2);
1283 assert!(p.step_budget.is_some());
1284 assert!(p.edge_types.is_empty());
1285 assert_eq!(p.max_visited, 200);
1286 }
1287
1288 #[test]
1289 fn hela_spread_params_default_embed_timeout_is_some() {
1290 let p = HelaSpreadParams::default();
1291 assert!(
1292 p.embed_timeout.is_some(),
1293 "default embed_timeout must be Some (5 s)"
1294 );
1295 }
1296
1297 #[tokio::test]
1300 async fn hela_spreading_recall_embed_timeout_returns_error() {
1301 use std::time::Duration;
1302 use zeph_llm::any::AnyProvider;
1303 use zeph_llm::mock::MockProvider;
1304
1305 use crate::embedding_store::EmbeddingStore;
1306 use crate::error::MemoryError;
1307 use crate::in_memory_store::InMemoryVectorStore;
1308
1309 let store = setup_store().await;
1310
1311 let mock = MockProvider::default().with_embed_delay(500);
1313 let provider = AnyProvider::Mock(mock);
1314
1315 let sqlite = crate::store::SqliteStore::with_pool_size(":memory:", 1)
1316 .await
1317 .unwrap();
1318 let embeddings =
1319 EmbeddingStore::with_store(Box::new(InMemoryVectorStore::new()), sqlite.pool().clone());
1320
1321 let params = HelaSpreadParams {
1322 embed_timeout: Some(Duration::from_millis(50)),
1323 ..Default::default()
1324 };
1325
1326 let result = hela_spreading_recall(
1327 &store,
1328 &embeddings,
1329 &provider,
1330 "test query",
1331 5,
1332 ¶ms,
1333 false,
1334 0.0,
1335 )
1336 .await;
1337
1338 assert!(
1339 matches!(result, Err(MemoryError::Timeout(_))),
1340 "expected Err(MemoryError::Timeout), got {result:?}"
1341 );
1342 }
1343
1344 #[tokio::test]
1347 async fn hela_spreading_recall_no_timeout_does_not_wrap() {
1348 use zeph_llm::any::AnyProvider;
1349 use zeph_llm::mock::MockProvider;
1350
1351 use crate::embedding_store::EmbeddingStore;
1352 use crate::in_memory_store::InMemoryVectorStore;
1353
1354 let store = setup_store().await;
1355
1356 let mock = MockProvider::default().with_embed_delay(0);
1357 let provider = AnyProvider::Mock(mock);
1358
1359 let sqlite = crate::store::SqliteStore::with_pool_size(":memory:", 1)
1360 .await
1361 .unwrap();
1362 let embeddings =
1363 EmbeddingStore::with_store(Box::new(InMemoryVectorStore::new()), sqlite.pool().clone());
1364
1365 let params = HelaSpreadParams {
1366 embed_timeout: None,
1367 ..Default::default()
1368 };
1369
1370 let result = hela_spreading_recall(
1373 &store,
1374 &embeddings,
1375 &provider,
1376 "test query",
1377 5,
1378 ¶ms,
1379 false,
1380 0.0,
1381 )
1382 .await;
1383
1384 assert!(
1387 !matches!(result, Err(crate::error::MemoryError::Timeout(_))),
1388 "embed_timeout: None must not produce a Timeout error, got {result:?}"
1389 );
1390 }
1391
1392 #[test]
1393 fn hela_synthetic_anchor_edge_id_is_zero() {
1394 let edge = Edge::synthetic_anchor(42);
1395 assert_eq!(
1396 edge.id, 0,
1397 "synthetic anchor must have id = 0 to be excluded from Hebbian"
1398 );
1399 assert_eq!(edge.source_entity_id, 42);
1400 assert_eq!(edge.target_entity_id, 42);
1401 }
1402
1403 #[test]
1404 fn hela_negative_cosine_clamped_to_zero_in_score() {
1405 let anti = vec![-1.0_f32, 0.0];
1407 let query = vec![1.0_f32, 0.0];
1408 let cosine_raw = cosine(&query, &anti);
1409 assert!(cosine_raw < 0.0);
1410 let clamped = cosine_raw.max(0.0);
1411 let fact_score = 0.9_f32 * clamped;
1412 assert!(
1413 fact_score < f32::EPSILON,
1414 "anti-correlated score must be 0.0"
1415 );
1416 }
1417
1418 #[test]
1419 fn hela_path_weight_multiplicative() {
1420 let w1 = 0.8_f32;
1422 let w2 = 0.5_f32;
1423 let expected = w1 * w2;
1424 assert!((expected - 0.4).abs() < 1e-6);
1425 }
1426
1427 #[test]
1428 fn hela_max_path_weight_on_multipath() {
1429 let pw_a = 0.9_f32; let pw_b = 0.3_f32; let kept = pw_a.max(pw_b);
1433 assert!(
1434 (kept - 0.9).abs() < 1e-6,
1435 "multi-path resolution must keep maximum path_weight"
1436 );
1437 }
1438
1439 #[test]
1440 fn hela_fact_score_formula() {
1441 let path_weight = 0.8_f32;
1442 let cosine_clamped = 0.75_f32;
1443 let expected = path_weight * cosine_clamped;
1444 assert!((expected - 0.6).abs() < 1e-5);
1446 }
1447}