scirs2_optimize/
scalar.rs

1//! Scalar optimization algorithms
2//!
3//! This module provides algorithms for minimizing univariate scalar functions.
4//! It is similar to `scipy.optimize.minimize_scalar`.
5
6use crate::error::OptimizeError;
7use num_traits::Float;
8use std::fmt;
9
10/// Methods for scalar optimization
11#[derive(Debug, Clone, Copy)]
12pub enum Method {
13    /// Brent method - combines parabolic interpolation with golden section search
14    Brent,
15    /// Bounded Brent method - Brent within specified bounds
16    Bounded,
17    /// Golden section search
18    Golden,
19}
20
21/// Options for scalar optimization
22#[derive(Debug, Clone)]
23pub struct Options {
24    /// Maximum number of iterations
25    pub max_iter: usize,
26    /// Tolerance for convergence
27    pub xatol: f64,
28    /// Relative tolerance
29    pub xrtol: f64,
30    /// Bracket for the search (optional)
31    pub bracket: Option<(f64, f64, f64)>,
32    /// Display convergence messages
33    pub disp: bool,
34}
35
36impl Default for Options {
37    fn default() -> Self {
38        Options {
39            max_iter: 500,
40            xatol: 1e-5,
41            xrtol: 1.4901161193847656e-8,
42            bracket: None,
43            disp: false,
44        }
45    }
46}
47
48impl fmt::Display for Method {
49    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
50        match self {
51            Method::Brent => write!(f, "Brent"),
52            Method::Bounded => write!(f, "Bounded"),
53            Method::Golden => write!(f, "Golden"),
54        }
55    }
56}
57
58/// Result type for scalar optimization
59#[derive(Debug, Clone)]
60pub struct ScalarOptimizeResult {
61    /// Found minimum
62    pub x: f64,
63    /// Function value at the minimum
64    pub fun: f64,
65    /// Number of iterations
66    pub iterations: usize,
67    /// Number of function evaluations
68    pub function_evals: usize,
69    /// Whether the optimization succeeded
70    pub success: bool,
71    /// Message describing the result
72    pub message: String,
73}
74
75/// Main minimize scalar function
76///
77/// Minimization of scalar function of one variable.
78///
79/// # Arguments
80///
81/// * `fun` - The objective function to be minimized
82/// * `bounds` - Optional bounds as (lower, upper) tuple
83/// * `method` - Optimization method to use
84/// * `options` - Optional algorithm options
85///
86/// # Returns
87///
88/// Returns a `ScalarOptimizeResult` containing the optimization result.
89///
90/// # Examples
91///
92/// ```no_run
93/// use scirs2_optimize::scalar::{minimize_scalar, Method};
94///
95/// fn f(x: f64) -> f64 {
96///     (x - 2.0) * x * (x + 2.0).powi(2)
97/// }
98///
99/// // Using the Brent method
100/// let result = minimize_scalar(f, None, Method::Brent, None)?;
101/// println!("Minimum at x = {}", result.x);
102/// println!("Function value = {}", result.fun);
103///
104/// // Using the bounded method
105/// let bounds = Some((-3.0, -1.0));
106/// let result = minimize_scalar(f, bounds, Method::Bounded, None)?;
107/// println!("Bounded minimum at x = {}", result.x);
108/// # Ok::<(), Box<dyn std::error::Error>>(())
109/// ```
110pub fn minimize_scalar<F>(
111    fun: F,
112    bounds: Option<(f64, f64)>,
113    method: Method,
114    options: Option<Options>,
115) -> Result<ScalarOptimizeResult, OptimizeError>
116where
117    F: Fn(f64) -> f64,
118{
119    let opts = options.unwrap_or_default();
120
121    match method {
122        Method::Brent => minimize_scalar_brent(fun, opts),
123        Method::Bounded => {
124            if let Some((a, b)) = bounds {
125                minimize_scalar_bounded(fun, a, b, opts)
126            } else {
127                Err(OptimizeError::ValueError(
128                    "Bounds are required for bounded method".to_string(),
129                ))
130            }
131        }
132        Method::Golden => minimize_scalar_golden(fun, opts),
133    }
134}
135
136/// Brent's method for scalar minimization
137fn minimize_scalar_brent<F>(fun: F, options: Options) -> Result<ScalarOptimizeResult, OptimizeError>
138where
139    F: Fn(f64) -> f64,
140{
141    // Implementation of Brent's method
142    // This combines parabolic interpolation with golden section search
143
144    const GOLDEN: f64 = 0.3819660112501051; // (3 - sqrt(5)) / 2
145    const SQRT_EPS: f64 = 1.4901161193847656e-8;
146
147    // Get initial bracket or use default
148    let (a, _b, c) = if let Some(bracket) = options.bracket {
149        bracket
150    } else {
151        // Use simple bracketing strategy
152        let x0 = 0.0;
153        let x1 = 1.0;
154        bracket_minimum(&fun, x0, x1)?
155    };
156
157    let tol = 3.0 * SQRT_EPS;
158    let (mut a, mut b) = if a < c { (a, c) } else { (c, a) };
159
160    // Initialize
161    let mut v = a + GOLDEN * (b - a);
162    let mut w = v;
163    let mut x = v;
164    let mut fx = fun(x);
165    let mut fv = fx;
166    let mut fw = fx;
167
168    let mut d = 0.0;
169    let mut e = 0.0;
170    let mut iter = 0;
171    let mut feval = 1;
172
173    while iter < options.max_iter {
174        let xm = 0.5 * (a + b);
175        let tol1 = tol * x.abs() + options.xatol;
176        let tol2 = 2.0 * tol1;
177
178        // Check for convergence
179        if (x - xm).abs() <= tol2 - 0.5 * (b - a) {
180            return Ok(ScalarOptimizeResult {
181                x,
182                fun: fx,
183                iterations: iter,
184                function_evals: feval,
185                success: true,
186                message: "Optimization terminated successfully.".to_string(),
187            });
188        }
189
190        // Fit parabola
191        if e.abs() > tol1 {
192            let r = (x - w) * (fx - fv);
193            let q_temp = (x - v) * (fx - fw);
194            let p_temp = (x - v) * q_temp - (x - w) * r;
195            let mut q_val = 2.0 * (q_temp - r);
196
197            let p_val = if q_val > 0.0 {
198                q_val = -q_val;
199                -p_temp
200            } else {
201                p_temp
202            };
203
204            let etemp = e;
205            e = d;
206
207            // Check if parabolic interpolation is acceptable
208            if p_val.abs() < (0.5 * q_val * etemp).abs()
209                && p_val > q_val * (a - x)
210                && p_val < q_val * (b - x)
211            {
212                d = p_val / q_val;
213                let u = x + d;
214
215                // f(x + d) must not be too close to a or b
216                if (u - a) < tol2 || (b - u) < tol2 {
217                    d = if xm > x { tol1 } else { -tol1 };
218                }
219            } else {
220                // Golden section step
221                e = if x >= xm { a - x } else { b - x };
222                d = GOLDEN * e;
223            }
224        } else {
225            // Golden section step
226            e = if x >= xm { a - x } else { b - x };
227            d = GOLDEN * e;
228        }
229
230        // Evaluate new point
231        let u = if d.abs() >= tol1 {
232            x + d
233        } else {
234            x + if d > 0.0 { tol1 } else { -tol1 }
235        };
236
237        let fu = fun(u);
238        feval += 1;
239
240        // Update bracket
241        if fu <= fx {
242            if u >= x {
243                a = x;
244            } else {
245                b = x;
246            }
247
248            v = w;
249            fv = fw;
250            w = x;
251            fw = fx;
252            x = u;
253            fx = fu;
254        } else {
255            if u < x {
256                a = u;
257            } else {
258                b = u;
259            }
260
261            if fu <= fw || w == x {
262                v = w;
263                fv = fw;
264                w = u;
265                fw = fu;
266            } else if fu <= fv || v == x || v == w {
267                v = u;
268                fv = fu;
269            }
270        }
271
272        iter += 1;
273    }
274
275    Err(OptimizeError::ConvergenceError(
276        "Maximum number of iterations reached".to_string(),
277    ))
278}
279
280/// Bounded Brent method for scalar minimization
281fn minimize_scalar_bounded<F>(
282    fun: F,
283    xmin: f64,
284    xmax: f64,
285    options: Options,
286) -> Result<ScalarOptimizeResult, OptimizeError>
287where
288    F: Fn(f64) -> f64,
289{
290    if xmin >= xmax {
291        return Err(OptimizeError::ValueError(
292            "Lower bound must be less than upper bound".to_string(),
293        ));
294    }
295
296    // Bounded version of Brent's method
297    // Similar to regular Brent but ensures x stays within [xmin, xmax]
298
299    const GOLDEN: f64 = 0.3819660112501051;
300    const SQRT_EPS: f64 = 1.4901161193847656e-8;
301
302    let tol = 3.0 * SQRT_EPS;
303    let (mut a, mut b) = (xmin, xmax);
304
305    // Initial points
306    let mut v = a + GOLDEN * (b - a);
307    let mut w = v;
308    let mut x = v;
309    let mut fx = fun(x);
310    let mut fv = fx;
311    let mut fw = fx;
312
313    let mut d = 0.0;
314    let mut e = 0.0;
315    let mut iter = 0;
316    let mut feval = 1;
317
318    while iter < options.max_iter {
319        let xm = 0.5 * (a + b);
320        let tol1 = tol * x.abs() + options.xatol;
321        let tol2 = 2.0 * tol1;
322
323        // Check for convergence
324        if (x - xm).abs() <= tol2 - 0.5 * (b - a) {
325            return Ok(ScalarOptimizeResult {
326                x,
327                fun: fx,
328                iterations: iter,
329                function_evals: feval,
330                success: true,
331                message: "Optimization terminated successfully.".to_string(),
332            });
333        }
334
335        // Parabolic interpolation
336        if e.abs() > tol1 {
337            let r = (x - w) * (fx - fv);
338            let q_temp = (x - v) * (fx - fw);
339            let p_temp = (x - v) * q_temp - (x - w) * r;
340            let mut q_val = 2.0 * (q_temp - r);
341
342            let p_val = if q_val > 0.0 {
343                q_val = -q_val;
344                -p_temp
345            } else {
346                p_temp
347            };
348
349            let etemp = e;
350            e = d;
351
352            if p_val.abs() < (0.5 * q_val * etemp).abs()
353                && p_val > q_val * (a - x)
354                && p_val < q_val * (b - x)
355            {
356                d = p_val / q_val;
357                let u = x + d;
358
359                if (u - a) < tol2 || (b - u) < tol2 {
360                    d = if xm > x { tol1 } else { -tol1 };
361                }
362            } else {
363                e = if x >= xm { a - x } else { b - x };
364                d = GOLDEN * e;
365            }
366        } else {
367            e = if x >= xm { a - x } else { b - x };
368            d = GOLDEN * e;
369        }
370
371        // Make sure we stay within bounds
372        let u = (x + if d.abs() >= tol1 {
373            d
374        } else if d > 0.0 {
375            tol1
376        } else {
377            -tol1
378        })
379        .max(xmin)
380        .min(xmax);
381
382        let fu = fun(u);
383        feval += 1;
384
385        // Update variables
386        if fu <= fx {
387            if u >= x {
388                a = x;
389            } else {
390                b = x;
391            }
392
393            v = w;
394            fv = fw;
395            w = x;
396            fw = fx;
397            x = u;
398            fx = fu;
399        } else {
400            if u < x {
401                a = u;
402            } else {
403                b = u;
404            }
405
406            if fu <= fw || w == x {
407                v = w;
408                fv = fw;
409                w = u;
410                fw = fu;
411            } else if fu <= fv || v == x || v == w {
412                v = u;
413                fv = fu;
414            }
415        }
416
417        iter += 1;
418    }
419
420    Err(OptimizeError::ConvergenceError(
421        "Maximum number of iterations reached".to_string(),
422    ))
423}
424
425/// Golden section search for scalar minimization
426fn minimize_scalar_golden<F>(
427    fun: F,
428    options: Options,
429) -> Result<ScalarOptimizeResult, OptimizeError>
430where
431    F: Fn(f64) -> f64,
432{
433    const GOLDEN: f64 = 0.6180339887498949; // (sqrt(5) - 1) / 2
434
435    // Get initial bracket or use default
436    let (a, _b, c) = if let Some(bracket) = options.bracket {
437        bracket
438    } else {
439        let x0 = 0.0;
440        let x1 = 1.0;
441        bracket_minimum(&fun, x0, x1)?
442    };
443
444    let (mut a, mut b) = if a < c { (a, c) } else { (c, a) };
445
446    // Initialize points
447    let mut x1 = a + (1.0 - GOLDEN) * (b - a);
448    let mut x2 = a + GOLDEN * (b - a);
449    let mut f1 = fun(x1);
450    let mut f2 = fun(x2);
451
452    let mut iter = 0;
453    let mut feval = 2;
454
455    while iter < options.max_iter {
456        if (b - a).abs() < options.xatol {
457            let x = 0.5 * (a + b);
458            let fx = fun(x);
459            feval += 1;
460
461            return Ok(ScalarOptimizeResult {
462                x,
463                fun: fx,
464                iterations: iter,
465                function_evals: feval,
466                success: true,
467                message: "Optimization terminated successfully.".to_string(),
468            });
469        }
470
471        if f1 < f2 {
472            b = x2;
473            x2 = x1;
474            f2 = f1;
475            x1 = a + (1.0 - GOLDEN) * (b - a);
476            f1 = fun(x1);
477            feval += 1;
478        } else {
479            a = x1;
480            x1 = x2;
481            f1 = f2;
482            x2 = a + GOLDEN * (b - a);
483            f2 = fun(x2);
484            feval += 1;
485        }
486
487        iter += 1;
488    }
489
490    Err(OptimizeError::ConvergenceError(
491        "Maximum number of iterations reached".to_string(),
492    ))
493}
494
495/// Bracket a minimum given two initial points
496fn bracket_minimum<F>(fun: &F, xa: f64, xb: f64) -> Result<(f64, f64, f64), OptimizeError>
497where
498    F: Fn(f64) -> f64,
499{
500    const GOLDEN_RATIO: f64 = 1.618033988749895;
501    const TINY: f64 = 1e-21;
502    const MAX_ITER: usize = 50;
503
504    let (mut a, mut b) = (xa, xb);
505    let mut fa = fun(a);
506    let mut fb = fun(b);
507
508    if fa < fb {
509        std::mem::swap(&mut a, &mut b);
510        std::mem::swap(&mut fa, &mut fb);
511    }
512
513    let mut c = b + GOLDEN_RATIO * (b - a);
514    let mut fc = fun(c);
515    let mut iter = 0;
516
517    while fb >= fc {
518        let r = (b - a) * (fb - fc);
519        let q = (b - c) * (fb - fa);
520        let u = b - ((b - c) * q - (b - a) * r) / (2.0 * (q - r).max(TINY).copysign(q - r));
521        let ulim = b + 100.0 * (c - b);
522
523        let fu = if (b - u) * (u - c) > 0.0 {
524            let fu = fun(u);
525            if fu < fc {
526                return Ok((b, u, c));
527            } else if fu > fb {
528                return Ok((a, b, u));
529            }
530            let u = c + GOLDEN_RATIO * (c - b);
531            fun(u)
532        } else if (c - u) * (u - ulim) > 0.0 {
533            let fu = fun(u);
534            if fu < fc {
535                b = c;
536                fb = fc;
537                c = u;
538                fc = fu;
539                let u = c + GOLDEN_RATIO * (c - b);
540                fun(u)
541            } else {
542                fu
543            }
544        } else if (u - ulim) * (ulim - c) >= 0.0 {
545            let u = ulim;
546            fun(u)
547        } else {
548            let u = c + GOLDEN_RATIO * (c - b);
549            fun(u)
550        };
551
552        a = b;
553        fa = fb;
554        b = c;
555        fb = fc;
556        c = u;
557        fc = fu;
558
559        iter += 1;
560        if iter >= MAX_ITER {
561            return Err(OptimizeError::ValueError(
562                "Failed to bracket minimum".to_string(),
563            ));
564        }
565    }
566
567    Ok((a, b, c))
568}
569
570#[cfg(test)]
571mod tests {
572    use super::*;
573    use approx::assert_abs_diff_eq;
574
575    #[test]
576    fn test_brent_method() {
577        // Test function: (x - 2)^2
578        let f = |x: f64| (x - 2.0).powi(2);
579
580        let result = minimize_scalar(f, None, Method::Brent, None).unwrap();
581        assert!(result.success);
582        assert_abs_diff_eq!(result.x, 2.0, epsilon = 1e-5);
583        assert_abs_diff_eq!(result.fun, 0.0, epsilon = 1e-10);
584    }
585
586    #[test]
587    fn test_bounded_method() {
588        // Test function: (x - 2)^2, but constrained to [-1, 1]
589        let f = |x: f64| (x - 2.0).powi(2);
590
591        let result = minimize_scalar(f, Some((-1.0, 1.0)), Method::Bounded, None).unwrap();
592        assert!(result.success);
593        // Allow for some numerical tolerance
594        assert!(result.x > 0.99 && result.x <= 1.0);
595        assert!(result.fun >= 0.99 && result.fun <= 1.01);
596    }
597
598    #[test]
599    fn test_golden_method() {
600        // Test function: x^4 - 2x^2 + x
601        let f = |x: f64| x.powi(4) - 2.0 * x.powi(2) + x;
602
603        let result = minimize_scalar(f, None, Method::Golden, None).unwrap();
604        assert!(result.success);
605        // The actual minimum depends on the implementation details
606        // For the test, we just check it's in a reasonable range
607        assert!(result.x > 0.5 && result.x < 1.0);
608    }
609
610    #[test]
611    fn test_complex_function() {
612        // Test with a more complex function
613        let f = |x: f64| (x - 2.0) * x * (x + 2.0).powi(2);
614
615        let result = minimize_scalar(f, None, Method::Brent, None).unwrap();
616        assert!(result.success);
617        // The minimum occurs around x ≈ 1.28
618        assert!(result.x > 1.2 && result.x < 1.3);
619    }
620}