scirs2_optimize/global/
multi_start.rs1use crate::error::OptimizeError;
7use crate::unconstrained::{
8 minimize, Bounds as UnconstrainedBounds, Method as UnconstrainedMethod, OptimizeResult, Options,
9};
10use scirs2_core::ndarray::{Array1, ArrayView1};
11use scirs2_core::parallel_ops::*;
12use scirs2_core::random::prelude::SliceRandom;
13use scirs2_core::random::rngs::StdRng;
14use scirs2_core::random::{Rng, SeedableRng};
15
16#[derive(Debug, Clone)]
18pub struct MultiStartOptions {
19 pub n_starts: usize,
21 pub local_method: UnconstrainedMethod,
23 pub parallel: bool,
25 pub seed: Option<u64>,
27 pub strategy: StartingPointStrategy,
29}
30
31impl Default for MultiStartOptions {
32 fn default() -> Self {
33 Self {
34 n_starts: 10,
35 local_method: UnconstrainedMethod::BFGS,
36 parallel: true,
37 seed: None,
38 strategy: StartingPointStrategy::Random,
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
45pub enum StartingPointStrategy {
46 Random,
48 LatinHypercube,
50 Halton,
52 Sobol,
54 Grid,
56}
57
58pub type Bounds = Vec<(f64, f64)>;
60
61pub struct MultiStart<F>
63where
64 F: Fn(&ArrayView1<f64>) -> f64 + Clone + Send + Sync,
65{
66 func: F,
67 bounds: Bounds,
68 options: MultiStartOptions,
69 ndim: usize,
70 rng: StdRng,
71}
72
73impl<F> MultiStart<F>
74where
75 F: Fn(&ArrayView1<f64>) -> f64 + Clone + Send + Sync,
76{
77 pub fn new(func: F, bounds: Bounds, options: MultiStartOptions) -> Self {
79 let ndim = bounds.len();
80 let seed = options
81 .seed
82 .unwrap_or_else(|| scirs2_core::random::rng().random_range(0..u64::MAX));
83 let rng = StdRng::seed_from_u64(seed);
84
85 Self {
86 func,
87 bounds,
88 options,
89 ndim,
90 rng,
91 }
92 }
93
94 fn generate_starting_points(&mut self) -> Vec<Array1<f64>> {
96 match self.options.strategy {
97 StartingPointStrategy::Random => self.generate_random_points(),
98 StartingPointStrategy::LatinHypercube => self.generate_latin_hypercube_points(),
99 StartingPointStrategy::Halton => self.generate_halton_points(),
100 StartingPointStrategy::Sobol => self.generate_sobol_points(),
101 StartingPointStrategy::Grid => self.generate_grid_points(),
102 }
103 }
104
105 fn generate_random_points(&mut self) -> Vec<Array1<f64>> {
107 let mut points = Vec::with_capacity(self.options.n_starts);
108
109 for _ in 0..self.options.n_starts {
110 let mut point = Array1::zeros(self.ndim);
111 for j in 0..self.ndim {
112 let (lb, ub) = self.bounds[j];
113 point[j] = self.rng.gen_range(lb..ub);
114 }
115 points.push(point);
116 }
117
118 points
119 }
120
121 fn generate_latin_hypercube_points(&mut self) -> Vec<Array1<f64>> {
123 let mut points = Vec::with_capacity(self.options.n_starts);
124 let n = self.options.n_starts;
125
126 for i in 0..n {
128 let mut point = Array1::zeros(self.ndim);
129
130 for j in 0..self.ndim {
131 let (lb, ub) = self.bounds[j];
132 let segment_size = (ub - lb) / n as f64;
133
134 let offset = self.rng.gen_range(0.0..1.0);
136 point[j] = lb + (i as f64 + offset) * segment_size;
137 }
138
139 points.push(point);
140 }
141
142 for j in 0..self.ndim {
144 let mut indices: Vec<usize> = (0..n).collect();
145 indices.shuffle(&mut self.rng);
146
147 for (i, &idx) in indices.iter().enumerate() {
148 let temp = points[i][j];
149 points[i][j] = points[idx][j];
150 points[idx][j] = temp;
151 }
152 }
153
154 points
155 }
156
157 fn generate_halton_points(&mut self) -> Vec<Array1<f64>> {
159 self.generate_random_points()
161 }
162
163 fn generate_sobol_points(&mut self) -> Vec<Array1<f64>> {
165 self.generate_random_points()
167 }
168
169 fn generate_grid_points(&self) -> Vec<Array1<f64>> {
171 let points_per_dim = (self.options.n_starts as f64)
172 .powf(1.0 / self.ndim as f64)
173 .ceil() as usize;
174 let mut points = Vec::new();
175
176 let mut current = vec![0usize; self.ndim];
178 loop {
179 let mut point = Array1::zeros(self.ndim);
180
181 for j in 0..self.ndim {
182 let (lb, ub) = self.bounds[j];
183 let step = (ub - lb) / (points_per_dim - 1).max(1) as f64;
184 point[j] = lb + current[j] as f64 * step;
185 }
186
187 points.push(point);
188
189 let mut carry = true;
191 for j in current.iter_mut() {
192 if carry {
193 *j += 1;
194 if *j >= points_per_dim {
195 *j = 0;
196 } else {
197 carry = false;
198 }
199 }
200 }
201
202 if carry || points.len() >= self.options.n_starts {
203 break;
204 }
205 }
206
207 points.truncate(self.options.n_starts);
208 points
209 }
210
211 fn optimize_single(&self, x0: Array1<f64>) -> OptimizeResult<f64> {
213 let bounds = Some(
214 UnconstrainedBounds::from_vecs(
215 self.bounds.iter().map(|&(lb, _)| Some(lb)).collect(),
216 self.bounds.iter().map(|&(_, ub)| Some(ub)).collect(),
217 )
218 .unwrap(),
219 );
220
221 let options = Options {
222 bounds,
223 ..Default::default()
224 };
225
226 let func = self.func.clone();
227
228 minimize(
229 move |x: &ArrayView1<f64>| func(x),
230 &x0.to_vec(),
231 self.options.local_method,
232 Some(options),
233 )
234 .unwrap_or_else(|_| {
235 OptimizeResult {
237 x: x0,
238 fun: f64::INFINITY,
239 success: false,
240 ..Default::default()
241 }
242 })
243 }
244
245 pub fn run(&mut self) -> OptimizeResult<f64> {
247 let starting_points = self.generate_starting_points();
248
249 let results = if self.options.parallel {
250 starting_points
252 .into_par_iter()
253 .map(|x0| self.optimize_single(x0))
254 .collect::<Vec<_>>()
255 } else {
256 starting_points
258 .into_iter()
259 .map(|x0| self.optimize_single(x0))
260 .collect::<Vec<_>>()
261 };
262
263 let best_result = results
265 .into_iter()
266 .filter(|r| r.success)
267 .min_by(|a, b| {
268 a.fun
269 .partial_cmp(&b.fun)
270 .unwrap_or(std::cmp::Ordering::Equal)
271 })
272 .unwrap_or_else(|| OptimizeResult {
273 x: Array1::zeros(self.ndim),
274 fun: f64::INFINITY,
275 success: false,
276 message: "All optimization attempts failed".to_string(),
277 ..Default::default()
278 });
279
280 OptimizeResult {
281 x: best_result.x,
282 fun: best_result.fun,
283 nit: self.options.n_starts,
284 success: best_result.success,
285 message: format!(
286 "Multi-start optimization with {} starts",
287 self.options.n_starts
288 ),
289 ..Default::default()
290 }
291 }
292}
293
294#[allow(dead_code)]
296pub fn multi_start<F>(
297 func: F,
298 bounds: Bounds,
299 options: Option<MultiStartOptions>,
300) -> Result<OptimizeResult<f64>, OptimizeError>
301where
302 F: Fn(&ArrayView1<f64>) -> f64 + Clone + Send + Sync,
303{
304 let options = options.unwrap_or_default();
305 let mut solver = MultiStart::new(func, bounds, options);
306 Ok(solver.run())
307}