scirs2_optimize/neuromorphic/
spiking_networks.rs1use super::{NeuromorphicConfig, SpikeEvent};
7use scirs2_core::error::CoreResult as Result;
8use scirs2_core::ndarray::{Array1, ArrayView1};
9use scirs2_core::random::Rng;
10use std::collections::VecDeque;
11
12#[derive(Debug, Clone)]
14pub struct SpikingNeuralNetwork {
15 pub config: NeuromorphicConfig,
17 pub neurons: Vec<SpikingNeuron>,
19 pub synapses: Vec<Vec<Synapse>>,
21 pub current_time: f64,
23 pub spike_history: VecDeque<SpikeEvent>,
25 pub population_activity: Array1<f64>,
27}
28
29#[derive(Debug, Clone)]
31pub struct SpikingNeuron {
32 pub membrane_potential: f64,
34 pub resting_potential: f64,
36 pub threshold: f64,
38 pub tau_membrane: f64,
40 pub refractory_period: f64,
42 pub last_spike_time: Option<f64>,
44 pub input_current: f64,
46 pub adaptation_current: f64,
48 pub noise_amplitude: f64,
50}
51
52#[derive(Debug, Clone)]
54pub struct Synapse {
55 pub source: usize,
57 pub target: usize,
59 pub weight: f64,
61 pub delay: f64,
63 pub facilitation: f64,
65 pub depression: f64,
66 pub pre_trace: f64,
68 pub post_trace: f64,
69}
70
71impl SpikingNeuron {
72 pub fn new(config: &NeuromorphicConfig) -> Self {
74 Self {
75 membrane_potential: 0.0,
76 resting_potential: 0.0,
77 threshold: config.spike_threshold,
78 tau_membrane: 0.020, refractory_period: config.refractory_period,
80 last_spike_time: None,
81 input_current: 0.0,
82 adaptation_current: 0.0,
83 noise_amplitude: config.noise_level,
84 }
85 }
86
87 pub fn update(&mut self, dt: f64, external_current: f64, current_time: f64) -> Option<f64> {
89 if let Some(last_spike) = self.last_spike_time {
91 if (current_time - last_spike) < self.refractory_period {
92 return None; }
94 }
95
96 let noise = if self.noise_amplitude > 0.0 {
98 let mut rng = scirs2_core::random::rng();
99 (rng.random::<f64>() - 0.5) * 2.0 * self.noise_amplitude
100 } else {
101 0.0
102 };
103
104 let total_current = external_current + self.input_current - self.adaptation_current + noise;
106 let dv_dt = (-(self.membrane_potential - self.resting_potential) + total_current)
107 / self.tau_membrane;
108
109 self.membrane_potential += dv_dt * dt;
110
111 if self.membrane_potential >= self.threshold {
113 self.fire_spike();
114 Some(0.0) } else {
116 None
117 }
118 }
119
120 fn fire_spike(&mut self) {
122 self.membrane_potential = self.resting_potential;
123 self.last_spike_time = Some(0.0); self.adaptation_current += 0.1; }
128
129 pub fn decay_adaptation(&mut self, dt: f64) {
131 let tau_adaptation = 0.1; self.adaptation_current *= (-dt / tau_adaptation).exp();
133 }
134}
135
136impl Synapse {
137 pub fn new(source: usize, target: usize, weight: f64, delay: f64) -> Self {
139 Self {
140 source,
141 target,
142 weight,
143 delay,
144 facilitation: 1.0,
145 depression: 1.0,
146 pre_trace: 0.0,
147 post_trace: 0.0,
148 }
149 }
150
151 pub fn compute_current(&self, pre_spike: bool) -> f64 {
153 if pre_spike {
154 self.weight * self.facilitation * self.depression
155 } else {
156 0.0
157 }
158 }
159
160 pub fn update_stp(&mut self, dt: f64, pre_spike: bool) {
162 let tau_facilitation = 0.050; let tau_depression = 0.100; self.facilitation += (1.0 - self.facilitation) * dt / tau_facilitation;
167 self.depression += (1.0 - self.depression) * dt / tau_depression;
168
169 if pre_spike {
170 self.facilitation = (self.facilitation * 1.2).min(3.0); self.depression *= 0.8; }
173 }
174
175 pub fn update_stdp_traces(&mut self, dt: f64, pre_spike: bool, post_spike: bool) {
177 let tau_stdp = 0.020; self.pre_trace *= (-dt / tau_stdp).exp();
181 self.post_trace *= (-dt / tau_stdp).exp();
182
183 if pre_spike {
185 self.pre_trace += 1.0;
186 }
187 if post_spike {
188 self.post_trace += 1.0;
189 }
190 }
191
192 pub fn apply_stdp(&mut self, learning_rate: f64, pre_spike: bool, post_spike: bool) {
194 let mut weight_change = 0.0;
195
196 if pre_spike && self.post_trace > 0.0 {
197 weight_change += learning_rate * self.post_trace;
199 }
200
201 if post_spike && self.pre_trace > 0.0 {
202 weight_change -= learning_rate * 0.5 * self.pre_trace;
204 }
205
206 self.weight += weight_change;
207 self.weight = self.weight.max(-1.0).min(1.0); }
209}
210
211impl SpikingNeuralNetwork {
212 pub fn new(config: NeuromorphicConfig, num_parameters: usize) -> Self {
214 let mut neurons = Vec::with_capacity(config.num_neurons);
215 for _ in 0..config.num_neurons {
216 neurons.push(SpikingNeuron::new(&config));
217 }
218
219 let mut synapses = vec![Vec::new(); config.num_neurons];
221 let connection_probability = 0.1; let mut rng = scirs2_core::random::rng();
223
224 for i in 0..config.num_neurons {
225 for j in 0..config.num_neurons {
226 if i != j && rng.random::<f64>() < connection_probability {
227 let weight = (rng.random::<f64>() - 0.5) * 0.2;
228 let delay = rng.random::<f64>() * 0.005; synapses[i].push(Synapse::new(i, j, weight, delay));
230 }
231 }
232 }
233
234 let num_neurons = config.num_neurons;
235 Self {
236 config,
237 neurons,
238 synapses,
239 current_time: 0.0,
240 spike_history: VecDeque::with_capacity(10000),
241 population_activity: Array1::zeros(num_neurons),
242 }
243 }
244
245 pub fn encode_parameters(&mut self, parameters: &ArrayView1<f64>) {
247 let neurons_per_param = self.config.num_neurons / parameters.len();
248
249 for (param_idx, ¶m_val) in parameters.iter().enumerate() {
250 let start_idx = param_idx * neurons_per_param;
251 let end_idx = ((param_idx + 1) * neurons_per_param).min(self.config.num_neurons);
252
253 let input_current = (param_val + 1.0) * 5.0; for neuron_idx in start_idx..end_idx {
257 self.neurons[neuron_idx].input_current = input_current;
258 }
259 }
260 }
261
262 pub fn decode_parameters(&self, num_parameters: usize) -> Array1<f64> {
264 let mut decoded = Array1::zeros(num_parameters);
265 let neurons_per_param = self.config.num_neurons / num_parameters;
266
267 for param_idx in 0..num_parameters {
268 let start_idx = param_idx * neurons_per_param;
269 let end_idx = ((param_idx + 1) * neurons_per_param).min(self.config.num_neurons);
270
271 let mut activity_sum = 0.0;
273 for neuron_idx in start_idx..end_idx {
274 activity_sum += self.population_activity[neuron_idx];
275 }
276
277 if end_idx > start_idx {
278 decoded[param_idx] = (activity_sum / (end_idx - start_idx) as f64) - 1.0;
279 }
280 }
281
282 decoded
283 }
284
285 pub fn simulate_step(&mut self, objective_feedback: f64) -> Result<Vec<usize>> {
287 let mut spiked_neurons = Vec::new();
288
289 let inputs: Vec<(f64, f64)> = (0..self.neurons.len())
291 .map(|neuron_idx| {
292 let synaptic_input = self.compute_synaptic_input(neuron_idx);
293 let feedback_input = self.compute_feedback_input(neuron_idx, objective_feedback);
294 (synaptic_input, feedback_input)
295 })
296 .collect();
297
298 for (neuron_idx, neuron) in self.neurons.iter_mut().enumerate() {
300 let (synaptic_input, feedback_input) = inputs[neuron_idx];
301 let total_input = synaptic_input + feedback_input;
302
303 if let Some(_spike_time) = neuron.update(self.config.dt, total_input, self.current_time)
305 {
306 spiked_neurons.push(neuron_idx);
307 neuron.last_spike_time = Some(self.current_time);
308
309 self.spike_history.push_back(SpikeEvent {
311 time: self.current_time,
312 neuron_id: neuron_idx,
313 weight: 1.0,
314 });
315
316 self.population_activity[neuron_idx] = 1.0;
318 } else {
319 self.population_activity[neuron_idx] *= 0.95;
321 }
322
323 neuron.decay_adaptation(self.config.dt);
325 }
326
327 self.update_synapses(&spiked_neurons)?;
329
330 self.cleanup_spike_history();
332
333 self.current_time += self.config.dt;
334
335 Ok(spiked_neurons)
336 }
337
338 fn compute_synaptic_input(&self, target_neuron: usize) -> f64 {
340 let mut total_input = 0.0;
341
342 for source_neuron in 0..self.config.num_neurons {
344 for synapse in &self.synapses[source_neuron] {
345 if synapse.target == target_neuron {
346 if let Some(last_spike) = self.neurons[source_neuron].last_spike_time {
348 let time_since_spike = self.current_time - last_spike;
349 if time_since_spike >= synapse.delay
350 && time_since_spike < synapse.delay + self.config.dt
351 {
352 total_input += synapse.compute_current(true);
353 }
354 }
355 }
356 }
357 }
358
359 total_input
360 }
361
362 fn compute_feedback_input(&self, neuron_idx: usize, objective_feedback: f64) -> f64 {
364 let feedback_strength = 1.0;
366 let normalized_feedback = -objective_feedback; let phase = neuron_idx as f64 / self.config.num_neurons as f64 * 2.0 * std::f64::consts::PI;
370 feedback_strength * normalized_feedback * (phase.sin() + 1.0) * 0.5
371 }
372
373 fn update_synapses(&mut self, spiked_neurons: &[usize]) -> Result<()> {
375 for source_neuron in 0..self.config.num_neurons {
376 let source_spiked = spiked_neurons.contains(&source_neuron);
377
378 for synapse in &mut self.synapses[source_neuron] {
379 let target_spiked = spiked_neurons.contains(&synapse.target);
380
381 synapse.update_stp(self.config.dt, source_spiked);
383
384 synapse.update_stdp_traces(self.config.dt, source_spiked, target_spiked);
386
387 synapse.apply_stdp(self.config.learning_rate, source_spiked, target_spiked);
389 }
390 }
391
392 Ok(())
393 }
394
395 fn cleanup_spike_history(&mut self) {
397 let cutoff_time = self.current_time - 0.1; while let Some(spike) = self.spike_history.front() {
399 if spike.time < cutoff_time {
400 self.spike_history.pop_front();
401 } else {
402 break;
403 }
404 }
405 }
406
407 pub fn get_firing_rates(&self, window_duration: f64) -> Array1<f64> {
409 let mut rates = Array1::zeros(self.config.num_neurons);
410 let start_time = self.current_time - window_duration;
411
412 for spike in &self.spike_history {
413 if spike.time >= start_time {
414 rates[spike.neuron_id] += 1.0;
415 }
416 }
417
418 rates /= window_duration;
420 rates
421 }
422
423 pub fn reset(&mut self) {
425 self.current_time = 0.0;
426 self.spike_history.clear();
427 self.population_activity.fill(0.0);
428
429 for neuron in &mut self.neurons {
430 neuron.membrane_potential = neuron.resting_potential;
431 neuron.last_spike_time = None;
432 neuron.input_current = 0.0;
433 neuron.adaptation_current = 0.0;
434 }
435
436 for synapse_group in &mut self.synapses {
438 for synapse in synapse_group {
439 synapse.facilitation = 1.0;
440 synapse.depression = 1.0;
441 synapse.pre_trace = 0.0;
442 synapse.post_trace = 0.0;
443 }
444 }
445 }
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451
452 #[test]
453 fn test_spiking_neuron_creation() {
454 let config = NeuromorphicConfig::default();
455 let neuron = SpikingNeuron::new(&config);
456
457 assert_eq!(neuron.membrane_potential, 0.0);
458 assert_eq!(neuron.threshold, config.spike_threshold);
459 assert!(neuron.last_spike_time.is_none());
460 }
461
462 #[test]
463 fn test_neuron_spike() {
464 let config = NeuromorphicConfig::default();
465 let mut neuron = SpikingNeuron::new(&config);
466
467 let spike_time = neuron.update(0.001, 50.0, 0.0);
469 assert!(spike_time.is_some());
470 assert_eq!(neuron.membrane_potential, neuron.resting_potential);
471 }
472
473 #[test]
474 fn test_synapse_creation() {
475 let synapse = Synapse::new(0, 1, 0.5, 0.002);
476
477 assert_eq!(synapse.source, 0);
478 assert_eq!(synapse.target, 1);
479 assert_eq!(synapse.weight, 0.5);
480 assert_eq!(synapse.delay, 0.002);
481 }
482
483 #[test]
484 fn test_synapse_current() {
485 let mut synapse = Synapse::new(0, 1, 0.5, 0.001);
486
487 assert_eq!(synapse.compute_current(false), 0.0);
489
490 let current = synapse.compute_current(true);
492 assert!(current > 0.0);
493
494 synapse.update_stp(0.001, true);
496 let current_after_stp = synapse.compute_current(true);
497 assert!(current_after_stp != current); }
499
500 #[test]
501 fn test_spiking_network_creation() {
502 let config = NeuromorphicConfig::default();
503 let network = SpikingNeuralNetwork::new(config, 3);
504
505 assert_eq!(network.neurons.len(), 100); assert_eq!(network.synapses.len(), 100);
507 assert_eq!(network.current_time, 0.0);
508 }
509
510 #[test]
511 fn test_parameter_encoding() {
512 let config = NeuromorphicConfig::default();
513 let mut network = SpikingNeuralNetwork::new(config, 2);
514
515 let params = Array1::from(vec![0.5, -0.3]);
516 network.encode_parameters(¶ms.view());
517
518 assert!(network.neurons.iter().any(|n| n.input_current != 0.0));
520 }
521
522 #[test]
523 fn test_network_simulation() {
524 let config = NeuromorphicConfig {
525 num_neurons: 10,
526 ..Default::default()
527 };
528 let mut network = SpikingNeuralNetwork::new(config, 2);
529
530 for _ in 0..10 {
532 let _spiked = network.simulate_step(1.0).unwrap();
533 }
535
536 assert!(network.current_time > 0.0);
537 }
538
539 #[test]
540 fn test_firing_rates() {
541 let config = NeuromorphicConfig {
542 num_neurons: 5,
543 ..Default::default()
544 };
545 let mut network = SpikingNeuralNetwork::new(config, 1);
546
547 for neuron in &mut network.neurons {
549 neuron.input_current = 20.0;
550 }
551
552 for _ in 0..100 {
554 network.simulate_step(0.0).unwrap();
555 }
556
557 let rates = network.get_firing_rates(0.1);
558 assert!(rates.iter().any(|&r| r > 0.0)); }
560
561 #[test]
562 fn test_network_reset() {
563 let config = NeuromorphicConfig::default();
564 let mut network = SpikingNeuralNetwork::new(config, 2);
565
566 for _ in 0..10 {
568 network.simulate_step(1.0).unwrap();
569 }
570
571 let _time_before_reset = network.current_time;
572 network.reset();
573
574 assert_eq!(network.current_time, 0.0);
575 assert!(network.spike_history.is_empty());
576 assert!(network.population_activity.iter().all(|&x| x == 0.0));
577 }
578}