scirs2_optimize/reinforcement_learning/
actor_critic.rs

1//! Actor-Critic Methods for Optimization
2//!
3//! Advanced implementation of actor-critic algorithms that combine policy learning (actor)
4//! with value function estimation (critic) for sophisticated optimization strategies.
5
6use super::{
7    utils, Experience, ExperienceBuffer, ImprovementReward, OptimizationAction, OptimizationState,
8    RLOptimizationConfig, RLOptimizer, RewardFunction,
9};
10use crate::error::{OptimizeError, OptimizeResult};
11use crate::result::OptimizeResults;
12use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
13use scirs2_core::random::{rng, Rng};
14// use std::collections::VecDeque; // Unused import
15
16/// Actor network for policy learning
17#[derive(Debug, Clone)]
18pub struct ActorNetwork {
19    /// Hidden layer weights
20    pub hidden_weights: Array2<f64>,
21    /// Hidden layer biases
22    pub hidden_bias: Array1<f64>,
23    /// Output layer weights
24    pub output_weights: Array2<f64>,
25    /// Output layer biases
26    pub output_bias: Array1<f64>,
27    /// Network architecture
28    pub _input_size: usize,
29    pub hidden_size: usize,
30    pub output_size: usize,
31    /// Activation function type
32    pub activation: ActivationType,
33}
34
35/// Critic network for value function estimation
36#[derive(Debug, Clone)]
37pub struct CriticNetwork {
38    /// Hidden layer weights
39    pub hidden_weights: Array2<f64>,
40    /// Hidden layer biases
41    pub hidden_bias: Array1<f64>,
42    /// Output layer weights (single value output)
43    pub output_weights: Array1<f64>,
44    /// Output bias
45    pub output_bias: f64,
46    /// Network architecture
47    pub _input_size: usize,
48    pub hidden_size: usize,
49    /// Activation function type
50    pub activation: ActivationType,
51}
52
53/// Types of activation functions
54#[derive(Debug, Clone, Copy)]
55pub enum ActivationType {
56    Tanh,
57    ReLU,
58    Sigmoid,
59    LeakyReLU,
60    ELU,
61}
62
63impl ActivationType {
64    fn apply(&self, x: f64) -> f64 {
65        match self {
66            ActivationType::Tanh => x.tanh(),
67            ActivationType::ReLU => x.max(0.0),
68            ActivationType::Sigmoid => 1.0 / (1.0 + (-x).exp()),
69            ActivationType::LeakyReLU => {
70                if x > 0.0 {
71                    x
72                } else {
73                    0.01 * x
74                }
75            }
76            ActivationType::ELU => {
77                if x > 0.0 {
78                    x
79                } else {
80                    x.exp() - 1.0
81                }
82            }
83        }
84    }
85
86    fn derivative(&self, x: f64) -> f64 {
87        match self {
88            ActivationType::Tanh => {
89                let t = x.tanh();
90                1.0 - t * t
91            }
92            ActivationType::ReLU => {
93                if x > 0.0 {
94                    1.0
95                } else {
96                    0.0
97                }
98            }
99            ActivationType::Sigmoid => {
100                let s = 1.0 / (1.0 + (-x).exp());
101                s * (1.0 - s)
102            }
103            ActivationType::LeakyReLU => {
104                if x > 0.0 {
105                    1.0
106                } else {
107                    0.01
108                }
109            }
110            ActivationType::ELU => {
111                if x > 0.0 {
112                    1.0
113                } else {
114                    x.exp()
115                }
116            }
117        }
118    }
119}
120
121impl ActorNetwork {
122    /// Create new actor network
123    pub fn new(
124        input_size: usize,
125        hidden_size: usize,
126        output_size: usize,
127        activation: ActivationType,
128    ) -> Self {
129        let xavier_scale = (2.0 / (input_size + hidden_size) as f64).sqrt();
130
131        Self {
132            hidden_weights: Array2::from_shape_fn((hidden_size, input_size), |_| {
133                (scirs2_core::random::rng().random::<f64>() - 0.5) * 2.0 * xavier_scale
134            }),
135            hidden_bias: Array1::zeros(hidden_size),
136            output_weights: Array2::from_shape_fn((output_size, hidden_size), |_| {
137                (scirs2_core::random::rng().random::<f64>() - 0.5) * 2.0 * xavier_scale
138            }),
139            output_bias: Array1::zeros(output_size),
140            _input_size: input_size,
141            hidden_size,
142            output_size,
143            activation,
144        }
145    }
146
147    /// Forward pass through actor network
148    pub fn forward(&self, input: &ArrayView1<f64>) -> (Array1<f64>, Array1<f64>, Array1<f64>) {
149        // Hidden layer
150        let mut hidden_raw = Array1::zeros(self.hidden_size);
151        for i in 0..self.hidden_size {
152            for j in 0..self._input_size.min(input.len()) {
153                hidden_raw[i] += self.hidden_weights[[i, j]] * input[j];
154            }
155            hidden_raw[i] += self.hidden_bias[i];
156        }
157
158        let hidden_activated = hidden_raw.mapv(|x| self.activation.apply(x));
159
160        // Output layer
161        let mut output_raw = Array1::zeros(self.output_size);
162        for i in 0..self.output_size {
163            for j in 0..self.hidden_size {
164                output_raw[i] += self.output_weights[[i, j]] * hidden_activated[j];
165            }
166            output_raw[i] += self.output_bias[i];
167        }
168
169        let output_activated = output_raw.mapv(|x| self.activation.apply(x));
170
171        (hidden_raw, hidden_activated, output_activated)
172    }
173
174    /// Compute action probabilities with temperature scaling
175    pub fn action_probabilities(
176        &self,
177        policy_output: &ArrayView1<f64>,
178        temperature: f64,
179    ) -> Array1<f64> {
180        let scaled_output = policy_output.mapv(|x| x / temperature);
181        let max_val = scaled_output.fold(-f64::INFINITY, |a, &b| a.max(b));
182        let exp_output = scaled_output.mapv(|x| (x - max_val).exp());
183        let sum_exp = exp_output.sum();
184
185        if sum_exp > 0.0 {
186            exp_output / sum_exp
187        } else {
188            Array1::from_elem(policy_output.len(), 1.0 / policy_output.len() as f64)
189        }
190    }
191
192    /// Backward pass and update weights
193    pub fn backward_and_update(
194        &mut self,
195        input: &ArrayView1<f64>,
196        hidden_raw: &Array1<f64>,
197        hidden_activated: &Array1<f64>,
198        output_gradient: &ArrayView1<f64>,
199        learning_rate: f64,
200    ) {
201        // Output layer gradients
202        let output_raw_gradient = output_gradient.mapv(|g| g); // Assuming linear output
203
204        // Hidden layer gradients
205        let mut hidden_gradient: Array1<f64> = Array1::zeros(self.hidden_size);
206        for j in 0..self.hidden_size {
207            for i in 0..self.output_size {
208                hidden_gradient[j] += output_raw_gradient[i] * self.output_weights[[i, j]];
209            }
210            hidden_gradient[j] *= self.activation.derivative(hidden_raw[j]);
211        }
212
213        // Update output layer weights and biases
214        for i in 0..self.output_size {
215            for j in 0..self.hidden_size {
216                self.output_weights[[i, j]] -=
217                    learning_rate * output_raw_gradient[i] * hidden_activated[j];
218            }
219            self.output_bias[i] -= learning_rate * output_raw_gradient[i];
220        }
221
222        // Update hidden layer weights and biases
223        for i in 0..self.hidden_size {
224            for j in 0..self._input_size.min(input.len()) {
225                self.hidden_weights[[i, j]] -= learning_rate * hidden_gradient[i] * input[j];
226            }
227            self.hidden_bias[i] -= learning_rate * hidden_gradient[i];
228        }
229    }
230}
231
232impl CriticNetwork {
233    /// Create new critic network
234    pub fn new(input_size: usize, hidden_size: usize, activation: ActivationType) -> Self {
235        let xavier_scale = (2.0 / (input_size + hidden_size) as f64).sqrt();
236
237        Self {
238            hidden_weights: Array2::from_shape_fn((hidden_size, input_size), |_| {
239                (scirs2_core::random::rng().random::<f64>() - 0.5) * 2.0 * xavier_scale
240            }),
241            hidden_bias: Array1::zeros(hidden_size),
242            output_weights: Array1::from_shape_fn(hidden_size, |_| {
243                (scirs2_core::random::rng().random::<f64>() - 0.5) * 2.0 * xavier_scale
244            }),
245            output_bias: 0.0,
246            _input_size: input_size,
247            hidden_size,
248            activation,
249        }
250    }
251
252    /// Forward pass through critic network
253    pub fn forward(&self, input: &ArrayView1<f64>) -> (Array1<f64>, Array1<f64>, f64) {
254        // Hidden layer
255        let mut hidden_raw = Array1::zeros(self.hidden_size);
256        for i in 0..self.hidden_size {
257            for j in 0..self._input_size.min(input.len()) {
258                hidden_raw[i] += self.hidden_weights[[i, j]] * input[j];
259            }
260            hidden_raw[i] += self.hidden_bias[i];
261        }
262
263        let hidden_activated = hidden_raw.mapv(|x| self.activation.apply(x));
264
265        // Output layer (single value)
266        let mut output = 0.0;
267        for j in 0..self.hidden_size {
268            output += self.output_weights[j] * hidden_activated[j];
269        }
270        output += self.output_bias;
271
272        (hidden_raw, hidden_activated, output)
273    }
274
275    /// Backward pass and update weights
276    pub fn backward_and_update(
277        &mut self,
278        input: &ArrayView1<f64>,
279        hidden_raw: &Array1<f64>,
280        hidden_activated: &Array1<f64>,
281        target_value: f64,
282        predicted_value: f64,
283        learning_rate: f64,
284    ) {
285        let value_error = target_value - predicted_value;
286
287        // Hidden layer gradients
288        let mut hidden_gradient: Array1<f64> = Array1::zeros(self.hidden_size);
289        for j in 0..self.hidden_size {
290            hidden_gradient[j] =
291                value_error * self.output_weights[j] * self.activation.derivative(hidden_raw[j]);
292        }
293
294        // Update output layer weights and bias
295        for j in 0..self.hidden_size {
296            self.output_weights[j] += learning_rate * value_error * hidden_activated[j];
297        }
298        self.output_bias += learning_rate * value_error;
299
300        // Update hidden layer weights and biases
301        for i in 0..self.hidden_size {
302            for j in 0..self._input_size.min(input.len()) {
303                self.hidden_weights[[i, j]] += learning_rate * hidden_gradient[i] * input[j];
304            }
305            self.hidden_bias[i] += learning_rate * hidden_gradient[i];
306        }
307    }
308}
309
310/// Advantage Actor-Critic (A2C) optimizer
311#[derive(Debug, Clone)]
312pub struct AdvantageActorCriticOptimizer {
313    /// Configuration
314    config: RLOptimizationConfig,
315    /// Actor network
316    actor: ActorNetwork,
317    /// Critic network
318    critic: CriticNetwork,
319    /// Experience buffer
320    experience_buffer: ExperienceBuffer,
321    /// Reward function
322    reward_function: ImprovementReward,
323    /// Best solution found
324    best_params: Array1<f64>,
325    /// Best objective value
326    best_objective: f64,
327    /// Temperature for exploration
328    temperature: f64,
329    /// Baseline for variance reduction
330    baseline: f64,
331    /// Training statistics
332    training_stats: A2CTrainingStats,
333    /// Entropy coefficient for exploration
334    entropy_coeff: f64,
335    /// Value function coefficient
336    value_coeff: f64,
337}
338
339/// Training statistics for A2C
340#[derive(Debug, Clone)]
341pub struct A2CTrainingStats {
342    /// Average actor loss
343    pub avg_actor_loss: f64,
344    /// Average critic loss
345    pub avg_critic_loss: f64,
346    /// Average advantage
347    pub avg_advantage: f64,
348    /// Average entropy
349    pub avg_entropy: f64,
350    /// Episodes completed
351    pub episodes_completed: usize,
352    /// Total steps
353    pub total_steps: usize,
354}
355
356impl Default for A2CTrainingStats {
357    fn default() -> Self {
358        Self {
359            avg_actor_loss: 0.0,
360            avg_critic_loss: 0.0,
361            avg_advantage: 0.0,
362            avg_entropy: 0.0,
363            episodes_completed: 0,
364            total_steps: 0,
365        }
366    }
367}
368
369impl AdvantageActorCriticOptimizer {
370    /// Create new A2C optimizer
371    pub fn new(
372        config: RLOptimizationConfig,
373        state_size: usize,
374        action_size: usize,
375        hidden_size: usize,
376    ) -> Self {
377        let memory_size = config.memory_size;
378        Self {
379            config,
380            actor: ActorNetwork::new(state_size, hidden_size, action_size, ActivationType::Tanh),
381            critic: CriticNetwork::new(state_size, hidden_size, ActivationType::ReLU),
382            experience_buffer: ExperienceBuffer::new(memory_size),
383            reward_function: ImprovementReward::default(),
384            best_params: Array1::zeros(state_size),
385            best_objective: f64::INFINITY,
386            temperature: 1.0,
387            baseline: 0.0,
388            training_stats: A2CTrainingStats::default(),
389            entropy_coeff: 0.01,
390            value_coeff: 0.5,
391        }
392    }
393
394    /// Extract state features for networks
395    fn extract_state_features(&self, state: &OptimizationState) -> Array1<f64> {
396        let mut features = Vec::new();
397
398        // Parameter values (normalized)
399        for &param in state.parameters.iter() {
400            features.push(param.tanh());
401        }
402
403        // Objective value (log-normalized)
404        let log_obj = (state.objective_value.abs() + 1e-8).ln();
405        features.push(log_obj.tanh());
406
407        // Convergence metrics
408        features.push(
409            state
410                .convergence_metrics
411                .relative_objective_change
412                .ln()
413                .tanh(),
414        );
415        features.push(state.convergence_metrics.parameter_change_norm.tanh());
416        features.push((state.convergence_metrics.steps_since_improvement as f64 / 10.0).tanh());
417
418        // History features
419        if state.objective_history.len() >= 2 {
420            let recent_change = state.objective_history[state.objective_history.len() - 1]
421                - state.objective_history[state.objective_history.len() - 2];
422            features.push(recent_change.tanh());
423
424            let trend = if state.objective_history.len() >= 3 {
425                let slope = (state.objective_history[state.objective_history.len() - 1]
426                    - state.objective_history[0])
427                    / state.objective_history.len() as f64;
428                slope.tanh()
429            } else {
430                0.0
431            };
432            features.push(trend);
433        } else {
434            features.push(0.0);
435            features.push(0.0);
436        }
437
438        // Step information
439        features.push((state.step as f64 / self.config.max_steps_per_episode as f64).tanh());
440
441        Array1::from(features)
442    }
443
444    /// Select action using actor network with exploration
445    fn select_action_with_exploration(
446        &mut self,
447        state: &OptimizationState,
448    ) -> (OptimizationAction, Array1<f64>) {
449        let state_features = self.extract_state_features(state);
450        let (_, _, policy_output) = self.actor.forward(&state_features.view());
451
452        // Add exploration noise
453        let exploration_noise = if self.training_stats.episodes_completed
454            < self.config.num_episodes / 2
455        {
456            0.1 * (1.0
457                - self.training_stats.episodes_completed as f64 / self.config.num_episodes as f64)
458        } else {
459            0.01
460        };
461
462        let noisy_output = policy_output
463            .mapv(|x| x + (scirs2_core::random::rng().random::<f64>() - 0.5) * exploration_noise);
464        let action_probs = self
465            .actor
466            .action_probabilities(&noisy_output.view(), self.temperature);
467
468        // Sample action based on probabilities
469        let cumulative_probs: Vec<f64> = action_probs
470            .iter()
471            .scan(0.0, |acc, &p| {
472                *acc += p;
473                Some(*acc)
474            })
475            .collect();
476
477        let rand_val = scirs2_core::random::rng().random::<f64>();
478        let action_idx = cumulative_probs
479            .iter()
480            .position(|&cp| rand_val <= cp)
481            .unwrap_or(action_probs.len() - 1);
482
483        let action = self.decode_action_from_index(action_idx, &noisy_output);
484
485        (action, action_probs)
486    }
487
488    /// Decode action from action index and policy output
489    fn decode_action_from_index(
490        &self,
491        action_idx: usize,
492        policy_output: &Array1<f64>,
493    ) -> OptimizationAction {
494        let magnitude_factor = 1.0 + policy_output.get(1).unwrap_or(&0.0).abs();
495
496        match action_idx {
497            0 => OptimizationAction::GradientStep {
498                learning_rate: 0.001 * magnitude_factor,
499            },
500            1 => OptimizationAction::RandomPerturbation {
501                magnitude: 0.01 * magnitude_factor,
502            },
503            2 => OptimizationAction::MomentumUpdate {
504                momentum: 0.9 * (1.0 + policy_output.get(2).unwrap_or(&0.0) * 0.1),
505            },
506            3 => OptimizationAction::AdaptiveLearningRate {
507                factor: 0.5 + 0.5 * policy_output.get(3).unwrap_or(&0.0).tanh(),
508            },
509            4 => OptimizationAction::ResetToBest,
510            _ => OptimizationAction::Terminate,
511        }
512    }
513
514    /// Compute advantage function
515    fn compute_advantage(
516        &self,
517        reward: f64,
518        current_value: f64,
519        next_value: f64,
520        done: bool,
521    ) -> f64 {
522        let target = if done {
523            reward
524        } else {
525            reward + self.config.discount_factor * next_value
526        };
527        target - current_value
528    }
529
530    /// Update actor and critic networks using A2C algorithm
531    fn update_networks(&mut self, experiences: &[Experience]) -> Result<(), OptimizeError> {
532        let mut total_actor_loss = 0.0;
533        let mut total_critic_loss = 0.0;
534        let mut total_advantage = 0.0;
535        let mut total_entropy = 0.0;
536
537        for experience in experiences {
538            let state_features = self.extract_state_features(&experience.state);
539            let next_state_features = self.extract_state_features(&experience.next_state);
540
541            // Forward pass through critic for current and next state
542            let (hidden_raw, hidden_activated, current_value) =
543                self.critic.forward(&state_features.view());
544            let (_, _, next_value) = self.critic.forward(&next_state_features.view());
545
546            // Compute advantage
547            let advantage = self.compute_advantage(
548                experience.reward,
549                current_value,
550                next_value,
551                experience.done,
552            );
553
554            // Compute target value for critic
555            let target_value = if experience.done {
556                experience.reward
557            } else {
558                experience.reward + self.config.discount_factor * next_value
559            };
560
561            // Update critic network
562            self.critic.backward_and_update(
563                &state_features.view(),
564                &hidden_raw,
565                &hidden_activated,
566                target_value,
567                current_value,
568                self.config.learning_rate * self.value_coeff,
569            );
570
571            // Forward pass through actor
572            let (actor_hidden_raw, actor_hidden_activated, policy_output) =
573                self.actor.forward(&state_features.view());
574
575            let action_probs = self
576                .actor
577                .action_probabilities(&policy_output.view(), self.temperature);
578
579            // Compute entropy for exploration
580            let entropy = -action_probs
581                .iter()
582                .filter(|&&p| p > 1e-8)
583                .map(|&p| p * p.ln())
584                .sum::<f64>();
585
586            // Compute policy gradient (simplified REINFORCE with baseline)
587            let action_idx = self.get_action_index(&experience.action);
588            let log_prob = action_probs.get(action_idx).unwrap_or(&1e-8).ln();
589            let policy_loss = -log_prob * (advantage - self.baseline);
590
591            // Actor gradient (simplified)
592            let mut actor_gradient = Array1::zeros(policy_output.len());
593            if action_idx < actor_gradient.len() {
594                actor_gradient[action_idx] =
595                    -(advantage - self.baseline) / (action_probs[action_idx] + 1e-8);
596                // Add entropy bonus
597                actor_gradient[action_idx] += self.entropy_coeff * (1.0 + log_prob);
598            }
599
600            // Update actor network
601            self.actor.backward_and_update(
602                &state_features.view(),
603                &actor_hidden_raw,
604                &actor_hidden_activated,
605                &actor_gradient.view(),
606                self.config.learning_rate,
607            );
608
609            // Update statistics
610            total_actor_loss += policy_loss;
611            total_critic_loss += (target_value - current_value).powi(2);
612            total_advantage += advantage;
613            total_entropy += entropy;
614        }
615
616        // Update baseline
617        if !experiences.is_empty() {
618            self.baseline =
619                0.9 * self.baseline + 0.1 * (total_advantage / experiences.len() as f64);
620
621            // Update training statistics
622            let num_exp = experiences.len() as f64;
623            self.training_stats.avg_actor_loss =
624                0.9 * self.training_stats.avg_actor_loss + 0.1 * (total_actor_loss / num_exp);
625            self.training_stats.avg_critic_loss =
626                0.9 * self.training_stats.avg_critic_loss + 0.1 * (total_critic_loss / num_exp);
627            self.training_stats.avg_advantage =
628                0.9 * self.training_stats.avg_advantage + 0.1 * (total_advantage / num_exp);
629            self.training_stats.avg_entropy =
630                0.9 * self.training_stats.avg_entropy + 0.1 * (total_entropy / num_exp);
631        }
632
633        Ok(())
634    }
635
636    /// Get action index for gradient computation
637    fn get_action_index(&self, action: &OptimizationAction) -> usize {
638        match action {
639            OptimizationAction::GradientStep { .. } => 0,
640            OptimizationAction::RandomPerturbation { .. } => 1,
641            OptimizationAction::MomentumUpdate { .. } => 2,
642            OptimizationAction::AdaptiveLearningRate { .. } => 3,
643            OptimizationAction::ResetToBest => 4,
644            OptimizationAction::Terminate => 5,
645        }
646    }
647
648    /// Get training statistics
649    pub fn get_training_stats(&self) -> &A2CTrainingStats {
650        &self.training_stats
651    }
652
653    /// Adjust exploration parameters
654    fn adjust_exploration(&mut self) {
655        // Decay temperature for exploration
656        self.temperature = (self.temperature * 0.999).max(0.1);
657
658        // Adjust entropy coefficient
659        self.entropy_coeff = (self.entropy_coeff * 0.9995).max(0.001);
660    }
661}
662
663impl RLOptimizer for AdvantageActorCriticOptimizer {
664    fn config(&self) -> &RLOptimizationConfig {
665        &self.config
666    }
667
668    fn select_action(&mut self, state: &OptimizationState) -> OptimizationAction {
669        let (action, _) = self.select_action_with_exploration(state);
670        action
671    }
672
673    fn update(&mut self, experience: &Experience) -> Result<(), OptimizeError> {
674        self.experience_buffer.add(experience.clone());
675
676        // Update networks when we have enough experiences
677        if self.experience_buffer.size() >= self.config.batch_size {
678            let batch = self.experience_buffer.sample_batch(self.config.batch_size);
679            self.update_networks(&batch)?;
680        }
681
682        Ok(())
683    }
684
685    fn run_episode<F>(
686        &mut self,
687        objective: &F,
688        initial_params: &ArrayView1<f64>,
689    ) -> OptimizeResult<OptimizeResults<f64>>
690    where
691        F: Fn(&ArrayView1<f64>) -> f64,
692    {
693        let mut current_params = initial_params.to_owned();
694        let mut current_state = utils::create_state(current_params.clone(), objective, 0, None);
695        let mut momentum = Array1::zeros(initial_params.len());
696        let mut total_reward = 0.0;
697
698        for step in 0..self.config.max_steps_per_episode {
699            // Select action
700            let (action, _) = self.select_action_with_exploration(&current_state);
701
702            // Apply action
703            let new_params =
704                utils::apply_action(&current_state, &action, &self.best_params, &mut momentum);
705            let new_state =
706                utils::create_state(new_params, objective, step + 1, Some(&current_state));
707
708            // Compute reward
709            let reward = self
710                .reward_function
711                .compute_reward(&current_state, &action, &new_state);
712            total_reward += reward;
713
714            // Create experience
715            let experience = Experience {
716                state: current_state.clone(),
717                action: action.clone(),
718                reward,
719                next_state: new_state.clone(),
720                done: utils::should_terminate(&new_state, self.config.max_steps_per_episode),
721            };
722
723            // Update networks
724            self.update(&experience)?;
725
726            // Update best solution
727            if new_state.objective_value < self.best_objective {
728                self.best_objective = new_state.objective_value;
729                self.best_params = new_state.parameters.clone();
730            }
731
732            current_state = new_state;
733            current_params = current_state.parameters.clone();
734
735            // Check termination
736            if utils::should_terminate(&current_state, self.config.max_steps_per_episode)
737                || matches!(action, OptimizationAction::Terminate)
738            {
739                break;
740            }
741        }
742
743        self.training_stats.episodes_completed += 1;
744        self.training_stats.total_steps += current_state.step;
745
746        // Adjust exploration parameters
747        self.adjust_exploration();
748
749        Ok(OptimizeResults::<f64> {
750            x: current_params,
751            fun: current_state.objective_value,
752            success: current_state.convergence_metrics.relative_objective_change < 1e-6,
753            nit: current_state.step,
754            message: format!("A2C episode completed, total reward: {:.4}", total_reward),
755            jac: None,
756            hess: None,
757            constr: None,
758            nfev: current_state.step,
759            njev: 0,
760            nhev: 0,
761            maxcv: 0,
762            status: if current_state.convergence_metrics.relative_objective_change < 1e-6 {
763                0
764            } else {
765                1
766            },
767        })
768    }
769
770    fn train<F>(
771        &mut self,
772        objective: &F,
773        initial_params: &ArrayView1<f64>,
774    ) -> OptimizeResult<OptimizeResults<f64>>
775    where
776        F: Fn(&ArrayView1<f64>) -> f64,
777    {
778        let mut best_result = OptimizeResults::<f64> {
779            x: initial_params.to_owned(),
780            fun: f64::INFINITY,
781            success: false,
782            nit: 0,
783            message: "Training not completed".to_string(),
784            jac: None,
785            hess: None,
786            constr: None,
787            nfev: 0,
788            njev: 0,
789            nhev: 0,
790            maxcv: 0,
791            status: 1, // Failure status by default
792        };
793
794        for episode in 0..self.config.num_episodes {
795            let result = self.run_episode(objective, initial_params)?;
796
797            if result.fun < best_result.fun {
798                best_result = result;
799            }
800
801            // Periodic logging (every 100 episodes)
802            if (episode + 1) % 100 == 0 {
803                println!("Episode {}: Best objective = {:.6}, Avg advantage = {:.4}, Temperature = {:.4}",
804                    episode + 1, best_result.fun, self.training_stats.avg_advantage, self.temperature);
805            }
806        }
807
808        best_result.x = self.best_params.clone();
809        best_result.fun = self.best_objective;
810        best_result.message = format!(
811            "A2C training completed: {} episodes, {} total steps, final best = {:.6}",
812            self.training_stats.episodes_completed,
813            self.training_stats.total_steps,
814            self.best_objective
815        );
816
817        Ok(best_result)
818    }
819
820    fn reset(&mut self) {
821        self.best_objective = f64::INFINITY;
822        self.best_params.fill(0.0);
823        self.training_stats = A2CTrainingStats::default();
824        self.temperature = 1.0;
825        self.baseline = 0.0;
826        self.entropy_coeff = 0.01;
827    }
828}
829
830/// Convenience function for Actor-Critic optimization
831#[allow(dead_code)]
832pub fn actor_critic_optimize<F>(
833    objective: F,
834    initial_params: &ArrayView1<f64>,
835    config: Option<RLOptimizationConfig>,
836    hidden_size: Option<usize>,
837) -> OptimizeResult<OptimizeResults<f64>>
838where
839    F: Fn(&ArrayView1<f64>) -> f64,
840{
841    let config = config.unwrap_or_default();
842    let hidden_size = hidden_size.unwrap_or(64);
843    let state_size = initial_params.len() + 8; // Additional features
844    let action_size = 6; // Number of different action types
845
846    let mut optimizer =
847        AdvantageActorCriticOptimizer::new(config, state_size, action_size, hidden_size);
848    optimizer.train(&objective, initial_params)
849}
850
851#[cfg(test)]
852mod tests {
853    use super::*;
854
855    #[test]
856    fn test_actor_network_creation() {
857        let actor = ActorNetwork::new(10, 20, 6, ActivationType::Tanh);
858        assert_eq!(actor._input_size, 10);
859        assert_eq!(actor.hidden_size, 20);
860        assert_eq!(actor.output_size, 6);
861    }
862
863    #[test]
864    fn test_critic_network_creation() {
865        let critic = CriticNetwork::new(10, 20, ActivationType::ReLU);
866        assert_eq!(critic._input_size, 10);
867        assert_eq!(critic.hidden_size, 20);
868    }
869
870    #[test]
871    fn test_actor_forward_pass() {
872        let actor = ActorNetwork::new(5, 10, 3, ActivationType::Tanh);
873        let input = Array1::from(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
874
875        let (hidden_raw, hidden_activated, output) = actor.forward(&input.view());
876
877        assert_eq!(hidden_raw.len(), 10);
878        assert_eq!(hidden_activated.len(), 10);
879        assert_eq!(output.len(), 3);
880        assert!(output.iter().all(|&x| x.is_finite()));
881    }
882
883    #[test]
884    fn test_critic_forward_pass() {
885        let critic = CriticNetwork::new(5, 10, ActivationType::ReLU);
886        let input = Array1::from(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
887
888        let (hidden_raw, hidden_activated, value) = critic.forward(&input.view());
889
890        assert_eq!(hidden_raw.len(), 10);
891        assert_eq!(hidden_activated.len(), 10);
892        assert!(value.is_finite());
893    }
894
895    #[test]
896    fn test_activation_functions() {
897        assert!((ActivationType::Tanh.apply(0.0) - 0.0).abs() < 1e-10);
898        assert!((ActivationType::ReLU.apply(-1.0) - 0.0).abs() < 1e-10);
899        assert!(ActivationType::ReLU.apply(1.0) == 1.0);
900        assert!((ActivationType::Sigmoid.apply(0.0) - 0.5).abs() < 1e-10);
901    }
902
903    #[test]
904    fn test_action_probabilities() {
905        let actor = ActorNetwork::new(3, 5, 4, ActivationType::Tanh);
906        let output = Array1::from(vec![1.0, 2.0, 0.5, -1.0]);
907
908        let probs = actor.action_probabilities(&output.view(), 1.0);
909
910        assert_eq!(probs.len(), 4);
911        assert!((probs.sum() - 1.0).abs() < 1e-6);
912        assert!(probs.iter().all(|&p| p >= 0.0 && p <= 1.0));
913    }
914
915    #[test]
916    fn test_a2c_optimizer_creation() {
917        let config = RLOptimizationConfig::default();
918        let optimizer = AdvantageActorCriticOptimizer::new(config, 10, 6, 20);
919
920        assert_eq!(optimizer.actor._input_size, 10);
921        assert_eq!(optimizer.actor.output_size, 6);
922        assert_eq!(optimizer.critic._input_size, 10);
923    }
924
925    #[test]
926    fn test_advantage_computation() {
927        let config = RLOptimizationConfig::default();
928        let optimizer = AdvantageActorCriticOptimizer::new(config, 5, 6, 10);
929
930        let advantage = optimizer.compute_advantage(1.0, 2.0, 3.0, false);
931        let expected = 1.0 + 0.99 * 3.0 - 2.0; // reward + gamma * next_value - current_value
932
933        assert!((advantage - expected).abs() < 1e-6);
934    }
935
936    #[test]
937    fn test_action_index_mapping() {
938        let config = RLOptimizationConfig::default();
939        let optimizer = AdvantageActorCriticOptimizer::new(config, 5, 6, 10);
940
941        let actions = vec![
942            OptimizationAction::GradientStep {
943                learning_rate: 0.01,
944            },
945            OptimizationAction::RandomPerturbation { magnitude: 0.1 },
946            OptimizationAction::MomentumUpdate { momentum: 0.9 },
947            OptimizationAction::AdaptiveLearningRate { factor: 0.5 },
948            OptimizationAction::ResetToBest,
949            OptimizationAction::Terminate,
950        ];
951
952        for (expected_idx, action) in actions.iter().enumerate() {
953            assert_eq!(optimizer.get_action_index(action), expected_idx);
954        }
955    }
956
957    #[test]
958    fn test_basic_a2c_optimization() {
959        let config = RLOptimizationConfig {
960            num_episodes: 10,
961            max_steps_per_episode: 20,
962            learning_rate: 0.01,
963            ..Default::default()
964        };
965
966        let objective = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
967        let initial = Array1::from(vec![2.0, 2.0]);
968
969        let result =
970            actor_critic_optimize(objective, &initial.view(), Some(config), Some(16)).unwrap();
971
972        // Should make some progress
973        let initial_obj = objective(&initial.view());
974        assert!(result.fun <= initial_obj);
975        assert!(result.nit > 0);
976    }
977}
978
979#[allow(dead_code)]
980pub fn placeholder() {
981    // Placeholder function to prevent unused module warnings
982}