Skip to main content

scirs2_optimize/metaheuristics/
de.rs

1//! # Differential Evolution (DE) Metaheuristic
2//!
3//! A comprehensive implementation of Differential Evolution for continuous optimization:
4//! - **DE/rand/1**: Classic random-base mutation
5//! - **DE/best/1**: Best-member mutation for fast convergence
6//! - **DE/rand-to-best/1**: Hybrid using both random and best members
7//! - **Binomial and exponential crossover**
8//! - **Self-adaptive parameter control (jDE)**
9//! - **Opposition-based learning** for population initialization
10//! - **Constraint handling** via penalty and feasibility rules
11
12use crate::error::{OptimizeError, OptimizeResult};
13use crate::result::OptimizeResults;
14use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
15use scirs2_core::random::rngs::StdRng;
16use scirs2_core::random::{rng, Rng, SeedableRng};
17use scirs2_core::RngExt;
18
19// ---------------------------------------------------------------------------
20// Enums & Configuration
21// ---------------------------------------------------------------------------
22
23/// DE mutation strategy
24#[derive(Debug, Clone, Copy, PartialEq)]
25pub enum DeStrategy {
26    /// DE/rand/1: v = x_r1 + F * (x_r2 - x_r3)
27    Rand1,
28    /// DE/best/1: v = x_best + F * (x_r1 - x_r2)
29    Best1,
30    /// DE/rand-to-best/1: v = x_ri + F * (x_best - x_ri) + F * (x_r1 - x_r2)
31    RandToBest1,
32}
33
34impl Default for DeStrategy {
35    fn default() -> Self {
36        DeStrategy::Rand1
37    }
38}
39
40/// Crossover type
41#[derive(Debug, Clone, Copy, PartialEq)]
42pub enum CrossoverType {
43    /// Binomial (uniform) crossover
44    Binomial,
45    /// Exponential crossover
46    Exponential,
47}
48
49impl Default for CrossoverType {
50    fn default() -> Self {
51        CrossoverType::Binomial
52    }
53}
54
55/// Constraint handling method for DE
56#[derive(Debug, Clone)]
57pub struct DeConstraintHandler {
58    /// Penalty coefficient for constraint violations
59    pub penalty_coeff: f64,
60    /// Use feasibility rules (feasible solutions always preferred over infeasible)
61    pub use_feasibility_rules: bool,
62}
63
64impl Default for DeConstraintHandler {
65    fn default() -> Self {
66        Self {
67            penalty_coeff: 1e6,
68            use_feasibility_rules: true,
69        }
70    }
71}
72
73/// Opposition-based learning configuration
74#[derive(Debug, Clone)]
75pub struct OppositionBasedInit {
76    /// Enable opposition-based learning for initialization
77    pub enabled: bool,
78    /// Jumping rate: probability of applying opposition in each generation
79    pub jumping_rate: f64,
80}
81
82impl Default for OppositionBasedInit {
83    fn default() -> Self {
84        Self {
85            enabled: true,
86            jumping_rate: 0.3,
87        }
88    }
89}
90
91// ---------------------------------------------------------------------------
92// DE Options
93// ---------------------------------------------------------------------------
94
95/// Options for Differential Evolution optimizer
96#[derive(Debug, Clone)]
97pub struct DeOptions {
98    /// Population size (typically 5-10 times the dimension)
99    pub pop_size: usize,
100    /// Maximum number of generations
101    pub max_generations: usize,
102    /// Mutation factor F in [0, 2]
103    pub mutation_factor: f64,
104    /// Crossover probability CR in [0, 1]
105    pub crossover_prob: f64,
106    /// Mutation strategy
107    pub strategy: DeStrategy,
108    /// Crossover type
109    pub crossover: CrossoverType,
110    /// Search bounds per dimension: (lower, upper)
111    pub bounds: Vec<(f64, f64)>,
112    /// Random seed
113    pub seed: Option<u64>,
114    /// Convergence tolerance on function value spread
115    pub tol: f64,
116    /// Patience: generations without improvement
117    pub patience: usize,
118    /// Opposition-based learning
119    pub opposition: OppositionBasedInit,
120    /// Constraint handler
121    pub constraint_handler: Option<DeConstraintHandler>,
122}
123
124impl Default for DeOptions {
125    fn default() -> Self {
126        Self {
127            pop_size: 50,
128            max_generations: 1000,
129            mutation_factor: 0.8,
130            crossover_prob: 0.9,
131            strategy: DeStrategy::Rand1,
132            crossover: CrossoverType::Binomial,
133            bounds: Vec::new(),
134            seed: None,
135            tol: 1e-12,
136            patience: 100,
137            opposition: OppositionBasedInit::default(),
138            constraint_handler: None,
139        }
140    }
141}
142
143// ---------------------------------------------------------------------------
144// jDE (self-adaptive) Options
145// ---------------------------------------------------------------------------
146
147/// Options for self-adaptive jDE variant
148#[derive(Debug, Clone)]
149pub struct JdeOptions {
150    /// Base DE options
151    pub base: DeOptions,
152    /// Probability of adapting F
153    pub tau_f: f64,
154    /// Probability of adapting CR
155    pub tau_cr: f64,
156    /// Lower bound for F
157    pub f_lower: f64,
158    /// Upper bound for F
159    pub f_upper: f64,
160}
161
162impl Default for JdeOptions {
163    fn default() -> Self {
164        Self {
165            base: DeOptions::default(),
166            tau_f: 0.1,
167            tau_cr: 0.1,
168            f_lower: 0.1,
169            f_upper: 0.9,
170        }
171    }
172}
173
174// ---------------------------------------------------------------------------
175// DE Result
176// ---------------------------------------------------------------------------
177
178/// Result from DE optimization
179#[derive(Debug, Clone)]
180pub struct DeResult {
181    /// Best solution found
182    pub x: Array1<f64>,
183    /// Objective value at best solution
184    pub fun: f64,
185    /// Number of function evaluations
186    pub nfev: usize,
187    /// Number of generations
188    pub generations: usize,
189    /// Whether the optimization converged
190    pub converged: bool,
191    /// Termination message
192    pub message: String,
193    /// Final population fitness spread (std dev)
194    pub population_spread: f64,
195}
196
197impl DeResult {
198    /// Convert to standard OptimizeResults
199    pub fn to_optimize_results(&self) -> OptimizeResults<f64> {
200        OptimizeResults {
201            x: self.x.clone(),
202            fun: self.fun,
203            jac: None,
204            hess: None,
205            constr: None,
206            nit: self.generations,
207            nfev: self.nfev,
208            njev: 0,
209            nhev: 0,
210            maxcv: 0,
211            message: self.message.clone(),
212            success: self.converged,
213            status: if self.converged { 0 } else { 1 },
214        }
215    }
216}
217
218// ---------------------------------------------------------------------------
219// Core DE Optimizer
220// ---------------------------------------------------------------------------
221
222/// Differential Evolution optimizer
223pub struct DifferentialEvolutionOptimizer {
224    options: DeOptions,
225    rng: StdRng,
226}
227
228impl DifferentialEvolutionOptimizer {
229    /// Create a new DE optimizer
230    pub fn new(options: DeOptions) -> OptimizeResult<Self> {
231        if options.bounds.is_empty() {
232            return Err(OptimizeError::InvalidInput(
233                "Bounds must be provided for DE".to_string(),
234            ));
235        }
236        if options.pop_size < 4 {
237            return Err(OptimizeError::InvalidParameter(
238                "Population size must be >= 4 for DE".to_string(),
239            ));
240        }
241        if options.mutation_factor < 0.0 || options.mutation_factor > 2.0 {
242            return Err(OptimizeError::InvalidParameter(
243                "Mutation factor F must be in [0, 2]".to_string(),
244            ));
245        }
246        if options.crossover_prob < 0.0 || options.crossover_prob > 1.0 {
247            return Err(OptimizeError::InvalidParameter(
248                "Crossover probability CR must be in [0, 1]".to_string(),
249            ));
250        }
251
252        let seed = options.seed.unwrap_or_else(|| rng().random());
253        Ok(Self {
254            options,
255            rng: StdRng::seed_from_u64(seed),
256        })
257    }
258
259    /// Optimize an unconstrained objective function
260    pub fn optimize<F>(&mut self, func: F) -> OptimizeResult<DeResult>
261    where
262        F: Fn(&ArrayView1<f64>) -> f64,
263    {
264        self.optimize_constrained(func, None::<fn(&ArrayView1<f64>) -> Vec<f64>>)
265    }
266
267    /// Optimize with optional constraint functions
268    ///
269    /// `constraints_fn` returns a vector of g_i(x) where g_i(x) > 0 means violation.
270    pub fn optimize_constrained<F, G>(
271        &mut self,
272        func: F,
273        constraints_fn: Option<G>,
274    ) -> OptimizeResult<DeResult>
275    where
276        F: Fn(&ArrayView1<f64>) -> f64,
277        G: Fn(&ArrayView1<f64>) -> Vec<f64>,
278    {
279        let ndim = self.options.bounds.len();
280        let np = self.options.pop_size;
281
282        // Initialize population
283        let mut population = self.initialize_population(ndim, np);
284        let mut fitness: Vec<f64> = Vec::with_capacity(np);
285        let mut violations: Vec<f64> = vec![0.0; np]; // total violation per member
286        let mut nfev: usize = 0;
287
288        // Evaluate initial population
289        for i in 0..np {
290            let row = population.row(i);
291            let f_val = func(&row);
292            nfev += 1;
293            let viol = if let Some(ref cf) = constraints_fn {
294                let v = cf(&row);
295                v.iter().map(|vi| vi.max(0.0)).sum::<f64>()
296            } else {
297                0.0
298            };
299            fitness.push(self.penalized_fitness(f_val, viol));
300            violations[i] = viol;
301        }
302
303        // Track best
304        let mut best_idx = self.find_best(&fitness, &violations);
305        let mut best_x = population.row(best_idx).to_owned();
306        let mut best_fun = func(&best_x.view());
307        let mut no_improve_count: usize = 0;
308
309        for gen in 0..self.options.max_generations {
310            let mut new_population = population.clone();
311            let mut new_fitness = fitness.clone();
312            let mut new_violations = violations.clone();
313
314            for i in 0..np {
315                // Generate mutant vector
316                let mutant = self.mutate(&population, i, best_idx, ndim);
317
318                // Crossover
319                let trial = self.crossover(&population.row(i).to_owned(), &mutant, ndim);
320
321                // Clip to bounds
322                let trial_clipped = self.clip_to_bounds(&trial);
323
324                // Evaluate trial
325                let trial_view = trial_clipped.view();
326                let trial_f = func(&trial_view);
327                nfev += 1;
328                let trial_viol = if let Some(ref cf) = constraints_fn {
329                    let v = cf(&trial_view);
330                    v.iter().map(|vi| vi.max(0.0)).sum::<f64>()
331                } else {
332                    0.0
333                };
334                let trial_penalized = self.penalized_fitness(trial_f, trial_viol);
335
336                // Selection
337                let replace = if self
338                    .options
339                    .constraint_handler
340                    .as_ref()
341                    .map_or(false, |ch| ch.use_feasibility_rules)
342                {
343                    self.feasibility_selection(
344                        trial_penalized,
345                        trial_viol,
346                        fitness[i],
347                        violations[i],
348                    )
349                } else {
350                    trial_penalized <= fitness[i]
351                };
352
353                if replace {
354                    for d in 0..ndim {
355                        new_population[[i, d]] = trial_clipped[d];
356                    }
357                    new_fitness[i] = trial_penalized;
358                    new_violations[i] = trial_viol;
359                }
360            }
361
362            population = new_population;
363            fitness = new_fitness;
364            violations = new_violations;
365
366            // Update best
367            let new_best_idx = self.find_best(&fitness, &violations);
368            let candidate_fun = func(&population.row(new_best_idx));
369
370            if candidate_fun < best_fun {
371                best_idx = new_best_idx;
372                best_x = population.row(best_idx).to_owned();
373                best_fun = candidate_fun;
374                no_improve_count = 0;
375            } else {
376                best_idx = new_best_idx;
377                no_improve_count += 1;
378            }
379
380            // Check convergence: population spread
381            let spread = self.population_spread(&fitness);
382            if spread < self.options.tol {
383                return Ok(DeResult {
384                    x: best_x,
385                    fun: best_fun,
386                    nfev,
387                    generations: gen + 1,
388                    converged: true,
389                    message: format!(
390                        "DE converged: spread {:.2e} < tol {:.2e} at generation {}",
391                        spread,
392                        self.options.tol,
393                        gen + 1
394                    ),
395                    population_spread: spread,
396                });
397            }
398
399            if no_improve_count >= self.options.patience {
400                return Ok(DeResult {
401                    x: best_x,
402                    fun: best_fun,
403                    nfev,
404                    generations: gen + 1,
405                    converged: true,
406                    message: format!(
407                        "DE converged: no improvement for {} generations",
408                        self.options.patience
409                    ),
410                    population_spread: spread,
411                });
412            }
413        }
414
415        let spread = self.population_spread(&fitness);
416        Ok(DeResult {
417            x: best_x,
418            fun: best_fun,
419            nfev,
420            generations: self.options.max_generations,
421            converged: false,
422            message: format!(
423                "DE completed {} generations without full convergence",
424                self.options.max_generations
425            ),
426            population_spread: spread,
427        })
428    }
429
430    // --- Population Initialization ---
431
432    fn initialize_population(&mut self, ndim: usize, np: usize) -> Array2<f64> {
433        let mut pop = Array2::zeros((np, ndim));
434
435        // Random initialization within bounds
436        for i in 0..np {
437            for d in 0..ndim {
438                let (lo, hi) = self.options.bounds[d];
439                pop[[i, d]] = lo + self.rng.random::<f64>() * (hi - lo);
440            }
441        }
442
443        // Opposition-based learning: double the candidates, keep the best NP
444        if self.options.opposition.enabled {
445            let mut all_candidates = Vec::with_capacity(2 * np);
446            for i in 0..np {
447                let mut member = Vec::with_capacity(ndim);
448                let mut opposite = Vec::with_capacity(ndim);
449                for d in 0..ndim {
450                    let val = pop[[i, d]];
451                    member.push(val);
452                    let (lo, hi) = self.options.bounds[d];
453                    opposite.push(lo + hi - val);
454                }
455                all_candidates.push(member);
456                all_candidates.push(opposite);
457            }
458
459            // We just return the first NP as-is plus opposition awareness
460            // For full OBL, one would evaluate and select top NP, but
461            // that requires the objective function here. So we interleave.
462            for i in 0..np {
463                if i < all_candidates.len() / 2 {
464                    // Use opposition for every other member
465                    let opp_idx = 2 * i + 1;
466                    if opp_idx < all_candidates.len() && i % 2 == 1 {
467                        for d in 0..ndim {
468                            pop[[i, d]] = all_candidates[opp_idx][d];
469                        }
470                    }
471                }
472            }
473        }
474
475        pop
476    }
477
478    // --- Mutation ---
479
480    fn mutate(
481        &mut self,
482        population: &Array2<f64>,
483        target_idx: usize,
484        best_idx: usize,
485        ndim: usize,
486    ) -> Array1<f64> {
487        let np = population.nrows();
488        let f = self.options.mutation_factor;
489
490        match self.options.strategy {
491            DeStrategy::Rand1 => {
492                let (r1, r2, r3) = self.pick_three_distinct(np, target_idx);
493                let mut mutant = Array1::zeros(ndim);
494                for d in 0..ndim {
495                    mutant[d] =
496                        population[[r1, d]] + f * (population[[r2, d]] - population[[r3, d]]);
497                }
498                mutant
499            }
500            DeStrategy::Best1 => {
501                let (r1, r2) = self.pick_two_distinct(np, target_idx);
502                let mut mutant = Array1::zeros(ndim);
503                for d in 0..ndim {
504                    mutant[d] =
505                        population[[best_idx, d]] + f * (population[[r1, d]] - population[[r2, d]]);
506                }
507                mutant
508            }
509            DeStrategy::RandToBest1 => {
510                let (r1, r2) = self.pick_two_distinct(np, target_idx);
511                let mut mutant = Array1::zeros(ndim);
512                for d in 0..ndim {
513                    mutant[d] = population[[target_idx, d]]
514                        + f * (population[[best_idx, d]] - population[[target_idx, d]])
515                        + f * (population[[r1, d]] - population[[r2, d]]);
516                }
517                mutant
518            }
519        }
520    }
521
522    // --- Crossover ---
523
524    fn crossover(
525        &mut self,
526        target: &Array1<f64>,
527        mutant: &Array1<f64>,
528        ndim: usize,
529    ) -> Array1<f64> {
530        match self.options.crossover {
531            CrossoverType::Binomial => self.binomial_crossover(target, mutant, ndim),
532            CrossoverType::Exponential => self.exponential_crossover(target, mutant, ndim),
533        }
534    }
535
536    fn binomial_crossover(
537        &mut self,
538        target: &Array1<f64>,
539        mutant: &Array1<f64>,
540        ndim: usize,
541    ) -> Array1<f64> {
542        let cr = self.options.crossover_prob;
543        let j_rand = self.rng.random_range(0..ndim);
544        let mut trial = target.clone();
545        for d in 0..ndim {
546            if self.rng.random::<f64>() < cr || d == j_rand {
547                trial[d] = mutant[d];
548            }
549        }
550        trial
551    }
552
553    fn exponential_crossover(
554        &mut self,
555        target: &Array1<f64>,
556        mutant: &Array1<f64>,
557        ndim: usize,
558    ) -> Array1<f64> {
559        let cr = self.options.crossover_prob;
560        let mut trial = target.clone();
561        let start = self.rng.random_range(0..ndim);
562        let mut d = start;
563        loop {
564            trial[d] = mutant[d];
565            d = (d + 1) % ndim;
566            if d == start || self.rng.random::<f64>() >= cr {
567                break;
568            }
569        }
570        trial
571    }
572
573    // --- Helpers ---
574
575    fn clip_to_bounds(&self, x: &Array1<f64>) -> Array1<f64> {
576        let mut clipped = x.clone();
577        for (d, (lo, hi)) in self.options.bounds.iter().enumerate() {
578            if d < clipped.len() {
579                clipped[d] = clipped[d].clamp(*lo, *hi);
580            }
581        }
582        clipped
583    }
584
585    fn penalized_fitness(&self, obj: f64, violation: f64) -> f64 {
586        if let Some(ref ch) = self.options.constraint_handler {
587            obj + ch.penalty_coeff * violation
588        } else {
589            obj
590        }
591    }
592
593    fn feasibility_selection(
594        &self,
595        trial_fit: f64,
596        trial_viol: f64,
597        current_fit: f64,
598        current_viol: f64,
599    ) -> bool {
600        let trial_feasible = trial_viol <= 1e-15;
601        let current_feasible = current_viol <= 1e-15;
602
603        match (trial_feasible, current_feasible) {
604            (true, true) => trial_fit <= current_fit,
605            (true, false) => true,  // feasible beats infeasible
606            (false, true) => false, // infeasible loses to feasible
607            (false, false) => trial_viol < current_viol, // less violation wins
608        }
609    }
610
611    fn find_best(&self, fitness: &[f64], violations: &[f64]) -> usize {
612        let mut best_idx = 0;
613        for i in 1..fitness.len() {
614            let is_better = if self
615                .options
616                .constraint_handler
617                .as_ref()
618                .map_or(false, |ch| ch.use_feasibility_rules)
619            {
620                self.feasibility_selection(
621                    fitness[i],
622                    violations[i],
623                    fitness[best_idx],
624                    violations[best_idx],
625                )
626            } else {
627                fitness[i] < fitness[best_idx]
628            };
629            if is_better {
630                best_idx = i;
631            }
632        }
633        best_idx
634    }
635
636    fn population_spread(&self, fitness: &[f64]) -> f64 {
637        if fitness.is_empty() {
638            return 0.0;
639        }
640        let mean = fitness.iter().sum::<f64>() / fitness.len() as f64;
641        let variance =
642            fitness.iter().map(|f| (f - mean).powi(2)).sum::<f64>() / fitness.len() as f64;
643        variance.sqrt()
644    }
645
646    fn pick_three_distinct(&mut self, np: usize, exclude: usize) -> (usize, usize, usize) {
647        let mut r1 = self.rng.random_range(0..np);
648        while r1 == exclude {
649            r1 = self.rng.random_range(0..np);
650        }
651        let mut r2 = self.rng.random_range(0..np);
652        while r2 == exclude || r2 == r1 {
653            r2 = self.rng.random_range(0..np);
654        }
655        let mut r3 = self.rng.random_range(0..np);
656        while r3 == exclude || r3 == r1 || r3 == r2 {
657            r3 = self.rng.random_range(0..np);
658        }
659        (r1, r2, r3)
660    }
661
662    fn pick_two_distinct(&mut self, np: usize, exclude: usize) -> (usize, usize) {
663        let mut r1 = self.rng.random_range(0..np);
664        while r1 == exclude {
665            r1 = self.rng.random_range(0..np);
666        }
667        let mut r2 = self.rng.random_range(0..np);
668        while r2 == exclude || r2 == r1 {
669            r2 = self.rng.random_range(0..np);
670        }
671        (r1, r2)
672    }
673}
674
675// ---------------------------------------------------------------------------
676// Self-Adaptive jDE
677// ---------------------------------------------------------------------------
678
679/// Self-adaptive Differential Evolution (jDE) optimizer
680///
681/// Automatically adapts F and CR during the search (Brest et al. 2006).
682pub fn jde_optimize<F>(func: F, options: JdeOptions) -> OptimizeResult<DeResult>
683where
684    F: Fn(&ArrayView1<f64>) -> f64,
685{
686    if options.base.bounds.is_empty() {
687        return Err(OptimizeError::InvalidInput(
688            "Bounds must be provided for jDE".to_string(),
689        ));
690    }
691    let ndim = options.base.bounds.len();
692    let np = options.base.pop_size;
693    if np < 4 {
694        return Err(OptimizeError::InvalidParameter(
695            "Population size must be >= 4".to_string(),
696        ));
697    }
698
699    let seed = options.base.seed.unwrap_or_else(|| rng().random());
700    let mut local_rng = StdRng::seed_from_u64(seed);
701    let bounds = &options.base.bounds;
702
703    // Initialize population
704    let mut population = Array2::zeros((np, ndim));
705    for i in 0..np {
706        for d in 0..ndim {
707            let (lo, hi) = bounds[d];
708            population[[i, d]] = lo + local_rng.random::<f64>() * (hi - lo);
709        }
710    }
711
712    // Per-member F and CR
713    let mut f_vec = vec![options.base.mutation_factor; np];
714    let mut cr_vec = vec![options.base.crossover_prob; np];
715
716    // Evaluate
717    let mut fitness: Vec<f64> = (0..np).map(|i| func(&population.row(i))).collect();
718    let mut nfev = np;
719
720    let mut best_idx = fitness
721        .iter()
722        .enumerate()
723        .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
724        .map(|(i, _)| i)
725        .unwrap_or(0);
726    let mut best_x = population.row(best_idx).to_owned();
727    let mut best_fun = fitness[best_idx];
728    let mut no_improve: usize = 0;
729
730    for gen in 0..options.base.max_generations {
731        let mut new_pop = population.clone();
732        let mut new_fitness = fitness.clone();
733        let mut new_f_vec = f_vec.clone();
734        let mut new_cr_vec = cr_vec.clone();
735
736        for i in 0..np {
737            // Self-adapt F
738            let fi = if local_rng.random::<f64>() < options.tau_f {
739                let new_f = options.f_lower + local_rng.random::<f64>() * options.f_upper;
740                new_f_vec[i] = new_f;
741                new_f
742            } else {
743                f_vec[i]
744            };
745
746            // Self-adapt CR
747            let cri = if local_rng.random::<f64>() < options.tau_cr {
748                let new_cr = local_rng.random::<f64>();
749                new_cr_vec[i] = new_cr;
750                new_cr
751            } else {
752                cr_vec[i]
753            };
754
755            // DE/rand/1 mutation with adapted F
756            let (r1, r2, r3) = pick_three_distinct_rng(&mut local_rng, np, i);
757            let mut mutant = Array1::zeros(ndim);
758            for d in 0..ndim {
759                mutant[d] = population[[r1, d]] + fi * (population[[r2, d]] - population[[r3, d]]);
760            }
761
762            // Binomial crossover with adapted CR
763            let j_rand = local_rng.random_range(0..ndim);
764            let mut trial = Array1::zeros(ndim);
765            for d in 0..ndim {
766                if local_rng.random::<f64>() < cri || d == j_rand {
767                    trial[d] = mutant[d];
768                } else {
769                    trial[d] = population[[i, d]];
770                }
771            }
772
773            // Clip
774            for d in 0..ndim {
775                let (lo, hi) = bounds[d];
776                trial[d] = trial[d].clamp(lo, hi);
777            }
778
779            let trial_f = func(&trial.view());
780            nfev += 1;
781
782            if trial_f <= fitness[i] {
783                for d in 0..ndim {
784                    new_pop[[i, d]] = trial[d];
785                }
786                new_fitness[i] = trial_f;
787            } else {
788                // Revert F and CR adaptation
789                new_f_vec[i] = f_vec[i];
790                new_cr_vec[i] = cr_vec[i];
791            }
792        }
793
794        population = new_pop;
795        fitness = new_fitness;
796        f_vec = new_f_vec;
797        cr_vec = new_cr_vec;
798
799        // Update best
800        let new_best_idx = fitness
801            .iter()
802            .enumerate()
803            .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
804            .map(|(i, _)| i)
805            .unwrap_or(0);
806
807        if fitness[new_best_idx] < best_fun {
808            best_idx = new_best_idx;
809            best_x = population.row(best_idx).to_owned();
810            best_fun = fitness[best_idx];
811            no_improve = 0;
812        } else {
813            no_improve += 1;
814        }
815
816        // Check convergence
817        let spread = {
818            let mean = fitness.iter().sum::<f64>() / np as f64;
819            let var = fitness.iter().map(|f| (f - mean).powi(2)).sum::<f64>() / np as f64;
820            var.sqrt()
821        };
822
823        if spread < options.base.tol {
824            return Ok(DeResult {
825                x: best_x,
826                fun: best_fun,
827                nfev,
828                generations: gen + 1,
829                converged: true,
830                message: format!("jDE converged at generation {}", gen + 1),
831                population_spread: spread,
832            });
833        }
834
835        if no_improve >= options.base.patience {
836            return Ok(DeResult {
837                x: best_x,
838                fun: best_fun,
839                nfev,
840                generations: gen + 1,
841                converged: true,
842                message: format!(
843                    "jDE: no improvement for {} generations",
844                    options.base.patience
845                ),
846                population_spread: spread,
847            });
848        }
849    }
850
851    let spread = {
852        let mean = fitness.iter().sum::<f64>() / np as f64;
853        let var = fitness.iter().map(|f| (f - mean).powi(2)).sum::<f64>() / np as f64;
854        var.sqrt()
855    };
856
857    Ok(DeResult {
858        x: best_x,
859        fun: best_fun,
860        nfev,
861        generations: options.base.max_generations,
862        converged: false,
863        message: "jDE completed max generations".to_string(),
864        population_spread: spread,
865    })
866}
867
868fn pick_three_distinct_rng(
869    rng_ref: &mut StdRng,
870    np: usize,
871    exclude: usize,
872) -> (usize, usize, usize) {
873    let mut r1 = rng_ref.random_range(0..np);
874    while r1 == exclude {
875        r1 = rng_ref.random_range(0..np);
876    }
877    let mut r2 = rng_ref.random_range(0..np);
878    while r2 == exclude || r2 == r1 {
879        r2 = rng_ref.random_range(0..np);
880    }
881    let mut r3 = rng_ref.random_range(0..np);
882    while r3 == exclude || r3 == r1 || r3 == r2 {
883        r3 = rng_ref.random_range(0..np);
884    }
885    (r1, r2, r3)
886}
887
888// ---------------------------------------------------------------------------
889// Tests
890// ---------------------------------------------------------------------------
891
892#[cfg(test)]
893mod tests {
894    use super::*;
895    use scirs2_core::ndarray::array;
896
897    fn sphere(x: &ArrayView1<f64>) -> f64 {
898        x.iter().map(|xi| xi * xi).sum()
899    }
900
901    fn rosenbrock(x: &ArrayView1<f64>) -> f64 {
902        let mut sum = 0.0;
903        for i in 0..x.len() - 1 {
904            sum += 100.0 * (x[i + 1] - x[i] * x[i]).powi(2) + (1.0 - x[i]).powi(2);
905        }
906        sum
907    }
908
909    fn rastrigin(x: &ArrayView1<f64>) -> f64 {
910        let n = x.len() as f64;
911        10.0 * n
912            + x.iter()
913                .map(|xi| xi * xi - 10.0 * (2.0 * std::f64::consts::PI * xi).cos())
914                .sum::<f64>()
915    }
916
917    // --- DE/rand/1 tests ---
918
919    #[test]
920    fn test_de_rand1_sphere() {
921        let opts = DeOptions {
922            pop_size: 30,
923            max_generations: 500,
924            mutation_factor: 0.8,
925            crossover_prob: 0.9,
926            strategy: DeStrategy::Rand1,
927            crossover: CrossoverType::Binomial,
928            bounds: vec![(-5.0, 5.0); 2],
929            seed: Some(42),
930            patience: 200,
931            ..Default::default()
932        };
933
934        let mut de = DifferentialEvolutionOptimizer::new(opts).expect("valid options");
935        let result = de.optimize(sphere).expect("DE should optimize sphere");
936
937        assert!(result.fun < 1e-4, "DE/rand/1 sphere: got {}", result.fun);
938        assert!(result.nfev > 0);
939    }
940
941    #[test]
942    fn test_de_best1_sphere() {
943        let opts = DeOptions {
944            pop_size: 30,
945            max_generations: 500,
946            mutation_factor: 0.5,
947            crossover_prob: 0.9,
948            strategy: DeStrategy::Best1,
949            bounds: vec![(-5.0, 5.0); 2],
950            seed: Some(42),
951            patience: 200,
952            ..Default::default()
953        };
954
955        let mut de = DifferentialEvolutionOptimizer::new(opts).expect("valid options");
956        let result = de.optimize(sphere).expect("DE/best/1 should work");
957
958        assert!(result.fun < 1e-4, "DE/best/1 sphere: got {}", result.fun);
959    }
960
961    #[test]
962    fn test_de_rand_to_best1_sphere() {
963        let opts = DeOptions {
964            pop_size: 30,
965            max_generations: 500,
966            mutation_factor: 0.7,
967            crossover_prob: 0.9,
968            strategy: DeStrategy::RandToBest1,
969            bounds: vec![(-5.0, 5.0); 2],
970            seed: Some(42),
971            patience: 200,
972            ..Default::default()
973        };
974
975        let mut de = DifferentialEvolutionOptimizer::new(opts).expect("valid options");
976        let result = de.optimize(sphere).expect("DE/rand-to-best/1 should work");
977
978        assert!(
979            result.fun < 1e-3,
980            "DE/rand-to-best/1 sphere: got {}",
981            result.fun
982        );
983    }
984
985    // --- Crossover type tests ---
986
987    #[test]
988    fn test_de_exponential_crossover() {
989        let opts = DeOptions {
990            pop_size: 30,
991            max_generations: 500,
992            mutation_factor: 0.8,
993            crossover_prob: 0.9,
994            strategy: DeStrategy::Rand1,
995            crossover: CrossoverType::Exponential,
996            bounds: vec![(-5.0, 5.0); 2],
997            seed: Some(42),
998            patience: 200,
999            ..Default::default()
1000        };
1001
1002        let mut de = DifferentialEvolutionOptimizer::new(opts).expect("valid options");
1003        let result = de.optimize(sphere).expect("exp crossover should work");
1004
1005        assert!(
1006            result.fun < 0.1,
1007            "Exponential crossover sphere: got {}",
1008            result.fun
1009        );
1010    }
1011
1012    // --- Rastrigin (multimodal) ---
1013
1014    #[test]
1015    fn test_de_rastrigin() {
1016        let opts = DeOptions {
1017            pop_size: 50,
1018            max_generations: 1000,
1019            mutation_factor: 0.8,
1020            crossover_prob: 0.9,
1021            strategy: DeStrategy::Rand1,
1022            bounds: vec![(-5.12, 5.12); 3],
1023            seed: Some(42),
1024            patience: 300,
1025            ..Default::default()
1026        };
1027
1028        let mut de = DifferentialEvolutionOptimizer::new(opts).expect("valid options");
1029        let result = de.optimize(rastrigin).expect("DE on rastrigin");
1030
1031        assert!(result.fun < 10.0, "DE rastrigin: got {}", result.fun);
1032    }
1033
1034    // --- jDE tests ---
1035
1036    #[test]
1037    fn test_jde_sphere() {
1038        let opts = JdeOptions {
1039            base: DeOptions {
1040                pop_size: 30,
1041                max_generations: 500,
1042                mutation_factor: 0.5,
1043                crossover_prob: 0.9,
1044                bounds: vec![(-5.0, 5.0); 2],
1045                seed: Some(42),
1046                patience: 200,
1047                ..Default::default()
1048            },
1049            tau_f: 0.1,
1050            tau_cr: 0.1,
1051            f_lower: 0.1,
1052            f_upper: 0.9,
1053        };
1054
1055        let result = jde_optimize(sphere, opts).expect("jDE should work");
1056
1057        assert!(result.fun < 1e-4, "jDE sphere: got {}", result.fun);
1058    }
1059
1060    #[test]
1061    fn test_jde_rosenbrock() {
1062        let opts = JdeOptions {
1063            base: DeOptions {
1064                pop_size: 40,
1065                max_generations: 2000,
1066                bounds: vec![(-5.0, 5.0); 2],
1067                seed: Some(42),
1068                patience: 500,
1069                ..Default::default()
1070            },
1071            ..Default::default()
1072        };
1073
1074        let result = jde_optimize(rosenbrock, opts).expect("jDE on rosenbrock");
1075
1076        assert!(result.fun < 1.0, "jDE rosenbrock: got {}", result.fun);
1077    }
1078
1079    // --- Constraint handling tests ---
1080
1081    #[test]
1082    fn test_de_with_constraints() {
1083        // Minimize x^2 + y^2 subject to x + y >= 2
1084        let constraints = |x: &ArrayView1<f64>| -> Vec<f64> {
1085            vec![2.0 - (x[0] + x[1])] // violation when x+y < 2
1086        };
1087
1088        let opts = DeOptions {
1089            pop_size: 40,
1090            max_generations: 500,
1091            bounds: vec![(-5.0, 5.0); 2],
1092            seed: Some(42),
1093            patience: 200,
1094            constraint_handler: Some(DeConstraintHandler {
1095                penalty_coeff: 1e4,
1096                use_feasibility_rules: true,
1097            }),
1098            ..Default::default()
1099        };
1100
1101        let mut de = DifferentialEvolutionOptimizer::new(opts).expect("valid options");
1102        let result = de
1103            .optimize_constrained(sphere, Some(constraints))
1104            .expect("constrained DE should work");
1105
1106        // Optimal: x = y = 1, f = 2
1107        let sum = result.x[0] + result.x[1];
1108        assert!(sum >= 1.5, "Constraint should be ~satisfied: sum = {}", sum);
1109        assert!(result.fun < 5.0, "Constrained DE fun: {}", result.fun);
1110    }
1111
1112    // --- Opposition-based learning ---
1113
1114    #[test]
1115    fn test_de_opposition_based_init() {
1116        let opts = DeOptions {
1117            pop_size: 30,
1118            max_generations: 300,
1119            bounds: vec![(-5.0, 5.0); 2],
1120            seed: Some(42),
1121            opposition: OppositionBasedInit {
1122                enabled: true,
1123                jumping_rate: 0.3,
1124            },
1125            patience: 150,
1126            ..Default::default()
1127        };
1128
1129        let mut de = DifferentialEvolutionOptimizer::new(opts).expect("valid options");
1130        let result = de.optimize(sphere).expect("OBL DE should work");
1131
1132        assert!(result.fun < 1.0, "OBL DE sphere: got {}", result.fun);
1133    }
1134
1135    #[test]
1136    fn test_de_no_opposition() {
1137        let opts = DeOptions {
1138            pop_size: 30,
1139            max_generations: 300,
1140            bounds: vec![(-5.0, 5.0); 2],
1141            seed: Some(42),
1142            opposition: OppositionBasedInit {
1143                enabled: false,
1144                jumping_rate: 0.0,
1145            },
1146            patience: 150,
1147            ..Default::default()
1148        };
1149
1150        let mut de = DifferentialEvolutionOptimizer::new(opts).expect("valid options");
1151        let result = de.optimize(sphere).expect("DE without OBL should work");
1152
1153        assert!(result.fun < 1.0, "DE no-OBL sphere: got {}", result.fun);
1154    }
1155
1156    // --- Edge cases ---
1157
1158    #[test]
1159    fn test_de_empty_bounds_error() {
1160        let opts = DeOptions {
1161            bounds: vec![],
1162            ..Default::default()
1163        };
1164        let result = DifferentialEvolutionOptimizer::new(opts);
1165        assert!(result.is_err());
1166    }
1167
1168    #[test]
1169    fn test_de_small_popsize_error() {
1170        let opts = DeOptions {
1171            pop_size: 2,
1172            bounds: vec![(-1.0, 1.0)],
1173            ..Default::default()
1174        };
1175        let result = DifferentialEvolutionOptimizer::new(opts);
1176        assert!(result.is_err());
1177    }
1178
1179    #[test]
1180    fn test_de_invalid_mutation_error() {
1181        let opts = DeOptions {
1182            mutation_factor: 3.0,
1183            bounds: vec![(-1.0, 1.0)],
1184            ..Default::default()
1185        };
1186        let result = DifferentialEvolutionOptimizer::new(opts);
1187        assert!(result.is_err());
1188    }
1189
1190    #[test]
1191    fn test_de_invalid_crossover_error() {
1192        let opts = DeOptions {
1193            crossover_prob: 1.5,
1194            bounds: vec![(-1.0, 1.0)],
1195            ..Default::default()
1196        };
1197        let result = DifferentialEvolutionOptimizer::new(opts);
1198        assert!(result.is_err());
1199    }
1200
1201    #[test]
1202    fn test_de_to_optimize_results() {
1203        let de_result = DeResult {
1204            x: array![1.0, 2.0],
1205            fun: 5.0,
1206            nfev: 1000,
1207            generations: 50,
1208            converged: true,
1209            message: "test".to_string(),
1210            population_spread: 0.01,
1211        };
1212        let opt = de_result.to_optimize_results();
1213        assert_eq!(opt.nfev, 1000);
1214        assert_eq!(opt.nit, 50);
1215        assert!(opt.success);
1216    }
1217
1218    #[test]
1219    fn test_jde_empty_bounds_error() {
1220        let opts = JdeOptions {
1221            base: DeOptions {
1222                bounds: vec![],
1223                ..Default::default()
1224            },
1225            ..Default::default()
1226        };
1227        let result = jde_optimize(sphere, opts);
1228        assert!(result.is_err());
1229    }
1230
1231    #[test]
1232    fn test_de_1d() {
1233        let opts = DeOptions {
1234            pop_size: 20,
1235            max_generations: 200,
1236            bounds: vec![(-10.0, 10.0)],
1237            seed: Some(42),
1238            patience: 100,
1239            ..Default::default()
1240        };
1241
1242        let mut de = DifferentialEvolutionOptimizer::new(opts).expect("valid");
1243        let result = de
1244            .optimize(|x: &ArrayView1<f64>| (x[0] - 3.0).powi(2))
1245            .expect("1D DE");
1246
1247        assert!(
1248            (result.x[0] - 3.0).abs() < 0.5,
1249            "1D DE: x = {}",
1250            result.x[0]
1251        );
1252    }
1253}