1use super::{
27 neuron::{LIFNeuron, NeuronConfig, SpikeTrain},
28 synapse::{Synapse, SynapseMatrix, STDPConfig, AsymmetricSTDP},
29 SimTime, Spike,
30};
31use crate::graph::{DynamicGraph, VertexId, EdgeId};
32use std::collections::{HashMap, HashSet, VecDeque};
33
34#[derive(Debug, Clone)]
36pub struct CausalConfig {
37 pub num_event_types: usize,
39 pub causal_threshold: f64,
41 pub time_window: f64,
43 pub stdp: AsymmetricSTDP,
45 pub learning_rate: f64,
47 pub decay_rate: f64,
49}
50
51impl Default for CausalConfig {
52 fn default() -> Self {
53 Self {
54 num_event_types: 100,
55 causal_threshold: 0.1,
56 time_window: 50.0,
57 stdp: AsymmetricSTDP::default(),
58 learning_rate: 0.01,
59 decay_rate: 0.001,
60 }
61 }
62}
63
64#[derive(Debug, Clone, Copy, PartialEq)]
66pub enum CausalRelation {
67 Causes,
69 Prevents,
71 None,
73}
74
75#[derive(Debug, Clone)]
77pub struct CausalEdge {
78 pub source: usize,
80 pub target: usize,
82 pub strength: f64,
84 pub relation: CausalRelation,
86}
87
88#[derive(Debug, Clone)]
90pub struct CausalGraph {
91 pub num_nodes: usize,
93 edges: Vec<CausalEdge>,
95 adjacency: HashMap<usize, Vec<(usize, f64, CausalRelation)>>,
97}
98
99impl CausalGraph {
100 pub fn new(num_nodes: usize) -> Self {
102 Self {
103 num_nodes,
104 edges: Vec::new(),
105 adjacency: HashMap::new(),
106 }
107 }
108
109 pub fn add_edge(&mut self, source: usize, target: usize, strength: f64, relation: CausalRelation) {
111 self.edges.push(CausalEdge {
112 source,
113 target,
114 strength,
115 relation,
116 });
117
118 self.adjacency
119 .entry(source)
120 .or_insert_with(Vec::new)
121 .push((target, strength, relation));
122 }
123
124 pub fn edges_from(&self, source: usize) -> &[(usize, f64, CausalRelation)] {
126 self.adjacency.get(&source).map(|v| v.as_slice()).unwrap_or(&[])
127 }
128
129 pub fn edges(&self) -> &[CausalEdge] {
131 &self.edges
132 }
133
134 const MAX_CLOSURE_NODES: usize = 500;
136
137 pub fn transitive_closure(&self) -> Self {
142 let mut closed = Self::new(self.num_nodes);
143
144 if self.num_nodes > Self::MAX_CLOSURE_NODES {
146 for edge in &self.edges {
148 closed.add_edge(edge.source, edge.target, edge.strength, edge.relation);
149 }
150 return closed;
151 }
152
153 for edge in &self.edges {
155 closed.add_edge(edge.source, edge.target, edge.strength, edge.relation);
156 }
157
158 for k in 0..self.num_nodes {
160 for i in 0..self.num_nodes {
161 for j in 0..self.num_nodes {
162 if i == j || i == k || j == k {
163 continue;
164 }
165
166 let ik_strength = self.adjacency.get(&i)
168 .and_then(|edges| edges.iter().find(|(t, _, _)| *t == k))
169 .map(|(_, s, _)| *s);
170
171 let kj_strength = self.adjacency.get(&k)
172 .and_then(|edges| edges.iter().find(|(t, _, _)| *t == j))
173 .map(|(_, s, _)| *s);
174
175 if let (Some(s1), Some(s2)) = (ik_strength, kj_strength) {
176 let indirect_strength = s1 * s2;
177
178 let existing = closed.adjacency.get(&i)
180 .and_then(|edges| edges.iter().find(|(t, _, _)| *t == j))
181 .map(|(_, s, _)| *s)
182 .unwrap_or(0.0);
183
184 if indirect_strength > existing {
185 closed.add_edge(i, j, indirect_strength, CausalRelation::Causes);
186 }
187 }
188 }
189 }
190 }
191
192 closed
193 }
194
195 pub fn reachable_from(&self, source: usize) -> HashSet<usize> {
197 let mut visited = HashSet::new();
198 let mut queue = VecDeque::new();
199
200 queue.push_back(source);
201 visited.insert(source);
202
203 while let Some(node) = queue.pop_front() {
204 for (target, _, _) in self.edges_from(node) {
205 if visited.insert(*target) {
206 queue.push_back(*target);
207 }
208 }
209 }
210
211 visited
212 }
213
214 pub fn to_undirected(&self) -> DynamicGraph {
216 let graph = DynamicGraph::new();
217
218 for edge in &self.edges {
219 if !graph.has_edge(edge.source as u64, edge.target as u64) {
220 let _ = graph.insert_edge(
221 edge.source as u64,
222 edge.target as u64,
223 edge.strength,
224 );
225 }
226 }
227
228 graph
229 }
230}
231
232#[derive(Debug, Clone)]
234pub struct GraphEvent {
235 pub event_type: GraphEventType,
237 pub vertex: Option<VertexId>,
239 pub edge: Option<(VertexId, VertexId)>,
241 pub data: f64,
243}
244
245#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
247pub enum GraphEventType {
248 EdgeInsert,
250 EdgeDelete,
252 WeightChange,
254 MinCutChange,
256 ComponentSplit,
258 ComponentMerge,
260}
261
262pub struct CausalDiscoverySNN {
264 event_neurons: Vec<LIFNeuron>,
266 spike_trains: Vec<SpikeTrain>,
268 synapses: SynapseMatrix,
270 stdp: AsymmetricSTDP,
272 config: CausalConfig,
274 time: SimTime,
276 event_type_map: HashMap<GraphEventType, usize>,
278 index_to_event: HashMap<usize, GraphEventType>,
280}
281
282impl CausalDiscoverySNN {
283 pub fn new(config: CausalConfig) -> Self {
285 let n = config.num_event_types;
286
287 let neuron_config = NeuronConfig {
289 tau_membrane: 10.0, threshold: 0.5,
291 ..NeuronConfig::default()
292 };
293
294 let event_neurons: Vec<_> = (0..n)
295 .map(|i| LIFNeuron::with_config(i, neuron_config.clone()))
296 .collect();
297
298 let spike_trains: Vec<_> = (0..n)
299 .map(|i| SpikeTrain::with_window(i, config.time_window * 10.0))
300 .collect();
301
302 let mut synapses = SynapseMatrix::new(n, n);
304 for i in 0..n {
305 for j in 0..n {
306 if i != j {
307 synapses.add_synapse(i, j, 0.0); }
309 }
310 }
311
312 let event_type_map: HashMap<_, _> = [
314 (GraphEventType::EdgeInsert, 0),
315 (GraphEventType::EdgeDelete, 1),
316 (GraphEventType::WeightChange, 2),
317 (GraphEventType::MinCutChange, 3),
318 (GraphEventType::ComponentSplit, 4),
319 (GraphEventType::ComponentMerge, 5),
320 ].iter().cloned().collect();
321
322 let index_to_event: HashMap<_, _> = event_type_map.iter()
323 .map(|(k, v)| (*v, *k))
324 .collect();
325
326 Self {
327 event_neurons,
328 spike_trains,
329 synapses,
330 stdp: config.stdp.clone(),
331 config,
332 time: 0.0,
333 event_type_map,
334 index_to_event,
335 }
336 }
337
338 fn event_to_neuron(&self, event: &GraphEvent) -> usize {
340 self.event_type_map.get(&event.event_type).copied().unwrap_or(0)
341 }
342
343 pub fn observe_event(&mut self, event: GraphEvent, timestamp: SimTime) {
345 self.time = timestamp;
346
347 let neuron_id = self.event_to_neuron(&event);
349
350 if neuron_id < self.event_neurons.len() {
351 self.event_neurons[neuron_id].inject_spike(timestamp);
353 self.spike_trains[neuron_id].record_spike(timestamp);
354
355 self.stdp.update_weights(&mut self.synapses, neuron_id, timestamp);
357 }
358 }
359
360 pub fn observe_events(&mut self, events: &[GraphEvent], timestamps: &[SimTime]) {
362 for (event, &ts) in events.iter().zip(timestamps.iter()) {
363 self.observe_event(event.clone(), ts);
364 }
365 }
366
367 pub fn decay_weights(&mut self) {
371 let decay = self.config.decay_rate;
372 let baseline = 0.5; let n = self.config.num_event_types;
374
375 for i in 0..n {
377 for j in 0..n {
378 if let Some(synapse) = self.synapses.get_synapse_mut(i, j) {
379 synapse.weight = synapse.weight * (1.0 - decay) + baseline * decay;
381 }
382 }
383 }
384 }
385
386 pub fn extract_causal_graph(&self) -> CausalGraph {
388 let n = self.config.num_event_types;
389 let mut graph = CausalGraph::new(n);
390
391 for ((i, j), synapse) in self.synapses.iter() {
392 let w = synapse.weight;
393
394 if w.abs() > self.config.causal_threshold {
395 let strength = w.abs();
396 let relation = if w > 0.0 {
397 CausalRelation::Causes
398 } else {
399 CausalRelation::Prevents
400 };
401
402 graph.add_edge(*i, *j, strength, relation);
403 }
404 }
405
406 graph
407 }
408
409 pub fn optimal_intervention_points(
411 &self,
412 controllable: &[usize],
413 targets: &[usize],
414 ) -> Vec<usize> {
415 let causal = self.extract_causal_graph();
416 let undirected = causal.to_undirected();
417
418 let mut intervention_points = Vec::new();
420 let controllable_set: HashSet<_> = controllable.iter().cloned().collect();
421 let target_set: HashSet<_> = targets.iter().cloned().collect();
422
423 for edge in causal.edges() {
424 if controllable_set.contains(&edge.source) ||
426 target_set.contains(&edge.target) {
427 intervention_points.push(edge.source);
428 }
429 }
430
431 intervention_points.sort();
432 intervention_points.dedup();
433 intervention_points
434 }
435
436 pub fn causal_strength(&self, from: GraphEventType, to: GraphEventType) -> f64 {
438 let i = self.event_type_map.get(&from).copied().unwrap_or(0);
439 let j = self.event_type_map.get(&to).copied().unwrap_or(0);
440
441 self.synapses.weight(i, j)
442 }
443
444 pub fn direct_causes(&self, event_type: GraphEventType) -> Vec<(GraphEventType, f64)> {
446 let j = self.event_type_map.get(&event_type).copied().unwrap_or(0);
447 let mut causes = Vec::new();
448
449 for i in 0..self.config.num_event_types {
450 if i != j {
451 let w = self.synapses.weight(i, j);
452 if w > self.config.causal_threshold {
453 if let Some(&event) = self.index_to_event.get(&i) {
454 causes.push((event, w));
455 }
456 }
457 }
458 }
459
460 causes.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
461 causes
462 }
463
464 pub fn direct_effects(&self, event_type: GraphEventType) -> Vec<(GraphEventType, f64)> {
466 let i = self.event_type_map.get(&event_type).copied().unwrap_or(0);
467 let mut effects = Vec::new();
468
469 for j in 0..self.config.num_event_types {
470 if i != j {
471 let w = self.synapses.weight(i, j);
472 if w > self.config.causal_threshold {
473 if let Some(&event) = self.index_to_event.get(&j) {
474 effects.push((event, w));
475 }
476 }
477 }
478 }
479
480 effects.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
481 effects
482 }
483
484 pub fn reset(&mut self) {
486 self.time = 0.0;
487
488 for neuron in &mut self.event_neurons {
489 neuron.reset();
490 }
491
492 for train in &mut self.spike_trains {
493 train.clear();
494 }
495
496 for i in 0..self.config.num_event_types {
498 for j in 0..self.config.num_event_types {
499 if i != j {
500 self.synapses.set_weight(i, j, 0.0);
501 }
502 }
503 }
504 }
505
506 pub fn summary(&self) -> CausalSummary {
508 let causal = self.extract_causal_graph();
509
510 let mut total_strength = 0.0;
511 let mut causes_count = 0;
512 let mut prevents_count = 0;
513
514 for edge in causal.edges() {
515 total_strength += edge.strength;
516 match edge.relation {
517 CausalRelation::Causes => causes_count += 1,
518 CausalRelation::Prevents => prevents_count += 1,
519 CausalRelation::None => {}
520 }
521 }
522
523 CausalSummary {
524 num_relationships: causal.edges().len(),
525 causes_count,
526 prevents_count,
527 avg_strength: total_strength / causal.edges().len().max(1) as f64,
528 time_elapsed: self.time,
529 }
530 }
531}
532
533#[derive(Debug, Clone)]
535pub struct CausalSummary {
536 pub num_relationships: usize,
538 pub causes_count: usize,
540 pub prevents_count: usize,
542 pub avg_strength: f64,
544 pub time_elapsed: SimTime,
546}
547
548#[cfg(test)]
549mod tests {
550 use super::*;
551
552 #[test]
553 fn test_causal_graph() {
554 let mut graph = CausalGraph::new(5);
555 graph.add_edge(0, 1, 0.8, CausalRelation::Causes);
556 graph.add_edge(1, 2, 0.6, CausalRelation::Causes);
557
558 assert_eq!(graph.edges().len(), 2);
559
560 let reachable = graph.reachable_from(0);
561 assert!(reachable.contains(&1));
562 assert!(reachable.contains(&2));
563 }
564
565 #[test]
566 fn test_causal_discovery_snn() {
567 let config = CausalConfig::default();
568 let mut snn = CausalDiscoverySNN::new(config);
569
570 for i in 0..10 {
572 let t = i as f64 * 10.0;
573
574 snn.observe_event(
576 GraphEvent {
577 event_type: GraphEventType::EdgeInsert,
578 vertex: None,
579 edge: Some((0, 1)),
580 data: 1.0,
581 },
582 t,
583 );
584
585 snn.observe_event(
586 GraphEvent {
587 event_type: GraphEventType::MinCutChange,
588 vertex: None,
589 edge: None,
590 data: 0.5,
591 },
592 t + 5.0,
593 );
594 }
595
596 let summary = snn.summary();
597 assert!(summary.time_elapsed > 0.0);
598 }
599
600 #[test]
601 fn test_transitive_closure() {
602 let mut graph = CausalGraph::new(4);
603 graph.add_edge(0, 1, 0.8, CausalRelation::Causes);
604 graph.add_edge(1, 2, 0.6, CausalRelation::Causes);
605 graph.add_edge(2, 3, 0.5, CausalRelation::Causes);
606
607 let closed = graph.transitive_closure();
608
609 assert!(closed.edges().len() >= 3);
611 }
612
613 #[test]
614 fn test_intervention_points() {
615 let config = CausalConfig::default();
616 let snn = CausalDiscoverySNN::new(config);
617
618 let interventions = snn.optimal_intervention_points(&[0, 1], &[3, 4]);
619 assert!(interventions.len() >= 0);
621 }
622}