1use crate::error::OptimizeError;
7use num_traits::Float;
8use std::fmt;
9
10#[derive(Debug, Clone, Copy)]
12pub enum Method {
13 Brent,
15 Bounded,
17 Golden,
19}
20
21#[derive(Debug, Clone)]
23pub struct Options {
24 pub max_iter: usize,
26 pub xatol: f64,
28 pub xrtol: f64,
30 pub bracket: Option<(f64, f64, f64)>,
32 pub disp: bool,
34}
35
36impl Default for Options {
37 fn default() -> Self {
38 Options {
39 max_iter: 500,
40 xatol: 1e-5,
41 xrtol: 1.4901161193847656e-8,
42 bracket: None,
43 disp: false,
44 }
45 }
46}
47
48impl fmt::Display for Method {
49 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
50 match self {
51 Method::Brent => write!(f, "Brent"),
52 Method::Bounded => write!(f, "Bounded"),
53 Method::Golden => write!(f, "Golden"),
54 }
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct ScalarOptimizeResult {
61 pub x: f64,
63 pub fun: f64,
65 pub iterations: usize,
67 pub function_evals: usize,
69 pub success: bool,
71 pub message: String,
73}
74
75pub fn minimize_scalar<F>(
111 fun: F,
112 bounds: Option<(f64, f64)>,
113 method: Method,
114 options: Option<Options>,
115) -> Result<ScalarOptimizeResult, OptimizeError>
116where
117 F: Fn(f64) -> f64,
118{
119 let opts = options.unwrap_or_default();
120
121 match method {
122 Method::Brent => minimize_scalar_brent(fun, opts),
123 Method::Bounded => {
124 if let Some((a, b)) = bounds {
125 minimize_scalar_bounded(fun, a, b, opts)
126 } else {
127 Err(OptimizeError::ValueError(
128 "Bounds are required for bounded method".to_string(),
129 ))
130 }
131 }
132 Method::Golden => minimize_scalar_golden(fun, opts),
133 }
134}
135
136fn minimize_scalar_brent<F>(fun: F, options: Options) -> Result<ScalarOptimizeResult, OptimizeError>
138where
139 F: Fn(f64) -> f64,
140{
141 const GOLDEN: f64 = 0.3819660112501051; const SQRT_EPS: f64 = 1.4901161193847656e-8;
146
147 let (a, _b, c) = if let Some(bracket) = options.bracket {
149 bracket
150 } else {
151 let x0 = 0.0;
153 let x1 = 1.0;
154 bracket_minimum(&fun, x0, x1)?
155 };
156
157 let tol = 3.0 * SQRT_EPS;
158 let (mut a, mut b) = if a < c { (a, c) } else { (c, a) };
159
160 let mut v = a + GOLDEN * (b - a);
162 let mut w = v;
163 let mut x = v;
164 let mut fx = fun(x);
165 let mut fv = fx;
166 let mut fw = fx;
167
168 let mut d = 0.0;
169 let mut e = 0.0;
170 let mut iter = 0;
171 let mut feval = 1;
172
173 while iter < options.max_iter {
174 let xm = 0.5 * (a + b);
175 let tol1 = tol * x.abs() + options.xatol;
176 let tol2 = 2.0 * tol1;
177
178 if (x - xm).abs() <= tol2 - 0.5 * (b - a) {
180 return Ok(ScalarOptimizeResult {
181 x,
182 fun: fx,
183 iterations: iter,
184 function_evals: feval,
185 success: true,
186 message: "Optimization terminated successfully.".to_string(),
187 });
188 }
189
190 if e.abs() > tol1 {
192 let r = (x - w) * (fx - fv);
193 let q_temp = (x - v) * (fx - fw);
194 let p_temp = (x - v) * q_temp - (x - w) * r;
195 let mut q_val = 2.0 * (q_temp - r);
196
197 let p_val = if q_val > 0.0 {
198 q_val = -q_val;
199 -p_temp
200 } else {
201 p_temp
202 };
203
204 let etemp = e;
205 e = d;
206
207 if p_val.abs() < (0.5 * q_val * etemp).abs()
209 && p_val > q_val * (a - x)
210 && p_val < q_val * (b - x)
211 {
212 d = p_val / q_val;
213 let u = x + d;
214
215 if (u - a) < tol2 || (b - u) < tol2 {
217 d = if xm > x { tol1 } else { -tol1 };
218 }
219 } else {
220 e = if x >= xm { a - x } else { b - x };
222 d = GOLDEN * e;
223 }
224 } else {
225 e = if x >= xm { a - x } else { b - x };
227 d = GOLDEN * e;
228 }
229
230 let u = if d.abs() >= tol1 {
232 x + d
233 } else {
234 x + if d > 0.0 { tol1 } else { -tol1 }
235 };
236
237 let fu = fun(u);
238 feval += 1;
239
240 if fu <= fx {
242 if u >= x {
243 a = x;
244 } else {
245 b = x;
246 }
247
248 v = w;
249 fv = fw;
250 w = x;
251 fw = fx;
252 x = u;
253 fx = fu;
254 } else {
255 if u < x {
256 a = u;
257 } else {
258 b = u;
259 }
260
261 if fu <= fw || w == x {
262 v = w;
263 fv = fw;
264 w = u;
265 fw = fu;
266 } else if fu <= fv || v == x || v == w {
267 v = u;
268 fv = fu;
269 }
270 }
271
272 iter += 1;
273 }
274
275 Err(OptimizeError::ConvergenceError(
276 "Maximum number of iterations reached".to_string(),
277 ))
278}
279
280fn minimize_scalar_bounded<F>(
282 fun: F,
283 xmin: f64,
284 xmax: f64,
285 options: Options,
286) -> Result<ScalarOptimizeResult, OptimizeError>
287where
288 F: Fn(f64) -> f64,
289{
290 if xmin >= xmax {
291 return Err(OptimizeError::ValueError(
292 "Lower bound must be less than upper bound".to_string(),
293 ));
294 }
295
296 const GOLDEN: f64 = 0.3819660112501051;
300 const SQRT_EPS: f64 = 1.4901161193847656e-8;
301
302 let tol = 3.0 * SQRT_EPS;
303 let (mut a, mut b) = (xmin, xmax);
304
305 let mut v = a + GOLDEN * (b - a);
307 let mut w = v;
308 let mut x = v;
309 let mut fx = fun(x);
310 let mut fv = fx;
311 let mut fw = fx;
312
313 let mut d = 0.0;
314 let mut e = 0.0;
315 let mut iter = 0;
316 let mut feval = 1;
317
318 while iter < options.max_iter {
319 let xm = 0.5 * (a + b);
320 let tol1 = tol * x.abs() + options.xatol;
321 let tol2 = 2.0 * tol1;
322
323 if (x - xm).abs() <= tol2 - 0.5 * (b - a) {
325 return Ok(ScalarOptimizeResult {
326 x,
327 fun: fx,
328 iterations: iter,
329 function_evals: feval,
330 success: true,
331 message: "Optimization terminated successfully.".to_string(),
332 });
333 }
334
335 if e.abs() > tol1 {
337 let r = (x - w) * (fx - fv);
338 let q_temp = (x - v) * (fx - fw);
339 let p_temp = (x - v) * q_temp - (x - w) * r;
340 let mut q_val = 2.0 * (q_temp - r);
341
342 let p_val = if q_val > 0.0 {
343 q_val = -q_val;
344 -p_temp
345 } else {
346 p_temp
347 };
348
349 let etemp = e;
350 e = d;
351
352 if p_val.abs() < (0.5 * q_val * etemp).abs()
353 && p_val > q_val * (a - x)
354 && p_val < q_val * (b - x)
355 {
356 d = p_val / q_val;
357 let u = x + d;
358
359 if (u - a) < tol2 || (b - u) < tol2 {
360 d = if xm > x { tol1 } else { -tol1 };
361 }
362 } else {
363 e = if x >= xm { a - x } else { b - x };
364 d = GOLDEN * e;
365 }
366 } else {
367 e = if x >= xm { a - x } else { b - x };
368 d = GOLDEN * e;
369 }
370
371 let u = (x + if d.abs() >= tol1 {
373 d
374 } else if d > 0.0 {
375 tol1
376 } else {
377 -tol1
378 })
379 .max(xmin)
380 .min(xmax);
381
382 let fu = fun(u);
383 feval += 1;
384
385 if fu <= fx {
387 if u >= x {
388 a = x;
389 } else {
390 b = x;
391 }
392
393 v = w;
394 fv = fw;
395 w = x;
396 fw = fx;
397 x = u;
398 fx = fu;
399 } else {
400 if u < x {
401 a = u;
402 } else {
403 b = u;
404 }
405
406 if fu <= fw || w == x {
407 v = w;
408 fv = fw;
409 w = u;
410 fw = fu;
411 } else if fu <= fv || v == x || v == w {
412 v = u;
413 fv = fu;
414 }
415 }
416
417 iter += 1;
418 }
419
420 Err(OptimizeError::ConvergenceError(
421 "Maximum number of iterations reached".to_string(),
422 ))
423}
424
425fn minimize_scalar_golden<F>(
427 fun: F,
428 options: Options,
429) -> Result<ScalarOptimizeResult, OptimizeError>
430where
431 F: Fn(f64) -> f64,
432{
433 const GOLDEN: f64 = 0.6180339887498949; let (a, _b, c) = if let Some(bracket) = options.bracket {
437 bracket
438 } else {
439 let x0 = 0.0;
440 let x1 = 1.0;
441 bracket_minimum(&fun, x0, x1)?
442 };
443
444 let (mut a, mut b) = if a < c { (a, c) } else { (c, a) };
445
446 let mut x1 = a + (1.0 - GOLDEN) * (b - a);
448 let mut x2 = a + GOLDEN * (b - a);
449 let mut f1 = fun(x1);
450 let mut f2 = fun(x2);
451
452 let mut iter = 0;
453 let mut feval = 2;
454
455 while iter < options.max_iter {
456 if (b - a).abs() < options.xatol {
457 let x = 0.5 * (a + b);
458 let fx = fun(x);
459 feval += 1;
460
461 return Ok(ScalarOptimizeResult {
462 x,
463 fun: fx,
464 iterations: iter,
465 function_evals: feval,
466 success: true,
467 message: "Optimization terminated successfully.".to_string(),
468 });
469 }
470
471 if f1 < f2 {
472 b = x2;
473 x2 = x1;
474 f2 = f1;
475 x1 = a + (1.0 - GOLDEN) * (b - a);
476 f1 = fun(x1);
477 feval += 1;
478 } else {
479 a = x1;
480 x1 = x2;
481 f1 = f2;
482 x2 = a + GOLDEN * (b - a);
483 f2 = fun(x2);
484 feval += 1;
485 }
486
487 iter += 1;
488 }
489
490 Err(OptimizeError::ConvergenceError(
491 "Maximum number of iterations reached".to_string(),
492 ))
493}
494
495fn bracket_minimum<F>(fun: &F, xa: f64, xb: f64) -> Result<(f64, f64, f64), OptimizeError>
497where
498 F: Fn(f64) -> f64,
499{
500 const GOLDEN_RATIO: f64 = 1.618033988749895;
501 const TINY: f64 = 1e-21;
502 const MAX_ITER: usize = 50;
503
504 let (mut a, mut b) = (xa, xb);
505 let mut fa = fun(a);
506 let mut fb = fun(b);
507
508 if fa < fb {
509 std::mem::swap(&mut a, &mut b);
510 std::mem::swap(&mut fa, &mut fb);
511 }
512
513 let mut c = b + GOLDEN_RATIO * (b - a);
514 let mut fc = fun(c);
515 let mut iter = 0;
516
517 while fb >= fc {
518 let r = (b - a) * (fb - fc);
519 let q = (b - c) * (fb - fa);
520 let u = b - ((b - c) * q - (b - a) * r) / (2.0 * (q - r).max(TINY).copysign(q - r));
521 let ulim = b + 100.0 * (c - b);
522
523 let fu = if (b - u) * (u - c) > 0.0 {
524 let fu = fun(u);
525 if fu < fc {
526 return Ok((b, u, c));
527 } else if fu > fb {
528 return Ok((a, b, u));
529 }
530 let u = c + GOLDEN_RATIO * (c - b);
531 fun(u)
532 } else if (c - u) * (u - ulim) > 0.0 {
533 let fu = fun(u);
534 if fu < fc {
535 b = c;
536 fb = fc;
537 c = u;
538 fc = fu;
539 let u = c + GOLDEN_RATIO * (c - b);
540 fun(u)
541 } else {
542 fu
543 }
544 } else if (u - ulim) * (ulim - c) >= 0.0 {
545 let u = ulim;
546 fun(u)
547 } else {
548 let u = c + GOLDEN_RATIO * (c - b);
549 fun(u)
550 };
551
552 a = b;
553 fa = fb;
554 b = c;
555 fb = fc;
556 c = u;
557 fc = fu;
558
559 iter += 1;
560 if iter >= MAX_ITER {
561 return Err(OptimizeError::ValueError(
562 "Failed to bracket minimum".to_string(),
563 ));
564 }
565 }
566
567 Ok((a, b, c))
568}
569
570#[cfg(test)]
571mod tests {
572 use super::*;
573 use approx::assert_abs_diff_eq;
574
575 #[test]
576 fn test_brent_method() {
577 let f = |x: f64| (x - 2.0).powi(2);
579
580 let result = minimize_scalar(f, None, Method::Brent, None).unwrap();
581 assert!(result.success);
582 assert_abs_diff_eq!(result.x, 2.0, epsilon = 1e-5);
583 assert_abs_diff_eq!(result.fun, 0.0, epsilon = 1e-10);
584 }
585
586 #[test]
587 fn test_bounded_method() {
588 let f = |x: f64| (x - 2.0).powi(2);
590
591 let result = minimize_scalar(f, Some((-1.0, 1.0)), Method::Bounded, None).unwrap();
592 assert!(result.success);
593 assert!(result.x > 0.99 && result.x <= 1.0);
595 assert!(result.fun >= 0.99 && result.fun <= 1.01);
596 }
597
598 #[test]
599 fn test_golden_method() {
600 let f = |x: f64| x.powi(4) - 2.0 * x.powi(2) + x;
602
603 let result = minimize_scalar(f, None, Method::Golden, None).unwrap();
604 assert!(result.success);
605 assert!(result.x > 0.5 && result.x < 1.0);
608 }
609
610 #[test]
611 fn test_complex_function() {
612 let f = |x: f64| (x - 2.0) * x * (x + 2.0).powi(2);
614
615 let result = minimize_scalar(f, None, Method::Brent, None).unwrap();
616 assert!(result.success);
617 assert!(result.x > 1.2 && result.x < 1.3);
619 }
620}