scirs2_optimize/global/
multi_start.rs1use crate::error::OptimizeError;
7use crate::unconstrained::{
8 minimize, Bounds as UnconstrainedBounds, Method as UnconstrainedMethod, OptimizeResult, Options,
9};
10use ndarray::{Array1, ArrayView1};
11use rand::prelude::*;
12use rand::rngs::StdRng;
13use rayon::prelude::*;
14
15#[derive(Debug, Clone)]
17pub struct MultiStartOptions {
18 pub n_starts: usize,
20 pub local_method: UnconstrainedMethod,
22 pub parallel: bool,
24 pub seed: Option<u64>,
26 pub strategy: StartingPointStrategy,
28}
29
30impl Default for MultiStartOptions {
31 fn default() -> Self {
32 Self {
33 n_starts: 10,
34 local_method: UnconstrainedMethod::BFGS,
35 parallel: true,
36 seed: None,
37 strategy: StartingPointStrategy::Random,
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
44pub enum StartingPointStrategy {
45 Random,
47 LatinHypercube,
49 Halton,
51 Sobol,
53 Grid,
55}
56
57pub type Bounds = Vec<(f64, f64)>;
59
60pub struct MultiStart<F>
62where
63 F: Fn(&ArrayView1<f64>) -> f64 + Clone + Send + Sync,
64{
65 func: F,
66 bounds: Bounds,
67 options: MultiStartOptions,
68 ndim: usize,
69 rng: StdRng,
70}
71
72impl<F> MultiStart<F>
73where
74 F: Fn(&ArrayView1<f64>) -> f64 + Clone + Send + Sync,
75{
76 pub fn new(func: F, bounds: Bounds, options: MultiStartOptions) -> Self {
78 let ndim = bounds.len();
79 let seed = options.seed.unwrap_or_else(rand::random);
80 let rng = StdRng::seed_from_u64(seed);
81
82 Self {
83 func,
84 bounds,
85 options,
86 ndim,
87 rng,
88 }
89 }
90
91 fn generate_starting_points(&mut self) -> Vec<Array1<f64>> {
93 match self.options.strategy {
94 StartingPointStrategy::Random => self.generate_random_points(),
95 StartingPointStrategy::LatinHypercube => self.generate_latin_hypercube_points(),
96 StartingPointStrategy::Halton => self.generate_halton_points(),
97 StartingPointStrategy::Sobol => self.generate_sobol_points(),
98 StartingPointStrategy::Grid => self.generate_grid_points(),
99 }
100 }
101
102 fn generate_random_points(&mut self) -> Vec<Array1<f64>> {
104 let mut points = Vec::with_capacity(self.options.n_starts);
105
106 for _ in 0..self.options.n_starts {
107 let mut point = Array1::zeros(self.ndim);
108 for j in 0..self.ndim {
109 let (lb, ub) = self.bounds[j];
110 point[j] = self.rng.random_range(lb..ub);
111 }
112 points.push(point);
113 }
114
115 points
116 }
117
118 fn generate_latin_hypercube_points(&mut self) -> Vec<Array1<f64>> {
120 let mut points = Vec::with_capacity(self.options.n_starts);
121 let n = self.options.n_starts;
122
123 for i in 0..n {
125 let mut point = Array1::zeros(self.ndim);
126
127 for j in 0..self.ndim {
128 let (lb, ub) = self.bounds[j];
129 let segment_size = (ub - lb) / n as f64;
130
131 let offset = self.rng.random::<f64>();
133 point[j] = lb + (i as f64 + offset) * segment_size;
134 }
135
136 points.push(point);
137 }
138
139 for j in 0..self.ndim {
141 let mut indices: Vec<usize> = (0..n).collect();
142 indices.shuffle(&mut self.rng);
143
144 for (i, &idx) in indices.iter().enumerate() {
145 let temp = points[i][j];
146 points[i][j] = points[idx][j];
147 points[idx][j] = temp;
148 }
149 }
150
151 points
152 }
153
154 fn generate_halton_points(&mut self) -> Vec<Array1<f64>> {
156 self.generate_random_points()
158 }
159
160 fn generate_sobol_points(&mut self) -> Vec<Array1<f64>> {
162 self.generate_random_points()
164 }
165
166 fn generate_grid_points(&mut self) -> Vec<Array1<f64>> {
168 let points_per_dim = (self.options.n_starts as f64)
169 .powf(1.0 / self.ndim as f64)
170 .ceil() as usize;
171 let mut points = Vec::new();
172
173 let mut current = vec![0usize; self.ndim];
175 loop {
176 let mut point = Array1::zeros(self.ndim);
177
178 for j in 0..self.ndim {
179 let (lb, ub) = self.bounds[j];
180 let step = (ub - lb) / (points_per_dim - 1).max(1) as f64;
181 point[j] = lb + current[j] as f64 * step;
182 }
183
184 points.push(point);
185
186 let mut carry = true;
188 for j in current.iter_mut() {
189 if carry {
190 *j += 1;
191 if *j >= points_per_dim {
192 *j = 0;
193 } else {
194 carry = false;
195 }
196 }
197 }
198
199 if carry || points.len() >= self.options.n_starts {
200 break;
201 }
202 }
203
204 points.truncate(self.options.n_starts);
205 points
206 }
207
208 fn optimize_single(&self, x0: Array1<f64>) -> OptimizeResult<f64> {
210 let bounds = Some(
211 UnconstrainedBounds::from_vecs(
212 self.bounds.iter().map(|&(lb, _)| Some(lb)).collect(),
213 self.bounds.iter().map(|&(_, ub)| Some(ub)).collect(),
214 )
215 .unwrap(),
216 );
217
218 let options = Options {
219 bounds,
220 ..Default::default()
221 };
222
223 let func = self.func.clone();
224
225 minimize(
226 move |x: &ArrayView1<f64>| func(x),
227 &x0.to_vec(),
228 self.options.local_method,
229 Some(options),
230 )
231 .unwrap_or_else(|_| {
232 OptimizeResult {
234 x: x0,
235 fun: f64::INFINITY,
236 success: false,
237 ..Default::default()
238 }
239 })
240 }
241
242 pub fn run(&mut self) -> OptimizeResult<f64> {
244 let starting_points = self.generate_starting_points();
245
246 let results = if self.options.parallel {
247 starting_points
249 .into_par_iter()
250 .map(|x0| self.optimize_single(x0))
251 .collect::<Vec<_>>()
252 } else {
253 starting_points
255 .into_iter()
256 .map(|x0| self.optimize_single(x0))
257 .collect::<Vec<_>>()
258 };
259
260 let best_result = results
262 .into_iter()
263 .filter(|r| r.success)
264 .min_by(|a, b| {
265 a.fun
266 .partial_cmp(&b.fun)
267 .unwrap_or(std::cmp::Ordering::Equal)
268 })
269 .unwrap_or_else(|| OptimizeResult {
270 x: Array1::zeros(self.ndim),
271 fun: f64::INFINITY,
272 success: false,
273 message: "All optimization attempts failed".to_string(),
274 ..Default::default()
275 });
276
277 OptimizeResult {
278 x: best_result.x,
279 fun: best_result.fun,
280 nit: self.options.n_starts,
281 iterations: self.options.n_starts,
282 success: best_result.success,
283 message: format!(
284 "Multi-start optimization with {} starts",
285 self.options.n_starts
286 ),
287 ..Default::default()
288 }
289 }
290}
291
292pub fn multi_start<F>(
294 func: F,
295 bounds: Bounds,
296 options: Option<MultiStartOptions>,
297) -> Result<OptimizeResult<f64>, OptimizeError>
298where
299 F: Fn(&ArrayView1<f64>) -> f64 + Clone + Send + Sync,
300{
301 let options = options.unwrap_or_default();
302 let mut solver = MultiStart::new(func, bounds, options);
303 Ok(solver.run())
304}