sklears_multioutput/optimization/
nsga2_algorithms.rs

1//! Enhanced Evolutionary Multi-Objective Optimization Algorithms
2//!
3//! This module provides NSGA-II (Non-dominated Sorting Genetic Algorithm II) and related
4//! evolutionary algorithms for multi-objective optimization. NSGA-II is one of the most
5//! popular and effective multi-objective evolutionary algorithms.
6//!
7//! ## Key Features
8//!
9//! - **Non-dominated Sorting**: Fast and efficient ranking of solutions based on Pareto dominance
10//! - **Crowding Distance**: Maintains diversity in the population and Pareto front
11//! - **Multiple Algorithm Variants**: Standard NSGA-II, SBX crossover, and differential evolution
12//! - **Advanced Operators**: Simulated binary crossover (SBX) and polynomial mutation
13//! - **Elitism**: Preserves good solutions across generations
14//! - **Hypervolume Tracking**: Monitors convergence quality over generations
15
16// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
17use scirs2_core::ndarray::{array, s, Array1, Array2, ArrayView2};
18use scirs2_core::random::RandNormal;
19use scirs2_core::random::Rng;
20use sklears_core::{
21    error::{Result as SklResult, SklearsError},
22    traits::{Estimator, Fit, Predict, Untrained},
23    types::Float,
24};
25
26use super::multi_objective_optimization::ParetoSolution;
27
28/// NSGA-II (Non-dominated Sorting Genetic Algorithm II) algorithm types
29#[derive(Debug, Clone, PartialEq)]
30pub enum NSGA2Algorithm {
31    /// Standard NSGA-II
32    Standard,
33    /// NSGA-II with simulated binary crossover
34    SBX,
35    /// NSGA-II with differential evolution
36    DE,
37}
38
39/// NSGA-II Configuration
40#[derive(Debug, Clone)]
41pub struct NSGA2Config {
42    /// Population size
43    pub population_size: usize,
44    /// Number of generations
45    pub generations: usize,
46    /// Crossover probability
47    pub crossover_prob: Float,
48    /// Mutation probability
49    pub mutation_prob: Float,
50    /// Distribution index for SBX crossover
51    pub eta_c: Float,
52    /// Distribution index for polynomial mutation
53    pub eta_m: Float,
54    /// Algorithm variant
55    pub algorithm: NSGA2Algorithm,
56    /// Random state for reproducibility
57    pub random_state: Option<u64>,
58}
59
60impl Default for NSGA2Config {
61    fn default() -> Self {
62        Self {
63            population_size: 100,
64            generations: 250,
65            crossover_prob: 0.9,
66            mutation_prob: 0.1,
67            eta_c: 20.0,
68            eta_m: 20.0,
69            algorithm: NSGA2Algorithm::Standard,
70            random_state: None,
71        }
72    }
73}
74
75/// NSGA-II Multi-Objective Optimizer
76#[derive(Debug, Clone)]
77pub struct NSGA2Optimizer<S = Untrained> {
78    state: S,
79    config: NSGA2Config,
80}
81
82/// Trained state for NSGA-II Optimizer
83#[derive(Debug, Clone)]
84pub struct NSGA2OptimizerTrained {
85    /// Pareto-optimal solutions
86    pub pareto_solutions: Vec<ParetoSolution>,
87    /// Best compromise solution
88    pub best_solution: ParetoSolution,
89    /// Convergence history (hypervolume indicator)
90    pub convergence_history: Vec<Float>,
91    /// Final population
92    pub final_population: Vec<ParetoSolution>,
93    /// Configuration used for optimization
94    pub config: NSGA2Config,
95    /// Number of objectives
96    pub n_objectives: usize,
97}
98
99impl Default for NSGA2Optimizer<Untrained> {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105impl NSGA2Optimizer<Untrained> {
106    /// Create a new NSGA-II Optimizer
107    pub fn new() -> Self {
108        Self {
109            state: Untrained,
110            config: NSGA2Config::default(),
111        }
112    }
113
114    /// Set the configuration
115    pub fn config(mut self, config: NSGA2Config) -> Self {
116        self.config = config;
117        self
118    }
119
120    /// Set the population size
121    pub fn population_size(mut self, population_size: usize) -> Self {
122        self.config.population_size = population_size;
123        self
124    }
125
126    /// Set the number of generations
127    pub fn generations(mut self, generations: usize) -> Self {
128        self.config.generations = generations;
129        self
130    }
131
132    /// Set the crossover probability
133    pub fn crossover_prob(mut self, crossover_prob: Float) -> Self {
134        self.config.crossover_prob = crossover_prob;
135        self
136    }
137
138    /// Set the mutation probability
139    pub fn mutation_prob(mut self, mutation_prob: Float) -> Self {
140        self.config.mutation_prob = mutation_prob;
141        self
142    }
143
144    /// Set the algorithm variant
145    pub fn algorithm(mut self, algorithm: NSGA2Algorithm) -> Self {
146        self.config.algorithm = algorithm;
147        self
148    }
149}
150
151impl Estimator for NSGA2Optimizer<Untrained> {
152    type Config = NSGA2Config;
153    type Error = SklearsError;
154    type Float = Float;
155
156    fn config(&self) -> &Self::Config {
157        &self.config
158    }
159}
160
161impl Estimator for NSGA2Optimizer<NSGA2OptimizerTrained> {
162    type Config = NSGA2Config;
163    type Error = SklearsError;
164    type Float = Float;
165
166    fn config(&self) -> &Self::Config {
167        &self.state.config
168    }
169}
170
171impl Fit<ArrayView2<'_, Float>, ArrayView2<'_, Float>> for NSGA2Optimizer<Untrained> {
172    type Fitted = NSGA2Optimizer<NSGA2OptimizerTrained>;
173
174    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView2<'_, Float>) -> SklResult<Self::Fitted> {
175        if X.nrows() != y.nrows() {
176            return Err(SklearsError::InvalidInput(
177                "X and y must have the same number of samples".to_string(),
178            ));
179        }
180
181        let n_samples = X.nrows();
182        let n_features = X.ncols();
183        let n_outputs = y.ncols();
184        let n_objectives = 2; // Default: accuracy and complexity
185
186        let mut rng = match self.config.random_state {
187            Some(seed) => scirs2_core::random::seeded_rng(seed),
188            None => scirs2_core::random::seeded_rng(42), // Default seed
189        };
190
191        // Initialize population
192        let mut population = self.initialize_population(n_features, n_outputs, &mut rng)?;
193
194        // Evaluate initial population
195        self.evaluate_population(&mut population, X, y)?;
196
197        let mut convergence_history = Vec::new();
198
199        // Main evolutionary loop
200        for generation in 0..self.config.generations {
201            // Non-dominated sorting
202            self.nsga2_non_dominated_sort(&mut population)?;
203
204            // Calculate crowding distance
205            self.nsga2_crowding_distance(&mut population)?;
206
207            // Calculate hypervolume for convergence tracking
208            let hypervolume = self.calculate_hypervolume(&population)?;
209            convergence_history.push(hypervolume);
210
211            // Generate offspring population
212            let mut offspring = self.nsga2_generate_offspring(&population, &mut rng)?;
213
214            // Evaluate offspring
215            self.evaluate_population(&mut offspring, X, y)?;
216
217            // Combine parent and offspring populations
218            population.extend(offspring);
219
220            // Environmental selection
221            population = self.nsga2_environmental_selection(population)?;
222
223            if generation % 50 == 0 {
224                println!(
225                    "Generation {}: Hypervolume = {:.6}",
226                    generation, hypervolume
227                );
228            }
229        }
230
231        // Final evaluation
232        self.nsga2_non_dominated_sort(&mut population)?;
233        let pareto_solutions = self.extract_pareto_front(&population)?;
234        let best_solution = self.find_best_compromise(&pareto_solutions)?;
235
236        Ok(NSGA2Optimizer {
237            state: NSGA2OptimizerTrained {
238                pareto_solutions: pareto_solutions.clone(),
239                best_solution,
240                convergence_history,
241                final_population: population,
242                config: self.config.clone(),
243                n_objectives,
244            },
245            config: self.config,
246        })
247    }
248}
249
250impl NSGA2Optimizer<Untrained> {
251    /// Initialize population for NSGA-II
252    fn initialize_population<R: Rng>(
253        &self,
254        n_features: usize,
255        n_outputs: usize,
256        rng: &mut R,
257    ) -> SklResult<Vec<ParetoSolution>> {
258        let mut population = Vec::with_capacity(self.config.population_size);
259        let param_size = n_features * n_outputs + n_outputs; // weights + bias
260
261        for _ in 0..self.config.population_size {
262            let parameters = Array1::from_shape_fn(param_size, |_| rng.gen_range(-1.0..1.0));
263            let solution = ParetoSolution {
264                parameters,
265                objectives: Array1::zeros(2), // Will be filled during evaluation
266                rank: 0,
267                crowding_distance: 0.0,
268            };
269            population.push(solution);
270        }
271
272        Ok(population)
273    }
274
275    /// NSGA-II Non-dominated sorting
276    fn nsga2_non_dominated_sort(&self, population: &mut Vec<ParetoSolution>) -> SklResult<()> {
277        let n = population.len();
278        let mut domination_counts = vec![0; n];
279        let mut dominated_solutions = vec![Vec::new(); n];
280        let mut fronts: Vec<Vec<usize>> = Vec::new();
281
282        // Calculate domination relationships
283        for i in 0..n {
284            for j in 0..n {
285                if i != j {
286                    if self.nsga2_dominates(&population[i], &population[j]) {
287                        dominated_solutions[i].push(j);
288                    } else if self.nsga2_dominates(&population[j], &population[i]) {
289                        domination_counts[i] += 1;
290                    }
291                }
292            }
293        }
294
295        // First front
296        let mut current_front: Vec<usize> = (0..n).filter(|&i| domination_counts[i] == 0).collect();
297        let mut rank = 0;
298
299        while !current_front.is_empty() {
300            // Assign rank to current front
301            for &i in &current_front {
302                population[i].rank = rank;
303            }
304
305            fronts.push(current_front.clone());
306
307            // Generate next front
308            let mut next_front = Vec::new();
309            for &i in &current_front {
310                for &j in &dominated_solutions[i] {
311                    domination_counts[j] -= 1;
312                    if domination_counts[j] == 0 {
313                        next_front.push(j);
314                    }
315                }
316            }
317
318            current_front = next_front;
319            rank += 1;
320        }
321
322        Ok(())
323    }
324
325    /// Check if solution a dominates solution b for NSGA-II
326    pub fn nsga2_dominates(&self, a: &ParetoSolution, b: &ParetoSolution) -> bool {
327        let mut at_least_one_better = false;
328
329        for i in 0..a.objectives.len() {
330            if a.objectives[i] > b.objectives[i] {
331                return false; // b is better in at least one objective
332            }
333            if a.objectives[i] < b.objectives[i] {
334                at_least_one_better = true;
335            }
336        }
337
338        at_least_one_better
339    }
340
341    /// Calculate crowding distance for NSGA-II
342    fn nsga2_crowding_distance(&self, population: &mut Vec<ParetoSolution>) -> SklResult<()> {
343        let n = population.len();
344        if n == 0 {
345            return Ok(());
346        }
347
348        // Initialize crowding distances
349        for solution in population.iter_mut() {
350            solution.crowding_distance = 0.0;
351        }
352
353        let n_objectives = population[0].objectives.len();
354
355        for obj_idx in 0..n_objectives {
356            // Sort by objective value
357            let mut indices: Vec<usize> = (0..n).collect();
358            indices.sort_by(|&a, &b| {
359                population[a].objectives[obj_idx]
360                    .partial_cmp(&population[b].objectives[obj_idx])
361                    .unwrap_or(std::cmp::Ordering::Equal)
362            });
363
364            // Set boundary points to infinite distance
365            population[indices[0]].crowding_distance = Float::INFINITY;
366            population[indices[n - 1]].crowding_distance = Float::INFINITY;
367
368            // Calculate distances for intermediate points
369            let obj_range = population[indices[n - 1]].objectives[obj_idx]
370                - population[indices[0]].objectives[obj_idx];
371
372            if obj_range > 0.0 {
373                for i in 1..(n - 1) {
374                    let distance = (population[indices[i + 1]].objectives[obj_idx]
375                        - population[indices[i - 1]].objectives[obj_idx])
376                        / obj_range;
377                    population[indices[i]].crowding_distance += distance;
378                }
379            }
380        }
381
382        Ok(())
383    }
384
385    /// Generate offspring population using NSGA-II
386    fn nsga2_generate_offspring<R: Rng>(
387        &self,
388        population: &[ParetoSolution],
389        rng: &mut R,
390    ) -> SklResult<Vec<ParetoSolution>> {
391        let mut offspring = Vec::new();
392
393        for _ in 0..self.config.population_size {
394            // Binary tournament selection
395            let parent1 = self.nsga2_tournament_selection(population, rng)?;
396            let parent2 = self.nsga2_tournament_selection(population, rng)?;
397
398            // Crossover
399            let mut child = match self.config.algorithm {
400                NSGA2Algorithm::SBX => self.simulated_binary_crossover(&parent1, &parent2, rng)?,
401                _ => self.uniform_crossover(&parent1, &parent2, rng)?,
402            };
403
404            // Mutation
405            match self.config.algorithm {
406                NSGA2Algorithm::SBX => self.polynomial_mutation(&mut child, rng)?,
407                _ => self.gaussian_mutation(&mut child, rng)?,
408            }
409
410            offspring.push(child);
411        }
412
413        Ok(offspring)
414    }
415
416    /// Binary tournament selection for NSGA-II
417    fn nsga2_tournament_selection<R: Rng>(
418        &self,
419        population: &[ParetoSolution],
420        rng: &mut R,
421    ) -> SklResult<ParetoSolution> {
422        let idx1 = rng.gen_range(0..population.len());
423        let idx2 = rng.gen_range(0..population.len());
424
425        let solution1 = &population[idx1];
426        let solution2 = &population[idx2];
427
428        // Compare by rank first, then by crowding distance
429        if solution1.rank < solution2.rank {
430            Ok(solution1.clone())
431        } else if solution1.rank > solution2.rank {
432            Ok(solution2.clone())
433        } else {
434            // Same rank, compare by crowding distance (higher is better)
435            if solution1.crowding_distance > solution2.crowding_distance {
436                Ok(solution1.clone())
437            } else {
438                Ok(solution2.clone())
439            }
440        }
441    }
442
443    /// Simulated Binary Crossover (SBX)
444    fn simulated_binary_crossover<R: Rng>(
445        &self,
446        parent1: &ParetoSolution,
447        parent2: &ParetoSolution,
448        rng: &mut R,
449    ) -> SklResult<ParetoSolution> {
450        let mut child_params = parent1.parameters.clone();
451
452        if rng.random::<Float>() <= self.config.crossover_prob {
453            for i in 0..child_params.len() {
454                let p1 = parent1.parameters[i];
455                let p2 = parent2.parameters[i];
456
457                if rng.random::<Float>() <= 0.5 {
458                    let u = rng.random::<Float>();
459                    let beta = if u <= 0.5 {
460                        (2.0 * u).powf(1.0 / (self.config.eta_c + 1.0))
461                    } else {
462                        (1.0 / (2.0 * (1.0 - u))).powf(1.0 / (self.config.eta_c + 1.0))
463                    };
464
465                    let child_val = 0.5 * ((1.0 + beta) * p1 + (1.0 - beta) * p2);
466                    child_params[i] = child_val.clamp(-2.0, 2.0);
467                }
468            }
469        }
470
471        Ok(ParetoSolution {
472            parameters: child_params,
473            objectives: Array1::zeros(parent1.objectives.len()),
474            rank: 0,
475            crowding_distance: 0.0,
476        })
477    }
478
479    /// Polynomial mutation
480    fn polynomial_mutation<R: Rng>(
481        &self,
482        solution: &mut ParetoSolution,
483        rng: &mut R,
484    ) -> SklResult<()> {
485        for i in 0..solution.parameters.len() {
486            if rng.random::<Float>() <= self.config.mutation_prob {
487                let u = rng.random::<Float>();
488                let delta = if u < 0.5 {
489                    (2.0 * u).powf(1.0 / (self.config.eta_m + 1.0)) - 1.0
490                } else {
491                    1.0 - (2.0 * (1.0 - u)).powf(1.0 / (self.config.eta_m + 1.0))
492                };
493
494                solution.parameters[i] += delta * 0.1;
495                solution.parameters[i] = solution.parameters[i].clamp(-2.0, 2.0);
496            }
497        }
498        Ok(())
499    }
500
501    /// Environmental selection for NSGA-II
502    fn nsga2_environmental_selection(
503        &self,
504        mut population: Vec<ParetoSolution>,
505    ) -> SklResult<Vec<ParetoSolution>> {
506        // Sort by rank and crowding distance
507        self.nsga2_non_dominated_sort(&mut population)?;
508        self.nsga2_crowding_distance(&mut population)?;
509
510        // Sort population by rank, then by crowding distance
511        population.sort_by(|a, b| {
512            match a.rank.cmp(&b.rank) {
513                std::cmp::Ordering::Equal => {
514                    // Higher crowding distance is better
515                    b.crowding_distance
516                        .partial_cmp(&a.crowding_distance)
517                        .unwrap_or(std::cmp::Ordering::Equal)
518                }
519                other => other,
520            }
521        });
522
523        // Take the best individuals up to population size
524        population.truncate(self.config.population_size);
525        Ok(population)
526    }
527
528    /// Extract Pareto front (rank 0 solutions)
529    fn extract_pareto_front(
530        &self,
531        population: &[ParetoSolution],
532    ) -> SklResult<Vec<ParetoSolution>> {
533        Ok(population
534            .iter()
535            .filter(|sol| sol.rank == 0)
536            .cloned()
537            .collect())
538    }
539
540    /// Evaluate population fitness
541    fn evaluate_population(
542        &self,
543        population: &mut [ParetoSolution],
544        X: &ArrayView2<Float>,
545        y: &ArrayView2<Float>,
546    ) -> SklResult<()> {
547        let n_features = X.ncols();
548        let n_outputs = y.ncols();
549
550        for solution in population.iter_mut() {
551            // Extract weights and bias from parameters
552            let weights = solution
553                .parameters
554                .slice(s![..n_features * n_outputs])
555                .to_owned()
556                .into_shape((n_features, n_outputs))
557                .map_err(|e| SklearsError::InvalidInput(format!("Shape error: {}", e)))?;
558
559            let bias = solution
560                .parameters
561                .slice(s![n_features * n_outputs..])
562                .to_owned();
563
564            // Make predictions
565            let mut predictions = X.dot(&weights);
566            for mut row in predictions.rows_mut() {
567                row += &bias;
568            }
569
570            // Calculate objectives
571            let mse = self.calculate_mse(&predictions.view(), y)?;
572            let complexity = self.calculate_complexity(&weights)?;
573
574            solution.objectives = array![mse, complexity];
575        }
576
577        Ok(())
578    }
579
580    /// Calculate Mean Squared Error
581    fn calculate_mse(
582        &self,
583        predictions: &ArrayView2<Float>,
584        y: &ArrayView2<Float>,
585    ) -> SklResult<Float> {
586        let diff = predictions - y;
587        let squared_diff = &diff * &diff;
588        Ok(squared_diff.sum() / (predictions.nrows() * predictions.ncols()) as Float)
589    }
590
591    /// Calculate model complexity (based on parameter magnitudes)
592    fn calculate_complexity(&self, weights: &Array2<Float>) -> SklResult<Float> {
593        Ok(weights.mapv(|x| x.abs()).sum())
594    }
595
596    /// Uniform crossover operation
597    fn uniform_crossover<R: Rng>(
598        &self,
599        parent1: &ParetoSolution,
600        parent2: &ParetoSolution,
601        rng: &mut R,
602    ) -> SklResult<ParetoSolution> {
603        let mut child_params = parent1.parameters.clone();
604
605        if rng.random::<Float>() <= self.config.crossover_prob {
606            for i in 0..child_params.len() {
607                if rng.random::<Float>() <= 0.5 {
608                    child_params[i] = parent2.parameters[i];
609                }
610            }
611        }
612
613        Ok(ParetoSolution {
614            parameters: child_params,
615            objectives: Array1::zeros(parent1.objectives.len()),
616            rank: 0,
617            crowding_distance: 0.0,
618        })
619    }
620
621    /// Gaussian mutation operation
622    fn gaussian_mutation<R: Rng>(
623        &self,
624        solution: &mut ParetoSolution,
625        rng: &mut R,
626    ) -> SklResult<()> {
627        for i in 0..solution.parameters.len() {
628            if rng.random::<Float>() <= self.config.mutation_prob {
629                let normal = RandNormal::new(0.0, 0.1).map_err(|e| {
630                    SklearsError::InvalidInput(format!(
631                        "Failed to create normal distribution: {}",
632                        e
633                    ))
634                })?;
635                let mutation = rng.sample(normal);
636                solution.parameters[i] += mutation;
637                solution.parameters[i] = solution.parameters[i].clamp(-2.0, 2.0);
638            }
639        }
640        Ok(())
641    }
642
643    /// Calculate hypervolume indicator
644    fn calculate_hypervolume(&self, population: &[ParetoSolution]) -> SklResult<Float> {
645        // Extract non-dominated solutions (Pareto front)
646        let pareto_front: Vec<&ParetoSolution> =
647            population.iter().filter(|sol| sol.rank == 0).collect();
648
649        if pareto_front.is_empty() {
650            return Ok(0.0);
651        }
652
653        // Simple hypervolume calculation using reference point (1.0, 1.0)
654        let reference_point = array![1.0, 1.0];
655        let mut hypervolume = 0.0;
656
657        for solution in &pareto_front {
658            let mut volume = 1.0;
659            for i in 0..solution.objectives.len() {
660                let contribution = (reference_point[i] - solution.objectives[i]).max(0.0);
661                volume *= contribution;
662            }
663            hypervolume += volume;
664        }
665
666        Ok(hypervolume / pareto_front.len() as Float)
667    }
668
669    /// Find best compromise solution from Pareto solutions
670    fn find_best_compromise(
671        &self,
672        pareto_solutions: &[ParetoSolution],
673    ) -> SklResult<ParetoSolution> {
674        if pareto_solutions.is_empty() {
675            return Err(SklearsError::InvalidInput(
676                "No Pareto solutions available".to_string(),
677            ));
678        }
679
680        let mut best_solution = pareto_solutions[0].clone();
681        let mut best_distance = Float::INFINITY;
682
683        // Find solution closest to ideal point (0, 0)
684        for solution in pareto_solutions {
685            let distance = solution.objectives.mapv(|x| x * x).sum().sqrt();
686            if distance < best_distance {
687                best_distance = distance;
688                best_solution = solution.clone();
689            }
690        }
691
692        Ok(best_solution)
693    }
694}
695
696impl Predict<ArrayView2<'_, Float>, Array2<Float>> for NSGA2Optimizer<NSGA2OptimizerTrained> {
697    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
698        let n_samples = X.nrows();
699        let n_features = X.ncols();
700        let n_outputs = self.state.best_solution.parameters.len() / (n_features + 1);
701
702        // Extract weights and bias from best solution
703        let weights = self
704            .state
705            .best_solution
706            .parameters
707            .slice(s![..n_features * n_outputs])
708            .to_owned()
709            .into_shape((n_features, n_outputs))
710            .map_err(|e| SklearsError::InvalidInput(format!("Shape error: {}", e)))?;
711
712        let bias = self
713            .state
714            .best_solution
715            .parameters
716            .slice(s![n_features * n_outputs..])
717            .to_owned();
718
719        // Make predictions: y = X * W + b
720        let mut predictions = X.dot(&weights);
721        for mut row in predictions.rows_mut() {
722            row += &bias;
723        }
724
725        Ok(predictions)
726    }
727}
728
729impl NSGA2Optimizer<NSGA2OptimizerTrained> {
730    /// Get the Pareto-optimal solutions
731    pub fn pareto_solutions(&self) -> &[ParetoSolution] {
732        &self.state.pareto_solutions
733    }
734
735    /// Get the best compromise solution
736    pub fn best_solution(&self) -> &ParetoSolution {
737        &self.state.best_solution
738    }
739
740    /// Get the convergence history
741    pub fn convergence_history(&self) -> &[Float] {
742        &self.state.convergence_history
743    }
744
745    /// Get the final population
746    pub fn final_population(&self) -> &[ParetoSolution] {
747        &self.state.final_population
748    }
749}