Skip to main content

scirs2_optimize/evolution/
cma_es.rs

1//! Covariance Matrix Adaptation Evolution Strategy (CMA-ES)
2//!
3//! CMA-ES is a stochastic, derivative-free method for numerical optimization
4//! of non-linear or non-convex continuous optimization problems. It belongs
5//! to the class of evolutionary algorithms and evolution strategies.
6//!
7//! ## Key features
8//!
9//! - Invariant under order-preserving transformations of the fitness function
10//! - Robust to scaling, rotation, and translation of the search space
11//! - Adaptive step-size control (Cumulative Step-size Adaptation / CSA)
12//! - Covariance matrix adaptation via rank-1 and rank-mu updates
13//! - IPOP restart strategy for escaping local optima
14//! - Multiple boundary handling strategies
15//!
16//! ## References
17//!
18//! - Hansen, N. (2016). The CMA Evolution Strategy: A Tutorial. arXiv:1604.00772
19//! - Auger, A. & Hansen, N. (2005). A Restart CMA Evolution Strategy with
20//!   Increasing Population Size. CEC 2005.
21
22use crate::error::{OptimizeError, OptimizeResult};
23use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
24use scirs2_core::random::rngs::StdRng;
25use scirs2_core::random::{Rng, SeedableRng};
26use scirs2_core::RngExt;
27
28/// Boundary handling strategy for CMA-ES
29#[derive(Debug, Clone, Copy, PartialEq)]
30pub enum BoundaryHandling {
31    /// No boundary handling (unconstrained)
32    None,
33    /// Project infeasible solutions onto the boundary
34    Projection,
35    /// Reflect infeasible solutions at the boundary
36    Reflection,
37    /// Penalize infeasible solutions with a quadratic penalty
38    Penalty {
39        /// Penalty weight factor
40        weight: f64,
41    },
42    /// Resample infeasible solutions until feasible
43    Resampling {
44        /// Maximum number of resampling attempts
45        max_attempts: usize,
46    },
47}
48
49impl Default for BoundaryHandling {
50    fn default() -> Self {
51        BoundaryHandling::Reflection
52    }
53}
54
55/// Restart strategy for CMA-ES
56#[derive(Debug, Clone, Copy, PartialEq)]
57pub enum RestartStrategy {
58    /// No restart (single run)
59    NoRestart,
60    /// IPOP: Increasing Population Size restart
61    /// Population doubles on each restart
62    Ipop {
63        /// Maximum number of restarts
64        max_restarts: usize,
65    },
66    /// BIPOP: Bi-Population restart
67    /// Alternates between large and small populations
68    Bipop {
69        /// Maximum number of restarts
70        max_restarts: usize,
71    },
72}
73
74impl Default for RestartStrategy {
75    fn default() -> Self {
76        RestartStrategy::Ipop { max_restarts: 9 }
77    }
78}
79
80/// Options for CMA-ES optimization
81#[derive(Debug, Clone)]
82pub struct CmaEsOptions {
83    /// Initial step size (sigma). Controls the spread of the initial search distribution.
84    pub sigma0: f64,
85    /// Population size (lambda). If None, uses default 4 + floor(3 * ln(n)).
86    pub population_size: Option<usize>,
87    /// Maximum number of function evaluations
88    pub max_fevals: usize,
89    /// Maximum number of iterations
90    pub max_iterations: usize,
91    /// Function value tolerance for convergence
92    pub ftol: f64,
93    /// Solution tolerance for convergence
94    pub xtol: f64,
95    /// Boundary handling strategy
96    pub boundary_handling: BoundaryHandling,
97    /// Restart strategy
98    pub restart_strategy: RestartStrategy,
99    /// Lower bounds for each dimension (None = unbounded)
100    pub lower_bounds: Option<Vec<f64>>,
101    /// Upper bounds for each dimension (None = unbounded)
102    pub upper_bounds: Option<Vec<f64>>,
103    /// Random seed for reproducibility
104    pub seed: Option<u64>,
105    /// Verbosity level (0 = silent, 1 = per-restart, 2 = per-iteration)
106    pub verbosity: usize,
107}
108
109impl Default for CmaEsOptions {
110    fn default() -> Self {
111        Self {
112            sigma0: 0.3,
113            population_size: None,
114            max_fevals: 100_000,
115            max_iterations: 100_000,
116            ftol: 1e-12,
117            xtol: 1e-12,
118            boundary_handling: BoundaryHandling::default(),
119            restart_strategy: RestartStrategy::default(),
120            lower_bounds: None,
121            upper_bounds: None,
122            seed: None,
123            verbosity: 0,
124        }
125    }
126}
127
128/// Result of CMA-ES optimization
129#[derive(Debug, Clone)]
130pub struct CmaEsResult {
131    /// Best solution found
132    pub x: Array1<f64>,
133    /// Best function value found
134    pub fun: f64,
135    /// Total number of function evaluations
136    pub nfev: usize,
137    /// Total number of iterations (across all restarts)
138    pub nit: usize,
139    /// Number of restarts performed
140    pub n_restarts: usize,
141    /// Whether optimization converged successfully
142    pub success: bool,
143    /// Termination message
144    pub message: String,
145    /// Final step size (sigma)
146    pub sigma_final: f64,
147    /// Final condition number of the covariance matrix
148    pub cond_final: f64,
149}
150
151/// Internal state of the CMA-ES algorithm
152#[derive(Debug)]
153pub struct CmaEsState {
154    /// Dimension of the problem
155    n: usize,
156    /// Current mean of the distribution
157    mean: Array1<f64>,
158    /// Current step size
159    sigma: f64,
160    /// Population size (lambda)
161    lambda: usize,
162    /// Number of selected parents (mu)
163    mu: usize,
164    /// Recombination weights
165    weights: Vec<f64>,
166    /// Variance effective selection mass
167    mu_eff: f64,
168    /// Evolution path for sigma (cumulation)
169    p_sigma: Array1<f64>,
170    /// Evolution path for covariance matrix (cumulation)
171    p_c: Array1<f64>,
172    /// Covariance matrix
173    cov: Array2<f64>,
174    /// Eigenvalues of the covariance matrix (squared, i.e. D^2)
175    eigenvalues: Array1<f64>,
176    /// Eigenvectors (columns of B)
177    eigenvectors: Array2<f64>,
178    /// Inverse square root of covariance matrix: C^(-1/2)
179    inv_sqrt_cov: Array2<f64>,
180    /// Learning rate for cumulation of sigma
181    c_sigma: f64,
182    /// Damping for sigma
183    d_sigma: f64,
184    /// Learning rate for cumulation of C
185    c_c: f64,
186    /// Learning rate for rank-1 update
187    c_1: f64,
188    /// Learning rate for rank-mu update
189    c_mu: f64,
190    /// Expected norm of N(0,I) distributed random vector
191    chi_n: f64,
192    /// Iteration counter
193    generation: usize,
194    /// Total function evaluations
195    fevals: usize,
196    /// Best solution found so far
197    best_x: Array1<f64>,
198    /// Best function value found so far
199    best_f: f64,
200    /// Eigendecomposition update counter
201    eigen_update_counter: usize,
202    /// RNG
203    rng: StdRng,
204}
205
206impl CmaEsState {
207    /// Create a new CMA-ES state
208    pub fn new(x0: &[f64], sigma0: f64, lambda: Option<usize>, seed: u64) -> OptimizeResult<Self> {
209        let n = x0.len();
210        if n == 0 {
211            return Err(OptimizeError::InvalidInput(
212                "Dimension must be at least 1".to_string(),
213            ));
214        }
215        if sigma0 <= 0.0 || !sigma0.is_finite() {
216            return Err(OptimizeError::InvalidInput(format!(
217                "sigma0 must be positive and finite, got {}",
218                sigma0
219            )));
220        }
221
222        // Default population size: 4 + floor(3 * ln(n))
223        let lambda = lambda.unwrap_or_else(|| 4 + (3.0 * (n as f64).ln()).floor() as usize);
224        let lambda = lambda.max(4); // minimum population size
225        let mu = lambda / 2;
226
227        // Recombination weights (log-linear)
228        let raw_weights: Vec<f64> = (0..mu)
229            .map(|i| ((mu as f64 + 0.5).ln() - ((i + 1) as f64).ln()).max(0.0))
230            .collect();
231        let w_sum: f64 = raw_weights.iter().sum();
232        let weights: Vec<f64> = raw_weights.iter().map(|w| w / w_sum).collect();
233
234        // Variance effective selection mass
235        let w_sq_sum: f64 = weights.iter().map(|w| w * w).sum();
236        let mu_eff = 1.0 / w_sq_sum;
237
238        // Strategy parameter setting: adaptation
239        let c_sigma = (mu_eff + 2.0) / (n as f64 + mu_eff + 5.0);
240        let d_sigma =
241            1.0 + 2.0 * (((mu_eff - 1.0) / (n as f64 + 1.0)).sqrt() - 1.0).max(0.0) + c_sigma;
242        let c_c = (4.0 + mu_eff / n as f64) / (n as f64 + 4.0 + 2.0 * mu_eff / n as f64);
243        let c_1 = 2.0 / ((n as f64 + 1.3).powi(2) + mu_eff);
244        let c_mu_candidate = (1.0 - c_1)
245            .min(2.0 * (mu_eff - 2.0 + 1.0 / mu_eff) / ((n as f64 + 2.0).powi(2) + mu_eff));
246        let c_mu = c_mu_candidate.max(0.0);
247
248        // Expected norm of N(0,I)
249        let chi_n =
250            (n as f64).sqrt() * (1.0 - 1.0 / (4.0 * n as f64) + 1.0 / (21.0 * (n as f64).powi(2)));
251
252        let mean = Array1::from_vec(x0.to_vec());
253        let cov = Array2::eye(n);
254        let eigenvalues = Array1::ones(n);
255        let eigenvectors = Array2::eye(n);
256        let inv_sqrt_cov = Array2::eye(n);
257        let p_sigma = Array1::zeros(n);
258        let p_c = Array1::zeros(n);
259
260        let rng = StdRng::seed_from_u64(seed);
261
262        Ok(Self {
263            n,
264            mean: mean.clone(),
265            sigma: sigma0,
266            lambda,
267            mu,
268            weights,
269            mu_eff,
270            p_sigma,
271            p_c,
272            cov,
273            eigenvalues,
274            eigenvectors,
275            inv_sqrt_cov,
276            c_sigma,
277            d_sigma,
278            c_c,
279            c_1,
280            c_mu,
281            chi_n,
282            generation: 0,
283            fevals: 0,
284            best_x: mean,
285            best_f: f64::INFINITY,
286            eigen_update_counter: 0,
287            rng,
288        })
289    }
290
291    /// Sample a population of candidate solutions
292    fn sample_population(&mut self) -> Vec<Array1<f64>> {
293        let mut population = Vec::with_capacity(self.lambda);
294        for _ in 0..self.lambda {
295            // z ~ N(0, I)
296            let z: Array1<f64> = Array1::from_vec(
297                (0..self.n)
298                    .map(|_| sample_standard_normal(&mut self.rng))
299                    .collect(),
300            );
301            // y = B * D * z  (where C = B * D^2 * B^T)
302            let d_z = &z * &self.eigenvalues.mapv(f64::sqrt);
303            let y = self.eigenvectors.dot(&d_z);
304            // x = mean + sigma * y
305            let x = &self.mean + &(y * self.sigma);
306            population.push(x);
307        }
308        population
309    }
310
311    /// Apply boundary handling to a candidate solution
312    fn apply_boundary_handling(
313        &mut self,
314        x: &Array1<f64>,
315        lower: &Option<Vec<f64>>,
316        upper: &Option<Vec<f64>>,
317        handling: BoundaryHandling,
318    ) -> (Array1<f64>, f64) {
319        let mut x_fixed = x.clone();
320        let mut penalty = 0.0;
321
322        let lb = lower.as_deref();
323        let ub = upper.as_deref();
324
325        if lb.is_none() && ub.is_none() {
326            return (x_fixed, 0.0);
327        }
328
329        match handling {
330            BoundaryHandling::None => {}
331            BoundaryHandling::Projection => {
332                for i in 0..self.n {
333                    if let Some(lb_vals) = lb {
334                        if x_fixed[i] < lb_vals[i] {
335                            x_fixed[i] = lb_vals[i];
336                        }
337                    }
338                    if let Some(ub_vals) = ub {
339                        if x_fixed[i] > ub_vals[i] {
340                            x_fixed[i] = ub_vals[i];
341                        }
342                    }
343                }
344            }
345            BoundaryHandling::Reflection => {
346                for i in 0..self.n {
347                    let lo = lb.map_or(f64::NEG_INFINITY, |v| v[i]);
348                    let hi = ub.map_or(f64::INFINITY, |v| v[i]);
349                    if lo.is_finite() && hi.is_finite() {
350                        let range = hi - lo;
351                        if range > 0.0 {
352                            // Reflect until within bounds
353                            let mut val = x_fixed[i];
354                            for _ in 0..10 {
355                                if val < lo {
356                                    val = lo + (lo - val);
357                                } else if val > hi {
358                                    val = hi - (val - hi);
359                                } else {
360                                    break;
361                                }
362                            }
363                            // Final clamp if reflection didn't converge
364                            x_fixed[i] = val.clamp(lo, hi);
365                        }
366                    } else {
367                        if lo.is_finite() && x_fixed[i] < lo {
368                            x_fixed[i] = lo;
369                        }
370                        if hi.is_finite() && x_fixed[i] > hi {
371                            x_fixed[i] = hi;
372                        }
373                    }
374                }
375            }
376            BoundaryHandling::Penalty { weight } => {
377                for i in 0..self.n {
378                    if let Some(lb_vals) = lb {
379                        if x_fixed[i] < lb_vals[i] {
380                            let diff = lb_vals[i] - x_fixed[i];
381                            penalty += weight * diff * diff;
382                            x_fixed[i] = lb_vals[i];
383                        }
384                    }
385                    if let Some(ub_vals) = ub {
386                        if x_fixed[i] > ub_vals[i] {
387                            let diff = x_fixed[i] - ub_vals[i];
388                            penalty += weight * diff * diff;
389                            x_fixed[i] = ub_vals[i];
390                        }
391                    }
392                }
393            }
394            BoundaryHandling::Resampling { max_attempts } => {
395                let mut feasible = true;
396                for i in 0..self.n {
397                    let lo = lb.map_or(f64::NEG_INFINITY, |v| v[i]);
398                    let hi = ub.map_or(f64::INFINITY, |v| v[i]);
399                    if x_fixed[i] < lo || x_fixed[i] > hi {
400                        feasible = false;
401                        break;
402                    }
403                }
404                if !feasible {
405                    // Try resampling
406                    for _ in 0..max_attempts {
407                        let z: Array1<f64> = Array1::from_vec(
408                            (0..self.n)
409                                .map(|_| sample_standard_normal(&mut self.rng))
410                                .collect(),
411                        );
412                        let d_z = &z * &self.eigenvalues.mapv(f64::sqrt);
413                        let y = self.eigenvectors.dot(&d_z);
414                        x_fixed = &self.mean + &(y * self.sigma);
415
416                        let mut all_feasible = true;
417                        for i in 0..self.n {
418                            let lo = lb.map_or(f64::NEG_INFINITY, |v| v[i]);
419                            let hi = ub.map_or(f64::INFINITY, |v| v[i]);
420                            if x_fixed[i] < lo || x_fixed[i] > hi {
421                                all_feasible = false;
422                                break;
423                            }
424                        }
425                        if all_feasible {
426                            return (x_fixed, 0.0);
427                        }
428                    }
429                    // If all resampling failed, project
430                    for i in 0..self.n {
431                        if let Some(lb_vals) = lb {
432                            if x_fixed[i] < lb_vals[i] {
433                                x_fixed[i] = lb_vals[i];
434                            }
435                        }
436                        if let Some(ub_vals) = ub {
437                            if x_fixed[i] > ub_vals[i] {
438                                x_fixed[i] = ub_vals[i];
439                            }
440                        }
441                    }
442                }
443            }
444        }
445
446        (x_fixed, penalty)
447    }
448
449    /// Perform eigendecomposition of the covariance matrix
450    fn update_eigen(&mut self) {
451        // Symmetrize C (numerical safety)
452        let n = self.n;
453        for i in 0..n {
454            for j in (i + 1)..n {
455                let avg = 0.5 * (self.cov[[i, j]] + self.cov[[j, i]]);
456                self.cov[[i, j]] = avg;
457                self.cov[[j, i]] = avg;
458            }
459        }
460
461        // Simple Jacobi eigendecomposition for small-medium problems
462        let (eigenvalues, eigenvectors) = jacobi_eigen(&self.cov, n);
463
464        // Ensure eigenvalues are positive (numerical stability)
465        let min_eigenval = 1e-20;
466        self.eigenvalues = eigenvalues.mapv(|v| v.max(min_eigenval));
467        self.eigenvectors = eigenvectors;
468
469        // Compute C^(-1/2) = B * D^(-1) * B^T
470        let d_inv = self.eigenvalues.mapv(|v| 1.0 / v.sqrt());
471        let mut inv_sqrt = Array2::zeros((n, n));
472        for i in 0..n {
473            for j in 0..n {
474                let mut sum = 0.0;
475                for k in 0..n {
476                    sum += self.eigenvectors[[i, k]] * d_inv[k] * self.eigenvectors[[j, k]];
477                }
478                inv_sqrt[[i, j]] = sum;
479            }
480        }
481        self.inv_sqrt_cov = inv_sqrt;
482    }
483
484    /// Update the CMA-ES state after evaluating a generation
485    fn update(&mut self, population: &[Array1<f64>], fitness: &[f64]) {
486        // Sort by fitness (ascending)
487        let mut indices: Vec<usize> = (0..self.lambda).collect();
488        indices.sort_by(|&a, &b| {
489            fitness[a]
490                .partial_cmp(&fitness[b])
491                .unwrap_or(std::cmp::Ordering::Equal)
492        });
493
494        // Update best
495        if fitness[indices[0]] < self.best_f {
496            self.best_f = fitness[indices[0]];
497            self.best_x = population[indices[0]].clone();
498        }
499
500        // Compute weighted mean of selected points
501        let old_mean = self.mean.clone();
502        self.mean = Array1::zeros(self.n);
503        for i in 0..self.mu {
504            let idx = indices[i];
505            self.mean = &self.mean + &(&population[idx] * self.weights[i]);
506        }
507
508        // Compute displacement from old mean
509        let mean_diff = &self.mean - &old_mean;
510
511        // Update evolution path for sigma (p_sigma)
512        let inv_sqrt_cov_mean_diff = self.inv_sqrt_cov.dot(&mean_diff);
513        let c_sigma_factor = (self.c_sigma * (2.0 - self.c_sigma) * self.mu_eff).sqrt();
514        self.p_sigma = &self.p_sigma * (1.0 - self.c_sigma)
515            + &(&inv_sqrt_cov_mean_diff * (c_sigma_factor / self.sigma));
516
517        // Determine if we should use heavy-side function (h_sigma)
518        let p_sigma_norm = self.p_sigma.dot(&self.p_sigma).sqrt();
519        let threshold = (1.0 - (1.0 - self.c_sigma).powi(2 * (self.generation as i32 + 1))).sqrt()
520            * (1.4 + 2.0 / (self.n as f64 + 1.0))
521            * self.chi_n;
522        let h_sigma: f64 = if p_sigma_norm < threshold { 1.0 } else { 0.0 };
523
524        // Update evolution path for C (p_c)
525        let c_c_factor = (self.c_c * (2.0 - self.c_c) * self.mu_eff).sqrt();
526        self.p_c =
527            &self.p_c * (1.0 - self.c_c) + &(&mean_diff * (h_sigma * c_c_factor / self.sigma));
528
529        // Rank-1 update component
530        let pc_outer = outer_product(&self.p_c, &self.p_c);
531
532        // Correction factor for h_sigma
533        let delta_h = (1.0 - h_sigma) * self.c_c * (2.0 - self.c_c);
534
535        // Rank-mu update component
536        let mut rank_mu_update: Array2<f64> = Array2::zeros((self.n, self.n));
537        for i in 0..self.mu {
538            let idx = indices[i];
539            let y_i = (&population[idx] - &old_mean) / self.sigma;
540            let y_outer = outer_product(&y_i, &y_i);
541            rank_mu_update = rank_mu_update + &(&y_outer * self.weights[i]);
542        }
543
544        // Update covariance matrix
545        // C = (1 - c_1 - c_mu + delta_h * c_1) * C + c_1 * pc_outer + c_mu * rank_mu_update
546        let scale = 1.0 - self.c_1 - self.c_mu + delta_h * self.c_1;
547        self.cov = &self.cov * scale + &(&pc_outer * self.c_1) + &(&rank_mu_update * self.c_mu);
548
549        // Update step size (sigma) via CSA
550        let sigma_exp = (self.c_sigma / self.d_sigma) * (p_sigma_norm / self.chi_n - 1.0);
551        self.sigma *= sigma_exp.exp();
552
553        // Clamp sigma to prevent explosion/implosion
554        self.sigma = self.sigma.clamp(1e-20, 1e20);
555
556        // Update eigendecomposition periodically
557        self.eigen_update_counter += 1;
558        let update_freq = (self.lambda as f64 / ((self.c_1 + self.c_mu) * self.n as f64 * 10.0))
559            .max(1.0) as usize;
560        if self.eigen_update_counter >= update_freq {
561            self.update_eigen();
562            self.eigen_update_counter = 0;
563        }
564
565        self.generation += 1;
566    }
567
568    /// Check stopping criteria
569    fn check_termination(&self, recent_fitness: &[f64], options: &CmaEsOptions) -> Option<String> {
570        // Max function evaluations
571        if self.fevals >= options.max_fevals {
572            return Some(format!(
573                "Maximum function evaluations ({}) reached",
574                options.max_fevals
575            ));
576        }
577
578        // Max iterations
579        if self.generation >= options.max_iterations {
580            return Some(format!(
581                "Maximum iterations ({}) reached",
582                options.max_iterations
583            ));
584        }
585
586        // Function value tolerance (flat fitness landscape)
587        if recent_fitness.len() >= self.lambda {
588            let f_best = recent_fitness.iter().copied().fold(f64::INFINITY, f64::min);
589            let f_worst = recent_fitness
590                .iter()
591                .copied()
592                .fold(f64::NEG_INFINITY, f64::max);
593            if (f_worst - f_best).abs() < options.ftol {
594                return Some("Function tolerance reached (flat fitness)".to_string());
595            }
596        }
597
598        // Solution tolerance (sigma * max(eigenvalue) is very small)
599        let max_eigenval = self.eigenvalues.iter().copied().fold(0.0_f64, f64::max);
600        if self.sigma * max_eigenval.sqrt() < options.xtol {
601            return Some("Solution tolerance reached".to_string());
602        }
603
604        // Condition number check
605        let min_eigenval = self
606            .eigenvalues
607            .iter()
608            .copied()
609            .fold(f64::INFINITY, f64::min);
610        if min_eigenval > 0.0 {
611            let cond = max_eigenval / min_eigenval;
612            if cond > 1e14 {
613                return Some(format!("Condition number too large: {:.2e}", cond));
614            }
615        }
616
617        // Sigma too small
618        if self.sigma < 1e-20 {
619            return Some("Step size sigma below minimum threshold".to_string());
620        }
621
622        None
623    }
624
625    /// Get the condition number of the covariance matrix
626    pub fn condition_number(&self) -> f64 {
627        let max_ev = self.eigenvalues.iter().copied().fold(0.0_f64, f64::max);
628        let min_ev = self
629            .eigenvalues
630            .iter()
631            .copied()
632            .fold(f64::INFINITY, f64::min);
633        if min_ev > 0.0 {
634            max_ev / min_ev
635        } else {
636            f64::INFINITY
637        }
638    }
639}
640
641/// IPOP-CMA-ES: CMA-ES with Increasing Population Restarts
642pub struct IpopCmaEs<F>
643where
644    F: Fn(&ArrayView1<f64>) -> f64,
645{
646    func: F,
647    x0: Vec<f64>,
648    options: CmaEsOptions,
649}
650
651impl<F> IpopCmaEs<F>
652where
653    F: Fn(&ArrayView1<f64>) -> f64,
654{
655    /// Create a new IPOP-CMA-ES optimizer
656    pub fn new(func: F, x0: &[f64], options: CmaEsOptions) -> Self {
657        Self {
658            func,
659            x0: x0.to_vec(),
660            options,
661        }
662    }
663
664    /// Run the IPOP-CMA-ES algorithm
665    pub fn run(&self) -> OptimizeResult<CmaEsResult> {
666        let max_restarts = match self.options.restart_strategy {
667            RestartStrategy::NoRestart => 0,
668            RestartStrategy::Ipop { max_restarts } => max_restarts,
669            RestartStrategy::Bipop { max_restarts } => max_restarts,
670        };
671
672        let seed = self
673            .options
674            .seed
675            .unwrap_or_else(|| scirs2_core::random::rng().random_range(0..u64::MAX));
676
677        let mut overall_best_x = Array1::from_vec(self.x0.clone());
678        let mut overall_best_f = f64::INFINITY;
679        let mut total_fevals = 0_usize;
680        let mut total_iterations = 0_usize;
681        let mut final_message = String::new();
682        let mut final_sigma = self.options.sigma0;
683        let mut final_cond = 1.0;
684        let mut any_success = false;
685
686        let base_lambda = self
687            .options
688            .population_size
689            .unwrap_or_else(|| 4 + (3.0 * (self.x0.len() as f64).ln()).floor() as usize);
690
691        let mut rng = StdRng::seed_from_u64(seed);
692
693        for restart_idx in 0..=max_restarts {
694            // Increase population size for IPOP
695            let current_lambda = if restart_idx == 0 {
696                base_lambda
697            } else {
698                match self.options.restart_strategy {
699                    RestartStrategy::Ipop { .. } => base_lambda * (1 << restart_idx.min(10)),
700                    RestartStrategy::Bipop { .. } => {
701                        // Alternate between large and small
702                        if restart_idx % 2 == 1 {
703                            base_lambda * (1 << ((restart_idx / 2 + 1).min(10)))
704                        } else {
705                            (base_lambda / 2).max(4)
706                        }
707                    }
708                    RestartStrategy::NoRestart => base_lambda,
709                }
710            };
711
712            // For restarts > 0, randomize starting point within bounds
713            let x0_restart = if restart_idx == 0 {
714                self.x0.clone()
715            } else {
716                let mut x0_new = self.x0.clone();
717                for i in 0..x0_new.len() {
718                    let lo = self.options.lower_bounds.as_ref().map_or(-10.0, |lb| lb[i]);
719                    let hi = self.options.upper_bounds.as_ref().map_or(10.0, |ub| ub[i]);
720                    x0_new[i] = rng.random_range(lo..hi);
721                }
722                x0_new
723            };
724
725            let run_seed = seed.wrapping_add(restart_idx as u64 * 1_000_000);
726            let result = self.run_single(
727                &x0_restart,
728                current_lambda,
729                run_seed,
730                self.options.max_fevals.saturating_sub(total_fevals),
731            )?;
732
733            total_fevals += result.nfev;
734            total_iterations += result.nit;
735
736            if result.fun < overall_best_f {
737                overall_best_f = result.fun;
738                overall_best_x = result.x.clone();
739                final_sigma = result.sigma_final;
740                final_cond = result.cond_final;
741            }
742            if result.success {
743                any_success = true;
744            }
745            final_message = result.message.clone();
746
747            // Check if we've exhausted the budget
748            if total_fevals >= self.options.max_fevals {
749                final_message = "Budget exhausted across restarts".to_string();
750                break;
751            }
752
753            // If converged with good tolerance, no need for more restarts
754            if result.success && result.fun < self.options.ftol {
755                break;
756            }
757        }
758
759        let n_restarts = if any_success { 0 } else { max_restarts };
760
761        Ok(CmaEsResult {
762            x: overall_best_x,
763            fun: overall_best_f,
764            nfev: total_fevals,
765            nit: total_iterations,
766            n_restarts,
767            success: any_success || overall_best_f < f64::INFINITY,
768            message: final_message,
769            sigma_final: final_sigma,
770            cond_final: final_cond,
771        })
772    }
773
774    /// Run a single CMA-ES optimization
775    fn run_single(
776        &self,
777        x0: &[f64],
778        lambda: usize,
779        seed: u64,
780        max_fevals: usize,
781    ) -> OptimizeResult<CmaEsResult> {
782        let mut state = CmaEsState::new(x0, self.options.sigma0, Some(lambda), seed)?;
783
784        // Evaluate initial point
785        let x0_arr = Array1::from_vec(x0.to_vec());
786        let f0 = (self.func)(&x0_arr.view());
787        state.fevals += 1;
788        if f0 < state.best_f {
789            state.best_f = f0;
790            state.best_x = x0_arr;
791        }
792
793        let mut termination_msg = String::new();
794        let mut converged = false;
795
796        loop {
797            // Sample population
798            let mut population = state.sample_population();
799
800            // Apply boundary handling
801            let mut penalties = vec![0.0; state.lambda];
802            for i in 0..state.lambda {
803                let (x_fixed, pen) = state.apply_boundary_handling(
804                    &population[i],
805                    &self.options.lower_bounds,
806                    &self.options.upper_bounds,
807                    self.options.boundary_handling,
808                );
809                population[i] = x_fixed;
810                penalties[i] = pen;
811            }
812
813            // Evaluate fitness
814            let mut fitness: Vec<f64> = population.iter().map(|x| (self.func)(&x.view())).collect();
815            state.fevals += state.lambda;
816
817            // Add penalties
818            for (f, p) in fitness.iter_mut().zip(penalties.iter()) {
819                *f += *p;
820            }
821
822            // Check termination
823            if let Some(msg) = state.check_termination(&fitness, &self.options) {
824                termination_msg = msg;
825                // If solution tolerance was reached, it's a success
826                converged = termination_msg.contains("tolerance reached");
827                break;
828            }
829
830            // Check local budget
831            if state.fevals >= max_fevals {
832                termination_msg = "Local budget exhausted".to_string();
833                break;
834            }
835
836            // Update state
837            state.update(&population, &fitness);
838        }
839
840        let cond_final = state.condition_number();
841        Ok(CmaEsResult {
842            x: state.best_x,
843            fun: state.best_f,
844            nfev: state.fevals,
845            nit: state.generation,
846            n_restarts: 0,
847            success: converged,
848            message: termination_msg,
849            sigma_final: state.sigma,
850            cond_final,
851        })
852    }
853}
854
855/// Convenience function to minimize using CMA-ES
856///
857/// # Arguments
858///
859/// * `func` - Objective function to minimize
860/// * `x0` - Initial guess
861/// * `options` - CMA-ES options (uses defaults if None)
862///
863/// # Returns
864///
865/// * `CmaEsResult` with the best solution found
866pub fn cma_es_minimize<F>(
867    func: F,
868    x0: &[f64],
869    options: Option<CmaEsOptions>,
870) -> OptimizeResult<CmaEsResult>
871where
872    F: Fn(&ArrayView1<f64>) -> f64,
873{
874    let options = options.unwrap_or_default();
875    let optimizer = IpopCmaEs::new(func, x0, options);
876    optimizer.run()
877}
878
879// ---- Utility functions ----
880
881/// Sample from standard normal using Box-Muller transform
882fn sample_standard_normal(rng: &mut StdRng) -> f64 {
883    let u1: f64 = rng.random_range(1e-10..1.0);
884    let u2: f64 = rng.random_range(0.0..std::f64::consts::TAU);
885    (-2.0 * u1.ln()).sqrt() * u2.cos()
886}
887
888/// Compute outer product of two vectors
889fn outer_product(a: &Array1<f64>, b: &Array1<f64>) -> Array2<f64> {
890    let n = a.len();
891    let m = b.len();
892    let mut result = Array2::zeros((n, m));
893    for i in 0..n {
894        for j in 0..m {
895            result[[i, j]] = a[i] * b[j];
896        }
897    }
898    result
899}
900
901/// Jacobi eigendecomposition for symmetric matrices
902/// Returns (eigenvalues, eigenvectors) where eigenvectors are column-wise
903fn jacobi_eigen(mat: &Array2<f64>, n: usize) -> (Array1<f64>, Array2<f64>) {
904    let mut a = mat.clone();
905    let mut v = Array2::eye(n);
906    let max_iter = 100 * n * n;
907    let tol = 1e-15;
908
909    for _ in 0..max_iter {
910        // Find largest off-diagonal element
911        let mut max_val = 0.0_f64;
912        let mut p = 0;
913        let mut q = 1;
914        for i in 0..n {
915            for j in (i + 1)..n {
916                if a[[i, j]].abs() > max_val {
917                    max_val = a[[i, j]].abs();
918                    p = i;
919                    q = j;
920                }
921            }
922        }
923
924        if max_val < tol {
925            break;
926        }
927
928        // Compute rotation
929        let app = a[[p, p]];
930        let aqq = a[[q, q]];
931        let apq = a[[p, q]];
932
933        let theta = if (app - aqq).abs() < tol {
934            std::f64::consts::FRAC_PI_4
935        } else {
936            0.5 * (2.0 * apq / (app - aqq)).atan()
937        };
938
939        let cos_t = theta.cos();
940        let sin_t = theta.sin();
941
942        // Apply Givens rotation
943        let mut new_a = a.clone();
944        for i in 0..n {
945            if i != p && i != q {
946                new_a[[i, p]] = cos_t * a[[i, p]] + sin_t * a[[i, q]];
947                new_a[[p, i]] = new_a[[i, p]];
948                new_a[[i, q]] = -sin_t * a[[i, p]] + cos_t * a[[i, q]];
949                new_a[[q, i]] = new_a[[i, q]];
950            }
951        }
952        new_a[[p, p]] = cos_t * cos_t * app + 2.0 * sin_t * cos_t * apq + sin_t * sin_t * aqq;
953        new_a[[q, q]] = sin_t * sin_t * app - 2.0 * sin_t * cos_t * apq + cos_t * cos_t * aqq;
954        new_a[[p, q]] = 0.0;
955        new_a[[q, p]] = 0.0;
956        a = new_a;
957
958        // Update eigenvectors
959        let mut new_v = v.clone();
960        for i in 0..n {
961            new_v[[i, p]] = cos_t * v[[i, p]] + sin_t * v[[i, q]];
962            new_v[[i, q]] = -sin_t * v[[i, p]] + cos_t * v[[i, q]];
963        }
964        v = new_v;
965    }
966
967    let eigenvalues = Array1::from_vec((0..n).map(|i| a[[i, i]]).collect());
968    (eigenvalues, v)
969}
970
971#[cfg(test)]
972mod tests {
973    use super::*;
974
975    /// Sphere function: sum(x_i^2)
976    fn sphere(x: &ArrayView1<f64>) -> f64 {
977        x.iter().map(|xi| xi * xi).sum()
978    }
979
980    /// Rosenbrock function
981    fn rosenbrock(x: &ArrayView1<f64>) -> f64 {
982        let mut sum = 0.0;
983        for i in 0..x.len() - 1 {
984            sum += 100.0 * (x[i + 1] - x[i] * x[i]).powi(2) + (1.0 - x[i]).powi(2);
985        }
986        sum
987    }
988
989    /// Rastrigin function (multimodal)
990    fn rastrigin(x: &ArrayView1<f64>) -> f64 {
991        let n = x.len() as f64;
992        let mut sum = 10.0 * n;
993        for &xi in x.iter() {
994            sum += xi * xi - 10.0 * (2.0 * std::f64::consts::PI * xi).cos();
995        }
996        sum
997    }
998
999    #[test]
1000    fn test_cma_es_sphere_2d() {
1001        let options = CmaEsOptions {
1002            sigma0: 0.5,
1003            max_fevals: 10_000,
1004            restart_strategy: RestartStrategy::NoRestart,
1005            seed: Some(42),
1006            ..Default::default()
1007        };
1008        let result = cma_es_minimize(sphere, &[3.0, -2.0], Some(options));
1009        assert!(result.is_ok());
1010        let res = result.expect("CMA-ES sphere 2D failed");
1011        assert!(res.fun < 1e-6, "Sphere function value: {}", res.fun);
1012        for &xi in res.x.iter() {
1013            assert!(xi.abs() < 1e-3, "Solution component: {}", xi);
1014        }
1015    }
1016
1017    #[test]
1018    fn test_cma_es_sphere_5d() {
1019        let options = CmaEsOptions {
1020            sigma0: 1.0,
1021            max_fevals: 50_000,
1022            restart_strategy: RestartStrategy::NoRestart,
1023            seed: Some(123),
1024            ..Default::default()
1025        };
1026        let result = cma_es_minimize(sphere, &[5.0, -3.0, 2.0, -1.0, 4.0], Some(options));
1027        assert!(result.is_ok());
1028        let res = result.expect("CMA-ES sphere 5D failed");
1029        assert!(res.fun < 1e-4, "Sphere 5D value: {}", res.fun);
1030    }
1031
1032    #[test]
1033    fn test_cma_es_rosenbrock_2d() {
1034        let options = CmaEsOptions {
1035            sigma0: 0.5,
1036            max_fevals: 50_000,
1037            restart_strategy: RestartStrategy::NoRestart,
1038            seed: Some(99),
1039            ..Default::default()
1040        };
1041        let result = cma_es_minimize(rosenbrock, &[0.0, 0.0], Some(options));
1042        assert!(result.is_ok());
1043        let res = result.expect("CMA-ES Rosenbrock 2D failed");
1044        // Rosenbrock is harder, allow more tolerance
1045        assert!(res.fun < 1e-2, "Rosenbrock value: {}", res.fun);
1046    }
1047
1048    #[test]
1049    fn test_cma_es_with_bounds() {
1050        let options = CmaEsOptions {
1051            sigma0: 0.3,
1052            max_fevals: 10_000,
1053            restart_strategy: RestartStrategy::NoRestart,
1054            lower_bounds: Some(vec![0.0, 0.0]),
1055            upper_bounds: Some(vec![5.0, 5.0]),
1056            boundary_handling: BoundaryHandling::Reflection,
1057            seed: Some(77),
1058            ..Default::default()
1059        };
1060
1061        // Minimum of sphere is at (0,0), which is on the boundary
1062        let result = cma_es_minimize(sphere, &[2.5, 2.5], Some(options));
1063        assert!(result.is_ok());
1064        let res = result.expect("CMA-ES bounded failed");
1065        assert!(res.fun < 0.1, "Bounded sphere value: {}", res.fun);
1066        // Check bounds are respected
1067        for &xi in res.x.iter() {
1068            assert!(xi >= -0.01, "Lower bound violated: {}", xi);
1069            assert!(xi <= 5.01, "Upper bound violated: {}", xi);
1070        }
1071    }
1072
1073    #[test]
1074    fn test_cma_es_ipop_restart() {
1075        let options = CmaEsOptions {
1076            sigma0: 2.0,
1077            max_fevals: 50_000,
1078            restart_strategy: RestartStrategy::Ipop { max_restarts: 3 },
1079            lower_bounds: Some(vec![-5.12, -5.12]),
1080            upper_bounds: Some(vec![5.12, 5.12]),
1081            seed: Some(55),
1082            ..Default::default()
1083        };
1084
1085        let result = cma_es_minimize(rastrigin, &[3.0, -2.0], Some(options));
1086        assert!(result.is_ok());
1087        let res = result.expect("IPOP CMA-ES Rastrigin failed");
1088        // Rastrigin global minimum is 0 at origin
1089        assert!(res.fun < 5.0, "Rastrigin IPOP value: {}", res.fun);
1090    }
1091
1092    #[test]
1093    fn test_cma_es_penalty_boundary() {
1094        let options = CmaEsOptions {
1095            sigma0: 0.5,
1096            max_fevals: 10_000,
1097            restart_strategy: RestartStrategy::NoRestart,
1098            lower_bounds: Some(vec![-1.0, -1.0]),
1099            upper_bounds: Some(vec![1.0, 1.0]),
1100            boundary_handling: BoundaryHandling::Penalty { weight: 100.0 },
1101            seed: Some(42),
1102            ..Default::default()
1103        };
1104
1105        let result = cma_es_minimize(sphere, &[0.5, 0.5], Some(options));
1106        assert!(result.is_ok());
1107    }
1108
1109    #[test]
1110    fn test_cma_es_projection_boundary() {
1111        let options = CmaEsOptions {
1112            sigma0: 0.5,
1113            max_fevals: 10_000,
1114            restart_strategy: RestartStrategy::NoRestart,
1115            lower_bounds: Some(vec![-2.0, -2.0]),
1116            upper_bounds: Some(vec![2.0, 2.0]),
1117            boundary_handling: BoundaryHandling::Projection,
1118            seed: Some(42),
1119            ..Default::default()
1120        };
1121
1122        let result = cma_es_minimize(sphere, &[1.0, 1.0], Some(options));
1123        assert!(result.is_ok());
1124        let res = result.expect("CMA-ES projection failed");
1125        assert!(res.fun < 0.01, "Projection sphere value: {}", res.fun);
1126    }
1127
1128    #[test]
1129    fn test_cma_es_resampling_boundary() {
1130        let options = CmaEsOptions {
1131            sigma0: 0.3,
1132            max_fevals: 10_000,
1133            restart_strategy: RestartStrategy::NoRestart,
1134            lower_bounds: Some(vec![-2.0, -2.0]),
1135            upper_bounds: Some(vec![2.0, 2.0]),
1136            boundary_handling: BoundaryHandling::Resampling { max_attempts: 100 },
1137            seed: Some(42),
1138            ..Default::default()
1139        };
1140
1141        let result = cma_es_minimize(sphere, &[1.0, 1.0], Some(options));
1142        assert!(result.is_ok());
1143    }
1144
1145    #[test]
1146    fn test_state_creation() {
1147        let state = CmaEsState::new(&[0.0, 0.0, 0.0], 0.5, None, 42);
1148        assert!(state.is_ok());
1149        let s = state.expect("State creation failed");
1150        assert_eq!(s.n, 3);
1151        assert_eq!(s.sigma, 0.5);
1152        assert!(s.lambda >= 4);
1153        assert!(s.mu >= 2);
1154    }
1155
1156    #[test]
1157    fn test_state_creation_invalid() {
1158        let state = CmaEsState::new(&[], 0.5, None, 42);
1159        assert!(state.is_err());
1160
1161        let state = CmaEsState::new(&[0.0], -1.0, None, 42);
1162        assert!(state.is_err());
1163
1164        let state = CmaEsState::new(&[0.0], f64::NAN, None, 42);
1165        assert!(state.is_err());
1166    }
1167
1168    #[test]
1169    fn test_jacobi_eigen_identity() {
1170        let mat = Array2::eye(3);
1171        let (eigenvalues, eigenvectors) = jacobi_eigen(&mat, 3);
1172        for &ev in eigenvalues.iter() {
1173            assert!((ev - 1.0).abs() < 1e-10);
1174        }
1175        // Eigenvectors should form orthonormal basis
1176        let prod = eigenvectors.t().dot(&eigenvectors);
1177        for i in 0..3 {
1178            for j in 0..3 {
1179                let expected = if i == j { 1.0 } else { 0.0 };
1180                assert!(
1181                    (prod[[i, j]] - expected).abs() < 1e-10,
1182                    "Orthogonality check failed at ({}, {}): {}",
1183                    i,
1184                    j,
1185                    prod[[i, j]]
1186                );
1187            }
1188        }
1189    }
1190
1191    #[test]
1192    fn test_bipop_restart() {
1193        let options = CmaEsOptions {
1194            sigma0: 1.0,
1195            max_fevals: 20_000,
1196            restart_strategy: RestartStrategy::Bipop { max_restarts: 2 },
1197            seed: Some(42),
1198            ..Default::default()
1199        };
1200
1201        let result = cma_es_minimize(sphere, &[3.0, -2.0], Some(options));
1202        assert!(result.is_ok());
1203    }
1204}