1use std::collections::HashMap;
17use std::time::{SystemTime, UNIX_EPOCH};
18
19use crate::error::MemoryError;
20use crate::graph::store::GraphStore;
21use crate::graph::types::{Edge, EdgeType, evolved_weight};
22
23#[derive(Debug, Clone)]
25pub struct ActivatedNode {
26 pub entity_id: i64,
28 pub activation: f32,
30 pub depth: u32,
32}
33
34#[derive(Debug, Clone)]
36pub struct ActivatedFact {
37 pub edge: Edge,
39 pub activation_score: f32,
41}
42
43#[derive(Debug, Clone)]
46pub struct SpreadingActivationParams {
47 pub decay_lambda: f32,
48 pub max_hops: u32,
49 pub activation_threshold: f32,
50 pub inhibition_threshold: f32,
51 pub max_activated_nodes: usize,
52 pub temporal_decay_rate: f64,
53 pub seed_structural_weight: f32,
55 pub seed_community_cap: usize,
57}
58
59pub struct SpreadingActivation {
61 params: SpreadingActivationParams,
62}
63
64impl SpreadingActivation {
65 #[must_use]
70 pub fn new(params: SpreadingActivationParams) -> Self {
71 Self { params }
72 }
73
74 #[allow(clippy::too_many_lines)]
90 pub async fn spread(
91 &self,
92 store: &GraphStore,
93 seeds: HashMap<i64, f32>,
94 edge_types: &[EdgeType],
95 ) -> Result<(Vec<ActivatedNode>, Vec<ActivatedFact>), MemoryError> {
96 if seeds.is_empty() {
97 return Ok((Vec::new(), Vec::new()));
98 }
99
100 let now_secs: i64 = SystemTime::now()
103 .duration_since(UNIX_EPOCH)
104 .map(|d| d.as_secs().cast_signed())
105 .unwrap_or(0);
106
107 let mut activation: HashMap<i64, (f32, u32)> = HashMap::new();
109
110 let mut seed_count = 0usize;
113 for (entity_id, match_score) in &seeds {
114 if *match_score < self.params.activation_threshold {
115 tracing::debug!(
116 entity_id,
117 score = match_score,
118 threshold = self.params.activation_threshold,
119 "spreading activation: seed below threshold, skipping"
120 );
121 continue;
122 }
123 activation.insert(*entity_id, (*match_score, 0));
124 seed_count += 1;
125 }
126
127 tracing::debug!(
128 seeds = seed_count,
129 "spreading activation: initialized seeds"
130 );
131
132 let mut activated_facts: Vec<ActivatedFact> = Vec::new();
134
135 for hop in 0..self.params.max_hops {
137 let active_nodes: Vec<(i64, f32)> = activation
139 .iter()
140 .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
141 .map(|(&id, &(score, _))| (id, score))
142 .collect();
143
144 if active_nodes.is_empty() {
145 break;
146 }
147
148 let node_ids: Vec<i64> = active_nodes.iter().map(|(id, _)| *id).collect();
149
150 let edges = store.edges_for_entities(&node_ids, edge_types).await?;
152 let edge_count = edges.len();
153
154 let mut next_activation: HashMap<i64, (f32, u32)> = HashMap::new();
155
156 for edge in &edges {
157 for &(active_id, node_score) in &active_nodes {
160 let neighbor = if edge.source_entity_id == active_id {
161 edge.target_entity_id
162 } else if edge.target_entity_id == active_id {
163 edge.source_entity_id
164 } else {
165 continue;
166 };
167
168 let current_score = activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
173 let next_score = next_activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
174 if current_score >= self.params.inhibition_threshold
175 || next_score >= self.params.inhibition_threshold
176 {
177 continue;
178 }
179
180 let recency = self.recency_weight(&edge.valid_from, now_secs);
181 let edge_weight = evolved_weight(edge.retrieval_count, edge.confidence);
182 let spread_value =
183 node_score * self.params.decay_lambda * edge_weight * recency;
184
185 if spread_value < self.params.activation_threshold {
186 continue;
187 }
188
189 let depth_at_max = hop + 1;
193 let entry = next_activation
194 .entry(neighbor)
195 .or_insert((0.0, depth_at_max));
196 let new_score = (entry.0 + spread_value).min(1.0);
197 if new_score > entry.0 {
198 entry.0 = new_score;
199 entry.1 = depth_at_max;
200 }
201 }
202 }
203
204 for (node_id, (new_score, new_depth)) in next_activation {
206 let entry = activation.entry(node_id).or_insert((0.0, new_depth));
207 if new_score > entry.0 {
208 entry.0 = new_score;
209 entry.1 = new_depth;
210 }
211 }
212
213 let pruned_count = if activation.len() > self.params.max_activated_nodes {
216 let before = activation.len();
217 let mut entries: Vec<(i64, (f32, u32))> = activation.drain().collect();
218 entries.sort_by(|(_, (a, _)), (_, (b, _))| b.total_cmp(a));
219 entries.truncate(self.params.max_activated_nodes);
220 activation = entries.into_iter().collect();
221 before - self.params.max_activated_nodes
222 } else {
223 0
224 };
225
226 tracing::debug!(
227 hop,
228 active_nodes = active_nodes.len(),
229 edges_fetched = edge_count,
230 after_merge = activation.len(),
231 pruned = pruned_count,
232 "spreading activation: hop complete"
233 );
234
235 for edge in edges {
237 let src_score = activation
239 .get(&edge.source_entity_id)
240 .map_or(0.0, |&(s, _)| s);
241 let tgt_score = activation
242 .get(&edge.target_entity_id)
243 .map_or(0.0, |&(s, _)| s);
244 if src_score >= self.params.activation_threshold
245 && tgt_score >= self.params.activation_threshold
246 {
247 let activation_score = src_score.max(tgt_score);
248 activated_facts.push(ActivatedFact {
249 edge,
250 activation_score,
251 });
252 }
253 }
254 }
255
256 let mut result: Vec<ActivatedNode> = activation
258 .into_iter()
259 .filter(|(_, (score, _))| *score >= self.params.activation_threshold)
260 .map(|(entity_id, (activation, depth))| ActivatedNode {
261 entity_id,
262 activation,
263 depth,
264 })
265 .collect();
266 result.sort_by(|a, b| b.activation.total_cmp(&a.activation));
267
268 tracing::info!(
269 activated = result.len(),
270 facts = activated_facts.len(),
271 "spreading activation: complete"
272 );
273
274 Ok((result, activated_facts))
275 }
276
277 #[allow(clippy::cast_precision_loss)]
283 fn recency_weight(&self, valid_from: &str, now_secs: i64) -> f32 {
284 if self.params.temporal_decay_rate <= 0.0 {
285 return 1.0;
286 }
287 let Some(valid_from_secs) = parse_sqlite_datetime_to_unix(valid_from) else {
288 return 1.0;
289 };
290 let age_secs = (now_secs - valid_from_secs).max(0);
291 let age_days = age_secs as f64 / 86_400.0;
292 let weight = 1.0_f64 / (1.0 + age_days * self.params.temporal_decay_rate);
293 #[allow(clippy::cast_possible_truncation)]
295 let w = weight as f32;
296 w
297 }
298}
299
300#[must_use]
305fn parse_sqlite_datetime_to_unix(s: &str) -> Option<i64> {
306 if s.len() < 19 {
307 return None;
308 }
309 let year: i64 = s[0..4].parse().ok()?;
310 let month: i64 = s[5..7].parse().ok()?;
311 let day: i64 = s[8..10].parse().ok()?;
312 let hour: i64 = s[11..13].parse().ok()?;
313 let min: i64 = s[14..16].parse().ok()?;
314 let sec: i64 = s[17..19].parse().ok()?;
315
316 let (y, m) = if month <= 2 {
319 (year - 1, month + 9)
320 } else {
321 (year, month - 3)
322 };
323 let era = y.div_euclid(400);
324 let yoe = y - era * 400;
325 let doy = (153 * m + 2) / 5 + day - 1;
326 let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
327 let days = era * 146_097 + doe - 719_468;
328
329 Some(days * 86_400 + hour * 3_600 + min * 60 + sec)
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335 use crate::graph::GraphStore;
336 use crate::graph::types::EntityType;
337 use crate::sqlite::SqliteStore;
338
339 async fn setup_store() -> GraphStore {
340 let store = SqliteStore::new(":memory:").await.unwrap();
341 GraphStore::new(store.pool().clone())
342 }
343
344 fn default_params() -> SpreadingActivationParams {
345 SpreadingActivationParams {
346 decay_lambda: 0.85,
347 max_hops: 3,
348 activation_threshold: 0.1,
349 inhibition_threshold: 0.8,
350 max_activated_nodes: 50,
351 temporal_decay_rate: 0.0,
352 seed_structural_weight: 0.4,
353 seed_community_cap: 3,
354 }
355 }
356
357 #[tokio::test]
360 async fn spread_empty_graph_no_edges_no_facts() {
361 let store = setup_store().await;
362 let sa = SpreadingActivation::new(default_params());
363 let seeds = HashMap::from([(1_i64, 1.0_f32)]);
364 let (nodes, facts) = sa.spread(&store, seeds, &[]).await.unwrap();
365 assert_eq!(nodes.len(), 1, "seed must be in activated nodes");
367 assert_eq!(nodes[0].entity_id, 1);
368 assert!((nodes[0].activation - 1.0).abs() < 1e-6);
369 assert!(
371 facts.is_empty(),
372 "expected no activated facts on empty graph"
373 );
374 }
375
376 #[tokio::test]
378 async fn spread_empty_seeds_returns_empty() {
379 let store = setup_store().await;
380 let sa = SpreadingActivation::new(default_params());
381 let (nodes, facts) = sa.spread(&store, HashMap::new(), &[]).await.unwrap();
382 assert!(nodes.is_empty());
383 assert!(facts.is_empty());
384 }
385
386 #[tokio::test]
388 async fn spread_single_seed_no_edges_returns_seed() {
389 let store = setup_store().await;
390 let alice = store
391 .upsert_entity("Alice", "Alice", EntityType::Person, None)
392 .await
393 .unwrap();
394
395 let sa = SpreadingActivation::new(default_params());
396 let seeds = HashMap::from([(alice, 1.0_f32)]);
397 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
398 assert_eq!(nodes.len(), 1);
399 assert_eq!(nodes[0].entity_id, alice);
400 assert_eq!(nodes[0].depth, 0);
401 assert!((nodes[0].activation - 1.0).abs() < 1e-6);
402 }
403
404 #[tokio::test]
406 async fn spread_linear_chain_all_activated_with_decay() {
407 let store = setup_store().await;
408 let a = store
409 .upsert_entity("A", "A", EntityType::Person, None)
410 .await
411 .unwrap();
412 let b = store
413 .upsert_entity("B", "B", EntityType::Person, None)
414 .await
415 .unwrap();
416 let c = store
417 .upsert_entity("C", "C", EntityType::Person, None)
418 .await
419 .unwrap();
420 store
421 .insert_edge(a, b, "knows", "A knows B", 1.0, None)
422 .await
423 .unwrap();
424 store
425 .insert_edge(b, c, "knows", "B knows C", 1.0, None)
426 .await
427 .unwrap();
428
429 let mut cfg = default_params();
430 cfg.max_hops = 3;
431 cfg.decay_lambda = 0.9;
432 let sa = SpreadingActivation::new(cfg);
433 let seeds = HashMap::from([(a, 1.0_f32)]);
434 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
435
436 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
437 assert!(ids.contains(&a), "A (seed) must be activated");
438 assert!(ids.contains(&b), "B (hop 1) must be activated");
439 assert!(ids.contains(&c), "C (hop 2) must be activated");
440
441 let score_a = nodes.iter().find(|n| n.entity_id == a).unwrap().activation;
443 let score_b = nodes.iter().find(|n| n.entity_id == b).unwrap().activation;
444 let score_c = nodes.iter().find(|n| n.entity_id == c).unwrap().activation;
445 assert!(
446 score_a > score_b,
447 "seed A should have higher activation than hop-1 B"
448 );
449 assert!(
450 score_b > score_c,
451 "hop-1 B should have higher activation than hop-2 C"
452 );
453 }
454
455 #[tokio::test]
457 async fn spread_linear_chain_max_hops_limits_reach() {
458 let store = setup_store().await;
459 let a = store
460 .upsert_entity("A", "A", EntityType::Person, None)
461 .await
462 .unwrap();
463 let b = store
464 .upsert_entity("B", "B", EntityType::Person, None)
465 .await
466 .unwrap();
467 let c = store
468 .upsert_entity("C", "C", EntityType::Person, None)
469 .await
470 .unwrap();
471 store
472 .insert_edge(a, b, "knows", "A knows B", 1.0, None)
473 .await
474 .unwrap();
475 store
476 .insert_edge(b, c, "knows", "B knows C", 1.0, None)
477 .await
478 .unwrap();
479
480 let mut cfg = default_params();
481 cfg.max_hops = 1;
482 let sa = SpreadingActivation::new(cfg);
483 let seeds = HashMap::from([(a, 1.0_f32)]);
484 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
485
486 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
487 assert!(ids.contains(&a), "A must be activated (seed)");
488 assert!(ids.contains(&b), "B must be activated (hop 1)");
489 assert!(!ids.contains(&c), "C must NOT be activated with max_hops=1");
490 }
491
492 #[tokio::test]
496 async fn spread_diamond_graph_convergence() {
497 let store = setup_store().await;
498 let a = store
499 .upsert_entity("A", "A", EntityType::Person, None)
500 .await
501 .unwrap();
502 let b = store
503 .upsert_entity("B", "B", EntityType::Person, None)
504 .await
505 .unwrap();
506 let c = store
507 .upsert_entity("C", "C", EntityType::Person, None)
508 .await
509 .unwrap();
510 let d = store
511 .upsert_entity("D", "D", EntityType::Person, None)
512 .await
513 .unwrap();
514 store
515 .insert_edge(a, b, "rel", "A-B", 1.0, None)
516 .await
517 .unwrap();
518 store
519 .insert_edge(a, c, "rel", "A-C", 1.0, None)
520 .await
521 .unwrap();
522 store
523 .insert_edge(b, d, "rel", "B-D", 1.0, None)
524 .await
525 .unwrap();
526 store
527 .insert_edge(c, d, "rel", "C-D", 1.0, None)
528 .await
529 .unwrap();
530
531 let mut cfg = default_params();
532 cfg.max_hops = 3;
533 cfg.decay_lambda = 0.9;
534 cfg.inhibition_threshold = 0.95; let sa = SpreadingActivation::new(cfg);
536 let seeds = HashMap::from([(a, 1.0_f32)]);
537 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
538
539 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
540 assert!(ids.contains(&d), "D must be activated via diamond paths");
541
542 let node_d = nodes.iter().find(|n| n.entity_id == d).unwrap();
544 assert_eq!(node_d.depth, 2, "D should be at depth 2");
545 }
546
547 #[tokio::test]
549 async fn spread_inhibition_prevents_runaway() {
550 let store = setup_store().await;
551 let hub = store
553 .upsert_entity("Hub", "Hub", EntityType::Concept, None)
554 .await
555 .unwrap();
556
557 for i in 0..5 {
558 let leaf = store
559 .upsert_entity(
560 &format!("Leaf{i}"),
561 &format!("Leaf{i}"),
562 EntityType::Concept,
563 None,
564 )
565 .await
566 .unwrap();
567 store
568 .insert_edge(hub, leaf, "has", &format!("Hub has Leaf{i}"), 1.0, None)
569 .await
570 .unwrap();
571 store
573 .insert_edge(
574 leaf,
575 hub,
576 "part_of",
577 &format!("Leaf{i} part_of Hub"),
578 1.0,
579 None,
580 )
581 .await
582 .unwrap();
583 }
584
585 let mut cfg = default_params();
587 cfg.inhibition_threshold = 0.8;
588 cfg.max_hops = 3;
589 let sa = SpreadingActivation::new(cfg);
590 let seeds = HashMap::from([(hub, 1.0_f32)]);
591 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
592
593 let hub_node = nodes.iter().find(|n| n.entity_id == hub);
595 assert!(hub_node.is_some(), "hub must be in results");
596 assert!(
597 hub_node.unwrap().activation <= 1.0,
598 "activation must not exceed 1.0"
599 );
600 }
601
602 #[tokio::test]
604 async fn spread_max_activated_nodes_cap_enforced() {
605 let store = setup_store().await;
606 let root = store
607 .upsert_entity("Root", "Root", EntityType::Person, None)
608 .await
609 .unwrap();
610
611 for i in 0..20 {
613 let leaf = store
614 .upsert_entity(
615 &format!("Node{i}"),
616 &format!("Node{i}"),
617 EntityType::Concept,
618 None,
619 )
620 .await
621 .unwrap();
622 store
623 .insert_edge(root, leaf, "has", &format!("Root has Node{i}"), 0.9, None)
624 .await
625 .unwrap();
626 }
627
628 let max_nodes = 5;
629 let cfg = SpreadingActivationParams {
630 max_activated_nodes: max_nodes,
631 max_hops: 2,
632 ..default_params()
633 };
634 let sa = SpreadingActivation::new(cfg);
635 let seeds = HashMap::from([(root, 1.0_f32)]);
636 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
637
638 assert!(
639 nodes.len() <= max_nodes,
640 "activation must be capped at {max_nodes} nodes, got {}",
641 nodes.len()
642 );
643 }
644
645 #[tokio::test]
647 async fn spread_temporal_decay_recency_effect() {
648 let store = setup_store().await;
649 let src = store
650 .upsert_entity("Src", "Src", EntityType::Person, None)
651 .await
652 .unwrap();
653 let recent = store
654 .upsert_entity("Recent", "Recent", EntityType::Tool, None)
655 .await
656 .unwrap();
657 let old = store
658 .upsert_entity("Old", "Old", EntityType::Tool, None)
659 .await
660 .unwrap();
661
662 store
664 .insert_edge(src, recent, "uses", "Src uses Recent", 1.0, None)
665 .await
666 .unwrap();
667
668 sqlx::query(
670 "INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
671 VALUES (?1, ?2, 'uses', 'Src uses Old', 1.0, '1970-01-01 00:00:00')",
672 )
673 .bind(src)
674 .bind(old)
675 .execute(store.pool())
676 .await
677 .unwrap();
678
679 let mut cfg = default_params();
680 cfg.max_hops = 2;
681 let sa = SpreadingActivation::new(SpreadingActivationParams {
683 temporal_decay_rate: 0.5,
684 ..cfg
685 });
686 let seeds = HashMap::from([(src, 1.0_f32)]);
687 let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
688
689 let score_recent = nodes
690 .iter()
691 .find(|n| n.entity_id == recent)
692 .map_or(0.0, |n| n.activation);
693 let score_old = nodes
694 .iter()
695 .find(|n| n.entity_id == old)
696 .map_or(0.0, |n| n.activation);
697
698 assert!(
699 score_recent > score_old,
700 "recent edge ({score_recent}) must produce higher activation than old edge ({score_old})"
701 );
702 }
703
704 #[tokio::test]
706 async fn spread_edge_type_filter_excludes_other_types() {
707 let store = setup_store().await;
708 let a = store
709 .upsert_entity("A", "A", EntityType::Person, None)
710 .await
711 .unwrap();
712 let b_semantic = store
713 .upsert_entity("BSemantic", "BSemantic", EntityType::Tool, None)
714 .await
715 .unwrap();
716 let c_causal = store
717 .upsert_entity("CCausal", "CCausal", EntityType::Concept, None)
718 .await
719 .unwrap();
720
721 store
723 .insert_edge(a, b_semantic, "uses", "A uses BSemantic", 1.0, None)
724 .await
725 .unwrap();
726
727 sqlx::query(
729 "INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from, edge_type)
730 VALUES (?1, ?2, 'caused', 'A caused CCausal', 1.0, datetime('now'), 'causal')",
731 )
732 .bind(a)
733 .bind(c_causal)
734 .execute(store.pool())
735 .await
736 .unwrap();
737
738 let cfg = default_params();
739 let sa = SpreadingActivation::new(cfg);
740
741 let seeds = HashMap::from([(a, 1.0_f32)]);
743 let (nodes, _) = sa
744 .spread(&store, seeds, &[EdgeType::Semantic])
745 .await
746 .unwrap();
747
748 let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
749 assert!(
750 ids.contains(&b_semantic),
751 "BSemantic must be activated via semantic edge"
752 );
753 assert!(
754 !ids.contains(&c_causal),
755 "CCausal must NOT be activated when filtering to semantic only"
756 );
757 }
758
759 #[tokio::test]
761 async fn spread_large_seed_list() {
762 let store = setup_store().await;
763 let mut seeds = HashMap::new();
764
765 for i in 0..100i64 {
767 let id = store
768 .upsert_entity(
769 &format!("Entity{i}"),
770 &format!("entity{i}"),
771 EntityType::Concept,
772 None,
773 )
774 .await
775 .unwrap();
776 seeds.insert(id, 1.0_f32);
777 }
778
779 let cfg = default_params();
780 let sa = SpreadingActivation::new(cfg);
781 let result = sa.spread(&store, seeds, &[]).await;
783 assert!(
784 result.is_ok(),
785 "large seed list must not error: {:?}",
786 result.err()
787 );
788 }
789}