scirs2_optimize/reinforcement_learning/
q_learning_optimization.rs

1//! Q-Learning for Optimization
2//!
3//! Value-based reinforcement learning approach to optimization strategy learning.
4
5use super::{utils, OptimizationAction, OptimizationState, RLOptimizationConfig, RLOptimizer};
6use crate::error::{OptimizeError, OptimizeResult};
7use crate::result::OptimizeResults;
8use ndarray::{Array1, ArrayView1};
9// Unused import
10// use scirs2_core::error::CoreResult;
11use rand::{rng, Rng};
12use std::collections::HashMap;
13
14/// Q-Learning optimizer for optimization problems
15#[derive(Debug, Clone)]
16pub struct QLearningOptimizer {
17    /// Configuration
18    config: RLOptimizationConfig,
19    /// Q-table (simplified state-action values)
20    q_table: HashMap<String, f64>,
21    /// Current exploration rate
22    exploration_rate: f64,
23    /// Best solution found
24    best_params: Array1<f64>,
25    /// Best objective value
26    best_objective: f64,
27}
28
29impl QLearningOptimizer {
30    /// Create new Q-learning optimizer
31    pub fn new(config: RLOptimizationConfig, numparams: usize) -> Self {
32        let exploration_rate = config.exploration_rate;
33        Self {
34            config,
35            q_table: HashMap::new(),
36            exploration_rate,
37            best_params: Array1::zeros(numparams),
38            best_objective: f64::INFINITY,
39        }
40    }
41
42    /// Convert state-action pair to string key
43    fn state_action_key(&self, state: &OptimizationState, action: &OptimizationAction) -> String {
44        // Simplified state representation
45        let obj_bucket = (state.objective_value * 10.0) as i32;
46        let step_bucket = state.step / 10;
47        let action_id = match action {
48            OptimizationAction::GradientStep { .. } => 0,
49            OptimizationAction::RandomPerturbation { .. } => 1,
50            OptimizationAction::MomentumUpdate { .. } => 2,
51            OptimizationAction::AdaptiveLearningRate { .. } => 3,
52            OptimizationAction::ResetToBest => 4,
53            OptimizationAction::Terminate => 5,
54        };
55
56        format!("{}_{}__{}", obj_bucket, step_bucket, action_id)
57    }
58
59    /// Get Q-value for state-action pair
60    fn get_q_value(&self, state: &OptimizationState, action: &OptimizationAction) -> f64 {
61        let key = self.state_action_key(state, action);
62        *self.q_table.get(&key).unwrap_or(&0.0)
63    }
64
65    /// Update Q-value for state-action pair
66    fn update_q_value(
67        &mut self,
68        state: &OptimizationState,
69        action: &OptimizationAction,
70        new_value: f64,
71    ) {
72        let key = self.state_action_key(state, action);
73        self.q_table.insert(key, new_value);
74    }
75
76    /// Get all possible actions
77    fn get_possible_actions(&self) -> Vec<OptimizationAction> {
78        vec![
79            OptimizationAction::GradientStep {
80                learning_rate: 0.01,
81            },
82            OptimizationAction::RandomPerturbation { magnitude: 0.1 },
83            OptimizationAction::MomentumUpdate { momentum: 0.9 },
84            OptimizationAction::AdaptiveLearningRate { factor: 0.5 },
85            OptimizationAction::ResetToBest,
86            OptimizationAction::Terminate,
87        ]
88    }
89}
90
91impl RLOptimizer for QLearningOptimizer {
92    fn config(&self) -> &RLOptimizationConfig {
93        &self.config
94    }
95
96    fn select_action(&mut self, state: &OptimizationState) -> OptimizationAction {
97        // Epsilon-greedy action selection
98        if rand::rng().random_range(0.0..1.0) < self.exploration_rate {
99            // Random action
100            let actions = self.get_possible_actions();
101            let idx = rand::rng().random_range(0..actions.len());
102            actions[idx].clone()
103        } else {
104            // Greedy action
105            let actions = self.get_possible_actions();
106            let mut best_action = actions[0].clone();
107            let mut best_q = self.get_q_value(state, &best_action);
108
109            for action in &actions[1..] {
110                let q_value = self.get_q_value(state, action);
111                if q_value > best_q {
112                    best_q = q_value;
113                    best_action = action.clone();
114                }
115            }
116
117            best_action
118        }
119    }
120
121    fn update(&mut self, experience: &super::Experience) -> Result<(), OptimizeError> {
122        // Q-learning update: Q(s,a) = Q(s,a) + α[r + γ max Q(s',a') - Q(s,a)]
123        let current_q = self.get_q_value(&experience.state, &experience.action);
124
125        let max_next_q = if experience.done {
126            0.0
127        } else {
128            let actions = self.get_possible_actions();
129            actions
130                .iter()
131                .map(|a| self.get_q_value(&experience.next_state, a))
132                .fold(f64::NEG_INFINITY, f64::max)
133        };
134
135        let target = experience.reward + self.config.discount_factor * max_next_q;
136        let new_q = current_q + self.config.learning_rate * (target - current_q);
137
138        self.update_q_value(&experience.state, &experience.action, new_q);
139
140        Ok(())
141    }
142
143    fn run_episode<F>(
144        &mut self,
145        objective: &F,
146        initial_params: &ArrayView1<f64>,
147    ) -> OptimizeResult<OptimizeResults<f64>>
148    where
149        F: Fn(&ArrayView1<f64>) -> f64,
150    {
151        let mut current_params = initial_params.to_owned();
152        let mut current_state = utils::create_state(current_params.clone(), objective, 0, None);
153        let mut momentum = Array1::zeros(initial_params.len());
154
155        for step in 0..self.config.max_steps_per_episode {
156            let action = self.select_action(&current_state);
157            let new_params =
158                utils::apply_action(&current_state, &action, &self.best_params, &mut momentum);
159            let new_state =
160                utils::create_state(new_params, objective, step + 1, Some(&current_state));
161
162            // Simple reward: improvement in objective
163            let reward = current_state.objective_value - new_state.objective_value;
164
165            let experience = super::Experience {
166                state: current_state.clone(),
167                action: action.clone(),
168                reward,
169                next_state: new_state.clone(),
170                done: utils::should_terminate(&new_state, self.config.max_steps_per_episode),
171            };
172
173            self.update(&experience)?;
174
175            if new_state.objective_value < self.best_objective {
176                self.best_objective = new_state.objective_value;
177                self.best_params = new_state.parameters.clone();
178            }
179
180            current_state = new_state;
181            current_params = current_state.parameters.clone();
182
183            if utils::should_terminate(&current_state, self.config.max_steps_per_episode)
184                || matches!(action, OptimizationAction::Terminate)
185            {
186                break;
187            }
188        }
189
190        // Decay exploration rate
191        self.exploration_rate = (self.exploration_rate * self.config.exploration_decay)
192            .max(self.config.min_exploration_rate);
193
194        Ok(OptimizeResults::<f64> {
195            x: current_params,
196            fun: current_state.objective_value,
197            success: current_state.convergence_metrics.relative_objective_change < 1e-6,
198            nit: current_state.step,
199            message: "Q-learning episode completed".to_string(),
200            jac: None,
201            hess: None,
202            constr: None,
203            nfev: current_state.step,
204            njev: 0,
205            nhev: 0,
206            maxcv: 0,
207            status: if current_state.convergence_metrics.relative_objective_change < 1e-6 {
208                0
209            } else {
210                1
211            },
212        })
213    }
214
215    fn train<F>(
216        &mut self,
217        objective: &F,
218        initial_params: &ArrayView1<f64>,
219    ) -> OptimizeResult<OptimizeResults<f64>>
220    where
221        F: Fn(&ArrayView1<f64>) -> f64,
222    {
223        let mut best_result = OptimizeResults::<f64> {
224            x: initial_params.to_owned(),
225            fun: f64::INFINITY,
226            success: false,
227            nit: 0,
228            nfev: 0,
229            njev: 0,
230            nhev: 0,
231            maxcv: 0,
232            status: 0,
233            message: "Training not completed".to_string(),
234            jac: None,
235            hess: None,
236            constr: None,
237        };
238
239        for _episode in 0..self.config.num_episodes {
240            let result = self.run_episode(objective, initial_params)?;
241
242            if result.fun < best_result.fun {
243                best_result = result;
244            }
245        }
246
247        best_result.x = self.best_params.clone();
248        best_result.fun = self.best_objective;
249        best_result.message = "Q-learning training completed".to_string();
250
251        Ok(best_result)
252    }
253
254    fn reset(&mut self) {
255        self.q_table.clear();
256        self.exploration_rate = self.config.exploration_rate;
257        self.best_objective = f64::INFINITY;
258        self.best_params.fill(0.0);
259    }
260}
261
262#[allow(dead_code)]
263pub fn placeholder() {}