scirs2_optimize/global/
multi_start.rs1use crate::error::OptimizeError;
7use crate::unconstrained::{
8    minimize, Bounds as UnconstrainedBounds, Method as UnconstrainedMethod, OptimizeResult, Options,
9};
10use scirs2_core::ndarray::{Array1, ArrayView1};
11use scirs2_core::parallel_ops::*;
12use scirs2_core::random::prelude::SliceRandom;
13use scirs2_core::random::rngs::StdRng;
14use scirs2_core::random::{Rng, SeedableRng};
15
16#[derive(Debug, Clone)]
18pub struct MultiStartOptions {
19    pub n_starts: usize,
21    pub local_method: UnconstrainedMethod,
23    pub parallel: bool,
25    pub seed: Option<u64>,
27    pub strategy: StartingPointStrategy,
29}
30
31impl Default for MultiStartOptions {
32    fn default() -> Self {
33        Self {
34            n_starts: 10,
35            local_method: UnconstrainedMethod::BFGS,
36            parallel: true,
37            seed: None,
38            strategy: StartingPointStrategy::Random,
39        }
40    }
41}
42
43#[derive(Debug, Clone)]
45pub enum StartingPointStrategy {
46    Random,
48    LatinHypercube,
50    Halton,
52    Sobol,
54    Grid,
56}
57
58pub type Bounds = Vec<(f64, f64)>;
60
61pub struct MultiStart<F>
63where
64    F: Fn(&ArrayView1<f64>) -> f64 + Clone + Send + Sync,
65{
66    func: F,
67    bounds: Bounds,
68    options: MultiStartOptions,
69    ndim: usize,
70    rng: StdRng,
71}
72
73impl<F> MultiStart<F>
74where
75    F: Fn(&ArrayView1<f64>) -> f64 + Clone + Send + Sync,
76{
77    pub fn new(func: F, bounds: Bounds, options: MultiStartOptions) -> Self {
79        let ndim = bounds.len();
80        let seed = options
81            .seed
82            .unwrap_or_else(|| scirs2_core::random::rng().random_range(0..u64::MAX));
83        let rng = StdRng::seed_from_u64(seed);
84
85        Self {
86            func,
87            bounds,
88            options,
89            ndim,
90            rng,
91        }
92    }
93
94    fn generate_starting_points(&mut self) -> Vec<Array1<f64>> {
96        match self.options.strategy {
97            StartingPointStrategy::Random => self.generate_random_points(),
98            StartingPointStrategy::LatinHypercube => self.generate_latin_hypercube_points(),
99            StartingPointStrategy::Halton => self.generate_halton_points(),
100            StartingPointStrategy::Sobol => self.generate_sobol_points(),
101            StartingPointStrategy::Grid => self.generate_grid_points(),
102        }
103    }
104
105    fn generate_random_points(&mut self) -> Vec<Array1<f64>> {
107        let mut points = Vec::with_capacity(self.options.n_starts);
108
109        for _ in 0..self.options.n_starts {
110            let mut point = Array1::zeros(self.ndim);
111            for j in 0..self.ndim {
112                let (lb, ub) = self.bounds[j];
113                point[j] = self.rng.gen_range(lb..ub);
114            }
115            points.push(point);
116        }
117
118        points
119    }
120
121    fn generate_latin_hypercube_points(&mut self) -> Vec<Array1<f64>> {
123        let mut points = Vec::with_capacity(self.options.n_starts);
124        let n = self.options.n_starts;
125
126        for i in 0..n {
128            let mut point = Array1::zeros(self.ndim);
129
130            for j in 0..self.ndim {
131                let (lb, ub) = self.bounds[j];
132                let segment_size = (ub - lb) / n as f64;
133
134                let offset = self.rng.gen_range(0.0..1.0);
136                point[j] = lb + (i as f64 + offset) * segment_size;
137            }
138
139            points.push(point);
140        }
141
142        for j in 0..self.ndim {
144            let mut indices: Vec<usize> = (0..n).collect();
145            indices.shuffle(&mut self.rng);
146
147            for (i, &idx) in indices.iter().enumerate() {
148                let temp = points[i][j];
149                points[i][j] = points[idx][j];
150                points[idx][j] = temp;
151            }
152        }
153
154        points
155    }
156
157    fn generate_halton_points(&mut self) -> Vec<Array1<f64>> {
159        self.generate_random_points()
161    }
162
163    fn generate_sobol_points(&mut self) -> Vec<Array1<f64>> {
165        self.generate_random_points()
167    }
168
169    fn generate_grid_points(&self) -> Vec<Array1<f64>> {
171        let points_per_dim = (self.options.n_starts as f64)
172            .powf(1.0 / self.ndim as f64)
173            .ceil() as usize;
174        let mut points = Vec::new();
175
176        let mut current = vec![0usize; self.ndim];
178        loop {
179            let mut point = Array1::zeros(self.ndim);
180
181            for j in 0..self.ndim {
182                let (lb, ub) = self.bounds[j];
183                let step = (ub - lb) / (points_per_dim - 1).max(1) as f64;
184                point[j] = lb + current[j] as f64 * step;
185            }
186
187            points.push(point);
188
189            let mut carry = true;
191            for j in current.iter_mut() {
192                if carry {
193                    *j += 1;
194                    if *j >= points_per_dim {
195                        *j = 0;
196                    } else {
197                        carry = false;
198                    }
199                }
200            }
201
202            if carry || points.len() >= self.options.n_starts {
203                break;
204            }
205        }
206
207        points.truncate(self.options.n_starts);
208        points
209    }
210
211    fn optimize_single(&self, x0: Array1<f64>) -> OptimizeResult<f64> {
213        let bounds = Some(
214            UnconstrainedBounds::from_vecs(
215                self.bounds.iter().map(|&(lb, _)| Some(lb)).collect(),
216                self.bounds.iter().map(|&(_, ub)| Some(ub)).collect(),
217            )
218            .unwrap(),
219        );
220
221        let options = Options {
222            bounds,
223            ..Default::default()
224        };
225
226        let func = self.func.clone();
227
228        minimize(
229            move |x: &ArrayView1<f64>| func(x),
230            &x0.to_vec(),
231            self.options.local_method,
232            Some(options),
233        )
234        .unwrap_or_else(|_| {
235            OptimizeResult {
237                x: x0,
238                fun: f64::INFINITY,
239                success: false,
240                ..Default::default()
241            }
242        })
243    }
244
245    pub fn run(&mut self) -> OptimizeResult<f64> {
247        let starting_points = self.generate_starting_points();
248
249        let results = if self.options.parallel {
250            starting_points
252                .into_par_iter()
253                .map(|x0| self.optimize_single(x0))
254                .collect::<Vec<_>>()
255        } else {
256            starting_points
258                .into_iter()
259                .map(|x0| self.optimize_single(x0))
260                .collect::<Vec<_>>()
261        };
262
263        let best_result = results
265            .into_iter()
266            .filter(|r| r.success)
267            .min_by(|a, b| {
268                a.fun
269                    .partial_cmp(&b.fun)
270                    .unwrap_or(std::cmp::Ordering::Equal)
271            })
272            .unwrap_or_else(|| OptimizeResult {
273                x: Array1::zeros(self.ndim),
274                fun: f64::INFINITY,
275                success: false,
276                message: "All optimization attempts failed".to_string(),
277                ..Default::default()
278            });
279
280        OptimizeResult {
281            x: best_result.x,
282            fun: best_result.fun,
283            nit: self.options.n_starts,
284            success: best_result.success,
285            message: format!(
286                "Multi-start optimization with {} starts",
287                self.options.n_starts
288            ),
289            ..Default::default()
290        }
291    }
292}
293
294#[allow(dead_code)]
296pub fn multi_start<F>(
297    func: F,
298    bounds: Bounds,
299    options: Option<MultiStartOptions>,
300) -> Result<OptimizeResult<f64>, OptimizeError>
301where
302    F: Fn(&ArrayView1<f64>) -> f64 + Clone + Send + Sync,
303{
304    let options = options.unwrap_or_default();
305    let mut solver = MultiStart::new(func, bounds, options);
306    Ok(solver.run())
307}