Skip to main content

scirs2_neural/snn/
snn_layer.rs

1//! Spiking Neural Network Layers and Networks
2//!
3//! Provides:
4//! - `SpikingLayer` — a layer of LIF neurons with exponential synapses
5//! - `SpikingNetwork` — multi-layer SNN with full simulation loop
6//! - `SpikeEncoder` re-exported for convenience
7//! - `SpikeTrain` statistics utilities
8
9use crate::error::{NeuralError, Result};
10use crate::snn::neuron_models::{LIFConfig, LIFNeuron};
11use crate::snn::synapse::ExponentialSynapse;
12
13// ---------------------------------------------------------------------------
14// SpikingLayer
15// ---------------------------------------------------------------------------
16
17/// A single layer of LIF neurons, each receiving inputs through exponential synapses.
18///
19/// Connectivity: dense (n_in × n_out) with independent weights per synapse.
20#[derive(Debug)]
21pub struct SpikingLayer {
22    /// Output neurons
23    pub neurons: Vec<LIFNeuron>,
24    /// Synapses: `synapses[j][i]` is the synapse from input `i` to output `j`
25    pub synapses: Vec<Vec<ExponentialSynapse>>,
26    /// Number of input channels
27    pub n_in: usize,
28    /// Number of output neurons
29    pub n_out: usize,
30}
31
32impl SpikingLayer {
33    /// Create a new spiking layer with initialised LIF neurons and AMPA synapses.
34    ///
35    /// All synaptic weights are initialised to `init_weight / n_in` to keep
36    /// the total input approximately constant regardless of fan-in.
37    ///
38    /// # Arguments
39    /// * `n_in`        — number of input spike channels
40    /// * `n_out`       — number of output neurons
41    /// * `config`      — LIF configuration
42    /// * `init_weight` — baseline total excitatory drive (before fan-in scaling)
43    ///
44    /// # Errors
45    /// Returns an error if `n_in == 0` or `n_out == 0`.
46    pub fn new(n_in: usize, n_out: usize, config: &LIFConfig, init_weight: f32) -> Result<Self> {
47        if n_in == 0 {
48            return Err(NeuralError::InvalidArgument("n_in must be > 0".into()));
49        }
50        if n_out == 0 {
51            return Err(NeuralError::InvalidArgument("n_out must be > 0".into()));
52        }
53
54        let w = init_weight / n_in as f32;
55        let neurons: Vec<LIFNeuron> = (0..n_out).map(|_| LIFNeuron::new(config)).collect();
56        let synapses: Vec<Vec<ExponentialSynapse>> = (0..n_out)
57            .map(|_| (0..n_in).map(|_| ExponentialSynapse::ampa(w)).collect())
58            .collect();
59
60        Ok(Self {
61            neurons,
62            synapses,
63            n_in,
64            n_out,
65        })
66    }
67
68    /// Create a layer with explicit synaptic weights.
69    ///
70    /// # Arguments
71    /// * `weights` — 2-D weight matrix, shape [n_out][n_in]
72    /// * `config`  — LIF neuron configuration
73    ///
74    /// # Errors
75    /// Returns an error if the weight matrix is jagged.
76    pub fn from_weights(weights: &[Vec<f32>], config: &LIFConfig) -> Result<Self> {
77        let n_out = weights.len();
78        if n_out == 0 {
79            return Err(NeuralError::InvalidArgument(
80                "weights must be non-empty".into(),
81            ));
82        }
83        let n_in = weights[0].len();
84        if n_in == 0 {
85            return Err(NeuralError::InvalidArgument(
86                "inner weight dimension must be > 0".into(),
87            ));
88        }
89        for (j, row) in weights.iter().enumerate() {
90            if row.len() != n_in {
91                return Err(NeuralError::DimensionMismatch(format!(
92                    "row {j} has {} weights, expected {n_in}",
93                    row.len()
94                )));
95            }
96        }
97
98        let neurons: Vec<LIFNeuron> = (0..n_out).map(|_| LIFNeuron::new(config)).collect();
99        let synapses: Vec<Vec<ExponentialSynapse>> = weights
100            .iter()
101            .map(|row| row.iter().map(|&w| ExponentialSynapse::ampa(w)).collect())
102            .collect();
103
104        Ok(Self {
105            neurons,
106            synapses,
107            n_in,
108            n_out,
109        })
110    }
111
112    /// Run the layer forward for one time step.
113    ///
114    /// Each output neuron accumulates current from all active input synapses
115    /// and integrates it via its LIF dynamics.
116    ///
117    /// # Arguments
118    /// * `input_spikes` — boolean spike vector of length `n_in`
119    /// * `dt`           — time step (ms)
120    ///
121    /// # Returns
122    /// Boolean spike vector of length `n_out`.
123    ///
124    /// # Errors
125    /// Returns an error if `input_spikes.len() != n_in`.
126    pub fn forward(&mut self, input_spikes: &[bool], dt: f32) -> Result<Vec<bool>> {
127        if input_spikes.len() != self.n_in {
128            return Err(NeuralError::DimensionMismatch(format!(
129                "input spike length {} != n_in {}",
130                input_spikes.len(),
131                self.n_in
132            )));
133        }
134
135        let mut output_spikes = vec![false; self.n_out];
136
137        for (j, (neuron, syn_row)) in self
138            .neurons
139            .iter_mut()
140            .zip(self.synapses.iter_mut())
141            .enumerate()
142        {
143            let mut total_current = 0.0_f32;
144            for (syn, &spike) in syn_row.iter_mut().zip(input_spikes.iter()) {
145                let g = syn.update(spike, dt);
146                // Use fixed post-synaptic potential at rest for current calculation
147                total_current += g * neuron.r_m;
148            }
149            output_spikes[j] = neuron.step(total_current, dt);
150        }
151
152        Ok(output_spikes)
153    }
154
155    /// Reset all neurons and synapses to their initial state.
156    pub fn reset(&mut self) {
157        for neuron in self.neurons.iter_mut() {
158            neuron.reset();
159        }
160        for syn_row in self.synapses.iter_mut() {
161            for syn in syn_row.iter_mut() {
162                syn.g = 0.0;
163            }
164        }
165    }
166
167    /// Get the weight matrix as a 2D vector [n_out][n_in].
168    pub fn weights(&self) -> Vec<Vec<f32>> {
169        self.synapses
170            .iter()
171            .map(|row| row.iter().map(|s| s.weight).collect())
172            .collect()
173    }
174
175    /// Set the weight of a specific synapse.
176    ///
177    /// # Errors
178    /// Returns an error if indices are out of bounds.
179    pub fn set_weight(&mut self, out_idx: usize, in_idx: usize, weight: f32) -> Result<()> {
180        if out_idx >= self.n_out {
181            return Err(NeuralError::InvalidArgument(format!(
182                "out_idx {out_idx} >= n_out {}",
183                self.n_out
184            )));
185        }
186        if in_idx >= self.n_in {
187            return Err(NeuralError::InvalidArgument(format!(
188                "in_idx {in_idx} >= n_in {}",
189                self.n_in
190            )));
191        }
192        self.synapses[out_idx][in_idx].weight = weight;
193        Ok(())
194    }
195}
196
197// ---------------------------------------------------------------------------
198// SpikingNetwork
199// ---------------------------------------------------------------------------
200
201/// Multi-layer spiking neural network.
202///
203/// Layers are stacked sequentially; the output spike train of layer k is
204/// the input to layer k+1.
205#[derive(Debug)]
206pub struct SpikingNetwork {
207    /// Ordered list of spiking layers
208    pub layers: Vec<SpikingLayer>,
209    /// Simulation time step (ms)
210    pub dt: f32,
211}
212
213impl SpikingNetwork {
214    /// Create a spiking network by stacking layers.
215    ///
216    /// # Arguments
217    /// * `layer_sizes` — `[n0, n1, …, nL]` where n0 is input size and nL is output size
218    /// * `config`      — LIF configuration applied to all layers
219    /// * `init_weight` — initial total weight per output neuron
220    /// * `dt`          — simulation time step (ms)
221    ///
222    /// # Errors
223    /// Returns an error if fewer than 2 sizes are given.
224    pub fn new(
225        layer_sizes: &[usize],
226        config: &LIFConfig,
227        init_weight: f32,
228        dt: f32,
229    ) -> Result<Self> {
230        if layer_sizes.len() < 2 {
231            return Err(NeuralError::InvalidArchitecture(
232                "At least 2 layer sizes required".into(),
233            ));
234        }
235        let mut layers = Vec::with_capacity(layer_sizes.len() - 1);
236        for window in layer_sizes.windows(2) {
237            let n_in = window[0];
238            let n_out = window[1];
239            layers.push(SpikingLayer::new(n_in, n_out, config, init_weight)?);
240        }
241        Ok(Self { layers, dt })
242    }
243
244    /// Simulate the network for `T` time steps given input spike trains.
245    ///
246    /// # Arguments
247    /// * `input_spikes` — spike trains for the input layer: `input_spikes[t]` is
248    ///   the input spike vector at time step `t`
249    /// * `t_steps`      — number of simulation steps (must equal `input_spikes.len()`)
250    ///
251    /// # Returns
252    /// Spike trains for every layer at every time step:
253    /// `result[t][layer][neuron]`
254    ///
255    /// # Errors
256    /// Returns an error on dimension mismatches.
257    pub fn simulate(
258        &mut self,
259        input_spikes: &[Vec<bool>],
260        t_steps: usize,
261    ) -> Result<Vec<Vec<Vec<bool>>>> {
262        if input_spikes.len() != t_steps {
263            return Err(NeuralError::DimensionMismatch(format!(
264                "input_spikes has {} time steps, expected {t_steps}",
265                input_spikes.len()
266            )));
267        }
268
269        let n_layers = self.layers.len();
270        // result[t][layer] = spike vector
271        let mut result: Vec<Vec<Vec<bool>>> = Vec::with_capacity(t_steps);
272
273        for input_t in input_spikes.iter().take(t_steps) {
274            let mut layer_spikes: Vec<Vec<bool>> = Vec::with_capacity(n_layers);
275            let mut current_input = input_t.clone();
276
277            for layer in self.layers.iter_mut() {
278                let out = layer.forward(&current_input, self.dt)?;
279                layer_spikes.push(out.clone());
280                current_input = out;
281            }
282
283            result.push(layer_spikes);
284        }
285
286        Ok(result)
287    }
288
289    /// Reset all layers to their initial state.
290    pub fn reset(&mut self) {
291        for layer in self.layers.iter_mut() {
292            layer.reset();
293        }
294    }
295
296    /// Count the total number of spikes across all layers and time steps.
297    pub fn count_spikes(spike_record: &[Vec<Vec<bool>>]) -> usize {
298        spike_record
299            .iter()
300            .flat_map(|t| t.iter())
301            .flat_map(|l| l.iter())
302            .filter(|&&s| s)
303            .count()
304    }
305
306    /// Compute mean firing rate (spikes / neuron / time step) for each layer.
307    pub fn mean_firing_rates(spike_record: &[Vec<Vec<bool>>]) -> Vec<f32> {
308        let t_steps = spike_record.len();
309        if t_steps == 0 {
310            return Vec::new();
311        }
312        let n_layers = spike_record[0].len();
313        let mut rates = vec![0.0_f32; n_layers];
314
315        for t in spike_record.iter() {
316            for (l, layer_spikes) in t.iter().enumerate() {
317                let n = layer_spikes.len() as f32;
318                if n > 0.0 {
319                    let fired: f32 = layer_spikes.iter().filter(|&&s| s).count() as f32;
320                    rates[l] += fired / n;
321                }
322            }
323        }
324        for r in rates.iter_mut() {
325            *r /= t_steps as f32;
326        }
327        rates
328    }
329}
330
331// ---------------------------------------------------------------------------
332// Tests
333// ---------------------------------------------------------------------------
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    fn default_config() -> LIFConfig {
340        LIFConfig {
341            v_rest: -65.0,
342            v_thresh: -50.0,
343            v_reset: -65.0,
344            tau_m: 20.0,
345            r_m: 10.0,
346            t_ref: 2.0,
347        }
348    }
349
350    #[test]
351    fn spiking_layer_silent_input_silent_output() {
352        let mut layer =
353            SpikingLayer::new(5, 3, &default_config(), 1.0).expect("operation should succeed");
354        for _ in 0..100 {
355            let out = layer
356                .forward(&[false; 5], 0.1)
357                .expect("operation should succeed");
358            assert!(out.iter().all(|&s| !s), "no input → no output");
359        }
360    }
361
362    #[test]
363    fn spiking_layer_strong_input_fires() {
364        let mut layer =
365            SpikingLayer::new(4, 2, &default_config(), 100.0).expect("operation should succeed");
366        let mut any_fired = false;
367        for _ in 0..500 {
368            let out = layer
369                .forward(&[true; 4], 0.5)
370                .expect("operation should succeed");
371            if out.iter().any(|&s| s) {
372                any_fired = true;
373                break;
374            }
375        }
376        assert!(
377            any_fired,
378            "Strong input should cause at least one output spike"
379        );
380    }
381
382    #[test]
383    fn spiking_layer_dimension_mismatch() {
384        let mut layer =
385            SpikingLayer::new(4, 2, &default_config(), 1.0).expect("operation should succeed");
386        let result = layer.forward(&[false; 3], 0.1);
387        assert!(result.is_err());
388    }
389
390    #[test]
391    fn spiking_layer_set_weight() {
392        let mut layer =
393            SpikingLayer::new(3, 2, &default_config(), 1.0).expect("operation should succeed");
394        layer
395            .set_weight(1, 2, 5.0)
396            .expect("operation should succeed");
397        assert!((layer.synapses[1][2].weight - 5.0).abs() < 1e-6);
398    }
399
400    #[test]
401    fn spiking_network_creates_and_simulates() {
402        let config = default_config();
403        let mut net =
404            SpikingNetwork::new(&[4, 3, 2], &config, 5.0, 0.1).expect("operation should succeed");
405        let input: Vec<Vec<bool>> = (0..50).map(|_| vec![true, false, true, false]).collect();
406        let result = net.simulate(&input, 50).expect("operation should succeed");
407        assert_eq!(result.len(), 50);
408        assert_eq!(result[0].len(), 2); // 2 layers
409        assert_eq!(result[0][0].len(), 3); // first hidden layer has 3 neurons
410        assert_eq!(result[0][1].len(), 2); // output layer has 2 neurons
411    }
412
413    #[test]
414    fn spiking_network_spike_count_statistics() {
415        let config = default_config();
416        let mut net =
417            SpikingNetwork::new(&[2, 3], &config, 20.0, 1.0).expect("operation should succeed");
418        let input: Vec<Vec<bool>> = (0..100).map(|_| vec![true, true]).collect();
419        let record = net.simulate(&input, 100).expect("operation should succeed");
420        let total = SpikingNetwork::count_spikes(&record);
421        let rates = SpikingNetwork::mean_firing_rates(&record);
422        assert!(total > 0, "Some spikes expected");
423        assert_eq!(rates.len(), 1);
424    }
425
426    #[test]
427    fn spiking_network_rejects_bad_input_length() {
428        let config = default_config();
429        let mut net =
430            SpikingNetwork::new(&[2, 3], &config, 1.0, 0.1).expect("operation should succeed");
431        // 5 steps provided but t_steps=3
432        let input: Vec<Vec<bool>> = vec![vec![true, false]; 5];
433        assert!(net.simulate(&input, 3).is_err());
434    }
435
436    #[test]
437    fn from_weights_roundtrip() {
438        let weights = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
439        let layer = SpikingLayer::from_weights(&weights, &default_config())
440            .expect("operation should succeed");
441        let recovered = layer.weights();
442        for (r, expected) in recovered.iter().zip(weights.iter()) {
443            for (&got, &exp) in r.iter().zip(expected.iter()) {
444                assert!((got - exp).abs() < 1e-6);
445            }
446        }
447    }
448}