1use std::collections::HashMap;
17use std::time::{SystemTime, UNIX_EPOCH};
18#[allow(unused_imports)]
19use zeph_db::sql;
20
21use crate::error::MemoryError;
22use crate::graph::store::GraphStore;
23use crate::graph::types::{Edge, EdgeType, edge_type_weight, evolved_weight};
24
25#[derive(Debug, Clone)]
27pub struct ActivatedNode {
28 pub entity_id: i64,
30 pub activation: f32,
32 pub depth: u32,
34}
35
36#[derive(Debug, Clone)]
38pub struct ActivatedFact {
39 pub edge: Edge,
41 pub activation_score: f32,
43}
44
45#[derive(Debug, Clone)]
48pub struct SpreadingActivationParams {
49 pub decay_lambda: f32,
50 pub max_hops: u32,
51 pub activation_threshold: f32,
52 pub inhibition_threshold: f32,
53 pub max_activated_nodes: usize,
54 pub temporal_decay_rate: f64,
55 pub seed_structural_weight: f32,
57 pub seed_community_cap: usize,
59}
60
61pub struct SpreadingActivation {
63 params: SpreadingActivationParams,
64}
65
66impl SpreadingActivation {
67 #[must_use]
72 pub fn new(params: SpreadingActivationParams) -> Self {
73 Self { params }
74 }
75
76 #[allow(clippy::too_many_lines)]
92 pub async fn spread(
93 &self,
94 store: &GraphStore,
95 seeds: HashMap<i64, f32>,
96 edge_types: &[EdgeType],
97 ) -> Result<(Vec<ActivatedNode>, Vec<ActivatedFact>), MemoryError> {
98 if seeds.is_empty() {
99 return Ok((Vec::new(), Vec::new()));
100 }
101
102 let now_secs: i64 = SystemTime::now()
105 .duration_since(UNIX_EPOCH)
106 .map_or(0, |d| d.as_secs().cast_signed());
107
108 let mut activation: HashMap<i64, (f32, u32)> = HashMap::new();
110
111 let mut seed_count = 0usize;
114 for (entity_id, match_score) in &seeds {
115 if *match_score < self.params.activation_threshold {
116 tracing::debug!(
117 entity_id,
118 score = match_score,
119 threshold = self.params.activation_threshold,
120 "spreading activation: seed below threshold, skipping"
121 );
122 continue;
123 }
124 activation.insert(*entity_id, (*match_score, 0));
125 seed_count += 1;
126 }
127
128 tracing::debug!(
129 seeds = seed_count,
130 "spreading activation: initialized seeds"
131 );
132
133 let mut activated_facts: Vec<ActivatedFact> = Vec::new();
135
136 for hop in 0..self.params.max_hops {
138 let active_nodes: Vec<(i64, f32)> = activation
140 .iter()
141 .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
142 .map(|(&id, &(score, _))| (id, score))
143 .collect();
144
145 if active_nodes.is_empty() {
146 break;
147 }
148
149 let node_ids: Vec<i64> = active_nodes.iter().map(|(id, _)| *id).collect();
150
151 let edges = store.edges_for_entities(&node_ids, edge_types).await?;
153 let edge_count = edges.len();
154
155 let mut next_activation: HashMap<i64, (f32, u32)> = HashMap::new();
156
157 for edge in &edges {
158 for &(active_id, node_score) in &active_nodes {
161 let neighbor = if edge.source_entity_id == active_id {
162 edge.target_entity_id
163 } else if edge.target_entity_id == active_id {
164 edge.source_entity_id
165 } else {
166 continue;
167 };
168
169 let current_score = activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
174 let next_score = next_activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
175 if current_score >= self.params.inhibition_threshold
176 || next_score >= self.params.inhibition_threshold
177 {
178 continue;
179 }
180
181 let recency = self.recency_weight(&edge.valid_from, now_secs);
182 let edge_weight = evolved_weight(edge.retrieval_count, edge.confidence);
183 let type_w = edge_type_weight(edge.edge_type);
184 let spread_value =
185 node_score * self.params.decay_lambda * edge_weight * recency * type_w;
186
187 if spread_value < self.params.activation_threshold {
188 continue;
189 }
190
191 let depth_at_max = hop + 1;
195 let entry = next_activation
196 .entry(neighbor)
197 .or_insert((0.0, depth_at_max));
198 let new_score = (entry.0 + spread_value).min(1.0);
199 if new_score > entry.0 {
200 entry.0 = new_score;
201 entry.1 = depth_at_max;
202 }
203 }
204 }
205
206 for (node_id, (new_score, new_depth)) in next_activation {
208 let entry = activation.entry(node_id).or_insert((0.0, new_depth));
209 if new_score > entry.0 {
210 entry.0 = new_score;
211 entry.1 = new_depth;
212 }
213 }
214
215 let pruned_count = if activation.len() > self.params.max_activated_nodes {
218 let before = activation.len();
219 let mut entries: Vec<(i64, (f32, u32))> = activation.drain().collect();
220 entries.sort_by(|(_, (a, _)), (_, (b, _))| b.total_cmp(a));
221 entries.truncate(self.params.max_activated_nodes);
222 activation = entries.into_iter().collect();
223 before - self.params.max_activated_nodes
224 } else {
225 0
226 };
227
228 tracing::debug!(
229 hop,
230 active_nodes = active_nodes.len(),
231 edges_fetched = edge_count,
232 after_merge = activation.len(),
233 pruned = pruned_count,
234 "spreading activation: hop complete"
235 );
236
237 for edge in edges {
239 let src_score = activation
241 .get(&edge.source_entity_id)
242 .map_or(0.0, |&(s, _)| s);
243 let tgt_score = activation
244 .get(&edge.target_entity_id)
245 .map_or(0.0, |&(s, _)| s);
246 if src_score >= self.params.activation_threshold
247 && tgt_score >= self.params.activation_threshold
248 {
249 let activation_score = src_score.max(tgt_score);
250 activated_facts.push(ActivatedFact {
251 edge,
252 activation_score,
253 });
254 }
255 }
256 }
257
258 let mut result: Vec<ActivatedNode> = activation
260 .into_iter()
261 .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
262 .map(|(entity_id, (activation, depth))| ActivatedNode {
263 entity_id,
264 activation,
265 depth,
266 })
267 .collect();
268 result.sort_by(|a, b| b.activation.total_cmp(&a.activation));
269
270 tracing::info!(
271 activated = result.len(),
272 facts = activated_facts.len(),
273 "spreading activation: complete"
274 );
275
276 Ok((result, activated_facts))
277 }
278
279 #[allow(clippy::cast_precision_loss)]
285 fn recency_weight(&self, valid_from: &str, now_secs: i64) -> f32 {
286 if self.params.temporal_decay_rate <= 0.0 {
287 return 1.0;
288 }
289 let Some(valid_from_secs) = parse_sqlite_datetime_to_unix(valid_from) else {
290 return 1.0;
291 };
292 let age_secs = (now_secs - valid_from_secs).max(0);
293 let age_days = age_secs as f64 / 86_400.0;
294 let weight = 1.0_f64 / (1.0 + age_days * self.params.temporal_decay_rate);
295 #[allow(clippy::cast_possible_truncation)]
297 let w = weight as f32;
298 w
299 }
300}
301
302#[must_use]
307fn parse_sqlite_datetime_to_unix(s: &str) -> Option<i64> {
308 if s.len() < 19 {
309 return None;
310 }
311 let year: i64 = s[0..4].parse().ok()?;
312 let month: i64 = s[5..7].parse().ok()?;
313 let day: i64 = s[8..10].parse().ok()?;
314 let hour: i64 = s[11..13].parse().ok()?;
315 let min: i64 = s[14..16].parse().ok()?;
316 let sec: i64 = s[17..19].parse().ok()?;
317
318 let (y, m) = if month <= 2 {
321 (year - 1, month + 9)
322 } else {
323 (year, month - 3)
324 };
325 let era = y.div_euclid(400);
326 let yoe = y - era * 400;
327 let doy = (153 * m + 2) / 5 + day - 1;
328 let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
329 let days = era * 146_097 + doe - 719_468;
330
331 Some(days * 86_400 + hour * 3_600 + min * 60 + sec)
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use crate::graph::GraphStore;
338 use crate::graph::types::EntityType;
339 use crate::store::SqliteStore;
340
341 async fn setup_store() -> GraphStore {
342 let store = SqliteStore::new(":memory:").await.unwrap();
343 GraphStore::new(store.pool().clone())
344 }
345
346 fn default_params() -> SpreadingActivationParams {
347 SpreadingActivationParams {
348 decay_lambda: 0.85,
349 max_hops: 3,
350 activation_threshold: 0.1,
351 inhibition_threshold: 0.8,
352 max_activated_nodes: 50,
353 temporal_decay_rate: 0.0,
354 seed_structural_weight: 0.4,
355 seed_community_cap: 3,
356 }
357 }
358
359 #[tokio::test]
362 async fn spread_empty_graph_no_edges_no_facts() {
363 let store = setup_store().await;
364 let sa = SpreadingActivation::new(default_params());
365 let seeds = HashMap::from([(1_i64, 1.0_f32)]);
366 let (nodes, facts) = sa.spread(&store, seeds, &[]).await.unwrap();
367 assert_eq!(nodes.len(), 1, "seed must be in activated nodes");
369 assert_eq!(nodes[0].entity_id, 1);
370 assert!((nodes[0].activation - 1.0).abs() < 1e-6);
371 assert!(
373 facts.is_empty(),
374 "expected no activated facts on empty graph"
375 );
376 }
377
378 #[tokio::test]
380 async fn spread_empty_seeds_returns_empty() {
381 let store = setup_store().await;
382 let sa = SpreadingActivation::new(default_params());
383 let (nodes, facts) = sa.spread(&store, HashMap::new(), &[]).await.unwrap();
384 assert!(nodes.is_empty());
385 assert!(facts.is_empty());
386 }
387
388 #[tokio::test]
390 async fn spread_single_seed_no_edges_returns_seed() {
391 let store = setup_store().await;
392 let alice = store
393 .upsert_entity("Alice", "Alice", EntityType::Person, None)
394 .await
395 .unwrap();
396
397 let sa = SpreadingActivation::new(default_params());
398 let seeds = HashMap::from([(alice, 1.0_f32)]);
399 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
400 assert_eq!(nodes.len(), 1);
401 assert_eq!(nodes[0].entity_id, alice);
402 assert_eq!(nodes[0].depth, 0);
403 assert!((nodes[0].activation - 1.0).abs() < 1e-6);
404 }
405
406 #[tokio::test]
408 async fn spread_linear_chain_all_activated_with_decay() {
409 let store = setup_store().await;
410 let a = store
411 .upsert_entity("A", "A", EntityType::Person, None)
412 .await
413 .unwrap();
414 let b = store
415 .upsert_entity("B", "B", EntityType::Person, None)
416 .await
417 .unwrap();
418 let c = store
419 .upsert_entity("C", "C", EntityType::Person, None)
420 .await
421 .unwrap();
422 store
423 .insert_edge(a, b, "knows", "A knows B", 1.0, None)
424 .await
425 .unwrap();
426 store
427 .insert_edge(b, c, "knows", "B knows C", 1.0, None)
428 .await
429 .unwrap();
430
431 let mut cfg = default_params();
432 cfg.max_hops = 3;
433 cfg.decay_lambda = 0.9;
434 let sa = SpreadingActivation::new(cfg);
435 let seeds = HashMap::from([(a, 1.0_f32)]);
436 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
437
438 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
439 assert!(ids.contains(&a), "A (seed) must be activated");
440 assert!(ids.contains(&b), "B (hop 1) must be activated");
441 assert!(ids.contains(&c), "C (hop 2) must be activated");
442
443 let score_a = nodes.iter().find(|n| n.entity_id == a).unwrap().activation;
445 let score_b = nodes.iter().find(|n| n.entity_id == b).unwrap().activation;
446 let score_c = nodes.iter().find(|n| n.entity_id == c).unwrap().activation;
447 assert!(
448 score_a > score_b,
449 "seed A should have higher activation than hop-1 B"
450 );
451 assert!(
452 score_b > score_c,
453 "hop-1 B should have higher activation than hop-2 C"
454 );
455 }
456
457 #[tokio::test]
459 async fn spread_linear_chain_max_hops_limits_reach() {
460 let store = setup_store().await;
461 let a = store
462 .upsert_entity("A", "A", EntityType::Person, None)
463 .await
464 .unwrap();
465 let b = store
466 .upsert_entity("B", "B", EntityType::Person, None)
467 .await
468 .unwrap();
469 let c = store
470 .upsert_entity("C", "C", EntityType::Person, None)
471 .await
472 .unwrap();
473 store
474 .insert_edge(a, b, "knows", "A knows B", 1.0, None)
475 .await
476 .unwrap();
477 store
478 .insert_edge(b, c, "knows", "B knows C", 1.0, None)
479 .await
480 .unwrap();
481
482 let mut cfg = default_params();
483 cfg.max_hops = 1;
484 let sa = SpreadingActivation::new(cfg);
485 let seeds = HashMap::from([(a, 1.0_f32)]);
486 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
487
488 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
489 assert!(ids.contains(&a), "A must be activated (seed)");
490 assert!(ids.contains(&b), "B must be activated (hop 1)");
491 assert!(!ids.contains(&c), "C must NOT be activated with max_hops=1");
492 }
493
494 #[tokio::test]
498 async fn spread_diamond_graph_convergence() {
499 let store = setup_store().await;
500 let a = store
501 .upsert_entity("A", "A", EntityType::Person, None)
502 .await
503 .unwrap();
504 let b = store
505 .upsert_entity("B", "B", EntityType::Person, None)
506 .await
507 .unwrap();
508 let c = store
509 .upsert_entity("C", "C", EntityType::Person, None)
510 .await
511 .unwrap();
512 let d = store
513 .upsert_entity("D", "D", EntityType::Person, None)
514 .await
515 .unwrap();
516 store
517 .insert_edge(a, b, "rel", "A-B", 1.0, None)
518 .await
519 .unwrap();
520 store
521 .insert_edge(a, c, "rel", "A-C", 1.0, None)
522 .await
523 .unwrap();
524 store
525 .insert_edge(b, d, "rel", "B-D", 1.0, None)
526 .await
527 .unwrap();
528 store
529 .insert_edge(c, d, "rel", "C-D", 1.0, None)
530 .await
531 .unwrap();
532
533 let mut cfg = default_params();
534 cfg.max_hops = 3;
535 cfg.decay_lambda = 0.9;
536 cfg.inhibition_threshold = 0.95; let sa = SpreadingActivation::new(cfg);
538 let seeds = HashMap::from([(a, 1.0_f32)]);
539 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
540
541 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
542 assert!(ids.contains(&d), "D must be activated via diamond paths");
543
544 let node_d = nodes.iter().find(|n| n.entity_id == d).unwrap();
546 assert_eq!(node_d.depth, 2, "D should be at depth 2");
547 }
548
549 #[tokio::test]
551 async fn spread_inhibition_prevents_runaway() {
552 let store = setup_store().await;
553 let hub = store
555 .upsert_entity("Hub", "Hub", EntityType::Concept, None)
556 .await
557 .unwrap();
558
559 for i in 0..5 {
560 let leaf = store
561 .upsert_entity(
562 &format!("Leaf{i}"),
563 &format!("Leaf{i}"),
564 EntityType::Concept,
565 None,
566 )
567 .await
568 .unwrap();
569 store
570 .insert_edge(hub, leaf, "has", &format!("Hub has Leaf{i}"), 1.0, None)
571 .await
572 .unwrap();
573 store
575 .insert_edge(
576 leaf,
577 hub,
578 "part_of",
579 &format!("Leaf{i} part_of Hub"),
580 1.0,
581 None,
582 )
583 .await
584 .unwrap();
585 }
586
587 let mut cfg = default_params();
589 cfg.inhibition_threshold = 0.8;
590 cfg.max_hops = 3;
591 let sa = SpreadingActivation::new(cfg);
592 let seeds = HashMap::from([(hub, 1.0_f32)]);
593 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
594
595 let hub_node = nodes.iter().find(|n| n.entity_id == hub);
597 assert!(hub_node.is_some(), "hub must be in results");
598 assert!(
599 hub_node.unwrap().activation <= 1.0,
600 "activation must not exceed 1.0"
601 );
602 }
603
604 #[tokio::test]
606 async fn spread_max_activated_nodes_cap_enforced() {
607 let store = setup_store().await;
608 let root = store
609 .upsert_entity("Root", "Root", EntityType::Person, None)
610 .await
611 .unwrap();
612
613 for i in 0..20 {
615 let leaf = store
616 .upsert_entity(
617 &format!("Node{i}"),
618 &format!("Node{i}"),
619 EntityType::Concept,
620 None,
621 )
622 .await
623 .unwrap();
624 store
625 .insert_edge(root, leaf, "has", &format!("Root has Node{i}"), 0.9, None)
626 .await
627 .unwrap();
628 }
629
630 let max_nodes = 5;
631 let cfg = SpreadingActivationParams {
632 max_activated_nodes: max_nodes,
633 max_hops: 2,
634 ..default_params()
635 };
636 let sa = SpreadingActivation::new(cfg);
637 let seeds = HashMap::from([(root, 1.0_f32)]);
638 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
639
640 assert!(
641 nodes.len() <= max_nodes,
642 "activation must be capped at {max_nodes} nodes, got {}",
643 nodes.len()
644 );
645 }
646
647 #[tokio::test]
649 async fn spread_temporal_decay_recency_effect() {
650 let store = setup_store().await;
651 let src = store
652 .upsert_entity("Src", "Src", EntityType::Person, None)
653 .await
654 .unwrap();
655 let recent = store
656 .upsert_entity("Recent", "Recent", EntityType::Tool, None)
657 .await
658 .unwrap();
659 let old = store
660 .upsert_entity("Old", "Old", EntityType::Tool, None)
661 .await
662 .unwrap();
663
664 store
666 .insert_edge(src, recent, "uses", "Src uses Recent", 1.0, None)
667 .await
668 .unwrap();
669
670 zeph_db::query(
672 sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
673 VALUES (?1, ?2, 'uses', 'Src uses Old', 1.0, '1970-01-01 00:00:00')"),
674 )
675 .bind(src)
676 .bind(old)
677 .execute(store.pool())
678 .await
679 .unwrap();
680
681 let mut cfg = default_params();
682 cfg.max_hops = 2;
683 let sa = SpreadingActivation::new(SpreadingActivationParams {
685 temporal_decay_rate: 0.5,
686 ..cfg
687 });
688 let seeds = HashMap::from([(src, 1.0_f32)]);
689 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
690
691 let score_recent = nodes
692 .iter()
693 .find(|n| n.entity_id == recent)
694 .map_or(0.0, |n| n.activation);
695 let score_old = nodes
696 .iter()
697 .find(|n| n.entity_id == old)
698 .map_or(0.0, |n| n.activation);
699
700 assert!(
701 score_recent > score_old,
702 "recent edge ({score_recent}) must produce higher activation than old edge ({score_old})"
703 );
704 }
705
706 #[tokio::test]
708 async fn spread_edge_type_filter_excludes_other_types() {
709 let store = setup_store().await;
710 let a = store
711 .upsert_entity("A", "A", EntityType::Person, None)
712 .await
713 .unwrap();
714 let b_semantic = store
715 .upsert_entity("BSemantic", "BSemantic", EntityType::Tool, None)
716 .await
717 .unwrap();
718 let c_causal = store
719 .upsert_entity("CCausal", "CCausal", EntityType::Concept, None)
720 .await
721 .unwrap();
722
723 store
725 .insert_edge(a, b_semantic, "uses", "A uses BSemantic", 1.0, None)
726 .await
727 .unwrap();
728
729 zeph_db::query(
731 sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from, edge_type)
732 VALUES (?1, ?2, 'caused', 'A caused CCausal', 1.0, datetime('now'), 'causal')"),
733 )
734 .bind(a)
735 .bind(c_causal)
736 .execute(store.pool())
737 .await
738 .unwrap();
739
740 let cfg = default_params();
741 let sa = SpreadingActivation::new(cfg);
742
743 let seeds = HashMap::from([(a, 1.0_f32)]);
745 let (nodes, _) = sa
746 .spread(&store, seeds, &[EdgeType::Semantic])
747 .await
748 .unwrap();
749
750 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
751 assert!(
752 ids.contains(&b_semantic),
753 "BSemantic must be activated via semantic edge"
754 );
755 assert!(
756 !ids.contains(&c_causal),
757 "CCausal must NOT be activated when filtering to semantic only"
758 );
759 }
760
761 #[tokio::test]
763 async fn spread_large_seed_list() {
764 let store = setup_store().await;
765 let mut seeds = HashMap::new();
766
767 for i in 0..100i64 {
769 let id = store
770 .upsert_entity(
771 &format!("Entity{i}"),
772 &format!("entity{i}"),
773 EntityType::Concept,
774 None,
775 )
776 .await
777 .unwrap();
778 seeds.insert(id, 1.0_f32);
779 }
780
781 let cfg = default_params();
782 let sa = SpreadingActivation::new(cfg);
783 let result = sa.spread(&store, seeds, &[]).await;
785 assert!(
786 result.is_ok(),
787 "large seed list must not error: {:?}",
788 result.err()
789 );
790 }
791}