1use crate::error::OptimizeError;
8use crate::unconstrained::{minimize, Bounds, Method, OptimizeResult, Options};
9use scirs2_core::ndarray::{Array1, ArrayView1};
10use scirs2_core::random::rngs::StdRng;
11#[allow(unused_imports)]
12use scirs2_core::random::{Cauchy, Distribution as RandDistribution};
13use scirs2_core::random::{Rng, SeedableRng};
14
15#[derive(Debug, Clone)]
17pub struct DualAnnealingOptions {
18 pub maxiter: usize,
20 pub initial_temp: f64,
22 pub visit: f64,
24 pub accept: f64,
26 pub maxfun: usize,
28 pub seed: Option<u64>,
30 pub restart_temp_ratio: f64,
32 pub bounds: Vec<(f64, f64)>,
34}
35
36impl Default for DualAnnealingOptions {
37 fn default() -> Self {
38 Self {
39 maxiter: 1000,
40 initial_temp: 5230.0,
41 visit: 2.62,
42 accept: -5.0,
43 maxfun: 10000000,
44 seed: None,
45 restart_temp_ratio: 2e-5,
46 bounds: vec![],
47 }
48 }
49}
50
51pub struct DualAnnealing<F>
53where
54 F: Fn(&ArrayView1<f64>) -> f64 + Clone,
55{
56 func: F,
57 x0: Array1<f64>,
58 options: DualAnnealingOptions,
59 ndim: usize,
60 rng: StdRng,
61 temperature: f64,
62 markov_chain_length: usize,
63 current_x: Array1<f64>,
64 current_energy: f64,
65 best_x: Array1<f64>,
66 best_energy: f64,
67 nfev: usize,
68 not_improved_counter: usize,
69}
70
71impl<F> DualAnnealing<F>
72where
73 F: Fn(&ArrayView1<f64>) -> f64 + Clone,
74{
75 pub fn new(func: F, x0: Array1<f64>, options: DualAnnealingOptions) -> Self {
77 let ndim = x0.len();
78 let seed = options
79 .seed
80 .unwrap_or_else(|| scirs2_core::random::rng().random_range(0..u64::MAX));
81 let rng = StdRng::seed_from_u64(seed);
82
83 let initial_energy = func(&x0.view());
84 let temperature = options.initial_temp;
85
86 Self {
87 func,
88 x0: x0.clone(),
89 options,
90 ndim,
91 rng,
92 temperature,
93 markov_chain_length: 100 * ndim,
94 current_x: x0.clone(),
95 current_energy: initial_energy,
96 best_x: x0.clone(),
97 best_energy: initial_energy,
98 nfev: 1,
99 not_improved_counter: 0,
100 }
101 }
102
103 fn generate_new_point(&mut self) -> Array1<f64> {
105 let mut x_new = self.current_x.clone();
106
107 for i in 0..self.ndim {
109 let (lb, ub) = self.options.bounds[i];
110 let y = self.current_x[i];
111
112 let q = self.options.visit;
114 let mut v;
115
116 loop {
118 let u: f64 = self.rng.gen_range(0.0..1.0);
119 let u1: f64 = self.rng.gen_range(0.0..1.0);
120 let sign = if u1 < 0.5 { -1.0 } else { 1.0 };
121
122 v = sign * self.temperature * ((1.0 + 1.0 / q).powf(u.abs()) - 1.0);
123
124 let new_val = y + v;
126 if new_val >= lb && new_val <= ub {
127 x_new[i] = new_val;
128 break;
129 }
130 }
131 }
132
133 x_new
134 }
135
136 fn accept_probability(&self, energy_new: f64) -> f64 {
138 if energy_new <= self.current_energy {
139 1.0
140 } else {
141 let delta = energy_new - self.current_energy;
142 (-delta / self.temperature).exp()
143 }
144 }
145
146 fn local_search(&self) -> (Array1<f64>, f64, usize) {
148 let result = minimize(
149 |x| (self.func)(x),
150 &self.current_x.to_vec(),
151 Method::LBFGS,
152 Some(Options {
153 bounds: Some(
154 Bounds::from_vecs(
155 self.options
156 .bounds
157 .iter()
158 .map(|&(lb, _)| Some(lb))
159 .collect(),
160 self.options
161 .bounds
162 .iter()
163 .map(|&(_, ub)| Some(ub))
164 .collect(),
165 )
166 .unwrap(),
167 ),
168 ..Default::default()
169 }),
170 )
171 .unwrap();
172
173 (result.x, result.fun, result.nfev)
174 }
175
176 fn update_temperature(&mut self, k: usize) {
178 self.temperature = self.options.initial_temp / (k as f64).ln_1p();
180 }
181
182 fn check_restart(&mut self) -> bool {
184 if self.not_improved_counter >= self.markov_chain_length {
185 self.not_improved_counter = 0;
186 self.temperature = self.options.initial_temp;
187 true
188 } else {
189 false
190 }
191 }
192
193 fn step(&mut self, iteration: usize) -> bool {
195 let mut improved = false;
196
197 for _ in 0..self.markov_chain_length {
199 let x_new = self.generate_new_point();
200 let energy_new = (self.func)(&x_new.view());
201 self.nfev += 1;
202
203 let accept_prob = self.accept_probability(energy_new);
205 if self.rng.gen_range(0.0..1.0) < accept_prob {
206 self.current_x = x_new;
207 self.current_energy = energy_new;
208
209 if energy_new < self.best_energy {
210 self.best_x = self.current_x.clone();
211 self.best_energy = energy_new;
212 improved = true;
213 self.not_improved_counter = 0;
214 }
215 }
216 }
217
218 if iteration.is_multiple_of(10) {
220 let (x_local, energy_local, nfev_local) = self.local_search();
222 self.nfev += nfev_local;
223
224 if energy_local < self.best_energy {
225 self.best_x = x_local;
226 self.best_energy = energy_local;
227 self.current_x = self.best_x.clone();
228 self.current_energy = self.best_energy;
229 improved = true;
230 self.not_improved_counter = 0;
231 }
232 }
233
234 if !improved {
235 self.not_improved_counter += 1;
236 }
237
238 self.update_temperature(iteration + 1);
240
241 self.check_restart();
243
244 improved
245 }
246
247 pub fn run(&mut self) -> OptimizeResult<f64> {
249 let mut nit = 0;
250 let mut success = false;
251 let mut message = "Maximum number of iterations reached".to_string();
252
253 for i in 0..self.options.maxiter {
254 let _improved = self.step(i);
255 nit += 1;
256
257 if self.temperature < self.options.restart_temp_ratio * self.options.initial_temp {
259 success = true;
260 message = "Temperature converged".to_string();
261 break;
262 }
263
264 if self.nfev >= self.options.maxfun {
265 message = "Maximum number of function evaluations reached".to_string();
266 break;
267 }
268 }
269
270 let (x_final, energy_final, nfev_final) = self.local_search();
272 self.nfev += nfev_final;
273
274 if energy_final < self.best_energy {
275 self.best_x = x_final;
276 self.best_energy = energy_final;
277 }
278
279 OptimizeResult {
280 x: self.best_x.clone(),
281 fun: self.best_energy,
282 nfev: self.nfev,
283 func_evals: self.nfev,
284 nit,
285 success,
286 message,
287 ..Default::default()
288 }
289 }
290}
291
292#[allow(dead_code)]
294pub fn dual_annealing<F>(
295 func: F,
296 x0: Array1<f64>,
297 bounds: Vec<(f64, f64)>,
298 options: Option<DualAnnealingOptions>,
299) -> Result<OptimizeResult<f64>, OptimizeError>
300where
301 F: Fn(&ArrayView1<f64>) -> f64 + Clone,
302{
303 let mut options = options.unwrap_or_default();
304
305 if options.bounds.is_empty() {
307 options.bounds = bounds;
308 }
309
310 let mut solver = DualAnnealing::new(func, x0, options);
311 Ok(solver.run())
312}