sklears_utils/
optimization.rs

1//! Optimization utilities for numerical optimization algorithms
2//!
3//! This module provides utilities for optimization algorithms including:
4//! - Line search methods (Armijo, Wolfe conditions)
5//! - Convergence criteria checking
6//! - Gradient computation helpers
7//! - Constraint handling utilities
8//!
9//! # Examples
10//!
11//! ```rust
12//! use sklears_utils::optimization::{LineSearch, ConvergenceCriteria, GradientComputer};
13//! use scirs2_core::ndarray::Array1;
14//!
15//! let line_search = LineSearch::armijo(1e-4);
16//! let conv_criteria = ConvergenceCriteria::new()
17//!     .with_tolerance(1e-6)
18//!     .with_max_iterations(1000);
19//! ```
20
21use crate::UtilsResult;
22use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
23use std::collections::VecDeque;
24
25/// Line search methods for optimization algorithms
26#[derive(Debug, Clone)]
27pub struct LineSearch {
28    pub method: LineSearchMethod,
29    pub c1: f64, // Armijo condition parameter
30    pub c2: f64, // Wolfe condition parameter
31    pub max_iterations: usize,
32    pub initial_step: f64,
33    pub step_decay: f64,
34}
35
36#[derive(Debug, Clone)]
37pub enum LineSearchMethod {
38    Armijo,
39    Wolfe,
40    StrongWolfe,
41    Backtracking,
42}
43
44impl LineSearch {
45    /// Create Armijo line search with default parameters
46    pub fn armijo(c1: f64) -> Self {
47        Self {
48            method: LineSearchMethod::Armijo,
49            c1,
50            c2: 0.9,
51            max_iterations: 50,
52            initial_step: 1.0,
53            step_decay: 0.5,
54        }
55    }
56
57    /// Create Wolfe line search with default parameters
58    pub fn wolfe(c1: f64, c2: f64) -> Self {
59        Self {
60            method: LineSearchMethod::Wolfe,
61            c1,
62            c2,
63            max_iterations: 50,
64            initial_step: 1.0,
65            step_decay: 0.5,
66        }
67    }
68
69    /// Create strong Wolfe line search
70    pub fn strong_wolfe(c1: f64, c2: f64) -> Self {
71        Self {
72            method: LineSearchMethod::StrongWolfe,
73            c1,
74            c2,
75            max_iterations: 50,
76            initial_step: 1.0,
77            step_decay: 0.5,
78        }
79    }
80
81    /// Create backtracking line search
82    pub fn backtracking(c1: f64) -> Self {
83        Self {
84            method: LineSearchMethod::Backtracking,
85            c1,
86            c2: 0.9,
87            max_iterations: 50,
88            initial_step: 1.0,
89            step_decay: 0.5,
90        }
91    }
92
93    /// Perform line search to find appropriate step size
94    pub fn search<F, G>(
95        &self,
96        f: F,
97        grad_f: G,
98        x: &ArrayView1<f64>,
99        direction: &ArrayView1<f64>,
100        f_x: f64,
101        grad_x: &ArrayView1<f64>,
102    ) -> UtilsResult<f64>
103    where
104        F: Fn(&ArrayView1<f64>) -> f64,
105        G: Fn(&ArrayView1<f64>) -> Array1<f64>,
106    {
107        match self.method {
108            LineSearchMethod::Armijo => self.armijo_search(f, x, direction, f_x, grad_x),
109            LineSearchMethod::Backtracking => {
110                self.backtracking_search(f, x, direction, f_x, grad_x)
111            }
112            LineSearchMethod::Wolfe => self.wolfe_search(f, grad_f, x, direction, f_x, grad_x),
113            LineSearchMethod::StrongWolfe => {
114                self.strong_wolfe_search(f, grad_f, x, direction, f_x, grad_x)
115            }
116        }
117    }
118
119    fn armijo_search<F>(
120        &self,
121        f: F,
122        x: &ArrayView1<f64>,
123        direction: &ArrayView1<f64>,
124        f_x: f64,
125        grad_x: &ArrayView1<f64>,
126    ) -> UtilsResult<f64>
127    where
128        F: Fn(&ArrayView1<f64>) -> f64,
129    {
130        let mut alpha = self.initial_step;
131        let directional_derivative = grad_x.dot(direction);
132
133        for _ in 0..self.max_iterations {
134            let x_new = x + &(direction * alpha);
135            let f_new = f(&x_new.view());
136
137            // Armijo condition: f(x + α*p) ≤ f(x) + c₁*α*∇f(x)ᵀp
138            if f_new <= f_x + self.c1 * alpha * directional_derivative {
139                return Ok(alpha);
140            }
141
142            alpha *= self.step_decay;
143        }
144
145        Ok(alpha) // Return last alpha even if conditions not met
146    }
147
148    fn backtracking_search<F>(
149        &self,
150        f: F,
151        x: &ArrayView1<f64>,
152        direction: &ArrayView1<f64>,
153        f_x: f64,
154        grad_x: &ArrayView1<f64>,
155    ) -> UtilsResult<f64>
156    where
157        F: Fn(&ArrayView1<f64>) -> f64,
158    {
159        self.armijo_search(f, x, direction, f_x, grad_x)
160    }
161
162    fn wolfe_search<F, G>(
163        &self,
164        f: F,
165        grad_f: G,
166        x: &ArrayView1<f64>,
167        direction: &ArrayView1<f64>,
168        f_x: f64,
169        grad_x: &ArrayView1<f64>,
170    ) -> UtilsResult<f64>
171    where
172        F: Fn(&ArrayView1<f64>) -> f64,
173        G: Fn(&ArrayView1<f64>) -> Array1<f64>,
174    {
175        let mut alpha = self.initial_step;
176        let directional_derivative = grad_x.dot(direction);
177
178        for _ in 0..self.max_iterations {
179            let x_new = x + &(direction * alpha);
180            let f_new = f(&x_new.view());
181
182            // Armijo condition
183            if f_new > f_x + self.c1 * alpha * directional_derivative {
184                alpha *= self.step_decay;
185                continue;
186            }
187
188            // Wolfe condition: ∇f(x + α*p)ᵀp ≥ c₂*∇f(x)ᵀp
189            let grad_new = grad_f(&x_new.view());
190            let new_directional_derivative = grad_new.dot(direction);
191
192            if new_directional_derivative >= self.c2 * directional_derivative {
193                return Ok(alpha);
194            }
195
196            alpha /= self.step_decay; // Increase step size
197        }
198
199        Ok(alpha)
200    }
201
202    fn strong_wolfe_search<F, G>(
203        &self,
204        f: F,
205        grad_f: G,
206        x: &ArrayView1<f64>,
207        direction: &ArrayView1<f64>,
208        f_x: f64,
209        grad_x: &ArrayView1<f64>,
210    ) -> UtilsResult<f64>
211    where
212        F: Fn(&ArrayView1<f64>) -> f64,
213        G: Fn(&ArrayView1<f64>) -> Array1<f64>,
214    {
215        let mut alpha = self.initial_step;
216        let directional_derivative = grad_x.dot(direction);
217
218        for _ in 0..self.max_iterations {
219            let x_new = x + &(direction * alpha);
220            let f_new = f(&x_new.view());
221
222            // Armijo condition
223            if f_new > f_x + self.c1 * alpha * directional_derivative {
224                alpha *= self.step_decay;
225                continue;
226            }
227
228            // Strong Wolfe condition: |∇f(x + α*p)ᵀp| ≤ c₂*|∇f(x)ᵀp|
229            let grad_new = grad_f(&x_new.view());
230            let new_directional_derivative = grad_new.dot(direction);
231
232            if new_directional_derivative.abs() <= self.c2 * directional_derivative.abs() {
233                return Ok(alpha);
234            }
235
236            alpha *= if new_directional_derivative < 0.0 {
237                1.0 / self.step_decay
238            } else {
239                self.step_decay
240            };
241        }
242
243        Ok(alpha)
244    }
245}
246
247/// Convergence criteria for optimization algorithms
248#[derive(Debug, Clone)]
249pub struct ConvergenceCriteria {
250    pub tolerance: f64,
251    pub gradient_tolerance: f64,
252    pub parameter_tolerance: f64,
253    pub max_iterations: usize,
254    pub min_iterations: usize,
255    pub function_tolerance: f64,
256    pub patience: usize,
257}
258
259impl Default for ConvergenceCriteria {
260    fn default() -> Self {
261        Self {
262            tolerance: 1e-6,
263            gradient_tolerance: 1e-6,
264            parameter_tolerance: 1e-8,
265            max_iterations: 1000,
266            min_iterations: 1,
267            function_tolerance: 1e-9,
268            patience: 10,
269        }
270    }
271}
272
273impl ConvergenceCriteria {
274    /// Create new convergence criteria with default values
275    pub fn new() -> Self {
276        Self::default()
277    }
278
279    /// Set gradient tolerance
280    pub fn with_tolerance(mut self, tol: f64) -> Self {
281        self.tolerance = tol;
282        self
283    }
284
285    /// Set gradient tolerance
286    pub fn with_gradient_tolerance(mut self, tol: f64) -> Self {
287        self.gradient_tolerance = tol;
288        self
289    }
290
291    /// Set parameter tolerance
292    pub fn with_parameter_tolerance(mut self, tol: f64) -> Self {
293        self.parameter_tolerance = tol;
294        self
295    }
296
297    /// Set maximum iterations
298    pub fn with_max_iterations(mut self, max_iter: usize) -> Self {
299        self.max_iterations = max_iter;
300        self
301    }
302
303    /// Set minimum iterations
304    pub fn with_min_iterations(mut self, min_iter: usize) -> Self {
305        self.min_iterations = min_iter;
306        self
307    }
308
309    /// Set function tolerance
310    pub fn with_function_tolerance(mut self, tol: f64) -> Self {
311        self.function_tolerance = tol;
312        self
313    }
314
315    /// Set patience for early stopping
316    pub fn with_patience(mut self, patience: usize) -> Self {
317        self.patience = patience;
318        self
319    }
320
321    /// Check if convergence is achieved
322    pub fn is_converged(
323        &self,
324        iteration: usize,
325        current_f: f64,
326        previous_f: Option<f64>,
327        gradient: Option<&ArrayView1<f64>>,
328        parameter_change: Option<f64>,
329        no_improvement_count: usize,
330    ) -> ConvergenceStatus {
331        // Check minimum iterations
332        if iteration < self.min_iterations {
333            return ConvergenceStatus::Continuing;
334        }
335
336        // Check maximum iterations
337        if iteration >= self.max_iterations {
338            return ConvergenceStatus::MaxIterationsReached;
339        }
340
341        // Check gradient tolerance
342        if let Some(grad) = gradient {
343            let grad_norm = grad.iter().map(|x| x * x).sum::<f64>().sqrt();
344            if grad_norm < self.gradient_tolerance {
345                return ConvergenceStatus::GradientTolerance;
346            }
347        }
348
349        // Check function tolerance
350        if let Some(prev_f) = previous_f {
351            let f_change = (current_f - prev_f).abs();
352            if f_change < self.function_tolerance {
353                return ConvergenceStatus::FunctionTolerance;
354            }
355        }
356
357        // Check parameter tolerance
358        if let Some(param_change) = parameter_change {
359            if param_change < self.parameter_tolerance {
360                return ConvergenceStatus::ParameterTolerance;
361            }
362        }
363
364        // Check early stopping patience
365        if no_improvement_count >= self.patience {
366            return ConvergenceStatus::NoImprovement;
367        }
368
369        ConvergenceStatus::Continuing
370    }
371}
372
373/// Status of convergence checking
374#[derive(Debug, Clone, PartialEq)]
375pub enum ConvergenceStatus {
376    Continuing,
377    GradientTolerance,
378    FunctionTolerance,
379    ParameterTolerance,
380    MaxIterationsReached,
381    NoImprovement,
382}
383
384impl ConvergenceStatus {
385    pub fn is_converged(&self) -> bool {
386        !matches!(self, ConvergenceStatus::Continuing)
387    }
388
389    pub fn is_successful(&self) -> bool {
390        matches!(
391            self,
392            ConvergenceStatus::GradientTolerance
393                | ConvergenceStatus::FunctionTolerance
394                | ConvergenceStatus::ParameterTolerance
395        )
396    }
397}
398
399/// Gradient computation utilities
400#[derive(Debug, Clone)]
401pub struct GradientComputer {
402    pub method: GradientMethod,
403    pub epsilon: f64,
404    pub parallel: bool,
405}
406
407#[derive(Debug, Clone)]
408pub enum GradientMethod {
409    Forward,
410    Backward,
411    Central,
412}
413
414impl Default for GradientComputer {
415    fn default() -> Self {
416        Self {
417            method: GradientMethod::Central,
418            epsilon: 1e-8,
419            parallel: false,
420        }
421    }
422}
423
424impl GradientComputer {
425    /// Create new gradient computer
426    pub fn new() -> Self {
427        Self::default()
428    }
429
430    /// Set gradient method
431    pub fn with_method(mut self, method: GradientMethod) -> Self {
432        self.method = method;
433        self
434    }
435
436    /// Set finite difference epsilon
437    pub fn with_epsilon(mut self, eps: f64) -> Self {
438        self.epsilon = eps;
439        self
440    }
441
442    /// Enable parallel computation
443    pub fn with_parallel(mut self, parallel: bool) -> Self {
444        self.parallel = parallel;
445        self
446    }
447
448    /// Compute numerical gradient using finite differences
449    pub fn compute_gradient<F>(&self, f: F, x: &ArrayView1<f64>) -> UtilsResult<Array1<f64>>
450    where
451        F: Fn(&ArrayView1<f64>) -> f64 + Sync,
452    {
453        let n = x.len();
454        let mut gradient = Array1::zeros(n);
455
456        match self.method {
457            GradientMethod::Forward => {
458                for i in 0..n {
459                    let mut x_plus = x.to_owned();
460                    x_plus[i] += self.epsilon;
461                    gradient[i] = (f(&x_plus.view()) - f(x)) / self.epsilon;
462                }
463            }
464            GradientMethod::Backward => {
465                for i in 0..n {
466                    let mut x_minus = x.to_owned();
467                    x_minus[i] -= self.epsilon;
468                    gradient[i] = (f(x) - f(&x_minus.view())) / self.epsilon;
469                }
470            }
471            GradientMethod::Central => {
472                for i in 0..n {
473                    let mut x_plus = x.to_owned();
474                    let mut x_minus = x.to_owned();
475                    x_plus[i] += self.epsilon;
476                    x_minus[i] -= self.epsilon;
477                    gradient[i] = (f(&x_plus.view()) - f(&x_minus.view())) / (2.0 * self.epsilon);
478                }
479            }
480        }
481
482        Ok(gradient)
483    }
484
485    /// Compute Jacobian matrix for vector-valued functions
486    pub fn compute_jacobian<F>(
487        &self,
488        f: F,
489        x: &ArrayView1<f64>,
490        m: usize,
491    ) -> UtilsResult<Array2<f64>>
492    where
493        F: Fn(&ArrayView1<f64>) -> Array1<f64> + Sync,
494    {
495        let n = x.len();
496        let mut jacobian = Array2::zeros((m, n));
497
498        match self.method {
499            GradientMethod::Central => {
500                for j in 0..n {
501                    let mut x_plus = x.to_owned();
502                    let mut x_minus = x.to_owned();
503                    x_plus[j] += self.epsilon;
504                    x_minus[j] -= self.epsilon;
505
506                    let f_plus = f(&x_plus.view());
507                    let f_minus = f(&x_minus.view());
508
509                    for i in 0..m {
510                        jacobian[[i, j]] = (f_plus[i] - f_minus[i]) / (2.0 * self.epsilon);
511                    }
512                }
513            }
514            _ => {
515                // For Jacobian, prefer central differences for accuracy
516                return self
517                    .clone()
518                    .with_method(GradientMethod::Central)
519                    .compute_jacobian(f, x, m);
520            }
521        }
522
523        Ok(jacobian)
524    }
525}
526
527/// Type alias for constraint function
528pub type ConstraintFunction = Box<dyn Fn(&ArrayView1<f64>) -> f64 + Send + Sync>;
529
530/// Constraint handling utilities for constrained optimization
531pub struct ConstraintHandler {
532    pub equality_constraints: Vec<ConstraintFunction>,
533    pub inequality_constraints: Vec<ConstraintFunction>,
534    pub bounds: Option<(Array1<f64>, Array1<f64>)>, // (lower, upper)
535    pub penalty_parameter: f64,
536    pub tolerance: f64,
537}
538
539impl Default for ConstraintHandler {
540    fn default() -> Self {
541        Self::new()
542    }
543}
544
545impl std::fmt::Debug for ConstraintHandler {
546    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
547        f.debug_struct("ConstraintHandler")
548            .field(
549                "equality_constraints",
550                &format!("{} functions", self.equality_constraints.len()),
551            )
552            .field(
553                "inequality_constraints",
554                &format!("{} functions", self.inequality_constraints.len()),
555            )
556            .field("bounds", &self.bounds)
557            .field("penalty_parameter", &self.penalty_parameter)
558            .field("tolerance", &self.tolerance)
559            .finish()
560    }
561}
562
563#[derive(Debug, Clone)]
564pub struct ConstraintViolation {
565    pub equality_violations: Vec<f64>,
566    pub inequality_violations: Vec<f64>,
567    pub bound_violations: Vec<f64>,
568    pub max_violation: f64,
569    pub total_violation: f64,
570}
571
572impl ConstraintHandler {
573    /// Create new constraint handler
574    pub fn new() -> Self {
575        Self {
576            equality_constraints: Vec::new(),
577            inequality_constraints: Vec::new(),
578            bounds: None,
579            penalty_parameter: 1.0,
580            tolerance: 1e-6,
581        }
582    }
583
584    /// Set bounds constraints
585    pub fn with_bounds(mut self, lower: Array1<f64>, upper: Array1<f64>) -> Self {
586        self.bounds = Some((lower, upper));
587        self
588    }
589
590    /// Set penalty parameter
591    pub fn with_penalty_parameter(mut self, penalty: f64) -> Self {
592        self.penalty_parameter = penalty;
593        self
594    }
595
596    /// Set tolerance
597    pub fn with_tolerance(mut self, tol: f64) -> Self {
598        self.tolerance = tol;
599        self
600    }
601
602    /// Project point onto bounds
603    pub fn project_bounds(&self, x: &Array1<f64>) -> Array1<f64> {
604        if let Some((ref lower, ref upper)) = self.bounds {
605            x.iter()
606                .zip(lower.iter())
607                .zip(upper.iter())
608                .map(|((x_i, l_i), u_i)| x_i.max(*l_i).min(*u_i))
609                .collect::<Array1<f64>>()
610        } else {
611            x.clone()
612        }
613    }
614
615    /// Check constraint violations
616    pub fn check_violations(&self, x: &ArrayView1<f64>) -> ConstraintViolation {
617        let mut equality_violations = Vec::new();
618        let mut inequality_violations = Vec::new();
619        let mut bound_violations = Vec::new();
620
621        // Check equality constraints: c_eq(x) = 0
622        for constraint in &self.equality_constraints {
623            let violation = constraint(x).abs();
624            equality_violations.push(violation);
625        }
626
627        // Check inequality constraints: c_ineq(x) <= 0
628        for constraint in &self.inequality_constraints {
629            let value = constraint(x);
630            let violation = if value > 0.0 { value } else { 0.0 };
631            inequality_violations.push(violation);
632        }
633
634        // Check bound constraints
635        if let Some((ref lower, ref upper)) = self.bounds {
636            for i in 0..x.len() {
637                let x_i = x[i];
638                let lower_violation = if x_i < lower[i] { lower[i] - x_i } else { 0.0 };
639                let upper_violation = if x_i > upper[i] { x_i - upper[i] } else { 0.0 };
640                bound_violations.push(lower_violation + upper_violation);
641            }
642        }
643
644        let max_violation = equality_violations
645            .iter()
646            .chain(&inequality_violations)
647            .chain(&bound_violations)
648            .fold(0.0f64, |acc, &x| acc.max(x));
649
650        let total_violation = equality_violations.iter().sum::<f64>()
651            + inequality_violations.iter().sum::<f64>()
652            + bound_violations.iter().sum::<f64>();
653
654        ConstraintViolation {
655            equality_violations,
656            inequality_violations,
657            bound_violations,
658            max_violation,
659            total_violation,
660        }
661    }
662
663    /// Check if constraints are satisfied
664    pub fn is_feasible(&self, x: &ArrayView1<f64>) -> bool {
665        let violations = self.check_violations(x);
666        violations.max_violation <= self.tolerance
667    }
668
669    /// Compute penalty function value
670    pub fn penalty_function(&self, x: &ArrayView1<f64>) -> f64 {
671        let violations = self.check_violations(x);
672        self.penalty_parameter * violations.total_violation
673    }
674}
675
676/// History tracking for optimization algorithms
677#[derive(Debug, Clone)]
678pub struct OptimizationHistory {
679    pub function_values: VecDeque<f64>,
680    pub gradient_norms: VecDeque<f64>,
681    pub parameter_changes: VecDeque<f64>,
682    pub step_sizes: VecDeque<f64>,
683    pub max_history_size: usize,
684}
685
686impl OptimizationHistory {
687    /// Create new optimization history tracker
688    pub fn new(max_size: usize) -> Self {
689        Self {
690            function_values: VecDeque::new(),
691            gradient_norms: VecDeque::new(),
692            parameter_changes: VecDeque::new(),
693            step_sizes: VecDeque::new(),
694            max_history_size: max_size,
695        }
696    }
697
698    /// Add function value to history
699    pub fn add_function_value(&mut self, value: f64) {
700        if self.function_values.len() >= self.max_history_size {
701            self.function_values.pop_front();
702        }
703        self.function_values.push_back(value);
704    }
705
706    /// Add gradient norm to history
707    pub fn add_gradient_norm(&mut self, norm: f64) {
708        if self.gradient_norms.len() >= self.max_history_size {
709            self.gradient_norms.pop_front();
710        }
711        self.gradient_norms.push_back(norm);
712    }
713
714    /// Add parameter change to history
715    pub fn add_parameter_change(&mut self, change: f64) {
716        if self.parameter_changes.len() >= self.max_history_size {
717            self.parameter_changes.pop_front();
718        }
719        self.parameter_changes.push_back(change);
720    }
721
722    /// Add step size to history
723    pub fn add_step_size(&mut self, step: f64) {
724        if self.step_sizes.len() >= self.max_history_size {
725            self.step_sizes.pop_front();
726        }
727        self.step_sizes.push_back(step);
728    }
729
730    /// Get recent function values
731    pub fn recent_function_values(&self, n: usize) -> Vec<f64> {
732        self.function_values.iter().rev().take(n).cloned().collect()
733    }
734
735    /// Check for improvement trend
736    pub fn has_improvement_trend(&self, window_size: usize) -> bool {
737        if self.function_values.len() < window_size + 1 {
738            return false;
739        }
740
741        let recent = self.recent_function_values(window_size + 1);
742        if recent.len() < 2 {
743            return false;
744        }
745
746        // Check if function values are decreasing
747        recent.windows(2).all(|w| w[0] < w[1])
748    }
749
750    /// Get average improvement rate
751    pub fn average_improvement_rate(&self, window_size: usize) -> Option<f64> {
752        if self.function_values.len() < window_size + 1 {
753            return None;
754        }
755
756        let recent = self.recent_function_values(window_size + 1);
757        if recent.len() < 2 {
758            return None;
759        }
760
761        let improvements: Vec<f64> = recent.windows(2).map(|w| w[1] - w[0]).collect();
762
763        let avg_improvement = improvements.iter().sum::<f64>() / improvements.len() as f64;
764        Some(avg_improvement)
765    }
766}
767
768#[allow(non_snake_case)]
769#[cfg(test)]
770mod tests {
771    use super::*;
772    use scirs2_core::ndarray::array;
773
774    #[test]
775    fn test_line_search_armijo() {
776        let line_search = LineSearch::armijo(1e-4);
777
778        // Simple quadratic function: f(x) = x^2
779        let f = |x: &ArrayView1<f64>| x[0] * x[0];
780        let grad_f = |x: &ArrayView1<f64>| array![2.0 * x[0]];
781
782        let x = array![2.0];
783        let direction = array![-1.0]; // Descent direction
784        let f_x = f(&x.view());
785        let grad_x = grad_f(&x.view());
786
787        let alpha = line_search
788            .search(f, grad_f, &x.view(), &direction.view(), f_x, &grad_x.view())
789            .unwrap();
790
791        assert!(alpha > 0.0);
792        assert!(alpha <= 1.0);
793    }
794
795    #[test]
796    fn test_convergence_criteria() {
797        let criteria = ConvergenceCriteria::new()
798            .with_tolerance(1e-6)
799            .with_max_iterations(100);
800
801        // Test continuing
802        let status = criteria.is_converged(50, 1.0, Some(1.1), None, None, 0);
803        assert_eq!(status, ConvergenceStatus::Continuing);
804
805        // Test max iterations
806        let status = criteria.is_converged(100, 1.0, Some(1.1), None, None, 0);
807        assert_eq!(status, ConvergenceStatus::MaxIterationsReached);
808
809        // Test gradient tolerance
810        let small_grad = array![1e-7];
811        let status = criteria.is_converged(50, 1.0, Some(1.1), Some(&small_grad.view()), None, 0);
812        assert_eq!(status, ConvergenceStatus::GradientTolerance);
813    }
814
815    #[test]
816    fn test_gradient_computer() {
817        let grad_computer = GradientComputer::new();
818
819        // Test on quadratic function: f(x) = x₁² + x₂²
820        let f = |x: &ArrayView1<f64>| x[0] * x[0] + x[1] * x[1];
821        let x = array![2.0, 3.0];
822
823        let gradient = grad_computer.compute_gradient(f, &x.view()).unwrap();
824
825        // Analytical gradient should be [2*x₁, 2*x₂] = [4.0, 6.0]
826        assert!((gradient[0] - 4.0).abs() < 1e-6);
827        assert!((gradient[1] - 6.0).abs() < 1e-6);
828    }
829
830    #[test]
831    fn test_constraint_handler_bounds() {
832        let lower = array![-1.0, -2.0];
833        let upper = array![1.0, 2.0];
834        let handler = ConstraintHandler::new().with_bounds(lower, upper);
835
836        // Test point within bounds
837        let x_feasible = array![0.5, 1.0];
838        assert!(handler.is_feasible(&x_feasible.view()));
839
840        // Test point outside bounds
841        let x_infeasible = array![2.0, -3.0];
842        assert!(!handler.is_feasible(&x_infeasible.view()));
843
844        // Test projection
845        let x_projected = handler.project_bounds(&x_infeasible);
846        assert_eq!(x_projected, array![1.0, -2.0]);
847        assert!(handler.is_feasible(&x_projected.view()));
848    }
849
850    #[test]
851    fn test_optimization_history() {
852        let mut history = OptimizationHistory::new(5);
853
854        // Add some function values
855        for i in 0..7 {
856            history.add_function_value(i as f64);
857        }
858
859        // Should only keep last 5 values
860        assert_eq!(history.function_values.len(), 5);
861        assert_eq!(history.recent_function_values(3), vec![6.0, 5.0, 4.0]);
862
863        // Test improvement trend (values are increasing, so no improvement)
864        assert!(!history.has_improvement_trend(3));
865
866        // Add decreasing values
867        let mut history2 = OptimizationHistory::new(10);
868        for i in (0..5).rev() {
869            history2.add_function_value(i as f64);
870        }
871
872        assert!(history2.has_improvement_trend(3));
873    }
874
875    #[test]
876    fn test_jacobian_computation() {
877        let grad_computer = GradientComputer::new();
878
879        // Test vector function: f(x) = [x₁², x₁*x₂]
880        let f = |x: &ArrayView1<f64>| array![x[0] * x[0], x[0] * x[1]];
881        let x = array![2.0, 3.0];
882
883        let jacobian = grad_computer.compute_jacobian(f, &x.view(), 2).unwrap();
884
885        // Analytical Jacobian should be:
886        // [[2*x₁, 0  ],     [[4, 0],
887        //  [x₂,   x₁]]  =    [3, 2]]
888        assert!((jacobian[[0, 0]] - 4.0).abs() < 1e-6);
889        assert!((jacobian[[0, 1]] - 0.0).abs() < 1e-6);
890        assert!((jacobian[[1, 0]] - 3.0).abs() < 1e-6);
891        assert!((jacobian[[1, 1]] - 2.0).abs() < 1e-6);
892    }
893
894    #[test]
895    fn test_constraint_violations() {
896        let handler = ConstraintHandler::new()
897            .with_bounds(array![-1.0, -1.0], array![1.0, 1.0])
898            .with_tolerance(1e-6);
899
900        let x_violating = array![2.0, -2.0];
901        let violations = handler.check_violations(&x_violating.view());
902
903        assert!(violations.max_violation > 0.0);
904        assert!(violations.total_violation > 0.0);
905        assert_eq!(violations.bound_violations.len(), 2);
906        assert!(violations.bound_violations[0] > 0.0); // Upper bound violation
907        assert!(violations.bound_violations[1] > 0.0); // Lower bound violation
908    }
909}