scirs2_optimize/reinforcement_learning/
mod.rs

1//! Reinforcement Learning Optimization Module
2//!
3//! This module implements optimization algorithms based on reinforcement learning
4//! principles, where optimization strategies are learned through interaction
5//! with the objective function environment.
6//!
7//! # Key Features
8//!
9//! - **Policy Gradient Optimization**: Learn optimization policies using policy gradients
10//! - **Q-Learning for Optimization**: Value-based approach to optimization strategy learning
11//! - **Actor-Critic Methods**: Combined policy and value learning for optimization
12//! - **Bandit-based Optimization**: Multi-armed bandit approaches for hyperparameter tuning
13//! - **Evolutionary Strategies**: Population-based RL optimization
14//! - **Meta-Learning**: Learning to optimize across different problem classes
15//!
16//! # Applications
17//!
18//! - Automatic hyperparameter tuning
19//! - Adaptive optimization algorithms
20//! - Black-box optimization
21//! - Neural architecture search
22//! - AutoML optimization pipelines
23
24use crate::error::OptimizeResult;
25use crate::result::OptimizeResults;
26use ndarray::{Array1, ArrayView1};
27use rand::{rng, Rng};
28
29pub mod actor_critic;
30pub mod bandit_optimization;
31pub mod evolutionary_strategies;
32pub mod meta_learning;
33pub mod policy_gradient;
34pub mod q_learning_optimization;
35
36#[allow(ambiguous_glob_reexports)]
37pub use actor_critic::*;
38#[allow(ambiguous_glob_reexports)]
39pub use bandit_optimization::*;
40#[allow(ambiguous_glob_reexports)]
41pub use evolutionary_strategies::*;
42#[allow(ambiguous_glob_reexports)]
43pub use meta_learning::*;
44#[allow(ambiguous_glob_reexports)]
45pub use policy_gradient::*;
46#[allow(ambiguous_glob_reexports)]
47pub use q_learning_optimization::*;
48
49/// Configuration for reinforcement learning optimization
50#[derive(Debug, Clone)]
51pub struct RLOptimizationConfig {
52    /// Number of episodes for training
53    pub num_episodes: usize,
54    /// Maximum steps per episode
55    pub max_steps_per_episode: usize,
56    /// Learning rate for policy/value updates
57    pub learning_rate: f64,
58    /// Discount factor for future rewards
59    pub discount_factor: f64,
60    /// Exploration parameter (epsilon for epsilon-greedy)
61    pub exploration_rate: f64,
62    /// Decay rate for exploration
63    pub exploration_decay: f64,
64    /// Minimum exploration rate
65    pub min_exploration_rate: f64,
66    /// Batch size for experience replay
67    pub batch_size: usize,
68    /// Memory buffer size
69    pub memory_size: usize,
70    /// Whether to use experience replay
71    pub use_experience_replay: bool,
72}
73
74impl Default for RLOptimizationConfig {
75    fn default() -> Self {
76        Self {
77            num_episodes: 1000,
78            max_steps_per_episode: 100,
79            learning_rate: 0.001,
80            discount_factor: 0.99,
81            exploration_rate: 0.1,
82            exploration_decay: 0.995,
83            min_exploration_rate: 0.01,
84            batch_size: 32,
85            memory_size: 10000,
86            use_experience_replay: true,
87        }
88    }
89}
90
91/// State representation for optimization RL
92#[derive(Debug, Clone)]
93pub struct OptimizationState {
94    /// Current parameter values
95    pub parameters: Array1<f64>,
96    /// Current objective value
97    pub objective_value: f64,
98    /// Gradient information (if available)
99    pub gradient: Option<Array1<f64>>,
100    /// Step number in episode
101    pub step: usize,
102    /// History of recent objective values
103    pub objective_history: Vec<f64>,
104    /// Convergence indicators
105    pub convergence_metrics: ConvergenceMetrics,
106}
107
108/// Convergence metrics for RL state
109#[derive(Debug, Clone)]
110pub struct ConvergenceMetrics {
111    /// Relative change in objective
112    pub relative_objective_change: f64,
113    /// Gradient norm (if available)
114    pub gradient_norm: Option<f64>,
115    /// Parameter change norm
116    pub parameter_change_norm: f64,
117    /// Number of steps since last improvement
118    pub steps_since_improvement: usize,
119}
120
121/// Action space for optimization RL
122#[derive(Debug, Clone)]
123pub enum OptimizationAction {
124    /// Gradient-based step with learning rate
125    GradientStep { learning_rate: f64 },
126    /// Random perturbation with magnitude
127    RandomPerturbation { magnitude: f64 },
128    /// Momentum update with coefficient
129    MomentumUpdate { momentum: f64 },
130    /// Adaptive learning rate adjustment
131    AdaptiveLearningRate { factor: f64 },
132    /// Reset to best known solution
133    ResetToBest,
134    /// Early termination
135    Terminate,
136}
137
138/// Experience tuple for RL
139#[derive(Debug, Clone)]
140pub struct Experience {
141    /// State before action
142    pub state: OptimizationState,
143    /// Action taken
144    pub action: OptimizationAction,
145    /// Reward received
146    pub reward: f64,
147    /// Next state
148    pub next_state: OptimizationState,
149    /// Whether episode terminated
150    pub done: bool,
151}
152
153/// Trait for RL-based optimizers
154pub trait RLOptimizer {
155    /// Configuration
156    fn config(&self) -> &RLOptimizationConfig;
157
158    /// Select action given current state
159    fn select_action(&mut self, state: &OptimizationState) -> OptimizationAction;
160
161    /// Update policy/value function based on experience
162    fn update(&mut self, experience: &Experience) -> OptimizeResult<()>;
163
164    /// Run optimization episode
165    fn run_episode<F>(
166        &mut self,
167        objective: &F,
168        initial_params: &ArrayView1<f64>,
169    ) -> OptimizeResult<OptimizeResults<f64>>
170    where
171        F: Fn(&ArrayView1<f64>) -> f64;
172
173    /// Train the RL optimizer
174    fn train<F>(
175        &mut self,
176        objective: &F,
177        initial_params: &ArrayView1<f64>,
178    ) -> OptimizeResult<OptimizeResults<f64>>
179    where
180        F: Fn(&ArrayView1<f64>) -> f64;
181
182    /// Reset optimizer state
183    fn reset(&mut self);
184}
185
186/// Reward function for optimization RL
187pub trait RewardFunction {
188    /// Compute reward based on state transition
189    fn compute_reward(
190        &self,
191        prev_state: &OptimizationState,
192        action: &OptimizationAction,
193        new_state: &OptimizationState,
194    ) -> f64;
195}
196
197/// Simple improvement-based reward function
198#[derive(Debug, Clone)]
199pub struct ImprovementReward {
200    /// Scaling factor for objective improvement
201    pub improvement_scale: f64,
202    /// Penalty for taking steps
203    pub step_penalty: f64,
204    /// Bonus for convergence
205    pub convergence_bonus: f64,
206}
207
208impl Default for ImprovementReward {
209    fn default() -> Self {
210        Self {
211            improvement_scale: 10.0,
212            step_penalty: 0.01,
213            convergence_bonus: 1.0,
214        }
215    }
216}
217
218impl RewardFunction for ImprovementReward {
219    fn compute_reward(
220        &self,
221        prev_state: &OptimizationState,
222        _action: &OptimizationAction,
223        new_state: &OptimizationState,
224    ) -> f64 {
225        // Reward for objective improvement
226        let improvement = prev_state.objective_value - new_state.objective_value;
227        let improvement_reward = self.improvement_scale * improvement;
228
229        // Penalty for taking steps (encourages efficiency)
230        let step_penalty = -self.step_penalty;
231
232        // Bonus for convergence
233        let convergence_bonus = if new_state.convergence_metrics.relative_objective_change < 1e-6 {
234            self.convergence_bonus
235        } else {
236            0.0
237        };
238
239        improvement_reward + step_penalty + convergence_bonus
240    }
241}
242
243/// Experience replay buffer
244#[derive(Debug, Clone)]
245pub struct ExperienceBuffer {
246    /// Buffer for experiences
247    pub buffer: Vec<Experience>,
248    /// Maximum buffer size
249    pub max_size: usize,
250    /// Current position (for circular buffer)
251    pub position: usize,
252}
253
254impl ExperienceBuffer {
255    /// Create new experience buffer
256    pub fn new(max_size: usize) -> Self {
257        Self {
258            buffer: Vec::with_capacity(max_size),
259            max_size,
260            position: 0,
261        }
262    }
263
264    /// Add experience to buffer
265    pub fn add(&mut self, experience: Experience) {
266        if self.buffer.len() < self.max_size {
267            self.buffer.push(experience);
268        } else {
269            self.buffer[self.position] = experience;
270            self.position = (self.position + 1) % self.max_size;
271        }
272    }
273
274    /// Sample batch of experiences
275    pub fn sample_batch(&self, batchsize: usize) -> Vec<Experience> {
276        let mut batch = Vec::with_capacity(batchsize);
277        for _ in 0..batchsize.min(self.buffer.len()) {
278            let idx = rand::rng().random_range(0..self.buffer.len());
279            batch.push(self.buffer[idx].clone());
280        }
281        batch
282    }
283
284    /// Get buffer size
285    pub fn size(&self) -> usize {
286        self.buffer.len()
287    }
288}
289
290/// Utility functions for RL optimization
291pub mod utils {
292    use super::*;
293
294    /// Create optimization state from parameters and objective
295    pub fn create_state<F>(
296        parameters: Array1<f64>,
297        objective: &F,
298        step: usize,
299        prev_state: Option<&OptimizationState>,
300    ) -> OptimizationState
301    where
302        F: Fn(&ArrayView1<f64>) -> f64,
303    {
304        let objective_value = objective(&parameters.view());
305
306        // Compute convergence metrics
307        let convergence_metrics = if let Some(prev) = prev_state {
308            let relative_change = (prev.objective_value - objective_value).abs()
309                / (prev.objective_value.abs() + 1e-12);
310
311            // Ensure parameter arrays have the same shape before computing difference
312            let param_change = if parameters.len() == prev.parameters.len() {
313                (&parameters - &prev.parameters)
314                    .mapv(|x| x * x)
315                    .sum()
316                    .sqrt()
317            } else {
318                // If shapes don't match, use parameter norm as fallback
319                parameters.mapv(|x| x * x).sum().sqrt()
320            };
321            let steps_since_improvement = if objective_value < prev.objective_value {
322                0
323            } else {
324                prev.convergence_metrics.steps_since_improvement + 1
325            };
326
327            ConvergenceMetrics {
328                relative_objective_change: relative_change,
329                gradient_norm: None,
330                parameter_change_norm: param_change,
331                steps_since_improvement,
332            }
333        } else {
334            ConvergenceMetrics {
335                relative_objective_change: f64::INFINITY,
336                gradient_norm: None,
337                parameter_change_norm: 0.0,
338                steps_since_improvement: 0,
339            }
340        };
341
342        // Update objective history
343        let mut objective_history = prev_state
344            .map(|s| s.objective_history.clone())
345            .unwrap_or_default();
346        objective_history.push(objective_value);
347        if objective_history.len() > 10 {
348            objective_history.remove(0);
349        }
350
351        OptimizationState {
352            parameters,
353            objective_value,
354            gradient: None, // Would be computed if needed
355            step,
356            objective_history,
357            convergence_metrics,
358        }
359    }
360
361    /// Apply action to current state
362    pub fn apply_action(
363        state: &OptimizationState,
364        action: &OptimizationAction,
365        best_params: &Array1<f64>,
366        momentum: &mut Array1<f64>,
367    ) -> Array1<f64> {
368        match action {
369            OptimizationAction::GradientStep { learning_rate } => {
370                // Simplified: use finite difference gradient
371                let mut new_params = state.parameters.clone();
372
373                // Random direction as proxy for gradient
374                for i in 0..new_params.len() {
375                    let step = (rand::rng().random::<f64>() - 0.5) * learning_rate;
376                    new_params[i] += step;
377                }
378
379                new_params
380            }
381            OptimizationAction::RandomPerturbation { magnitude } => {
382                let mut new_params = state.parameters.clone();
383                for i in 0..new_params.len() {
384                    let perturbation = (rand::rng().random::<f64>() - 0.5) * 2.0 * magnitude;
385                    new_params[i] += perturbation;
386                }
387                new_params
388            }
389            OptimizationAction::MomentumUpdate {
390                momentum: momentum_coeff,
391            } => {
392                // Ensure momentum has the same shape as parameters
393                if momentum.len() != state.parameters.len() {
394                    *momentum = Array1::zeros(state.parameters.len());
395                }
396
397                // Update momentum (simplified)
398                for i in 0..momentum.len().min(state.parameters.len()) {
399                    let gradient_estimate = (rand::rng().random::<f64>() - 0.5) * 0.1;
400                    momentum[i] =
401                        momentum_coeff * momentum[i] + (1.0 - momentum_coeff) * gradient_estimate;
402                }
403
404                &state.parameters - &*momentum
405            }
406            OptimizationAction::AdaptiveLearningRate { factor: _factor } => {
407                // Adaptive step (simplified)
408                let step_size = 0.01 / (1.0 + state.step as f64 * 0.01);
409                let direction = Array1::from(vec![step_size; state.parameters.len()]);
410                &state.parameters - &direction
411            }
412            OptimizationAction::ResetToBest => best_params.clone(),
413            OptimizationAction::Terminate => state.parameters.clone(),
414        }
415    }
416
417    /// Check if optimization should terminate
418    pub fn should_terminate(state: &OptimizationState, max_steps: usize) -> bool {
419        state.step >= max_steps
420            || state.convergence_metrics.relative_objective_change < 1e-8
421            || state.convergence_metrics.steps_since_improvement > 50
422    }
423}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428
429    #[test]
430    fn test_optimization_state_creation() {
431        let params = Array1::from(vec![1.0, 2.0]);
432        let objective = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
433
434        let state = utils::create_state(params, &objective, 0, None);
435
436        assert_eq!(state.parameters.len(), 2);
437        assert_eq!(state.objective_value, 5.0);
438        assert_eq!(state.step, 0);
439    }
440
441    #[test]
442    fn test_experience_buffer() {
443        let mut buffer = ExperienceBuffer::new(5);
444
445        let params = Array1::from(vec![1.0]);
446        let objective = |x: &ArrayView1<f64>| x[0].powi(2);
447        let state = utils::create_state(params.clone(), &objective, 0, None);
448
449        let experience = Experience {
450            state: state.clone(),
451            action: OptimizationAction::GradientStep {
452                learning_rate: 0.01,
453            },
454            reward: 1.0,
455            next_state: state,
456            done: false,
457        };
458
459        buffer.add(experience);
460        assert_eq!(buffer.size(), 1);
461
462        let batch = buffer.sample_batch(1);
463        assert_eq!(batch.len(), 1);
464    }
465
466    #[test]
467    fn test_improvement_reward() {
468        let reward_fn = ImprovementReward::default();
469
470        let params1 = Array1::from(vec![2.0]);
471        let params2 = Array1::from(vec![1.0]);
472        let objective = |x: &ArrayView1<f64>| x[0].powi(2);
473
474        let state1 = utils::create_state(params1, &objective, 0, None);
475        let state2 = utils::create_state(params2, &objective, 1, Some(&state1));
476
477        let action = OptimizationAction::GradientStep { learning_rate: 0.1 };
478        let reward = reward_fn.compute_reward(&state1, &action, &state2);
479
480        // Should get positive reward for improvement (4.0 -> 1.0)
481        assert!(reward > 0.0);
482    }
483
484    #[test]
485    fn test_action_application() {
486        let params = Array1::from(vec![1.0, 2.0]);
487        let objective = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
488        let state = utils::create_state(params.clone(), &objective, 0, None);
489        let mut momentum = Array1::zeros(2);
490
491        let action = OptimizationAction::RandomPerturbation { magnitude: 0.1 };
492        let new_params = utils::apply_action(&state, &action, &params, &mut momentum);
493
494        assert_eq!(new_params.len(), 2);
495        // Parameters should have changed due to perturbation
496        assert!(new_params != state.parameters);
497    }
498
499    #[test]
500    fn test_termination_condition() {
501        let params = Array1::from(vec![1.0]);
502        let objective = |x: &ArrayView1<f64>| x[0].powi(2);
503        let state = utils::create_state(params, &objective, 100, None);
504
505        // Should terminate due to max steps
506        assert!(utils::should_terminate(&state, 50));
507    }
508
509    #[test]
510    fn test_convergence_metrics() {
511        let params1 = Array1::from(vec![2.0]);
512        let params2 = Array1::from(vec![1.9]);
513        let objective = |x: &ArrayView1<f64>| x[0].powi(2);
514
515        let state1 = utils::create_state(params1, &objective, 0, None);
516        let state2 = utils::create_state(params2, &objective, 1, Some(&state1));
517
518        assert!(state2.convergence_metrics.relative_objective_change > 0.0);
519        assert!(state2.convergence_metrics.parameter_change_norm > 0.0);
520        assert_eq!(state2.convergence_metrics.steps_since_improvement, 0); // Improvement occurred
521    }
522}