1use crate::error::{NeuralError, Result};
10use crate::snn::neuron_models::{LIFConfig, LIFNeuron};
11use crate::snn::synapse::ExponentialSynapse;
12
13#[derive(Debug)]
21pub struct SpikingLayer {
22 pub neurons: Vec<LIFNeuron>,
24 pub synapses: Vec<Vec<ExponentialSynapse>>,
26 pub n_in: usize,
28 pub n_out: usize,
30}
31
32impl SpikingLayer {
33 pub fn new(n_in: usize, n_out: usize, config: &LIFConfig, init_weight: f32) -> Result<Self> {
47 if n_in == 0 {
48 return Err(NeuralError::InvalidArgument("n_in must be > 0".into()));
49 }
50 if n_out == 0 {
51 return Err(NeuralError::InvalidArgument("n_out must be > 0".into()));
52 }
53
54 let w = init_weight / n_in as f32;
55 let neurons: Vec<LIFNeuron> = (0..n_out).map(|_| LIFNeuron::new(config)).collect();
56 let synapses: Vec<Vec<ExponentialSynapse>> = (0..n_out)
57 .map(|_| (0..n_in).map(|_| ExponentialSynapse::ampa(w)).collect())
58 .collect();
59
60 Ok(Self {
61 neurons,
62 synapses,
63 n_in,
64 n_out,
65 })
66 }
67
68 pub fn from_weights(weights: &[Vec<f32>], config: &LIFConfig) -> Result<Self> {
77 let n_out = weights.len();
78 if n_out == 0 {
79 return Err(NeuralError::InvalidArgument(
80 "weights must be non-empty".into(),
81 ));
82 }
83 let n_in = weights[0].len();
84 if n_in == 0 {
85 return Err(NeuralError::InvalidArgument(
86 "inner weight dimension must be > 0".into(),
87 ));
88 }
89 for (j, row) in weights.iter().enumerate() {
90 if row.len() != n_in {
91 return Err(NeuralError::DimensionMismatch(format!(
92 "row {j} has {} weights, expected {n_in}",
93 row.len()
94 )));
95 }
96 }
97
98 let neurons: Vec<LIFNeuron> = (0..n_out).map(|_| LIFNeuron::new(config)).collect();
99 let synapses: Vec<Vec<ExponentialSynapse>> = weights
100 .iter()
101 .map(|row| row.iter().map(|&w| ExponentialSynapse::ampa(w)).collect())
102 .collect();
103
104 Ok(Self {
105 neurons,
106 synapses,
107 n_in,
108 n_out,
109 })
110 }
111
112 pub fn forward(&mut self, input_spikes: &[bool], dt: f32) -> Result<Vec<bool>> {
127 if input_spikes.len() != self.n_in {
128 return Err(NeuralError::DimensionMismatch(format!(
129 "input spike length {} != n_in {}",
130 input_spikes.len(),
131 self.n_in
132 )));
133 }
134
135 let mut output_spikes = vec![false; self.n_out];
136
137 for (j, (neuron, syn_row)) in self
138 .neurons
139 .iter_mut()
140 .zip(self.synapses.iter_mut())
141 .enumerate()
142 {
143 let mut total_current = 0.0_f32;
144 for (syn, &spike) in syn_row.iter_mut().zip(input_spikes.iter()) {
145 let g = syn.update(spike, dt);
146 total_current += g * neuron.r_m;
148 }
149 output_spikes[j] = neuron.step(total_current, dt);
150 }
151
152 Ok(output_spikes)
153 }
154
155 pub fn reset(&mut self) {
157 for neuron in self.neurons.iter_mut() {
158 neuron.reset();
159 }
160 for syn_row in self.synapses.iter_mut() {
161 for syn in syn_row.iter_mut() {
162 syn.g = 0.0;
163 }
164 }
165 }
166
167 pub fn weights(&self) -> Vec<Vec<f32>> {
169 self.synapses
170 .iter()
171 .map(|row| row.iter().map(|s| s.weight).collect())
172 .collect()
173 }
174
175 pub fn set_weight(&mut self, out_idx: usize, in_idx: usize, weight: f32) -> Result<()> {
180 if out_idx >= self.n_out {
181 return Err(NeuralError::InvalidArgument(format!(
182 "out_idx {out_idx} >= n_out {}",
183 self.n_out
184 )));
185 }
186 if in_idx >= self.n_in {
187 return Err(NeuralError::InvalidArgument(format!(
188 "in_idx {in_idx} >= n_in {}",
189 self.n_in
190 )));
191 }
192 self.synapses[out_idx][in_idx].weight = weight;
193 Ok(())
194 }
195}
196
197#[derive(Debug)]
206pub struct SpikingNetwork {
207 pub layers: Vec<SpikingLayer>,
209 pub dt: f32,
211}
212
213impl SpikingNetwork {
214 pub fn new(
225 layer_sizes: &[usize],
226 config: &LIFConfig,
227 init_weight: f32,
228 dt: f32,
229 ) -> Result<Self> {
230 if layer_sizes.len() < 2 {
231 return Err(NeuralError::InvalidArchitecture(
232 "At least 2 layer sizes required".into(),
233 ));
234 }
235 let mut layers = Vec::with_capacity(layer_sizes.len() - 1);
236 for window in layer_sizes.windows(2) {
237 let n_in = window[0];
238 let n_out = window[1];
239 layers.push(SpikingLayer::new(n_in, n_out, config, init_weight)?);
240 }
241 Ok(Self { layers, dt })
242 }
243
244 pub fn simulate(
258 &mut self,
259 input_spikes: &[Vec<bool>],
260 t_steps: usize,
261 ) -> Result<Vec<Vec<Vec<bool>>>> {
262 if input_spikes.len() != t_steps {
263 return Err(NeuralError::DimensionMismatch(format!(
264 "input_spikes has {} time steps, expected {t_steps}",
265 input_spikes.len()
266 )));
267 }
268
269 let n_layers = self.layers.len();
270 let mut result: Vec<Vec<Vec<bool>>> = Vec::with_capacity(t_steps);
272
273 for input_t in input_spikes.iter().take(t_steps) {
274 let mut layer_spikes: Vec<Vec<bool>> = Vec::with_capacity(n_layers);
275 let mut current_input = input_t.clone();
276
277 for layer in self.layers.iter_mut() {
278 let out = layer.forward(¤t_input, self.dt)?;
279 layer_spikes.push(out.clone());
280 current_input = out;
281 }
282
283 result.push(layer_spikes);
284 }
285
286 Ok(result)
287 }
288
289 pub fn reset(&mut self) {
291 for layer in self.layers.iter_mut() {
292 layer.reset();
293 }
294 }
295
296 pub fn count_spikes(spike_record: &[Vec<Vec<bool>>]) -> usize {
298 spike_record
299 .iter()
300 .flat_map(|t| t.iter())
301 .flat_map(|l| l.iter())
302 .filter(|&&s| s)
303 .count()
304 }
305
306 pub fn mean_firing_rates(spike_record: &[Vec<Vec<bool>>]) -> Vec<f32> {
308 let t_steps = spike_record.len();
309 if t_steps == 0 {
310 return Vec::new();
311 }
312 let n_layers = spike_record[0].len();
313 let mut rates = vec![0.0_f32; n_layers];
314
315 for t in spike_record.iter() {
316 for (l, layer_spikes) in t.iter().enumerate() {
317 let n = layer_spikes.len() as f32;
318 if n > 0.0 {
319 let fired: f32 = layer_spikes.iter().filter(|&&s| s).count() as f32;
320 rates[l] += fired / n;
321 }
322 }
323 }
324 for r in rates.iter_mut() {
325 *r /= t_steps as f32;
326 }
327 rates
328 }
329}
330
331#[cfg(test)]
336mod tests {
337 use super::*;
338
339 fn default_config() -> LIFConfig {
340 LIFConfig {
341 v_rest: -65.0,
342 v_thresh: -50.0,
343 v_reset: -65.0,
344 tau_m: 20.0,
345 r_m: 10.0,
346 t_ref: 2.0,
347 }
348 }
349
350 #[test]
351 fn spiking_layer_silent_input_silent_output() {
352 let mut layer =
353 SpikingLayer::new(5, 3, &default_config(), 1.0).expect("operation should succeed");
354 for _ in 0..100 {
355 let out = layer
356 .forward(&[false; 5], 0.1)
357 .expect("operation should succeed");
358 assert!(out.iter().all(|&s| !s), "no input → no output");
359 }
360 }
361
362 #[test]
363 fn spiking_layer_strong_input_fires() {
364 let mut layer =
365 SpikingLayer::new(4, 2, &default_config(), 100.0).expect("operation should succeed");
366 let mut any_fired = false;
367 for _ in 0..500 {
368 let out = layer
369 .forward(&[true; 4], 0.5)
370 .expect("operation should succeed");
371 if out.iter().any(|&s| s) {
372 any_fired = true;
373 break;
374 }
375 }
376 assert!(
377 any_fired,
378 "Strong input should cause at least one output spike"
379 );
380 }
381
382 #[test]
383 fn spiking_layer_dimension_mismatch() {
384 let mut layer =
385 SpikingLayer::new(4, 2, &default_config(), 1.0).expect("operation should succeed");
386 let result = layer.forward(&[false; 3], 0.1);
387 assert!(result.is_err());
388 }
389
390 #[test]
391 fn spiking_layer_set_weight() {
392 let mut layer =
393 SpikingLayer::new(3, 2, &default_config(), 1.0).expect("operation should succeed");
394 layer
395 .set_weight(1, 2, 5.0)
396 .expect("operation should succeed");
397 assert!((layer.synapses[1][2].weight - 5.0).abs() < 1e-6);
398 }
399
400 #[test]
401 fn spiking_network_creates_and_simulates() {
402 let config = default_config();
403 let mut net =
404 SpikingNetwork::new(&[4, 3, 2], &config, 5.0, 0.1).expect("operation should succeed");
405 let input: Vec<Vec<bool>> = (0..50).map(|_| vec![true, false, true, false]).collect();
406 let result = net.simulate(&input, 50).expect("operation should succeed");
407 assert_eq!(result.len(), 50);
408 assert_eq!(result[0].len(), 2); assert_eq!(result[0][0].len(), 3); assert_eq!(result[0][1].len(), 2); }
412
413 #[test]
414 fn spiking_network_spike_count_statistics() {
415 let config = default_config();
416 let mut net =
417 SpikingNetwork::new(&[2, 3], &config, 20.0, 1.0).expect("operation should succeed");
418 let input: Vec<Vec<bool>> = (0..100).map(|_| vec![true, true]).collect();
419 let record = net.simulate(&input, 100).expect("operation should succeed");
420 let total = SpikingNetwork::count_spikes(&record);
421 let rates = SpikingNetwork::mean_firing_rates(&record);
422 assert!(total > 0, "Some spikes expected");
423 assert_eq!(rates.len(), 1);
424 }
425
426 #[test]
427 fn spiking_network_rejects_bad_input_length() {
428 let config = default_config();
429 let mut net =
430 SpikingNetwork::new(&[2, 3], &config, 1.0, 0.1).expect("operation should succeed");
431 let input: Vec<Vec<bool>> = vec![vec![true, false]; 5];
433 assert!(net.simulate(&input, 3).is_err());
434 }
435
436 #[test]
437 fn from_weights_roundtrip() {
438 let weights = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
439 let layer = SpikingLayer::from_weights(&weights, &default_config())
440 .expect("operation should succeed");
441 let recovered = layer.weights();
442 for (r, expected) in recovered.iter().zip(weights.iter()) {
443 for (&got, &exp) in r.iter().zip(expected.iter()) {
444 assert!((got - exp).abs() < 1e-6);
445 }
446 }
447 }
448}