1use std::collections::HashMap;
17use std::time::{SystemTime, UNIX_EPOCH};
18
19use crate::error::MemoryError;
20use crate::graph::store::GraphStore;
21use crate::graph::types::{Edge, EdgeType};
22
23#[derive(Debug, Clone)]
25pub struct ActivatedNode {
26 pub entity_id: i64,
28 pub activation: f32,
30 pub depth: u32,
32}
33
34#[derive(Debug, Clone)]
36pub struct ActivatedFact {
37 pub edge: Edge,
39 pub activation_score: f32,
41}
42
43#[derive(Debug, Clone)]
46pub struct SpreadingActivationParams {
47 pub decay_lambda: f32,
48 pub max_hops: u32,
49 pub activation_threshold: f32,
50 pub inhibition_threshold: f32,
51 pub max_activated_nodes: usize,
52 pub temporal_decay_rate: f64,
53}
54
55pub struct SpreadingActivation {
57 params: SpreadingActivationParams,
58}
59
60impl SpreadingActivation {
61 #[must_use]
66 pub fn new(params: SpreadingActivationParams) -> Self {
67 Self { params }
68 }
69
70 #[allow(clippy::too_many_lines)]
86 pub async fn spread(
87 &self,
88 store: &GraphStore,
89 seeds: HashMap<i64, f32>,
90 edge_types: &[EdgeType],
91 ) -> Result<(Vec<ActivatedNode>, Vec<ActivatedFact>), MemoryError> {
92 if seeds.is_empty() {
93 return Ok((Vec::new(), Vec::new()));
94 }
95
96 let now_secs: i64 = SystemTime::now()
99 .duration_since(UNIX_EPOCH)
100 .map(|d| d.as_secs().cast_signed())
101 .unwrap_or(0);
102
103 let mut activation: HashMap<i64, (f32, u32)> = HashMap::new();
105
106 let mut seed_count = 0usize;
109 for (entity_id, match_score) in &seeds {
110 if *match_score < self.params.activation_threshold {
111 tracing::debug!(
112 entity_id,
113 score = match_score,
114 threshold = self.params.activation_threshold,
115 "spreading activation: seed below threshold, skipping"
116 );
117 continue;
118 }
119 activation.insert(*entity_id, (*match_score, 0));
120 seed_count += 1;
121 }
122
123 tracing::debug!(
124 seeds = seed_count,
125 "spreading activation: initialized seeds"
126 );
127
128 let mut activated_facts: Vec<ActivatedFact> = Vec::new();
130
131 for hop in 0..self.params.max_hops {
133 let active_nodes: Vec<(i64, f32)> = activation
135 .iter()
136 .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
137 .map(|(&id, &(score, _))| (id, score))
138 .collect();
139
140 if active_nodes.is_empty() {
141 break;
142 }
143
144 let node_ids: Vec<i64> = active_nodes.iter().map(|(id, _)| *id).collect();
145
146 let edges = store.edges_for_entities(&node_ids, edge_types).await?;
148 let edge_count = edges.len();
149
150 let mut next_activation: HashMap<i64, (f32, u32)> = HashMap::new();
151
152 for edge in &edges {
153 for &(active_id, node_score) in &active_nodes {
156 let neighbor = if edge.source_entity_id == active_id {
157 edge.target_entity_id
158 } else if edge.target_entity_id == active_id {
159 edge.source_entity_id
160 } else {
161 continue;
162 };
163
164 let current_score = activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
169 let next_score = next_activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
170 if current_score >= self.params.inhibition_threshold
171 || next_score >= self.params.inhibition_threshold
172 {
173 continue;
174 }
175
176 let recency = self.recency_weight(&edge.valid_from, now_secs);
177 let spread_value =
178 node_score * self.params.decay_lambda * edge.confidence * recency;
179
180 if spread_value < self.params.activation_threshold {
181 continue;
182 }
183
184 let depth_at_max = hop + 1;
188 let entry = next_activation
189 .entry(neighbor)
190 .or_insert((0.0, depth_at_max));
191 let new_score = (entry.0 + spread_value).min(1.0);
192 if new_score > entry.0 {
193 entry.0 = new_score;
194 entry.1 = depth_at_max;
195 }
196 }
197 }
198
199 for (node_id, (new_score, new_depth)) in next_activation {
201 let entry = activation.entry(node_id).or_insert((0.0, new_depth));
202 if new_score > entry.0 {
203 entry.0 = new_score;
204 entry.1 = new_depth;
205 }
206 }
207
208 let pruned_count = if activation.len() > self.params.max_activated_nodes {
211 let before = activation.len();
212 let mut entries: Vec<(i64, (f32, u32))> = activation.drain().collect();
213 entries.sort_by(|(_, (a, _)), (_, (b, _))| b.total_cmp(a));
214 entries.truncate(self.params.max_activated_nodes);
215 activation = entries.into_iter().collect();
216 before - self.params.max_activated_nodes
217 } else {
218 0
219 };
220
221 tracing::debug!(
222 hop,
223 active_nodes = active_nodes.len(),
224 edges_fetched = edge_count,
225 after_merge = activation.len(),
226 pruned = pruned_count,
227 "spreading activation: hop complete"
228 );
229
230 for edge in edges {
232 let src_score = activation
234 .get(&edge.source_entity_id)
235 .map_or(0.0, |&(s, _)| s);
236 let tgt_score = activation
237 .get(&edge.target_entity_id)
238 .map_or(0.0, |&(s, _)| s);
239 if src_score >= self.params.activation_threshold
240 && tgt_score >= self.params.activation_threshold
241 {
242 let activation_score = src_score.max(tgt_score);
243 activated_facts.push(ActivatedFact {
244 edge,
245 activation_score,
246 });
247 }
248 }
249 }
250
251 let mut result: Vec<ActivatedNode> = activation
253 .into_iter()
254 .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
255 .map(|(entity_id, (activation, depth))| ActivatedNode {
256 entity_id,
257 activation,
258 depth,
259 })
260 .collect();
261 result.sort_by(|a, b| b.activation.total_cmp(&a.activation));
262
263 tracing::info!(
264 activated = result.len(),
265 facts = activated_facts.len(),
266 "spreading activation: complete"
267 );
268
269 Ok((result, activated_facts))
270 }
271
272 #[allow(clippy::cast_precision_loss)]
278 fn recency_weight(&self, valid_from: &str, now_secs: i64) -> f32 {
279 if self.params.temporal_decay_rate <= 0.0 {
280 return 1.0;
281 }
282 let Some(valid_from_secs) = parse_sqlite_datetime_to_unix(valid_from) else {
283 return 1.0;
284 };
285 let age_secs = (now_secs - valid_from_secs).max(0);
286 let age_days = age_secs as f64 / 86_400.0;
287 let weight = 1.0_f64 / (1.0 + age_days * self.params.temporal_decay_rate);
288 #[allow(clippy::cast_possible_truncation)]
290 let w = weight as f32;
291 w
292 }
293}
294
295#[must_use]
300fn parse_sqlite_datetime_to_unix(s: &str) -> Option<i64> {
301 if s.len() < 19 {
302 return None;
303 }
304 let year: i64 = s[0..4].parse().ok()?;
305 let month: i64 = s[5..7].parse().ok()?;
306 let day: i64 = s[8..10].parse().ok()?;
307 let hour: i64 = s[11..13].parse().ok()?;
308 let min: i64 = s[14..16].parse().ok()?;
309 let sec: i64 = s[17..19].parse().ok()?;
310
311 let (y, m) = if month <= 2 {
314 (year - 1, month + 9)
315 } else {
316 (year, month - 3)
317 };
318 let era = y.div_euclid(400);
319 let yoe = y - era * 400;
320 let doy = (153 * m + 2) / 5 + day - 1;
321 let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
322 let days = era * 146_097 + doe - 719_468;
323
324 Some(days * 86_400 + hour * 3_600 + min * 60 + sec)
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use crate::graph::GraphStore;
331 use crate::graph::types::EntityType;
332 use crate::sqlite::SqliteStore;
333
334 async fn setup_store() -> GraphStore {
335 let store = SqliteStore::new(":memory:").await.unwrap();
336 GraphStore::new(store.pool().clone())
337 }
338
339 fn default_params() -> SpreadingActivationParams {
340 SpreadingActivationParams {
341 decay_lambda: 0.85,
342 max_hops: 3,
343 activation_threshold: 0.1,
344 inhibition_threshold: 0.8,
345 max_activated_nodes: 50,
346 temporal_decay_rate: 0.0,
347 }
348 }
349
350 #[tokio::test]
353 async fn spread_empty_graph_no_edges_no_facts() {
354 let store = setup_store().await;
355 let sa = SpreadingActivation::new(default_params());
356 let seeds = HashMap::from([(1_i64, 1.0_f32)]);
357 let (nodes, facts) = sa.spread(&store, seeds, &[]).await.unwrap();
358 assert_eq!(nodes.len(), 1, "seed must be in activated nodes");
360 assert_eq!(nodes[0].entity_id, 1);
361 assert!((nodes[0].activation - 1.0).abs() < 1e-6);
362 assert!(
364 facts.is_empty(),
365 "expected no activated facts on empty graph"
366 );
367 }
368
369 #[tokio::test]
371 async fn spread_empty_seeds_returns_empty() {
372 let store = setup_store().await;
373 let sa = SpreadingActivation::new(default_params());
374 let (nodes, facts) = sa.spread(&store, HashMap::new(), &[]).await.unwrap();
375 assert!(nodes.is_empty());
376 assert!(facts.is_empty());
377 }
378
379 #[tokio::test]
381 async fn spread_single_seed_no_edges_returns_seed() {
382 let store = setup_store().await;
383 let alice = store
384 .upsert_entity("Alice", "Alice", EntityType::Person, None)
385 .await
386 .unwrap();
387
388 let sa = SpreadingActivation::new(default_params());
389 let seeds = HashMap::from([(alice, 1.0_f32)]);
390 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
391 assert_eq!(nodes.len(), 1);
392 assert_eq!(nodes[0].entity_id, alice);
393 assert_eq!(nodes[0].depth, 0);
394 assert!((nodes[0].activation - 1.0).abs() < 1e-6);
395 }
396
397 #[tokio::test]
399 async fn spread_linear_chain_all_activated_with_decay() {
400 let store = setup_store().await;
401 let a = store
402 .upsert_entity("A", "A", EntityType::Person, None)
403 .await
404 .unwrap();
405 let b = store
406 .upsert_entity("B", "B", EntityType::Person, None)
407 .await
408 .unwrap();
409 let c = store
410 .upsert_entity("C", "C", EntityType::Person, None)
411 .await
412 .unwrap();
413 store
414 .insert_edge(a, b, "knows", "A knows B", 1.0, None)
415 .await
416 .unwrap();
417 store
418 .insert_edge(b, c, "knows", "B knows C", 1.0, None)
419 .await
420 .unwrap();
421
422 let mut cfg = default_params();
423 cfg.max_hops = 3;
424 cfg.decay_lambda = 0.9;
425 let sa = SpreadingActivation::new(cfg);
426 let seeds = HashMap::from([(a, 1.0_f32)]);
427 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
428
429 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
430 assert!(ids.contains(&a), "A (seed) must be activated");
431 assert!(ids.contains(&b), "B (hop 1) must be activated");
432 assert!(ids.contains(&c), "C (hop 2) must be activated");
433
434 let score_a = nodes.iter().find(|n| n.entity_id == a).unwrap().activation;
436 let score_b = nodes.iter().find(|n| n.entity_id == b).unwrap().activation;
437 let score_c = nodes.iter().find(|n| n.entity_id == c).unwrap().activation;
438 assert!(
439 score_a > score_b,
440 "seed A should have higher activation than hop-1 B"
441 );
442 assert!(
443 score_b > score_c,
444 "hop-1 B should have higher activation than hop-2 C"
445 );
446 }
447
448 #[tokio::test]
450 async fn spread_linear_chain_max_hops_limits_reach() {
451 let store = setup_store().await;
452 let a = store
453 .upsert_entity("A", "A", EntityType::Person, None)
454 .await
455 .unwrap();
456 let b = store
457 .upsert_entity("B", "B", EntityType::Person, None)
458 .await
459 .unwrap();
460 let c = store
461 .upsert_entity("C", "C", EntityType::Person, None)
462 .await
463 .unwrap();
464 store
465 .insert_edge(a, b, "knows", "A knows B", 1.0, None)
466 .await
467 .unwrap();
468 store
469 .insert_edge(b, c, "knows", "B knows C", 1.0, None)
470 .await
471 .unwrap();
472
473 let mut cfg = default_params();
474 cfg.max_hops = 1;
475 let sa = SpreadingActivation::new(cfg);
476 let seeds = HashMap::from([(a, 1.0_f32)]);
477 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
478
479 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
480 assert!(ids.contains(&a), "A must be activated (seed)");
481 assert!(ids.contains(&b), "B must be activated (hop 1)");
482 assert!(!ids.contains(&c), "C must NOT be activated with max_hops=1");
483 }
484
485 #[tokio::test]
489 async fn spread_diamond_graph_convergence() {
490 let store = setup_store().await;
491 let a = store
492 .upsert_entity("A", "A", EntityType::Person, None)
493 .await
494 .unwrap();
495 let b = store
496 .upsert_entity("B", "B", EntityType::Person, None)
497 .await
498 .unwrap();
499 let c = store
500 .upsert_entity("C", "C", EntityType::Person, None)
501 .await
502 .unwrap();
503 let d = store
504 .upsert_entity("D", "D", EntityType::Person, None)
505 .await
506 .unwrap();
507 store
508 .insert_edge(a, b, "rel", "A-B", 1.0, None)
509 .await
510 .unwrap();
511 store
512 .insert_edge(a, c, "rel", "A-C", 1.0, None)
513 .await
514 .unwrap();
515 store
516 .insert_edge(b, d, "rel", "B-D", 1.0, None)
517 .await
518 .unwrap();
519 store
520 .insert_edge(c, d, "rel", "C-D", 1.0, None)
521 .await
522 .unwrap();
523
524 let mut cfg = default_params();
525 cfg.max_hops = 3;
526 cfg.decay_lambda = 0.9;
527 cfg.inhibition_threshold = 0.95; let sa = SpreadingActivation::new(cfg);
529 let seeds = HashMap::from([(a, 1.0_f32)]);
530 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
531
532 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
533 assert!(ids.contains(&d), "D must be activated via diamond paths");
534
535 let node_d = nodes.iter().find(|n| n.entity_id == d).unwrap();
537 assert_eq!(node_d.depth, 2, "D should be at depth 2");
538 }
539
540 #[tokio::test]
542 async fn spread_inhibition_prevents_runaway() {
543 let store = setup_store().await;
544 let hub = store
546 .upsert_entity("Hub", "Hub", EntityType::Concept, None)
547 .await
548 .unwrap();
549
550 for i in 0..5 {
551 let leaf = store
552 .upsert_entity(
553 &format!("Leaf{i}"),
554 &format!("Leaf{i}"),
555 EntityType::Concept,
556 None,
557 )
558 .await
559 .unwrap();
560 store
561 .insert_edge(hub, leaf, "has", &format!("Hub has Leaf{i}"), 1.0, None)
562 .await
563 .unwrap();
564 store
566 .insert_edge(
567 leaf,
568 hub,
569 "part_of",
570 &format!("Leaf{i} part_of Hub"),
571 1.0,
572 None,
573 )
574 .await
575 .unwrap();
576 }
577
578 let mut cfg = default_params();
580 cfg.inhibition_threshold = 0.8;
581 cfg.max_hops = 3;
582 let sa = SpreadingActivation::new(cfg);
583 let seeds = HashMap::from([(hub, 1.0_f32)]);
584 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
585
586 let hub_node = nodes.iter().find(|n| n.entity_id == hub);
588 assert!(hub_node.is_some(), "hub must be in results");
589 assert!(
590 hub_node.unwrap().activation <= 1.0,
591 "activation must not exceed 1.0"
592 );
593 }
594
595 #[tokio::test]
597 async fn spread_max_activated_nodes_cap_enforced() {
598 let store = setup_store().await;
599 let root = store
600 .upsert_entity("Root", "Root", EntityType::Person, None)
601 .await
602 .unwrap();
603
604 for i in 0..20 {
606 let leaf = store
607 .upsert_entity(
608 &format!("Node{i}"),
609 &format!("Node{i}"),
610 EntityType::Concept,
611 None,
612 )
613 .await
614 .unwrap();
615 store
616 .insert_edge(root, leaf, "has", &format!("Root has Node{i}"), 0.9, None)
617 .await
618 .unwrap();
619 }
620
621 let max_nodes = 5;
622 let cfg = SpreadingActivationParams {
623 max_activated_nodes: max_nodes,
624 max_hops: 2,
625 ..default_params()
626 };
627 let sa = SpreadingActivation::new(cfg);
628 let seeds = HashMap::from([(root, 1.0_f32)]);
629 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
630
631 assert!(
632 nodes.len() <= max_nodes,
633 "activation must be capped at {max_nodes} nodes, got {}",
634 nodes.len()
635 );
636 }
637
638 #[tokio::test]
640 async fn spread_temporal_decay_recency_effect() {
641 let store = setup_store().await;
642 let src = store
643 .upsert_entity("Src", "Src", EntityType::Person, None)
644 .await
645 .unwrap();
646 let recent = store
647 .upsert_entity("Recent", "Recent", EntityType::Tool, None)
648 .await
649 .unwrap();
650 let old = store
651 .upsert_entity("Old", "Old", EntityType::Tool, None)
652 .await
653 .unwrap();
654
655 store
657 .insert_edge(src, recent, "uses", "Src uses Recent", 1.0, None)
658 .await
659 .unwrap();
660
661 sqlx::query(
663 "INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
664 VALUES (?1, ?2, 'uses', 'Src uses Old', 1.0, '1970-01-01 00:00:00')",
665 )
666 .bind(src)
667 .bind(old)
668 .execute(store.pool())
669 .await
670 .unwrap();
671
672 let mut cfg = default_params();
673 cfg.max_hops = 2;
674 let sa = SpreadingActivation::new(SpreadingActivationParams {
676 temporal_decay_rate: 0.5,
677 ..cfg
678 });
679 let seeds = HashMap::from([(src, 1.0_f32)]);
680 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
681
682 let score_recent = nodes
683 .iter()
684 .find(|n| n.entity_id == recent)
685 .map_or(0.0, |n| n.activation);
686 let score_old = nodes
687 .iter()
688 .find(|n| n.entity_id == old)
689 .map_or(0.0, |n| n.activation);
690
691 assert!(
692 score_recent > score_old,
693 "recent edge ({score_recent}) must produce higher activation than old edge ({score_old})"
694 );
695 }
696
697 #[tokio::test]
699 async fn spread_edge_type_filter_excludes_other_types() {
700 let store = setup_store().await;
701 let a = store
702 .upsert_entity("A", "A", EntityType::Person, None)
703 .await
704 .unwrap();
705 let b_semantic = store
706 .upsert_entity("BSemantic", "BSemantic", EntityType::Tool, None)
707 .await
708 .unwrap();
709 let c_causal = store
710 .upsert_entity("CCausal", "CCausal", EntityType::Concept, None)
711 .await
712 .unwrap();
713
714 store
716 .insert_edge(a, b_semantic, "uses", "A uses BSemantic", 1.0, None)
717 .await
718 .unwrap();
719
720 sqlx::query(
722 "INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from, edge_type)
723 VALUES (?1, ?2, 'caused', 'A caused CCausal', 1.0, datetime('now'), 'causal')",
724 )
725 .bind(a)
726 .bind(c_causal)
727 .execute(store.pool())
728 .await
729 .unwrap();
730
731 let cfg = default_params();
732 let sa = SpreadingActivation::new(cfg);
733
734 let seeds = HashMap::from([(a, 1.0_f32)]);
736 let (nodes, _) = sa
737 .spread(&store, seeds, &[EdgeType::Semantic])
738 .await
739 .unwrap();
740
741 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
742 assert!(
743 ids.contains(&b_semantic),
744 "BSemantic must be activated via semantic edge"
745 );
746 assert!(
747 !ids.contains(&c_causal),
748 "CCausal must NOT be activated when filtering to semantic only"
749 );
750 }
751
752 #[tokio::test]
754 async fn spread_large_seed_list() {
755 let store = setup_store().await;
756 let mut seeds = HashMap::new();
757
758 for i in 0..100i64 {
760 let id = store
761 .upsert_entity(
762 &format!("Entity{i}"),
763 &format!("entity{i}"),
764 EntityType::Concept,
765 None,
766 )
767 .await
768 .unwrap();
769 seeds.insert(id, 1.0_f32);
770 }
771
772 let cfg = default_params();
773 let sa = SpreadingActivation::new(cfg);
774 let result = sa.spread(&store, seeds, &[]).await;
776 assert!(
777 result.is_ok(),
778 "large seed list must not error: {:?}",
779 result.err()
780 );
781 }
782}