1use super::{
7 utils, Experience, ExperienceBuffer, ImprovementReward, OptimizationAction, OptimizationState,
8 RLOptimizationConfig, RLOptimizer, RewardFunction,
9};
10use crate::error::{OptimizeError, OptimizeResult};
11use crate::result::OptimizeResults;
12use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
13use scirs2_core::random::{rng, Rng};
14#[derive(Debug, Clone)]
18pub struct ActorNetwork {
19 pub hidden_weights: Array2<f64>,
21 pub hidden_bias: Array1<f64>,
23 pub output_weights: Array2<f64>,
25 pub output_bias: Array1<f64>,
27 pub _input_size: usize,
29 pub hidden_size: usize,
30 pub output_size: usize,
31 pub activation: ActivationType,
33}
34
35#[derive(Debug, Clone)]
37pub struct CriticNetwork {
38 pub hidden_weights: Array2<f64>,
40 pub hidden_bias: Array1<f64>,
42 pub output_weights: Array1<f64>,
44 pub output_bias: f64,
46 pub _input_size: usize,
48 pub hidden_size: usize,
49 pub activation: ActivationType,
51}
52
53#[derive(Debug, Clone, Copy)]
55pub enum ActivationType {
56 Tanh,
57 ReLU,
58 Sigmoid,
59 LeakyReLU,
60 ELU,
61}
62
63impl ActivationType {
64 fn apply(&self, x: f64) -> f64 {
65 match self {
66 ActivationType::Tanh => x.tanh(),
67 ActivationType::ReLU => x.max(0.0),
68 ActivationType::Sigmoid => 1.0 / (1.0 + (-x).exp()),
69 ActivationType::LeakyReLU => {
70 if x > 0.0 {
71 x
72 } else {
73 0.01 * x
74 }
75 }
76 ActivationType::ELU => {
77 if x > 0.0 {
78 x
79 } else {
80 x.exp() - 1.0
81 }
82 }
83 }
84 }
85
86 fn derivative(&self, x: f64) -> f64 {
87 match self {
88 ActivationType::Tanh => {
89 let t = x.tanh();
90 1.0 - t * t
91 }
92 ActivationType::ReLU => {
93 if x > 0.0 {
94 1.0
95 } else {
96 0.0
97 }
98 }
99 ActivationType::Sigmoid => {
100 let s = 1.0 / (1.0 + (-x).exp());
101 s * (1.0 - s)
102 }
103 ActivationType::LeakyReLU => {
104 if x > 0.0 {
105 1.0
106 } else {
107 0.01
108 }
109 }
110 ActivationType::ELU => {
111 if x > 0.0 {
112 1.0
113 } else {
114 x.exp()
115 }
116 }
117 }
118 }
119}
120
121impl ActorNetwork {
122 pub fn new(
124 input_size: usize,
125 hidden_size: usize,
126 output_size: usize,
127 activation: ActivationType,
128 ) -> Self {
129 let xavier_scale = (2.0 / (input_size + hidden_size) as f64).sqrt();
130
131 Self {
132 hidden_weights: Array2::from_shape_fn((hidden_size, input_size), |_| {
133 (scirs2_core::random::rng().random::<f64>() - 0.5) * 2.0 * xavier_scale
134 }),
135 hidden_bias: Array1::zeros(hidden_size),
136 output_weights: Array2::from_shape_fn((output_size, hidden_size), |_| {
137 (scirs2_core::random::rng().random::<f64>() - 0.5) * 2.0 * xavier_scale
138 }),
139 output_bias: Array1::zeros(output_size),
140 _input_size: input_size,
141 hidden_size,
142 output_size,
143 activation,
144 }
145 }
146
147 pub fn forward(&self, input: &ArrayView1<f64>) -> (Array1<f64>, Array1<f64>, Array1<f64>) {
149 let mut hidden_raw = Array1::zeros(self.hidden_size);
151 for i in 0..self.hidden_size {
152 for j in 0..self._input_size.min(input.len()) {
153 hidden_raw[i] += self.hidden_weights[[i, j]] * input[j];
154 }
155 hidden_raw[i] += self.hidden_bias[i];
156 }
157
158 let hidden_activated = hidden_raw.mapv(|x| self.activation.apply(x));
159
160 let mut output_raw = Array1::zeros(self.output_size);
162 for i in 0..self.output_size {
163 for j in 0..self.hidden_size {
164 output_raw[i] += self.output_weights[[i, j]] * hidden_activated[j];
165 }
166 output_raw[i] += self.output_bias[i];
167 }
168
169 let output_activated = output_raw.mapv(|x| self.activation.apply(x));
170
171 (hidden_raw, hidden_activated, output_activated)
172 }
173
174 pub fn action_probabilities(
176 &self,
177 policy_output: &ArrayView1<f64>,
178 temperature: f64,
179 ) -> Array1<f64> {
180 let scaled_output = policy_output.mapv(|x| x / temperature);
181 let max_val = scaled_output.fold(-f64::INFINITY, |a, &b| a.max(b));
182 let exp_output = scaled_output.mapv(|x| (x - max_val).exp());
183 let sum_exp = exp_output.sum();
184
185 if sum_exp > 0.0 {
186 exp_output / sum_exp
187 } else {
188 Array1::from_elem(policy_output.len(), 1.0 / policy_output.len() as f64)
189 }
190 }
191
192 pub fn backward_and_update(
194 &mut self,
195 input: &ArrayView1<f64>,
196 hidden_raw: &Array1<f64>,
197 hidden_activated: &Array1<f64>,
198 output_gradient: &ArrayView1<f64>,
199 learning_rate: f64,
200 ) {
201 let output_raw_gradient = output_gradient.mapv(|g| g); let mut hidden_gradient: Array1<f64> = Array1::zeros(self.hidden_size);
206 for j in 0..self.hidden_size {
207 for i in 0..self.output_size {
208 hidden_gradient[j] += output_raw_gradient[i] * self.output_weights[[i, j]];
209 }
210 hidden_gradient[j] *= self.activation.derivative(hidden_raw[j]);
211 }
212
213 for i in 0..self.output_size {
215 for j in 0..self.hidden_size {
216 self.output_weights[[i, j]] -=
217 learning_rate * output_raw_gradient[i] * hidden_activated[j];
218 }
219 self.output_bias[i] -= learning_rate * output_raw_gradient[i];
220 }
221
222 for i in 0..self.hidden_size {
224 for j in 0..self._input_size.min(input.len()) {
225 self.hidden_weights[[i, j]] -= learning_rate * hidden_gradient[i] * input[j];
226 }
227 self.hidden_bias[i] -= learning_rate * hidden_gradient[i];
228 }
229 }
230}
231
232impl CriticNetwork {
233 pub fn new(input_size: usize, hidden_size: usize, activation: ActivationType) -> Self {
235 let xavier_scale = (2.0 / (input_size + hidden_size) as f64).sqrt();
236
237 Self {
238 hidden_weights: Array2::from_shape_fn((hidden_size, input_size), |_| {
239 (scirs2_core::random::rng().random::<f64>() - 0.5) * 2.0 * xavier_scale
240 }),
241 hidden_bias: Array1::zeros(hidden_size),
242 output_weights: Array1::from_shape_fn(hidden_size, |_| {
243 (scirs2_core::random::rng().random::<f64>() - 0.5) * 2.0 * xavier_scale
244 }),
245 output_bias: 0.0,
246 _input_size: input_size,
247 hidden_size,
248 activation,
249 }
250 }
251
252 pub fn forward(&self, input: &ArrayView1<f64>) -> (Array1<f64>, Array1<f64>, f64) {
254 let mut hidden_raw = Array1::zeros(self.hidden_size);
256 for i in 0..self.hidden_size {
257 for j in 0..self._input_size.min(input.len()) {
258 hidden_raw[i] += self.hidden_weights[[i, j]] * input[j];
259 }
260 hidden_raw[i] += self.hidden_bias[i];
261 }
262
263 let hidden_activated = hidden_raw.mapv(|x| self.activation.apply(x));
264
265 let mut output = 0.0;
267 for j in 0..self.hidden_size {
268 output += self.output_weights[j] * hidden_activated[j];
269 }
270 output += self.output_bias;
271
272 (hidden_raw, hidden_activated, output)
273 }
274
275 pub fn backward_and_update(
277 &mut self,
278 input: &ArrayView1<f64>,
279 hidden_raw: &Array1<f64>,
280 hidden_activated: &Array1<f64>,
281 target_value: f64,
282 predicted_value: f64,
283 learning_rate: f64,
284 ) {
285 let value_error = target_value - predicted_value;
286
287 let mut hidden_gradient: Array1<f64> = Array1::zeros(self.hidden_size);
289 for j in 0..self.hidden_size {
290 hidden_gradient[j] =
291 value_error * self.output_weights[j] * self.activation.derivative(hidden_raw[j]);
292 }
293
294 for j in 0..self.hidden_size {
296 self.output_weights[j] += learning_rate * value_error * hidden_activated[j];
297 }
298 self.output_bias += learning_rate * value_error;
299
300 for i in 0..self.hidden_size {
302 for j in 0..self._input_size.min(input.len()) {
303 self.hidden_weights[[i, j]] += learning_rate * hidden_gradient[i] * input[j];
304 }
305 self.hidden_bias[i] += learning_rate * hidden_gradient[i];
306 }
307 }
308}
309
310#[derive(Debug, Clone)]
312pub struct AdvantageActorCriticOptimizer {
313 config: RLOptimizationConfig,
315 actor: ActorNetwork,
317 critic: CriticNetwork,
319 experience_buffer: ExperienceBuffer,
321 reward_function: ImprovementReward,
323 best_params: Array1<f64>,
325 best_objective: f64,
327 temperature: f64,
329 baseline: f64,
331 training_stats: A2CTrainingStats,
333 entropy_coeff: f64,
335 value_coeff: f64,
337}
338
339#[derive(Debug, Clone)]
341pub struct A2CTrainingStats {
342 pub avg_actor_loss: f64,
344 pub avg_critic_loss: f64,
346 pub avg_advantage: f64,
348 pub avg_entropy: f64,
350 pub episodes_completed: usize,
352 pub total_steps: usize,
354}
355
356impl Default for A2CTrainingStats {
357 fn default() -> Self {
358 Self {
359 avg_actor_loss: 0.0,
360 avg_critic_loss: 0.0,
361 avg_advantage: 0.0,
362 avg_entropy: 0.0,
363 episodes_completed: 0,
364 total_steps: 0,
365 }
366 }
367}
368
369impl AdvantageActorCriticOptimizer {
370 pub fn new(
372 config: RLOptimizationConfig,
373 state_size: usize,
374 action_size: usize,
375 hidden_size: usize,
376 ) -> Self {
377 let memory_size = config.memory_size;
378 Self {
379 config,
380 actor: ActorNetwork::new(state_size, hidden_size, action_size, ActivationType::Tanh),
381 critic: CriticNetwork::new(state_size, hidden_size, ActivationType::ReLU),
382 experience_buffer: ExperienceBuffer::new(memory_size),
383 reward_function: ImprovementReward::default(),
384 best_params: Array1::zeros(state_size),
385 best_objective: f64::INFINITY,
386 temperature: 1.0,
387 baseline: 0.0,
388 training_stats: A2CTrainingStats::default(),
389 entropy_coeff: 0.01,
390 value_coeff: 0.5,
391 }
392 }
393
394 fn extract_state_features(&self, state: &OptimizationState) -> Array1<f64> {
396 let mut features = Vec::new();
397
398 for ¶m in state.parameters.iter() {
400 features.push(param.tanh());
401 }
402
403 let log_obj = (state.objective_value.abs() + 1e-8).ln();
405 features.push(log_obj.tanh());
406
407 features.push(
409 state
410 .convergence_metrics
411 .relative_objective_change
412 .ln()
413 .tanh(),
414 );
415 features.push(state.convergence_metrics.parameter_change_norm.tanh());
416 features.push((state.convergence_metrics.steps_since_improvement as f64 / 10.0).tanh());
417
418 if state.objective_history.len() >= 2 {
420 let recent_change = state.objective_history[state.objective_history.len() - 1]
421 - state.objective_history[state.objective_history.len() - 2];
422 features.push(recent_change.tanh());
423
424 let trend = if state.objective_history.len() >= 3 {
425 let slope = (state.objective_history[state.objective_history.len() - 1]
426 - state.objective_history[0])
427 / state.objective_history.len() as f64;
428 slope.tanh()
429 } else {
430 0.0
431 };
432 features.push(trend);
433 } else {
434 features.push(0.0);
435 features.push(0.0);
436 }
437
438 features.push((state.step as f64 / self.config.max_steps_per_episode as f64).tanh());
440
441 Array1::from(features)
442 }
443
444 fn select_action_with_exploration(
446 &mut self,
447 state: &OptimizationState,
448 ) -> (OptimizationAction, Array1<f64>) {
449 let state_features = self.extract_state_features(state);
450 let (_, _, policy_output) = self.actor.forward(&state_features.view());
451
452 let exploration_noise = if self.training_stats.episodes_completed
454 < self.config.num_episodes / 2
455 {
456 0.1 * (1.0
457 - self.training_stats.episodes_completed as f64 / self.config.num_episodes as f64)
458 } else {
459 0.01
460 };
461
462 let noisy_output = policy_output
463 .mapv(|x| x + (scirs2_core::random::rng().random::<f64>() - 0.5) * exploration_noise);
464 let action_probs = self
465 .actor
466 .action_probabilities(&noisy_output.view(), self.temperature);
467
468 let cumulative_probs: Vec<f64> = action_probs
470 .iter()
471 .scan(0.0, |acc, &p| {
472 *acc += p;
473 Some(*acc)
474 })
475 .collect();
476
477 let rand_val = scirs2_core::random::rng().random::<f64>();
478 let action_idx = cumulative_probs
479 .iter()
480 .position(|&cp| rand_val <= cp)
481 .unwrap_or(action_probs.len() - 1);
482
483 let action = self.decode_action_from_index(action_idx, &noisy_output);
484
485 (action, action_probs)
486 }
487
488 fn decode_action_from_index(
490 &self,
491 action_idx: usize,
492 policy_output: &Array1<f64>,
493 ) -> OptimizationAction {
494 let magnitude_factor = 1.0 + policy_output.get(1).unwrap_or(&0.0).abs();
495
496 match action_idx {
497 0 => OptimizationAction::GradientStep {
498 learning_rate: 0.001 * magnitude_factor,
499 },
500 1 => OptimizationAction::RandomPerturbation {
501 magnitude: 0.01 * magnitude_factor,
502 },
503 2 => OptimizationAction::MomentumUpdate {
504 momentum: 0.9 * (1.0 + policy_output.get(2).unwrap_or(&0.0) * 0.1),
505 },
506 3 => OptimizationAction::AdaptiveLearningRate {
507 factor: 0.5 + 0.5 * policy_output.get(3).unwrap_or(&0.0).tanh(),
508 },
509 4 => OptimizationAction::ResetToBest,
510 _ => OptimizationAction::Terminate,
511 }
512 }
513
514 fn compute_advantage(
516 &self,
517 reward: f64,
518 current_value: f64,
519 next_value: f64,
520 done: bool,
521 ) -> f64 {
522 let target = if done {
523 reward
524 } else {
525 reward + self.config.discount_factor * next_value
526 };
527 target - current_value
528 }
529
530 fn update_networks(&mut self, experiences: &[Experience]) -> Result<(), OptimizeError> {
532 let mut total_actor_loss = 0.0;
533 let mut total_critic_loss = 0.0;
534 let mut total_advantage = 0.0;
535 let mut total_entropy = 0.0;
536
537 for experience in experiences {
538 let state_features = self.extract_state_features(&experience.state);
539 let next_state_features = self.extract_state_features(&experience.next_state);
540
541 let (hidden_raw, hidden_activated, current_value) =
543 self.critic.forward(&state_features.view());
544 let (_, _, next_value) = self.critic.forward(&next_state_features.view());
545
546 let advantage = self.compute_advantage(
548 experience.reward,
549 current_value,
550 next_value,
551 experience.done,
552 );
553
554 let target_value = if experience.done {
556 experience.reward
557 } else {
558 experience.reward + self.config.discount_factor * next_value
559 };
560
561 self.critic.backward_and_update(
563 &state_features.view(),
564 &hidden_raw,
565 &hidden_activated,
566 target_value,
567 current_value,
568 self.config.learning_rate * self.value_coeff,
569 );
570
571 let (actor_hidden_raw, actor_hidden_activated, policy_output) =
573 self.actor.forward(&state_features.view());
574
575 let action_probs = self
576 .actor
577 .action_probabilities(&policy_output.view(), self.temperature);
578
579 let entropy = -action_probs
581 .iter()
582 .filter(|&&p| p > 1e-8)
583 .map(|&p| p * p.ln())
584 .sum::<f64>();
585
586 let action_idx = self.get_action_index(&experience.action);
588 let log_prob = action_probs.get(action_idx).unwrap_or(&1e-8).ln();
589 let policy_loss = -log_prob * (advantage - self.baseline);
590
591 let mut actor_gradient = Array1::zeros(policy_output.len());
593 if action_idx < actor_gradient.len() {
594 actor_gradient[action_idx] =
595 -(advantage - self.baseline) / (action_probs[action_idx] + 1e-8);
596 actor_gradient[action_idx] += self.entropy_coeff * (1.0 + log_prob);
598 }
599
600 self.actor.backward_and_update(
602 &state_features.view(),
603 &actor_hidden_raw,
604 &actor_hidden_activated,
605 &actor_gradient.view(),
606 self.config.learning_rate,
607 );
608
609 total_actor_loss += policy_loss;
611 total_critic_loss += (target_value - current_value).powi(2);
612 total_advantage += advantage;
613 total_entropy += entropy;
614 }
615
616 if !experiences.is_empty() {
618 self.baseline =
619 0.9 * self.baseline + 0.1 * (total_advantage / experiences.len() as f64);
620
621 let num_exp = experiences.len() as f64;
623 self.training_stats.avg_actor_loss =
624 0.9 * self.training_stats.avg_actor_loss + 0.1 * (total_actor_loss / num_exp);
625 self.training_stats.avg_critic_loss =
626 0.9 * self.training_stats.avg_critic_loss + 0.1 * (total_critic_loss / num_exp);
627 self.training_stats.avg_advantage =
628 0.9 * self.training_stats.avg_advantage + 0.1 * (total_advantage / num_exp);
629 self.training_stats.avg_entropy =
630 0.9 * self.training_stats.avg_entropy + 0.1 * (total_entropy / num_exp);
631 }
632
633 Ok(())
634 }
635
636 fn get_action_index(&self, action: &OptimizationAction) -> usize {
638 match action {
639 OptimizationAction::GradientStep { .. } => 0,
640 OptimizationAction::RandomPerturbation { .. } => 1,
641 OptimizationAction::MomentumUpdate { .. } => 2,
642 OptimizationAction::AdaptiveLearningRate { .. } => 3,
643 OptimizationAction::ResetToBest => 4,
644 OptimizationAction::Terminate => 5,
645 }
646 }
647
648 pub fn get_training_stats(&self) -> &A2CTrainingStats {
650 &self.training_stats
651 }
652
653 fn adjust_exploration(&mut self) {
655 self.temperature = (self.temperature * 0.999).max(0.1);
657
658 self.entropy_coeff = (self.entropy_coeff * 0.9995).max(0.001);
660 }
661}
662
663impl RLOptimizer for AdvantageActorCriticOptimizer {
664 fn config(&self) -> &RLOptimizationConfig {
665 &self.config
666 }
667
668 fn select_action(&mut self, state: &OptimizationState) -> OptimizationAction {
669 let (action, _) = self.select_action_with_exploration(state);
670 action
671 }
672
673 fn update(&mut self, experience: &Experience) -> Result<(), OptimizeError> {
674 self.experience_buffer.add(experience.clone());
675
676 if self.experience_buffer.size() >= self.config.batch_size {
678 let batch = self.experience_buffer.sample_batch(self.config.batch_size);
679 self.update_networks(&batch)?;
680 }
681
682 Ok(())
683 }
684
685 fn run_episode<F>(
686 &mut self,
687 objective: &F,
688 initial_params: &ArrayView1<f64>,
689 ) -> OptimizeResult<OptimizeResults<f64>>
690 where
691 F: Fn(&ArrayView1<f64>) -> f64,
692 {
693 let mut current_params = initial_params.to_owned();
694 let mut current_state = utils::create_state(current_params.clone(), objective, 0, None);
695 let mut momentum = Array1::zeros(initial_params.len());
696 let mut total_reward = 0.0;
697
698 for step in 0..self.config.max_steps_per_episode {
699 let (action, _) = self.select_action_with_exploration(¤t_state);
701
702 let new_params =
704 utils::apply_action(¤t_state, &action, &self.best_params, &mut momentum);
705 let new_state =
706 utils::create_state(new_params, objective, step + 1, Some(¤t_state));
707
708 let reward = self
710 .reward_function
711 .compute_reward(¤t_state, &action, &new_state);
712 total_reward += reward;
713
714 let experience = Experience {
716 state: current_state.clone(),
717 action: action.clone(),
718 reward,
719 next_state: new_state.clone(),
720 done: utils::should_terminate(&new_state, self.config.max_steps_per_episode),
721 };
722
723 self.update(&experience)?;
725
726 if new_state.objective_value < self.best_objective {
728 self.best_objective = new_state.objective_value;
729 self.best_params = new_state.parameters.clone();
730 }
731
732 current_state = new_state;
733 current_params = current_state.parameters.clone();
734
735 if utils::should_terminate(¤t_state, self.config.max_steps_per_episode)
737 || matches!(action, OptimizationAction::Terminate)
738 {
739 break;
740 }
741 }
742
743 self.training_stats.episodes_completed += 1;
744 self.training_stats.total_steps += current_state.step;
745
746 self.adjust_exploration();
748
749 Ok(OptimizeResults::<f64> {
750 x: current_params,
751 fun: current_state.objective_value,
752 success: current_state.convergence_metrics.relative_objective_change < 1e-6,
753 nit: current_state.step,
754 message: format!("A2C episode completed, total reward: {:.4}", total_reward),
755 jac: None,
756 hess: None,
757 constr: None,
758 nfev: current_state.step,
759 njev: 0,
760 nhev: 0,
761 maxcv: 0,
762 status: if current_state.convergence_metrics.relative_objective_change < 1e-6 {
763 0
764 } else {
765 1
766 },
767 })
768 }
769
770 fn train<F>(
771 &mut self,
772 objective: &F,
773 initial_params: &ArrayView1<f64>,
774 ) -> OptimizeResult<OptimizeResults<f64>>
775 where
776 F: Fn(&ArrayView1<f64>) -> f64,
777 {
778 let mut best_result = OptimizeResults::<f64> {
779 x: initial_params.to_owned(),
780 fun: f64::INFINITY,
781 success: false,
782 nit: 0,
783 message: "Training not completed".to_string(),
784 jac: None,
785 hess: None,
786 constr: None,
787 nfev: 0,
788 njev: 0,
789 nhev: 0,
790 maxcv: 0,
791 status: 1, };
793
794 for episode in 0..self.config.num_episodes {
795 let result = self.run_episode(objective, initial_params)?;
796
797 if result.fun < best_result.fun {
798 best_result = result;
799 }
800
801 if (episode + 1) % 100 == 0 {
803 println!("Episode {}: Best objective = {:.6}, Avg advantage = {:.4}, Temperature = {:.4}",
804 episode + 1, best_result.fun, self.training_stats.avg_advantage, self.temperature);
805 }
806 }
807
808 best_result.x = self.best_params.clone();
809 best_result.fun = self.best_objective;
810 best_result.message = format!(
811 "A2C training completed: {} episodes, {} total steps, final best = {:.6}",
812 self.training_stats.episodes_completed,
813 self.training_stats.total_steps,
814 self.best_objective
815 );
816
817 Ok(best_result)
818 }
819
820 fn reset(&mut self) {
821 self.best_objective = f64::INFINITY;
822 self.best_params.fill(0.0);
823 self.training_stats = A2CTrainingStats::default();
824 self.temperature = 1.0;
825 self.baseline = 0.0;
826 self.entropy_coeff = 0.01;
827 }
828}
829
830#[allow(dead_code)]
832pub fn actor_critic_optimize<F>(
833 objective: F,
834 initial_params: &ArrayView1<f64>,
835 config: Option<RLOptimizationConfig>,
836 hidden_size: Option<usize>,
837) -> OptimizeResult<OptimizeResults<f64>>
838where
839 F: Fn(&ArrayView1<f64>) -> f64,
840{
841 let config = config.unwrap_or_default();
842 let hidden_size = hidden_size.unwrap_or(64);
843 let state_size = initial_params.len() + 8; let action_size = 6; let mut optimizer =
847 AdvantageActorCriticOptimizer::new(config, state_size, action_size, hidden_size);
848 optimizer.train(&objective, initial_params)
849}
850
851#[cfg(test)]
852mod tests {
853 use super::*;
854
855 #[test]
856 fn test_actor_network_creation() {
857 let actor = ActorNetwork::new(10, 20, 6, ActivationType::Tanh);
858 assert_eq!(actor._input_size, 10);
859 assert_eq!(actor.hidden_size, 20);
860 assert_eq!(actor.output_size, 6);
861 }
862
863 #[test]
864 fn test_critic_network_creation() {
865 let critic = CriticNetwork::new(10, 20, ActivationType::ReLU);
866 assert_eq!(critic._input_size, 10);
867 assert_eq!(critic.hidden_size, 20);
868 }
869
870 #[test]
871 fn test_actor_forward_pass() {
872 let actor = ActorNetwork::new(5, 10, 3, ActivationType::Tanh);
873 let input = Array1::from(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
874
875 let (hidden_raw, hidden_activated, output) = actor.forward(&input.view());
876
877 assert_eq!(hidden_raw.len(), 10);
878 assert_eq!(hidden_activated.len(), 10);
879 assert_eq!(output.len(), 3);
880 assert!(output.iter().all(|&x| x.is_finite()));
881 }
882
883 #[test]
884 fn test_critic_forward_pass() {
885 let critic = CriticNetwork::new(5, 10, ActivationType::ReLU);
886 let input = Array1::from(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
887
888 let (hidden_raw, hidden_activated, value) = critic.forward(&input.view());
889
890 assert_eq!(hidden_raw.len(), 10);
891 assert_eq!(hidden_activated.len(), 10);
892 assert!(value.is_finite());
893 }
894
895 #[test]
896 fn test_activation_functions() {
897 assert!((ActivationType::Tanh.apply(0.0) - 0.0).abs() < 1e-10);
898 assert!((ActivationType::ReLU.apply(-1.0) - 0.0).abs() < 1e-10);
899 assert!(ActivationType::ReLU.apply(1.0) == 1.0);
900 assert!((ActivationType::Sigmoid.apply(0.0) - 0.5).abs() < 1e-10);
901 }
902
903 #[test]
904 fn test_action_probabilities() {
905 let actor = ActorNetwork::new(3, 5, 4, ActivationType::Tanh);
906 let output = Array1::from(vec![1.0, 2.0, 0.5, -1.0]);
907
908 let probs = actor.action_probabilities(&output.view(), 1.0);
909
910 assert_eq!(probs.len(), 4);
911 assert!((probs.sum() - 1.0).abs() < 1e-6);
912 assert!(probs.iter().all(|&p| p >= 0.0 && p <= 1.0));
913 }
914
915 #[test]
916 fn test_a2c_optimizer_creation() {
917 let config = RLOptimizationConfig::default();
918 let optimizer = AdvantageActorCriticOptimizer::new(config, 10, 6, 20);
919
920 assert_eq!(optimizer.actor._input_size, 10);
921 assert_eq!(optimizer.actor.output_size, 6);
922 assert_eq!(optimizer.critic._input_size, 10);
923 }
924
925 #[test]
926 fn test_advantage_computation() {
927 let config = RLOptimizationConfig::default();
928 let optimizer = AdvantageActorCriticOptimizer::new(config, 5, 6, 10);
929
930 let advantage = optimizer.compute_advantage(1.0, 2.0, 3.0, false);
931 let expected = 1.0 + 0.99 * 3.0 - 2.0; assert!((advantage - expected).abs() < 1e-6);
934 }
935
936 #[test]
937 fn test_action_index_mapping() {
938 let config = RLOptimizationConfig::default();
939 let optimizer = AdvantageActorCriticOptimizer::new(config, 5, 6, 10);
940
941 let actions = vec![
942 OptimizationAction::GradientStep {
943 learning_rate: 0.01,
944 },
945 OptimizationAction::RandomPerturbation { magnitude: 0.1 },
946 OptimizationAction::MomentumUpdate { momentum: 0.9 },
947 OptimizationAction::AdaptiveLearningRate { factor: 0.5 },
948 OptimizationAction::ResetToBest,
949 OptimizationAction::Terminate,
950 ];
951
952 for (expected_idx, action) in actions.iter().enumerate() {
953 assert_eq!(optimizer.get_action_index(action), expected_idx);
954 }
955 }
956
957 #[test]
958 fn test_basic_a2c_optimization() {
959 let config = RLOptimizationConfig {
960 num_episodes: 10,
961 max_steps_per_episode: 20,
962 learning_rate: 0.01,
963 ..Default::default()
964 };
965
966 let objective = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
967 let initial = Array1::from(vec![2.0, 2.0]);
968
969 let result =
970 actor_critic_optimize(objective, &initial.view(), Some(config), Some(16)).unwrap();
971
972 let initial_obj = objective(&initial.view());
974 assert!(result.fun <= initial_obj);
975 assert!(result.nit > 0);
976 }
977}
978
979#[allow(dead_code)]
980pub fn placeholder() {
981 }