scirs2_optimize/global/
dual_annealing.rs

1//! Dual Annealing algorithm for global optimization
2//!
3//! A global optimization algorithm combining classical simulated annealing
4//! with a fast simulated annealing (FSA) algorithm for finding the global
5//! minimum of multivariate functions.
6
7use crate::error::OptimizeError;
8use crate::unconstrained::{minimize, Bounds, Method, OptimizeResult, Options};
9use scirs2_core::ndarray::{Array1, ArrayView1};
10use scirs2_core::random::rngs::StdRng;
11#[allow(unused_imports)]
12use scirs2_core::random::{Cauchy, Distribution as RandDistribution};
13use scirs2_core::random::{Rng, SeedableRng};
14
15/// Options for Dual Annealing algorithm
16#[derive(Debug, Clone)]
17pub struct DualAnnealingOptions {
18    /// Maximum number of global search iterations
19    pub maxiter: usize,
20    /// Minimum temperature for annealing
21    pub initial_temp: f64,
22    /// Visiting parameter (between 1 and 3)
23    pub visit: f64,
24    /// Acceptance parameter (between 0 and 1)
25    pub accept: f64,
26    /// Maximum number of function evaluations with no improvement
27    pub maxfun: usize,
28    /// Random seed for reproducibility
29    pub seed: Option<u64>,
30    /// Number of function evaluations without improvement before restarting
31    pub restart_temp_ratio: f64,
32    /// Bounds for variables
33    pub bounds: Vec<(f64, f64)>,
34}
35
36impl Default for DualAnnealingOptions {
37    fn default() -> Self {
38        Self {
39            maxiter: 1000,
40            initial_temp: 5230.0,
41            visit: 2.62,
42            accept: -5.0,
43            maxfun: 10000000,
44            seed: None,
45            restart_temp_ratio: 2e-5,
46            bounds: vec![],
47        }
48    }
49}
50
51/// Dual Annealing solver
52pub struct DualAnnealing<F>
53where
54    F: Fn(&ArrayView1<f64>) -> f64 + Clone,
55{
56    func: F,
57    x0: Array1<f64>,
58    options: DualAnnealingOptions,
59    ndim: usize,
60    rng: StdRng,
61    temperature: f64,
62    markov_chain_length: usize,
63    current_x: Array1<f64>,
64    current_energy: f64,
65    best_x: Array1<f64>,
66    best_energy: f64,
67    nfev: usize,
68    not_improved_counter: usize,
69}
70
71impl<F> DualAnnealing<F>
72where
73    F: Fn(&ArrayView1<f64>) -> f64 + Clone,
74{
75    /// Create new Dual Annealing solver
76    pub fn new(func: F, x0: Array1<f64>, options: DualAnnealingOptions) -> Self {
77        let ndim = x0.len();
78        let seed = options
79            .seed
80            .unwrap_or_else(|| scirs2_core::random::rng().random_range(0..u64::MAX));
81        let rng = StdRng::seed_from_u64(seed);
82
83        let initial_energy = func(&x0.view());
84        let temperature = options.initial_temp;
85
86        Self {
87            func,
88            x0: x0.clone(),
89            options,
90            ndim,
91            rng,
92            temperature,
93            markov_chain_length: 100 * ndim,
94            current_x: x0.clone(),
95            current_energy: initial_energy,
96            best_x: x0.clone(),
97            best_energy: initial_energy,
98            nfev: 1,
99            not_improved_counter: 0,
100        }
101    }
102
103    /// Generate new point using visiting distribution
104    fn generate_new_point(&mut self) -> Array1<f64> {
105        let mut x_new = self.current_x.clone();
106
107        // Using generalized visiting distribution
108        for i in 0..self.ndim {
109            let (lb, ub) = self.options.bounds[i];
110            let y = self.current_x[i];
111
112            // Generate random value using visiting distribution
113            let q = self.options.visit;
114            let mut v;
115
116            // Generate from Power distribution
117            loop {
118                let u: f64 = self.rng.gen_range(0.0..1.0);
119                let u1: f64 = self.rng.gen_range(0.0..1.0);
120                let sign = if u1 < 0.5 { -1.0 } else { 1.0 };
121
122                v = sign * self.temperature * ((1.0 + 1.0 / q).powf(u.abs()) - 1.0);
123
124                // Apply bounds
125                let new_val = y + v;
126                if new_val >= lb && new_val <= ub {
127                    x_new[i] = new_val;
128                    break;
129                }
130            }
131        }
132
133        x_new
134    }
135
136    /// Calculate acceptance probability
137    fn accept_probability(&self, energy_new: f64) -> f64 {
138        if energy_new <= self.current_energy {
139            1.0
140        } else {
141            let delta = energy_new - self.current_energy;
142            (-delta / self.temperature).exp()
143        }
144    }
145
146    /// Perform local search using gradient-based method
147    fn local_search(&self) -> (Array1<f64>, f64, usize) {
148        let result = minimize(
149            |x| (self.func)(x),
150            &self.current_x.to_vec(),
151            Method::LBFGS,
152            Some(Options {
153                bounds: Some(
154                    Bounds::from_vecs(
155                        self.options
156                            .bounds
157                            .iter()
158                            .map(|&(lb, _)| Some(lb))
159                            .collect(),
160                        self.options
161                            .bounds
162                            .iter()
163                            .map(|&(_, ub)| Some(ub))
164                            .collect(),
165                    )
166                    .unwrap(),
167                ),
168                ..Default::default()
169            }),
170        )
171        .unwrap();
172
173        (result.x, result.fun, result.nfev)
174    }
175
176    /// Update temperature using annealing schedule
177    fn update_temperature(&mut self, k: usize) {
178        // Classical annealing schedule
179        self.temperature = self.options.initial_temp / (k as f64).ln_1p();
180    }
181
182    /// Check if restart is needed
183    fn check_restart(&mut self) -> bool {
184        if self.not_improved_counter >= self.markov_chain_length {
185            self.not_improved_counter = 0;
186            self.temperature = self.options.initial_temp;
187            true
188        } else {
189            false
190        }
191    }
192
193    /// Run one iteration of the algorithm
194    fn step(&mut self, iteration: usize) -> bool {
195        let mut improved = false;
196
197        // Global search phase
198        for _ in 0..self.markov_chain_length {
199            let x_new = self.generate_new_point();
200            let energy_new = (self.func)(&x_new.view());
201            self.nfev += 1;
202
203            // Acceptance test
204            let accept_prob = self.accept_probability(energy_new);
205            if self.rng.gen_range(0.0..1.0) < accept_prob {
206                self.current_x = x_new;
207                self.current_energy = energy_new;
208
209                if energy_new < self.best_energy {
210                    self.best_x = self.current_x.clone();
211                    self.best_energy = energy_new;
212                    improved = true;
213                    self.not_improved_counter = 0;
214                }
215            }
216        }
217
218        // Local search phase
219        if iteration.is_multiple_of(10) {
220            // Perform local search periodically
221            let (x_local, energy_local, nfev_local) = self.local_search();
222            self.nfev += nfev_local;
223
224            if energy_local < self.best_energy {
225                self.best_x = x_local;
226                self.best_energy = energy_local;
227                self.current_x = self.best_x.clone();
228                self.current_energy = self.best_energy;
229                improved = true;
230                self.not_improved_counter = 0;
231            }
232        }
233
234        if !improved {
235            self.not_improved_counter += 1;
236        }
237
238        // Update temperature
239        self.update_temperature(iteration + 1);
240
241        // Check for restart
242        self.check_restart();
243
244        improved
245    }
246
247    /// Run the dual annealing algorithm
248    pub fn run(&mut self) -> OptimizeResult<f64> {
249        let mut nit = 0;
250        let mut success = false;
251        let mut message = "Maximum number of iterations reached".to_string();
252
253        for i in 0..self.options.maxiter {
254            let _improved = self.step(i);
255            nit += 1;
256
257            // Check convergence
258            if self.temperature < self.options.restart_temp_ratio * self.options.initial_temp {
259                success = true;
260                message = "Temperature converged".to_string();
261                break;
262            }
263
264            if self.nfev >= self.options.maxfun {
265                message = "Maximum number of function evaluations reached".to_string();
266                break;
267            }
268        }
269
270        // Final local search for polish
271        let (x_final, energy_final, nfev_final) = self.local_search();
272        self.nfev += nfev_final;
273
274        if energy_final < self.best_energy {
275            self.best_x = x_final;
276            self.best_energy = energy_final;
277        }
278
279        OptimizeResult {
280            x: self.best_x.clone(),
281            fun: self.best_energy,
282            nfev: self.nfev,
283            func_evals: self.nfev,
284            nit,
285            success,
286            message,
287            ..Default::default()
288        }
289    }
290}
291
292/// Perform global optimization using dual annealing
293#[allow(dead_code)]
294pub fn dual_annealing<F>(
295    func: F,
296    x0: Array1<f64>,
297    bounds: Vec<(f64, f64)>,
298    options: Option<DualAnnealingOptions>,
299) -> Result<OptimizeResult<f64>, OptimizeError>
300where
301    F: Fn(&ArrayView1<f64>) -> f64 + Clone,
302{
303    let mut options = options.unwrap_or_default();
304
305    // Ensure bounds are set
306    if options.bounds.is_empty() {
307        options.bounds = bounds;
308    }
309
310    let mut solver = DualAnnealing::new(func, x0, options);
311    Ok(solver.run())
312}