ruvector_mincut/snn/
network.rs

1//! # Spiking Neural Network Architecture
2//!
3//! Provides layered spiking network architecture for integration with MinCut algorithms.
4//!
5//! ## Network Types
6//!
7//! - **Feedforward**: Input → Hidden → Output
8//! - **Recurrent**: With lateral connections
9//! - **Graph-coupled**: Topology mirrors graph structure
10
11use super::{
12    neuron::{LIFNeuron, NeuronConfig, NeuronPopulation, SpikeTrain},
13    synapse::{STDPConfig, Synapse, SynapseMatrix},
14    SimTime, Spike, Vector,
15};
16use crate::graph::DynamicGraph;
17use rayon::prelude::*;
18use std::collections::VecDeque;
19
20/// Configuration for a single network layer
21#[derive(Debug, Clone)]
22pub struct LayerConfig {
23    /// Number of neurons in this layer
24    pub size: usize,
25    /// Neuron configuration
26    pub neuron_config: NeuronConfig,
27    /// Whether this layer has recurrent (lateral) connections
28    pub recurrent: bool,
29}
30
31impl LayerConfig {
32    /// Create a new layer config
33    pub fn new(size: usize) -> Self {
34        Self {
35            size,
36            neuron_config: NeuronConfig::default(),
37            recurrent: false,
38        }
39    }
40
41    /// Enable recurrent connections
42    pub fn with_recurrence(mut self) -> Self {
43        self.recurrent = true;
44        self
45    }
46
47    /// Set custom neuron configuration
48    pub fn with_neuron_config(mut self, config: NeuronConfig) -> Self {
49        self.neuron_config = config;
50        self
51    }
52}
53
54/// Configuration for the full network
55#[derive(Debug, Clone)]
56pub struct NetworkConfig {
57    /// Layer configurations (input to output)
58    pub layers: Vec<LayerConfig>,
59    /// STDP configuration for all synapses
60    pub stdp_config: STDPConfig,
61    /// Time step for simulation
62    pub dt: f64,
63    /// Enable winner-take-all lateral inhibition
64    pub winner_take_all: bool,
65    /// WTA inhibition strength
66    pub wta_strength: f64,
67}
68
69impl Default for NetworkConfig {
70    fn default() -> Self {
71        Self {
72            layers: vec![
73                LayerConfig::new(100), // Input
74                LayerConfig::new(50),  // Hidden
75                LayerConfig::new(10),  // Output
76            ],
77            stdp_config: STDPConfig::default(),
78            dt: 1.0,
79            winner_take_all: false,
80            wta_strength: 0.8,
81        }
82    }
83}
84
85/// A spiking neural network
86#[derive(Debug, Clone)]
87pub struct SpikingNetwork {
88    /// Configuration
89    pub config: NetworkConfig,
90    /// Neurons organized by layer
91    layers: Vec<NeuronPopulation>,
92    /// Feedforward weight matrices (layer i → layer i+1)
93    feedforward_weights: Vec<SynapseMatrix>,
94    /// Recurrent weight matrices (within layer)
95    recurrent_weights: Vec<Option<SynapseMatrix>>,
96    /// Current simulation time
97    time: SimTime,
98    /// Spike buffer for delayed transmission
99    spike_buffer: VecDeque<(Spike, usize, SimTime)>, // (spike, target_layer, arrival_time)
100    /// Global inhibition state (for WTA)
101    global_inhibition: f64,
102}
103
104impl SpikingNetwork {
105    /// Create a new spiking network from configuration
106    pub fn new(config: NetworkConfig) -> Self {
107        let mut layers = Vec::new();
108        let mut feedforward_weights = Vec::new();
109        let mut recurrent_weights = Vec::new();
110
111        for (i, layer_config) in config.layers.iter().enumerate() {
112            // Create neuron population
113            let population = NeuronPopulation::with_config(
114                layer_config.size,
115                layer_config.neuron_config.clone(),
116            );
117            layers.push(population);
118
119            // Create feedforward weights to next layer
120            if i + 1 < config.layers.len() {
121                let next_size = config.layers[i + 1].size;
122                let mut weights = SynapseMatrix::with_config(
123                    layer_config.size,
124                    next_size,
125                    config.stdp_config.clone(),
126                );
127
128                // Initialize with random weights
129                for pre in 0..layer_config.size {
130                    for post in 0..next_size {
131                        let weight = rand_weight();
132                        weights.add_synapse(pre, post, weight);
133                    }
134                }
135
136                feedforward_weights.push(weights);
137            }
138
139            // Create recurrent weights if enabled
140            if layer_config.recurrent {
141                let mut weights = SynapseMatrix::with_config(
142                    layer_config.size,
143                    layer_config.size,
144                    config.stdp_config.clone(),
145                );
146
147                // Sparse random recurrent connections
148                for pre in 0..layer_config.size {
149                    for post in 0..layer_config.size {
150                        if pre != post && rand_bool(0.1) {
151                            weights.add_synapse(pre, post, rand_weight() * 0.5);
152                        }
153                    }
154                }
155
156                recurrent_weights.push(Some(weights));
157            } else {
158                recurrent_weights.push(None);
159            }
160        }
161
162        Self {
163            config,
164            layers,
165            feedforward_weights,
166            recurrent_weights,
167            time: 0.0,
168            spike_buffer: VecDeque::new(),
169            global_inhibition: 0.0,
170        }
171    }
172
173    /// Create network with topology matching a graph
174    pub fn from_graph(graph: &DynamicGraph, config: NetworkConfig) -> Self {
175        let n = graph.num_vertices();
176
177        // Single layer matching graph topology
178        let mut network_config = config.clone();
179        network_config.layers = vec![LayerConfig::new(n).with_recurrence()];
180
181        let mut network = Self::new(network_config);
182
183        // Copy graph edges as recurrent connections
184        if let Some(ref mut recurrent) = network.recurrent_weights[0] {
185            let vertices: Vec<_> = graph.vertices();
186            let vertex_to_idx: std::collections::HashMap<_, _> =
187                vertices.iter().enumerate().map(|(i, &v)| (v, i)).collect();
188
189            for edge in graph.edges() {
190                if let (Some(&pre), Some(&post)) = (
191                    vertex_to_idx.get(&edge.source),
192                    vertex_to_idx.get(&edge.target),
193                ) {
194                    recurrent.set_weight(pre, post, edge.weight);
195                    recurrent.set_weight(post, pre, edge.weight); // Undirected
196                }
197            }
198        }
199
200        network
201    }
202
203    /// Reset network state
204    pub fn reset(&mut self) {
205        self.time = 0.0;
206        self.spike_buffer.clear();
207        self.global_inhibition = 0.0;
208
209        for layer in &mut self.layers {
210            layer.reset();
211        }
212    }
213
214    /// Get number of layers
215    pub fn num_layers(&self) -> usize {
216        self.layers.len()
217    }
218
219    /// Get layer size
220    pub fn layer_size(&self, layer: usize) -> usize {
221        self.layers.get(layer).map(|l| l.size()).unwrap_or(0)
222    }
223
224    /// Get current simulation time
225    pub fn current_time(&self) -> SimTime {
226        self.time
227    }
228
229    /// Inject current to input layer
230    pub fn inject_current(&mut self, currents: &[f64]) {
231        if !self.layers.is_empty() {
232            let input_layer = &mut self.layers[0];
233            let n = currents.len().min(input_layer.size());
234
235            for (i, neuron) in input_layer.neurons.iter_mut().take(n).enumerate() {
236                neuron.set_membrane_potential(neuron.membrane_potential() + currents[i] * 0.1);
237            }
238        }
239    }
240
241    /// Run one integration step
242    /// Returns spikes from output layer
243    pub fn step(&mut self) -> Vec<Spike> {
244        let dt = self.config.dt;
245        self.time += dt;
246
247        // Collect all spikes from this timestep
248        let mut all_spikes: Vec<Vec<Spike>> = Vec::new();
249
250        // Process each layer
251        for layer_idx in 0..self.layers.len() {
252            // Calculate input currents for this layer
253            let mut currents = vec![0.0; self.layers[layer_idx].size()];
254
255            // Add feedforward input from previous layer (sparse iteration)
256            if layer_idx > 0 {
257                let weights = &self.feedforward_weights[layer_idx - 1];
258                // Collect pre-activations once
259                let pre_activations: Vec<f64> = self.layers[layer_idx - 1]
260                    .neurons
261                    .iter()
262                    .map(|n| n.membrane_potential().max(0.0))
263                    .collect();
264                // Use sparse weighted sum computation
265                let ff_currents = weights.compute_weighted_sums(&pre_activations);
266                for (j, &c) in ff_currents.iter().enumerate() {
267                    currents[j] += c;
268                }
269            }
270
271            // Add recurrent input (sparse iteration)
272            if let Some(ref weights) = self.recurrent_weights[layer_idx] {
273                // Collect activations
274                let activations: Vec<f64> = self.layers[layer_idx]
275                    .neurons
276                    .iter()
277                    .map(|n| n.membrane_potential().max(0.0))
278                    .collect();
279                // Use sparse weighted sum computation
280                let rec_currents = weights.compute_weighted_sums(&activations);
281                for (j, &c) in rec_currents.iter().enumerate() {
282                    currents[j] += c;
283                }
284            }
285
286            // Apply winner-take-all inhibition
287            if self.config.winner_take_all && layer_idx == self.layers.len() - 1 {
288                let max_v = self.layers[layer_idx]
289                    .neurons
290                    .iter()
291                    .map(|n| n.membrane_potential())
292                    .fold(f64::NEG_INFINITY, f64::max);
293
294                for (i, neuron) in self.layers[layer_idx].neurons.iter().enumerate() {
295                    if neuron.membrane_potential() < max_v {
296                        currents[i] -= self.config.wta_strength * self.global_inhibition;
297                    }
298                }
299            }
300
301            // Update neurons
302            let spikes = self.layers[layer_idx].step(&currents, dt);
303            all_spikes.push(spikes.clone());
304
305            // Update global inhibition
306            if !spikes.is_empty() {
307                self.global_inhibition = (self.global_inhibition + 0.1).min(1.0);
308            } else {
309                self.global_inhibition *= 0.95;
310            }
311
312            // STDP updates for feedforward weights
313            if layer_idx > 0 {
314                for spike in &spikes {
315                    self.feedforward_weights[layer_idx - 1]
316                        .on_post_spike(spike.neuron_id, self.time);
317                }
318            }
319
320            if layer_idx + 1 < self.layers.len() {
321                for spike in &spikes {
322                    self.feedforward_weights[layer_idx].on_pre_spike(spike.neuron_id, self.time);
323                }
324            }
325        }
326
327        // Return output layer spikes
328        all_spikes.last().cloned().unwrap_or_default()
329    }
330
331    /// Run until a decision is made (output neuron spikes)
332    pub fn run_until_decision(&mut self, max_steps: usize) -> Vec<Spike> {
333        for _ in 0..max_steps {
334            let spikes = self.step();
335            if !spikes.is_empty() {
336                return spikes;
337            }
338        }
339        Vec::new()
340    }
341
342    /// Get population firing rate for a layer
343    pub fn layer_rate(&self, layer: usize, window: f64) -> f64 {
344        self.layers
345            .get(layer)
346            .map(|l| l.population_rate(window))
347            .unwrap_or(0.0)
348    }
349
350    /// Get global synchrony
351    pub fn global_synchrony(&self) -> f64 {
352        let mut total_sync = 0.0;
353        let mut count = 0;
354
355        for layer in &self.layers {
356            total_sync += layer.synchrony(10.0);
357            count += 1;
358        }
359
360        if count > 0 {
361            total_sync / count as f64
362        } else {
363            0.0
364        }
365    }
366
367    /// Get synchrony matrix (pairwise correlation)
368    pub fn synchrony_matrix(&self) -> Vec<Vec<f64>> {
369        // Single layer synchrony for simplicity
370        let layer = &self.layers[0];
371        let n = layer.size();
372        let mut matrix = vec![vec![0.0; n]; n];
373
374        for i in 0..n {
375            for j in (i + 1)..n {
376                let corr =
377                    layer.spike_trains[i].cross_correlation(&layer.spike_trains[j], 50.0, 5.0);
378                let sync = corr.iter().sum::<f64>() / corr.len() as f64;
379                matrix[i][j] = sync;
380                matrix[j][i] = sync;
381            }
382            matrix[i][i] = 1.0;
383        }
384
385        matrix
386    }
387
388    /// Get output layer activities
389    pub fn get_output(&self) -> Vec<f64> {
390        self.layers
391            .last()
392            .map(|l| l.neurons.iter().map(|n| n.membrane_potential()).collect())
393            .unwrap_or_default()
394    }
395
396    /// Apply reward signal for R-STDP
397    pub fn apply_reward(&mut self, reward: f64) {
398        for weights in &mut self.feedforward_weights {
399            weights.apply_reward(reward);
400        }
401        for weights in &mut self.recurrent_weights {
402            if let Some(w) = weights {
403                w.apply_reward(reward);
404            }
405        }
406    }
407
408    /// Get low-activity regions (for search skip optimization)
409    pub fn low_activity_regions(&self) -> Vec<usize> {
410        let mut low_activity = Vec::new();
411        let threshold = 0.001;
412
413        for (layer_idx, layer) in self.layers.iter().enumerate() {
414            for (neuron_idx, train) in layer.spike_trains.iter().enumerate() {
415                if train.spike_rate(100.0) < threshold {
416                    low_activity.push(layer_idx * 1000 + neuron_idx);
417                }
418            }
419        }
420
421        low_activity
422    }
423
424    /// Sync first layer weights back to graph
425    pub fn sync_to_graph(&self, graph: &mut DynamicGraph) {
426        if let Some(ref recurrent) = self.recurrent_weights.first().and_then(|r| r.as_ref()) {
427            let vertices: Vec<_> = graph.vertices();
428
429            for ((pre, post), synapse) in recurrent.iter() {
430                if *pre < vertices.len() && *post < vertices.len() {
431                    let u = vertices[*pre];
432                    let v = vertices[*post];
433                    if graph.has_edge(u, v) {
434                        let _ = graph.update_edge_weight(u, v, synapse.weight);
435                    }
436                }
437            }
438        }
439    }
440}
441
442// Thread-safe PRNG for weight initialization using atomic CAS
443use std::sync::atomic::{AtomicU64, Ordering};
444static RNG_STATE: AtomicU64 = AtomicU64::new(0x853c49e6748fea9b);
445
446fn rand_u64() -> u64 {
447    // Use compare_exchange loop to ensure atomicity
448    loop {
449        let current = RNG_STATE.load(Ordering::Relaxed);
450        let next = current
451            .wrapping_mul(0x5851f42d4c957f2d)
452            .wrapping_add(0x14057b7ef767814f);
453        match RNG_STATE.compare_exchange_weak(current, next, Ordering::Relaxed, Ordering::Relaxed) {
454            Ok(_) => return next,
455            Err(_) => continue, // Retry on contention
456        }
457    }
458}
459
460fn rand_weight() -> f64 {
461    (rand_u64() as f64) / (u64::MAX as f64) * 0.5 + 0.25
462}
463
464fn rand_bool(p: f64) -> bool {
465    (rand_u64() as f64) / (u64::MAX as f64) < p
466}
467
468#[cfg(test)]
469mod tests {
470    use super::*;
471
472    #[test]
473    fn test_network_creation() {
474        let config = NetworkConfig::default();
475        let network = SpikingNetwork::new(config);
476
477        assert_eq!(network.num_layers(), 3);
478        assert_eq!(network.layer_size(0), 100);
479        assert_eq!(network.layer_size(1), 50);
480        assert_eq!(network.layer_size(2), 10);
481    }
482
483    #[test]
484    fn test_network_step() {
485        let config = NetworkConfig::default();
486        let mut network = SpikingNetwork::new(config);
487
488        // Inject strong current
489        let currents = vec![5.0; 100];
490        network.inject_current(&currents);
491
492        // Run several steps
493        let mut total_spikes = 0;
494        for _ in 0..100 {
495            let spikes = network.step();
496            total_spikes += spikes.len();
497        }
498
499        // Should produce some output
500        assert!(network.current_time() > 0.0);
501    }
502
503    #[test]
504    fn test_graph_network() {
505        use crate::graph::DynamicGraph;
506
507        let graph = DynamicGraph::new();
508        graph.insert_edge(0, 1, 1.0).unwrap();
509        graph.insert_edge(1, 2, 1.0).unwrap();
510        graph.insert_edge(2, 0, 1.0).unwrap();
511
512        let config = NetworkConfig::default();
513        let network = SpikingNetwork::from_graph(&graph, config);
514
515        assert_eq!(network.num_layers(), 1);
516        assert_eq!(network.layer_size(0), 3);
517    }
518
519    #[test]
520    fn test_synchrony_matrix() {
521        let mut config = NetworkConfig::default();
522        config.layers = vec![LayerConfig::new(5)];
523
524        let mut network = SpikingNetwork::new(config);
525
526        // Run a bit
527        let currents = vec![2.0; 5];
528        for _ in 0..50 {
529            network.inject_current(&currents);
530            network.step();
531        }
532
533        let sync = network.synchrony_matrix();
534        assert_eq!(sync.len(), 5);
535        assert_eq!(sync[0].len(), 5);
536
537        // Diagonal should be 1
538        for i in 0..5 {
539            assert!((sync[i][i] - 1.0).abs() < 0.001);
540        }
541    }
542}