1use std::collections::HashMap;
17use std::time::{SystemTime, UNIX_EPOCH};
18#[allow(unused_imports)]
19use zeph_db::sql;
20
21use crate::error::MemoryError;
22use crate::graph::store::GraphStore;
23use crate::graph::types::{Edge, EdgeType, edge_type_weight, evolved_weight};
24
25#[derive(Debug, Clone)]
27pub struct ActivatedNode {
28 pub entity_id: i64,
30 pub activation: f32,
32 pub depth: u32,
34}
35
36#[derive(Debug, Clone)]
38pub struct ActivatedFact {
39 pub edge: Edge,
41 pub activation_score: f32,
43}
44
45#[derive(Debug, Clone)]
48pub struct SpreadingActivationParams {
49 pub decay_lambda: f32,
50 pub max_hops: u32,
51 pub activation_threshold: f32,
52 pub inhibition_threshold: f32,
53 pub max_activated_nodes: usize,
54 pub temporal_decay_rate: f64,
55 pub seed_structural_weight: f32,
57 pub seed_community_cap: usize,
59}
60
61pub struct SpreadingActivation {
63 params: SpreadingActivationParams,
64}
65
66impl SpreadingActivation {
67 #[must_use]
72 pub fn new(params: SpreadingActivationParams) -> Self {
73 Self { params }
74 }
75
76 #[allow(clippy::too_many_lines)]
92 pub async fn spread(
93 &self,
94 store: &GraphStore,
95 seeds: HashMap<i64, f32>,
96 edge_types: &[EdgeType],
97 ) -> Result<(Vec<ActivatedNode>, Vec<ActivatedFact>), MemoryError> {
98 if seeds.is_empty() {
99 return Ok((Vec::new(), Vec::new()));
100 }
101
102 let now_secs: i64 = SystemTime::now()
105 .duration_since(UNIX_EPOCH)
106 .map(|d| d.as_secs().cast_signed())
107 .unwrap_or(0);
108
109 let mut activation: HashMap<i64, (f32, u32)> = HashMap::new();
111
112 let mut seed_count = 0usize;
115 for (entity_id, match_score) in &seeds {
116 if *match_score < self.params.activation_threshold {
117 tracing::debug!(
118 entity_id,
119 score = match_score,
120 threshold = self.params.activation_threshold,
121 "spreading activation: seed below threshold, skipping"
122 );
123 continue;
124 }
125 activation.insert(*entity_id, (*match_score, 0));
126 seed_count += 1;
127 }
128
129 tracing::debug!(
130 seeds = seed_count,
131 "spreading activation: initialized seeds"
132 );
133
134 let mut activated_facts: Vec<ActivatedFact> = Vec::new();
136
137 for hop in 0..self.params.max_hops {
139 let active_nodes: Vec<(i64, f32)> = activation
141 .iter()
142 .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
143 .map(|(&id, &(score, _))| (id, score))
144 .collect();
145
146 if active_nodes.is_empty() {
147 break;
148 }
149
150 let node_ids: Vec<i64> = active_nodes.iter().map(|(id, _)| *id).collect();
151
152 let edges = store.edges_for_entities(&node_ids, edge_types).await?;
154 let edge_count = edges.len();
155
156 let mut next_activation: HashMap<i64, (f32, u32)> = HashMap::new();
157
158 for edge in &edges {
159 for &(active_id, node_score) in &active_nodes {
162 let neighbor = if edge.source_entity_id == active_id {
163 edge.target_entity_id
164 } else if edge.target_entity_id == active_id {
165 edge.source_entity_id
166 } else {
167 continue;
168 };
169
170 let current_score = activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
175 let next_score = next_activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
176 if current_score >= self.params.inhibition_threshold
177 || next_score >= self.params.inhibition_threshold
178 {
179 continue;
180 }
181
182 let recency = self.recency_weight(&edge.valid_from, now_secs);
183 let edge_weight = evolved_weight(edge.retrieval_count, edge.confidence);
184 let type_w = edge_type_weight(edge.edge_type);
185 let spread_value =
186 node_score * self.params.decay_lambda * edge_weight * recency * type_w;
187
188 if spread_value < self.params.activation_threshold {
189 continue;
190 }
191
192 let depth_at_max = hop + 1;
196 let entry = next_activation
197 .entry(neighbor)
198 .or_insert((0.0, depth_at_max));
199 let new_score = (entry.0 + spread_value).min(1.0);
200 if new_score > entry.0 {
201 entry.0 = new_score;
202 entry.1 = depth_at_max;
203 }
204 }
205 }
206
207 for (node_id, (new_score, new_depth)) in next_activation {
209 let entry = activation.entry(node_id).or_insert((0.0, new_depth));
210 if new_score > entry.0 {
211 entry.0 = new_score;
212 entry.1 = new_depth;
213 }
214 }
215
216 let pruned_count = if activation.len() > self.params.max_activated_nodes {
219 let before = activation.len();
220 let mut entries: Vec<(i64, (f32, u32))> = activation.drain().collect();
221 entries.sort_by(|(_, (a, _)), (_, (b, _))| b.total_cmp(a));
222 entries.truncate(self.params.max_activated_nodes);
223 activation = entries.into_iter().collect();
224 before - self.params.max_activated_nodes
225 } else {
226 0
227 };
228
229 tracing::debug!(
230 hop,
231 active_nodes = active_nodes.len(),
232 edges_fetched = edge_count,
233 after_merge = activation.len(),
234 pruned = pruned_count,
235 "spreading activation: hop complete"
236 );
237
238 for edge in edges {
240 let src_score = activation
242 .get(&edge.source_entity_id)
243 .map_or(0.0, |&(s, _)| s);
244 let tgt_score = activation
245 .get(&edge.target_entity_id)
246 .map_or(0.0, |&(s, _)| s);
247 if src_score >= self.params.activation_threshold
248 && tgt_score >= self.params.activation_threshold
249 {
250 let activation_score = src_score.max(tgt_score);
251 activated_facts.push(ActivatedFact {
252 edge,
253 activation_score,
254 });
255 }
256 }
257 }
258
259 let mut result: Vec<ActivatedNode> = activation
261 .into_iter()
262 .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
263 .map(|(entity_id, (activation, depth))| ActivatedNode {
264 entity_id,
265 activation,
266 depth,
267 })
268 .collect();
269 result.sort_by(|a, b| b.activation.total_cmp(&a.activation));
270
271 tracing::info!(
272 activated = result.len(),
273 facts = activated_facts.len(),
274 "spreading activation: complete"
275 );
276
277 Ok((result, activated_facts))
278 }
279
280 #[allow(clippy::cast_precision_loss)]
286 fn recency_weight(&self, valid_from: &str, now_secs: i64) -> f32 {
287 if self.params.temporal_decay_rate <= 0.0 {
288 return 1.0;
289 }
290 let Some(valid_from_secs) = parse_sqlite_datetime_to_unix(valid_from) else {
291 return 1.0;
292 };
293 let age_secs = (now_secs - valid_from_secs).max(0);
294 let age_days = age_secs as f64 / 86_400.0;
295 let weight = 1.0_f64 / (1.0 + age_days * self.params.temporal_decay_rate);
296 #[allow(clippy::cast_possible_truncation)]
298 let w = weight as f32;
299 w
300 }
301}
302
303#[must_use]
308fn parse_sqlite_datetime_to_unix(s: &str) -> Option<i64> {
309 if s.len() < 19 {
310 return None;
311 }
312 let year: i64 = s[0..4].parse().ok()?;
313 let month: i64 = s[5..7].parse().ok()?;
314 let day: i64 = s[8..10].parse().ok()?;
315 let hour: i64 = s[11..13].parse().ok()?;
316 let min: i64 = s[14..16].parse().ok()?;
317 let sec: i64 = s[17..19].parse().ok()?;
318
319 let (y, m) = if month <= 2 {
322 (year - 1, month + 9)
323 } else {
324 (year, month - 3)
325 };
326 let era = y.div_euclid(400);
327 let yoe = y - era * 400;
328 let doy = (153 * m + 2) / 5 + day - 1;
329 let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
330 let days = era * 146_097 + doe - 719_468;
331
332 Some(days * 86_400 + hour * 3_600 + min * 60 + sec)
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use crate::graph::GraphStore;
339 use crate::graph::types::EntityType;
340 use crate::store::SqliteStore;
341
342 async fn setup_store() -> GraphStore {
343 let store = SqliteStore::new(":memory:").await.unwrap();
344 GraphStore::new(store.pool().clone())
345 }
346
347 fn default_params() -> SpreadingActivationParams {
348 SpreadingActivationParams {
349 decay_lambda: 0.85,
350 max_hops: 3,
351 activation_threshold: 0.1,
352 inhibition_threshold: 0.8,
353 max_activated_nodes: 50,
354 temporal_decay_rate: 0.0,
355 seed_structural_weight: 0.4,
356 seed_community_cap: 3,
357 }
358 }
359
360 #[tokio::test]
363 async fn spread_empty_graph_no_edges_no_facts() {
364 let store = setup_store().await;
365 let sa = SpreadingActivation::new(default_params());
366 let seeds = HashMap::from([(1_i64, 1.0_f32)]);
367 let (nodes, facts) = sa.spread(&store, seeds, &[]).await.unwrap();
368 assert_eq!(nodes.len(), 1, "seed must be in activated nodes");
370 assert_eq!(nodes[0].entity_id, 1);
371 assert!((nodes[0].activation - 1.0).abs() < 1e-6);
372 assert!(
374 facts.is_empty(),
375 "expected no activated facts on empty graph"
376 );
377 }
378
379 #[tokio::test]
381 async fn spread_empty_seeds_returns_empty() {
382 let store = setup_store().await;
383 let sa = SpreadingActivation::new(default_params());
384 let (nodes, facts) = sa.spread(&store, HashMap::new(), &[]).await.unwrap();
385 assert!(nodes.is_empty());
386 assert!(facts.is_empty());
387 }
388
389 #[tokio::test]
391 async fn spread_single_seed_no_edges_returns_seed() {
392 let store = setup_store().await;
393 let alice = store
394 .upsert_entity("Alice", "Alice", EntityType::Person, None)
395 .await
396 .unwrap();
397
398 let sa = SpreadingActivation::new(default_params());
399 let seeds = HashMap::from([(alice, 1.0_f32)]);
400 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
401 assert_eq!(nodes.len(), 1);
402 assert_eq!(nodes[0].entity_id, alice);
403 assert_eq!(nodes[0].depth, 0);
404 assert!((nodes[0].activation - 1.0).abs() < 1e-6);
405 }
406
407 #[tokio::test]
409 async fn spread_linear_chain_all_activated_with_decay() {
410 let store = setup_store().await;
411 let a = store
412 .upsert_entity("A", "A", EntityType::Person, None)
413 .await
414 .unwrap();
415 let b = store
416 .upsert_entity("B", "B", EntityType::Person, None)
417 .await
418 .unwrap();
419 let c = store
420 .upsert_entity("C", "C", EntityType::Person, None)
421 .await
422 .unwrap();
423 store
424 .insert_edge(a, b, "knows", "A knows B", 1.0, None)
425 .await
426 .unwrap();
427 store
428 .insert_edge(b, c, "knows", "B knows C", 1.0, None)
429 .await
430 .unwrap();
431
432 let mut cfg = default_params();
433 cfg.max_hops = 3;
434 cfg.decay_lambda = 0.9;
435 let sa = SpreadingActivation::new(cfg);
436 let seeds = HashMap::from([(a, 1.0_f32)]);
437 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
438
439 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
440 assert!(ids.contains(&a), "A (seed) must be activated");
441 assert!(ids.contains(&b), "B (hop 1) must be activated");
442 assert!(ids.contains(&c), "C (hop 2) must be activated");
443
444 let score_a = nodes.iter().find(|n| n.entity_id == a).unwrap().activation;
446 let score_b = nodes.iter().find(|n| n.entity_id == b).unwrap().activation;
447 let score_c = nodes.iter().find(|n| n.entity_id == c).unwrap().activation;
448 assert!(
449 score_a > score_b,
450 "seed A should have higher activation than hop-1 B"
451 );
452 assert!(
453 score_b > score_c,
454 "hop-1 B should have higher activation than hop-2 C"
455 );
456 }
457
458 #[tokio::test]
460 async fn spread_linear_chain_max_hops_limits_reach() {
461 let store = setup_store().await;
462 let a = store
463 .upsert_entity("A", "A", EntityType::Person, None)
464 .await
465 .unwrap();
466 let b = store
467 .upsert_entity("B", "B", EntityType::Person, None)
468 .await
469 .unwrap();
470 let c = store
471 .upsert_entity("C", "C", EntityType::Person, None)
472 .await
473 .unwrap();
474 store
475 .insert_edge(a, b, "knows", "A knows B", 1.0, None)
476 .await
477 .unwrap();
478 store
479 .insert_edge(b, c, "knows", "B knows C", 1.0, None)
480 .await
481 .unwrap();
482
483 let mut cfg = default_params();
484 cfg.max_hops = 1;
485 let sa = SpreadingActivation::new(cfg);
486 let seeds = HashMap::from([(a, 1.0_f32)]);
487 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
488
489 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
490 assert!(ids.contains(&a), "A must be activated (seed)");
491 assert!(ids.contains(&b), "B must be activated (hop 1)");
492 assert!(!ids.contains(&c), "C must NOT be activated with max_hops=1");
493 }
494
495 #[tokio::test]
499 async fn spread_diamond_graph_convergence() {
500 let store = setup_store().await;
501 let a = store
502 .upsert_entity("A", "A", EntityType::Person, None)
503 .await
504 .unwrap();
505 let b = store
506 .upsert_entity("B", "B", EntityType::Person, None)
507 .await
508 .unwrap();
509 let c = store
510 .upsert_entity("C", "C", EntityType::Person, None)
511 .await
512 .unwrap();
513 let d = store
514 .upsert_entity("D", "D", EntityType::Person, None)
515 .await
516 .unwrap();
517 store
518 .insert_edge(a, b, "rel", "A-B", 1.0, None)
519 .await
520 .unwrap();
521 store
522 .insert_edge(a, c, "rel", "A-C", 1.0, None)
523 .await
524 .unwrap();
525 store
526 .insert_edge(b, d, "rel", "B-D", 1.0, None)
527 .await
528 .unwrap();
529 store
530 .insert_edge(c, d, "rel", "C-D", 1.0, None)
531 .await
532 .unwrap();
533
534 let mut cfg = default_params();
535 cfg.max_hops = 3;
536 cfg.decay_lambda = 0.9;
537 cfg.inhibition_threshold = 0.95; let sa = SpreadingActivation::new(cfg);
539 let seeds = HashMap::from([(a, 1.0_f32)]);
540 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
541
542 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
543 assert!(ids.contains(&d), "D must be activated via diamond paths");
544
545 let node_d = nodes.iter().find(|n| n.entity_id == d).unwrap();
547 assert_eq!(node_d.depth, 2, "D should be at depth 2");
548 }
549
550 #[tokio::test]
552 async fn spread_inhibition_prevents_runaway() {
553 let store = setup_store().await;
554 let hub = store
556 .upsert_entity("Hub", "Hub", EntityType::Concept, None)
557 .await
558 .unwrap();
559
560 for i in 0..5 {
561 let leaf = store
562 .upsert_entity(
563 &format!("Leaf{i}"),
564 &format!("Leaf{i}"),
565 EntityType::Concept,
566 None,
567 )
568 .await
569 .unwrap();
570 store
571 .insert_edge(hub, leaf, "has", &format!("Hub has Leaf{i}"), 1.0, None)
572 .await
573 .unwrap();
574 store
576 .insert_edge(
577 leaf,
578 hub,
579 "part_of",
580 &format!("Leaf{i} part_of Hub"),
581 1.0,
582 None,
583 )
584 .await
585 .unwrap();
586 }
587
588 let mut cfg = default_params();
590 cfg.inhibition_threshold = 0.8;
591 cfg.max_hops = 3;
592 let sa = SpreadingActivation::new(cfg);
593 let seeds = HashMap::from([(hub, 1.0_f32)]);
594 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
595
596 let hub_node = nodes.iter().find(|n| n.entity_id == hub);
598 assert!(hub_node.is_some(), "hub must be in results");
599 assert!(
600 hub_node.unwrap().activation <= 1.0,
601 "activation must not exceed 1.0"
602 );
603 }
604
605 #[tokio::test]
607 async fn spread_max_activated_nodes_cap_enforced() {
608 let store = setup_store().await;
609 let root = store
610 .upsert_entity("Root", "Root", EntityType::Person, None)
611 .await
612 .unwrap();
613
614 for i in 0..20 {
616 let leaf = store
617 .upsert_entity(
618 &format!("Node{i}"),
619 &format!("Node{i}"),
620 EntityType::Concept,
621 None,
622 )
623 .await
624 .unwrap();
625 store
626 .insert_edge(root, leaf, "has", &format!("Root has Node{i}"), 0.9, None)
627 .await
628 .unwrap();
629 }
630
631 let max_nodes = 5;
632 let cfg = SpreadingActivationParams {
633 max_activated_nodes: max_nodes,
634 max_hops: 2,
635 ..default_params()
636 };
637 let sa = SpreadingActivation::new(cfg);
638 let seeds = HashMap::from([(root, 1.0_f32)]);
639 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
640
641 assert!(
642 nodes.len() <= max_nodes,
643 "activation must be capped at {max_nodes} nodes, got {}",
644 nodes.len()
645 );
646 }
647
648 #[tokio::test]
650 async fn spread_temporal_decay_recency_effect() {
651 let store = setup_store().await;
652 let src = store
653 .upsert_entity("Src", "Src", EntityType::Person, None)
654 .await
655 .unwrap();
656 let recent = store
657 .upsert_entity("Recent", "Recent", EntityType::Tool, None)
658 .await
659 .unwrap();
660 let old = store
661 .upsert_entity("Old", "Old", EntityType::Tool, None)
662 .await
663 .unwrap();
664
665 store
667 .insert_edge(src, recent, "uses", "Src uses Recent", 1.0, None)
668 .await
669 .unwrap();
670
671 zeph_db::query(
673 sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
674 VALUES (?1, ?2, 'uses', 'Src uses Old', 1.0, '1970-01-01 00:00:00')"),
675 )
676 .bind(src)
677 .bind(old)
678 .execute(store.pool())
679 .await
680 .unwrap();
681
682 let mut cfg = default_params();
683 cfg.max_hops = 2;
684 let sa = SpreadingActivation::new(SpreadingActivationParams {
686 temporal_decay_rate: 0.5,
687 ..cfg
688 });
689 let seeds = HashMap::from([(src, 1.0_f32)]);
690 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
691
692 let score_recent = nodes
693 .iter()
694 .find(|n| n.entity_id == recent)
695 .map_or(0.0, |n| n.activation);
696 let score_old = nodes
697 .iter()
698 .find(|n| n.entity_id == old)
699 .map_or(0.0, |n| n.activation);
700
701 assert!(
702 score_recent > score_old,
703 "recent edge ({score_recent}) must produce higher activation than old edge ({score_old})"
704 );
705 }
706
707 #[tokio::test]
709 async fn spread_edge_type_filter_excludes_other_types() {
710 let store = setup_store().await;
711 let a = store
712 .upsert_entity("A", "A", EntityType::Person, None)
713 .await
714 .unwrap();
715 let b_semantic = store
716 .upsert_entity("BSemantic", "BSemantic", EntityType::Tool, None)
717 .await
718 .unwrap();
719 let c_causal = store
720 .upsert_entity("CCausal", "CCausal", EntityType::Concept, None)
721 .await
722 .unwrap();
723
724 store
726 .insert_edge(a, b_semantic, "uses", "A uses BSemantic", 1.0, None)
727 .await
728 .unwrap();
729
730 zeph_db::query(
732 sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from, edge_type)
733 VALUES (?1, ?2, 'caused', 'A caused CCausal', 1.0, datetime('now'), 'causal')"),
734 )
735 .bind(a)
736 .bind(c_causal)
737 .execute(store.pool())
738 .await
739 .unwrap();
740
741 let cfg = default_params();
742 let sa = SpreadingActivation::new(cfg);
743
744 let seeds = HashMap::from([(a, 1.0_f32)]);
746 let (nodes, _) = sa
747 .spread(&store, seeds, &[EdgeType::Semantic])
748 .await
749 .unwrap();
750
751 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
752 assert!(
753 ids.contains(&b_semantic),
754 "BSemantic must be activated via semantic edge"
755 );
756 assert!(
757 !ids.contains(&c_causal),
758 "CCausal must NOT be activated when filtering to semantic only"
759 );
760 }
761
762 #[tokio::test]
764 async fn spread_large_seed_list() {
765 let store = setup_store().await;
766 let mut seeds = HashMap::new();
767
768 for i in 0..100i64 {
770 let id = store
771 .upsert_entity(
772 &format!("Entity{i}"),
773 &format!("entity{i}"),
774 EntityType::Concept,
775 None,
776 )
777 .await
778 .unwrap();
779 seeds.insert(id, 1.0_f32);
780 }
781
782 let cfg = default_params();
783 let sa = SpreadingActivation::new(cfg);
784 let result = sa.spread(&store, seeds, &[]).await;
786 assert!(
787 result.is_ok(),
788 "large seed list must not error: {:?}",
789 result.err()
790 );
791 }
792}