scirs2_optimize/derivative_free/
pattern_search.rs1use super::{clip, DerivativeFreeOptimizer, DfOptResult};
17use crate::error::{OptimizeError, OptimizeResult};
18use scirs2_core::ndarray::Array1;
19
20#[derive(Debug, Clone, Copy, PartialEq)]
22pub enum PatternType {
23 Axes,
25 CompassRose,
27 Simplex,
29}
30
31#[derive(Debug, Clone)]
33pub struct PatternSearchOptions {
34 pub lower: Option<Vec<f64>>,
36 pub upper: Option<Vec<f64>>,
38 pub delta_init: f64,
40 pub delta_min: f64,
42 pub expand_factor: f64,
44 pub contract_factor: f64,
46 pub max_fev: usize,
48 pub max_iter: usize,
50 pub f_tol: f64,
52 pub pattern: PatternType,
54 pub opportunistic: bool,
56}
57
58impl Default for PatternSearchOptions {
59 fn default() -> Self {
60 PatternSearchOptions {
61 lower: None,
62 upper: None,
63 delta_init: 1.0,
64 delta_min: 1e-7,
65 expand_factor: 2.0,
66 contract_factor: 0.5,
67 max_fev: 50000,
68 max_iter: 10000,
69 f_tol: 1e-10,
70 pattern: PatternType::Axes,
71 opportunistic: true,
72 }
73 }
74}
75
76pub struct PatternSearchSolver {
78 pub options: PatternSearchOptions,
79}
80
81impl PatternSearchSolver {
82 pub fn new() -> Self {
84 PatternSearchSolver {
85 options: PatternSearchOptions::default(),
86 }
87 }
88
89 pub fn with_options(options: PatternSearchOptions) -> Self {
91 PatternSearchSolver { options }
92 }
93
94 fn get_bounds(&self, n: usize) -> (Vec<f64>, Vec<f64>) {
96 let lo = match &self.options.lower {
97 Some(l) => l.clone(),
98 None => vec![f64::NEG_INFINITY; n],
99 };
100 let hi = match &self.options.upper {
101 Some(u) => u.clone(),
102 None => vec![f64::INFINITY; n],
103 };
104 (lo, hi)
105 }
106
107 fn project(&self, x: &[f64], lo: &[f64], hi: &[f64]) -> Vec<f64> {
109 x.iter()
110 .zip(lo.iter().zip(hi.iter()))
111 .map(|(&xi, (&l, &h))| clip(xi, l, h))
112 .collect()
113 }
114
115 fn generate_pattern(&self, n: usize) -> Vec<Vec<f64>> {
117 match self.options.pattern {
118 PatternType::Axes => {
119 let mut dirs = Vec::with_capacity(2 * n);
120 for i in 0..n {
121 let mut d = vec![0.0; n];
122 d[i] = 1.0;
123 dirs.push(d.clone());
124 d[i] = -1.0;
125 dirs.push(d);
126 }
127 dirs
128 }
129 PatternType::CompassRose => {
130 let mut dirs = Vec::new();
131 for i in 0..n {
133 let mut d = vec![0.0; n];
134 d[i] = 1.0;
135 dirs.push(d.clone());
136 d[i] = -1.0;
137 dirs.push(d);
138 }
139 if n <= 8 {
141 for i in 0..n {
142 for j in (i + 1)..n {
143 let mut d = vec![0.0; n];
144 d[i] = 1.0;
145 d[j] = -1.0;
146 let scale = 1.0 / (2.0_f64).sqrt();
147 let ds: Vec<f64> = d.iter().map(|v| v * scale).collect();
148 dirs.push(ds.clone());
149 let ds2: Vec<f64> = ds.iter().map(|v| -v).collect();
150 dirs.push(ds2);
151 }
152 }
153 }
154 dirs
155 }
156 PatternType::Simplex => {
157 let mut dirs = Vec::with_capacity(n + 1);
158 let neg_sum_scale = 1.0 / (n as f64).sqrt();
159 let neg_d = vec![-neg_sum_scale; n];
160 for i in 0..n {
161 let mut d = vec![0.0; n];
162 d[i] = 1.0;
163 dirs.push(d);
164 }
165 dirs.push(neg_d);
166 dirs
167 }
168 }
169 }
170}
171
172impl Default for PatternSearchSolver {
173 fn default() -> Self {
174 PatternSearchSolver::new()
175 }
176}
177
178impl DerivativeFreeOptimizer for PatternSearchSolver {
179 fn minimize<F>(&self, func: F, x0: &[f64]) -> OptimizeResult<DfOptResult>
180 where
181 F: Fn(&[f64]) -> f64,
182 {
183 let n = x0.len();
184 if n == 0 {
185 return Err(OptimizeError::InvalidInput(
186 "x0 must be non-empty".to_string(),
187 ));
188 }
189
190 let (lo, hi) = self.get_bounds(n);
191 let mut x = self.project(x0, &lo, &hi);
192 let mut delta = self.options.delta_init;
193
194 let mut nfev = 0usize;
195 let mut nit = 0usize;
196 let mut fx = {
197 nfev += 1;
198 func(&x)
199 };
200
201 let dirs = self.generate_pattern(n);
202
203 loop {
204 if nit >= self.options.max_iter || nfev >= self.options.max_fev {
205 break;
206 }
207
208 if delta < self.options.delta_min {
209 return Ok(DfOptResult {
210 x: Array1::from_vec(x),
211 fun: fx,
212 nfev,
213 nit,
214 success: true,
215 message: "Converged: mesh size below tolerance".to_string(),
216 });
217 }
218
219 let mut improved = false;
220 let mut best_x = x.clone();
221 let mut best_f = fx;
222
223 'poll: for dir in &dirs {
224 if nfev >= self.options.max_fev {
225 break 'poll;
226 }
227
228 let xtrial: Vec<f64> = x
230 .iter()
231 .zip(dir.iter())
232 .map(|(&xi, &di)| xi + delta * di)
233 .collect();
234 let xtrial = self.project(&xtrial, &lo, &hi);
235
236 nfev += 1;
237 let ftrial = func(&xtrial);
238
239 if ftrial < best_f - self.options.f_tol {
240 best_f = ftrial;
241 best_x = xtrial;
242 improved = true;
243 if self.options.opportunistic {
244 break 'poll;
245 }
246 }
247 }
248
249 if improved {
250 x = best_x;
251 fx = best_f;
252 delta = (delta * self.options.expand_factor).min(self.options.delta_init * 100.0);
254 } else {
255 delta *= self.options.contract_factor;
257 }
258
259 nit += 1;
260 }
261
262 let success = delta < self.options.delta_min * 10.0;
263 Ok(DfOptResult {
264 x: Array1::from_vec(x),
265 fun: fx,
266 nfev,
267 nit,
268 success,
269 message: if success {
270 "Converged".to_string()
271 } else {
272 "Maximum iterations/evaluations reached".to_string()
273 },
274 })
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use approx::assert_abs_diff_eq;
282
283 #[test]
284 fn test_pattern_search_quadratic() {
285 let solver = PatternSearchSolver::new();
286 let result = solver
287 .minimize(
288 |x: &[f64]| (x[0] - 2.0).powi(2) + (x[1] - 3.0).powi(2),
289 &[0.0, 0.0],
290 )
291 .expect("optimization failed");
292 assert_abs_diff_eq!(result.x[0], 2.0, epsilon = 1e-3);
293 assert_abs_diff_eq!(result.x[1], 3.0, epsilon = 1e-3);
294 assert_abs_diff_eq!(result.fun, 0.0, epsilon = 1e-4);
295 }
296
297 #[test]
298 fn test_pattern_search_bounded() {
299 let opts = PatternSearchOptions {
300 lower: Some(vec![0.0, 0.0]),
301 upper: Some(vec![5.0, 5.0]),
302 delta_min: 1e-8,
303 ..Default::default()
304 };
305 let solver = PatternSearchSolver::with_options(opts);
306 let result = solver
308 .minimize(
309 |x: &[f64]| (x[0] + 2.0).powi(2) + (x[1] + 2.0).powi(2),
310 &[1.0, 1.0],
311 )
312 .expect("optimization failed");
313 assert_abs_diff_eq!(result.x[0], 0.0, epsilon = 1e-3);
314 assert_abs_diff_eq!(result.x[1], 0.0, epsilon = 1e-3);
315 }
316
317 #[test]
318 fn test_pattern_search_compass_rose() {
319 let opts = PatternSearchOptions {
320 pattern: PatternType::CompassRose,
321 delta_min: 1e-7,
322 max_fev: 100000,
323 ..Default::default()
324 };
325 let solver = PatternSearchSolver::with_options(opts);
326 let result = solver
327 .minimize(
328 |x: &[f64]| (x[0] - 1.0).powi(2) + (x[1] - 2.0).powi(2),
329 &[0.0, 0.0],
330 )
331 .expect("optimization failed");
332 assert_abs_diff_eq!(result.x[0], 1.0, epsilon = 1e-3);
333 assert_abs_diff_eq!(result.x[1], 2.0, epsilon = 1e-3);
334 }
335
336 #[test]
337 fn test_pattern_search_non_opportunistic() {
338 let opts = PatternSearchOptions {
339 opportunistic: false,
340 delta_min: 1e-6,
341 max_fev: 200000,
342 ..Default::default()
343 };
344 let solver = PatternSearchSolver::with_options(opts);
345 let result = solver
346 .minimize(|x: &[f64]| x[0].powi(2) + x[1].powi(2), &[5.0, 5.0])
347 .expect("optimization failed");
348 assert_abs_diff_eq!(result.fun, 0.0, epsilon = 1e-4);
349 }
350
351 #[test]
352 fn test_pattern_search_1d() {
353 let solver = PatternSearchSolver::new();
354 let result = solver
355 .minimize(|x: &[f64]| (x[0] - 7.0).powi(2), &[0.0])
356 .expect("optimization failed");
357 assert_abs_diff_eq!(result.x[0], 7.0, epsilon = 1e-3);
358 assert_abs_diff_eq!(result.fun, 0.0, epsilon = 1e-4);
359 }
360
361 #[test]
362 fn test_pattern_search_simplex_pattern() {
363 let opts = PatternSearchOptions {
364 pattern: PatternType::Simplex,
365 delta_min: 1e-7,
366 ..Default::default()
367 };
368 let solver = PatternSearchSolver::with_options(opts);
369 let result = solver
370 .minimize(
371 |x: &[f64]| (x[0] - 1.5).powi(2) + (x[1] + 0.5).powi(2),
372 &[0.0, 0.0],
373 )
374 .expect("optimization failed");
375 assert_abs_diff_eq!(result.x[0], 1.5, epsilon = 1e-3);
376 assert_abs_diff_eq!(result.x[1], -0.5, epsilon = 1e-3);
377 }
378}