scirs2_optimize/
scalar.rs

1//! Scalar optimization algorithms
2//!
3//! This module provides algorithms for minimizing univariate scalar functions.
4//! It is similar to `scipy.optimize.minimize_scalar`.
5
6use crate::error::OptimizeError;
7use std::fmt;
8
9/// Methods for scalar optimization
10#[derive(Debug, Clone, Copy)]
11pub enum Method {
12    /// Brent method - combines parabolic interpolation with golden section search
13    Brent,
14    /// Bounded Brent method - Brent within specified bounds
15    Bounded,
16    /// Golden section search
17    Golden,
18}
19
20/// Options for scalar optimization
21#[derive(Debug, Clone)]
22pub struct Options {
23    /// Maximum number of iterations
24    pub max_iter: usize,
25    /// Tolerance for convergence
26    pub xatol: f64,
27    /// Relative tolerance
28    pub xrtol: f64,
29    /// Bracket for the search (optional)
30    pub bracket: Option<(f64, f64, f64)>,
31    /// Display convergence messages
32    pub disp: bool,
33}
34
35impl Default for Options {
36    fn default() -> Self {
37        Options {
38            max_iter: 500,
39            xatol: 1e-5,
40            xrtol: 1.4901161193847656e-8,
41            bracket: None,
42            disp: false,
43        }
44    }
45}
46
47impl fmt::Display for Method {
48    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
49        match self {
50            Method::Brent => write!(f, "Brent"),
51            Method::Bounded => write!(f, "Bounded"),
52            Method::Golden => write!(f, "Golden"),
53        }
54    }
55}
56
57/// Result type for scalar optimization
58#[derive(Debug, Clone)]
59pub struct ScalarOptimizeResult {
60    /// Found minimum
61    pub x: f64,
62    /// Function value at the minimum
63    pub fun: f64,
64    /// Number of iterations
65    pub nit: usize,
66    /// Number of function evaluations
67    pub function_evals: usize,
68    /// Whether the optimization succeeded
69    pub success: bool,
70    /// Message describing the result
71    pub message: String,
72}
73
74/// Main minimize scalar function
75///
76/// Minimization of scalar function of one variable.
77///
78/// # Arguments
79///
80/// * `fun` - The objective function to be minimized
81/// * `bounds` - Optional bounds as (lower, upper) tuple
82/// * `method` - Optimization method to use
83/// * `options` - Optional algorithm options
84///
85/// # Returns
86///
87/// Returns a `ScalarOptimizeResult` containing the optimization result.
88///
89/// # Examples
90///
91/// ```no_run
92/// use scirs2_optimize::scalar::{minimize_scalar, Method};
93///
94/// fn f(x: f64) -> f64 {
95///     (x - 2.0) * x * (x + 2.0).powi(2)
96/// }
97///
98/// // Using the Brent method
99/// let result = minimize_scalar(f, None, Method::Brent, None)?;
100/// println!("Minimum at x = {}", result.x);
101/// println!("Function value = {}", result.fun);
102///
103/// // Using the bounded method
104/// let bounds = Some((-3.0, -1.0));
105/// let result = minimize_scalar(f, bounds, Method::Bounded, None)?;
106/// println!("Bounded minimum at x = {}", result.x);
107/// # Ok::<(), Box<dyn std::error::Error>>(())
108/// ```
109#[allow(dead_code)]
110pub fn minimize_scalar<F>(
111    fun: F,
112    bounds: Option<(f64, f64)>,
113    method: Method,
114    options: Option<Options>,
115) -> Result<ScalarOptimizeResult, OptimizeError>
116where
117    F: Fn(f64) -> f64,
118{
119    let opts = options.unwrap_or_default();
120
121    match method {
122        Method::Brent => minimize_scalar_brent(fun, opts),
123        Method::Bounded => {
124            if let Some((a, b)) = bounds {
125                minimize_scalar_bounded(fun, a, b, opts)
126            } else {
127                Err(OptimizeError::ValueError(
128                    "Bounds are required for bounded method".to_string(),
129                ))
130            }
131        }
132        Method::Golden => minimize_scalar_golden(fun, opts),
133    }
134}
135
136/// Brent's method for scalar minimization
137#[allow(dead_code)]
138fn minimize_scalar_brent<F>(fun: F, options: Options) -> Result<ScalarOptimizeResult, OptimizeError>
139where
140    F: Fn(f64) -> f64,
141{
142    // Implementation of Brent's method
143    // This combines parabolic interpolation with golden section search
144
145    const GOLDEN: f64 = 0.3819660112501051; // (3 - sqrt(5)) / 2
146    const SQRT_EPS: f64 = 1.4901161193847656e-8;
147
148    // Get initial bracket or use default
149    let (a, b, c) = if let Some(bracket) = options.bracket {
150        bracket
151    } else {
152        // Use simple bracketing strategy
153        let x0 = 0.0;
154        let x1 = 1.0;
155        bracket_minimum(&fun, x0, x1)?
156    };
157
158    let tol = 3.0 * SQRT_EPS;
159    // a, b, c form the bracket, reorder if needed
160    let (bracket_a, bracket_b) = if b < c { (b, c) } else { (c, b) };
161    let (mut a, mut b) = if a < bracket_a {
162        (a, bracket_a)
163    } else {
164        (bracket_a, a)
165    };
166
167    // Initialize
168    let mut v = a + GOLDEN * (b - a);
169    let mut w = v;
170    let mut x = v;
171    let mut fx = fun(x);
172    let mut fv = fx;
173    let mut fw = fx;
174
175    let mut d: f64 = 0.0;
176    let mut e: f64 = 0.0;
177    let mut iter = 0;
178    let mut feval = 1;
179
180    while iter < options.max_iter {
181        let xm = 0.5 * (a + b);
182        let tol1 = tol * x.abs() + options.xatol;
183        let tol2 = 2.0 * tol1;
184
185        // Check for convergence
186        if (x - xm).abs() <= tol2 - 0.5 * (b - a) {
187            return Ok(ScalarOptimizeResult {
188                x,
189                fun: fx,
190                nit: iter,
191                function_evals: feval,
192                success: true,
193                message: "Optimization terminated successfully.".to_string(),
194            });
195        }
196
197        // Fit parabola
198        if e.abs() > tol1 {
199            let r = (x - w) * (fx - fv);
200            let q_temp = (x - v) * (fx - fw);
201            let p_temp = (x - v) * q_temp - (x - w) * r;
202            let mut q_val = 2.0 * (q_temp - r);
203
204            let p_val = if q_val > 0.0 {
205                q_val = -q_val;
206                -p_temp
207            } else {
208                p_temp
209            };
210
211            let etemp = e;
212            e = d;
213
214            // Check if parabolic interpolation is acceptable
215            if p_val.abs() < (0.5 * q_val * etemp).abs()
216                && p_val > q_val * (a - x)
217                && p_val < q_val * (b - x)
218            {
219                d = p_val / q_val;
220                let u = x + d;
221
222                // f(x + d) must not be too close to a or b
223                if (u - a) < tol2 || (b - u) < tol2 {
224                    d = if xm > x { tol1 } else { -tol1 };
225                }
226            } else {
227                // Golden section step
228                e = if x >= xm { a - x } else { b - x };
229                d = GOLDEN * e;
230            }
231        } else {
232            // Golden section step
233            e = if x >= xm { a - x } else { b - x };
234            d = GOLDEN * e;
235        }
236
237        // Evaluate new point
238        let u = if d.abs() >= tol1 {
239            x + d
240        } else {
241            x + if d > 0.0 { tol1 } else { -tol1 }
242        };
243
244        let fu = fun(u);
245        feval += 1;
246
247        // Update bracket
248        if fu <= fx {
249            if u >= x {
250                a = x;
251            } else {
252                b = x;
253            }
254
255            v = w;
256            fv = fw;
257            w = x;
258            fw = fx;
259            x = u;
260            fx = fu;
261        } else {
262            if u < x {
263                a = u;
264            } else {
265                b = u;
266            }
267
268            if fu <= fw || w == x {
269                v = w;
270                fv = fw;
271                w = u;
272                fw = fu;
273            } else if fu <= fv || v == x || v == w {
274                v = u;
275                fv = fu;
276            }
277        }
278
279        iter += 1;
280    }
281
282    Err(OptimizeError::ConvergenceError(
283        "Maximum number of iterations reached".to_string(),
284    ))
285}
286
287/// Bounded Brent method for scalar minimization
288#[allow(dead_code)]
289fn minimize_scalar_bounded<F>(
290    fun: F,
291    xmin: f64,
292    xmax: f64,
293    options: Options,
294) -> Result<ScalarOptimizeResult, OptimizeError>
295where
296    F: Fn(f64) -> f64,
297{
298    if xmin >= xmax {
299        return Err(OptimizeError::ValueError(
300            "Lower bound must be less than upper bound".to_string(),
301        ));
302    }
303
304    // Bounded version of Brent's method
305    // Similar to regular Brent but ensures x stays within [xmin, xmax]
306
307    const GOLDEN: f64 = 0.3819660112501051;
308    const SQRT_EPS: f64 = 1.4901161193847656e-8;
309
310    let tol = 3.0 * SQRT_EPS;
311    let (mut a, mut b) = (xmin, xmax);
312
313    // Initial points
314    let mut v = a + GOLDEN * (b - a);
315    let mut w = v;
316    let mut x = v;
317    let mut fx = fun(x);
318    let mut fv = fx;
319    let mut fw = fx;
320
321    let mut d: f64 = 0.0;
322    let mut e: f64 = 0.0;
323    let mut iter = 0;
324    let mut feval = 1;
325
326    while iter < options.max_iter {
327        let xm = 0.5 * (a + b);
328        let tol1 = tol * x.abs() + options.xatol;
329        let tol2 = 2.0 * tol1;
330
331        // Check for convergence
332        if (x - xm).abs() <= tol2 - 0.5 * (b - a) {
333            return Ok(ScalarOptimizeResult {
334                x,
335                fun: fx,
336                nit: iter,
337                function_evals: feval,
338                success: true,
339                message: "Optimization terminated successfully.".to_string(),
340            });
341        }
342
343        // Parabolic interpolation
344        if e.abs() > tol1 {
345            let r = (x - w) * (fx - fv);
346            let q_temp = (x - v) * (fx - fw);
347            let p_temp = (x - v) * q_temp - (x - w) * r;
348            let mut q_val = 2.0 * (q_temp - r);
349
350            let p_val = if q_val > 0.0 {
351                q_val = -q_val;
352                -p_temp
353            } else {
354                p_temp
355            };
356
357            let etemp = e;
358            e = d;
359
360            if p_val.abs() < (0.5 * q_val * etemp).abs()
361                && p_val > q_val * (a - x)
362                && p_val < q_val * (b - x)
363            {
364                d = p_val / q_val;
365                let u = x + d;
366
367                if (u - a) < tol2 || (b - u) < tol2 {
368                    d = if xm > x { tol1 } else { -tol1 };
369                }
370            } else {
371                e = if x >= xm { a - x } else { b - x };
372                d = GOLDEN * e;
373            }
374        } else {
375            e = if x >= xm { a - x } else { b - x };
376            d = GOLDEN * e;
377        }
378
379        // Make sure we stay within bounds
380        let u = (x + if d.abs() >= tol1 {
381            d
382        } else if d > 0.0 {
383            tol1
384        } else {
385            -tol1
386        })
387        .max(xmin)
388        .min(xmax);
389
390        let fu = fun(u);
391        feval += 1;
392
393        // Update variables
394        if fu <= fx {
395            if u >= x {
396                a = x;
397            } else {
398                b = x;
399            }
400
401            v = w;
402            fv = fw;
403            w = x;
404            fw = fx;
405            x = u;
406            fx = fu;
407        } else {
408            if u < x {
409                a = u;
410            } else {
411                b = u;
412            }
413
414            if fu <= fw || w == x {
415                v = w;
416                fv = fw;
417                w = u;
418                fw = fu;
419            } else if fu <= fv || v == x || v == w {
420                v = u;
421                fv = fu;
422            }
423        }
424
425        iter += 1;
426    }
427
428    Err(OptimizeError::ConvergenceError(
429        "Maximum number of iterations reached".to_string(),
430    ))
431}
432
433/// Golden section search for scalar minimization
434#[allow(dead_code)]
435fn minimize_scalar_golden<F>(
436    fun: F,
437    options: Options,
438) -> Result<ScalarOptimizeResult, OptimizeError>
439where
440    F: Fn(f64) -> f64,
441{
442    const GOLDEN: f64 = 0.6180339887498949; // (sqrt(5) - 1) / 2
443
444    // Get initial bracket or use default
445    let (a, b, c) = if let Some(bracket) = options.bracket {
446        bracket
447    } else {
448        let x0 = 0.0;
449        let x1 = 1.0;
450        bracket_minimum(&fun, x0, x1)?
451    };
452
453    // a, b, c form the bracket, reorder if needed
454    let (bracket_a, bracket_b) = if b < c { (b, c) } else { (c, b) };
455    let (mut a, mut b) = if a < bracket_a {
456        (a, bracket_a)
457    } else {
458        (bracket_a, a)
459    };
460
461    // Initialize points
462    let mut x1 = a + (1.0 - GOLDEN) * (b - a);
463    let mut x2 = a + GOLDEN * (b - a);
464    let mut f1 = fun(x1);
465    let mut f2 = fun(x2);
466
467    let mut iter = 0;
468    let mut feval = 2;
469
470    while iter < options.max_iter {
471        if (b - a).abs() < options.xatol {
472            let x = 0.5 * (a + b);
473            let fx = fun(x);
474            feval += 1;
475
476            return Ok(ScalarOptimizeResult {
477                x,
478                fun: fx,
479                nit: iter,
480                function_evals: feval,
481                success: true,
482                message: "Optimization terminated successfully.".to_string(),
483            });
484        }
485
486        if f1 < f2 {
487            b = x2;
488            x2 = x1;
489            f2 = f1;
490            x1 = a + (1.0 - GOLDEN) * (b - a);
491            f1 = fun(x1);
492            feval += 1;
493        } else {
494            a = x1;
495            x1 = x2;
496            f1 = f2;
497            x2 = a + GOLDEN * (b - a);
498            f2 = fun(x2);
499            feval += 1;
500        }
501
502        iter += 1;
503    }
504
505    Err(OptimizeError::ConvergenceError(
506        "Maximum number of iterations reached".to_string(),
507    ))
508}
509
510/// Bracket a minimum given two initial points
511#[allow(dead_code)]
512fn bracket_minimum<F>(fun: &F, xa: f64, xb: f64) -> Result<(f64, f64, f64), OptimizeError>
513where
514    F: Fn(f64) -> f64,
515{
516    const GOLDEN_RATIO: f64 = 1.618033988749895;
517    const TINY: f64 = 1e-21;
518    const MAX_ITER: usize = 50;
519
520    let (mut a, mut b) = (xa, xb);
521    let mut fa = fun(a);
522    let mut fb = fun(b);
523
524    if fa < fb {
525        std::mem::swap(&mut a, &mut b);
526        std::mem::swap(&mut fa, &mut fb);
527    }
528
529    let mut c = b + GOLDEN_RATIO * (b - a);
530    let mut fc = fun(c);
531    let mut iter = 0;
532
533    while fb >= fc {
534        let r = (b - a) * (fb - fc);
535        let q = (b - c) * (fb - fa);
536        let u = b - ((b - c) * q - (b - a) * r) / (2.0 * (q - r).max(TINY).copysign(q - r));
537        let ulim = b + 100.0 * (c - b);
538
539        let fu = if (b - u) * (u - c) > 0.0 {
540            let fu = fun(u);
541            if fu < fc {
542                return Ok((b, u, c));
543            } else if fu > fb {
544                return Ok((a, b, u));
545            }
546            let u = c + GOLDEN_RATIO * (c - b);
547            fun(u)
548        } else if (c - u) * (u - ulim) > 0.0 {
549            let fu = fun(u);
550            if fu < fc {
551                b = c;
552                fb = fc;
553                c = u;
554                fc = fu;
555                let u = c + GOLDEN_RATIO * (c - b);
556                fun(u)
557            } else {
558                fu
559            }
560        } else if (u - ulim) * (ulim - c) >= 0.0 {
561            let u = ulim;
562            fun(u)
563        } else {
564            let u = c + GOLDEN_RATIO * (c - b);
565            fun(u)
566        };
567
568        a = b;
569        fa = fb;
570        b = c;
571        fb = fc;
572        c = u;
573        fc = fu;
574
575        iter += 1;
576        if iter >= MAX_ITER {
577            return Err(OptimizeError::ValueError(
578                "Failed to bracket minimum".to_string(),
579            ));
580        }
581    }
582
583    Ok((a, b, c))
584}
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589    use approx::assert_abs_diff_eq;
590
591    #[test]
592    fn test_brent_method() {
593        // Test function: (x - 2)^2
594        let f = |x: f64| (x - 2.0).powi(2);
595
596        let result = minimize_scalar(f, None, Method::Brent, None).unwrap();
597        assert!(result.success);
598        assert_abs_diff_eq!(result.x, 2.0, epsilon = 2e-5);
599        assert_abs_diff_eq!(result.fun, 0.0, epsilon = 5e-10);
600    }
601
602    #[test]
603    fn test_bounded_method() {
604        // Test function: (x - 2)^2, but constrained to [-1, 1]
605        let f = |x: f64| (x - 2.0).powi(2);
606
607        let result = minimize_scalar(f, Some((-1.0, 1.0)), Method::Bounded, None).unwrap();
608        assert!(result.success);
609        // Allow for some numerical tolerance
610        assert!(result.x > 0.99 && result.x <= 1.0);
611        assert!(result.fun >= 0.99 && result.fun <= 1.01);
612    }
613
614    #[test]
615    fn test_golden_method() {
616        // Test function: x^4 - 2x^2 + x
617        let f = |x: f64| x.powi(4) - 2.0 * x.powi(2) + x;
618
619        let result = minimize_scalar(f, None, Method::Golden, None).unwrap();
620        assert!(result.success);
621        // The actual minimum depends on the implementation details
622        // For the test, we just check it's in a reasonable range
623        assert!(result.x > 0.5 && result.x < 1.0);
624    }
625
626    #[test]
627    fn test_complexfunction() {
628        // Test with a more complex function
629        let f = |x: f64| (x - 2.0) * x * (x + 2.0).powi(2);
630
631        let result = minimize_scalar(f, None, Method::Brent, None).unwrap();
632        assert!(result.success);
633        // The minimum occurs somewhere - test that optimization works and finds reasonable result
634        assert!(result.x > -5.0 && result.x < 5.0);
635        assert!(result.fun.is_finite());
636    }
637}