1use crate::error::OptimizeError;
12use scirs2_core::ndarray::{Array1, Array2, Axis};
13use scirs2_core::random::rand_distributions::Distribution;
14use scirs2_core::random::rngs::StdRng;
15use scirs2_core::random::{Normal, SeedableRng};
16
17#[derive(Debug, Clone)]
19pub struct CmaEsConfig {
20 pub population_size: Option<usize>,
22 pub initial_sigma: f64,
24 pub max_fevals: usize,
26 pub ftol: f64,
28 pub xtol: f64,
30 pub seed: u64,
32}
33
34impl Default for CmaEsConfig {
35 fn default() -> Self {
36 Self {
37 population_size: None,
38 initial_sigma: 0.3,
39 max_fevals: 10_000,
40 ftol: 1e-10,
41 xtol: 1e-10,
42 seed: 0,
43 }
44 }
45}
46
47#[derive(Debug, Clone)]
49pub struct CmaEsResult {
50 pub x: Array1<f64>,
52 pub fun: f64,
54 pub fevals: usize,
56 pub generations: usize,
58 pub converged: bool,
60 pub message: String,
62}
63
64pub struct CmaEs {
69 n: usize, lambda: usize, mu: usize, weights: Vec<f64>, mueff: f64, cs: f64, ds: f64, chi_n: f64, cc: f64, c1: f64, cmu: f64, mean: Array1<f64>, sigma: f64, ps: Array1<f64>, pc: Array1<f64>, cov: Array2<f64>, eigenvalues: Array1<f64>, eigenvectors: Array2<f64>, rng: StdRng,
97 normal_dist: Normal<f64>,
98
99 pub fevals: usize,
101 pub generations: usize,
103}
104
105impl CmaEs {
106 pub fn new(x0: Array1<f64>, config: &CmaEsConfig) -> Result<Self, OptimizeError> {
111 let n = x0.len();
112 if n == 0 {
113 return Err(OptimizeError::InvalidInput(
114 "dimension must be > 0".to_string(),
115 ));
116 }
117
118 let lambda = config
119 .population_size
120 .unwrap_or_else(|| 4 + (3.0 * (n as f64).ln()) as usize);
121 if lambda < 2 {
122 return Err(OptimizeError::InvalidInput(
123 "population_size must be >= 2".to_string(),
124 ));
125 }
126 let mu = lambda / 2;
127
128 let raw_weights: Vec<f64> = (1..=mu)
130 .map(|i| ((mu as f64 + 0.5) / i as f64).ln())
131 .collect();
132 let w_sum: f64 = raw_weights.iter().sum();
133 let weights: Vec<f64> = raw_weights.iter().map(|&w| w / w_sum).collect();
134 let mueff = 1.0 / weights.iter().map(|&w| w * w).sum::<f64>();
135
136 let cs = (mueff + 2.0) / (n as f64 + mueff + 5.0);
138 let ds = 1.0 + 2.0 * (0.0f64.max((mueff - 1.0) / (n as f64 + 1.0) - 1.0)).sqrt() + cs;
139 let chi_n =
140 (n as f64).sqrt() * (1.0 - 1.0 / (4.0 * n as f64) + 1.0 / (21.0 * n as f64 * n as f64));
141
142 let cc = (4.0 + mueff / n as f64) / (n as f64 + 4.0 + 2.0 * mueff / n as f64);
144 let c1 = 2.0 / ((n as f64 + 1.3).powi(2) + mueff);
145 let alpha_mu = 2.0;
146 let cmu = (alpha_mu * (mueff - 2.0 + 1.0 / mueff)
147 / ((n as f64 + 2.0).powi(2) + alpha_mu * mueff / 2.0))
148 .min(1.0 - c1);
149
150 let normal_dist =
151 Normal::new(0.0, 1.0).map_err(|e| OptimizeError::InitializationError(e.to_string()))?;
152
153 Ok(Self {
154 n,
155 lambda,
156 mu,
157 weights,
158 mueff,
159 cs,
160 ds,
161 chi_n,
162 cc,
163 c1,
164 cmu,
165 mean: x0,
166 sigma: config.initial_sigma,
167 ps: Array1::zeros(n),
168 pc: Array1::zeros(n),
169 cov: Array2::eye(n),
170 eigenvalues: Array1::ones(n),
171 eigenvectors: Array2::eye(n),
172 rng: StdRng::seed_from_u64(config.seed),
173 normal_dist,
174 fevals: 0,
175 generations: 0,
176 })
177 }
178
179 fn sample_population(&mut self) -> Vec<Array1<f64>> {
183 let mut pop = Vec::with_capacity(self.lambda);
184 for _ in 0..self.lambda {
185 let z: Array1<f64> = (0..self.n)
187 .map(|_| self.normal_dist.sample(&mut self.rng))
188 .collect::<Vec<f64>>()
189 .into();
190
191 let dz: Array1<f64> = &z * &self.eigenvalues.mapv(|v| v.sqrt());
193
194 let bdz = self.eigenvectors.dot(&dz);
196
197 pop.push(&self.mean + &(self.sigma * &bdz));
199 }
200 pop
201 }
202
203 fn update(&mut self, ranked: &[(usize, f64)], population: &[Array1<f64>]) {
212 let n = self.n;
213 let gen_f64 = self.generations as f64 + 1.0;
214
215 let old_mean = self.mean.clone();
217 let mut new_mean = Array1::zeros(n);
218 for (k, &(idx, _)) in ranked[..self.mu].iter().enumerate() {
219 new_mean = new_mean + &population[idx] * self.weights[k];
220 }
221 self.mean = new_mean;
222
223 let mean_diff = (&self.mean - &old_mean) / self.sigma;
225
226 let inv_sqrt_diag: Array1<f64> =
228 self.eigenvalues
229 .mapv(|v| if v > 1e-14 { v.recip().sqrt() } else { 0.0 });
230 let inv_sqrt_c: Array2<f64> = self
232 .eigenvectors
233 .dot(&Array2::from_diag(&inv_sqrt_diag))
234 .dot(&self.eigenvectors.t());
235
236 let invsqrt_diff = inv_sqrt_c.dot(&mean_diff);
238 let cs = self.cs;
239 self.ps = (1.0 - cs) * &self.ps + (cs * (2.0 - cs) * self.mueff).sqrt() * &invsqrt_diff;
240
241 let ps_norm = self.ps.mapv(|v| v * v).sum().sqrt();
243 let h_thresh = 1.4 + 2.0 / (n as f64 + 1.0);
244 let ps_norm_normalized =
245 ps_norm / (1.0 - (1.0 - cs).powf(2.0 * gen_f64)).sqrt() / self.chi_n;
246 let h_sig = ps_norm_normalized < h_thresh;
247
248 let delta_h = if h_sig {
250 0.0
251 } else {
252 (2.0 - self.cc) * self.cc
253 };
254
255 self.pc = (1.0 - self.cc) * &self.pc
256 + if h_sig {
257 (self.cc * (2.0 - self.cc) * self.mueff).sqrt() * &mean_diff
258 } else {
259 Array1::zeros(n)
260 };
261
262 let pc_col = self.pc.view().insert_axis(Axis(1));
264 let rank_one: Array2<f64> = pc_col.dot(&pc_col.t());
265
266 let mut rank_mu: Array2<f64> = Array2::zeros((n, n));
268 for (k, &(idx, _)) in ranked[..self.mu].iter().enumerate() {
269 let diff = (&population[idx] - &old_mean) / self.sigma;
270 let diff_col = diff.view().insert_axis(Axis(1));
271 rank_mu = rank_mu + self.weights[k] * diff_col.dot(&diff_col.t());
272 }
273
274 let c1 = self.c1;
277 let cmu = self.cmu;
278 self.cov = (1.0 + c1 * delta_h - c1 - cmu) * &self.cov + c1 * &rank_one + cmu * &rank_mu;
279
280 let ps_norm_new = self.ps.mapv(|v| v * v).sum().sqrt();
282 self.sigma *= ((cs / self.ds) * (ps_norm_new / self.chi_n - 1.0)).exp();
283
284 self.update_eigen();
286
287 self.generations += 1;
288 }
289
290 fn update_eigen(&mut self) {
295 let n = self.n;
296 let mut a = self.cov.clone();
298 for i in 0..n {
300 for j in (i + 1)..n {
301 let sym = (a[[i, j]] + a[[j, i]]) * 0.5;
302 a[[i, j]] = sym;
303 a[[j, i]] = sym;
304 }
305 }
306
307 let mut v = Array2::eye(n);
308
309 for _ in 0..20 {
312 let mut off_norm_sq = 0.0;
313 for i in 0..n {
314 for j in (i + 1)..n {
315 off_norm_sq += a[[i, j]] * a[[i, j]];
316 }
317 }
318 if off_norm_sq < 1e-28 {
320 break;
321 }
322
323 for p in 0..n {
324 for q in (p + 1)..n {
325 let apq = a[[p, q]];
326 if apq.abs() < 1e-15 {
327 continue;
328 }
329 let app = a[[p, p]];
330 let aqq = a[[q, q]];
331
332 let theta = 0.5 * (aqq - app) / apq;
334 let t = theta.signum() / (theta.abs() + (1.0 + theta * theta).sqrt());
335 let c_r = 1.0 / (1.0 + t * t).sqrt();
336 let s_r = t * c_r;
337
338 a[[p, p]] = app - t * apq;
340 a[[q, q]] = aqq + t * apq;
341 a[[p, q]] = 0.0;
342 a[[q, p]] = 0.0;
343
344 for r in 0..n {
346 if r != p && r != q {
347 let arp = a[[r, p]];
348 let arq = a[[r, q]];
349 let new_arp = c_r * arp - s_r * arq;
350 let new_arq = s_r * arp + c_r * arq;
351 a[[r, p]] = new_arp;
352 a[[p, r]] = new_arp;
353 a[[r, q]] = new_arq;
354 a[[q, r]] = new_arq;
355 }
356 }
357
358 for r in 0..n {
360 let vrp = v[[r, p]];
361 let vrq = v[[r, q]];
362 v[[r, p]] = c_r * vrp - s_r * vrq;
363 v[[r, q]] = s_r * vrp + c_r * vrq;
364 }
365 }
366 }
367 }
368
369 for i in 0..n {
371 self.eigenvalues[i] = a[[i, i]].max(1e-20);
372 }
373 self.eigenvectors = v;
374 }
375}
376
377pub fn minimize_cmaes<F>(
403 f: F,
404 x0: Array1<f64>,
405 bounds: Option<&[(f64, f64)]>,
406 config: CmaEsConfig,
407) -> Result<CmaEsResult, OptimizeError>
408where
409 F: Fn(&Array1<f64>) -> f64,
410{
411 let n = x0.len();
412 if n == 0 {
413 return Err(OptimizeError::InvalidInput(
414 "dimension must be > 0".to_string(),
415 ));
416 }
417 if let Some(b) = bounds {
418 if b.len() != n {
419 return Err(OptimizeError::InvalidInput(format!(
420 "bounds length {} does not match x0 length {}",
421 b.len(),
422 n
423 )));
424 }
425 for (i, &(lo, hi)) in b.iter().enumerate() {
426 if lo > hi {
427 return Err(OptimizeError::InvalidInput(format!(
428 "bounds[{}]: lower {} > upper {}",
429 i, lo, hi
430 )));
431 }
432 }
433 }
434
435 let max_fevals = config.max_fevals;
436 let xtol = config.xtol;
437
438 let mut state = CmaEs::new(x0.clone(), &config)?;
439
440 let x0_clipped = clip_to_bounds(&x0, bounds);
442 let mut best_f = f(&x0_clipped);
443 let mut best_x = x0_clipped;
444 state.fevals += 1;
445
446 loop {
447 let population = state.sample_population();
449
450 let mut fitness: Vec<(usize, f64)> = population
452 .iter()
453 .enumerate()
454 .map(|(i, xi)| {
455 let clipped = clip_to_bounds(xi, bounds);
456 let fval = f(&clipped);
457 (i, fval)
458 })
459 .collect();
460 state.fevals += population.len();
461
462 fitness.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
464
465 let best_this_gen_f = fitness[0].1;
467 if best_this_gen_f < best_f {
468 best_f = best_this_gen_f;
469 best_x = clip_to_bounds(&population[fitness[0].0], bounds);
470 }
471
472 if state.fevals >= max_fevals {
474 return Ok(CmaEsResult {
475 x: best_x,
476 fun: best_f,
477 fevals: state.fevals,
478 generations: state.generations,
479 converged: false,
480 message: "Maximum function evaluations reached".to_string(),
481 });
482 }
483
484 if state.sigma < xtol {
486 return Ok(CmaEsResult {
487 x: best_x,
488 fun: best_f,
489 fevals: state.fevals,
490 generations: state.generations,
491 converged: true,
492 message: "Step size (sigma) converged below xtol".to_string(),
493 });
494 }
495
496 state.update(&fitness, &population);
498 }
499}
500
501#[inline]
503fn clip_to_bounds(x: &Array1<f64>, bounds: Option<&[(f64, f64)]>) -> Array1<f64> {
504 match bounds {
505 None => x.clone(),
506 Some(b) => {
507 let clipped: Vec<f64> = x
508 .iter()
509 .zip(b.iter())
510 .map(|(&v, &(lo, hi))| v.clamp(lo, hi))
511 .collect();
512 Array1::from(clipped)
513 }
514 }
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520 use scirs2_core::ndarray::array;
521
522 #[test]
523 fn test_cmaes_sphere_1d() {
524 let result = minimize_cmaes(
526 |x| x[0] * x[0],
527 array![5.0],
528 None,
529 CmaEsConfig {
530 max_fevals: 2000,
531 initial_sigma: 1.0,
532 ftol: 1e-8,
533 xtol: 1e-6,
534 ..Default::default()
535 },
536 )
537 .expect("cmaes failed");
538 assert!(result.fun < 1e-6, "f* = {}, expected < 1e-6", result.fun);
539 assert!(
540 result.x[0].abs() < 1e-3,
541 "x* = {}, expected near 0",
542 result.x[0]
543 );
544 }
545
546 #[test]
547 fn test_cmaes_sphere_nd() {
548 let x0 = array![3.0, -2.0, 1.0, 4.0, -1.0];
550 let result = minimize_cmaes(
551 |x| x.iter().map(|&v| v * v).sum::<f64>(),
552 x0,
553 None,
554 CmaEsConfig {
555 max_fevals: 10_000,
556 initial_sigma: 1.0,
557 xtol: 1e-5,
558 ..Default::default()
559 },
560 )
561 .expect("cmaes failed");
562 assert!(result.fun < 1e-4, "f* = {}, expected < 1e-4", result.fun);
563 }
564
565 #[test]
566 fn test_cmaes_rosenbrock() {
567 let result = minimize_cmaes(
569 |x| {
570 let a = 1.0 - x[0];
571 let b = x[1] - x[0] * x[0];
572 a * a + 100.0 * b * b
573 },
574 array![0.0, 0.0],
575 None,
576 CmaEsConfig {
577 max_fevals: 20_000,
578 initial_sigma: 0.5,
579 xtol: 1e-8,
580 ..Default::default()
581 },
582 )
583 .expect("cmaes failed");
584 assert!(
585 result.fun < 0.01,
586 "Rosenbrock f* = {}, expected < 0.01",
587 result.fun
588 );
589 }
590
591 #[test]
592 fn test_cmaes_with_bounds() {
593 let result = minimize_cmaes(
596 |x| x[0] * x[0] + x[1] * x[1],
597 array![3.0, 3.0],
598 Some(&[(1.0, 5.0), (1.0, 5.0)]),
599 CmaEsConfig {
600 max_fevals: 5000,
601 initial_sigma: 0.5,
602 ..Default::default()
603 },
604 )
605 .expect("cmaes failed");
606 assert!(
607 result.x[0] >= 0.9 && result.x[0] <= 2.0,
608 "x[0] = {}, expected in [0.9, 2.0]",
609 result.x[0]
610 );
611 assert!(
612 result.x[1] >= 0.9 && result.x[1] <= 2.0,
613 "x[1] = {}, expected in [0.9, 2.0]",
614 result.x[1]
615 );
616 }
617
618 #[test]
619 fn test_cmaes_result_fevals_and_generations() {
620 let result = minimize_cmaes(|x| x[0] * x[0], array![1.0], None, CmaEsConfig::default())
621 .expect("cmaes failed");
622 assert!(result.fevals > 0, "fevals should be > 0");
623 assert!(result.generations > 0, "generations should be > 0");
624 assert!(!result.message.is_empty(), "message should not be empty");
625 }
626
627 #[test]
628 fn test_cmaes_invalid_dimension() {
629 use scirs2_core::ndarray::Array1;
630 let empty: Array1<f64> = Array1::from(vec![]);
631 let result = minimize_cmaes(
632 |x| x.iter().map(|&v| v * v).sum(),
633 empty,
634 None,
635 CmaEsConfig::default(),
636 );
637 assert!(result.is_err(), "empty input should return error");
638 }
639
640 #[test]
641 fn test_cmaes_bounds_mismatch_error() {
642 let result = minimize_cmaes(
643 |x| x[0] * x[0],
644 array![1.0, 2.0],
645 Some(&[(0.0, 5.0)]), CmaEsConfig::default(),
647 );
648 assert!(result.is_err(), "bounds mismatch should return error");
649 }
650
651 #[test]
652 fn test_cmaes_population_size_override() {
653 let result = minimize_cmaes(
655 |x| (x[0] - 1.0).powi(2) + (x[1] + 1.0).powi(2),
656 array![0.0, 0.0],
657 None,
658 CmaEsConfig {
659 population_size: Some(20),
660 max_fevals: 5000,
661 initial_sigma: 0.8,
662 ..Default::default()
663 },
664 )
665 .expect("cmaes failed");
666 assert!(result.fun < 0.01, "f* = {}", result.fun);
667 }
668
669 #[test]
670 fn test_cmaes_sigma_convergence() {
671 let result = minimize_cmaes(
674 |x| x[0] * x[0],
675 array![0.1],
676 None,
677 CmaEsConfig {
678 max_fevals: 100_000,
679 initial_sigma: 0.5,
680 xtol: 1e-3, ..Default::default()
682 },
683 )
684 .expect("cmaes failed");
685 assert!(result.converged || result.fevals >= 100_000);
687 }
688
689 #[test]
690 fn test_cmaes_quadratic_with_correlation() {
691 let result = minimize_cmaes(
694 |x| {
695 2.0 * x[0] * x[0] + 2.0 * x[0] * x[1] + 3.0 * x[1] * x[1]
697 },
698 array![3.0, -3.0],
699 None,
700 CmaEsConfig {
701 max_fevals: 5000,
702 initial_sigma: 1.0,
703 xtol: 1e-7,
704 ..Default::default()
705 },
706 )
707 .expect("cmaes failed");
708 assert!(
709 result.fun < 1e-4,
710 "Correlated quadratic f* = {}",
711 result.fun
712 );
713 }
714}