1use scirs2_core::error::CoreResult as Result;
12use scirs2_core::ndarray::{Array1, ArrayView1};
13use scirs2_core::random::Rng;
14use statrs::statistics::Statistics;
15use std::collections::VecDeque;
16
17#[derive(Debug, Clone)]
19pub struct AdvancedAdvancedSTDP {
20 pub pre_trace_fast: f64,
22 pub post_trace_fast: f64,
23 pub pre_trace_slow: f64,
24 pub post_trace_slow: f64,
25
26 pub pre_trace_triplet: f64,
28 pub post_trace_triplet: f64,
29
30 pub calcium_concentration: f64,
32 pub calcium_threshold_low: f64,
33 pub calcium_threshold_high: f64,
34
35 pub metaplasticity_factor: f64,
37 pub recent_activity: VecDeque<f64>,
38 pub sliding_threshold: f64,
39
40 pub target_firing_rate: f64,
42 pub current_firing_rate: f64,
43 pub scaling_factor: f64,
44
45 pub tau_plus_fast: f64,
47 pub tau_minus_fast: f64,
48 pub tau_plus_slow: f64,
49 pub tau_minus_slow: f64,
50 pub tau_calcium: f64,
51 pub tau_metaplasticity: f64,
52
53 pub eta_ltp: f64,
55 pub eta_ltd: f64,
56 pub eta_triplet: f64,
57 pub eta_homeostatic: f64,
58
59 pub theta_d: f64,
61 pub theta_p: f64,
62
63 pub spike_history_pre: VecDeque<f64>,
65 pub spike_history_post: VecDeque<f64>,
66
67 pub w_min: f64,
69 pub w_max: f64,
70}
71
72impl AdvancedAdvancedSTDP {
73 pub fn new(eta_ltp: f64, eta_ltd: f64, target_firing_rate: f64) -> Self {
75 Self {
76 pre_trace_fast: 0.0,
78 post_trace_fast: 0.0,
79 pre_trace_slow: 0.0,
80 post_trace_slow: 0.0,
81 pre_trace_triplet: 0.0,
82 post_trace_triplet: 0.0,
83
84 calcium_concentration: 0.0,
86 calcium_threshold_low: 0.2,
87 calcium_threshold_high: 0.6,
88
89 metaplasticity_factor: 1.0,
91 recent_activity: VecDeque::with_capacity(1000),
92 sliding_threshold: 0.5,
93
94 target_firing_rate,
96 current_firing_rate: 0.0,
97 scaling_factor: 1.0,
98
99 tau_plus_fast: 0.017, tau_minus_fast: 0.034, tau_plus_slow: 0.688, tau_minus_slow: 0.688, tau_calcium: 0.048, tau_metaplasticity: 100.0, eta_ltp,
109 eta_ltd,
110 eta_triplet: eta_ltp * 0.1,
111 eta_homeostatic: eta_ltp * 0.01,
112
113 theta_d: 0.2,
115 theta_p: 0.8,
116
117 spike_history_pre: VecDeque::with_capacity(100),
119 spike_history_post: VecDeque::with_capacity(100),
120
121 w_min: -2.0,
123 w_max: 2.0,
124 }
125 }
126
127 pub fn update_weight_advanced(
129 &mut self,
130 current_weight: f64,
131 pre_spike: bool,
132 post_spike: bool,
133 dt: f64,
134 current_time: f64,
135 objective_improvement: f64,
136 ) -> f64 {
137 self.update_calcium(pre_spike, post_spike, dt);
139
140 self.update_metaplasticity(current_time, objective_improvement);
142
143 self.update_homeostasis(pre_spike, post_spike, dt);
145
146 self.decay_traces(dt);
148
149 if pre_spike {
151 self.spike_history_pre.push_back(current_time);
152 if self.spike_history_pre.len() > 100 {
153 self.spike_history_pre.pop_front();
154 }
155 }
156 if post_spike {
157 self.spike_history_post.push_back(current_time);
158 if self.spike_history_post.len() > 100 {
159 self.spike_history_post.pop_front();
160 }
161 }
162
163 let mut total_weight_change = 0.0;
164
165 total_weight_change += self.compute_pairwise_stdp(pre_spike, post_spike);
167
168 total_weight_change += self.compute_triplet_stdp(pre_spike, post_spike);
170
171 total_weight_change += self.compute_calcium_plasticity(current_weight);
173
174 total_weight_change += self.compute_bcm_plasticity(pre_spike, post_spike, current_weight);
176
177 total_weight_change += self.compute_homeostatic_scaling(current_weight);
179
180 total_weight_change *= self.metaplasticity_factor;
182
183 total_weight_change *= self.scaling_factor;
185
186 let new_weight = current_weight + total_weight_change;
188 self.apply_weight_constraints(new_weight)
189 }
190
191 fn update_calcium(&mut self, pre_spike: bool, post_spike: bool, dt: f64) {
192 self.calcium_concentration *= (-dt / self.tau_calcium).exp();
194
195 if pre_spike {
197 self.calcium_concentration += 0.1;
198 }
199 if post_spike {
200 self.calcium_concentration += 0.2;
201 }
202
203 self.calcium_concentration = self.calcium_concentration.min(1.0);
205 }
206
207 fn update_metaplasticity(&mut self, current_time: f64, objective_improvement: f64) {
208 self.recent_activity.push_back(objective_improvement);
210 if self.recent_activity.len() > 1000 {
211 self.recent_activity.pop_front();
212 }
213
214 if self.recent_activity.len() > 10 {
216 let mean: f64 =
217 self.recent_activity.iter().sum::<f64>() / self.recent_activity.len() as f64;
218 let variance: f64 = self
219 .recent_activity
220 .iter()
221 .map(|&x| (x - mean).powi(2))
222 .sum::<f64>()
223 / self.recent_activity.len() as f64;
224
225 self.metaplasticity_factor = 1.0 + variance.sqrt();
227
228 self.sliding_threshold = 0.9 * self.sliding_threshold + 0.1 * mean.abs();
230 }
231 }
232
233 fn update_homeostasis(&mut self, pre_spike: bool, post_spike: bool, dt: f64) {
234 let spike_rate = if post_spike { 1.0 / dt } else { 0.0 };
236 self.current_firing_rate = 0.999 * self.current_firing_rate + 0.001 * spike_rate;
237
238 let rate_ratio = self.current_firing_rate / self.target_firing_rate.max(0.1);
240 self.scaling_factor = (2.0 / (1.0 + rate_ratio)).min(2.0).max(0.5);
241 }
242
243 fn decay_traces(&mut self, dt: f64) {
244 self.pre_trace_fast *= (-dt / self.tau_plus_fast).exp();
246 self.post_trace_fast *= (-dt / self.tau_minus_fast).exp();
247
248 self.pre_trace_slow *= (-dt / self.tau_plus_slow).exp();
250 self.post_trace_slow *= (-dt / self.tau_minus_slow).exp();
251
252 self.pre_trace_triplet *= (-dt / (self.tau_plus_fast * 2.0)).exp();
254 self.post_trace_triplet *= (-dt / (self.tau_minus_fast * 2.0)).exp();
255 }
256
257 fn compute_pairwise_stdp(&mut self, pre_spike: bool, post_spike: bool) -> f64 {
258 let mut weight_change = 0.0;
259
260 if pre_spike {
261 self.pre_trace_fast += 1.0;
262 self.pre_trace_slow += 1.0;
263
264 weight_change -= self.eta_ltd * (self.post_trace_fast + 0.1 * self.post_trace_slow);
266 }
267
268 if post_spike {
269 self.post_trace_fast += 1.0;
270 self.post_trace_slow += 1.0;
271
272 weight_change += self.eta_ltp * (self.pre_trace_fast + 0.1 * self.pre_trace_slow);
274 }
275
276 weight_change
277 }
278
279 fn compute_triplet_stdp(&mut self, pre_spike: bool, post_spike: bool) -> f64 {
280 let mut weight_change = 0.0;
281
282 if pre_spike {
283 self.pre_trace_triplet += 1.0;
284 weight_change -= self.eta_triplet * self.post_trace_fast * self.post_trace_triplet;
286 }
287
288 if post_spike {
289 self.post_trace_triplet += 1.0;
290 weight_change += self.eta_triplet * self.pre_trace_fast * self.pre_trace_triplet;
292 }
293
294 weight_change
295 }
296
297 fn compute_calcium_plasticity(&self, current_weight: f64) -> f64 {
298 let ca = self.calcium_concentration;
299
300 if ca < self.calcium_threshold_low {
301 -self.eta_ltd * 0.1 * current_weight.abs()
303 } else if ca > self.calcium_threshold_high {
304 self.eta_ltp * 0.1 * (self.w_max - current_weight.abs())
306 } else {
307 let normalized_ca = (ca - self.calcium_threshold_low)
309 / (self.calcium_threshold_high - self.calcium_threshold_low);
310 self.eta_ltp * 0.05 * (2.0 * normalized_ca - 1.0)
311 }
312 }
313
314 fn compute_bcm_plasticity(
315 &self,
316 pre_spike: bool,
317 post_spike: bool,
318 _current_weight: f64,
319 ) -> f64 {
320 if !pre_spike && !post_spike {
321 return 0.0;
322 }
323
324 let post_activity = if post_spike { 1.0 } else { 0.0 };
325 let pre_activity = if pre_spike { 1.0 } else { 0.0 };
326
327 let theta = self.sliding_threshold;
329 pre_activity * post_activity * (post_activity - theta) * self.eta_ltp * 0.1
330 }
331
332 fn compute_homeostatic_scaling(&self, current_weight: f64) -> f64 {
333 let rate_error = self.target_firing_rate - self.current_firing_rate;
335 self.eta_homeostatic * rate_error * current_weight * 0.01
336 }
337
338 fn apply_weight_constraints(&self, weight: f64) -> f64 {
339 if weight > self.w_max {
341 self.w_max - (weight - self.w_max).exp().recip()
342 } else if weight < self.w_min {
343 self.w_min + (self.w_min - weight).exp().recip()
344 } else {
345 weight
346 }
347 }
348
349 pub fn get_plasticity_stats(&self) -> PlasticityStats {
351 PlasticityStats {
352 calcium_level: self.calcium_concentration,
353 metaplasticity_factor: self.metaplasticity_factor,
354 scaling_factor: self.scaling_factor,
355 firing_rate_error: self.target_firing_rate - self.current_firing_rate,
356 sliding_threshold: self.sliding_threshold,
357 trace_strength: (self.pre_trace_fast + self.post_trace_fast) / 2.0,
358 }
359 }
360}
361
362#[derive(Debug, Clone)]
364pub struct PlasticityStats {
365 pub calcium_level: f64,
366 pub metaplasticity_factor: f64,
367 pub scaling_factor: f64,
368 pub firing_rate_error: f64,
369 pub sliding_threshold: f64,
370 pub trace_strength: f64,
371}
372
373#[derive(Debug, Clone)]
375pub struct STDPLearningRule {
376 pub pre_trace: f64,
378 pub post_trace: f64,
380 pub learning_rate: f64,
382 pub tau_plus: f64,
384 pub tau_minus: f64,
385}
386
387impl STDPLearningRule {
388 pub fn new(learning_rate: f64) -> Self {
390 Self {
391 pre_trace: 0.0,
392 post_trace: 0.0,
393 learning_rate,
394 tau_plus: 0.020, tau_minus: 0.020, }
397 }
398
399 pub fn update_weight(
401 &mut self,
402 current_weight: f64,
403 pre_spike: bool,
404 post_spike: bool,
405 dt: f64,
406 ) -> f64 {
407 self.pre_trace *= (-dt / self.tau_plus).exp();
409 self.post_trace *= (-dt / self.tau_minus).exp();
410
411 let mut weight_change = 0.0;
412
413 if pre_spike {
414 self.pre_trace += 1.0;
415 if self.post_trace > 0.0 {
417 weight_change -= self.learning_rate * self.post_trace;
418 }
419 }
420
421 if post_spike {
422 self.post_trace += 1.0;
423 if self.pre_trace > 0.0 {
425 weight_change += self.learning_rate * self.pre_trace;
426 }
427 }
428
429 (current_weight + weight_change).max(-1.0).min(1.0)
430 }
431}
432
433#[derive(Debug, Clone)]
435pub struct AdvancedSTDPNetwork {
436 pub layers: Vec<STDPLayer>,
438 pub advanced_stdp_rules: Vec<Vec<AdvancedAdvancedSTDP>>,
440 pub current_params: Array1<f64>,
442 pub best_params: Array1<f64>,
444 pub best_objective: f64,
446 pub nit: usize,
448 pub network_stats: NetworkStats,
450}
451
452#[derive(Debug, Clone)]
454pub struct STDPLayer {
455 pub size: usize,
457 pub potentials: Array1<f64>,
459 pub last_spike_times: Array1<Option<f64>>,
461 pub firing_rates: Array1<f64>,
463}
464
465#[derive(Debug, Clone)]
467pub struct NetworkStats {
468 pub avg_plasticity: f64,
470 pub synchrony: f64,
472 pub energy_consumption: f64,
474 pub convergence: f64,
476}
477
478impl Default for NetworkStats {
479 fn default() -> Self {
480 Self {
481 avg_plasticity: 0.0,
482 synchrony: 0.0,
483 energy_consumption: 0.0,
484 convergence: 0.0,
485 }
486 }
487}
488
489impl AdvancedSTDPNetwork {
490 pub fn new(layer_sizes: Vec<usize>, target_firing_rate: f64, learning_rate: f64) -> Self {
492 let mut layers = Vec::new();
493 let mut advanced_stdp_rules = Vec::new();
494
495 for (layer_idx, &size) in layer_sizes.iter().enumerate() {
496 let layer = STDPLayer {
497 size,
498 potentials: Array1::zeros(size),
499 last_spike_times: Array1::from_vec(vec![None; size]),
500 firing_rates: Array1::zeros(size),
501 };
502 layers.push(layer);
503
504 if layer_idx > 0 {
506 let prev_size = layer_sizes[layer_idx - 1];
507 let mut layer_rules = Vec::new();
508
509 for _i in 0..size {
510 for _j in 0..prev_size {
511 layer_rules.push(AdvancedAdvancedSTDP::new(
512 learning_rate,
513 learning_rate * 0.5,
514 target_firing_rate,
515 ));
516 }
517 }
518 advanced_stdp_rules.push(layer_rules);
519 }
520 }
521
522 let input_size = layer_sizes[0];
523
524 Self {
525 layers,
526 advanced_stdp_rules,
527 current_params: Array1::zeros(input_size),
528 best_params: Array1::zeros(input_size),
529 best_objective: f64::INFINITY,
530 nit: 0,
531 network_stats: NetworkStats::default(),
532 }
533 }
534
535 pub fn optimize<F>(
537 &mut self,
538 objective: F,
539 initial_params: &ArrayView1<f64>,
540 max_nit: usize,
541 dt: f64,
542 ) -> Result<Array1<f64>>
543 where
544 F: Fn(&ArrayView1<f64>) -> f64,
545 {
546 self.current_params = initial_params.to_owned();
547 self.best_params = initial_params.to_owned();
548 self.best_objective = objective(initial_params);
549
550 let mut prev_objective = self.best_objective;
551
552 for iteration in 0..max_nit {
553 let current_time = iteration as f64 * dt;
554
555 let current_objective = objective(&self.current_params.view());
557 let objective_improvement = prev_objective - current_objective;
558
559 if current_objective < self.best_objective {
561 self.best_objective = current_objective;
562 self.best_params = self.current_params.clone();
563 }
564
565 let spike_patterns =
567 self.encode_parameters_to_spikes(&self.current_params, current_time);
568
569 let network_spikes =
571 self.simulate_network_dynamics(&spike_patterns, current_time, dt)?;
572
573 self.update_advanced_stdp_weights(
575 &network_spikes,
576 current_time,
577 dt,
578 objective_improvement,
579 )?;
580
581 let param_updates = self.decode_parameters_from_network(current_time);
583
584 let step_size = self.compute_adaptive_step_size(objective_improvement, iteration);
586 for (i, update) in param_updates.iter().enumerate() {
587 if i < self.current_params.len() {
588 self.current_params[i] += step_size * update;
589 }
590 }
591
592 self.update_network_statistics(current_time);
594
595 if objective_improvement.abs() < 1e-8 && iteration > 100 {
597 break;
598 }
599
600 prev_objective = current_objective;
601 self.nit = iteration + 1;
602 }
603
604 Ok(self.best_params.clone())
605 }
606
607 fn encode_parameters_to_spikes(
608 &self,
609 params: &Array1<f64>,
610 _current_time: f64,
611 ) -> Vec<Vec<bool>> {
612 let mut spike_patterns = Vec::new();
613
614 for layer in &self.layers {
615 let mut layer_spikes = vec![false; layer.size];
616
617 for i in 0..layer.size.min(params.len()) {
619 let spike_prob = ((params[i] + 1.0) / 2.0).max(0.0).min(1.0);
620 layer_spikes[i] = scirs2_core::random::rng().random::<f64>() < spike_prob * 0.1;
621 }
622
623 spike_patterns.push(layer_spikes);
624 }
625
626 spike_patterns
627 }
628
629 fn simulate_network_dynamics(
630 &mut self,
631 input_spikes: &[Vec<bool>],
632 current_time: f64,
633 dt: f64,
634 ) -> Result<Vec<Vec<bool>>> {
635 let mut all_spikes = input_spikes.to_vec();
636
637 for layer_idx in 1..self.layers.len() {
639 let mut layer_spikes = vec![false; self.layers[layer_idx].size];
640
641 for neuron_idx in 0..self.layers[layer_idx].size {
642 let mut input_current = 0.0;
644
645 for prev_neuron_idx in 0..self.layers[layer_idx - 1].size {
646 if all_spikes[layer_idx - 1][prev_neuron_idx] {
647 input_current += 0.1;
649 }
650 }
651
652 self.layers[layer_idx].potentials[neuron_idx] +=
654 dt * (-self.layers[layer_idx].potentials[neuron_idx] + input_current) / 0.02;
655
656 if self.layers[layer_idx].potentials[neuron_idx] > 1.0 {
658 self.layers[layer_idx].potentials[neuron_idx] = 0.0;
659 self.layers[layer_idx].last_spike_times[neuron_idx] = Some(current_time);
660 layer_spikes[neuron_idx] = true;
661 }
662
663 let spike_rate = if layer_spikes[neuron_idx] {
665 1.0 / dt
666 } else {
667 0.0
668 };
669 self.layers[layer_idx].firing_rates[neuron_idx] =
670 0.99 * self.layers[layer_idx].firing_rates[neuron_idx] + 0.01 * spike_rate;
671 }
672
673 all_spikes.push(layer_spikes);
674 }
675
676 Ok(all_spikes)
677 }
678
679 fn update_advanced_stdp_weights(
680 &mut self,
681 all_spikes: &[Vec<bool>],
682 current_time: f64,
683 dt: f64,
684 objective_improvement: f64,
685 ) -> Result<()> {
686 for layer_idx in 0..self.advanced_stdp_rules.len() {
688 let input_spikes = &all_spikes[layer_idx];
689 let output_spikes = &all_spikes[layer_idx + 1];
690
691 for (connection_idx, rule) in self.advanced_stdp_rules[layer_idx].iter_mut().enumerate()
692 {
693 let _layer_size = self.layers[layer_idx + 1].size;
695 let prev_layer_size = self.layers[layer_idx].size;
696 let neuron_idx = connection_idx / prev_layer_size;
697 let input_idx = connection_idx % prev_layer_size;
698
699 let pre_spike = input_spikes.get(input_idx).copied().unwrap_or(false);
700 let post_spike = output_spikes.get(neuron_idx).copied().unwrap_or(false);
701
702 let _new_weight = rule.update_weight_advanced(
704 0.5, pre_spike,
706 post_spike,
707 dt,
708 current_time,
709 objective_improvement,
710 );
711 }
712 }
713
714 Ok(())
715 }
716
717 fn decode_parameters_from_network(&self, current_time: f64) -> Array1<f64> {
718 let mut updates = Array1::zeros(self.current_params.len());
719
720 if !self.layers.is_empty() {
722 for (i, &rate) in self.layers[0].firing_rates.iter().enumerate() {
723 if i < updates.len() {
724 updates[i] = (rate - 5.0) * 0.01; }
726 }
727 }
728
729 updates
730 }
731
732 fn compute_adaptive_step_size(&self, objective_improvement: f64, iteration: usize) -> f64 {
733 let base_step = 0.01;
734 let improvement_factor = if objective_improvement > 0.0 {
735 1.2
736 } else {
737 0.8
738 };
739 let decay_factor = 1.0 / (1.0 + iteration as f64 * 0.001);
740
741 base_step * improvement_factor * decay_factor
742 }
743
744 fn update_network_statistics(&mut self, current_time: f64) {
745 let mut total_plasticity = 0.0;
747 let mut count = 0;
748
749 for layer_rules in &self.advanced_stdp_rules {
750 for rule in layer_rules {
751 let stats = rule.get_plasticity_stats();
752 total_plasticity += stats.metaplasticity_factor;
753 count += 1;
754 }
755 }
756
757 if count > 0 {
758 self.network_stats.avg_plasticity = total_plasticity / count as f64;
759 }
760
761 let mut synchrony = 0.0;
763 for layer in &self.layers {
764 let rate_variance = layer.firing_rates.clone().variance();
765 synchrony += 1.0 / (1.0 + rate_variance);
766 }
767 self.network_stats.synchrony = synchrony / self.layers.len() as f64;
768
769 let total_spikes: f64 = self
771 .layers
772 .iter()
773 .map(|layer| layer.firing_rates.sum())
774 .sum();
775 self.network_stats.energy_consumption = total_spikes * 1e-12; }
777
778 pub fn get_network_stats(&self) -> &NetworkStats {
780 &self.network_stats
781 }
782}
783
784#[allow(dead_code)]
786pub fn stdp_optimize<F>(
787 objective: F,
788 initial_params: &ArrayView1<f64>,
789 num_nit: usize,
790) -> Result<Array1<f64>>
791where
792 F: Fn(&ArrayView1<f64>) -> f64,
793{
794 let mut params = initial_params.to_owned();
795 let mut stdp_rules: Vec<STDPLearningRule> = (0..params.len())
796 .map(|_| STDPLearningRule::new(0.01))
797 .collect();
798
799 let mut prev_obj = objective(¶ms.view());
800
801 for _iter in 0..num_nit {
802 let current_obj = objective(¶ms.view());
803 let improvement = prev_obj - current_obj;
804
805 for (i, rule) in stdp_rules.iter_mut().enumerate() {
807 let pre_spike =
808 scirs2_core::random::rng().random::<f64>() < (params[i].abs() * 0.1).min(0.5);
809 let post_spike = improvement > 0.0 && scirs2_core::random::rng().random::<f64>() < 0.2;
810
811 params[i] = rule.update_weight(params[i], pre_spike, post_spike, 0.001);
812 }
813
814 prev_obj = current_obj;
815 }
816
817 Ok(params)
818}
819
820#[allow(dead_code)]
822pub fn advanced_stdp_optimize<F>(
823 objective: F,
824 initial_params: &ArrayView1<f64>,
825 max_nit: usize,
826 network_config: Option<(Vec<usize>, f64, f64)>, ) -> Result<Array1<f64>>
828where
829 F: Fn(&ArrayView1<f64>) -> f64,
830{
831 let (layer_sizes, target_rate, learning_rate) = network_config.unwrap_or_else(|| {
832 let input_size = initial_params.len();
833 (vec![input_size, input_size * 2, input_size], 5.0, 0.01)
834 });
835
836 let mut network = AdvancedSTDPNetwork::new(layer_sizes, target_rate, learning_rate);
837 network.optimize(objective, initial_params, max_nit, 0.001)
838}
839
840#[cfg(test)]
841mod tests {
842 use super::*;
843
844 #[test]
845 fn test_advanced_stdp_creation() {
846 let stdp = AdvancedAdvancedSTDP::new(0.01, 0.005, 5.0);
847 assert_eq!(stdp.eta_ltp, 0.01);
848 assert_eq!(stdp.target_firing_rate, 5.0);
849 }
850
851 #[test]
852 fn test_advanced_stdp_weight_update() {
853 let mut stdp = AdvancedAdvancedSTDP::new(0.1, 0.05, 5.0);
854
855 let new_weight = stdp.update_weight_advanced(0.5, true, true, 0.001, 0.0, 0.1);
856
857 assert!(new_weight.is_finite());
858 assert!(new_weight >= stdp.w_min && new_weight <= stdp.w_max);
859 }
860
861 #[test]
862 fn test_advanced_stdp_network() {
863 let layer_sizes = vec![3, 5, 3];
864 let network = AdvancedSTDPNetwork::new(layer_sizes, 5.0, 0.01);
865
866 assert_eq!(network.layers.len(), 3);
867 assert_eq!(network.layers[0].size, 3);
868 assert_eq!(network.layers[1].size, 5);
869 assert_eq!(network.layers[2].size, 3);
870 }
871
872 #[test]
873 fn test_plasticity_stats() {
874 let stdp = AdvancedAdvancedSTDP::new(0.01, 0.005, 5.0);
875 let stats = stdp.get_plasticity_stats();
876
877 assert!(stats.calcium_level >= 0.0);
878 assert!(stats.metaplasticity_factor > 0.0);
879 assert!(stats.scaling_factor > 0.0);
880 }
881
882 #[test]
883 fn test_basic_stdp_optimization() {
884 let objective = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
885 let initial = Array1::from(vec![1.0, 1.0]);
886
887 let result = stdp_optimize(objective, &initial.view(), 100).unwrap();
888
889 let final_obj = objective(&result.view());
890 let initial_obj = objective(&initial.view());
891 assert!(final_obj <= initial_obj);
892 }
893
894 #[test]
895 fn test_advanced_stdp_optimization() {
896 let objective = |x: &ArrayView1<f64>| (x[0] - 1.0).powi(2) + (x[1] + 0.5).powi(2);
897 let initial = Array1::from(vec![0.0, 0.0]);
898
899 let result = advanced_stdp_optimize(
900 objective,
901 &initial.view(),
902 50,
903 Some((vec![2, 4, 2], 3.0, 0.05)),
904 )
905 .unwrap();
906
907 assert_eq!(result.len(), 2);
908 let final_obj = objective(&result.view());
909 let initial_obj = objective(&initial.view());
910 assert!(final_obj <= initial_obj * 2.0); }
912}
913
914#[allow(dead_code)]
915pub fn placeholder() {
916 }