1use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
9use scirs2_core::random::Rng;
10use sklears_core::{
11 error::{Result as SklResult, SklearsError},
12 types::Float,
13};
14use std::cmp::Ordering;
15
16#[derive(Debug, Clone)]
18pub struct Individual {
19 pub variables: Array1<Float>,
21 pub objectives: Array1<Float>,
23 pub rank: usize,
25 pub crowding_distance: Float,
27}
28
29impl Individual {
30 pub fn new(variables: Array1<Float>) -> Self {
32 Self {
33 variables,
34 objectives: Array1::zeros(0), rank: 0,
36 crowding_distance: 0.0,
37 }
38 }
39
40 pub fn dominates(&self, other: &Individual) -> bool {
42 let mut at_least_one_better = false;
43
44 for i in 0..self.objectives.len() {
45 if self.objectives[i] > other.objectives[i] {
46 return false; }
48 if self.objectives[i] < other.objectives[i] {
49 at_least_one_better = true;
50 }
51 }
52
53 at_least_one_better
54 }
55}
56
57#[derive(Debug, Clone)]
88pub struct NSGAII {
89 population_size: usize,
91 n_generations: usize,
93 crossover_probability: Float,
95 mutation_probability: Float,
97 variable_bounds: Vec<(Float, Float)>,
99 random_state: Option<u64>,
101 crossover_eta: Float,
103 mutation_eta: Float,
105}
106
107impl Default for NSGAII {
108 fn default() -> Self {
109 Self {
110 population_size: 100,
111 n_generations: 100,
112 crossover_probability: 0.9,
113 mutation_probability: 0.1,
114 variable_bounds: Vec::new(),
115 random_state: None,
116 crossover_eta: 20.0,
117 mutation_eta: 20.0,
118 }
119 }
120}
121
122impl NSGAII {
123 pub fn new() -> Self {
125 Self::default()
126 }
127
128 pub fn population_size(mut self, size: usize) -> Self {
130 self.population_size = size;
131 self
132 }
133
134 pub fn n_generations(mut self, generations: usize) -> Self {
136 self.n_generations = generations;
137 self
138 }
139
140 pub fn crossover_probability(mut self, prob: Float) -> Self {
142 self.crossover_probability = prob;
143 self
144 }
145
146 pub fn mutation_probability(mut self, prob: Float) -> Self {
148 self.mutation_probability = prob;
149 self
150 }
151
152 pub fn variable_bounds(mut self, bounds: Vec<(Float, Float)>) -> Self {
154 self.variable_bounds = bounds;
155 self
156 }
157
158 pub fn random_state(mut self, seed: u64) -> Self {
160 self.random_state = Some(seed);
161 self
162 }
163
164 pub fn crossover_eta(mut self, eta: Float) -> Self {
166 self.crossover_eta = eta;
167 self
168 }
169
170 pub fn mutation_eta(mut self, eta: Float) -> Self {
172 self.mutation_eta = eta;
173 self
174 }
175
176 pub fn optimize<F>(&self, objective_fn: F, n_objectives: usize) -> SklResult<OptimizationResult>
178 where
179 F: Fn(&ArrayView1<Float>) -> Array1<Float>,
180 {
181 if self.variable_bounds.is_empty() {
182 return Err(SklearsError::InvalidInput(
183 "Variable bounds must be specified".to_string(),
184 ));
185 }
186
187 let n_variables = self.variable_bounds.len();
188 let mut rng = if let Some(seed) = self.random_state {
189 scirs2_core::random::seeded_rng(seed)
190 } else {
191 scirs2_core::random::seeded_rng(42)
192 };
193
194 let mut population = self.initialize_population(&mut rng, n_variables)?;
196
197 for individual in &mut population {
199 individual.objectives = objective_fn(&individual.variables.view());
200 }
201
202 let mut generation_stats = Vec::new();
203
204 for generation in 0..self.n_generations {
206 let offspring = self.create_offspring(&population, &mut rng, n_variables)?;
208
209 let mut combined_population = population;
211 combined_population.extend(offspring);
212
213 for individual in &mut combined_population {
215 if individual.objectives.len() != n_objectives {
216 individual.objectives = objective_fn(&individual.variables.view());
217 }
218 }
219
220 let fronts = self.non_dominated_sort(&combined_population);
222
223 population = self.environmental_selection(combined_population, fronts)?;
225
226 let stats = self.calculate_generation_stats(&population);
228 generation_stats.push(stats);
229
230 }
232
233 let fronts = self.non_dominated_sort(&population);
235 let pareto_front = if fronts.is_empty() {
236 Vec::new()
237 } else {
238 fronts[0].clone()
239 };
240
241 Ok(OptimizationResult {
242 pareto_front,
243 final_population: population,
244 generation_stats,
245 n_generations: self.n_generations,
246 })
247 }
248
249 fn initialize_population<R: Rng>(
251 &self,
252 rng: &mut scirs2_core::random::CoreRandom<R>,
253 n_variables: usize,
254 ) -> SklResult<Vec<Individual>> {
255 let mut population = Vec::with_capacity(self.population_size);
256
257 for _ in 0..self.population_size {
258 let mut variables = Array1::zeros(n_variables);
259
260 for j in 0..n_variables {
261 let (min_val, max_val) = self.variable_bounds[j];
262 variables[j] = rng.gen_range(min_val..max_val + 1.0);
263 }
264
265 population.push(Individual::new(variables));
266 }
267
268 Ok(population)
269 }
270
271 fn create_offspring<R: Rng>(
273 &self,
274 population: &[Individual],
275 rng: &mut scirs2_core::random::CoreRandom<R>,
276 n_variables: usize,
277 ) -> SklResult<Vec<Individual>> {
278 let mut offspring = Vec::with_capacity(self.population_size);
279
280 for _ in 0..self.population_size {
281 let parent1 = self.tournament_selection(population, rng);
283 let parent2 = self.tournament_selection(population, rng);
284
285 let mut child = if rng.gen::<Float>() < self.crossover_probability {
287 self.sbx_crossover(parent1, parent2, rng)?
288 } else {
289 parent1.clone()
290 };
291
292 if rng.gen::<Float>() < self.mutation_probability {
294 self.polynomial_mutation(&mut child, rng);
295 }
296
297 offspring.push(child);
298 }
299
300 Ok(offspring)
301 }
302
303 fn tournament_selection<'a, R: Rng>(
305 &self,
306 population: &'a [Individual],
307 rng: &mut scirs2_core::random::CoreRandom<R>,
308 ) -> &'a Individual {
309 let tournament_size = 2;
310 let mut best = &population[rng.gen_range(0..population.len())];
311
312 for _ in 1..tournament_size {
313 let candidate = &population[rng.gen_range(0..population.len())];
314 if self.compare_individuals(candidate, best) == Ordering::Less {
315 best = candidate;
316 }
317 }
318
319 best
320 }
321
322 fn compare_individuals(&self, a: &Individual, b: &Individual) -> Ordering {
324 match a.rank.cmp(&b.rank) {
326 Ordering::Equal => {
327 b.crowding_distance
329 .partial_cmp(&a.crowding_distance)
330 .unwrap_or(Ordering::Equal)
331 }
332 other => other,
333 }
334 }
335
336 fn sbx_crossover<R: Rng>(
338 &self,
339 parent1: &Individual,
340 parent2: &Individual,
341 rng: &mut scirs2_core::random::CoreRandom<R>,
342 ) -> SklResult<Individual> {
343 let n_variables = parent1.variables.len();
344 let mut child_variables = Array1::zeros(n_variables);
345
346 for i in 0..n_variables {
347 let p1 = parent1.variables[i];
348 let p2 = parent2.variables[i];
349 let (min_val, max_val) = self.variable_bounds[i];
350
351 if (p1 - p2).abs() > 1e-14 {
352 let u = rng.gen::<Float>();
353 let beta = if u <= 0.5 {
354 (2.0 * u).powf(1.0 / (self.crossover_eta + 1.0))
355 } else {
356 (1.0 / (2.0 * (1.0 - u))).powf(1.0 / (self.crossover_eta + 1.0))
357 };
358
359 let c1 = 0.5 * (p1 + p2 - beta * (p1 - p2).abs());
360 child_variables[i] = c1.clamp(min_val, max_val);
361 } else {
362 child_variables[i] = p1;
363 }
364 }
365
366 Ok(Individual::new(child_variables))
367 }
368
369 fn polynomial_mutation<R: Rng>(
371 &self,
372 individual: &mut Individual,
373 rng: &mut scirs2_core::random::CoreRandom<R>,
374 ) {
375 for i in 0..individual.variables.len() {
376 if rng.gen::<Float>() < (1.0 / individual.variables.len() as Float) {
377 let (min_val, max_val) = self.variable_bounds[i];
378 let x = individual.variables[i];
379 let u = rng.gen::<Float>();
380
381 let delta = if u <= 0.5 {
382 let bl = (x - min_val) / (max_val - min_val);
383 let b = 2.0 * u + (1.0 - 2.0 * u) * (1.0 - bl).powf(self.mutation_eta + 1.0);
384 b.powf(1.0 / (self.mutation_eta + 1.0)) - 1.0
385 } else {
386 let bu = (max_val - x) / (max_val - min_val);
387 let b = 2.0 * (1.0 - u)
388 + 2.0 * (u - 0.5) * (1.0 - bu).powf(self.mutation_eta + 1.0);
389 1.0 - b.powf(1.0 / (self.mutation_eta + 1.0))
390 };
391
392 individual.variables[i] = (x + delta * (max_val - min_val)).clamp(min_val, max_val);
393 }
394 }
395 }
396
397 fn non_dominated_sort(&self, population: &[Individual]) -> Vec<Vec<Individual>> {
399 let mut fronts = Vec::new();
400 let mut domination_count = vec![0; population.len()];
401 let mut dominated_solutions = vec![Vec::new(); population.len()];
402
403 for i in 0..population.len() {
405 for j in 0..population.len() {
406 if population[i].dominates(&population[j]) {
407 dominated_solutions[i].push(j);
408 } else if population[j].dominates(&population[i]) {
409 domination_count[i] += 1;
410 }
411 }
412 }
413
414 let mut current_front = Vec::new();
416 for i in 0..population.len() {
417 if domination_count[i] == 0 {
418 current_front.push(population[i].clone());
419 }
420 }
421
422 let mut front_index = 0;
423 while !current_front.is_empty() {
424 fronts.push(current_front.clone());
425 let mut next_front = Vec::new();
426
427 for individual in ¤t_front {
428 if let Some(ind_idx) = population.iter().position(|p| {
430 p.variables
431 .iter()
432 .zip(individual.variables.iter())
433 .all(|(a, b)| (a - b).abs() < 1e-10)
434 }) {
435 for &dominated_idx in &dominated_solutions[ind_idx] {
436 domination_count[dominated_idx] -= 1;
437 if domination_count[dominated_idx] == 0 {
438 next_front.push(population[dominated_idx].clone());
439 }
440 }
441 }
442 }
443
444 current_front = next_front;
445 front_index += 1;
446 }
447
448 fronts
449 }
450
451 fn environmental_selection(
453 &self,
454 population: Vec<Individual>,
455 fronts: Vec<Vec<Individual>>,
456 ) -> SklResult<Vec<Individual>> {
457 let mut selected = Vec::new();
458
459 for front in &fronts {
460 if selected.len() + front.len() <= self.population_size {
461 selected.extend(front.clone());
463 } else {
464 let remaining = self.population_size - selected.len();
466 let mut front_with_distance = front.clone();
467
468 self.calculate_crowding_distance(&mut front_with_distance);
470
471 front_with_distance.sort_by(|a, b| {
473 b.crowding_distance
474 .partial_cmp(&a.crowding_distance)
475 .unwrap_or(Ordering::Equal)
476 });
477
478 selected.extend(front_with_distance.into_iter().take(remaining));
479 break;
480 }
481 }
482
483 for (rank, front) in fronts.iter().enumerate() {
485 for individual in &mut selected {
486 if front.iter().any(|f| {
487 f.variables
488 .iter()
489 .zip(individual.variables.iter())
490 .all(|(a, b)| (a - b).abs() < 1e-10)
491 }) {
492 individual.rank = rank;
493 }
494 }
495 }
496
497 Ok(selected)
498 }
499
500 fn calculate_crowding_distance(&self, front: &mut [Individual]) {
502 if front.len() <= 2 {
503 for individual in front {
505 individual.crowding_distance = Float::INFINITY;
506 }
507 return;
508 }
509
510 let n_objectives = front[0].objectives.len();
511
512 for individual in front.iter_mut() {
514 individual.crowding_distance = 0.0;
515 }
516
517 for obj_idx in 0..n_objectives {
519 front.sort_by(|a, b| {
521 a.objectives[obj_idx]
522 .partial_cmp(&b.objectives[obj_idx])
523 .unwrap_or(Ordering::Equal)
524 });
525
526 front[0].crowding_distance = Float::INFINITY;
528 front[front.len() - 1].crowding_distance = Float::INFINITY;
529
530 let max_obj = front[front.len() - 1].objectives[obj_idx];
532 let min_obj = front[0].objectives[obj_idx];
533
534 if max_obj - min_obj > 0.0 {
535 for i in 1..front.len() - 1 {
536 if front[i].crowding_distance != Float::INFINITY {
537 front[i].crowding_distance += (front[i + 1].objectives[obj_idx]
538 - front[i - 1].objectives[obj_idx])
539 / (max_obj - min_obj);
540 }
541 }
542 }
543 }
544 }
545
546 fn calculate_generation_stats(&self, population: &[Individual]) -> GenerationStats {
548 let mut hypervolume = 0.0;
549 let spacing = 0.0;
550 let pareto_front_size = population.iter().filter(|ind| ind.rank == 0).count();
551
552 let pareto_individuals: Vec<_> = population.iter().filter(|ind| ind.rank == 0).collect();
554 if !pareto_individuals.is_empty() {
555 let n_objectives = pareto_individuals[0].objectives.len();
556 let reference_point = Array1::from_elem(n_objectives, 10.0); for individual in &pareto_individuals {
559 let mut volume = 1.0;
560 for obj_idx in 0..n_objectives {
561 volume *= (reference_point[obj_idx] - individual.objectives[obj_idx]).max(0.0);
562 }
563 hypervolume += volume;
564 }
565 }
566
567 GenerationStats {
568 hypervolume,
569 spacing,
570 pareto_front_size,
571 }
572 }
573}
574
575#[derive(Debug, Clone)]
577pub struct OptimizationResult {
578 pub pareto_front: Vec<Individual>,
580 pub final_population: Vec<Individual>,
582 pub generation_stats: Vec<GenerationStats>,
584 pub n_generations: usize,
586}
587
588impl OptimizationResult {
589 pub fn pareto_front(&self) -> &[Individual] {
591 &self.pareto_front
592 }
593
594 pub fn final_population(&self) -> &[Individual] {
596 &self.final_population
597 }
598
599 pub fn generation_stats(&self) -> &[GenerationStats] {
601 &self.generation_stats
602 }
603
604 pub fn pareto_objectives(&self) -> Array2<Float> {
606 if self.pareto_front.is_empty() {
607 return Array2::zeros((0, 0));
608 }
609
610 let n_objectives = self.pareto_front[0].objectives.len();
611 let mut objectives = Array2::zeros((self.pareto_front.len(), n_objectives));
612
613 for (i, individual) in self.pareto_front.iter().enumerate() {
614 for (j, &obj_val) in individual.objectives.iter().enumerate() {
615 objectives[[i, j]] = obj_val;
616 }
617 }
618
619 objectives
620 }
621
622 pub fn pareto_variables(&self) -> Array2<Float> {
624 if self.pareto_front.is_empty() {
625 return Array2::zeros((0, 0));
626 }
627
628 let n_variables = self.pareto_front[0].variables.len();
629 let mut variables = Array2::zeros((self.pareto_front.len(), n_variables));
630
631 for (i, individual) in self.pareto_front.iter().enumerate() {
632 for (j, &var_val) in individual.variables.iter().enumerate() {
633 variables[[i, j]] = var_val;
634 }
635 }
636
637 variables
638 }
639}
640
641#[derive(Debug, Clone)]
643pub struct GenerationStats {
644 pub hypervolume: Float,
646 pub spacing: Float,
648 pub pareto_front_size: usize,
650}
651
652#[allow(non_snake_case)]
653#[cfg(test)]
654mod tests {
655 use super::*;
656 use approx::assert_abs_diff_eq;
657 use scirs2_core::ndarray::array;
659
660 #[test]
661 fn test_nsga2_creation() {
662 let nsga2 = NSGAII::new()
663 .population_size(50)
664 .n_generations(25)
665 .crossover_probability(0.8)
666 .mutation_probability(0.2)
667 .variable_bounds(vec![(-1.0, 1.0), (-1.0, 1.0)])
668 .random_state(42);
669
670 assert_eq!(nsga2.population_size, 50);
671 assert_eq!(nsga2.n_generations, 25);
672 assert_abs_diff_eq!(nsga2.crossover_probability, 0.8);
673 assert_abs_diff_eq!(nsga2.mutation_probability, 0.2);
674 assert_eq!(nsga2.variable_bounds.len(), 2);
675 assert_eq!(nsga2.random_state, Some(42));
676 }
677
678 #[test]
679 fn test_nsga2_optimization() {
680 let objective_fn = |x: &ArrayView1<Float>| {
682 let f1 = x[0];
683 let g = 1.0 + 9.0 * x.iter().skip(1).sum::<Float>() / (x.len() - 1) as Float;
684 let f2 = g * (1.0 - (f1 / g).sqrt());
685 array![f1, f2]
686 };
687
688 let nsga2 = NSGAII::new()
689 .population_size(20)
690 .n_generations(10)
691 .variable_bounds(vec![(0.0, 1.0), (0.0, 1.0)])
692 .random_state(42);
693
694 let result = nsga2.optimize(objective_fn, 2).unwrap();
695
696 assert!(!result.pareto_front().is_empty());
698 assert!(result.pareto_front().len() <= 20); for individual in result.pareto_front() {
702 assert_eq!(individual.objectives.len(), 2);
703 assert!(individual.objectives[0] >= 0.0 && individual.objectives[0] <= 1.0);
704 assert!(individual.objectives[1] >= 0.0);
705 }
706
707 assert_eq!(result.generation_stats().len(), 10);
709 assert!(result.generation_stats()[0].pareto_front_size > 0);
710 }
711
712 #[test]
713 fn test_individual_dominance() {
714 let ind1 = Individual {
715 variables: array![1.0, 2.0],
716 objectives: array![1.0, 2.0], rank: 0,
718 crowding_distance: 0.0,
719 };
720
721 let ind2 = Individual {
722 variables: array![2.0, 3.0],
723 objectives: array![2.0, 3.0], rank: 0,
725 crowding_distance: 0.0,
726 };
727
728 assert!(ind1.dominates(&ind2));
729 assert!(!ind2.dominates(&ind1));
730 }
731
732 #[test]
733 fn test_non_dominated_sort() {
734 let nsga2 = NSGAII::new();
735
736 let population = vec![
737 Individual {
738 variables: array![1.0],
739 objectives: array![1.0, 3.0], rank: 0,
741 crowding_distance: 0.0,
742 },
743 Individual {
744 variables: array![2.0],
745 objectives: array![2.0, 2.0], rank: 0,
747 crowding_distance: 0.0,
748 },
749 Individual {
750 variables: array![3.0],
751 objectives: array![3.0, 1.0], rank: 0,
753 crowding_distance: 0.0,
754 },
755 Individual {
756 variables: array![4.0],
757 objectives: array![2.0, 3.0], rank: 0,
759 crowding_distance: 0.0,
760 },
761 ];
762
763 let fronts = nsga2.non_dominated_sort(&population);
764
765 assert!(fronts.len() >= 1);
766 assert_eq!(fronts[0].len(), 3); if fronts.len() > 1 {
768 assert_eq!(fronts[1].len(), 1); }
770 }
771
772 #[test]
773 fn test_optimization_result_accessors() {
774 let individuals = vec![
775 Individual {
776 variables: array![1.0, 2.0],
777 objectives: array![1.0, 2.0],
778 rank: 0,
779 crowding_distance: 0.0,
780 },
781 Individual {
782 variables: array![2.0, 1.0],
783 objectives: array![2.0, 1.0],
784 rank: 0,
785 crowding_distance: 0.0,
786 },
787 ];
788
789 let result = OptimizationResult {
790 pareto_front: individuals.clone(),
791 final_population: individuals,
792 generation_stats: vec![GenerationStats {
793 hypervolume: 1.0,
794 spacing: 0.5,
795 pareto_front_size: 2,
796 }],
797 n_generations: 10,
798 };
799
800 let objectives = result.pareto_objectives();
801 assert_eq!(objectives.shape(), &[2, 2]);
802 assert_abs_diff_eq!(objectives[[0, 0]], 1.0);
803 assert_abs_diff_eq!(objectives[[0, 1]], 2.0);
804
805 let variables = result.pareto_variables();
806 assert_eq!(variables.shape(), &[2, 2]);
807 assert_abs_diff_eq!(variables[[0, 0]], 1.0);
808 assert_abs_diff_eq!(variables[[0, 1]], 2.0);
809 }
810}