ruvector_mincut/snn/
causal.rs

1//! # Layer 3: Causal Discovery via Spike Timing
2//!
3//! Uses spike-timing cross-correlation to infer causal relationships in graph events.
4//!
5//! ## Key Insight
6//!
7//! Spike train cross-correlation with asymmetric temporal windows naturally encodes
8//! Granger-like causality:
9//!
10//! ```text
11//! Neuron A:  ──●────────●────────●────────
12//! Neuron B:  ────────●────────●────────●──
13//!            │←─Δt─→│
14//!
15//! If Δt consistently positive → A causes B
16//! STDP learning rule naturally encodes this!
17//! ```
18//!
19//! After learning, W_AB reflects causal strength A→B
20//!
21//! ## MinCut Application
22//!
23//! MinCut on the causal graph reveals optimal intervention points -
24//! minimum changes needed to affect outcomes.
25
26use super::{
27    neuron::{LIFNeuron, NeuronConfig, SpikeTrain},
28    synapse::{AsymmetricSTDP, STDPConfig, Synapse, SynapseMatrix},
29    SimTime, Spike,
30};
31use crate::graph::{DynamicGraph, EdgeId, VertexId};
32use std::collections::{HashMap, HashSet, VecDeque};
33
34/// Configuration for causal discovery
35#[derive(Debug, Clone)]
36pub struct CausalConfig {
37    /// Number of event types (neurons)
38    pub num_event_types: usize,
39    /// Threshold for causal relationship detection
40    pub causal_threshold: f64,
41    /// Time window for causality (ms)
42    pub time_window: f64,
43    /// Asymmetric STDP configuration
44    pub stdp: AsymmetricSTDP,
45    /// Learning rate for causal weight updates
46    pub learning_rate: f64,
47    /// Decay rate for causal weights
48    pub decay_rate: f64,
49}
50
51impl Default for CausalConfig {
52    fn default() -> Self {
53        Self {
54            num_event_types: 100,
55            causal_threshold: 0.1,
56            time_window: 50.0,
57            stdp: AsymmetricSTDP::default(),
58            learning_rate: 0.01,
59            decay_rate: 0.001,
60        }
61    }
62}
63
64/// Type of causal relationship
65#[derive(Debug, Clone, Copy, PartialEq)]
66pub enum CausalRelation {
67    /// A causes B (positive influence)
68    Causes,
69    /// A prevents B (negative influence)
70    Prevents,
71    /// No significant causal relationship
72    None,
73}
74
75/// A directed causal relationship
76#[derive(Debug, Clone)]
77pub struct CausalEdge {
78    /// Source event type
79    pub source: usize,
80    /// Target event type
81    pub target: usize,
82    /// Causal strength (absolute value)
83    pub strength: f64,
84    /// Type of relationship
85    pub relation: CausalRelation,
86}
87
88/// Directed graph representing causal relationships
89#[derive(Debug, Clone)]
90pub struct CausalGraph {
91    /// Number of nodes (event types)
92    pub num_nodes: usize,
93    /// Causal edges
94    edges: Vec<CausalEdge>,
95    /// Adjacency list (source → targets)
96    adjacency: HashMap<usize, Vec<(usize, f64, CausalRelation)>>,
97}
98
99impl CausalGraph {
100    /// Create a new empty causal graph
101    pub fn new(num_nodes: usize) -> Self {
102        Self {
103            num_nodes,
104            edges: Vec::new(),
105            adjacency: HashMap::new(),
106        }
107    }
108
109    /// Add a causal edge
110    pub fn add_edge(
111        &mut self,
112        source: usize,
113        target: usize,
114        strength: f64,
115        relation: CausalRelation,
116    ) {
117        self.edges.push(CausalEdge {
118            source,
119            target,
120            strength,
121            relation,
122        });
123
124        self.adjacency
125            .entry(source)
126            .or_insert_with(Vec::new)
127            .push((target, strength, relation));
128    }
129
130    /// Get edges from a node
131    pub fn edges_from(&self, source: usize) -> &[(usize, f64, CausalRelation)] {
132        self.adjacency
133            .get(&source)
134            .map(|v| v.as_slice())
135            .unwrap_or(&[])
136    }
137
138    /// Get all edges
139    pub fn edges(&self) -> &[CausalEdge] {
140        &self.edges
141    }
142
143    /// Maximum nodes for transitive closure (O(n³) algorithm)
144    const MAX_CLOSURE_NODES: usize = 500;
145
146    /// Compute transitive closure (indirect causation)
147    ///
148    /// Uses Floyd-Warshall algorithm with O(n³) complexity.
149    /// Limited to MAX_CLOSURE_NODES to prevent DoS.
150    pub fn transitive_closure(&self) -> Self {
151        let mut closed = Self::new(self.num_nodes);
152
153        // Resource limit: skip if too many nodes (O(n³) would be too slow)
154        if self.num_nodes > Self::MAX_CLOSURE_NODES {
155            // Just copy direct edges without transitive closure
156            for edge in &self.edges {
157                closed.add_edge(edge.source, edge.target, edge.strength, edge.relation);
158            }
159            return closed;
160        }
161
162        // Copy direct edges
163        for edge in &self.edges {
164            closed.add_edge(edge.source, edge.target, edge.strength, edge.relation);
165        }
166
167        // Floyd-Warshall-like algorithm for transitive closure
168        for k in 0..self.num_nodes {
169            for i in 0..self.num_nodes {
170                for j in 0..self.num_nodes {
171                    if i == j || i == k || j == k {
172                        continue;
173                    }
174
175                    // Check if path i→k→j exists
176                    let ik_strength = self
177                        .adjacency
178                        .get(&i)
179                        .and_then(|edges| edges.iter().find(|(t, _, _)| *t == k))
180                        .map(|(_, s, _)| *s);
181
182                    let kj_strength = self
183                        .adjacency
184                        .get(&k)
185                        .and_then(|edges| edges.iter().find(|(t, _, _)| *t == j))
186                        .map(|(_, s, _)| *s);
187
188                    if let (Some(s1), Some(s2)) = (ik_strength, kj_strength) {
189                        let indirect_strength = s1 * s2;
190
191                        // Only add if stronger than existing direct path
192                        let existing = closed
193                            .adjacency
194                            .get(&i)
195                            .and_then(|edges| edges.iter().find(|(t, _, _)| *t == j))
196                            .map(|(_, s, _)| *s)
197                            .unwrap_or(0.0);
198
199                        if indirect_strength > existing {
200                            closed.add_edge(i, j, indirect_strength, CausalRelation::Causes);
201                        }
202                    }
203                }
204            }
205        }
206
207        closed
208    }
209
210    /// Find nodes reachable from a source
211    pub fn reachable_from(&self, source: usize) -> HashSet<usize> {
212        let mut visited = HashSet::new();
213        let mut queue = VecDeque::new();
214
215        queue.push_back(source);
216        visited.insert(source);
217
218        while let Some(node) = queue.pop_front() {
219            for (target, _, _) in self.edges_from(node) {
220                if visited.insert(*target) {
221                    queue.push_back(*target);
222                }
223            }
224        }
225
226        visited
227    }
228
229    /// Convert to undirected graph for mincut analysis
230    pub fn to_undirected(&self) -> DynamicGraph {
231        let graph = DynamicGraph::new();
232
233        for edge in &self.edges {
234            if !graph.has_edge(edge.source as u64, edge.target as u64) {
235                let _ = graph.insert_edge(edge.source as u64, edge.target as u64, edge.strength);
236            }
237        }
238
239        graph
240    }
241}
242
243/// Graph event that can be observed
244#[derive(Debug, Clone)]
245pub struct GraphEvent {
246    /// Type of event
247    pub event_type: GraphEventType,
248    /// Associated vertex (if applicable)
249    pub vertex: Option<VertexId>,
250    /// Associated edge (if applicable)
251    pub edge: Option<(VertexId, VertexId)>,
252    /// Event metadata
253    pub data: f64,
254}
255
256/// Types of graph events
257#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
258pub enum GraphEventType {
259    /// Edge was added
260    EdgeInsert,
261    /// Edge was removed
262    EdgeDelete,
263    /// Edge weight changed
264    WeightChange,
265    /// MinCut value changed
266    MinCutChange,
267    /// Component split
268    ComponentSplit,
269    /// Component merged
270    ComponentMerge,
271}
272
273/// Causal discovery using spiking neural network
274pub struct CausalDiscoverySNN {
275    /// One neuron per graph event type
276    event_neurons: Vec<LIFNeuron>,
277    /// Spike trains for each neuron
278    spike_trains: Vec<SpikeTrain>,
279    /// Synaptic weights encode discovered causal strength
280    synapses: SynapseMatrix,
281    /// Asymmetric STDP for causality detection
282    stdp: AsymmetricSTDP,
283    /// Configuration
284    config: CausalConfig,
285    /// Current simulation time
286    time: SimTime,
287    /// Event type mapping
288    event_type_map: HashMap<GraphEventType, usize>,
289    /// Reverse mapping
290    index_to_event: HashMap<usize, GraphEventType>,
291}
292
293impl CausalDiscoverySNN {
294    /// Create a new causal discovery SNN
295    pub fn new(config: CausalConfig) -> Self {
296        let n = config.num_event_types;
297
298        // Create event neurons
299        let neuron_config = NeuronConfig {
300            tau_membrane: 10.0, // Fast response
301            threshold: 0.5,
302            ..NeuronConfig::default()
303        };
304
305        let event_neurons: Vec<_> = (0..n)
306            .map(|i| LIFNeuron::with_config(i, neuron_config.clone()))
307            .collect();
308
309        let spike_trains: Vec<_> = (0..n)
310            .map(|i| SpikeTrain::with_window(i, config.time_window * 10.0))
311            .collect();
312
313        // Fully connected synapses
314        let mut synapses = SynapseMatrix::new(n, n);
315        for i in 0..n {
316            for j in 0..n {
317                if i != j {
318                    synapses.add_synapse(i, j, 0.0); // Start with zero weights
319                }
320            }
321        }
322
323        // Initialize event type mapping
324        let event_type_map: HashMap<_, _> = [
325            (GraphEventType::EdgeInsert, 0),
326            (GraphEventType::EdgeDelete, 1),
327            (GraphEventType::WeightChange, 2),
328            (GraphEventType::MinCutChange, 3),
329            (GraphEventType::ComponentSplit, 4),
330            (GraphEventType::ComponentMerge, 5),
331        ]
332        .iter()
333        .cloned()
334        .collect();
335
336        let index_to_event: HashMap<_, _> = event_type_map.iter().map(|(k, v)| (*v, *k)).collect();
337
338        Self {
339            event_neurons,
340            spike_trains,
341            synapses,
342            stdp: config.stdp.clone(),
343            config,
344            time: 0.0,
345            event_type_map,
346            index_to_event,
347        }
348    }
349
350    /// Convert graph event to neuron index
351    fn event_to_neuron(&self, event: &GraphEvent) -> usize {
352        self.event_type_map
353            .get(&event.event_type)
354            .copied()
355            .unwrap_or(0)
356    }
357
358    /// Observe a graph event
359    pub fn observe_event(&mut self, event: GraphEvent, timestamp: SimTime) {
360        self.time = timestamp;
361
362        // Convert graph event to spike
363        let neuron_id = self.event_to_neuron(&event);
364
365        if neuron_id < self.event_neurons.len() {
366            // Record spike
367            self.event_neurons[neuron_id].inject_spike(timestamp);
368            self.spike_trains[neuron_id].record_spike(timestamp);
369
370            // STDP update: causal relationships emerge in weights
371            self.stdp
372                .update_weights(&mut self.synapses, neuron_id, timestamp);
373        }
374    }
375
376    /// Process a batch of events
377    pub fn observe_events(&mut self, events: &[GraphEvent], timestamps: &[SimTime]) {
378        for (event, &ts) in events.iter().zip(timestamps.iter()) {
379            self.observe_event(event.clone(), ts);
380        }
381    }
382
383    /// Decay all synaptic weights toward baseline
384    ///
385    /// Applies exponential decay: w' = w * (1 - decay_rate) + baseline * decay_rate
386    pub fn decay_weights(&mut self) {
387        let decay = self.config.decay_rate;
388        let baseline = 0.5; // Neutral weight
389        let n = self.config.num_event_types;
390
391        // Iterate through all possible synapse pairs
392        for i in 0..n {
393            for j in 0..n {
394                if let Some(synapse) = self.synapses.get_synapse_mut(i, j) {
395                    // Exponential decay toward baseline
396                    synapse.weight = synapse.weight * (1.0 - decay) + baseline * decay;
397                }
398            }
399        }
400    }
401
402    /// Extract causal graph from learned weights
403    pub fn extract_causal_graph(&self) -> CausalGraph {
404        let n = self.config.num_event_types;
405        let mut graph = CausalGraph::new(n);
406
407        for ((i, j), synapse) in self.synapses.iter() {
408            let w = synapse.weight;
409
410            if w.abs() > self.config.causal_threshold {
411                let strength = w.abs();
412                let relation = if w > 0.0 {
413                    CausalRelation::Causes
414                } else {
415                    CausalRelation::Prevents
416                };
417
418                graph.add_edge(*i, *j, strength, relation);
419            }
420        }
421
422        graph
423    }
424
425    /// Find optimal intervention points using MinCut on causal graph
426    pub fn optimal_intervention_points(
427        &self,
428        controllable: &[usize],
429        targets: &[usize],
430    ) -> Vec<usize> {
431        let causal = self.extract_causal_graph();
432        let undirected = causal.to_undirected();
433
434        // Simple heuristic: find nodes on paths from controllable to targets
435        let mut intervention_points = Vec::new();
436        let controllable_set: HashSet<_> = controllable.iter().cloned().collect();
437        let target_set: HashSet<_> = targets.iter().cloned().collect();
438
439        for edge in causal.edges() {
440            // If edge connects controllable region to target region
441            if controllable_set.contains(&edge.source) || target_set.contains(&edge.target) {
442                intervention_points.push(edge.source);
443            }
444        }
445
446        intervention_points.sort();
447        intervention_points.dedup();
448        intervention_points
449    }
450
451    /// Get causal strength between two event types
452    pub fn causal_strength(&self, from: GraphEventType, to: GraphEventType) -> f64 {
453        let i = self.event_type_map.get(&from).copied().unwrap_or(0);
454        let j = self.event_type_map.get(&to).copied().unwrap_or(0);
455
456        self.synapses.weight(i, j)
457    }
458
459    /// Get all direct causes of an event type
460    pub fn direct_causes(&self, event_type: GraphEventType) -> Vec<(GraphEventType, f64)> {
461        let j = self.event_type_map.get(&event_type).copied().unwrap_or(0);
462        let mut causes = Vec::new();
463
464        for i in 0..self.config.num_event_types {
465            if i != j {
466                let w = self.synapses.weight(i, j);
467                if w > self.config.causal_threshold {
468                    if let Some(&event) = self.index_to_event.get(&i) {
469                        causes.push((event, w));
470                    }
471                }
472            }
473        }
474
475        causes.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
476        causes
477    }
478
479    /// Get all direct effects of an event type
480    pub fn direct_effects(&self, event_type: GraphEventType) -> Vec<(GraphEventType, f64)> {
481        let i = self.event_type_map.get(&event_type).copied().unwrap_or(0);
482        let mut effects = Vec::new();
483
484        for j in 0..self.config.num_event_types {
485            if i != j {
486                let w = self.synapses.weight(i, j);
487                if w > self.config.causal_threshold {
488                    if let Some(&event) = self.index_to_event.get(&j) {
489                        effects.push((event, w));
490                    }
491                }
492            }
493        }
494
495        effects.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
496        effects
497    }
498
499    /// Reset the SNN
500    pub fn reset(&mut self) {
501        self.time = 0.0;
502
503        for neuron in &mut self.event_neurons {
504            neuron.reset();
505        }
506
507        for train in &mut self.spike_trains {
508            train.clear();
509        }
510
511        // Reset weights to zero
512        for i in 0..self.config.num_event_types {
513            for j in 0..self.config.num_event_types {
514                if i != j {
515                    self.synapses.set_weight(i, j, 0.0);
516                }
517            }
518        }
519    }
520
521    /// Get summary statistics
522    pub fn summary(&self) -> CausalSummary {
523        let causal = self.extract_causal_graph();
524
525        let mut total_strength = 0.0;
526        let mut causes_count = 0;
527        let mut prevents_count = 0;
528
529        for edge in causal.edges() {
530            total_strength += edge.strength;
531            match edge.relation {
532                CausalRelation::Causes => causes_count += 1,
533                CausalRelation::Prevents => prevents_count += 1,
534                CausalRelation::None => {}
535            }
536        }
537
538        CausalSummary {
539            num_relationships: causal.edges().len(),
540            causes_count,
541            prevents_count,
542            avg_strength: total_strength / causal.edges().len().max(1) as f64,
543            time_elapsed: self.time,
544        }
545    }
546}
547
548/// Summary of causal discovery
549#[derive(Debug, Clone)]
550pub struct CausalSummary {
551    /// Total number of discovered relationships
552    pub num_relationships: usize,
553    /// Number of positive causal relationships
554    pub causes_count: usize,
555    /// Number of preventive relationships
556    pub prevents_count: usize,
557    /// Average causal strength
558    pub avg_strength: f64,
559    /// Time elapsed in observation
560    pub time_elapsed: SimTime,
561}
562
563#[cfg(test)]
564mod tests {
565    use super::*;
566
567    #[test]
568    fn test_causal_graph() {
569        let mut graph = CausalGraph::new(5);
570        graph.add_edge(0, 1, 0.8, CausalRelation::Causes);
571        graph.add_edge(1, 2, 0.6, CausalRelation::Causes);
572
573        assert_eq!(graph.edges().len(), 2);
574
575        let reachable = graph.reachable_from(0);
576        assert!(reachable.contains(&1));
577        assert!(reachable.contains(&2));
578    }
579
580    #[test]
581    fn test_causal_discovery_snn() {
582        let config = CausalConfig::default();
583        let mut snn = CausalDiscoverySNN::new(config);
584
585        // Observe events with consistent temporal ordering
586        for i in 0..10 {
587            let t = i as f64 * 10.0;
588
589            // Edge insert always followed by mincut change
590            snn.observe_event(
591                GraphEvent {
592                    event_type: GraphEventType::EdgeInsert,
593                    vertex: None,
594                    edge: Some((0, 1)),
595                    data: 1.0,
596                },
597                t,
598            );
599
600            snn.observe_event(
601                GraphEvent {
602                    event_type: GraphEventType::MinCutChange,
603                    vertex: None,
604                    edge: None,
605                    data: 0.5,
606                },
607                t + 5.0,
608            );
609        }
610
611        let summary = snn.summary();
612        assert!(summary.time_elapsed > 0.0);
613    }
614
615    #[test]
616    fn test_transitive_closure() {
617        let mut graph = CausalGraph::new(4);
618        graph.add_edge(0, 1, 0.8, CausalRelation::Causes);
619        graph.add_edge(1, 2, 0.6, CausalRelation::Causes);
620        graph.add_edge(2, 3, 0.5, CausalRelation::Causes);
621
622        let closed = graph.transitive_closure();
623
624        // Should have indirect edges
625        assert!(closed.edges().len() >= 3);
626    }
627
628    #[test]
629    fn test_intervention_points() {
630        let config = CausalConfig::default();
631        let snn = CausalDiscoverySNN::new(config);
632
633        let interventions = snn.optimal_intervention_points(&[0, 1], &[3, 4]);
634        // Should return some intervention points (may be empty if no learned causality)
635        assert!(interventions.len() >= 0);
636    }
637}