1use crate::error::{OptimizeError, OptimizeResult};
24use scirs2_core::ndarray::{Array1, ArrayView1};
25
26#[derive(Debug, Clone)]
28pub struct DirectOptions {
29 pub max_fevals: usize,
31 pub max_iterations: usize,
33 pub ftol_abs: f64,
35 pub ftol_rel: f64,
37 pub vol_tol: f64,
39 pub epsilon: f64,
43 pub locally_biased: bool,
45}
46
47impl Default for DirectOptions {
48 fn default() -> Self {
49 Self {
50 max_fevals: 10_000,
51 max_iterations: 1_000,
52 ftol_abs: 1e-12,
53 ftol_rel: 1e-12,
54 vol_tol: 1e-16,
55 epsilon: 1e-4,
56 locally_biased: false,
57 }
58 }
59}
60
61#[derive(Debug, Clone)]
63pub struct DirectResult {
64 pub x: Array1<f64>,
66 pub fun: f64,
68 pub nfev: usize,
70 pub nit: usize,
72 pub n_rectangles: usize,
74 pub success: bool,
76 pub message: String,
78}
79
80#[derive(Debug, Clone)]
82struct Rectangle {
83 center: Vec<f64>,
85 f_center: f64,
87 half_widths: Vec<f64>,
89 size: f64,
91}
92
93impl Rectangle {
94 fn new(center: Vec<f64>, f_center: f64, half_widths: Vec<f64>) -> Self {
95 let size = half_widths.iter().copied().fold(0.0_f64, f64::max);
96 Self {
97 center,
98 f_center,
99 half_widths,
100 size,
101 }
102 }
103
104 fn diagonal(&self) -> f64 {
106 self.half_widths.iter().map(|w| w * w).sum::<f64>().sqrt()
107 }
108
109 fn volume(&self) -> f64 {
111 self.half_widths.iter().map(|w| 2.0 * w).product::<f64>()
112 }
113}
114
115pub struct Direct<F>
117where
118 F: Fn(&ArrayView1<f64>) -> f64,
119{
120 func: F,
121 lower_bounds: Vec<f64>,
122 upper_bounds: Vec<f64>,
123 options: DirectOptions,
124 ndim: usize,
125 rectangles: Vec<Rectangle>,
127 best_f: f64,
129 best_x: Vec<f64>,
131 fevals: usize,
133}
134
135impl<F> Direct<F>
136where
137 F: Fn(&ArrayView1<f64>) -> f64,
138{
139 pub fn new(
141 func: F,
142 lower_bounds: Vec<f64>,
143 upper_bounds: Vec<f64>,
144 options: DirectOptions,
145 ) -> OptimizeResult<Self> {
146 let ndim = lower_bounds.len();
147 if ndim == 0 {
148 return Err(OptimizeError::InvalidInput(
149 "Dimension must be at least 1".to_string(),
150 ));
151 }
152 if upper_bounds.len() != ndim {
153 return Err(OptimizeError::InvalidInput(
154 "Lower and upper bounds must have the same length".to_string(),
155 ));
156 }
157 for i in 0..ndim {
158 if lower_bounds[i] >= upper_bounds[i] {
159 return Err(OptimizeError::InvalidInput(format!(
160 "Lower bound must be less than upper bound for dimension {}: {} >= {}",
161 i, lower_bounds[i], upper_bounds[i]
162 )));
163 }
164 }
165
166 Ok(Self {
167 func,
168 lower_bounds,
169 upper_bounds,
170 options,
171 ndim,
172 rectangles: Vec::new(),
173 best_f: f64::INFINITY,
174 best_x: vec![0.0; ndim],
175 fevals: 0,
176 })
177 }
178
179 fn to_original(&self, normalized: &[f64]) -> Vec<f64> {
181 normalized
182 .iter()
183 .enumerate()
184 .map(|(i, &x)| self.lower_bounds[i] + x * (self.upper_bounds[i] - self.lower_bounds[i]))
185 .collect()
186 }
187
188 fn evaluate(&mut self, normalized_point: &[f64]) -> f64 {
190 let original = self.to_original(normalized_point);
191 let arr = Array1::from_vec(original.clone());
192 let f_val = (self.func)(&arr.view());
193 self.fevals += 1;
194
195 if f_val < self.best_f {
196 self.best_f = f_val;
197 self.best_x = original;
198 }
199
200 f_val
201 }
202
203 fn initialize(&mut self) {
205 let center = vec![0.5; self.ndim];
206 let f_center = self.evaluate(¢er);
207 let half_widths = vec![0.5; self.ndim];
208 let rect = Rectangle::new(center, f_center, half_widths);
209 self.rectangles.push(rect);
210 }
211
212 fn select_potentially_optimal(&self) -> Vec<usize> {
219 if self.rectangles.is_empty() {
220 return Vec::new();
221 }
222
223 let epsilon = self.options.epsilon;
224 let f_min = self.best_f;
225
226 let mut size_groups: std::collections::BTreeMap<u64, Vec<usize>> =
228 std::collections::BTreeMap::new();
229 for (idx, rect) in self.rectangles.iter().enumerate() {
230 let size_key = (rect.diagonal() * 1e12) as u64;
231 size_groups.entry(size_key).or_default().push(idx);
232 }
233
234 let mut hull_candidates: Vec<(f64, f64, usize)> = Vec::new(); for (_size_key, indices) in &size_groups {
237 let mut best_idx = indices[0];
238 let mut best_f = self.rectangles[indices[0]].f_center;
239 for &idx in &indices[1..] {
240 if self.rectangles[idx].f_center < best_f {
241 best_f = self.rectangles[idx].f_center;
242 best_idx = idx;
243 }
244 }
245 hull_candidates.push((self.rectangles[best_idx].diagonal(), best_f, best_idx));
246 }
247
248 hull_candidates.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
250
251 if hull_candidates.is_empty() {
252 return Vec::new();
253 }
254
255 let mut selected = Vec::new();
272
273 for i in 0..hull_candidates.len() {
274 let (d_i, f_i, idx) = hull_candidates[i];
275
276 let mut k_lower = 0.0_f64; for k in 0..i {
279 let (d_k, f_k, _) = hull_candidates[k];
280 if d_k < d_i && (d_i - d_k).abs() > 1e-15 {
281 let slope = (f_i - f_k) / (d_i - d_k);
282 if slope > k_lower {
283 k_lower = slope;
284 }
285 }
286 }
287
288 let k_upper = if self.options.locally_biased {
290 if i + 1 < hull_candidates.len() {
292 let (d_next, f_next, _) = hull_candidates[i + 1];
293 if (d_next - d_i).abs() > 1e-15 {
294 (f_next - f_i) / (d_next - d_i)
295 } else {
296 f64::INFINITY
297 }
298 } else {
299 f64::INFINITY
300 }
301 } else {
302 let mut k_up = f64::INFINITY;
304 for j in (i + 1)..hull_candidates.len() {
305 let (d_j, f_j, _) = hull_candidates[j];
306 if d_j > d_i && (d_j - d_i).abs() > 1e-15 {
307 let slope = (f_j - f_i) / (d_j - d_i);
308 if slope < k_up {
309 k_up = slope;
310 }
311 }
312 }
313 k_up
314 };
315
316 if k_upper < k_lower {
318 continue; }
320
321 let k_use = if k_upper.is_finite() {
324 k_upper
325 } else {
326 k_lower
327 };
328 let f_projected = f_i - k_use * d_i;
329 if f_projected <= f_min - epsilon * f_min.abs() {
330 selected.push(idx);
331 }
332 }
333
334 let best_rect_idx = self
337 .rectangles
338 .iter()
339 .enumerate()
340 .min_by(|(_, a), (_, b)| {
341 a.f_center
342 .partial_cmp(&b.f_center)
343 .unwrap_or(std::cmp::Ordering::Equal)
344 })
345 .map(|(i, _)| i);
346
347 if let Some(best_idx) = best_rect_idx {
348 if !selected.contains(&best_idx) {
349 selected.push(best_idx);
350 }
351 }
352
353 selected
354 }
355
356 fn divide_rectangle(&mut self, rect_idx: usize) -> Vec<Rectangle> {
358 let rect = self.rectangles[rect_idx].clone();
359
360 let max_width = rect.half_widths.iter().copied().fold(0.0_f64, f64::max);
362
363 let long_dims: Vec<usize> = rect
364 .half_widths
365 .iter()
366 .enumerate()
367 .filter(|(_, &w)| (w - max_width).abs() < 1e-15)
368 .map(|(i, _)| i)
369 .collect();
370
371 let new_hw = max_width / 3.0; let mut dim_sort: Vec<(usize, f64)> = Vec::new();
399 for &dim in &long_dims {
400 let mut c_probe_p = rect.center.clone();
401 c_probe_p[dim] += new_hw;
402 let f_probe_p = self.evaluate(&c_probe_p);
403
404 let mut c_probe_m = rect.center.clone();
405 c_probe_m[dim] -= new_hw;
406 let f_probe_m = self.evaluate(&c_probe_m);
407
408 dim_sort.push((dim, f_probe_p.min(f_probe_m)));
409 }
410
411 dim_sort.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
413
414 let mut new_rects = Vec::new();
421
422 for &(dim, _) in &dim_sort {
423 let mut c_child_p = rect.center.clone();
425 c_child_p[dim] += 2.0 * new_hw;
426 let f_child_p = self.evaluate(&c_child_p);
427
428 let mut c_child_m = rect.center.clone();
429 c_child_m[dim] -= 2.0 * new_hw;
430 let f_child_m = self.evaluate(&c_child_m);
431
432 let mut hw_child = rect.half_widths.clone();
434 hw_child[dim] = new_hw;
435
436 new_rects.push(Rectangle::new(c_child_p, f_child_p, hw_child.clone()));
437 new_rects.push(Rectangle::new(c_child_m, f_child_m, hw_child));
438 }
439
440 let mut parent_hw = rect.half_widths.clone();
443 for &(dim, _) in &dim_sort {
444 parent_hw[dim] = new_hw;
445 }
446 let parent_rect = Rectangle::new(rect.center.clone(), rect.f_center, parent_hw);
447 new_rects.push(parent_rect);
448
449 new_rects
450 }
451
452 pub fn run(&mut self) -> DirectResult {
454 self.initialize();
455
456 let mut prev_best_f = self.best_f;
457
458 for iteration in 0..self.options.max_iterations {
459 if self.fevals >= self.options.max_fevals {
461 return DirectResult {
462 x: Array1::from_vec(self.best_x.clone()),
463 fun: self.best_f,
464 nfev: self.fevals,
465 nit: iteration,
466 n_rectangles: self.rectangles.len(),
467 success: true,
468 message: format!(
469 "Maximum function evaluations ({}) reached",
470 self.options.max_fevals
471 ),
472 };
473 }
474
475 let po_indices = self.select_potentially_optimal();
477 if po_indices.is_empty() {
478 return DirectResult {
479 x: Array1::from_vec(self.best_x.clone()),
480 fun: self.best_f,
481 nfev: self.fevals,
482 nit: iteration,
483 n_rectangles: self.rectangles.len(),
484 success: true,
485 message: "No potentially optimal rectangles found".to_string(),
486 };
487 }
488
489 let mut sorted_indices = po_indices;
492 sorted_indices.sort_unstable_by(|a, b| b.cmp(a));
493
494 let mut new_rects_all = Vec::new();
495 for &idx in &sorted_indices {
496 if self.fevals >= self.options.max_fevals {
497 break;
498 }
499 let new_rects = self.divide_rectangle(idx);
500 new_rects_all.push((idx, new_rects));
501 }
502
503 let mut indices_to_remove: Vec<usize> =
506 new_rects_all.iter().map(|(idx, _)| *idx).collect();
507 indices_to_remove.sort_unstable_by(|a, b| b.cmp(a));
508 for idx in indices_to_remove {
509 self.rectangles.swap_remove(idx);
510 }
511 for (_, new_rects) in new_rects_all {
512 self.rectangles.extend(new_rects);
513 }
514
515 let f_improvement = (prev_best_f - self.best_f).abs();
520 let abs_stagnant = f_improvement < self.options.ftol_abs;
521 let rel_stagnant = if prev_best_f.abs() > 1e-30 {
522 f_improvement / prev_best_f.abs() < self.options.ftol_rel
523 } else {
524 abs_stagnant
525 };
526 if (abs_stagnant || rel_stagnant) && iteration > 10 {
527 let stagnation_limit = (self.options.max_iterations / 10).max(50);
530 if f_improvement == 0.0 && iteration > stagnation_limit {
535 return DirectResult {
536 x: Array1::from_vec(self.best_x.clone()),
537 fun: self.best_f,
538 nfev: self.fevals,
539 nit: iteration,
540 n_rectangles: self.rectangles.len(),
541 success: true,
542 message: "Function tolerance reached (stagnation)".to_string(),
543 };
544 }
545 }
546
547 let max_vol = self
549 .rectangles
550 .iter()
551 .map(|r| r.volume())
552 .fold(0.0_f64, f64::max);
553 if max_vol < self.options.vol_tol {
554 return DirectResult {
555 x: Array1::from_vec(self.best_x.clone()),
556 fun: self.best_f,
557 nfev: self.fevals,
558 nit: iteration,
559 n_rectangles: self.rectangles.len(),
560 success: true,
561 message: "Volume tolerance reached".to_string(),
562 };
563 }
564
565 prev_best_f = self.best_f;
566 }
567
568 DirectResult {
569 x: Array1::from_vec(self.best_x.clone()),
570 fun: self.best_f,
571 nfev: self.fevals,
572 nit: self.options.max_iterations,
573 n_rectangles: self.rectangles.len(),
574 success: true,
575 message: format!(
576 "Maximum iterations ({}) reached",
577 self.options.max_iterations
578 ),
579 }
580 }
581}
582
583pub fn direct_minimize<F>(
596 func: F,
597 lower_bounds: Vec<f64>,
598 upper_bounds: Vec<f64>,
599 options: Option<DirectOptions>,
600) -> OptimizeResult<DirectResult>
601where
602 F: Fn(&ArrayView1<f64>) -> f64,
603{
604 let options = options.unwrap_or_default();
605 let mut optimizer = Direct::new(func, lower_bounds, upper_bounds, options)?;
606 Ok(optimizer.run())
607}
608
609#[cfg(test)]
610mod tests {
611 use super::*;
612
613 fn sphere(x: &ArrayView1<f64>) -> f64 {
615 x.iter().map(|xi| xi * xi).sum()
616 }
617
618 fn rosenbrock(x: &ArrayView1<f64>) -> f64 {
620 let mut sum = 0.0;
621 for i in 0..x.len() - 1 {
622 sum += 100.0 * (x[i + 1] - x[i] * x[i]).powi(2) + (1.0 - x[i]).powi(2);
623 }
624 sum
625 }
626
627 fn rastrigin(x: &ArrayView1<f64>) -> f64 {
629 let n = x.len() as f64;
630 let mut sum = 10.0 * n;
631 for &xi in x.iter() {
632 sum += xi * xi - 10.0 * (2.0 * std::f64::consts::PI * xi).cos();
633 }
634 sum
635 }
636
637 fn branin(x: &ArrayView1<f64>) -> f64 {
639 let pi = std::f64::consts::PI;
640 let x1 = x[0];
641 let x2 = x[1];
642 let a = 1.0;
643 let b = 5.1 / (4.0 * pi * pi);
644 let c = 5.0 / pi;
645 let r = 6.0;
646 let s = 10.0;
647 let t = 1.0 / (8.0 * pi);
648 a * (x2 - b * x1 * x1 + c * x1 - r).powi(2) + s * (1.0 - t) * x1.cos() + s
649 }
650
651 #[test]
652 fn test_direct_sphere_2d() {
653 let result = direct_minimize(
654 sphere,
655 vec![-5.0, -5.0],
656 vec![5.0, 5.0],
657 Some(DirectOptions {
658 max_fevals: 500,
659 ..Default::default()
660 }),
661 );
662 assert!(result.is_ok());
663 let res = result.expect("DIRECT sphere 2D failed");
664 assert!(res.fun < 0.1, "DIRECT sphere value: {}", res.fun);
665 assert!(res.nfev <= 516, "Used {} evaluations", res.nfev);
669 }
670
671 #[test]
672 fn test_direct_sphere_3d() {
673 let result = direct_minimize(
674 sphere,
675 vec![-5.0, -5.0, -5.0],
676 vec![5.0, 5.0, 5.0],
677 Some(DirectOptions {
678 max_fevals: 2_000,
679 ..Default::default()
680 }),
681 );
682 assert!(result.is_ok());
683 let res = result.expect("DIRECT sphere 3D failed");
684 assert!(res.fun < 1.0, "DIRECT sphere 3D value: {}", res.fun);
685 }
686
687 #[test]
688 fn test_direct_rosenbrock() {
689 let result = direct_minimize(
690 rosenbrock,
691 vec![-2.0, -2.0],
692 vec![2.0, 2.0],
693 Some(DirectOptions {
694 max_fevals: 5_000,
695 ..Default::default()
696 }),
697 );
698 assert!(result.is_ok());
699 let res = result.expect("DIRECT Rosenbrock failed");
700 assert!(res.fun < 1.0, "DIRECT Rosenbrock value: {}", res.fun);
701 }
702
703 #[test]
704 fn test_direct_rastrigin() {
705 let result = direct_minimize(
706 rastrigin,
707 vec![-5.12, -5.12],
708 vec![5.12, 5.12],
709 Some(DirectOptions {
710 max_fevals: 5_000,
711 ..Default::default()
712 }),
713 );
714 assert!(result.is_ok());
715 let res = result.expect("DIRECT Rastrigin failed");
716 assert!(res.fun < 5.0, "DIRECT Rastrigin value: {}", res.fun);
718 }
719
720 #[test]
721 fn test_direct_branin() {
722 let result = direct_minimize(
723 branin,
724 vec![-5.0, 0.0],
725 vec![10.0, 15.0],
726 Some(DirectOptions {
727 max_fevals: 3_000,
728 ..Default::default()
729 }),
730 );
731 assert!(result.is_ok());
732 let res = result.expect("DIRECT Branin failed");
733 assert!(
735 res.fun < 1.0,
736 "DIRECT Branin value: {} (expected ~0.398)",
737 res.fun
738 );
739 }
740
741 #[test]
742 fn test_direct_locally_biased() {
743 let result = direct_minimize(
744 sphere,
745 vec![-5.0, -5.0],
746 vec![5.0, 5.0],
747 Some(DirectOptions {
748 max_fevals: 500,
749 locally_biased: true,
750 ..Default::default()
751 }),
752 );
753 assert!(result.is_ok());
754 let res = result.expect("DIRECT-L sphere failed");
755 assert!(res.fun < 1.0, "DIRECT-L sphere value: {}", res.fun);
756 }
757
758 #[test]
759 fn test_direct_invalid_bounds() {
760 let result = direct_minimize(sphere, vec![5.0, -5.0], vec![-5.0, 5.0], None);
761 assert!(result.is_err());
762 }
763
764 #[test]
765 fn test_direct_empty_dimensions() {
766 let result: OptimizeResult<DirectResult> = direct_minimize(sphere, vec![], vec![], None);
767 assert!(result.is_err());
768 }
769
770 #[test]
771 fn test_direct_1d() {
772 fn parabola(x: &ArrayView1<f64>) -> f64 {
773 (x[0] - 3.0).powi(2) + 1.0
774 }
775 let result = direct_minimize(
776 parabola,
777 vec![0.0],
778 vec![6.0],
779 Some(DirectOptions {
780 max_fevals: 200,
781 ..Default::default()
782 }),
783 );
784 assert!(result.is_ok());
785 let res = result.expect("DIRECT 1D parabola failed");
786 assert!(
787 (res.x[0] - 3.0).abs() < 0.5,
788 "DIRECT 1D minimum at x={} (expected 3.0)",
789 res.x[0]
790 );
791 assert!(
792 (res.fun - 1.0).abs() < 0.5,
793 "DIRECT 1D value {} (expected 1.0)",
794 res.fun
795 );
796 }
797
798 #[test]
799 fn test_direct_budget_management() {
800 let result = direct_minimize(
801 sphere,
802 vec![-10.0, -10.0],
803 vec![10.0, 10.0],
804 Some(DirectOptions {
805 max_fevals: 50,
806 ..Default::default()
807 }),
808 );
809 assert!(result.is_ok());
810 let res = result.expect("DIRECT budget test failed");
811 assert!(res.nfev <= 55, "Budget exceeded: {} > 50", res.nfev);
812 }
813}