Skip to main content

scirs2_optimize/derivative_free/
pattern_search.rs

1//! Generalized Pattern Search (GPS) / Pattern Search optimization
2//!
3//! Pattern search methods explore the objective function by evaluating it at
4//! a set of trial points generated by a positive spanning set (the pattern).
5//! The mesh size is reduced when no improvement is found and increased on success.
6//!
7//! This implements the MADS (Mesh Adaptive Direct Search) variant using an
8//! orthogonal basis pattern with optional bound constraints.
9//!
10//! # References
11//! - Torczon, V. (1997). "On the convergence of pattern search algorithms."
12//!   SIAM Journal on Optimization, 7(1), 1-25.
13//! - Audet, C. & Dennis, J.E. (2006). "Mesh adaptive direct search algorithms
14//!   for constrained optimization." SIAM Journal on Optimization, 17(1), 188-217.
15
16use super::{clip, DerivativeFreeOptimizer, DfOptResult};
17use crate::error::{OptimizeError, OptimizeResult};
18use scirs2_core::ndarray::Array1;
19
20/// Pattern type for pattern search
21#[derive(Debug, Clone, Copy, PartialEq)]
22pub enum PatternType {
23    /// ±ei axes (2n directions)
24    Axes,
25    /// Compass rose: ±ei and ±(ei - ej) where i < j
26    CompassRose,
27    /// Simplex pattern: ei and -sum(ei)
28    Simplex,
29}
30
31/// Options for the Pattern Search algorithm
32#[derive(Debug, Clone)]
33pub struct PatternSearchOptions {
34    /// Lower bounds for each variable
35    pub lower: Option<Vec<f64>>,
36    /// Upper bounds for each variable
37    pub upper: Option<Vec<f64>>,
38    /// Initial mesh size
39    pub delta_init: f64,
40    /// Minimum mesh size (convergence threshold)
41    pub delta_min: f64,
42    /// Mesh size expansion factor (on success)
43    pub expand_factor: f64,
44    /// Mesh size contraction factor (on failure)
45    pub contract_factor: f64,
46    /// Maximum number of function evaluations
47    pub max_fev: usize,
48    /// Maximum number of iterations
49    pub max_iter: usize,
50    /// Absolute tolerance on function value improvement
51    pub f_tol: f64,
52    /// Pattern type
53    pub pattern: PatternType,
54    /// Use opportunistic search (accept first improving point)
55    pub opportunistic: bool,
56}
57
58impl Default for PatternSearchOptions {
59    fn default() -> Self {
60        PatternSearchOptions {
61            lower: None,
62            upper: None,
63            delta_init: 1.0,
64            delta_min: 1e-7,
65            expand_factor: 2.0,
66            contract_factor: 0.5,
67            max_fev: 50000,
68            max_iter: 10000,
69            f_tol: 1e-10,
70            pattern: PatternType::Axes,
71            opportunistic: true,
72        }
73    }
74}
75
76/// Pattern Search optimizer implementing GPS/MADS
77pub struct PatternSearchSolver {
78    pub options: PatternSearchOptions,
79}
80
81impl PatternSearchSolver {
82    /// Create with default options
83    pub fn new() -> Self {
84        PatternSearchSolver {
85            options: PatternSearchOptions::default(),
86        }
87    }
88
89    /// Create with custom options
90    pub fn with_options(options: PatternSearchOptions) -> Self {
91        PatternSearchSolver { options }
92    }
93
94    /// Get effective bounds
95    fn get_bounds(&self, n: usize) -> (Vec<f64>, Vec<f64>) {
96        let lo = match &self.options.lower {
97            Some(l) => l.clone(),
98            None => vec![f64::NEG_INFINITY; n],
99        };
100        let hi = match &self.options.upper {
101            Some(u) => u.clone(),
102            None => vec![f64::INFINITY; n],
103        };
104        (lo, hi)
105    }
106
107    /// Project onto box
108    fn project(&self, x: &[f64], lo: &[f64], hi: &[f64]) -> Vec<f64> {
109        x.iter()
110            .zip(lo.iter().zip(hi.iter()))
111            .map(|(&xi, (&l, &h))| clip(xi, l, h))
112            .collect()
113    }
114
115    /// Generate pattern directions based on pattern type
116    fn generate_pattern(&self, n: usize) -> Vec<Vec<f64>> {
117        match self.options.pattern {
118            PatternType::Axes => {
119                let mut dirs = Vec::with_capacity(2 * n);
120                for i in 0..n {
121                    let mut d = vec![0.0; n];
122                    d[i] = 1.0;
123                    dirs.push(d.clone());
124                    d[i] = -1.0;
125                    dirs.push(d);
126                }
127                dirs
128            }
129            PatternType::CompassRose => {
130                let mut dirs = Vec::new();
131                // Axis directions
132                for i in 0..n {
133                    let mut d = vec![0.0; n];
134                    d[i] = 1.0;
135                    dirs.push(d.clone());
136                    d[i] = -1.0;
137                    dirs.push(d);
138                }
139                // Diagonal directions ei - ej for i < j (limited to avoid combinatorial explosion)
140                if n <= 8 {
141                    for i in 0..n {
142                        for j in (i + 1)..n {
143                            let mut d = vec![0.0; n];
144                            d[i] = 1.0;
145                            d[j] = -1.0;
146                            let scale = 1.0 / (2.0_f64).sqrt();
147                            let ds: Vec<f64> = d.iter().map(|v| v * scale).collect();
148                            dirs.push(ds.clone());
149                            let ds2: Vec<f64> = ds.iter().map(|v| -v).collect();
150                            dirs.push(ds2);
151                        }
152                    }
153                }
154                dirs
155            }
156            PatternType::Simplex => {
157                let mut dirs = Vec::with_capacity(n + 1);
158                let neg_sum_scale = 1.0 / (n as f64).sqrt();
159                let neg_d = vec![-neg_sum_scale; n];
160                for i in 0..n {
161                    let mut d = vec![0.0; n];
162                    d[i] = 1.0;
163                    dirs.push(d);
164                }
165                dirs.push(neg_d);
166                dirs
167            }
168        }
169    }
170}
171
172impl Default for PatternSearchSolver {
173    fn default() -> Self {
174        PatternSearchSolver::new()
175    }
176}
177
178impl DerivativeFreeOptimizer for PatternSearchSolver {
179    fn minimize<F>(&self, func: F, x0: &[f64]) -> OptimizeResult<DfOptResult>
180    where
181        F: Fn(&[f64]) -> f64,
182    {
183        let n = x0.len();
184        if n == 0 {
185            return Err(OptimizeError::InvalidInput(
186                "x0 must be non-empty".to_string(),
187            ));
188        }
189
190        let (lo, hi) = self.get_bounds(n);
191        let mut x = self.project(x0, &lo, &hi);
192        let mut delta = self.options.delta_init;
193
194        let mut nfev = 0usize;
195        let mut nit = 0usize;
196        let mut fx = {
197            nfev += 1;
198            func(&x)
199        };
200
201        let dirs = self.generate_pattern(n);
202
203        loop {
204            if nit >= self.options.max_iter || nfev >= self.options.max_fev {
205                break;
206            }
207
208            if delta < self.options.delta_min {
209                return Ok(DfOptResult {
210                    x: Array1::from_vec(x),
211                    fun: fx,
212                    nfev,
213                    nit,
214                    success: true,
215                    message: "Converged: mesh size below tolerance".to_string(),
216                });
217            }
218
219            let mut improved = false;
220            let mut best_x = x.clone();
221            let mut best_f = fx;
222
223            'poll: for dir in &dirs {
224                if nfev >= self.options.max_fev {
225                    break 'poll;
226                }
227
228                // Trial point: x + delta * dir
229                let xtrial: Vec<f64> = x
230                    .iter()
231                    .zip(dir.iter())
232                    .map(|(&xi, &di)| xi + delta * di)
233                    .collect();
234                let xtrial = self.project(&xtrial, &lo, &hi);
235
236                nfev += 1;
237                let ftrial = func(&xtrial);
238
239                if ftrial < best_f - self.options.f_tol {
240                    best_f = ftrial;
241                    best_x = xtrial;
242                    improved = true;
243                    if self.options.opportunistic {
244                        break 'poll;
245                    }
246                }
247            }
248
249            if improved {
250                x = best_x;
251                fx = best_f;
252                // Expand mesh on success
253                delta = (delta * self.options.expand_factor).min(self.options.delta_init * 100.0);
254            } else {
255                // Contract mesh on failure
256                delta *= self.options.contract_factor;
257            }
258
259            nit += 1;
260        }
261
262        let success = delta < self.options.delta_min * 10.0;
263        Ok(DfOptResult {
264            x: Array1::from_vec(x),
265            fun: fx,
266            nfev,
267            nit,
268            success,
269            message: if success {
270                "Converged".to_string()
271            } else {
272                "Maximum iterations/evaluations reached".to_string()
273            },
274        })
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use approx::assert_abs_diff_eq;
282
283    #[test]
284    fn test_pattern_search_quadratic() {
285        let solver = PatternSearchSolver::new();
286        let result = solver
287            .minimize(
288                |x: &[f64]| (x[0] - 2.0).powi(2) + (x[1] - 3.0).powi(2),
289                &[0.0, 0.0],
290            )
291            .expect("optimization failed");
292        assert_abs_diff_eq!(result.x[0], 2.0, epsilon = 1e-3);
293        assert_abs_diff_eq!(result.x[1], 3.0, epsilon = 1e-3);
294        assert_abs_diff_eq!(result.fun, 0.0, epsilon = 1e-4);
295    }
296
297    #[test]
298    fn test_pattern_search_bounded() {
299        let opts = PatternSearchOptions {
300            lower: Some(vec![0.0, 0.0]),
301            upper: Some(vec![5.0, 5.0]),
302            delta_min: 1e-8,
303            ..Default::default()
304        };
305        let solver = PatternSearchSolver::with_options(opts);
306        // True min at (-2, -2), bounded min at (0, 0)
307        let result = solver
308            .minimize(
309                |x: &[f64]| (x[0] + 2.0).powi(2) + (x[1] + 2.0).powi(2),
310                &[1.0, 1.0],
311            )
312            .expect("optimization failed");
313        assert_abs_diff_eq!(result.x[0], 0.0, epsilon = 1e-3);
314        assert_abs_diff_eq!(result.x[1], 0.0, epsilon = 1e-3);
315    }
316
317    #[test]
318    fn test_pattern_search_compass_rose() {
319        let opts = PatternSearchOptions {
320            pattern: PatternType::CompassRose,
321            delta_min: 1e-7,
322            max_fev: 100000,
323            ..Default::default()
324        };
325        let solver = PatternSearchSolver::with_options(opts);
326        let result = solver
327            .minimize(
328                |x: &[f64]| (x[0] - 1.0).powi(2) + (x[1] - 2.0).powi(2),
329                &[0.0, 0.0],
330            )
331            .expect("optimization failed");
332        assert_abs_diff_eq!(result.x[0], 1.0, epsilon = 1e-3);
333        assert_abs_diff_eq!(result.x[1], 2.0, epsilon = 1e-3);
334    }
335
336    #[test]
337    fn test_pattern_search_non_opportunistic() {
338        let opts = PatternSearchOptions {
339            opportunistic: false,
340            delta_min: 1e-6,
341            max_fev: 200000,
342            ..Default::default()
343        };
344        let solver = PatternSearchSolver::with_options(opts);
345        let result = solver
346            .minimize(|x: &[f64]| x[0].powi(2) + x[1].powi(2), &[5.0, 5.0])
347            .expect("optimization failed");
348        assert_abs_diff_eq!(result.fun, 0.0, epsilon = 1e-4);
349    }
350
351    #[test]
352    fn test_pattern_search_1d() {
353        let solver = PatternSearchSolver::new();
354        let result = solver
355            .minimize(|x: &[f64]| (x[0] - 7.0).powi(2), &[0.0])
356            .expect("optimization failed");
357        assert_abs_diff_eq!(result.x[0], 7.0, epsilon = 1e-3);
358        assert_abs_diff_eq!(result.fun, 0.0, epsilon = 1e-4);
359    }
360
361    #[test]
362    fn test_pattern_search_simplex_pattern() {
363        let opts = PatternSearchOptions {
364            pattern: PatternType::Simplex,
365            delta_min: 1e-7,
366            ..Default::default()
367        };
368        let solver = PatternSearchSolver::with_options(opts);
369        let result = solver
370            .minimize(
371                |x: &[f64]| (x[0] - 1.5).powi(2) + (x[1] + 0.5).powi(2),
372                &[0.0, 0.0],
373            )
374            .expect("optimization failed");
375        assert_abs_diff_eq!(result.x[0], 1.5, epsilon = 1e-3);
376        assert_abs_diff_eq!(result.x[1], -0.5, epsilon = 1e-3);
377    }
378}