1use scirs2_core::ndarray::{array, s, Array1, Array2, ArrayView2};
18use scirs2_core::random::RandNormal;
19use scirs2_core::random::Rng;
20use sklears_core::{
21 error::{Result as SklResult, SklearsError},
22 traits::{Estimator, Fit, Predict, Untrained},
23 types::Float,
24};
25
26use super::multi_objective_optimization::ParetoSolution;
27
28#[derive(Debug, Clone, PartialEq)]
30pub enum NSGA2Algorithm {
31 Standard,
33 SBX,
35 DE,
37}
38
39#[derive(Debug, Clone)]
41pub struct NSGA2Config {
42 pub population_size: usize,
44 pub generations: usize,
46 pub crossover_prob: Float,
48 pub mutation_prob: Float,
50 pub eta_c: Float,
52 pub eta_m: Float,
54 pub algorithm: NSGA2Algorithm,
56 pub random_state: Option<u64>,
58}
59
60impl Default for NSGA2Config {
61 fn default() -> Self {
62 Self {
63 population_size: 100,
64 generations: 250,
65 crossover_prob: 0.9,
66 mutation_prob: 0.1,
67 eta_c: 20.0,
68 eta_m: 20.0,
69 algorithm: NSGA2Algorithm::Standard,
70 random_state: None,
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
77pub struct NSGA2Optimizer<S = Untrained> {
78 state: S,
79 config: NSGA2Config,
80}
81
82#[derive(Debug, Clone)]
84pub struct NSGA2OptimizerTrained {
85 pub pareto_solutions: Vec<ParetoSolution>,
87 pub best_solution: ParetoSolution,
89 pub convergence_history: Vec<Float>,
91 pub final_population: Vec<ParetoSolution>,
93 pub config: NSGA2Config,
95 pub n_objectives: usize,
97}
98
99impl Default for NSGA2Optimizer<Untrained> {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105impl NSGA2Optimizer<Untrained> {
106 pub fn new() -> Self {
108 Self {
109 state: Untrained,
110 config: NSGA2Config::default(),
111 }
112 }
113
114 pub fn config(mut self, config: NSGA2Config) -> Self {
116 self.config = config;
117 self
118 }
119
120 pub fn population_size(mut self, population_size: usize) -> Self {
122 self.config.population_size = population_size;
123 self
124 }
125
126 pub fn generations(mut self, generations: usize) -> Self {
128 self.config.generations = generations;
129 self
130 }
131
132 pub fn crossover_prob(mut self, crossover_prob: Float) -> Self {
134 self.config.crossover_prob = crossover_prob;
135 self
136 }
137
138 pub fn mutation_prob(mut self, mutation_prob: Float) -> Self {
140 self.config.mutation_prob = mutation_prob;
141 self
142 }
143
144 pub fn algorithm(mut self, algorithm: NSGA2Algorithm) -> Self {
146 self.config.algorithm = algorithm;
147 self
148 }
149}
150
151impl Estimator for NSGA2Optimizer<Untrained> {
152 type Config = NSGA2Config;
153 type Error = SklearsError;
154 type Float = Float;
155
156 fn config(&self) -> &Self::Config {
157 &self.config
158 }
159}
160
161impl Estimator for NSGA2Optimizer<NSGA2OptimizerTrained> {
162 type Config = NSGA2Config;
163 type Error = SklearsError;
164 type Float = Float;
165
166 fn config(&self) -> &Self::Config {
167 &self.state.config
168 }
169}
170
171impl Fit<ArrayView2<'_, Float>, ArrayView2<'_, Float>> for NSGA2Optimizer<Untrained> {
172 type Fitted = NSGA2Optimizer<NSGA2OptimizerTrained>;
173
174 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView2<'_, Float>) -> SklResult<Self::Fitted> {
175 if X.nrows() != y.nrows() {
176 return Err(SklearsError::InvalidInput(
177 "X and y must have the same number of samples".to_string(),
178 ));
179 }
180
181 let n_samples = X.nrows();
182 let n_features = X.ncols();
183 let n_outputs = y.ncols();
184 let n_objectives = 2; let mut rng = match self.config.random_state {
187 Some(seed) => scirs2_core::random::seeded_rng(seed),
188 None => scirs2_core::random::seeded_rng(42), };
190
191 let mut population = self.initialize_population(n_features, n_outputs, &mut rng)?;
193
194 self.evaluate_population(&mut population, X, y)?;
196
197 let mut convergence_history = Vec::new();
198
199 for generation in 0..self.config.generations {
201 self.nsga2_non_dominated_sort(&mut population)?;
203
204 self.nsga2_crowding_distance(&mut population)?;
206
207 let hypervolume = self.calculate_hypervolume(&population)?;
209 convergence_history.push(hypervolume);
210
211 let mut offspring = self.nsga2_generate_offspring(&population, &mut rng)?;
213
214 self.evaluate_population(&mut offspring, X, y)?;
216
217 population.extend(offspring);
219
220 population = self.nsga2_environmental_selection(population)?;
222
223 if generation % 50 == 0 {
224 println!(
225 "Generation {}: Hypervolume = {:.6}",
226 generation, hypervolume
227 );
228 }
229 }
230
231 self.nsga2_non_dominated_sort(&mut population)?;
233 let pareto_solutions = self.extract_pareto_front(&population)?;
234 let best_solution = self.find_best_compromise(&pareto_solutions)?;
235
236 Ok(NSGA2Optimizer {
237 state: NSGA2OptimizerTrained {
238 pareto_solutions: pareto_solutions.clone(),
239 best_solution,
240 convergence_history,
241 final_population: population,
242 config: self.config.clone(),
243 n_objectives,
244 },
245 config: self.config,
246 })
247 }
248}
249
250impl NSGA2Optimizer<Untrained> {
251 fn initialize_population<R: Rng>(
253 &self,
254 n_features: usize,
255 n_outputs: usize,
256 rng: &mut R,
257 ) -> SklResult<Vec<ParetoSolution>> {
258 let mut population = Vec::with_capacity(self.config.population_size);
259 let param_size = n_features * n_outputs + n_outputs; for _ in 0..self.config.population_size {
262 let parameters = Array1::from_shape_fn(param_size, |_| rng.gen_range(-1.0..1.0));
263 let solution = ParetoSolution {
264 parameters,
265 objectives: Array1::zeros(2), rank: 0,
267 crowding_distance: 0.0,
268 };
269 population.push(solution);
270 }
271
272 Ok(population)
273 }
274
275 fn nsga2_non_dominated_sort(&self, population: &mut Vec<ParetoSolution>) -> SklResult<()> {
277 let n = population.len();
278 let mut domination_counts = vec![0; n];
279 let mut dominated_solutions = vec![Vec::new(); n];
280 let mut fronts: Vec<Vec<usize>> = Vec::new();
281
282 for i in 0..n {
284 for j in 0..n {
285 if i != j {
286 if self.nsga2_dominates(&population[i], &population[j]) {
287 dominated_solutions[i].push(j);
288 } else if self.nsga2_dominates(&population[j], &population[i]) {
289 domination_counts[i] += 1;
290 }
291 }
292 }
293 }
294
295 let mut current_front: Vec<usize> = (0..n).filter(|&i| domination_counts[i] == 0).collect();
297 let mut rank = 0;
298
299 while !current_front.is_empty() {
300 for &i in ¤t_front {
302 population[i].rank = rank;
303 }
304
305 fronts.push(current_front.clone());
306
307 let mut next_front = Vec::new();
309 for &i in ¤t_front {
310 for &j in &dominated_solutions[i] {
311 domination_counts[j] -= 1;
312 if domination_counts[j] == 0 {
313 next_front.push(j);
314 }
315 }
316 }
317
318 current_front = next_front;
319 rank += 1;
320 }
321
322 Ok(())
323 }
324
325 pub fn nsga2_dominates(&self, a: &ParetoSolution, b: &ParetoSolution) -> bool {
327 let mut at_least_one_better = false;
328
329 for i in 0..a.objectives.len() {
330 if a.objectives[i] > b.objectives[i] {
331 return false; }
333 if a.objectives[i] < b.objectives[i] {
334 at_least_one_better = true;
335 }
336 }
337
338 at_least_one_better
339 }
340
341 fn nsga2_crowding_distance(&self, population: &mut Vec<ParetoSolution>) -> SklResult<()> {
343 let n = population.len();
344 if n == 0 {
345 return Ok(());
346 }
347
348 for solution in population.iter_mut() {
350 solution.crowding_distance = 0.0;
351 }
352
353 let n_objectives = population[0].objectives.len();
354
355 for obj_idx in 0..n_objectives {
356 let mut indices: Vec<usize> = (0..n).collect();
358 indices.sort_by(|&a, &b| {
359 population[a].objectives[obj_idx]
360 .partial_cmp(&population[b].objectives[obj_idx])
361 .unwrap_or(std::cmp::Ordering::Equal)
362 });
363
364 population[indices[0]].crowding_distance = Float::INFINITY;
366 population[indices[n - 1]].crowding_distance = Float::INFINITY;
367
368 let obj_range = population[indices[n - 1]].objectives[obj_idx]
370 - population[indices[0]].objectives[obj_idx];
371
372 if obj_range > 0.0 {
373 for i in 1..(n - 1) {
374 let distance = (population[indices[i + 1]].objectives[obj_idx]
375 - population[indices[i - 1]].objectives[obj_idx])
376 / obj_range;
377 population[indices[i]].crowding_distance += distance;
378 }
379 }
380 }
381
382 Ok(())
383 }
384
385 fn nsga2_generate_offspring<R: Rng>(
387 &self,
388 population: &[ParetoSolution],
389 rng: &mut R,
390 ) -> SklResult<Vec<ParetoSolution>> {
391 let mut offspring = Vec::new();
392
393 for _ in 0..self.config.population_size {
394 let parent1 = self.nsga2_tournament_selection(population, rng)?;
396 let parent2 = self.nsga2_tournament_selection(population, rng)?;
397
398 let mut child = match self.config.algorithm {
400 NSGA2Algorithm::SBX => self.simulated_binary_crossover(&parent1, &parent2, rng)?,
401 _ => self.uniform_crossover(&parent1, &parent2, rng)?,
402 };
403
404 match self.config.algorithm {
406 NSGA2Algorithm::SBX => self.polynomial_mutation(&mut child, rng)?,
407 _ => self.gaussian_mutation(&mut child, rng)?,
408 }
409
410 offspring.push(child);
411 }
412
413 Ok(offspring)
414 }
415
416 fn nsga2_tournament_selection<R: Rng>(
418 &self,
419 population: &[ParetoSolution],
420 rng: &mut R,
421 ) -> SklResult<ParetoSolution> {
422 let idx1 = rng.gen_range(0..population.len());
423 let idx2 = rng.gen_range(0..population.len());
424
425 let solution1 = &population[idx1];
426 let solution2 = &population[idx2];
427
428 if solution1.rank < solution2.rank {
430 Ok(solution1.clone())
431 } else if solution1.rank > solution2.rank {
432 Ok(solution2.clone())
433 } else {
434 if solution1.crowding_distance > solution2.crowding_distance {
436 Ok(solution1.clone())
437 } else {
438 Ok(solution2.clone())
439 }
440 }
441 }
442
443 fn simulated_binary_crossover<R: Rng>(
445 &self,
446 parent1: &ParetoSolution,
447 parent2: &ParetoSolution,
448 rng: &mut R,
449 ) -> SklResult<ParetoSolution> {
450 let mut child_params = parent1.parameters.clone();
451
452 if rng.random::<Float>() <= self.config.crossover_prob {
453 for i in 0..child_params.len() {
454 let p1 = parent1.parameters[i];
455 let p2 = parent2.parameters[i];
456
457 if rng.random::<Float>() <= 0.5 {
458 let u = rng.random::<Float>();
459 let beta = if u <= 0.5 {
460 (2.0 * u).powf(1.0 / (self.config.eta_c + 1.0))
461 } else {
462 (1.0 / (2.0 * (1.0 - u))).powf(1.0 / (self.config.eta_c + 1.0))
463 };
464
465 let child_val = 0.5 * ((1.0 + beta) * p1 + (1.0 - beta) * p2);
466 child_params[i] = child_val.clamp(-2.0, 2.0);
467 }
468 }
469 }
470
471 Ok(ParetoSolution {
472 parameters: child_params,
473 objectives: Array1::zeros(parent1.objectives.len()),
474 rank: 0,
475 crowding_distance: 0.0,
476 })
477 }
478
479 fn polynomial_mutation<R: Rng>(
481 &self,
482 solution: &mut ParetoSolution,
483 rng: &mut R,
484 ) -> SklResult<()> {
485 for i in 0..solution.parameters.len() {
486 if rng.random::<Float>() <= self.config.mutation_prob {
487 let u = rng.random::<Float>();
488 let delta = if u < 0.5 {
489 (2.0 * u).powf(1.0 / (self.config.eta_m + 1.0)) - 1.0
490 } else {
491 1.0 - (2.0 * (1.0 - u)).powf(1.0 / (self.config.eta_m + 1.0))
492 };
493
494 solution.parameters[i] += delta * 0.1;
495 solution.parameters[i] = solution.parameters[i].clamp(-2.0, 2.0);
496 }
497 }
498 Ok(())
499 }
500
501 fn nsga2_environmental_selection(
503 &self,
504 mut population: Vec<ParetoSolution>,
505 ) -> SklResult<Vec<ParetoSolution>> {
506 self.nsga2_non_dominated_sort(&mut population)?;
508 self.nsga2_crowding_distance(&mut population)?;
509
510 population.sort_by(|a, b| {
512 match a.rank.cmp(&b.rank) {
513 std::cmp::Ordering::Equal => {
514 b.crowding_distance
516 .partial_cmp(&a.crowding_distance)
517 .unwrap_or(std::cmp::Ordering::Equal)
518 }
519 other => other,
520 }
521 });
522
523 population.truncate(self.config.population_size);
525 Ok(population)
526 }
527
528 fn extract_pareto_front(
530 &self,
531 population: &[ParetoSolution],
532 ) -> SklResult<Vec<ParetoSolution>> {
533 Ok(population
534 .iter()
535 .filter(|sol| sol.rank == 0)
536 .cloned()
537 .collect())
538 }
539
540 fn evaluate_population(
542 &self,
543 population: &mut [ParetoSolution],
544 X: &ArrayView2<Float>,
545 y: &ArrayView2<Float>,
546 ) -> SklResult<()> {
547 let n_features = X.ncols();
548 let n_outputs = y.ncols();
549
550 for solution in population.iter_mut() {
551 let weights = solution
553 .parameters
554 .slice(s![..n_features * n_outputs])
555 .to_owned()
556 .into_shape((n_features, n_outputs))
557 .map_err(|e| SklearsError::InvalidInput(format!("Shape error: {}", e)))?;
558
559 let bias = solution
560 .parameters
561 .slice(s![n_features * n_outputs..])
562 .to_owned();
563
564 let mut predictions = X.dot(&weights);
566 for mut row in predictions.rows_mut() {
567 row += &bias;
568 }
569
570 let mse = self.calculate_mse(&predictions.view(), y)?;
572 let complexity = self.calculate_complexity(&weights)?;
573
574 solution.objectives = array![mse, complexity];
575 }
576
577 Ok(())
578 }
579
580 fn calculate_mse(
582 &self,
583 predictions: &ArrayView2<Float>,
584 y: &ArrayView2<Float>,
585 ) -> SklResult<Float> {
586 let diff = predictions - y;
587 let squared_diff = &diff * &diff;
588 Ok(squared_diff.sum() / (predictions.nrows() * predictions.ncols()) as Float)
589 }
590
591 fn calculate_complexity(&self, weights: &Array2<Float>) -> SklResult<Float> {
593 Ok(weights.mapv(|x| x.abs()).sum())
594 }
595
596 fn uniform_crossover<R: Rng>(
598 &self,
599 parent1: &ParetoSolution,
600 parent2: &ParetoSolution,
601 rng: &mut R,
602 ) -> SklResult<ParetoSolution> {
603 let mut child_params = parent1.parameters.clone();
604
605 if rng.random::<Float>() <= self.config.crossover_prob {
606 for i in 0..child_params.len() {
607 if rng.random::<Float>() <= 0.5 {
608 child_params[i] = parent2.parameters[i];
609 }
610 }
611 }
612
613 Ok(ParetoSolution {
614 parameters: child_params,
615 objectives: Array1::zeros(parent1.objectives.len()),
616 rank: 0,
617 crowding_distance: 0.0,
618 })
619 }
620
621 fn gaussian_mutation<R: Rng>(
623 &self,
624 solution: &mut ParetoSolution,
625 rng: &mut R,
626 ) -> SklResult<()> {
627 for i in 0..solution.parameters.len() {
628 if rng.random::<Float>() <= self.config.mutation_prob {
629 let normal = RandNormal::new(0.0, 0.1).map_err(|e| {
630 SklearsError::InvalidInput(format!(
631 "Failed to create normal distribution: {}",
632 e
633 ))
634 })?;
635 let mutation = rng.sample(normal);
636 solution.parameters[i] += mutation;
637 solution.parameters[i] = solution.parameters[i].clamp(-2.0, 2.0);
638 }
639 }
640 Ok(())
641 }
642
643 fn calculate_hypervolume(&self, population: &[ParetoSolution]) -> SklResult<Float> {
645 let pareto_front: Vec<&ParetoSolution> =
647 population.iter().filter(|sol| sol.rank == 0).collect();
648
649 if pareto_front.is_empty() {
650 return Ok(0.0);
651 }
652
653 let reference_point = array![1.0, 1.0];
655 let mut hypervolume = 0.0;
656
657 for solution in &pareto_front {
658 let mut volume = 1.0;
659 for i in 0..solution.objectives.len() {
660 let contribution = (reference_point[i] - solution.objectives[i]).max(0.0);
661 volume *= contribution;
662 }
663 hypervolume += volume;
664 }
665
666 Ok(hypervolume / pareto_front.len() as Float)
667 }
668
669 fn find_best_compromise(
671 &self,
672 pareto_solutions: &[ParetoSolution],
673 ) -> SklResult<ParetoSolution> {
674 if pareto_solutions.is_empty() {
675 return Err(SklearsError::InvalidInput(
676 "No Pareto solutions available".to_string(),
677 ));
678 }
679
680 let mut best_solution = pareto_solutions[0].clone();
681 let mut best_distance = Float::INFINITY;
682
683 for solution in pareto_solutions {
685 let distance = solution.objectives.mapv(|x| x * x).sum().sqrt();
686 if distance < best_distance {
687 best_distance = distance;
688 best_solution = solution.clone();
689 }
690 }
691
692 Ok(best_solution)
693 }
694}
695
696impl Predict<ArrayView2<'_, Float>, Array2<Float>> for NSGA2Optimizer<NSGA2OptimizerTrained> {
697 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
698 let n_samples = X.nrows();
699 let n_features = X.ncols();
700 let n_outputs = self.state.best_solution.parameters.len() / (n_features + 1);
701
702 let weights = self
704 .state
705 .best_solution
706 .parameters
707 .slice(s![..n_features * n_outputs])
708 .to_owned()
709 .into_shape((n_features, n_outputs))
710 .map_err(|e| SklearsError::InvalidInput(format!("Shape error: {}", e)))?;
711
712 let bias = self
713 .state
714 .best_solution
715 .parameters
716 .slice(s![n_features * n_outputs..])
717 .to_owned();
718
719 let mut predictions = X.dot(&weights);
721 for mut row in predictions.rows_mut() {
722 row += &bias;
723 }
724
725 Ok(predictions)
726 }
727}
728
729impl NSGA2Optimizer<NSGA2OptimizerTrained> {
730 pub fn pareto_solutions(&self) -> &[ParetoSolution] {
732 &self.state.pareto_solutions
733 }
734
735 pub fn best_solution(&self) -> &ParetoSolution {
737 &self.state.best_solution
738 }
739
740 pub fn convergence_history(&self) -> &[Float] {
742 &self.state.convergence_history
743 }
744
745 pub fn final_population(&self) -> &[ParetoSolution] {
747 &self.state.final_population
748 }
749}