1use crate::error::OptimizeResult;
24use crate::multiobjective::indicators::{dominates, non_dominated_sort};
25use scirs2_core::random::rngs::StdRng;
26use scirs2_core::random::{Rng, SeedableRng};
27use scirs2_core::RngExt;
28
29#[derive(Clone, Debug)]
35pub struct Individual {
36 pub genes: Vec<f64>,
38 pub objectives: Vec<f64>,
40 pub rank: usize,
42 pub crowding_distance: f64,
44}
45
46impl Individual {
47 pub fn new(genes: Vec<f64>, objectives: Vec<f64>) -> Self {
48 Self {
49 genes,
50 objectives,
51 rank: 0,
52 crowding_distance: 0.0,
53 }
54 }
55}
56
57#[derive(Debug)]
59pub struct Nsga2Result {
60 pub pareto_front: Vec<Individual>,
62 pub all_fronts: Vec<Vec<Individual>>,
64 pub n_generations: usize,
66 pub n_evaluations: usize,
68}
69
70#[derive(Debug, Clone)]
72pub struct Nsga2Config {
73 pub population_size: usize,
75 pub n_generations: usize,
77 pub crossover_rate: f64,
79 pub mutation_rate: f64,
82 pub eta_c: f64,
84 pub eta_m: f64,
86 pub seed: u64,
88}
89
90impl Default for Nsga2Config {
91 fn default() -> Self {
92 Self {
93 population_size: 100,
94 n_generations: 200,
95 crossover_rate: 0.9,
96 mutation_rate: 0.0, eta_c: 20.0,
98 eta_m: 20.0,
99 seed: 12345,
100 }
101 }
102}
103
104pub fn nsga2<F>(
142 n_objectives: usize,
143 bounds: &[(f64, f64)],
144 objectives: F,
145 config: Nsga2Config,
146) -> OptimizeResult<Nsga2Result>
147where
148 F: Fn(&[f64]) -> Vec<f64>,
149{
150 use crate::error::OptimizeError;
151
152 if n_objectives == 0 {
153 return Err(OptimizeError::InvalidInput(
154 "n_objectives must be >= 1".to_string(),
155 ));
156 }
157 if bounds.is_empty() {
158 return Err(OptimizeError::InvalidInput(
159 "bounds must be non-empty".to_string(),
160 ));
161 }
162 for (i, &(lo, hi)) in bounds.iter().enumerate() {
163 if lo >= hi {
164 return Err(OptimizeError::InvalidInput(format!(
165 "bound[{i}]: lo ({lo}) must be < hi ({hi})"
166 )));
167 }
168 }
169
170 let n_vars = bounds.len();
171 let pop_size = if config.population_size % 2 == 0 {
172 config.population_size.max(4)
173 } else {
174 (config.population_size + 1).max(4)
175 };
176
177 let mutation_rate = if config.mutation_rate > 0.0 {
178 config.mutation_rate
179 } else {
180 1.0 / n_vars as f64
181 };
182
183 let mut rng = StdRng::seed_from_u64(config.seed);
184 let mut n_evaluations = 0usize;
185
186 let mut population: Vec<Individual> = (0..pop_size)
188 .map(|_| {
189 let genes = random_genes(bounds, &mut rng);
190 let objs = objectives(&genes);
191 n_evaluations += 1;
192 Individual::new(genes, objs)
193 })
194 .collect();
195
196 assign_ranks_and_crowding(&mut population);
198
199 for _ in 0..config.n_generations {
201 let offspring: Vec<Individual> = (0..pop_size / 2)
203 .flat_map(|_| {
204 let p1 = tournament_select(&population, &mut rng);
205 let p2 = tournament_select(&population, &mut rng);
206
207 let (c1_genes, c2_genes) = if rng.random::<f64>() < config.crossover_rate {
208 sbx_crossover(
209 &population[p1].genes,
210 &population[p2].genes,
211 config.eta_c,
212 bounds,
213 &mut rng,
214 )
215 } else {
216 (population[p1].genes.clone(), population[p2].genes.clone())
217 };
218
219 let c1_genes =
220 polynomial_mutation(c1_genes, mutation_rate, config.eta_m, bounds, &mut rng);
221 let c2_genes =
222 polynomial_mutation(c2_genes, mutation_rate, config.eta_m, bounds, &mut rng);
223
224 let objs1 = objectives(&c1_genes);
225 let objs2 = objectives(&c2_genes);
226 n_evaluations += 2;
227
228 vec![
229 Individual::new(c1_genes, objs1),
230 Individual::new(c2_genes, objs2),
231 ]
232 })
233 .collect();
234
235 let mut combined = population;
237 combined.extend(offspring);
238 assign_ranks_and_crowding(&mut combined);
239 population = select_survivors(combined, pop_size);
240 }
241
242 assign_ranks_and_crowding(&mut population);
244
245 let obj_vecs: Vec<Vec<f64>> = population
247 .iter()
248 .map(|ind| ind.objectives.clone())
249 .collect();
250 let front_indices = non_dominated_sort(&obj_vecs);
251
252 let mut all_fronts: Vec<Vec<Individual>> = front_indices
253 .iter()
254 .map(|idx_vec| idx_vec.iter().map(|&i| population[i].clone()).collect())
255 .collect();
256
257 let pareto_front = if all_fronts.is_empty() {
258 population.clone()
259 } else {
260 all_fronts.remove(0)
261 };
262
263 let obj_vecs2: Vec<Vec<f64>> = population
265 .iter()
266 .map(|ind| ind.objectives.clone())
267 .collect();
268 let front_indices2 = non_dominated_sort(&obj_vecs2);
269 let all_fronts_final: Vec<Vec<Individual>> = front_indices2
270 .iter()
271 .map(|idx_vec| idx_vec.iter().map(|&i| population[i].clone()).collect())
272 .collect();
273
274 Ok(Nsga2Result {
275 pareto_front,
276 all_fronts: all_fronts_final,
277 n_generations: config.n_generations,
278 n_evaluations,
279 })
280}
281
282pub(crate) fn assign_ranks_and_crowding(population: &mut [Individual]) {
288 if population.is_empty() {
289 return;
290 }
291
292 let obj_vecs: Vec<Vec<f64>> = population
293 .iter()
294 .map(|ind| ind.objectives.clone())
295 .collect();
296 let fronts = non_dominated_sort(&obj_vecs);
297
298 for (rank, front_idx) in fronts.iter().enumerate() {
299 for &i in front_idx {
300 population[i].rank = rank;
301 }
302 crowding_distance_assignment(population, front_idx);
303 }
304}
305
306fn crowding_distance_assignment(population: &mut [Individual], front_indices: &[usize]) {
309 let n = front_indices.len();
310 if n <= 2 {
311 for &i in front_indices {
313 population[i].crowding_distance = f64::INFINITY;
314 }
315 return;
316 }
317
318 for &i in front_indices {
320 population[i].crowding_distance = 0.0;
321 }
322
323 let n_obj = population[front_indices[0]].objectives.len();
324
325 for m in 0..n_obj {
326 let mut sorted = front_indices.to_vec();
328 sorted.sort_by(|&a, &b| {
329 population[a].objectives[m]
330 .partial_cmp(&population[b].objectives[m])
331 .unwrap_or(std::cmp::Ordering::Equal)
332 });
333
334 population[sorted[0]].crowding_distance = f64::INFINITY;
336 population[sorted[n - 1]].crowding_distance = f64::INFINITY;
337
338 let f_min = population[sorted[0]].objectives[m];
339 let f_max = population[sorted[n - 1]].objectives[m];
340 let range = f_max - f_min;
341
342 if range < f64::EPSILON {
343 continue; }
345
346 for k in 1..n - 1 {
347 let prev_val = population[sorted[k - 1]].objectives[m];
348 let next_val = population[sorted[k + 1]].objectives[m];
349 population[sorted[k]].crowding_distance += (next_val - prev_val) / range;
350 }
351 }
352}
353
354fn select_survivors(mut combined: Vec<Individual>, target_size: usize) -> Vec<Individual> {
365 combined.sort_by(|a, b| {
367 a.rank.cmp(&b.rank).then_with(|| {
368 b.crowding_distance
369 .partial_cmp(&a.crowding_distance)
370 .unwrap_or(std::cmp::Ordering::Equal)
371 })
372 });
373
374 combined.truncate(target_size);
375 combined
376}
377
378fn sbx_crossover(
390 parent1: &[f64],
391 parent2: &[f64],
392 eta_c: f64,
393 bounds: &[(f64, f64)],
394 rng: &mut StdRng,
395) -> (Vec<f64>, Vec<f64>) {
396 let n = parent1.len();
397 let mut child1 = parent1.to_vec();
398 let mut child2 = parent2.to_vec();
399
400 for i in 0..n {
401 if rng.random::<f64>() > 0.5 {
402 continue; }
404
405 let (lo, hi) = bounds[i];
406 let x1 = parent1[i].min(parent2[i]);
407 let x2 = parent1[i].max(parent2[i]);
408
409 if (x2 - x1).abs() < 1e-14 {
410 continue; }
412
413 let u: f64 = rng.random();
414
415 let beta_q = if u <= 0.5 {
417 let alpha = 2.0 - (1.0 / sbx_beta(x1, x2, lo, eta_c)).powf(eta_c + 1.0);
418 let alpha = alpha.max(0.0);
419 (2.0 * u * alpha).powf(1.0 / (eta_c + 1.0))
420 } else {
421 let alpha = 2.0 - (1.0 / sbx_beta(x1, x2, hi - x2 + x1, eta_c)).powf(eta_c + 1.0);
422 let alpha_inv = 2.0 * (1.0 - u) * alpha.max(0.0);
424 if alpha_inv < f64::EPSILON {
425 1.0
426 } else {
427 (1.0 / alpha_inv).powf(1.0 / (eta_c + 1.0))
428 }
429 };
430
431 let mid = 0.5 * (x1 + x2);
432 let half_diff = 0.5 * (x2 - x1);
433
434 let c1 = (mid - beta_q * half_diff).clamp(lo, hi);
435 let c2 = (mid + beta_q * half_diff).clamp(lo, hi);
436
437 if parent1[i] < parent2[i] {
439 child1[i] = c1;
440 child2[i] = c2;
441 } else {
442 child1[i] = c2;
443 child2[i] = c1;
444 }
445 }
446
447 (child1, child2)
448}
449
450fn sbx_beta(x1: f64, x2: f64, bound: f64, eta: f64) -> f64 {
452 let diff = (x2 - x1).abs().max(1e-14);
453 let dist = (bound - x1).abs().max(1e-14);
454 (1.0 + 2.0 * dist / diff).powf(eta + 1.0)
455}
456
457fn polynomial_mutation(
463 mut genes: Vec<f64>,
464 mutation_rate: f64,
465 eta_m: f64,
466 bounds: &[(f64, f64)],
467 rng: &mut StdRng,
468) -> Vec<f64> {
469 for (i, gene) in genes.iter_mut().enumerate() {
470 if rng.random::<f64>() >= mutation_rate {
471 continue;
472 }
473
474 let (lo, hi) = bounds[i];
475 let delta = hi - lo;
476 if delta < f64::EPSILON {
477 continue;
478 }
479
480 let u: f64 = rng.random();
481 let delta_q = if u < 0.5 {
482 let delta_l = (*gene - lo) / delta;
483 let base = 2.0 * u + (1.0 - 2.0 * u) * (1.0 - delta_l).powf(eta_m + 1.0);
484 base.powf(1.0 / (eta_m + 1.0)) - 1.0
485 } else {
486 let delta_r = (hi - *gene) / delta;
487 let base = 2.0 * (1.0 - u) + 2.0 * (u - 0.5) * (1.0 - delta_r).powf(eta_m + 1.0);
488 1.0 - base.powf(1.0 / (eta_m + 1.0))
489 };
490
491 *gene = (*gene + delta_q * delta).clamp(lo, hi);
492 }
493
494 genes
495}
496
497fn tournament_select(population: &[Individual], rng: &mut StdRng) -> usize {
507 let n = population.len();
508 let a = rng.random_range(0..n);
509 let mut b = rng.random_range(0..n);
510 if b == a && n > 1 {
512 b = (a + 1) % n;
513 }
514
515 let ia = &population[a];
516 let ib = &population[b];
517
518 if ia.rank < ib.rank || (ia.rank == ib.rank && ia.crowding_distance > ib.crowding_distance) {
519 a
520 } else {
521 b
522 }
523}
524
525fn random_genes(bounds: &[(f64, f64)], rng: &mut StdRng) -> Vec<f64> {
530 bounds
531 .iter()
532 .map(|&(lo, hi)| lo + rng.random::<f64>() * (hi - lo))
533 .collect()
534}
535
536#[cfg(test)]
541mod tests {
542 use super::*;
543
544 fn zdt1(x: &[f64]) -> Vec<f64> {
547 let f1 = x[0];
548 let g = 1.0 + 9.0 * x[1..].iter().sum::<f64>() / (x.len() - 1) as f64;
549 let f2 = g * (1.0 - (f1 / g).sqrt());
550 vec![f1, f2]
551 }
552
553 #[test]
554 fn test_nsga2_returns_pareto_front() {
555 let bounds: Vec<(f64, f64)> = vec![(0.0, 1.0); 5];
556 let mut cfg = Nsga2Config::default();
557 cfg.population_size = 20;
558 cfg.n_generations = 10;
559 cfg.seed = 99;
560
561 let result = nsga2(2, &bounds, zdt1, cfg).expect("nsga2 should succeed");
562 assert!(!result.pareto_front.is_empty());
563 assert!(result.n_evaluations > 0);
564 }
565
566 #[test]
567 fn test_nsga2_objectives_are_evaluated() {
568 let bounds = vec![(0.0, 1.0); 3];
569 let mut cfg = Nsga2Config::default();
570 cfg.population_size = 10;
571 cfg.n_generations = 3;
572
573 let result =
574 nsga2(2, &bounds, |x| vec![x[0], 1.0 - x[0]], cfg).expect("failed to create result");
575 assert!(result.n_evaluations >= 10);
577 }
578
579 #[test]
580 fn test_nsga2_pareto_front_non_dominated() {
581 let bounds = vec![(0.0, 1.0); 5];
582 let mut cfg = Nsga2Config::default();
583 cfg.population_size = 20;
584 cfg.n_generations = 20;
585 cfg.seed = 7;
586
587 let result = nsga2(2, &bounds, zdt1, cfg).expect("failed to create result");
588
589 let front = &result.pareto_front;
591 for i in 0..front.len() {
592 for j in 0..front.len() {
593 if i != j {
594 assert!(
595 !dominates(&front[i].objectives, &front[j].objectives),
596 "front[{i}] dominates front[{j}]"
597 );
598 }
599 }
600 }
601 }
602
603 #[test]
604 fn test_nsga2_bounds_respected() {
605 let bounds = vec![(0.2, 0.8); 3];
606 let mut cfg = Nsga2Config::default();
607 cfg.population_size = 20;
608 cfg.n_generations = 10;
609
610 let result =
611 nsga2(2, &bounds, |x| vec![x[0], 1.0 - x[0]], cfg).expect("failed to create result");
612
613 for ind in &result.pareto_front {
614 for (i, &g) in ind.genes.iter().enumerate() {
615 assert!(
616 g >= bounds[i].0 - 1e-9 && g <= bounds[i].1 + 1e-9,
617 "gene[{i}]={g} outside bounds"
618 );
619 }
620 }
621 }
622
623 #[test]
624 fn test_nsga2_zdt1_quality() {
625 let bounds: Vec<(f64, f64)> = vec![(0.0, 1.0); 10];
627 let mut cfg = Nsga2Config::default();
628 cfg.population_size = 40;
629 cfg.n_generations = 50;
630 cfg.seed = 42;
631
632 let result = nsga2(2, &bounds, zdt1, cfg).expect("failed to create result");
633
634 for ind in &result.pareto_front {
637 assert!(ind.objectives[0] >= 0.0, "f1 should be >= 0");
638 assert!(ind.objectives[1] >= 0.0, "f2 should be >= 0");
639 }
640 }
641
642 #[test]
643 fn test_nsga2_invalid_bounds() {
644 let result = nsga2(2, &[], |x| vec![x[0]], Nsga2Config::default());
645 assert!(result.is_err());
646
647 let result = nsga2(2, &[(1.0, 0.0)], |x| vec![x[0]], Nsga2Config::default());
648 assert!(result.is_err());
649 }
650
651 #[test]
652 fn test_nsga2_invalid_objectives() {
653 let result = nsga2(0, &[(0.0, 1.0)], |_| vec![], Nsga2Config::default());
654 assert!(result.is_err());
655 }
656
657 #[test]
658 fn test_sbx_offspring_in_bounds() {
659 let mut rng = StdRng::seed_from_u64(0);
660 let bounds = vec![(0.0, 1.0); 4];
661 let p1 = vec![0.2, 0.4, 0.6, 0.8];
662 let p2 = vec![0.8, 0.6, 0.4, 0.2];
663
664 for _ in 0..50 {
665 let (c1, c2) = sbx_crossover(&p1, &p2, 20.0, &bounds, &mut rng);
666 for (i, &v) in c1.iter().enumerate() {
667 assert!(
668 v >= bounds[i].0 && v <= bounds[i].1,
669 "c1[{i}]={v} out of bounds"
670 );
671 }
672 for (i, &v) in c2.iter().enumerate() {
673 assert!(
674 v >= bounds[i].0 && v <= bounds[i].1,
675 "c2[{i}]={v} out of bounds"
676 );
677 }
678 }
679 }
680
681 #[test]
682 fn test_polynomial_mutation_in_bounds() {
683 let mut rng = StdRng::seed_from_u64(1);
684 let bounds = vec![(0.0, 1.0); 5];
685 let genes = vec![0.5; 5];
686
687 for _ in 0..100 {
688 let mutated = polynomial_mutation(genes.clone(), 0.5, 20.0, &bounds, &mut rng);
689 for (i, &v) in mutated.iter().enumerate() {
690 assert!(
691 v >= bounds[i].0 && v <= bounds[i].1,
692 "mutated[{i}]={v} out of bounds"
693 );
694 }
695 }
696 }
697
698 #[test]
699 fn test_crowding_distance_boundary_gets_infinity() {
700 let mut pop = vec![
701 Individual::new(vec![0.0], vec![0.0, 3.0]),
702 Individual::new(vec![0.5], vec![0.5, 2.0]),
703 Individual::new(vec![1.0], vec![1.0, 0.0]),
704 ];
705 let indices: Vec<usize> = (0..3).collect();
706 crowding_distance_assignment(&mut pop, &indices);
707
708 assert_eq!(pop[0].crowding_distance, f64::INFINITY);
709 assert_eq!(pop[2].crowding_distance, f64::INFINITY);
710 assert!(pop[1].crowding_distance > 0.0);
712 }
713}