1use crate::error::OptimizeError;
7use std::fmt;
8
9#[derive(Debug, Clone, Copy)]
11pub enum Method {
12 Brent,
14 Bounded,
16 Golden,
18}
19
20#[derive(Debug, Clone)]
22pub struct Options {
23 pub max_iter: usize,
25 pub xatol: f64,
27 pub xrtol: f64,
29 pub bracket: Option<(f64, f64, f64)>,
31 pub disp: bool,
33}
34
35impl Default for Options {
36 fn default() -> Self {
37 Options {
38 max_iter: 500,
39 xatol: 1e-5,
40 xrtol: 1.4901161193847656e-8,
41 bracket: None,
42 disp: false,
43 }
44 }
45}
46
47impl fmt::Display for Method {
48 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
49 match self {
50 Method::Brent => write!(f, "Brent"),
51 Method::Bounded => write!(f, "Bounded"),
52 Method::Golden => write!(f, "Golden"),
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct ScalarOptimizeResult {
60 pub x: f64,
62 pub fun: f64,
64 pub nit: usize,
66 pub function_evals: usize,
68 pub success: bool,
70 pub message: String,
72}
73
74#[allow(dead_code)]
110pub fn minimize_scalar<F>(
111 fun: F,
112 bounds: Option<(f64, f64)>,
113 method: Method,
114 options: Option<Options>,
115) -> Result<ScalarOptimizeResult, OptimizeError>
116where
117 F: Fn(f64) -> f64,
118{
119 let opts = options.unwrap_or_default();
120
121 match method {
122 Method::Brent => minimize_scalar_brent(fun, opts),
123 Method::Bounded => {
124 if let Some((a, b)) = bounds {
125 minimize_scalar_bounded(fun, a, b, opts)
126 } else {
127 Err(OptimizeError::ValueError(
128 "Bounds are required for bounded method".to_string(),
129 ))
130 }
131 }
132 Method::Golden => minimize_scalar_golden(fun, opts),
133 }
134}
135
136#[allow(dead_code)]
138fn minimize_scalar_brent<F>(fun: F, options: Options) -> Result<ScalarOptimizeResult, OptimizeError>
139where
140 F: Fn(f64) -> f64,
141{
142 const GOLDEN: f64 = 0.3819660112501051; const SQRT_EPS: f64 = 1.4901161193847656e-8;
147
148 let (a, b, c) = if let Some(bracket) = options.bracket {
150 bracket
151 } else {
152 let x0 = 0.0;
154 let x1 = 1.0;
155 bracket_minimum(&fun, x0, x1)?
156 };
157
158 let tol = 3.0 * SQRT_EPS;
159 let (bracket_a, bracket_b) = if b < c { (b, c) } else { (c, b) };
161 let (mut a, mut b) = if a < bracket_a {
162 (a, bracket_a)
163 } else {
164 (bracket_a, a)
165 };
166
167 let mut v = a + GOLDEN * (b - a);
169 let mut w = v;
170 let mut x = v;
171 let mut fx = fun(x);
172 let mut fv = fx;
173 let mut fw = fx;
174
175 let mut d: f64 = 0.0;
176 let mut e: f64 = 0.0;
177 let mut iter = 0;
178 let mut feval = 1;
179
180 while iter < options.max_iter {
181 let xm = 0.5 * (a + b);
182 let tol1 = tol * x.abs() + options.xatol;
183 let tol2 = 2.0 * tol1;
184
185 if (x - xm).abs() <= tol2 - 0.5 * (b - a) {
187 return Ok(ScalarOptimizeResult {
188 x,
189 fun: fx,
190 nit: iter,
191 function_evals: feval,
192 success: true,
193 message: "Optimization terminated successfully.".to_string(),
194 });
195 }
196
197 if e.abs() > tol1 {
199 let r = (x - w) * (fx - fv);
200 let q_temp = (x - v) * (fx - fw);
201 let p_temp = (x - v) * q_temp - (x - w) * r;
202 let mut q_val = 2.0 * (q_temp - r);
203
204 let p_val = if q_val > 0.0 {
205 q_val = -q_val;
206 -p_temp
207 } else {
208 p_temp
209 };
210
211 let etemp = e;
212 e = d;
213
214 if p_val.abs() < (0.5 * q_val * etemp).abs()
216 && p_val > q_val * (a - x)
217 && p_val < q_val * (b - x)
218 {
219 d = p_val / q_val;
220 let u = x + d;
221
222 if (u - a) < tol2 || (b - u) < tol2 {
224 d = if xm > x { tol1 } else { -tol1 };
225 }
226 } else {
227 e = if x >= xm { a - x } else { b - x };
229 d = GOLDEN * e;
230 }
231 } else {
232 e = if x >= xm { a - x } else { b - x };
234 d = GOLDEN * e;
235 }
236
237 let u = if d.abs() >= tol1 {
239 x + d
240 } else {
241 x + if d > 0.0 { tol1 } else { -tol1 }
242 };
243
244 let fu = fun(u);
245 feval += 1;
246
247 if fu <= fx {
249 if u >= x {
250 a = x;
251 } else {
252 b = x;
253 }
254
255 v = w;
256 fv = fw;
257 w = x;
258 fw = fx;
259 x = u;
260 fx = fu;
261 } else {
262 if u < x {
263 a = u;
264 } else {
265 b = u;
266 }
267
268 if fu <= fw || w == x {
269 v = w;
270 fv = fw;
271 w = u;
272 fw = fu;
273 } else if fu <= fv || v == x || v == w {
274 v = u;
275 fv = fu;
276 }
277 }
278
279 iter += 1;
280 }
281
282 Err(OptimizeError::ConvergenceError(
283 "Maximum number of iterations reached".to_string(),
284 ))
285}
286
287#[allow(dead_code)]
289fn minimize_scalar_bounded<F>(
290 fun: F,
291 xmin: f64,
292 xmax: f64,
293 options: Options,
294) -> Result<ScalarOptimizeResult, OptimizeError>
295where
296 F: Fn(f64) -> f64,
297{
298 if xmin >= xmax {
299 return Err(OptimizeError::ValueError(
300 "Lower bound must be less than upper bound".to_string(),
301 ));
302 }
303
304 const GOLDEN: f64 = 0.3819660112501051;
308 const SQRT_EPS: f64 = 1.4901161193847656e-8;
309
310 let tol = 3.0 * SQRT_EPS;
311 let (mut a, mut b) = (xmin, xmax);
312
313 let mut v = a + GOLDEN * (b - a);
315 let mut w = v;
316 let mut x = v;
317 let mut fx = fun(x);
318 let mut fv = fx;
319 let mut fw = fx;
320
321 let mut d: f64 = 0.0;
322 let mut e: f64 = 0.0;
323 let mut iter = 0;
324 let mut feval = 1;
325
326 while iter < options.max_iter {
327 let xm = 0.5 * (a + b);
328 let tol1 = tol * x.abs() + options.xatol;
329 let tol2 = 2.0 * tol1;
330
331 if (x - xm).abs() <= tol2 - 0.5 * (b - a) {
333 return Ok(ScalarOptimizeResult {
334 x,
335 fun: fx,
336 nit: iter,
337 function_evals: feval,
338 success: true,
339 message: "Optimization terminated successfully.".to_string(),
340 });
341 }
342
343 if e.abs() > tol1 {
345 let r = (x - w) * (fx - fv);
346 let q_temp = (x - v) * (fx - fw);
347 let p_temp = (x - v) * q_temp - (x - w) * r;
348 let mut q_val = 2.0 * (q_temp - r);
349
350 let p_val = if q_val > 0.0 {
351 q_val = -q_val;
352 -p_temp
353 } else {
354 p_temp
355 };
356
357 let etemp = e;
358 e = d;
359
360 if p_val.abs() < (0.5 * q_val * etemp).abs()
361 && p_val > q_val * (a - x)
362 && p_val < q_val * (b - x)
363 {
364 d = p_val / q_val;
365 let u = x + d;
366
367 if (u - a) < tol2 || (b - u) < tol2 {
368 d = if xm > x { tol1 } else { -tol1 };
369 }
370 } else {
371 e = if x >= xm { a - x } else { b - x };
372 d = GOLDEN * e;
373 }
374 } else {
375 e = if x >= xm { a - x } else { b - x };
376 d = GOLDEN * e;
377 }
378
379 let u = (x + if d.abs() >= tol1 {
381 d
382 } else if d > 0.0 {
383 tol1
384 } else {
385 -tol1
386 })
387 .max(xmin)
388 .min(xmax);
389
390 let fu = fun(u);
391 feval += 1;
392
393 if fu <= fx {
395 if u >= x {
396 a = x;
397 } else {
398 b = x;
399 }
400
401 v = w;
402 fv = fw;
403 w = x;
404 fw = fx;
405 x = u;
406 fx = fu;
407 } else {
408 if u < x {
409 a = u;
410 } else {
411 b = u;
412 }
413
414 if fu <= fw || w == x {
415 v = w;
416 fv = fw;
417 w = u;
418 fw = fu;
419 } else if fu <= fv || v == x || v == w {
420 v = u;
421 fv = fu;
422 }
423 }
424
425 iter += 1;
426 }
427
428 Err(OptimizeError::ConvergenceError(
429 "Maximum number of iterations reached".to_string(),
430 ))
431}
432
433#[allow(dead_code)]
435fn minimize_scalar_golden<F>(
436 fun: F,
437 options: Options,
438) -> Result<ScalarOptimizeResult, OptimizeError>
439where
440 F: Fn(f64) -> f64,
441{
442 const GOLDEN: f64 = 0.6180339887498949; let (a, b, c) = if let Some(bracket) = options.bracket {
446 bracket
447 } else {
448 let x0 = 0.0;
449 let x1 = 1.0;
450 bracket_minimum(&fun, x0, x1)?
451 };
452
453 let (bracket_a, bracket_b) = if b < c { (b, c) } else { (c, b) };
455 let (mut a, mut b) = if a < bracket_a {
456 (a, bracket_a)
457 } else {
458 (bracket_a, a)
459 };
460
461 let mut x1 = a + (1.0 - GOLDEN) * (b - a);
463 let mut x2 = a + GOLDEN * (b - a);
464 let mut f1 = fun(x1);
465 let mut f2 = fun(x2);
466
467 let mut iter = 0;
468 let mut feval = 2;
469
470 while iter < options.max_iter {
471 if (b - a).abs() < options.xatol {
472 let x = 0.5 * (a + b);
473 let fx = fun(x);
474 feval += 1;
475
476 return Ok(ScalarOptimizeResult {
477 x,
478 fun: fx,
479 nit: iter,
480 function_evals: feval,
481 success: true,
482 message: "Optimization terminated successfully.".to_string(),
483 });
484 }
485
486 if f1 < f2 {
487 b = x2;
488 x2 = x1;
489 f2 = f1;
490 x1 = a + (1.0 - GOLDEN) * (b - a);
491 f1 = fun(x1);
492 feval += 1;
493 } else {
494 a = x1;
495 x1 = x2;
496 f1 = f2;
497 x2 = a + GOLDEN * (b - a);
498 f2 = fun(x2);
499 feval += 1;
500 }
501
502 iter += 1;
503 }
504
505 Err(OptimizeError::ConvergenceError(
506 "Maximum number of iterations reached".to_string(),
507 ))
508}
509
510#[allow(dead_code)]
512fn bracket_minimum<F>(fun: &F, xa: f64, xb: f64) -> Result<(f64, f64, f64), OptimizeError>
513where
514 F: Fn(f64) -> f64,
515{
516 const GOLDEN_RATIO: f64 = 1.618033988749895;
517 const TINY: f64 = 1e-21;
518 const MAX_ITER: usize = 50;
519
520 let (mut a, mut b) = (xa, xb);
521 let mut fa = fun(a);
522 let mut fb = fun(b);
523
524 if fa < fb {
525 std::mem::swap(&mut a, &mut b);
526 std::mem::swap(&mut fa, &mut fb);
527 }
528
529 let mut c = b + GOLDEN_RATIO * (b - a);
530 let mut fc = fun(c);
531 let mut iter = 0;
532
533 while fb >= fc {
534 let r = (b - a) * (fb - fc);
535 let q = (b - c) * (fb - fa);
536 let u = b - ((b - c) * q - (b - a) * r) / (2.0 * (q - r).max(TINY).copysign(q - r));
537 let ulim = b + 100.0 * (c - b);
538
539 let fu = if (b - u) * (u - c) > 0.0 {
540 let fu = fun(u);
541 if fu < fc {
542 return Ok((b, u, c));
543 } else if fu > fb {
544 return Ok((a, b, u));
545 }
546 let u = c + GOLDEN_RATIO * (c - b);
547 fun(u)
548 } else if (c - u) * (u - ulim) > 0.0 {
549 let fu = fun(u);
550 if fu < fc {
551 b = c;
552 fb = fc;
553 c = u;
554 fc = fu;
555 let u = c + GOLDEN_RATIO * (c - b);
556 fun(u)
557 } else {
558 fu
559 }
560 } else if (u - ulim) * (ulim - c) >= 0.0 {
561 let u = ulim;
562 fun(u)
563 } else {
564 let u = c + GOLDEN_RATIO * (c - b);
565 fun(u)
566 };
567
568 a = b;
569 fa = fb;
570 b = c;
571 fb = fc;
572 c = u;
573 fc = fu;
574
575 iter += 1;
576 if iter >= MAX_ITER {
577 return Err(OptimizeError::ValueError(
578 "Failed to bracket minimum".to_string(),
579 ));
580 }
581 }
582
583 Ok((a, b, c))
584}
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589 use approx::assert_abs_diff_eq;
590
591 #[test]
592 fn test_brent_method() {
593 let f = |x: f64| (x - 2.0).powi(2);
595
596 let result = minimize_scalar(f, None, Method::Brent, None).unwrap();
597 assert!(result.success);
598 assert_abs_diff_eq!(result.x, 2.0, epsilon = 2e-5);
599 assert_abs_diff_eq!(result.fun, 0.0, epsilon = 5e-10);
600 }
601
602 #[test]
603 fn test_bounded_method() {
604 let f = |x: f64| (x - 2.0).powi(2);
606
607 let result = minimize_scalar(f, Some((-1.0, 1.0)), Method::Bounded, None).unwrap();
608 assert!(result.success);
609 assert!(result.x > 0.99 && result.x <= 1.0);
611 assert!(result.fun >= 0.99 && result.fun <= 1.01);
612 }
613
614 #[test]
615 fn test_golden_method() {
616 let f = |x: f64| x.powi(4) - 2.0 * x.powi(2) + x;
618
619 let result = minimize_scalar(f, None, Method::Golden, None).unwrap();
620 assert!(result.success);
621 assert!(result.x > 0.5 && result.x < 1.0);
624 }
625
626 #[test]
627 fn test_complexfunction() {
628 let f = |x: f64| (x - 2.0) * x * (x + 2.0).powi(2);
630
631 let result = minimize_scalar(f, None, Method::Brent, None).unwrap();
632 assert!(result.success);
633 assert!(result.x > -5.0 && result.x < 5.0);
635 assert!(result.fun.is_finite());
636 }
637}