scirs2_optimize/reinforcement_learning/
bandit_optimization.rs1use crate::error::OptimizeResult;
6use crate::result::OptimizeResults;
7use scirs2_core::ndarray::{Array1, ArrayView1};
8use scirs2_core::random::{rng, Rng};
9#[derive(Debug, Clone)]
14pub struct BanditOptimizer {
15 pub num_arms: usize,
17 pub arm_rewards: Array1<f64>,
19 pub arm_counts: Array1<usize>,
21}
22
23impl BanditOptimizer {
24 pub fn new(num_arms: usize) -> Self {
26 Self {
27 num_arms,
28 arm_rewards: Array1::zeros(num_arms),
29 arm_counts: Array1::zeros(num_arms),
30 }
31 }
32
33 pub fn select_arm(&self) -> usize {
35 let total_counts: usize = self.arm_counts.sum();
36 if total_counts == 0 {
37 return scirs2_core::random::rng().random_range(0..self.num_arms);
38 }
39
40 let mut best_arm = 0;
41 let mut best_ucb = f64::NEG_INFINITY;
42
43 for arm in 0..self.num_arms {
44 if self.arm_counts[arm] == 0 {
45 return arm; }
47
48 let average_reward = self.arm_rewards[arm] / self.arm_counts[arm] as f64;
49 let confidence_interval =
50 (2.0 * (total_counts as f64).ln() / self.arm_counts[arm] as f64).sqrt();
51 let ucb = average_reward + confidence_interval;
52
53 if ucb > best_ucb {
54 best_ucb = ucb;
55 best_arm = arm;
56 }
57 }
58
59 best_arm
60 }
61
62 pub fn update_arm(&mut self, arm: usize, reward: f64) {
64 if arm < self.num_arms {
65 self.arm_rewards[arm] += reward;
66 self.arm_counts[arm] += 1;
67 }
68 }
69}
70
71#[allow(dead_code)]
73pub fn bandit_optimize<F>(
74 objective: F,
75 initial_params: &ArrayView1<f64>,
76 num_nit: usize,
77) -> OptimizeResult<OptimizeResults<f64>>
78where
79 F: Fn(&ArrayView1<f64>) -> f64,
80{
81 let mut bandit = BanditOptimizer::new(3); let mut params = initial_params.to_owned();
83 let mut best_obj = objective(initial_params);
84
85 for _iter in 0..num_nit {
86 let arm = bandit.select_arm();
87
88 let step_size = match arm {
90 0 => 0.01, 1 => 0.1, _ => 0.001, };
94
95 for i in 0..params.len() {
97 params[i] += (scirs2_core::random::rng().random::<f64>() - 0.5) * step_size;
98 }
99
100 let new_obj = objective(¶ms.view());
101 let reward = if new_obj < best_obj { 1.0 } else { 0.0 };
102
103 bandit.update_arm(arm, reward);
104
105 if new_obj < best_obj {
106 best_obj = new_obj;
107 }
108 }
109
110 Ok(OptimizeResults::<f64> {
111 x: params,
112 fun: best_obj,
113 success: true,
114 nit: num_nit,
115 message: "Bandit optimization completed".to_string(),
116 jac: None,
117 hess: None,
118 constr: None,
119 nfev: num_nit * 3, njev: 0,
121 nhev: 0,
122 maxcv: 0,
123 status: 0,
124 })
125}
126
127#[allow(dead_code)]
128pub fn placeholder() {}