scirs2_optimize/reinforcement_learning/
q_learning_optimization.rs1use super::{utils, OptimizationAction, OptimizationState, RLOptimizationConfig, RLOptimizer};
6use crate::error::{OptimizeError, OptimizeResult};
7use crate::result::OptimizeResults;
8use ndarray::{Array1, ArrayView1};
9use rand::{rng, Rng};
12use std::collections::HashMap;
13
14#[derive(Debug, Clone)]
16pub struct QLearningOptimizer {
17 config: RLOptimizationConfig,
19 q_table: HashMap<String, f64>,
21 exploration_rate: f64,
23 best_params: Array1<f64>,
25 best_objective: f64,
27}
28
29impl QLearningOptimizer {
30 pub fn new(config: RLOptimizationConfig, numparams: usize) -> Self {
32 let exploration_rate = config.exploration_rate;
33 Self {
34 config,
35 q_table: HashMap::new(),
36 exploration_rate,
37 best_params: Array1::zeros(numparams),
38 best_objective: f64::INFINITY,
39 }
40 }
41
42 fn state_action_key(&self, state: &OptimizationState, action: &OptimizationAction) -> String {
44 let obj_bucket = (state.objective_value * 10.0) as i32;
46 let step_bucket = state.step / 10;
47 let action_id = match action {
48 OptimizationAction::GradientStep { .. } => 0,
49 OptimizationAction::RandomPerturbation { .. } => 1,
50 OptimizationAction::MomentumUpdate { .. } => 2,
51 OptimizationAction::AdaptiveLearningRate { .. } => 3,
52 OptimizationAction::ResetToBest => 4,
53 OptimizationAction::Terminate => 5,
54 };
55
56 format!("{}_{}__{}", obj_bucket, step_bucket, action_id)
57 }
58
59 fn get_q_value(&self, state: &OptimizationState, action: &OptimizationAction) -> f64 {
61 let key = self.state_action_key(state, action);
62 *self.q_table.get(&key).unwrap_or(&0.0)
63 }
64
65 fn update_q_value(
67 &mut self,
68 state: &OptimizationState,
69 action: &OptimizationAction,
70 new_value: f64,
71 ) {
72 let key = self.state_action_key(state, action);
73 self.q_table.insert(key, new_value);
74 }
75
76 fn get_possible_actions(&self) -> Vec<OptimizationAction> {
78 vec![
79 OptimizationAction::GradientStep {
80 learning_rate: 0.01,
81 },
82 OptimizationAction::RandomPerturbation { magnitude: 0.1 },
83 OptimizationAction::MomentumUpdate { momentum: 0.9 },
84 OptimizationAction::AdaptiveLearningRate { factor: 0.5 },
85 OptimizationAction::ResetToBest,
86 OptimizationAction::Terminate,
87 ]
88 }
89}
90
91impl RLOptimizer for QLearningOptimizer {
92 fn config(&self) -> &RLOptimizationConfig {
93 &self.config
94 }
95
96 fn select_action(&mut self, state: &OptimizationState) -> OptimizationAction {
97 if rand::rng().random_range(0.0..1.0) < self.exploration_rate {
99 let actions = self.get_possible_actions();
101 let idx = rand::rng().random_range(0..actions.len());
102 actions[idx].clone()
103 } else {
104 let actions = self.get_possible_actions();
106 let mut best_action = actions[0].clone();
107 let mut best_q = self.get_q_value(state, &best_action);
108
109 for action in &actions[1..] {
110 let q_value = self.get_q_value(state, action);
111 if q_value > best_q {
112 best_q = q_value;
113 best_action = action.clone();
114 }
115 }
116
117 best_action
118 }
119 }
120
121 fn update(&mut self, experience: &super::Experience) -> Result<(), OptimizeError> {
122 let current_q = self.get_q_value(&experience.state, &experience.action);
124
125 let max_next_q = if experience.done {
126 0.0
127 } else {
128 let actions = self.get_possible_actions();
129 actions
130 .iter()
131 .map(|a| self.get_q_value(&experience.next_state, a))
132 .fold(f64::NEG_INFINITY, f64::max)
133 };
134
135 let target = experience.reward + self.config.discount_factor * max_next_q;
136 let new_q = current_q + self.config.learning_rate * (target - current_q);
137
138 self.update_q_value(&experience.state, &experience.action, new_q);
139
140 Ok(())
141 }
142
143 fn run_episode<F>(
144 &mut self,
145 objective: &F,
146 initial_params: &ArrayView1<f64>,
147 ) -> OptimizeResult<OptimizeResults<f64>>
148 where
149 F: Fn(&ArrayView1<f64>) -> f64,
150 {
151 let mut current_params = initial_params.to_owned();
152 let mut current_state = utils::create_state(current_params.clone(), objective, 0, None);
153 let mut momentum = Array1::zeros(initial_params.len());
154
155 for step in 0..self.config.max_steps_per_episode {
156 let action = self.select_action(¤t_state);
157 let new_params =
158 utils::apply_action(¤t_state, &action, &self.best_params, &mut momentum);
159 let new_state =
160 utils::create_state(new_params, objective, step + 1, Some(¤t_state));
161
162 let reward = current_state.objective_value - new_state.objective_value;
164
165 let experience = super::Experience {
166 state: current_state.clone(),
167 action: action.clone(),
168 reward,
169 next_state: new_state.clone(),
170 done: utils::should_terminate(&new_state, self.config.max_steps_per_episode),
171 };
172
173 self.update(&experience)?;
174
175 if new_state.objective_value < self.best_objective {
176 self.best_objective = new_state.objective_value;
177 self.best_params = new_state.parameters.clone();
178 }
179
180 current_state = new_state;
181 current_params = current_state.parameters.clone();
182
183 if utils::should_terminate(¤t_state, self.config.max_steps_per_episode)
184 || matches!(action, OptimizationAction::Terminate)
185 {
186 break;
187 }
188 }
189
190 self.exploration_rate = (self.exploration_rate * self.config.exploration_decay)
192 .max(self.config.min_exploration_rate);
193
194 Ok(OptimizeResults::<f64> {
195 x: current_params,
196 fun: current_state.objective_value,
197 success: current_state.convergence_metrics.relative_objective_change < 1e-6,
198 nit: current_state.step,
199 message: "Q-learning episode completed".to_string(),
200 jac: None,
201 hess: None,
202 constr: None,
203 nfev: current_state.step,
204 njev: 0,
205 nhev: 0,
206 maxcv: 0,
207 status: if current_state.convergence_metrics.relative_objective_change < 1e-6 {
208 0
209 } else {
210 1
211 },
212 })
213 }
214
215 fn train<F>(
216 &mut self,
217 objective: &F,
218 initial_params: &ArrayView1<f64>,
219 ) -> OptimizeResult<OptimizeResults<f64>>
220 where
221 F: Fn(&ArrayView1<f64>) -> f64,
222 {
223 let mut best_result = OptimizeResults::<f64> {
224 x: initial_params.to_owned(),
225 fun: f64::INFINITY,
226 success: false,
227 nit: 0,
228 nfev: 0,
229 njev: 0,
230 nhev: 0,
231 maxcv: 0,
232 status: 0,
233 message: "Training not completed".to_string(),
234 jac: None,
235 hess: None,
236 constr: None,
237 };
238
239 for _episode in 0..self.config.num_episodes {
240 let result = self.run_episode(objective, initial_params)?;
241
242 if result.fun < best_result.fun {
243 best_result = result;
244 }
245 }
246
247 best_result.x = self.best_params.clone();
248 best_result.fun = self.best_objective;
249 best_result.message = "Q-learning training completed".to_string();
250
251 Ok(best_result)
252 }
253
254 fn reset(&mut self) {
255 self.q_table.clear();
256 self.exploration_rate = self.config.exploration_rate;
257 self.best_objective = f64::INFINITY;
258 self.best_params.fill(0.0);
259 }
260}
261
262#[allow(dead_code)]
263pub fn placeholder() {}