1use super::{
27 neuron::{LIFNeuron, NeuronConfig, SpikeTrain},
28 synapse::{AsymmetricSTDP, STDPConfig, Synapse, SynapseMatrix},
29 SimTime, Spike,
30};
31use crate::graph::{DynamicGraph, EdgeId, VertexId};
32use std::collections::{HashMap, HashSet, VecDeque};
33
34#[derive(Debug, Clone)]
36pub struct CausalConfig {
37 pub num_event_types: usize,
39 pub causal_threshold: f64,
41 pub time_window: f64,
43 pub stdp: AsymmetricSTDP,
45 pub learning_rate: f64,
47 pub decay_rate: f64,
49}
50
51impl Default for CausalConfig {
52 fn default() -> Self {
53 Self {
54 num_event_types: 100,
55 causal_threshold: 0.1,
56 time_window: 50.0,
57 stdp: AsymmetricSTDP::default(),
58 learning_rate: 0.01,
59 decay_rate: 0.001,
60 }
61 }
62}
63
64#[derive(Debug, Clone, Copy, PartialEq)]
66pub enum CausalRelation {
67 Causes,
69 Prevents,
71 None,
73}
74
75#[derive(Debug, Clone)]
77pub struct CausalEdge {
78 pub source: usize,
80 pub target: usize,
82 pub strength: f64,
84 pub relation: CausalRelation,
86}
87
88#[derive(Debug, Clone)]
90pub struct CausalGraph {
91 pub num_nodes: usize,
93 edges: Vec<CausalEdge>,
95 adjacency: HashMap<usize, Vec<(usize, f64, CausalRelation)>>,
97}
98
99impl CausalGraph {
100 pub fn new(num_nodes: usize) -> Self {
102 Self {
103 num_nodes,
104 edges: Vec::new(),
105 adjacency: HashMap::new(),
106 }
107 }
108
109 pub fn add_edge(
111 &mut self,
112 source: usize,
113 target: usize,
114 strength: f64,
115 relation: CausalRelation,
116 ) {
117 self.edges.push(CausalEdge {
118 source,
119 target,
120 strength,
121 relation,
122 });
123
124 self.adjacency
125 .entry(source)
126 .or_insert_with(Vec::new)
127 .push((target, strength, relation));
128 }
129
130 pub fn edges_from(&self, source: usize) -> &[(usize, f64, CausalRelation)] {
132 self.adjacency
133 .get(&source)
134 .map(|v| v.as_slice())
135 .unwrap_or(&[])
136 }
137
138 pub fn edges(&self) -> &[CausalEdge] {
140 &self.edges
141 }
142
143 const MAX_CLOSURE_NODES: usize = 500;
145
146 pub fn transitive_closure(&self) -> Self {
151 let mut closed = Self::new(self.num_nodes);
152
153 if self.num_nodes > Self::MAX_CLOSURE_NODES {
155 for edge in &self.edges {
157 closed.add_edge(edge.source, edge.target, edge.strength, edge.relation);
158 }
159 return closed;
160 }
161
162 for edge in &self.edges {
164 closed.add_edge(edge.source, edge.target, edge.strength, edge.relation);
165 }
166
167 for k in 0..self.num_nodes {
169 for i in 0..self.num_nodes {
170 for j in 0..self.num_nodes {
171 if i == j || i == k || j == k {
172 continue;
173 }
174
175 let ik_strength = self
177 .adjacency
178 .get(&i)
179 .and_then(|edges| edges.iter().find(|(t, _, _)| *t == k))
180 .map(|(_, s, _)| *s);
181
182 let kj_strength = self
183 .adjacency
184 .get(&k)
185 .and_then(|edges| edges.iter().find(|(t, _, _)| *t == j))
186 .map(|(_, s, _)| *s);
187
188 if let (Some(s1), Some(s2)) = (ik_strength, kj_strength) {
189 let indirect_strength = s1 * s2;
190
191 let existing = closed
193 .adjacency
194 .get(&i)
195 .and_then(|edges| edges.iter().find(|(t, _, _)| *t == j))
196 .map(|(_, s, _)| *s)
197 .unwrap_or(0.0);
198
199 if indirect_strength > existing {
200 closed.add_edge(i, j, indirect_strength, CausalRelation::Causes);
201 }
202 }
203 }
204 }
205 }
206
207 closed
208 }
209
210 pub fn reachable_from(&self, source: usize) -> HashSet<usize> {
212 let mut visited = HashSet::new();
213 let mut queue = VecDeque::new();
214
215 queue.push_back(source);
216 visited.insert(source);
217
218 while let Some(node) = queue.pop_front() {
219 for (target, _, _) in self.edges_from(node) {
220 if visited.insert(*target) {
221 queue.push_back(*target);
222 }
223 }
224 }
225
226 visited
227 }
228
229 pub fn to_undirected(&self) -> DynamicGraph {
231 let graph = DynamicGraph::new();
232
233 for edge in &self.edges {
234 if !graph.has_edge(edge.source as u64, edge.target as u64) {
235 let _ = graph.insert_edge(edge.source as u64, edge.target as u64, edge.strength);
236 }
237 }
238
239 graph
240 }
241}
242
243#[derive(Debug, Clone)]
245pub struct GraphEvent {
246 pub event_type: GraphEventType,
248 pub vertex: Option<VertexId>,
250 pub edge: Option<(VertexId, VertexId)>,
252 pub data: f64,
254}
255
256#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
258pub enum GraphEventType {
259 EdgeInsert,
261 EdgeDelete,
263 WeightChange,
265 MinCutChange,
267 ComponentSplit,
269 ComponentMerge,
271}
272
273pub struct CausalDiscoverySNN {
275 event_neurons: Vec<LIFNeuron>,
277 spike_trains: Vec<SpikeTrain>,
279 synapses: SynapseMatrix,
281 stdp: AsymmetricSTDP,
283 config: CausalConfig,
285 time: SimTime,
287 event_type_map: HashMap<GraphEventType, usize>,
289 index_to_event: HashMap<usize, GraphEventType>,
291}
292
293impl CausalDiscoverySNN {
294 pub fn new(config: CausalConfig) -> Self {
296 let n = config.num_event_types;
297
298 let neuron_config = NeuronConfig {
300 tau_membrane: 10.0, threshold: 0.5,
302 ..NeuronConfig::default()
303 };
304
305 let event_neurons: Vec<_> = (0..n)
306 .map(|i| LIFNeuron::with_config(i, neuron_config.clone()))
307 .collect();
308
309 let spike_trains: Vec<_> = (0..n)
310 .map(|i| SpikeTrain::with_window(i, config.time_window * 10.0))
311 .collect();
312
313 let mut synapses = SynapseMatrix::new(n, n);
315 for i in 0..n {
316 for j in 0..n {
317 if i != j {
318 synapses.add_synapse(i, j, 0.0); }
320 }
321 }
322
323 let event_type_map: HashMap<_, _> = [
325 (GraphEventType::EdgeInsert, 0),
326 (GraphEventType::EdgeDelete, 1),
327 (GraphEventType::WeightChange, 2),
328 (GraphEventType::MinCutChange, 3),
329 (GraphEventType::ComponentSplit, 4),
330 (GraphEventType::ComponentMerge, 5),
331 ]
332 .iter()
333 .cloned()
334 .collect();
335
336 let index_to_event: HashMap<_, _> = event_type_map.iter().map(|(k, v)| (*v, *k)).collect();
337
338 Self {
339 event_neurons,
340 spike_trains,
341 synapses,
342 stdp: config.stdp.clone(),
343 config,
344 time: 0.0,
345 event_type_map,
346 index_to_event,
347 }
348 }
349
350 fn event_to_neuron(&self, event: &GraphEvent) -> usize {
352 self.event_type_map
353 .get(&event.event_type)
354 .copied()
355 .unwrap_or(0)
356 }
357
358 pub fn observe_event(&mut self, event: GraphEvent, timestamp: SimTime) {
360 self.time = timestamp;
361
362 let neuron_id = self.event_to_neuron(&event);
364
365 if neuron_id < self.event_neurons.len() {
366 self.event_neurons[neuron_id].inject_spike(timestamp);
368 self.spike_trains[neuron_id].record_spike(timestamp);
369
370 self.stdp
372 .update_weights(&mut self.synapses, neuron_id, timestamp);
373 }
374 }
375
376 pub fn observe_events(&mut self, events: &[GraphEvent], timestamps: &[SimTime]) {
378 for (event, &ts) in events.iter().zip(timestamps.iter()) {
379 self.observe_event(event.clone(), ts);
380 }
381 }
382
383 pub fn decay_weights(&mut self) {
387 let decay = self.config.decay_rate;
388 let baseline = 0.5; let n = self.config.num_event_types;
390
391 for i in 0..n {
393 for j in 0..n {
394 if let Some(synapse) = self.synapses.get_synapse_mut(i, j) {
395 synapse.weight = synapse.weight * (1.0 - decay) + baseline * decay;
397 }
398 }
399 }
400 }
401
402 pub fn extract_causal_graph(&self) -> CausalGraph {
404 let n = self.config.num_event_types;
405 let mut graph = CausalGraph::new(n);
406
407 for ((i, j), synapse) in self.synapses.iter() {
408 let w = synapse.weight;
409
410 if w.abs() > self.config.causal_threshold {
411 let strength = w.abs();
412 let relation = if w > 0.0 {
413 CausalRelation::Causes
414 } else {
415 CausalRelation::Prevents
416 };
417
418 graph.add_edge(*i, *j, strength, relation);
419 }
420 }
421
422 graph
423 }
424
425 pub fn optimal_intervention_points(
427 &self,
428 controllable: &[usize],
429 targets: &[usize],
430 ) -> Vec<usize> {
431 let causal = self.extract_causal_graph();
432 let undirected = causal.to_undirected();
433
434 let mut intervention_points = Vec::new();
436 let controllable_set: HashSet<_> = controllable.iter().cloned().collect();
437 let target_set: HashSet<_> = targets.iter().cloned().collect();
438
439 for edge in causal.edges() {
440 if controllable_set.contains(&edge.source) || target_set.contains(&edge.target) {
442 intervention_points.push(edge.source);
443 }
444 }
445
446 intervention_points.sort();
447 intervention_points.dedup();
448 intervention_points
449 }
450
451 pub fn causal_strength(&self, from: GraphEventType, to: GraphEventType) -> f64 {
453 let i = self.event_type_map.get(&from).copied().unwrap_or(0);
454 let j = self.event_type_map.get(&to).copied().unwrap_or(0);
455
456 self.synapses.weight(i, j)
457 }
458
459 pub fn direct_causes(&self, event_type: GraphEventType) -> Vec<(GraphEventType, f64)> {
461 let j = self.event_type_map.get(&event_type).copied().unwrap_or(0);
462 let mut causes = Vec::new();
463
464 for i in 0..self.config.num_event_types {
465 if i != j {
466 let w = self.synapses.weight(i, j);
467 if w > self.config.causal_threshold {
468 if let Some(&event) = self.index_to_event.get(&i) {
469 causes.push((event, w));
470 }
471 }
472 }
473 }
474
475 causes.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
476 causes
477 }
478
479 pub fn direct_effects(&self, event_type: GraphEventType) -> Vec<(GraphEventType, f64)> {
481 let i = self.event_type_map.get(&event_type).copied().unwrap_or(0);
482 let mut effects = Vec::new();
483
484 for j in 0..self.config.num_event_types {
485 if i != j {
486 let w = self.synapses.weight(i, j);
487 if w > self.config.causal_threshold {
488 if let Some(&event) = self.index_to_event.get(&j) {
489 effects.push((event, w));
490 }
491 }
492 }
493 }
494
495 effects.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
496 effects
497 }
498
499 pub fn reset(&mut self) {
501 self.time = 0.0;
502
503 for neuron in &mut self.event_neurons {
504 neuron.reset();
505 }
506
507 for train in &mut self.spike_trains {
508 train.clear();
509 }
510
511 for i in 0..self.config.num_event_types {
513 for j in 0..self.config.num_event_types {
514 if i != j {
515 self.synapses.set_weight(i, j, 0.0);
516 }
517 }
518 }
519 }
520
521 pub fn summary(&self) -> CausalSummary {
523 let causal = self.extract_causal_graph();
524
525 let mut total_strength = 0.0;
526 let mut causes_count = 0;
527 let mut prevents_count = 0;
528
529 for edge in causal.edges() {
530 total_strength += edge.strength;
531 match edge.relation {
532 CausalRelation::Causes => causes_count += 1,
533 CausalRelation::Prevents => prevents_count += 1,
534 CausalRelation::None => {}
535 }
536 }
537
538 CausalSummary {
539 num_relationships: causal.edges().len(),
540 causes_count,
541 prevents_count,
542 avg_strength: total_strength / causal.edges().len().max(1) as f64,
543 time_elapsed: self.time,
544 }
545 }
546}
547
548#[derive(Debug, Clone)]
550pub struct CausalSummary {
551 pub num_relationships: usize,
553 pub causes_count: usize,
555 pub prevents_count: usize,
557 pub avg_strength: f64,
559 pub time_elapsed: SimTime,
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566
567 #[test]
568 fn test_causal_graph() {
569 let mut graph = CausalGraph::new(5);
570 graph.add_edge(0, 1, 0.8, CausalRelation::Causes);
571 graph.add_edge(1, 2, 0.6, CausalRelation::Causes);
572
573 assert_eq!(graph.edges().len(), 2);
574
575 let reachable = graph.reachable_from(0);
576 assert!(reachable.contains(&1));
577 assert!(reachable.contains(&2));
578 }
579
580 #[test]
581 fn test_causal_discovery_snn() {
582 let config = CausalConfig::default();
583 let mut snn = CausalDiscoverySNN::new(config);
584
585 for i in 0..10 {
587 let t = i as f64 * 10.0;
588
589 snn.observe_event(
591 GraphEvent {
592 event_type: GraphEventType::EdgeInsert,
593 vertex: None,
594 edge: Some((0, 1)),
595 data: 1.0,
596 },
597 t,
598 );
599
600 snn.observe_event(
601 GraphEvent {
602 event_type: GraphEventType::MinCutChange,
603 vertex: None,
604 edge: None,
605 data: 0.5,
606 },
607 t + 5.0,
608 );
609 }
610
611 let summary = snn.summary();
612 assert!(summary.time_elapsed > 0.0);
613 }
614
615 #[test]
616 fn test_transitive_closure() {
617 let mut graph = CausalGraph::new(4);
618 graph.add_edge(0, 1, 0.8, CausalRelation::Causes);
619 graph.add_edge(1, 2, 0.6, CausalRelation::Causes);
620 graph.add_edge(2, 3, 0.5, CausalRelation::Causes);
621
622 let closed = graph.transitive_closure();
623
624 assert!(closed.edges().len() >= 3);
626 }
627
628 #[test]
629 fn test_intervention_points() {
630 let config = CausalConfig::default();
631 let snn = CausalDiscoverySNN::new(config);
632
633 let interventions = snn.optimal_intervention_points(&[0, 1], &[3, 4]);
634 assert!(interventions.len() >= 0);
636 }
637}