ruvector_mincut/snn/
causal.rs

1//! # Layer 3: Causal Discovery via Spike Timing
2//!
3//! Uses spike-timing cross-correlation to infer causal relationships in graph events.
4//!
5//! ## Key Insight
6//!
7//! Spike train cross-correlation with asymmetric temporal windows naturally encodes
8//! Granger-like causality:
9//!
10//! ```text
11//! Neuron A:  ──●────────●────────●────────
12//! Neuron B:  ────────●────────●────────●──
13//!            │←─Δt─→│
14//!
15//! If Δt consistently positive → A causes B
16//! STDP learning rule naturally encodes this!
17//! ```
18//!
19//! After learning, W_AB reflects causal strength A→B
20//!
21//! ## MinCut Application
22//!
23//! MinCut on the causal graph reveals optimal intervention points -
24//! minimum changes needed to affect outcomes.
25
26use super::{
27    neuron::{LIFNeuron, NeuronConfig, SpikeTrain},
28    synapse::{Synapse, SynapseMatrix, STDPConfig, AsymmetricSTDP},
29    SimTime, Spike,
30};
31use crate::graph::{DynamicGraph, VertexId, EdgeId};
32use std::collections::{HashMap, HashSet, VecDeque};
33
34/// Configuration for causal discovery
35#[derive(Debug, Clone)]
36pub struct CausalConfig {
37    /// Number of event types (neurons)
38    pub num_event_types: usize,
39    /// Threshold for causal relationship detection
40    pub causal_threshold: f64,
41    /// Time window for causality (ms)
42    pub time_window: f64,
43    /// Asymmetric STDP configuration
44    pub stdp: AsymmetricSTDP,
45    /// Learning rate for causal weight updates
46    pub learning_rate: f64,
47    /// Decay rate for causal weights
48    pub decay_rate: f64,
49}
50
51impl Default for CausalConfig {
52    fn default() -> Self {
53        Self {
54            num_event_types: 100,
55            causal_threshold: 0.1,
56            time_window: 50.0,
57            stdp: AsymmetricSTDP::default(),
58            learning_rate: 0.01,
59            decay_rate: 0.001,
60        }
61    }
62}
63
64/// Type of causal relationship
65#[derive(Debug, Clone, Copy, PartialEq)]
66pub enum CausalRelation {
67    /// A causes B (positive influence)
68    Causes,
69    /// A prevents B (negative influence)
70    Prevents,
71    /// No significant causal relationship
72    None,
73}
74
75/// A directed causal relationship
76#[derive(Debug, Clone)]
77pub struct CausalEdge {
78    /// Source event type
79    pub source: usize,
80    /// Target event type
81    pub target: usize,
82    /// Causal strength (absolute value)
83    pub strength: f64,
84    /// Type of relationship
85    pub relation: CausalRelation,
86}
87
88/// Directed graph representing causal relationships
89#[derive(Debug, Clone)]
90pub struct CausalGraph {
91    /// Number of nodes (event types)
92    pub num_nodes: usize,
93    /// Causal edges
94    edges: Vec<CausalEdge>,
95    /// Adjacency list (source → targets)
96    adjacency: HashMap<usize, Vec<(usize, f64, CausalRelation)>>,
97}
98
99impl CausalGraph {
100    /// Create a new empty causal graph
101    pub fn new(num_nodes: usize) -> Self {
102        Self {
103            num_nodes,
104            edges: Vec::new(),
105            adjacency: HashMap::new(),
106        }
107    }
108
109    /// Add a causal edge
110    pub fn add_edge(&mut self, source: usize, target: usize, strength: f64, relation: CausalRelation) {
111        self.edges.push(CausalEdge {
112            source,
113            target,
114            strength,
115            relation,
116        });
117
118        self.adjacency
119            .entry(source)
120            .or_insert_with(Vec::new)
121            .push((target, strength, relation));
122    }
123
124    /// Get edges from a node
125    pub fn edges_from(&self, source: usize) -> &[(usize, f64, CausalRelation)] {
126        self.adjacency.get(&source).map(|v| v.as_slice()).unwrap_or(&[])
127    }
128
129    /// Get all edges
130    pub fn edges(&self) -> &[CausalEdge] {
131        &self.edges
132    }
133
134    /// Maximum nodes for transitive closure (O(n³) algorithm)
135    const MAX_CLOSURE_NODES: usize = 500;
136
137    /// Compute transitive closure (indirect causation)
138    ///
139    /// Uses Floyd-Warshall algorithm with O(n³) complexity.
140    /// Limited to MAX_CLOSURE_NODES to prevent DoS.
141    pub fn transitive_closure(&self) -> Self {
142        let mut closed = Self::new(self.num_nodes);
143
144        // Resource limit: skip if too many nodes (O(n³) would be too slow)
145        if self.num_nodes > Self::MAX_CLOSURE_NODES {
146            // Just copy direct edges without transitive closure
147            for edge in &self.edges {
148                closed.add_edge(edge.source, edge.target, edge.strength, edge.relation);
149            }
150            return closed;
151        }
152
153        // Copy direct edges
154        for edge in &self.edges {
155            closed.add_edge(edge.source, edge.target, edge.strength, edge.relation);
156        }
157
158        // Floyd-Warshall-like algorithm for transitive closure
159        for k in 0..self.num_nodes {
160            for i in 0..self.num_nodes {
161                for j in 0..self.num_nodes {
162                    if i == j || i == k || j == k {
163                        continue;
164                    }
165
166                    // Check if path i→k→j exists
167                    let ik_strength = self.adjacency.get(&i)
168                        .and_then(|edges| edges.iter().find(|(t, _, _)| *t == k))
169                        .map(|(_, s, _)| *s);
170
171                    let kj_strength = self.adjacency.get(&k)
172                        .and_then(|edges| edges.iter().find(|(t, _, _)| *t == j))
173                        .map(|(_, s, _)| *s);
174
175                    if let (Some(s1), Some(s2)) = (ik_strength, kj_strength) {
176                        let indirect_strength = s1 * s2;
177
178                        // Only add if stronger than existing direct path
179                        let existing = closed.adjacency.get(&i)
180                            .and_then(|edges| edges.iter().find(|(t, _, _)| *t == j))
181                            .map(|(_, s, _)| *s)
182                            .unwrap_or(0.0);
183
184                        if indirect_strength > existing {
185                            closed.add_edge(i, j, indirect_strength, CausalRelation::Causes);
186                        }
187                    }
188                }
189            }
190        }
191
192        closed
193    }
194
195    /// Find nodes reachable from a source
196    pub fn reachable_from(&self, source: usize) -> HashSet<usize> {
197        let mut visited = HashSet::new();
198        let mut queue = VecDeque::new();
199
200        queue.push_back(source);
201        visited.insert(source);
202
203        while let Some(node) = queue.pop_front() {
204            for (target, _, _) in self.edges_from(node) {
205                if visited.insert(*target) {
206                    queue.push_back(*target);
207                }
208            }
209        }
210
211        visited
212    }
213
214    /// Convert to undirected graph for mincut analysis
215    pub fn to_undirected(&self) -> DynamicGraph {
216        let graph = DynamicGraph::new();
217
218        for edge in &self.edges {
219            if !graph.has_edge(edge.source as u64, edge.target as u64) {
220                let _ = graph.insert_edge(
221                    edge.source as u64,
222                    edge.target as u64,
223                    edge.strength,
224                );
225            }
226        }
227
228        graph
229    }
230}
231
232/// Graph event that can be observed
233#[derive(Debug, Clone)]
234pub struct GraphEvent {
235    /// Type of event
236    pub event_type: GraphEventType,
237    /// Associated vertex (if applicable)
238    pub vertex: Option<VertexId>,
239    /// Associated edge (if applicable)
240    pub edge: Option<(VertexId, VertexId)>,
241    /// Event metadata
242    pub data: f64,
243}
244
245/// Types of graph events
246#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
247pub enum GraphEventType {
248    /// Edge was added
249    EdgeInsert,
250    /// Edge was removed
251    EdgeDelete,
252    /// Edge weight changed
253    WeightChange,
254    /// MinCut value changed
255    MinCutChange,
256    /// Component split
257    ComponentSplit,
258    /// Component merged
259    ComponentMerge,
260}
261
262/// Causal discovery using spiking neural network
263pub struct CausalDiscoverySNN {
264    /// One neuron per graph event type
265    event_neurons: Vec<LIFNeuron>,
266    /// Spike trains for each neuron
267    spike_trains: Vec<SpikeTrain>,
268    /// Synaptic weights encode discovered causal strength
269    synapses: SynapseMatrix,
270    /// Asymmetric STDP for causality detection
271    stdp: AsymmetricSTDP,
272    /// Configuration
273    config: CausalConfig,
274    /// Current simulation time
275    time: SimTime,
276    /// Event type mapping
277    event_type_map: HashMap<GraphEventType, usize>,
278    /// Reverse mapping
279    index_to_event: HashMap<usize, GraphEventType>,
280}
281
282impl CausalDiscoverySNN {
283    /// Create a new causal discovery SNN
284    pub fn new(config: CausalConfig) -> Self {
285        let n = config.num_event_types;
286
287        // Create event neurons
288        let neuron_config = NeuronConfig {
289            tau_membrane: 10.0,  // Fast response
290            threshold: 0.5,
291            ..NeuronConfig::default()
292        };
293
294        let event_neurons: Vec<_> = (0..n)
295            .map(|i| LIFNeuron::with_config(i, neuron_config.clone()))
296            .collect();
297
298        let spike_trains: Vec<_> = (0..n)
299            .map(|i| SpikeTrain::with_window(i, config.time_window * 10.0))
300            .collect();
301
302        // Fully connected synapses
303        let mut synapses = SynapseMatrix::new(n, n);
304        for i in 0..n {
305            for j in 0..n {
306                if i != j {
307                    synapses.add_synapse(i, j, 0.0);  // Start with zero weights
308                }
309            }
310        }
311
312        // Initialize event type mapping
313        let event_type_map: HashMap<_, _> = [
314            (GraphEventType::EdgeInsert, 0),
315            (GraphEventType::EdgeDelete, 1),
316            (GraphEventType::WeightChange, 2),
317            (GraphEventType::MinCutChange, 3),
318            (GraphEventType::ComponentSplit, 4),
319            (GraphEventType::ComponentMerge, 5),
320        ].iter().cloned().collect();
321
322        let index_to_event: HashMap<_, _> = event_type_map.iter()
323            .map(|(k, v)| (*v, *k))
324            .collect();
325
326        Self {
327            event_neurons,
328            spike_trains,
329            synapses,
330            stdp: config.stdp.clone(),
331            config,
332            time: 0.0,
333            event_type_map,
334            index_to_event,
335        }
336    }
337
338    /// Convert graph event to neuron index
339    fn event_to_neuron(&self, event: &GraphEvent) -> usize {
340        self.event_type_map.get(&event.event_type).copied().unwrap_or(0)
341    }
342
343    /// Observe a graph event
344    pub fn observe_event(&mut self, event: GraphEvent, timestamp: SimTime) {
345        self.time = timestamp;
346
347        // Convert graph event to spike
348        let neuron_id = self.event_to_neuron(&event);
349
350        if neuron_id < self.event_neurons.len() {
351            // Record spike
352            self.event_neurons[neuron_id].inject_spike(timestamp);
353            self.spike_trains[neuron_id].record_spike(timestamp);
354
355            // STDP update: causal relationships emerge in weights
356            self.stdp.update_weights(&mut self.synapses, neuron_id, timestamp);
357        }
358    }
359
360    /// Process a batch of events
361    pub fn observe_events(&mut self, events: &[GraphEvent], timestamps: &[SimTime]) {
362        for (event, &ts) in events.iter().zip(timestamps.iter()) {
363            self.observe_event(event.clone(), ts);
364        }
365    }
366
367    /// Decay all synaptic weights toward baseline
368    ///
369    /// Applies exponential decay: w' = w * (1 - decay_rate) + baseline * decay_rate
370    pub fn decay_weights(&mut self) {
371        let decay = self.config.decay_rate;
372        let baseline = 0.5; // Neutral weight
373        let n = self.config.num_event_types;
374
375        // Iterate through all possible synapse pairs
376        for i in 0..n {
377            for j in 0..n {
378                if let Some(synapse) = self.synapses.get_synapse_mut(i, j) {
379                    // Exponential decay toward baseline
380                    synapse.weight = synapse.weight * (1.0 - decay) + baseline * decay;
381                }
382            }
383        }
384    }
385
386    /// Extract causal graph from learned weights
387    pub fn extract_causal_graph(&self) -> CausalGraph {
388        let n = self.config.num_event_types;
389        let mut graph = CausalGraph::new(n);
390
391        for ((i, j), synapse) in self.synapses.iter() {
392            let w = synapse.weight;
393
394            if w.abs() > self.config.causal_threshold {
395                let strength = w.abs();
396                let relation = if w > 0.0 {
397                    CausalRelation::Causes
398                } else {
399                    CausalRelation::Prevents
400                };
401
402                graph.add_edge(*i, *j, strength, relation);
403            }
404        }
405
406        graph
407    }
408
409    /// Find optimal intervention points using MinCut on causal graph
410    pub fn optimal_intervention_points(
411        &self,
412        controllable: &[usize],
413        targets: &[usize],
414    ) -> Vec<usize> {
415        let causal = self.extract_causal_graph();
416        let undirected = causal.to_undirected();
417
418        // Simple heuristic: find nodes on paths from controllable to targets
419        let mut intervention_points = Vec::new();
420        let controllable_set: HashSet<_> = controllable.iter().cloned().collect();
421        let target_set: HashSet<_> = targets.iter().cloned().collect();
422
423        for edge in causal.edges() {
424            // If edge connects controllable region to target region
425            if controllable_set.contains(&edge.source) ||
426               target_set.contains(&edge.target) {
427                intervention_points.push(edge.source);
428            }
429        }
430
431        intervention_points.sort();
432        intervention_points.dedup();
433        intervention_points
434    }
435
436    /// Get causal strength between two event types
437    pub fn causal_strength(&self, from: GraphEventType, to: GraphEventType) -> f64 {
438        let i = self.event_type_map.get(&from).copied().unwrap_or(0);
439        let j = self.event_type_map.get(&to).copied().unwrap_or(0);
440
441        self.synapses.weight(i, j)
442    }
443
444    /// Get all direct causes of an event type
445    pub fn direct_causes(&self, event_type: GraphEventType) -> Vec<(GraphEventType, f64)> {
446        let j = self.event_type_map.get(&event_type).copied().unwrap_or(0);
447        let mut causes = Vec::new();
448
449        for i in 0..self.config.num_event_types {
450            if i != j {
451                let w = self.synapses.weight(i, j);
452                if w > self.config.causal_threshold {
453                    if let Some(&event) = self.index_to_event.get(&i) {
454                        causes.push((event, w));
455                    }
456                }
457            }
458        }
459
460        causes.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
461        causes
462    }
463
464    /// Get all direct effects of an event type
465    pub fn direct_effects(&self, event_type: GraphEventType) -> Vec<(GraphEventType, f64)> {
466        let i = self.event_type_map.get(&event_type).copied().unwrap_or(0);
467        let mut effects = Vec::new();
468
469        for j in 0..self.config.num_event_types {
470            if i != j {
471                let w = self.synapses.weight(i, j);
472                if w > self.config.causal_threshold {
473                    if let Some(&event) = self.index_to_event.get(&j) {
474                        effects.push((event, w));
475                    }
476                }
477            }
478        }
479
480        effects.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
481        effects
482    }
483
484    /// Reset the SNN
485    pub fn reset(&mut self) {
486        self.time = 0.0;
487
488        for neuron in &mut self.event_neurons {
489            neuron.reset();
490        }
491
492        for train in &mut self.spike_trains {
493            train.clear();
494        }
495
496        // Reset weights to zero
497        for i in 0..self.config.num_event_types {
498            for j in 0..self.config.num_event_types {
499                if i != j {
500                    self.synapses.set_weight(i, j, 0.0);
501                }
502            }
503        }
504    }
505
506    /// Get summary statistics
507    pub fn summary(&self) -> CausalSummary {
508        let causal = self.extract_causal_graph();
509
510        let mut total_strength = 0.0;
511        let mut causes_count = 0;
512        let mut prevents_count = 0;
513
514        for edge in causal.edges() {
515            total_strength += edge.strength;
516            match edge.relation {
517                CausalRelation::Causes => causes_count += 1,
518                CausalRelation::Prevents => prevents_count += 1,
519                CausalRelation::None => {}
520            }
521        }
522
523        CausalSummary {
524            num_relationships: causal.edges().len(),
525            causes_count,
526            prevents_count,
527            avg_strength: total_strength / causal.edges().len().max(1) as f64,
528            time_elapsed: self.time,
529        }
530    }
531}
532
533/// Summary of causal discovery
534#[derive(Debug, Clone)]
535pub struct CausalSummary {
536    /// Total number of discovered relationships
537    pub num_relationships: usize,
538    /// Number of positive causal relationships
539    pub causes_count: usize,
540    /// Number of preventive relationships
541    pub prevents_count: usize,
542    /// Average causal strength
543    pub avg_strength: f64,
544    /// Time elapsed in observation
545    pub time_elapsed: SimTime,
546}
547
548#[cfg(test)]
549mod tests {
550    use super::*;
551
552    #[test]
553    fn test_causal_graph() {
554        let mut graph = CausalGraph::new(5);
555        graph.add_edge(0, 1, 0.8, CausalRelation::Causes);
556        graph.add_edge(1, 2, 0.6, CausalRelation::Causes);
557
558        assert_eq!(graph.edges().len(), 2);
559
560        let reachable = graph.reachable_from(0);
561        assert!(reachable.contains(&1));
562        assert!(reachable.contains(&2));
563    }
564
565    #[test]
566    fn test_causal_discovery_snn() {
567        let config = CausalConfig::default();
568        let mut snn = CausalDiscoverySNN::new(config);
569
570        // Observe events with consistent temporal ordering
571        for i in 0..10 {
572            let t = i as f64 * 10.0;
573
574            // Edge insert always followed by mincut change
575            snn.observe_event(
576                GraphEvent {
577                    event_type: GraphEventType::EdgeInsert,
578                    vertex: None,
579                    edge: Some((0, 1)),
580                    data: 1.0,
581                },
582                t,
583            );
584
585            snn.observe_event(
586                GraphEvent {
587                    event_type: GraphEventType::MinCutChange,
588                    vertex: None,
589                    edge: None,
590                    data: 0.5,
591                },
592                t + 5.0,
593            );
594        }
595
596        let summary = snn.summary();
597        assert!(summary.time_elapsed > 0.0);
598    }
599
600    #[test]
601    fn test_transitive_closure() {
602        let mut graph = CausalGraph::new(4);
603        graph.add_edge(0, 1, 0.8, CausalRelation::Causes);
604        graph.add_edge(1, 2, 0.6, CausalRelation::Causes);
605        graph.add_edge(2, 3, 0.5, CausalRelation::Causes);
606
607        let closed = graph.transitive_closure();
608
609        // Should have indirect edges
610        assert!(closed.edges().len() >= 3);
611    }
612
613    #[test]
614    fn test_intervention_points() {
615        let config = CausalConfig::default();
616        let snn = CausalDiscoverySNN::new(config);
617
618        let interventions = snn.optimal_intervention_points(&[0, 1], &[3, 4]);
619        // Should return some intervention points (may be empty if no learned causality)
620        assert!(interventions.len() >= 0);
621    }
622}