scirs2_optimize/neuromorphic/
spiking_networks.rs

1//! Spiking Neural Network Optimization
2//!
3//! This module implements optimization algorithms based on spiking neural networks,
4//! which process information using discrete spike events rather than continuous signals.
5
6use super::{NeuromorphicConfig, SpikeEvent};
7use scirs2_core::error::CoreResult as Result;
8use scirs2_core::ndarray::{Array1, ArrayView1};
9use scirs2_core::random::Rng;
10use std::collections::VecDeque;
11
12/// Spiking neural network for optimization
13#[derive(Debug, Clone)]
14pub struct SpikingNeuralNetwork {
15    /// Network configuration
16    pub config: NeuromorphicConfig,
17    /// Neuron states
18    pub neurons: Vec<SpikingNeuron>,
19    /// Synaptic connections
20    pub synapses: Vec<Vec<Synapse>>,
21    /// Current simulation time
22    pub current_time: f64,
23    /// Spike history buffer
24    pub spike_history: VecDeque<SpikeEvent>,
25    /// Population activity monitor
26    pub population_activity: Array1<f64>,
27}
28
29/// Spiking neuron model (Leaky Integrate-and-Fire)
30#[derive(Debug, Clone)]
31pub struct SpikingNeuron {
32    /// Membrane potential
33    pub membrane_potential: f64,
34    /// Resting potential
35    pub resting_potential: f64,
36    /// Spike threshold
37    pub threshold: f64,
38    /// Membrane time constant
39    pub tau_membrane: f64,
40    /// Refractory period
41    pub refractory_period: f64,
42    /// Time of last spike
43    pub last_spike_time: Option<f64>,
44    /// Input current
45    pub input_current: f64,
46    /// Adaptation current
47    pub adaptation_current: f64,
48    /// Noise level
49    pub noise_amplitude: f64,
50}
51
52/// Synaptic connection between neurons
53#[derive(Debug, Clone)]
54pub struct Synapse {
55    /// Source neuron index
56    pub source: usize,
57    /// Target neuron index
58    pub target: usize,
59    /// Synaptic weight
60    pub weight: f64,
61    /// Synaptic delay
62    pub delay: f64,
63    /// Short-term plasticity variables
64    pub facilitation: f64,
65    pub depression: f64,
66    /// STDP trace variables
67    pub pre_trace: f64,
68    pub post_trace: f64,
69}
70
71impl SpikingNeuron {
72    /// Create a new LIF neuron
73    pub fn new(config: &NeuromorphicConfig) -> Self {
74        Self {
75            membrane_potential: 0.0,
76            resting_potential: 0.0,
77            threshold: config.spike_threshold,
78            tau_membrane: 0.020, // 20ms membrane time constant
79            refractory_period: config.refractory_period,
80            last_spike_time: None,
81            input_current: 0.0,
82            adaptation_current: 0.0,
83            noise_amplitude: config.noise_level,
84        }
85    }
86
87    /// Update neuron state for one time step
88    pub fn update(&mut self, dt: f64, external_current: f64, current_time: f64) -> Option<f64> {
89        // Check if in refractory period
90        if let Some(last_spike) = self.last_spike_time {
91            if (current_time - last_spike) < self.refractory_period {
92                return None; // Still refractory
93            }
94        }
95
96        // Add noise
97        let noise = if self.noise_amplitude > 0.0 {
98            let mut rng = scirs2_core::random::rng();
99            (rng.random::<f64>() - 0.5) * 2.0 * self.noise_amplitude
100        } else {
101            0.0
102        };
103
104        // Leaky integrate-and-fire dynamics
105        let total_current = external_current + self.input_current - self.adaptation_current + noise;
106        let dv_dt = (-(self.membrane_potential - self.resting_potential) + total_current)
107            / self.tau_membrane;
108
109        self.membrane_potential += dv_dt * dt;
110
111        // Check for spike
112        if self.membrane_potential >= self.threshold {
113            self.fire_spike();
114            Some(0.0) // Return spike time (relative to current time)
115        } else {
116            None
117        }
118    }
119
120    /// Fire a spike and reset membrane potential
121    fn fire_spike(&mut self) {
122        self.membrane_potential = self.resting_potential;
123        self.last_spike_time = Some(0.0); // Will be updated by caller
124
125        // Spike-triggered adaptation
126        self.adaptation_current += 0.1; // Simple adaptation increment
127    }
128
129    /// Decay adaptation current
130    pub fn decay_adaptation(&mut self, dt: f64) {
131        let tau_adaptation = 0.1; // 100ms adaptation time constant
132        self.adaptation_current *= (-dt / tau_adaptation).exp();
133    }
134}
135
136impl Synapse {
137    /// Create a new synapse
138    pub fn new(source: usize, target: usize, weight: f64, delay: f64) -> Self {
139        Self {
140            source,
141            target,
142            weight,
143            delay,
144            facilitation: 1.0,
145            depression: 1.0,
146            pre_trace: 0.0,
147            post_trace: 0.0,
148        }
149    }
150
151    /// Compute synaptic current
152    pub fn compute_current(&self, pre_spike: bool) -> f64 {
153        if pre_spike {
154            self.weight * self.facilitation * self.depression
155        } else {
156            0.0
157        }
158    }
159
160    /// Update short-term plasticity
161    pub fn update_stp(&mut self, dt: f64, pre_spike: bool) {
162        let tau_facilitation = 0.050; // 50ms
163        let tau_depression = 0.100; // 100ms
164
165        // Decay
166        self.facilitation += (1.0 - self.facilitation) * dt / tau_facilitation;
167        self.depression += (1.0 - self.depression) * dt / tau_depression;
168
169        if pre_spike {
170            self.facilitation = (self.facilitation * 1.2).min(3.0); // Facilitate
171            self.depression *= 0.8; // Depress
172        }
173    }
174
175    /// Update STDP traces
176    pub fn update_stdp_traces(&mut self, dt: f64, pre_spike: bool, post_spike: bool) {
177        let tau_stdp = 0.020; // 20ms STDP time constant
178
179        // Decay traces
180        self.pre_trace *= (-dt / tau_stdp).exp();
181        self.post_trace *= (-dt / tau_stdp).exp();
182
183        // Update traces on spikes
184        if pre_spike {
185            self.pre_trace += 1.0;
186        }
187        if post_spike {
188            self.post_trace += 1.0;
189        }
190    }
191
192    /// Apply STDP weight update
193    pub fn apply_stdp(&mut self, learning_rate: f64, pre_spike: bool, post_spike: bool) {
194        let mut weight_change = 0.0;
195
196        if pre_spike && self.post_trace > 0.0 {
197            // Pre-before-post: potentiation
198            weight_change += learning_rate * self.post_trace;
199        }
200
201        if post_spike && self.pre_trace > 0.0 {
202            // Post-before-pre: depression
203            weight_change -= learning_rate * 0.5 * self.pre_trace;
204        }
205
206        self.weight += weight_change;
207        self.weight = self.weight.max(-1.0).min(1.0); // Bound weights
208    }
209}
210
211impl SpikingNeuralNetwork {
212    /// Create a new spiking neural network
213    pub fn new(config: NeuromorphicConfig, num_parameters: usize) -> Self {
214        let mut neurons = Vec::with_capacity(config.num_neurons);
215        for _ in 0..config.num_neurons {
216            neurons.push(SpikingNeuron::new(&config));
217        }
218
219        // Create random connectivity
220        let mut synapses = vec![Vec::new(); config.num_neurons];
221        let connection_probability = 0.1; // 10% connection probability
222        let mut rng = scirs2_core::random::rng();
223
224        for i in 0..config.num_neurons {
225            for j in 0..config.num_neurons {
226                if i != j && rng.random::<f64>() < connection_probability {
227                    let weight = (rng.random::<f64>() - 0.5) * 0.2;
228                    let delay = rng.random::<f64>() * 0.005; // 0-5ms delay
229                    synapses[i].push(Synapse::new(i, j, weight, delay));
230                }
231            }
232        }
233
234        let num_neurons = config.num_neurons;
235        Self {
236            config,
237            neurons,
238            synapses,
239            current_time: 0.0,
240            spike_history: VecDeque::with_capacity(10000),
241            population_activity: Array1::zeros(num_neurons),
242        }
243    }
244
245    /// Encode parameters as spike trains
246    pub fn encode_parameters(&mut self, parameters: &ArrayView1<f64>) {
247        let neurons_per_param = self.config.num_neurons / parameters.len();
248
249        for (param_idx, &param_val) in parameters.iter().enumerate() {
250            let start_idx = param_idx * neurons_per_param;
251            let end_idx = ((param_idx + 1) * neurons_per_param).min(self.config.num_neurons);
252
253            // Rate coding: parameter value determines input current
254            let input_current = (param_val + 1.0) * 5.0; // Scale to reasonable range
255
256            for neuron_idx in start_idx..end_idx {
257                self.neurons[neuron_idx].input_current = input_current;
258            }
259        }
260    }
261
262    /// Decode parameters from population activity
263    pub fn decode_parameters(&self, num_parameters: usize) -> Array1<f64> {
264        let mut decoded = Array1::zeros(num_parameters);
265        let neurons_per_param = self.config.num_neurons / num_parameters;
266
267        for param_idx in 0..num_parameters {
268            let start_idx = param_idx * neurons_per_param;
269            let end_idx = ((param_idx + 1) * neurons_per_param).min(self.config.num_neurons);
270
271            // Average population activity
272            let mut activity_sum = 0.0;
273            for neuron_idx in start_idx..end_idx {
274                activity_sum += self.population_activity[neuron_idx];
275            }
276
277            if end_idx > start_idx {
278                decoded[param_idx] = (activity_sum / (end_idx - start_idx) as f64) - 1.0;
279            }
280        }
281
282        decoded
283    }
284
285    /// Simulate one time step
286    pub fn simulate_step(&mut self, objective_feedback: f64) -> Result<Vec<usize>> {
287        let mut spiked_neurons = Vec::new();
288
289        // Collect inputs for all neurons first to avoid borrow checker issues
290        let inputs: Vec<(f64, f64)> = (0..self.neurons.len())
291            .map(|neuron_idx| {
292                let synaptic_input = self.compute_synaptic_input(neuron_idx);
293                let feedback_input = self.compute_feedback_input(neuron_idx, objective_feedback);
294                (synaptic_input, feedback_input)
295            })
296            .collect();
297
298        // Update all neurons
299        for (neuron_idx, neuron) in self.neurons.iter_mut().enumerate() {
300            let (synaptic_input, feedback_input) = inputs[neuron_idx];
301            let total_input = synaptic_input + feedback_input;
302
303            // Update neuron
304            if let Some(_spike_time) = neuron.update(self.config.dt, total_input, self.current_time)
305            {
306                spiked_neurons.push(neuron_idx);
307                neuron.last_spike_time = Some(self.current_time);
308
309                // Record spike
310                self.spike_history.push_back(SpikeEvent {
311                    time: self.current_time,
312                    neuron_id: neuron_idx,
313                    weight: 1.0,
314                });
315
316                // Update population activity
317                self.population_activity[neuron_idx] = 1.0;
318            } else {
319                // Decay population activity
320                self.population_activity[neuron_idx] *= 0.95;
321            }
322
323            // Decay adaptation
324            neuron.decay_adaptation(self.config.dt);
325        }
326
327        // Update synapses
328        self.update_synapses(&spiked_neurons)?;
329
330        // Cleanup old spikes
331        self.cleanup_spike_history();
332
333        self.current_time += self.config.dt;
334
335        Ok(spiked_neurons)
336    }
337
338    /// Compute synaptic input for a neuron
339    fn compute_synaptic_input(&self, target_neuron: usize) -> f64 {
340        let mut total_input = 0.0;
341
342        // Check all neurons for connections to target
343        for source_neuron in 0..self.config.num_neurons {
344            for synapse in &self.synapses[source_neuron] {
345                if synapse.target == target_neuron {
346                    // Check if source neuron spiked recently (within delay)
347                    if let Some(last_spike) = self.neurons[source_neuron].last_spike_time {
348                        let time_since_spike = self.current_time - last_spike;
349                        if time_since_spike >= synapse.delay
350                            && time_since_spike < synapse.delay + self.config.dt
351                        {
352                            total_input += synapse.compute_current(true);
353                        }
354                    }
355                }
356            }
357        }
358
359        total_input
360    }
361
362    /// Compute objective-based feedback input
363    fn compute_feedback_input(&self, neuron_idx: usize, objective_feedback: f64) -> f64 {
364        // Simple feedback scheme: better objective values give positive input
365        let feedback_strength = 1.0;
366        let normalized_feedback = -objective_feedback; // Assume minimization
367
368        // Different neurons get different phases of feedback
369        let phase = neuron_idx as f64 / self.config.num_neurons as f64 * 2.0 * std::f64::consts::PI;
370        feedback_strength * normalized_feedback * (phase.sin() + 1.0) * 0.5
371    }
372
373    /// Update synaptic plasticity
374    fn update_synapses(&mut self, spiked_neurons: &[usize]) -> Result<()> {
375        for source_neuron in 0..self.config.num_neurons {
376            let source_spiked = spiked_neurons.contains(&source_neuron);
377
378            for synapse in &mut self.synapses[source_neuron] {
379                let target_spiked = spiked_neurons.contains(&synapse.target);
380
381                // Update short-term plasticity
382                synapse.update_stp(self.config.dt, source_spiked);
383
384                // Update STDP traces
385                synapse.update_stdp_traces(self.config.dt, source_spiked, target_spiked);
386
387                // Apply STDP weight updates
388                synapse.apply_stdp(self.config.learning_rate, source_spiked, target_spiked);
389            }
390        }
391
392        Ok(())
393    }
394
395    /// Remove old spikes from history
396    fn cleanup_spike_history(&mut self) {
397        let cutoff_time = self.current_time - 0.1; // Keep 100ms of history
398        while let Some(spike) = self.spike_history.front() {
399            if spike.time < cutoff_time {
400                self.spike_history.pop_front();
401            } else {
402                break;
403            }
404        }
405    }
406
407    /// Get firing rates over recent window
408    pub fn get_firing_rates(&self, window_duration: f64) -> Array1<f64> {
409        let mut rates = Array1::zeros(self.config.num_neurons);
410        let start_time = self.current_time - window_duration;
411
412        for spike in &self.spike_history {
413            if spike.time >= start_time {
414                rates[spike.neuron_id] += 1.0;
415            }
416        }
417
418        // Convert to Hz
419        rates /= window_duration;
420        rates
421    }
422
423    /// Reset network state
424    pub fn reset(&mut self) {
425        self.current_time = 0.0;
426        self.spike_history.clear();
427        self.population_activity.fill(0.0);
428
429        for neuron in &mut self.neurons {
430            neuron.membrane_potential = neuron.resting_potential;
431            neuron.last_spike_time = None;
432            neuron.input_current = 0.0;
433            neuron.adaptation_current = 0.0;
434        }
435
436        // Reset synaptic state
437        for synapse_group in &mut self.synapses {
438            for synapse in synapse_group {
439                synapse.facilitation = 1.0;
440                synapse.depression = 1.0;
441                synapse.pre_trace = 0.0;
442                synapse.post_trace = 0.0;
443            }
444        }
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451
452    #[test]
453    fn test_spiking_neuron_creation() {
454        let config = NeuromorphicConfig::default();
455        let neuron = SpikingNeuron::new(&config);
456
457        assert_eq!(neuron.membrane_potential, 0.0);
458        assert_eq!(neuron.threshold, config.spike_threshold);
459        assert!(neuron.last_spike_time.is_none());
460    }
461
462    #[test]
463    fn test_neuron_spike() {
464        let config = NeuromorphicConfig::default();
465        let mut neuron = SpikingNeuron::new(&config);
466
467        // Apply strong input to cause spike
468        let spike_time = neuron.update(0.001, 50.0, 0.0);
469        assert!(spike_time.is_some());
470        assert_eq!(neuron.membrane_potential, neuron.resting_potential);
471    }
472
473    #[test]
474    fn test_synapse_creation() {
475        let synapse = Synapse::new(0, 1, 0.5, 0.002);
476
477        assert_eq!(synapse.source, 0);
478        assert_eq!(synapse.target, 1);
479        assert_eq!(synapse.weight, 0.5);
480        assert_eq!(synapse.delay, 0.002);
481    }
482
483    #[test]
484    fn test_synapse_current() {
485        let mut synapse = Synapse::new(0, 1, 0.5, 0.001);
486
487        // No current without spike
488        assert_eq!(synapse.compute_current(false), 0.0);
489
490        // Current with spike
491        let current = synapse.compute_current(true);
492        assert!(current > 0.0);
493
494        // Test short-term plasticity
495        synapse.update_stp(0.001, true);
496        let current_after_stp = synapse.compute_current(true);
497        assert!(current_after_stp != current); // Should change due to plasticity
498    }
499
500    #[test]
501    fn test_spiking_network_creation() {
502        let config = NeuromorphicConfig::default();
503        let network = SpikingNeuralNetwork::new(config, 3);
504
505        assert_eq!(network.neurons.len(), 100); // Default num_neurons
506        assert_eq!(network.synapses.len(), 100);
507        assert_eq!(network.current_time, 0.0);
508    }
509
510    #[test]
511    fn test_parameter_encoding() {
512        let config = NeuromorphicConfig::default();
513        let mut network = SpikingNeuralNetwork::new(config, 2);
514
515        let params = Array1::from(vec![0.5, -0.3]);
516        network.encode_parameters(&params.view());
517
518        // Check that some neurons received input
519        assert!(network.neurons.iter().any(|n| n.input_current != 0.0));
520    }
521
522    #[test]
523    fn test_network_simulation() {
524        let config = NeuromorphicConfig {
525            num_neurons: 10,
526            ..Default::default()
527        };
528        let mut network = SpikingNeuralNetwork::new(config, 2);
529
530        // Simulate a few steps
531        for _ in 0..10 {
532            let _spiked = network.simulate_step(1.0).unwrap();
533            // Should complete without error
534        }
535
536        assert!(network.current_time > 0.0);
537    }
538
539    #[test]
540    fn test_firing_rates() {
541        let config = NeuromorphicConfig {
542            num_neurons: 5,
543            ..Default::default()
544        };
545        let mut network = SpikingNeuralNetwork::new(config, 1);
546
547        // Force some spikes by setting high input
548        for neuron in &mut network.neurons {
549            neuron.input_current = 20.0;
550        }
551
552        // Simulate to generate spikes
553        for _ in 0..100 {
554            network.simulate_step(0.0).unwrap();
555        }
556
557        let rates = network.get_firing_rates(0.1);
558        assert!(rates.iter().any(|&r| r > 0.0)); // Should have some firing
559    }
560
561    #[test]
562    fn test_network_reset() {
563        let config = NeuromorphicConfig::default();
564        let mut network = SpikingNeuralNetwork::new(config, 2);
565
566        // Simulate to change state
567        for _ in 0..10 {
568            network.simulate_step(1.0).unwrap();
569        }
570
571        let _time_before_reset = network.current_time;
572        network.reset();
573
574        assert_eq!(network.current_time, 0.0);
575        assert!(network.spike_history.is_empty());
576        assert!(network.population_activity.iter().all(|&x| x == 0.0));
577    }
578}