ruvector_mincut/snn/
network.rs

1//! # Spiking Neural Network Architecture
2//!
3//! Provides layered spiking network architecture for integration with MinCut algorithms.
4//!
5//! ## Network Types
6//!
7//! - **Feedforward**: Input → Hidden → Output
8//! - **Recurrent**: With lateral connections
9//! - **Graph-coupled**: Topology mirrors graph structure
10
11use super::{
12    neuron::{LIFNeuron, NeuronConfig, NeuronPopulation, SpikeTrain},
13    synapse::{Synapse, SynapseMatrix, STDPConfig},
14    SimTime, Spike, Vector,
15};
16use crate::graph::DynamicGraph;
17use rayon::prelude::*;
18use std::collections::VecDeque;
19
20/// Configuration for a single network layer
21#[derive(Debug, Clone)]
22pub struct LayerConfig {
23    /// Number of neurons in this layer
24    pub size: usize,
25    /// Neuron configuration
26    pub neuron_config: NeuronConfig,
27    /// Whether this layer has recurrent (lateral) connections
28    pub recurrent: bool,
29}
30
31impl LayerConfig {
32    /// Create a new layer config
33    pub fn new(size: usize) -> Self {
34        Self {
35            size,
36            neuron_config: NeuronConfig::default(),
37            recurrent: false,
38        }
39    }
40
41    /// Enable recurrent connections
42    pub fn with_recurrence(mut self) -> Self {
43        self.recurrent = true;
44        self
45    }
46
47    /// Set custom neuron configuration
48    pub fn with_neuron_config(mut self, config: NeuronConfig) -> Self {
49        self.neuron_config = config;
50        self
51    }
52}
53
54/// Configuration for the full network
55#[derive(Debug, Clone)]
56pub struct NetworkConfig {
57    /// Layer configurations (input to output)
58    pub layers: Vec<LayerConfig>,
59    /// STDP configuration for all synapses
60    pub stdp_config: STDPConfig,
61    /// Time step for simulation
62    pub dt: f64,
63    /// Enable winner-take-all lateral inhibition
64    pub winner_take_all: bool,
65    /// WTA inhibition strength
66    pub wta_strength: f64,
67}
68
69impl Default for NetworkConfig {
70    fn default() -> Self {
71        Self {
72            layers: vec![
73                LayerConfig::new(100), // Input
74                LayerConfig::new(50),  // Hidden
75                LayerConfig::new(10),  // Output
76            ],
77            stdp_config: STDPConfig::default(),
78            dt: 1.0,
79            winner_take_all: false,
80            wta_strength: 0.8,
81        }
82    }
83}
84
85/// A spiking neural network
86#[derive(Debug, Clone)]
87pub struct SpikingNetwork {
88    /// Configuration
89    pub config: NetworkConfig,
90    /// Neurons organized by layer
91    layers: Vec<NeuronPopulation>,
92    /// Feedforward weight matrices (layer i → layer i+1)
93    feedforward_weights: Vec<SynapseMatrix>,
94    /// Recurrent weight matrices (within layer)
95    recurrent_weights: Vec<Option<SynapseMatrix>>,
96    /// Current simulation time
97    time: SimTime,
98    /// Spike buffer for delayed transmission
99    spike_buffer: VecDeque<(Spike, usize, SimTime)>, // (spike, target_layer, arrival_time)
100    /// Global inhibition state (for WTA)
101    global_inhibition: f64,
102}
103
104impl SpikingNetwork {
105    /// Create a new spiking network from configuration
106    pub fn new(config: NetworkConfig) -> Self {
107        let mut layers = Vec::new();
108        let mut feedforward_weights = Vec::new();
109        let mut recurrent_weights = Vec::new();
110
111        for (i, layer_config) in config.layers.iter().enumerate() {
112            // Create neuron population
113            let population = NeuronPopulation::with_config(
114                layer_config.size,
115                layer_config.neuron_config.clone(),
116            );
117            layers.push(population);
118
119            // Create feedforward weights to next layer
120            if i + 1 < config.layers.len() {
121                let next_size = config.layers[i + 1].size;
122                let mut weights = SynapseMatrix::with_config(
123                    layer_config.size,
124                    next_size,
125                    config.stdp_config.clone(),
126                );
127
128                // Initialize with random weights
129                for pre in 0..layer_config.size {
130                    for post in 0..next_size {
131                        let weight = rand_weight();
132                        weights.add_synapse(pre, post, weight);
133                    }
134                }
135
136                feedforward_weights.push(weights);
137            }
138
139            // Create recurrent weights if enabled
140            if layer_config.recurrent {
141                let mut weights = SynapseMatrix::with_config(
142                    layer_config.size,
143                    layer_config.size,
144                    config.stdp_config.clone(),
145                );
146
147                // Sparse random recurrent connections
148                for pre in 0..layer_config.size {
149                    for post in 0..layer_config.size {
150                        if pre != post && rand_bool(0.1) {
151                            weights.add_synapse(pre, post, rand_weight() * 0.5);
152                        }
153                    }
154                }
155
156                recurrent_weights.push(Some(weights));
157            } else {
158                recurrent_weights.push(None);
159            }
160        }
161
162        Self {
163            config,
164            layers,
165            feedforward_weights,
166            recurrent_weights,
167            time: 0.0,
168            spike_buffer: VecDeque::new(),
169            global_inhibition: 0.0,
170        }
171    }
172
173    /// Create network with topology matching a graph
174    pub fn from_graph(graph: &DynamicGraph, config: NetworkConfig) -> Self {
175        let n = graph.num_vertices();
176
177        // Single layer matching graph topology
178        let mut network_config = config.clone();
179        network_config.layers = vec![LayerConfig::new(n).with_recurrence()];
180
181        let mut network = Self::new(network_config);
182
183        // Copy graph edges as recurrent connections
184        if let Some(ref mut recurrent) = network.recurrent_weights[0] {
185            let vertices: Vec<_> = graph.vertices();
186            let vertex_to_idx: std::collections::HashMap<_, _> = vertices
187                .iter()
188                .enumerate()
189                .map(|(i, &v)| (v, i))
190                .collect();
191
192            for edge in graph.edges() {
193                if let (Some(&pre), Some(&post)) = (
194                    vertex_to_idx.get(&edge.source),
195                    vertex_to_idx.get(&edge.target),
196                ) {
197                    recurrent.set_weight(pre, post, edge.weight);
198                    recurrent.set_weight(post, pre, edge.weight); // Undirected
199                }
200            }
201        }
202
203        network
204    }
205
206    /// Reset network state
207    pub fn reset(&mut self) {
208        self.time = 0.0;
209        self.spike_buffer.clear();
210        self.global_inhibition = 0.0;
211
212        for layer in &mut self.layers {
213            layer.reset();
214        }
215    }
216
217    /// Get number of layers
218    pub fn num_layers(&self) -> usize {
219        self.layers.len()
220    }
221
222    /// Get layer size
223    pub fn layer_size(&self, layer: usize) -> usize {
224        self.layers.get(layer).map(|l| l.size()).unwrap_or(0)
225    }
226
227    /// Get current simulation time
228    pub fn current_time(&self) -> SimTime {
229        self.time
230    }
231
232    /// Inject current to input layer
233    pub fn inject_current(&mut self, currents: &[f64]) {
234        if !self.layers.is_empty() {
235            let input_layer = &mut self.layers[0];
236            let n = currents.len().min(input_layer.size());
237
238            for (i, neuron) in input_layer.neurons.iter_mut().take(n).enumerate() {
239                neuron.set_membrane_potential(
240                    neuron.membrane_potential() + currents[i] * 0.1
241                );
242            }
243        }
244    }
245
246    /// Run one integration step
247    /// Returns spikes from output layer
248    pub fn step(&mut self) -> Vec<Spike> {
249        let dt = self.config.dt;
250        self.time += dt;
251
252        // Collect all spikes from this timestep
253        let mut all_spikes: Vec<Vec<Spike>> = Vec::new();
254
255        // Process each layer
256        for layer_idx in 0..self.layers.len() {
257            // Calculate input currents for this layer
258            let mut currents = vec![0.0; self.layers[layer_idx].size()];
259
260            // Add feedforward input from previous layer (sparse iteration)
261            if layer_idx > 0 {
262                let weights = &self.feedforward_weights[layer_idx - 1];
263                // Collect pre-activations once
264                let pre_activations: Vec<f64> = self.layers[layer_idx - 1]
265                    .neurons
266                    .iter()
267                    .map(|n| n.membrane_potential().max(0.0))
268                    .collect();
269                // Use sparse weighted sum computation
270                let ff_currents = weights.compute_weighted_sums(&pre_activations);
271                for (j, &c) in ff_currents.iter().enumerate() {
272                    currents[j] += c;
273                }
274            }
275
276            // Add recurrent input (sparse iteration)
277            if let Some(ref weights) = self.recurrent_weights[layer_idx] {
278                // Collect activations
279                let activations: Vec<f64> = self.layers[layer_idx]
280                    .neurons
281                    .iter()
282                    .map(|n| n.membrane_potential().max(0.0))
283                    .collect();
284                // Use sparse weighted sum computation
285                let rec_currents = weights.compute_weighted_sums(&activations);
286                for (j, &c) in rec_currents.iter().enumerate() {
287                    currents[j] += c;
288                }
289            }
290
291            // Apply winner-take-all inhibition
292            if self.config.winner_take_all && layer_idx == self.layers.len() - 1 {
293                let max_v = self.layers[layer_idx]
294                    .neurons
295                    .iter()
296                    .map(|n| n.membrane_potential())
297                    .fold(f64::NEG_INFINITY, f64::max);
298
299                for (i, neuron) in self.layers[layer_idx].neurons.iter().enumerate() {
300                    if neuron.membrane_potential() < max_v {
301                        currents[i] -= self.config.wta_strength * self.global_inhibition;
302                    }
303                }
304            }
305
306            // Update neurons
307            let spikes = self.layers[layer_idx].step(&currents, dt);
308            all_spikes.push(spikes.clone());
309
310            // Update global inhibition
311            if !spikes.is_empty() {
312                self.global_inhibition = (self.global_inhibition + 0.1).min(1.0);
313            } else {
314                self.global_inhibition *= 0.95;
315            }
316
317            // STDP updates for feedforward weights
318            if layer_idx > 0 {
319                for spike in &spikes {
320                    self.feedforward_weights[layer_idx - 1].on_post_spike(spike.neuron_id, self.time);
321                }
322            }
323
324            if layer_idx + 1 < self.layers.len() {
325                for spike in &spikes {
326                    self.feedforward_weights[layer_idx].on_pre_spike(spike.neuron_id, self.time);
327                }
328            }
329        }
330
331        // Return output layer spikes
332        all_spikes.last().cloned().unwrap_or_default()
333    }
334
335    /// Run until a decision is made (output neuron spikes)
336    pub fn run_until_decision(&mut self, max_steps: usize) -> Vec<Spike> {
337        for _ in 0..max_steps {
338            let spikes = self.step();
339            if !spikes.is_empty() {
340                return spikes;
341            }
342        }
343        Vec::new()
344    }
345
346    /// Get population firing rate for a layer
347    pub fn layer_rate(&self, layer: usize, window: f64) -> f64 {
348        self.layers
349            .get(layer)
350            .map(|l| l.population_rate(window))
351            .unwrap_or(0.0)
352    }
353
354    /// Get global synchrony
355    pub fn global_synchrony(&self) -> f64 {
356        let mut total_sync = 0.0;
357        let mut count = 0;
358
359        for layer in &self.layers {
360            total_sync += layer.synchrony(10.0);
361            count += 1;
362        }
363
364        if count > 0 {
365            total_sync / count as f64
366        } else {
367            0.0
368        }
369    }
370
371    /// Get synchrony matrix (pairwise correlation)
372    pub fn synchrony_matrix(&self) -> Vec<Vec<f64>> {
373        // Single layer synchrony for simplicity
374        let layer = &self.layers[0];
375        let n = layer.size();
376        let mut matrix = vec![vec![0.0; n]; n];
377
378        for i in 0..n {
379            for j in (i + 1)..n {
380                let corr = layer.spike_trains[i].cross_correlation(
381                    &layer.spike_trains[j],
382                    50.0,
383                    5.0,
384                );
385                let sync = corr.iter().sum::<f64>() / corr.len() as f64;
386                matrix[i][j] = sync;
387                matrix[j][i] = sync;
388            }
389            matrix[i][i] = 1.0;
390        }
391
392        matrix
393    }
394
395    /// Get output layer activities
396    pub fn get_output(&self) -> Vec<f64> {
397        self.layers
398            .last()
399            .map(|l| l.neurons.iter().map(|n| n.membrane_potential()).collect())
400            .unwrap_or_default()
401    }
402
403    /// Apply reward signal for R-STDP
404    pub fn apply_reward(&mut self, reward: f64) {
405        for weights in &mut self.feedforward_weights {
406            weights.apply_reward(reward);
407        }
408        for weights in &mut self.recurrent_weights {
409            if let Some(w) = weights {
410                w.apply_reward(reward);
411            }
412        }
413    }
414
415    /// Get low-activity regions (for search skip optimization)
416    pub fn low_activity_regions(&self) -> Vec<usize> {
417        let mut low_activity = Vec::new();
418        let threshold = 0.001;
419
420        for (layer_idx, layer) in self.layers.iter().enumerate() {
421            for (neuron_idx, train) in layer.spike_trains.iter().enumerate() {
422                if train.spike_rate(100.0) < threshold {
423                    low_activity.push(layer_idx * 1000 + neuron_idx);
424                }
425            }
426        }
427
428        low_activity
429    }
430
431    /// Sync first layer weights back to graph
432    pub fn sync_to_graph(&self, graph: &mut DynamicGraph) {
433        if let Some(ref recurrent) = self.recurrent_weights.first().and_then(|r| r.as_ref()) {
434            let vertices: Vec<_> = graph.vertices();
435
436            for ((pre, post), synapse) in recurrent.iter() {
437                if *pre < vertices.len() && *post < vertices.len() {
438                    let u = vertices[*pre];
439                    let v = vertices[*post];
440                    if graph.has_edge(u, v) {
441                        let _ = graph.update_edge_weight(u, v, synapse.weight);
442                    }
443                }
444            }
445        }
446    }
447}
448
449// Thread-safe PRNG for weight initialization using atomic CAS
450use std::sync::atomic::{AtomicU64, Ordering};
451static RNG_STATE: AtomicU64 = AtomicU64::new(0x853c49e6748fea9b);
452
453fn rand_u64() -> u64 {
454    // Use compare_exchange loop to ensure atomicity
455    loop {
456        let current = RNG_STATE.load(Ordering::Relaxed);
457        let next = current.wrapping_mul(0x5851f42d4c957f2d).wrapping_add(0x14057b7ef767814f);
458        match RNG_STATE.compare_exchange_weak(current, next, Ordering::Relaxed, Ordering::Relaxed) {
459            Ok(_) => return next,
460            Err(_) => continue, // Retry on contention
461        }
462    }
463}
464
465fn rand_weight() -> f64 {
466    (rand_u64() as f64) / (u64::MAX as f64) * 0.5 + 0.25
467}
468
469fn rand_bool(p: f64) -> bool {
470    (rand_u64() as f64) / (u64::MAX as f64) < p
471}
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476
477    #[test]
478    fn test_network_creation() {
479        let config = NetworkConfig::default();
480        let network = SpikingNetwork::new(config);
481
482        assert_eq!(network.num_layers(), 3);
483        assert_eq!(network.layer_size(0), 100);
484        assert_eq!(network.layer_size(1), 50);
485        assert_eq!(network.layer_size(2), 10);
486    }
487
488    #[test]
489    fn test_network_step() {
490        let config = NetworkConfig::default();
491        let mut network = SpikingNetwork::new(config);
492
493        // Inject strong current
494        let currents = vec![5.0; 100];
495        network.inject_current(&currents);
496
497        // Run several steps
498        let mut total_spikes = 0;
499        for _ in 0..100 {
500            let spikes = network.step();
501            total_spikes += spikes.len();
502        }
503
504        // Should produce some output
505        assert!(network.current_time() > 0.0);
506    }
507
508    #[test]
509    fn test_graph_network() {
510        use crate::graph::DynamicGraph;
511
512        let graph = DynamicGraph::new();
513        graph.insert_edge(0, 1, 1.0).unwrap();
514        graph.insert_edge(1, 2, 1.0).unwrap();
515        graph.insert_edge(2, 0, 1.0).unwrap();
516
517        let config = NetworkConfig::default();
518        let network = SpikingNetwork::from_graph(&graph, config);
519
520        assert_eq!(network.num_layers(), 1);
521        assert_eq!(network.layer_size(0), 3);
522    }
523
524    #[test]
525    fn test_synchrony_matrix() {
526        let mut config = NetworkConfig::default();
527        config.layers = vec![LayerConfig::new(5)];
528
529        let mut network = SpikingNetwork::new(config);
530
531        // Run a bit
532        let currents = vec![2.0; 5];
533        for _ in 0..50 {
534            network.inject_current(&currents);
535            network.step();
536        }
537
538        let sync = network.synchrony_matrix();
539        assert_eq!(sync.len(), 5);
540        assert_eq!(sync[0].len(), 5);
541
542        // Diagonal should be 1
543        for i in 0..5 {
544            assert!((sync[i][i] - 1.0).abs() < 0.001);
545        }
546    }
547}