Skip to main content

scirs2_optimize/multi_objective/algorithms/
nsga3.rs

1//! NSGA-III (Non-dominated Sorting Genetic Algorithm III)
2//!
3//! Reference-point-based many-objective evolutionary algorithm.
4//! NSGA-III extends NSGA-II for many-objective optimization (>3 objectives)
5//! by replacing crowding distance with reference-point-based selection.
6//!
7//! Key features:
8//! - Das-Dennis reference point generation
9//! - Adaptive normalization of objective values
10//! - Reference-point association using perpendicular distance
11//! - Niching-based selection for population diversity
12//!
13//! # References
14//!
15//! - Deb & Jain, "An Evolutionary Many-Objective Optimization Algorithm Using
16//!   Reference-Point-Based Nondominated Sorting Approach, Part I", IEEE TEC 2014
17
18use super::{utils, MultiObjectiveConfig, MultiObjectiveOptimizer};
19use crate::error::OptimizeError;
20use crate::multi_objective::crossover::{CrossoverOperator, SimulatedBinaryCrossover};
21use crate::multi_objective::mutation::{MutationOperator, PolynomialMutation};
22use crate::multi_objective::selection::{SelectionOperator, TournamentSelection};
23use crate::multi_objective::solutions::{MultiObjectiveResult, MultiObjectiveSolution, Population};
24use scirs2_core::ndarray::{Array1, ArrayView1};
25use scirs2_core::random::rngs::StdRng;
26use scirs2_core::random::{Rng, RngExt, SeedableRng};
27use std::cmp::Ordering;
28
29/// NSGA-III optimizer for many-objective optimization
30pub struct NSGAIII {
31    config: MultiObjectiveConfig,
32    n_objectives: usize,
33    n_variables: usize,
34    /// Structured reference points (Das-Dennis)
35    reference_points: Vec<Array1<f64>>,
36    population: Population,
37    generation: usize,
38    n_evaluations: usize,
39    rng: StdRng,
40    crossover: SimulatedBinaryCrossover,
41    mutation: PolynomialMutation,
42    selection: TournamentSelection,
43    /// Ideal point (best value per objective)
44    ideal_point: Array1<f64>,
45    /// Niche count for each reference point
46    niche_count: Vec<usize>,
47    convergence_history: Vec<f64>,
48}
49
50impl NSGAIII {
51    /// Create new NSGA-III optimizer with default reference points
52    pub fn new(population_size: usize, n_objectives: usize, n_variables: usize) -> Self {
53        let config = MultiObjectiveConfig {
54            population_size,
55            ..Default::default()
56        };
57        Self::with_config(config, n_objectives, n_variables, None)
58    }
59
60    /// Create NSGA-III with full configuration and optional custom reference points
61    pub fn with_config(
62        config: MultiObjectiveConfig,
63        n_objectives: usize,
64        n_variables: usize,
65        custom_reference_points: Option<Vec<Array1<f64>>>,
66    ) -> Self {
67        let seed = config.random_seed.unwrap_or_else(|| {
68            use std::time::{SystemTime, UNIX_EPOCH};
69            SystemTime::now()
70                .duration_since(UNIX_EPOCH)
71                .map(|d| d.as_secs())
72                .unwrap_or(42)
73        });
74
75        let rng = StdRng::seed_from_u64(seed);
76
77        let crossover =
78            SimulatedBinaryCrossover::new(config.crossover_eta, config.crossover_probability);
79        let mutation = PolynomialMutation::new(config.mutation_probability, config.mutation_eta);
80        let selection = TournamentSelection::new(2);
81
82        // Generate reference points
83        let reference_points = custom_reference_points.unwrap_or_else(|| {
84            let n_partitions = Self::auto_partitions(n_objectives, config.population_size);
85            utils::generate_das_dennis_points(n_objectives, n_partitions)
86        });
87
88        let niche_count = vec![0; reference_points.len()];
89
90        Self {
91            config,
92            n_objectives,
93            n_variables,
94            reference_points,
95            population: Population::new(),
96            generation: 0,
97            n_evaluations: 0,
98            rng,
99            crossover,
100            mutation,
101            selection,
102            ideal_point: Array1::from_elem(n_objectives, f64::INFINITY),
103            niche_count,
104            convergence_history: Vec::new(),
105        }
106    }
107
108    /// Return a reference to the current structured reference points.
109    pub fn reference_points(&self) -> &[Array1<f64>] {
110        &self.reference_points
111    }
112
113    /// Automatically determine the number of partitions based on objectives and pop size
114    fn auto_partitions(n_objectives: usize, pop_size: usize) -> usize {
115        // Heuristic: choose partitions so Das-Dennis points ~ pop_size
116        // Number of points = C(H + M - 1, M - 1) where H = partitions, M = objectives
117        if n_objectives <= 2 {
118            return pop_size.max(4);
119        }
120        // For higher dimensions, use fewer partitions
121        let mut p = 1;
122        loop {
123            let n_points = binomial_coefficient(p + n_objectives - 1, n_objectives - 1);
124            if n_points >= pop_size / 2 {
125                return p;
126            }
127            p += 1;
128            if p > 50 {
129                return p;
130            }
131        }
132    }
133
134    /// Evaluate a single individual
135    fn evaluate_individual<F>(
136        &mut self,
137        variables: &Array1<f64>,
138        objective_function: &F,
139    ) -> Result<Array1<f64>, OptimizeError>
140    where
141        F: Fn(&ArrayView1<f64>) -> Array1<f64> + Send + Sync,
142    {
143        self.n_evaluations += 1;
144
145        if let Some(max_evals) = self.config.max_evaluations {
146            if self.n_evaluations > max_evals {
147                return Err(OptimizeError::MaxEvaluationsReached);
148            }
149        }
150
151        let objectives = objective_function(&variables.view());
152        if objectives.len() != self.n_objectives {
153            return Err(OptimizeError::InvalidInput(format!(
154                "Expected {} objectives, got {}",
155                self.n_objectives,
156                objectives.len()
157            )));
158        }
159
160        Ok(objectives)
161    }
162
163    /// Update the ideal point from a set of solutions
164    fn update_ideal_point(&mut self, solutions: &[MultiObjectiveSolution]) {
165        for sol in solutions {
166            for (i, &obj) in sol.objectives.iter().enumerate() {
167                if obj < self.ideal_point[i] {
168                    self.ideal_point[i] = obj;
169                }
170            }
171        }
172    }
173
174    /// Normalize objectives using ideal and extreme points (ASF-based)
175    fn normalize_objectives(&self, solutions: &[MultiObjectiveSolution]) -> Vec<Array1<f64>> {
176        let n = solutions.len();
177        if n == 0 {
178            return vec![];
179        }
180
181        // Translate by ideal point
182        let translated: Vec<Array1<f64>> = solutions
183            .iter()
184            .map(|sol| &sol.objectives - &self.ideal_point)
185            .collect();
186
187        // Find extreme points using Achievement Scalarizing Function (ASF)
188        let mut extreme_points = Vec::with_capacity(self.n_objectives);
189        for j in 0..self.n_objectives {
190            let mut best_asf = f64::INFINITY;
191            let mut best_idx = 0;
192            for (i, t) in translated.iter().enumerate() {
193                // ASF: max(t_k / w_k) where w is the j-th axis weight
194                let asf = (0..self.n_objectives)
195                    .map(|k| {
196                        if k == j {
197                            t[k] // weight = 1 for target objective
198                        } else {
199                            t[k] * 1e6 // large weight for others
200                        }
201                    })
202                    .fold(f64::NEG_INFINITY, f64::max);
203                if asf < best_asf {
204                    best_asf = asf;
205                    best_idx = i;
206                }
207            }
208            extreme_points.push(translated[best_idx].clone());
209        }
210
211        // Compute intercepts from the extreme points
212        let mut intercepts = Array1::from_elem(self.n_objectives, 1.0);
213        if extreme_points.len() == self.n_objectives {
214            // Build the matrix of extreme points and solve for intercepts
215            // Simplified: use the diagonal elements as intercepts
216            for (j, ep) in extreme_points.iter().enumerate() {
217                let val = ep[j];
218                if val > 1e-10 {
219                    intercepts[j] = val;
220                }
221            }
222        }
223
224        // Normalize each translated objective
225        translated
226            .iter()
227            .map(|t| {
228                let mut normalized = Array1::zeros(self.n_objectives);
229                for j in 0..self.n_objectives {
230                    normalized[j] = if intercepts[j] > 1e-10 {
231                        t[j] / intercepts[j]
232                    } else {
233                        t[j]
234                    };
235                }
236                normalized
237            })
238            .collect()
239    }
240
241    /// Associate each solution to its nearest reference point
242    /// Returns (reference_point_index, perpendicular_distance) for each solution
243    fn associate_to_reference_points(
244        &self,
245        normalized_objectives: &[Array1<f64>],
246    ) -> Vec<(usize, f64)> {
247        normalized_objectives
248            .iter()
249            .map(|obj| {
250                let mut min_dist = f64::INFINITY;
251                let mut min_idx = 0;
252
253                for (rp_idx, rp) in self.reference_points.iter().enumerate() {
254                    let dist = perpendicular_distance(obj, rp);
255                    if dist < min_dist {
256                        min_dist = dist;
257                        min_idx = rp_idx;
258                    }
259                }
260
261                (min_idx, min_dist)
262            })
263            .collect()
264    }
265
266    /// Niching-based selection from the last front
267    fn niching_selection(
268        &mut self,
269        last_front_indices: &[usize],
270        all_solutions: &[MultiObjectiveSolution],
271        associations: &[(usize, f64)],
272        selected_so_far: &[usize],
273        remaining: usize,
274    ) -> Vec<usize> {
275        // Compute niche counts for already-selected solutions
276        let mut niche_count = vec![0usize; self.reference_points.len()];
277        for &idx in selected_so_far {
278            let (rp_idx, _) = associations[idx];
279            niche_count[rp_idx] += 1;
280        }
281
282        let mut selected = Vec::with_capacity(remaining);
283        let mut available: Vec<usize> = last_front_indices.to_vec();
284
285        for _ in 0..remaining {
286            if available.is_empty() {
287                break;
288            }
289
290            // Find the reference point with minimum niche count
291            // among those that have at least one member in the last front
292            let relevant_rps: Vec<usize> = available
293                .iter()
294                .map(|&idx| associations[idx].0)
295                .collect::<std::collections::HashSet<_>>()
296                .into_iter()
297                .collect();
298
299            if relevant_rps.is_empty() {
300                break;
301            }
302
303            let min_niche = relevant_rps
304                .iter()
305                .map(|&rp| niche_count[rp])
306                .min()
307                .unwrap_or(0);
308
309            // Collect reference points with minimum niche count
310            let min_niche_rps: Vec<usize> = relevant_rps
311                .iter()
312                .filter(|&&rp| niche_count[rp] == min_niche)
313                .copied()
314                .collect();
315
316            // Randomly pick one
317            let chosen_rp_idx = self.rng.random_range(0..min_niche_rps.len());
318            let chosen_rp = min_niche_rps[chosen_rp_idx];
319
320            // Find members of the last front associated with this reference point
321            let rp_members: Vec<usize> = available
322                .iter()
323                .filter(|&&idx| associations[idx].0 == chosen_rp)
324                .copied()
325                .collect();
326
327            if rp_members.is_empty() {
328                continue;
329            }
330
331            let chosen_member = if min_niche == 0 {
332                // Pick the one with smallest perpendicular distance
333                *rp_members
334                    .iter()
335                    .min_by(|&&a, &&b| {
336                        associations[a]
337                            .1
338                            .partial_cmp(&associations[b].1)
339                            .unwrap_or(Ordering::Equal)
340                    })
341                    .unwrap_or(&rp_members[0])
342            } else {
343                // Randomly pick
344                rp_members[self.rng.random_range(0..rp_members.len())]
345            };
346
347            selected.push(chosen_member);
348            niche_count[chosen_rp] += 1;
349            available.retain(|&idx| idx != chosen_member);
350        }
351
352        selected
353    }
354
355    /// Create offspring through crossover and mutation
356    fn create_offspring<F>(
357        &mut self,
358        objective_function: &F,
359    ) -> Result<Vec<MultiObjectiveSolution>, OptimizeError>
360    where
361        F: Fn(&ArrayView1<f64>) -> Array1<f64> + Send + Sync,
362    {
363        let mut offspring = Vec::new();
364        let pop_solutions = self.population.solutions().to_vec();
365
366        while offspring.len() < self.config.population_size {
367            let selected = self.selection.select(&pop_solutions, 2);
368            if selected.len() < 2 {
369                break;
370            }
371
372            let p1_vars = selected[0].variables.as_slice().unwrap_or(&[]);
373            let p2_vars = selected[1].variables.as_slice().unwrap_or(&[]);
374
375            let (mut c1_vars, mut c2_vars) = self.crossover.crossover(p1_vars, p2_vars);
376
377            let bounds: Vec<(f64, f64)> = if let Some((lower, upper)) = &self.config.bounds {
378                lower
379                    .iter()
380                    .zip(upper.iter())
381                    .map(|(&l, &u)| (l, u))
382                    .collect()
383            } else {
384                vec![(-1.0, 1.0); self.n_variables]
385            };
386
387            self.mutation.mutate(&mut c1_vars, &bounds);
388            self.mutation.mutate(&mut c2_vars, &bounds);
389
390            let c1_arr = Array1::from_vec(c1_vars);
391            let c1_obj = self.evaluate_individual(&c1_arr, objective_function)?;
392            offspring.push(MultiObjectiveSolution::new(c1_arr, c1_obj));
393
394            if offspring.len() < self.config.population_size {
395                let c2_arr = Array1::from_vec(c2_vars);
396                let c2_obj = self.evaluate_individual(&c2_arr, objective_function)?;
397                offspring.push(MultiObjectiveSolution::new(c2_arr, c2_obj));
398            }
399        }
400
401        Ok(offspring)
402    }
403
404    /// NSGA-III environmental selection using reference-point-based niching
405    fn environmental_selection(
406        &mut self,
407        combined: Vec<MultiObjectiveSolution>,
408    ) -> Vec<MultiObjectiveSolution> {
409        let target_size = self.config.population_size;
410        if combined.len() <= target_size {
411            return combined;
412        }
413
414        // Non-dominated sorting
415        let mut temp_pop = Population::from_solutions(combined.clone());
416        let fronts = temp_pop.non_dominated_sort();
417
418        // Fill until the critical front
419        let mut selected_indices: Vec<usize> = Vec::new();
420        let mut last_front_idx = 0;
421
422        for (fi, front) in fronts.iter().enumerate() {
423            if selected_indices.len() + front.len() <= target_size {
424                selected_indices.extend(front);
425            } else {
426                last_front_idx = fi;
427                break;
428            }
429        }
430
431        if selected_indices.len() >= target_size {
432            return selected_indices
433                .iter()
434                .take(target_size)
435                .map(|&i| combined[i].clone())
436                .collect();
437        }
438
439        // Need to select from the last (critical) front using niching
440        let remaining = target_size - selected_indices.len();
441        let last_front = &fronts[last_front_idx];
442
443        // Update ideal point
444        self.update_ideal_point(&combined);
445
446        // Normalize objectives for ALL solutions considered
447        let all_considered: Vec<usize> = selected_indices
448            .iter()
449            .chain(last_front.iter())
450            .copied()
451            .collect();
452        let all_solutions: Vec<MultiObjectiveSolution> = all_considered
453            .iter()
454            .map(|&i| combined[i].clone())
455            .collect();
456
457        let normalized = self.normalize_objectives(&all_solutions);
458        let associations = self.associate_to_reference_points(&normalized);
459
460        // Map indices back
461        let n_selected = selected_indices.len();
462        let last_front_local: Vec<usize> = (n_selected..all_solutions.len()).collect();
463        let selected_local: Vec<usize> = (0..n_selected).collect();
464
465        let niching_result = self.niching_selection(
466            &last_front_local,
467            &all_solutions,
468            &associations,
469            &selected_local,
470            remaining,
471        );
472
473        // Build final selection
474        let mut result: Vec<MultiObjectiveSolution> = selected_indices
475            .iter()
476            .map(|&i| combined[i].clone())
477            .collect();
478
479        for local_idx in niching_result {
480            let global_idx = all_considered[local_idx];
481            result.push(combined[global_idx].clone());
482        }
483
484        result
485    }
486
487    /// Calculate metrics for the current generation
488    fn calculate_metrics(&mut self) {
489        if let Some(ref_point) = &self.config.reference_point {
490            let pareto_front = self.population.extract_pareto_front();
491            let hv = utils::calculate_hypervolume(&pareto_front, ref_point);
492            self.convergence_history.push(hv);
493        }
494    }
495}
496
497impl MultiObjectiveOptimizer for NSGAIII {
498    fn optimize<F>(&mut self, objective_function: F) -> Result<MultiObjectiveResult, OptimizeError>
499    where
500        F: Fn(&ArrayView1<f64>) -> Array1<f64> + Send + Sync,
501    {
502        self.initialize_population()?;
503
504        // Generate and evaluate initial population
505        let initial_vars = utils::generate_random_population(
506            self.config.population_size,
507            self.n_variables,
508            &self.config.bounds,
509        );
510
511        let mut initial_solutions = Vec::new();
512        for vars in initial_vars {
513            let objs = self.evaluate_individual(&vars, &objective_function)?;
514            initial_solutions.push(MultiObjectiveSolution::new(vars, objs));
515        }
516
517        self.update_ideal_point(&initial_solutions);
518        self.population = Population::from_solutions(initial_solutions);
519
520        // Main evolution loop
521        while self.generation < self.config.max_generations {
522            if self.check_convergence() {
523                break;
524            }
525            self.evolve_generation(&objective_function)?;
526        }
527
528        // Extract results
529        let pareto_front = self.population.extract_pareto_front();
530        let hypervolume = self
531            .config
532            .reference_point
533            .as_ref()
534            .map(|rp| utils::calculate_hypervolume(&pareto_front, rp));
535
536        let mut result = MultiObjectiveResult::new(
537            pareto_front,
538            self.population.solutions().to_vec(),
539            self.n_evaluations,
540            self.generation,
541        );
542        result.hypervolume = hypervolume;
543        result.metrics.convergence_history = self.convergence_history.clone();
544        result.metrics.population_stats = self.population.calculate_statistics();
545
546        Ok(result)
547    }
548
549    fn evolve_generation<F>(&mut self, objective_function: &F) -> Result<(), OptimizeError>
550    where
551        F: Fn(&ArrayView1<f64>) -> Array1<f64> + Send + Sync,
552    {
553        let offspring = self.create_offspring(objective_function)?;
554
555        let mut combined = self.population.solutions().to_vec();
556        combined.extend(offspring);
557
558        let next_pop = self.environmental_selection(combined);
559        self.population = Population::from_solutions(next_pop);
560
561        self.generation += 1;
562        self.calculate_metrics();
563
564        Ok(())
565    }
566
567    fn initialize_population(&mut self) -> Result<(), OptimizeError> {
568        self.population.clear();
569        self.generation = 0;
570        self.n_evaluations = 0;
571        self.ideal_point = Array1::from_elem(self.n_objectives, f64::INFINITY);
572        self.niche_count = vec![0; self.reference_points.len()];
573        self.convergence_history.clear();
574        Ok(())
575    }
576
577    fn check_convergence(&self) -> bool {
578        if let Some(max_evals) = self.config.max_evaluations {
579            if self.n_evaluations >= max_evals {
580                return true;
581            }
582        }
583
584        if self.convergence_history.len() >= 10 {
585            let recent = &self.convergence_history[self.convergence_history.len() - 10..];
586            let max_hv = recent.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
587            let min_hv = recent.iter().fold(f64::INFINITY, |a, &b| a.min(b));
588            if (max_hv - min_hv) < self.config.tolerance {
589                return true;
590            }
591        }
592
593        false
594    }
595
596    fn get_population(&self) -> &Population {
597        &self.population
598    }
599
600    fn get_generation(&self) -> usize {
601        self.generation
602    }
603
604    fn get_evaluations(&self) -> usize {
605        self.n_evaluations
606    }
607
608    fn name(&self) -> &str {
609        "NSGA-III"
610    }
611}
612
613/// Compute perpendicular distance from a point to a reference line (through origin)
614fn perpendicular_distance(point: &Array1<f64>, direction: &Array1<f64>) -> f64 {
615    let dir_norm_sq = direction.dot(direction);
616    if dir_norm_sq < 1e-30 {
617        return point.dot(point).sqrt();
618    }
619
620    // Projection of point onto direction
621    let proj_scalar = point.dot(direction) / dir_norm_sq;
622    let proj = proj_scalar * direction;
623
624    // Distance = |point - projection|
625    let diff = point - &proj;
626    diff.dot(&diff).sqrt()
627}
628
629/// Compute binomial coefficient C(n, k)
630fn binomial_coefficient(n: usize, k: usize) -> usize {
631    if k > n {
632        return 0;
633    }
634    if k == 0 || k == n {
635        return 1;
636    }
637    let k = k.min(n - k);
638    let mut result = 1usize;
639    for i in 0..k {
640        result = result.saturating_mul(n - i);
641        result /= i + 1;
642    }
643    result
644}
645
646#[cfg(test)]
647mod tests {
648    use super::*;
649    use scirs2_core::ndarray::{array, s};
650
651    fn zdt1(x: &ArrayView1<f64>) -> Array1<f64> {
652        let f1 = x[0];
653        let g = 1.0 + 9.0 * x.slice(s![1..]).sum() / (x.len() - 1) as f64;
654        let f2 = g * (1.0 - (f1 / g).sqrt());
655        array![f1, f2]
656    }
657
658    fn dtlz1(x: &ArrayView1<f64>) -> Array1<f64> {
659        // 3-objective DTLZ1
660        let n = x.len();
661        let k = n - 2; // for 3 objectives, k = n - M + 1
662        let g: f64 = (0..k)
663            .map(|i| {
664                let xi = x[n - k + i];
665                (xi - 0.5).powi(2) - (20.0 * std::f64::consts::PI * (xi - 0.5)).cos()
666            })
667            .sum::<f64>();
668        let g = 100.0 * (k as f64 + g);
669
670        let f1 = 0.5 * x[0] * x[1] * (1.0 + g);
671        let f2 = 0.5 * x[0] * (1.0 - x[1]) * (1.0 + g);
672        let f3 = 0.5 * (1.0 - x[0]) * (1.0 + g);
673        array![f1, f2, f3]
674    }
675
676    #[test]
677    fn test_nsga3_creation() {
678        let nsga3 = NSGAIII::new(100, 3, 5);
679        assert_eq!(nsga3.n_objectives, 3);
680        assert_eq!(nsga3.n_variables, 5);
681        assert!(!nsga3.reference_points.is_empty());
682        assert_eq!(nsga3.generation, 0);
683    }
684
685    #[test]
686    fn test_nsga3_with_config() {
687        let mut config = MultiObjectiveConfig::default();
688        config.population_size = 50;
689        config.max_generations = 10;
690        config.random_seed = Some(42);
691
692        let nsga3 = NSGAIII::with_config(config, 2, 3, None);
693        assert_eq!(nsga3.n_objectives, 2);
694        assert!(!nsga3.reference_points.is_empty());
695    }
696
697    #[test]
698    fn test_nsga3_custom_reference_points() {
699        let rps = vec![
700            array![1.0, 0.0, 0.0],
701            array![0.0, 1.0, 0.0],
702            array![0.0, 0.0, 1.0],
703            array![0.5, 0.5, 0.0],
704            array![0.5, 0.0, 0.5],
705            array![0.0, 0.5, 0.5],
706            array![0.333, 0.333, 0.334],
707        ];
708
709        let config = MultiObjectiveConfig {
710            population_size: 20,
711            max_generations: 5,
712            random_seed: Some(42),
713            ..Default::default()
714        };
715
716        let nsga3 = NSGAIII::with_config(config, 3, 5, Some(rps.clone()));
717        assert_eq!(nsga3.reference_points.len(), 7);
718    }
719
720    #[test]
721    fn test_nsga3_optimize_zdt1() {
722        let mut config = MultiObjectiveConfig::default();
723        config.max_generations = 10;
724        config.population_size = 20;
725        config.bounds = Some((Array1::zeros(3), Array1::ones(3)));
726        config.random_seed = Some(42);
727
728        let mut nsga3 = NSGAIII::with_config(config, 2, 3, None);
729        let result = nsga3.optimize(zdt1);
730
731        assert!(result.is_ok());
732        let res = result.expect("should succeed");
733        assert!(res.success);
734        assert!(!res.pareto_front.is_empty());
735        assert!(res.n_evaluations > 0);
736    }
737
738    #[test]
739    fn test_nsga3_optimize_dtlz1() {
740        let mut config = MultiObjectiveConfig::default();
741        config.max_generations = 10;
742        config.population_size = 20;
743        config.bounds = Some((Array1::zeros(5), Array1::ones(5)));
744        config.random_seed = Some(42);
745
746        let mut nsga3 = NSGAIII::with_config(config, 3, 5, None);
747        let result = nsga3.optimize(dtlz1);
748
749        assert!(result.is_ok());
750        let res = result.expect("should succeed");
751        assert!(res.success);
752        assert!(!res.pareto_front.is_empty());
753    }
754
755    #[test]
756    fn test_nsga3_max_evaluations() {
757        let mut config = MultiObjectiveConfig::default();
758        config.max_generations = 1000;
759        config.max_evaluations = Some(50);
760        config.population_size = 10;
761        config.bounds = Some((Array1::zeros(3), Array1::ones(3)));
762        config.random_seed = Some(42);
763
764        let mut nsga3 = NSGAIII::with_config(config, 2, 3, None);
765        let result = nsga3.optimize(zdt1);
766        assert!(result.is_ok());
767        let res = result.expect("should succeed");
768        assert!(res.n_evaluations <= 60); // Allow slight overshoot
769    }
770
771    #[test]
772    fn test_perpendicular_distance() {
773        let point = array![1.0, 1.0];
774        let direction = array![1.0, 0.0];
775
776        let dist = perpendicular_distance(&point, &direction);
777        assert!(
778            (dist - 1.0).abs() < 1e-10,
779            "Distance should be 1.0, got {}",
780            dist
781        );
782
783        // Point on the line
784        let point2 = array![2.0, 0.0];
785        let dist2 = perpendicular_distance(&point2, &direction);
786        assert!(dist2.abs() < 1e-10, "Distance should be 0.0, got {}", dist2);
787    }
788
789    #[test]
790    fn test_binomial_coefficient() {
791        assert_eq!(binomial_coefficient(5, 2), 10);
792        assert_eq!(binomial_coefficient(10, 3), 120);
793        assert_eq!(binomial_coefficient(4, 0), 1);
794        assert_eq!(binomial_coefficient(4, 4), 1);
795        assert_eq!(binomial_coefficient(0, 0), 1);
796    }
797
798    #[test]
799    fn test_nsga3_name() {
800        let nsga3 = NSGAIII::new(50, 2, 3);
801        assert_eq!(nsga3.name(), "NSGA-III");
802    }
803
804    #[test]
805    fn test_nsga3_convergence_check() {
806        let config = MultiObjectiveConfig {
807            tolerance: 1e-10,
808            max_generations: 2,
809            population_size: 10,
810            ..Default::default()
811        };
812        let nsga3 = NSGAIII::with_config(config, 2, 2, None);
813        assert!(!nsga3.check_convergence());
814    }
815}