scirs2_optimize/neuromorphic/
stdp_learning.rs

1//! Advanced Spike-Timing Dependent Plasticity (STDP) Learning
2//!
3//! Implementation of cutting-edge STDP-based optimization algorithms with:
4//! - Multi-timescale adaptive plasticity
5//! - Homeostatic mechanisms and synaptic scaling
6//! - Intrinsic plasticity for neurons
7//! - Metaplasticity (plasticity of plasticity)
8//! - Triplet STDP rules
9//! - Calcium-based synaptic dynamics
10
11use scirs2_core::error::CoreResult as Result;
12use scirs2_core::ndarray::{Array1, ArrayView1};
13use scirs2_core::random::Rng;
14use statrs::statistics::Statistics;
15use std::collections::VecDeque;
16
17/// Advanced Multi-Timescale STDP with Metaplasticity
18#[derive(Debug, Clone)]
19pub struct AdvancedAdvancedSTDP {
20    // Basic STDP traces
21    pub pre_trace_fast: f64,
22    pub post_trace_fast: f64,
23    pub pre_trace_slow: f64,
24    pub post_trace_slow: f64,
25
26    // Triplet STDP traces
27    pub pre_trace_triplet: f64,
28    pub post_trace_triplet: f64,
29
30    // Calcium dynamics
31    pub calcium_concentration: f64,
32    pub calcium_threshold_low: f64,
33    pub calcium_threshold_high: f64,
34
35    // Metaplasticity variables
36    pub metaplasticity_factor: f64,
37    pub recent_activity: VecDeque<f64>,
38    pub sliding_threshold: f64,
39
40    // Homeostatic variables
41    pub target_firing_rate: f64,
42    pub current_firing_rate: f64,
43    pub scaling_factor: f64,
44
45    // Time constants
46    pub tau_plus_fast: f64,
47    pub tau_minus_fast: f64,
48    pub tau_plus_slow: f64,
49    pub tau_minus_slow: f64,
50    pub tau_calcium: f64,
51    pub tau_metaplasticity: f64,
52
53    // Learning rates
54    pub eta_ltp: f64,
55    pub eta_ltd: f64,
56    pub eta_triplet: f64,
57    pub eta_homeostatic: f64,
58
59    // BCM-like thresholds
60    pub theta_d: f64,
61    pub theta_p: f64,
62
63    // Spike timing windows
64    pub spike_history_pre: VecDeque<f64>,
65    pub spike_history_post: VecDeque<f64>,
66
67    // Weight bounds
68    pub w_min: f64,
69    pub w_max: f64,
70}
71
72impl AdvancedAdvancedSTDP {
73    /// Create new advanced STDP rule with sophisticated plasticity mechanisms
74    pub fn new(eta_ltp: f64, eta_ltd: f64, target_firing_rate: f64) -> Self {
75        Self {
76            // Initialize traces
77            pre_trace_fast: 0.0,
78            post_trace_fast: 0.0,
79            pre_trace_slow: 0.0,
80            post_trace_slow: 0.0,
81            pre_trace_triplet: 0.0,
82            post_trace_triplet: 0.0,
83
84            // Calcium dynamics
85            calcium_concentration: 0.0,
86            calcium_threshold_low: 0.2,
87            calcium_threshold_high: 0.6,
88
89            // Metaplasticity
90            metaplasticity_factor: 1.0,
91            recent_activity: VecDeque::with_capacity(1000),
92            sliding_threshold: 0.5,
93
94            // Homeostasis
95            target_firing_rate,
96            current_firing_rate: 0.0,
97            scaling_factor: 1.0,
98
99            // Time constants (in seconds)
100            tau_plus_fast: 0.017,      // 17ms fast LTP
101            tau_minus_fast: 0.034,     // 34ms fast LTD
102            tau_plus_slow: 0.688,      // 688ms slow LTP
103            tau_minus_slow: 0.688,     // 688ms slow LTD
104            tau_calcium: 0.048,        // 48ms calcium decay
105            tau_metaplasticity: 100.0, // 100s metaplasticity
106
107            // Learning rates
108            eta_ltp,
109            eta_ltd,
110            eta_triplet: eta_ltp * 0.1,
111            eta_homeostatic: eta_ltp * 0.01,
112
113            // BCM thresholds
114            theta_d: 0.2,
115            theta_p: 0.8,
116
117            // Spike histories
118            spike_history_pre: VecDeque::with_capacity(100),
119            spike_history_post: VecDeque::with_capacity(100),
120
121            // Weight bounds
122            w_min: -2.0,
123            w_max: 2.0,
124        }
125    }
126
127    /// Update all internal states and compute weight change
128    pub fn update_weight_advanced(
129        &mut self,
130        current_weight: f64,
131        pre_spike: bool,
132        post_spike: bool,
133        dt: f64,
134        current_time: f64,
135        objective_improvement: f64,
136    ) -> f64 {
137        // Update calcium concentration based on spikes
138        self.update_calcium(pre_spike, post_spike, dt);
139
140        // Update metaplasticity based on recent activity
141        self.update_metaplasticity(current_time, objective_improvement);
142
143        // Update homeostatic scaling
144        self.update_homeostasis(pre_spike, post_spike, dt);
145
146        // Decay all traces
147        self.decay_traces(dt);
148
149        // Update spike histories
150        if pre_spike {
151            self.spike_history_pre.push_back(current_time);
152            if self.spike_history_pre.len() > 100 {
153                self.spike_history_pre.pop_front();
154            }
155        }
156        if post_spike {
157            self.spike_history_post.push_back(current_time);
158            if self.spike_history_post.len() > 100 {
159                self.spike_history_post.pop_front();
160            }
161        }
162
163        let mut total_weight_change = 0.0;
164
165        // 1. Pairwise STDP with multiple timescales
166        total_weight_change += self.compute_pairwise_stdp(pre_spike, post_spike);
167
168        // 2. Triplet STDP for complex spike patterns
169        total_weight_change += self.compute_triplet_stdp(pre_spike, post_spike);
170
171        // 3. Calcium-based plasticity rules
172        total_weight_change += self.compute_calcium_plasticity(current_weight);
173
174        // 4. BCM-like metaplasticity
175        total_weight_change += self.compute_bcm_plasticity(pre_spike, post_spike, current_weight);
176
177        // 5. Homeostatic synaptic scaling
178        total_weight_change += self.compute_homeostatic_scaling(current_weight);
179
180        // Apply metaplasticity modulation
181        total_weight_change *= self.metaplasticity_factor;
182
183        // Apply global scaling factor
184        total_weight_change *= self.scaling_factor;
185
186        // Apply weight bounds and soft constraints
187        let new_weight = current_weight + total_weight_change;
188        self.apply_weight_constraints(new_weight)
189    }
190
191    fn update_calcium(&mut self, pre_spike: bool, post_spike: bool, dt: f64) {
192        // Decay calcium
193        self.calcium_concentration *= (-dt / self.tau_calcium).exp();
194
195        // Add calcium from spikes
196        if pre_spike {
197            self.calcium_concentration += 0.1;
198        }
199        if post_spike {
200            self.calcium_concentration += 0.2;
201        }
202
203        // Bound calcium concentration
204        self.calcium_concentration = self.calcium_concentration.min(1.0);
205    }
206
207    fn update_metaplasticity(&mut self, current_time: f64, objective_improvement: f64) {
208        // Store recent activity
209        self.recent_activity.push_back(objective_improvement);
210        if self.recent_activity.len() > 1000 {
211            self.recent_activity.pop_front();
212        }
213
214        // Compute activity variance for metaplasticity
215        if self.recent_activity.len() > 10 {
216            let mean: f64 =
217                self.recent_activity.iter().sum::<f64>() / self.recent_activity.len() as f64;
218            let variance: f64 = self
219                .recent_activity
220                .iter()
221                .map(|&x| (x - mean).powi(2))
222                .sum::<f64>()
223                / self.recent_activity.len() as f64;
224
225            // High variance increases plasticity
226            self.metaplasticity_factor = 1.0 + variance.sqrt();
227
228            // Update sliding threshold
229            self.sliding_threshold = 0.9 * self.sliding_threshold + 0.1 * mean.abs();
230        }
231    }
232
233    fn update_homeostasis(&mut self, pre_spike: bool, post_spike: bool, dt: f64) {
234        // Update current firing rate estimate
235        let spike_rate = if post_spike { 1.0 / dt } else { 0.0 };
236        self.current_firing_rate = 0.999 * self.current_firing_rate + 0.001 * spike_rate;
237
238        // Compute homeostatic scaling factor
239        let rate_ratio = self.current_firing_rate / self.target_firing_rate.max(0.1);
240        self.scaling_factor = (2.0 / (1.0 + rate_ratio)).min(2.0).max(0.5);
241    }
242
243    fn decay_traces(&mut self, dt: f64) {
244        // Decay fast traces
245        self.pre_trace_fast *= (-dt / self.tau_plus_fast).exp();
246        self.post_trace_fast *= (-dt / self.tau_minus_fast).exp();
247
248        // Decay slow traces
249        self.pre_trace_slow *= (-dt / self.tau_plus_slow).exp();
250        self.post_trace_slow *= (-dt / self.tau_minus_slow).exp();
251
252        // Decay triplet traces
253        self.pre_trace_triplet *= (-dt / (self.tau_plus_fast * 2.0)).exp();
254        self.post_trace_triplet *= (-dt / (self.tau_minus_fast * 2.0)).exp();
255    }
256
257    fn compute_pairwise_stdp(&mut self, pre_spike: bool, post_spike: bool) -> f64 {
258        let mut weight_change = 0.0;
259
260        if pre_spike {
261            self.pre_trace_fast += 1.0;
262            self.pre_trace_slow += 1.0;
263
264            // LTD: post-before-pre (fast and slow)
265            weight_change -= self.eta_ltd * (self.post_trace_fast + 0.1 * self.post_trace_slow);
266        }
267
268        if post_spike {
269            self.post_trace_fast += 1.0;
270            self.post_trace_slow += 1.0;
271
272            // LTP: pre-before-post (fast and slow)
273            weight_change += self.eta_ltp * (self.pre_trace_fast + 0.1 * self.pre_trace_slow);
274        }
275
276        weight_change
277    }
278
279    fn compute_triplet_stdp(&mut self, pre_spike: bool, post_spike: bool) -> f64 {
280        let mut weight_change = 0.0;
281
282        if pre_spike {
283            self.pre_trace_triplet += 1.0;
284            // Triplet LTD
285            weight_change -= self.eta_triplet * self.post_trace_fast * self.post_trace_triplet;
286        }
287
288        if post_spike {
289            self.post_trace_triplet += 1.0;
290            // Triplet LTP
291            weight_change += self.eta_triplet * self.pre_trace_fast * self.pre_trace_triplet;
292        }
293
294        weight_change
295    }
296
297    fn compute_calcium_plasticity(&self, current_weight: f64) -> f64 {
298        let ca = self.calcium_concentration;
299
300        if ca < self.calcium_threshold_low {
301            // Low calcium: LTD
302            -self.eta_ltd * 0.1 * current_weight.abs()
303        } else if ca > self.calcium_threshold_high {
304            // High calcium: LTP
305            self.eta_ltp * 0.1 * (self.w_max - current_weight.abs())
306        } else {
307            // Intermediate calcium: proportional to calcium level
308            let normalized_ca = (ca - self.calcium_threshold_low)
309                / (self.calcium_threshold_high - self.calcium_threshold_low);
310            self.eta_ltp * 0.05 * (2.0 * normalized_ca - 1.0)
311        }
312    }
313
314    fn compute_bcm_plasticity(
315        &self,
316        pre_spike: bool,
317        post_spike: bool,
318        _current_weight: f64,
319    ) -> f64 {
320        if !pre_spike && !post_spike {
321            return 0.0;
322        }
323
324        let post_activity = if post_spike { 1.0 } else { 0.0 };
325        let pre_activity = if pre_spike { 1.0 } else { 0.0 };
326
327        // BCM rule: Δw ∝ pre * post * (post - θ)
328        let theta = self.sliding_threshold;
329        pre_activity * post_activity * (post_activity - theta) * self.eta_ltp * 0.1
330    }
331
332    fn compute_homeostatic_scaling(&self, current_weight: f64) -> f64 {
333        // Homeostatic synaptic scaling to maintain target activity
334        let rate_error = self.target_firing_rate - self.current_firing_rate;
335        self.eta_homeostatic * rate_error * current_weight * 0.01
336    }
337
338    fn apply_weight_constraints(&self, weight: f64) -> f64 {
339        // Soft bounds with exponential penalty near limits
340        if weight > self.w_max {
341            self.w_max - (weight - self.w_max).exp().recip()
342        } else if weight < self.w_min {
343            self.w_min + (self.w_min - weight).exp().recip()
344        } else {
345            weight
346        }
347    }
348
349    /// Get plasticity statistics for monitoring
350    pub fn get_plasticity_stats(&self) -> PlasticityStats {
351        PlasticityStats {
352            calcium_level: self.calcium_concentration,
353            metaplasticity_factor: self.metaplasticity_factor,
354            scaling_factor: self.scaling_factor,
355            firing_rate_error: self.target_firing_rate - self.current_firing_rate,
356            sliding_threshold: self.sliding_threshold,
357            trace_strength: (self.pre_trace_fast + self.post_trace_fast) / 2.0,
358        }
359    }
360}
361
362/// Statistics for monitoring plasticity mechanisms
363#[derive(Debug, Clone)]
364pub struct PlasticityStats {
365    pub calcium_level: f64,
366    pub metaplasticity_factor: f64,
367    pub scaling_factor: f64,
368    pub firing_rate_error: f64,
369    pub sliding_threshold: f64,
370    pub trace_strength: f64,
371}
372
373/// Legacy STDP learning rule for backward compatibility
374#[derive(Debug, Clone)]
375pub struct STDPLearningRule {
376    /// Pre-synaptic trace
377    pub pre_trace: f64,
378    /// Post-synaptic trace  
379    pub post_trace: f64,
380    /// Learning rate
381    pub learning_rate: f64,
382    /// Time constants
383    pub tau_plus: f64,
384    pub tau_minus: f64,
385}
386
387impl STDPLearningRule {
388    /// Create new simple STDP rule
389    pub fn new(learning_rate: f64) -> Self {
390        Self {
391            pre_trace: 0.0,
392            post_trace: 0.0,
393            learning_rate,
394            tau_plus: 0.020,  // 20ms
395            tau_minus: 0.020, // 20ms
396        }
397    }
398
399    /// Update synaptic weight based on spike timing
400    pub fn update_weight(
401        &mut self,
402        current_weight: f64,
403        pre_spike: bool,
404        post_spike: bool,
405        dt: f64,
406    ) -> f64 {
407        // Decay traces
408        self.pre_trace *= (-dt / self.tau_plus).exp();
409        self.post_trace *= (-dt / self.tau_minus).exp();
410
411        let mut weight_change = 0.0;
412
413        if pre_spike {
414            self.pre_trace += 1.0;
415            // LTD: post-before-pre
416            if self.post_trace > 0.0 {
417                weight_change -= self.learning_rate * self.post_trace;
418            }
419        }
420
421        if post_spike {
422            self.post_trace += 1.0;
423            // LTP: pre-before-post
424            if self.pre_trace > 0.0 {
425                weight_change += self.learning_rate * self.pre_trace;
426            }
427        }
428
429        (current_weight + weight_change).max(-1.0).min(1.0)
430    }
431}
432
433/// Advanced-advanced STDP network for complex optimization problems
434#[derive(Debug, Clone)]
435pub struct AdvancedSTDPNetwork {
436    /// Network layers
437    pub layers: Vec<STDPLayer>,
438    /// Advanced-advanced STDP rules
439    pub advanced_stdp_rules: Vec<Vec<AdvancedAdvancedSTDP>>,
440    /// Current parameters being optimized
441    pub current_params: Array1<f64>,
442    /// Best parameters found
443    pub best_params: Array1<f64>,
444    /// Best objective value
445    pub best_objective: f64,
446    /// Iteration counter
447    pub nit: usize,
448    /// Network statistics
449    pub network_stats: NetworkStats,
450}
451
452/// Layer in STDP network
453#[derive(Debug, Clone)]
454pub struct STDPLayer {
455    /// Layer size
456    pub size: usize,
457    /// Neuron potentials
458    pub potentials: Array1<f64>,
459    /// Spike times
460    pub last_spike_times: Array1<Option<f64>>,
461    /// Firing rates
462    pub firing_rates: Array1<f64>,
463}
464
465/// Network-wide statistics
466#[derive(Debug, Clone)]
467pub struct NetworkStats {
468    /// Average plasticity across all synapses
469    pub avg_plasticity: f64,
470    /// Network synchrony measure
471    pub synchrony: f64,
472    /// Energy consumption estimate
473    pub energy_consumption: f64,
474    /// Convergence measure
475    pub convergence: f64,
476}
477
478impl Default for NetworkStats {
479    fn default() -> Self {
480        Self {
481            avg_plasticity: 0.0,
482            synchrony: 0.0,
483            energy_consumption: 0.0,
484            convergence: 0.0,
485        }
486    }
487}
488
489impl AdvancedSTDPNetwork {
490    /// Create new advanced STDP network
491    pub fn new(layer_sizes: Vec<usize>, target_firing_rate: f64, learning_rate: f64) -> Self {
492        let mut layers = Vec::new();
493        let mut advanced_stdp_rules = Vec::new();
494
495        for (layer_idx, &size) in layer_sizes.iter().enumerate() {
496            let layer = STDPLayer {
497                size,
498                potentials: Array1::zeros(size),
499                last_spike_times: Array1::from_vec(vec![None; size]),
500                firing_rates: Array1::zeros(size),
501            };
502            layers.push(layer);
503
504            // Create STDP rules for connections from previous layer
505            if layer_idx > 0 {
506                let prev_size = layer_sizes[layer_idx - 1];
507                let mut layer_rules = Vec::new();
508
509                for _i in 0..size {
510                    for _j in 0..prev_size {
511                        layer_rules.push(AdvancedAdvancedSTDP::new(
512                            learning_rate,
513                            learning_rate * 0.5,
514                            target_firing_rate,
515                        ));
516                    }
517                }
518                advanced_stdp_rules.push(layer_rules);
519            }
520        }
521
522        let input_size = layer_sizes[0];
523
524        Self {
525            layers,
526            advanced_stdp_rules,
527            current_params: Array1::zeros(input_size),
528            best_params: Array1::zeros(input_size),
529            best_objective: f64::INFINITY,
530            nit: 0,
531            network_stats: NetworkStats::default(),
532        }
533    }
534
535    /// Run advanced STDP optimization
536    pub fn optimize<F>(
537        &mut self,
538        objective: F,
539        initial_params: &ArrayView1<f64>,
540        max_nit: usize,
541        dt: f64,
542    ) -> Result<Array1<f64>>
543    where
544        F: Fn(&ArrayView1<f64>) -> f64,
545    {
546        self.current_params = initial_params.to_owned();
547        self.best_params = initial_params.to_owned();
548        self.best_objective = objective(initial_params);
549
550        let mut prev_objective = self.best_objective;
551
552        for iteration in 0..max_nit {
553            let current_time = iteration as f64 * dt;
554
555            // Evaluate current objective
556            let current_objective = objective(&self.current_params.view());
557            let objective_improvement = prev_objective - current_objective;
558
559            // Update best solution
560            if current_objective < self.best_objective {
561                self.best_objective = current_objective;
562                self.best_params = self.current_params.clone();
563            }
564
565            // Encode parameters as spike patterns
566            let spike_patterns =
567                self.encode_parameters_to_spikes(&self.current_params, current_time);
568
569            // Simulate network dynamics
570            let network_spikes =
571                self.simulate_network_dynamics(&spike_patterns, current_time, dt)?;
572
573            // Update synaptic weights using advanced STDP
574            self.update_advanced_stdp_weights(
575                &network_spikes,
576                current_time,
577                dt,
578                objective_improvement,
579            )?;
580
581            // Decode new parameters from network state
582            let param_updates = self.decode_parameters_from_network(current_time);
583
584            // Apply parameter updates with adaptive step size
585            let step_size = self.compute_adaptive_step_size(objective_improvement, iteration);
586            for (i, update) in param_updates.iter().enumerate() {
587                if i < self.current_params.len() {
588                    self.current_params[i] += step_size * update;
589                }
590            }
591
592            // Update network statistics
593            self.update_network_statistics(current_time);
594
595            // Check convergence
596            if objective_improvement.abs() < 1e-8 && iteration > 100 {
597                break;
598            }
599
600            prev_objective = current_objective;
601            self.nit = iteration + 1;
602        }
603
604        Ok(self.best_params.clone())
605    }
606
607    fn encode_parameters_to_spikes(
608        &self,
609        params: &Array1<f64>,
610        _current_time: f64,
611    ) -> Vec<Vec<bool>> {
612        let mut spike_patterns = Vec::new();
613
614        for layer in &self.layers {
615            let mut layer_spikes = vec![false; layer.size];
616
617            // For first layer, use parameter values to determine spike probability
618            for i in 0..layer.size.min(params.len()) {
619                let spike_prob = ((params[i] + 1.0) / 2.0).max(0.0).min(1.0);
620                layer_spikes[i] = scirs2_core::random::rng().random::<f64>() < spike_prob * 0.1;
621            }
622
623            spike_patterns.push(layer_spikes);
624        }
625
626        spike_patterns
627    }
628
629    fn simulate_network_dynamics(
630        &mut self,
631        input_spikes: &[Vec<bool>],
632        current_time: f64,
633        dt: f64,
634    ) -> Result<Vec<Vec<bool>>> {
635        let mut all_spikes = input_spikes.to_vec();
636
637        // Propagate through layers
638        for layer_idx in 1..self.layers.len() {
639            let mut layer_spikes = vec![false; self.layers[layer_idx].size];
640
641            for neuron_idx in 0..self.layers[layer_idx].size {
642                // Compute input from previous layer
643                let mut input_current = 0.0;
644
645                for prev_neuron_idx in 0..self.layers[layer_idx - 1].size {
646                    if all_spikes[layer_idx - 1][prev_neuron_idx] {
647                        // Use synaptic weight (simplified - would normally track weights)
648                        input_current += 0.1;
649                    }
650                }
651
652                // Update membrane potential
653                self.layers[layer_idx].potentials[neuron_idx] +=
654                    dt * (-self.layers[layer_idx].potentials[neuron_idx] + input_current) / 0.02;
655
656                // Check for spike
657                if self.layers[layer_idx].potentials[neuron_idx] > 1.0 {
658                    self.layers[layer_idx].potentials[neuron_idx] = 0.0;
659                    self.layers[layer_idx].last_spike_times[neuron_idx] = Some(current_time);
660                    layer_spikes[neuron_idx] = true;
661                }
662
663                // Update firing rate
664                let spike_rate = if layer_spikes[neuron_idx] {
665                    1.0 / dt
666                } else {
667                    0.0
668                };
669                self.layers[layer_idx].firing_rates[neuron_idx] =
670                    0.99 * self.layers[layer_idx].firing_rates[neuron_idx] + 0.01 * spike_rate;
671            }
672
673            all_spikes.push(layer_spikes);
674        }
675
676        Ok(all_spikes)
677    }
678
679    fn update_advanced_stdp_weights(
680        &mut self,
681        all_spikes: &[Vec<bool>],
682        current_time: f64,
683        dt: f64,
684        objective_improvement: f64,
685    ) -> Result<()> {
686        // Update STDP rules for each layer connection
687        for layer_idx in 0..self.advanced_stdp_rules.len() {
688            let input_spikes = &all_spikes[layer_idx];
689            let output_spikes = &all_spikes[layer_idx + 1];
690
691            for (connection_idx, rule) in self.advanced_stdp_rules[layer_idx].iter_mut().enumerate()
692            {
693                // Calculate neuron and input indices from connection index
694                let _layer_size = self.layers[layer_idx + 1].size;
695                let prev_layer_size = self.layers[layer_idx].size;
696                let neuron_idx = connection_idx / prev_layer_size;
697                let input_idx = connection_idx % prev_layer_size;
698
699                let pre_spike = input_spikes.get(input_idx).copied().unwrap_or(false);
700                let post_spike = output_spikes.get(neuron_idx).copied().unwrap_or(false);
701
702                // Update using advanced STDP
703                let _new_weight = rule.update_weight_advanced(
704                    0.5, // Current weight (simplified)
705                    pre_spike,
706                    post_spike,
707                    dt,
708                    current_time,
709                    objective_improvement,
710                );
711            }
712        }
713
714        Ok(())
715    }
716
717    fn decode_parameters_from_network(&self, current_time: f64) -> Array1<f64> {
718        let mut updates = Array1::zeros(self.current_params.len());
719
720        // Use firing rates from first layer as parameter updates
721        if !self.layers.is_empty() {
722            for (i, &rate) in self.layers[0].firing_rates.iter().enumerate() {
723                if i < updates.len() {
724                    updates[i] = (rate - 5.0) * 0.01; // Center around target rate
725                }
726            }
727        }
728
729        updates
730    }
731
732    fn compute_adaptive_step_size(&self, objective_improvement: f64, iteration: usize) -> f64 {
733        let base_step = 0.01;
734        let improvement_factor = if objective_improvement > 0.0 {
735            1.2
736        } else {
737            0.8
738        };
739        let decay_factor = 1.0 / (1.0 + iteration as f64 * 0.001);
740
741        base_step * improvement_factor * decay_factor
742    }
743
744    fn update_network_statistics(&mut self, current_time: f64) {
745        // Compute average plasticity
746        let mut total_plasticity = 0.0;
747        let mut count = 0;
748
749        for layer_rules in &self.advanced_stdp_rules {
750            for rule in layer_rules {
751                let stats = rule.get_plasticity_stats();
752                total_plasticity += stats.metaplasticity_factor;
753                count += 1;
754            }
755        }
756
757        if count > 0 {
758            self.network_stats.avg_plasticity = total_plasticity / count as f64;
759        }
760
761        // Compute network synchrony (simplified)
762        let mut synchrony = 0.0;
763        for layer in &self.layers {
764            let rate_variance = layer.firing_rates.clone().variance();
765            synchrony += 1.0 / (1.0 + rate_variance);
766        }
767        self.network_stats.synchrony = synchrony / self.layers.len() as f64;
768
769        // Energy consumption estimate
770        let total_spikes: f64 = self
771            .layers
772            .iter()
773            .map(|layer| layer.firing_rates.sum())
774            .sum();
775        self.network_stats.energy_consumption = total_spikes * 1e-12; // Simplified energy model
776    }
777
778    /// Get network performance statistics
779    pub fn get_network_stats(&self) -> &NetworkStats {
780        &self.network_stats
781    }
782}
783
784/// STDP-based parameter optimization
785#[allow(dead_code)]
786pub fn stdp_optimize<F>(
787    objective: F,
788    initial_params: &ArrayView1<f64>,
789    num_nit: usize,
790) -> Result<Array1<f64>>
791where
792    F: Fn(&ArrayView1<f64>) -> f64,
793{
794    let mut params = initial_params.to_owned();
795    let mut stdp_rules: Vec<STDPLearningRule> = (0..params.len())
796        .map(|_| STDPLearningRule::new(0.01))
797        .collect();
798
799    let mut prev_obj = objective(&params.view());
800
801    for _iter in 0..num_nit {
802        let current_obj = objective(&params.view());
803        let improvement = prev_obj - current_obj;
804
805        // More sophisticated spike-based encoding
806        for (i, rule) in stdp_rules.iter_mut().enumerate() {
807            let pre_spike =
808                scirs2_core::random::rng().random::<f64>() < (params[i].abs() * 0.1).min(0.5);
809            let post_spike = improvement > 0.0 && scirs2_core::random::rng().random::<f64>() < 0.2;
810
811            params[i] = rule.update_weight(params[i], pre_spike, post_spike, 0.001);
812        }
813
814        prev_obj = current_obj;
815    }
816
817    Ok(params)
818}
819
820/// Advanced-advanced STDP optimization with full network simulation
821#[allow(dead_code)]
822pub fn advanced_stdp_optimize<F>(
823    objective: F,
824    initial_params: &ArrayView1<f64>,
825    max_nit: usize,
826    network_config: Option<(Vec<usize>, f64, f64)>, // (layer_sizes, target_rate, learning_rate)
827) -> Result<Array1<f64>>
828where
829    F: Fn(&ArrayView1<f64>) -> f64,
830{
831    let (layer_sizes, target_rate, learning_rate) = network_config.unwrap_or_else(|| {
832        let input_size = initial_params.len();
833        (vec![input_size, input_size * 2, input_size], 5.0, 0.01)
834    });
835
836    let mut network = AdvancedSTDPNetwork::new(layer_sizes, target_rate, learning_rate);
837    network.optimize(objective, initial_params, max_nit, 0.001)
838}
839
840#[cfg(test)]
841mod tests {
842    use super::*;
843
844    #[test]
845    fn test_advanced_stdp_creation() {
846        let stdp = AdvancedAdvancedSTDP::new(0.01, 0.005, 5.0);
847        assert_eq!(stdp.eta_ltp, 0.01);
848        assert_eq!(stdp.target_firing_rate, 5.0);
849    }
850
851    #[test]
852    fn test_advanced_stdp_weight_update() {
853        let mut stdp = AdvancedAdvancedSTDP::new(0.1, 0.05, 5.0);
854
855        let new_weight = stdp.update_weight_advanced(0.5, true, true, 0.001, 0.0, 0.1);
856
857        assert!(new_weight.is_finite());
858        assert!(new_weight >= stdp.w_min && new_weight <= stdp.w_max);
859    }
860
861    #[test]
862    fn test_advanced_stdp_network() {
863        let layer_sizes = vec![3, 5, 3];
864        let network = AdvancedSTDPNetwork::new(layer_sizes, 5.0, 0.01);
865
866        assert_eq!(network.layers.len(), 3);
867        assert_eq!(network.layers[0].size, 3);
868        assert_eq!(network.layers[1].size, 5);
869        assert_eq!(network.layers[2].size, 3);
870    }
871
872    #[test]
873    fn test_plasticity_stats() {
874        let stdp = AdvancedAdvancedSTDP::new(0.01, 0.005, 5.0);
875        let stats = stdp.get_plasticity_stats();
876
877        assert!(stats.calcium_level >= 0.0);
878        assert!(stats.metaplasticity_factor > 0.0);
879        assert!(stats.scaling_factor > 0.0);
880    }
881
882    #[test]
883    fn test_basic_stdp_optimization() {
884        let objective = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
885        let initial = Array1::from(vec![1.0, 1.0]);
886
887        let result = stdp_optimize(objective, &initial.view(), 100).unwrap();
888
889        let final_obj = objective(&result.view());
890        let initial_obj = objective(&initial.view());
891        assert!(final_obj <= initial_obj);
892    }
893
894    #[test]
895    fn test_advanced_stdp_optimization() {
896        let objective = |x: &ArrayView1<f64>| (x[0] - 1.0).powi(2) + (x[1] + 0.5).powi(2);
897        let initial = Array1::from(vec![0.0, 0.0]);
898
899        let result = advanced_stdp_optimize(
900            objective,
901            &initial.view(),
902            50,
903            Some((vec![2, 4, 2], 3.0, 0.05)),
904        )
905        .unwrap();
906
907        assert_eq!(result.len(), 2);
908        let final_obj = objective(&result.view());
909        let initial_obj = objective(&initial.view());
910        assert!(final_obj <= initial_obj * 2.0); // Allow some tolerance for stochastic method
911    }
912}
913
914#[allow(dead_code)]
915pub fn placeholder() {
916    // Placeholder function to prevent unused module warnings
917}