Skip to main content

scirs2_neural/nas/
multi_objective.rs

1//! Multi-objective optimization for Neural Architecture Search
2//!
3//! This module provides multi-objective optimization capabilities for NAS,
4//! allowing optimization of multiple conflicting objectives simultaneously
5//! such as accuracy, latency, FLOPs, memory usage, and energy consumption.
6
7use crate::error::{NeuralError, Result};
8use crate::nas::{architecture_encoding::ArchitectureEncoding, EvaluationMetrics, SearchResult};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12/// Represents an objective to optimize
13#[derive(Debug, Clone)]
14pub struct Objective {
15    pub name: String,
16    pub minimize: bool,
17    pub weight: f64,
18    pub target: Option<f64>,
19    pub tolerance: Option<f64>,
20}
21
22impl Objective {
23    pub fn new(name: &str, minimize: bool, weight: f64) -> Self {
24        Self {
25            name: name.to_string(),
26            minimize,
27            weight,
28            target: None,
29            tolerance: None,
30        }
31    }
32
33    pub fn with_constraint(mut self, target: f64, tolerance: f64) -> Self {
34        self.target = Some(target);
35        self.tolerance = Some(tolerance);
36        self
37    }
38}
39
40pub struct MultiObjectiveConfig {
41    pub objectives: Vec<Objective>,
42    pub algorithm: MultiObjectiveAlgorithm,
43    pub population_size: usize,
44    pub max_generations: usize,
45    pub pareto_front_limit: usize,
46    pub reference_point: Option<Vec<f64>>,
47}
48
49impl Default for MultiObjectiveConfig {
50    fn default() -> Self {
51        Self {
52            objectives: vec![
53                Objective::new("validation_accuracy", false, 0.4),
54                Objective::new("model_flops", true, 0.3),
55                Objective::new("model_params", true, 0.2),
56                Objective::new("inference_latency", true, 0.1),
57            ],
58            algorithm: MultiObjectiveAlgorithm::NSGA2,
59            population_size: 50,
60            max_generations: 100,
61            pareto_front_limit: 20,
62            reference_point: None,
63        }
64    }
65}
66
67pub enum MultiObjectiveAlgorithm {
68    NSGA2,
69    SPEA2,
70    MOEAD,
71    HYPERE,
72    WeightedSum,
73    ConstraintHandling,
74}
75
76pub struct MultiObjectiveSolution {
77    pub architecture: Arc<dyn ArchitectureEncoding>,
78    pub objectives: Vec<f64>,
79    pub constraint_violations: Vec<f64>,
80    pub rank: usize,
81    pub crowding_distance: f64,
82    pub dominance_count: usize,
83    pub dominated_solutions: Vec<usize>,
84}
85
86impl Clone for MultiObjectiveSolution {
87    fn clone(&self) -> Self {
88        Self {
89            architecture: self.architecture.clone(),
90            objectives: self.objectives.clone(),
91            constraint_violations: self.constraint_violations.clone(),
92            rank: self.rank,
93            crowding_distance: self.crowding_distance,
94            dominance_count: self.dominance_count,
95            dominated_solutions: self.dominated_solutions.clone(),
96        }
97    }
98}
99
100impl MultiObjectiveSolution {
101    pub fn new(architecture: Arc<dyn ArchitectureEncoding>, objectives: Vec<f64>) -> Self {
102        Self {
103            architecture,
104            objectives,
105            constraint_violations: Vec::new(),
106            rank: 0,
107            crowding_distance: 0.0,
108            dominance_count: 0,
109            dominated_solutions: Vec::new(),
110        }
111    }
112
113    pub fn dominates(&self, other: &Self, config: &MultiObjectiveConfig) -> bool {
114        let mut better = false;
115        for (i, obj) in config.objectives.iter().enumerate() {
116            if i >= self.objectives.len() || i >= other.objectives.len() {
117                continue;
118            }
119            let sv = self.objectives[i];
120            let ov = other.objectives[i];
121            if obj.minimize {
122                if sv > ov {
123                    return false;
124                } else if sv < ov {
125                    better = true;
126                }
127            } else {
128                if sv < ov {
129                    return false;
130                } else if sv > ov {
131                    better = true;
132                }
133            }
134        }
135        better
136    }
137}
138
139pub struct MultiObjectiveOptimizer {
140    config: MultiObjectiveConfig,
141    population: Vec<MultiObjectiveSolution>,
142    pareto_front: Vec<MultiObjectiveSolution>,
143    generation: usize,
144    hypervolume_history: Vec<f64>,
145}
146
147impl MultiObjectiveOptimizer {
148    pub fn new(config: MultiObjectiveConfig) -> Self {
149        Self {
150            config,
151            population: Vec::new(),
152            pareto_front: Vec::new(),
153            generation: 0,
154            hypervolume_history: Vec::new(),
155        }
156    }
157
158    pub fn initialize_population(&mut self, results: &[SearchResult]) -> Result<()> {
159        self.population.clear();
160        for result in results.iter().take(self.config.population_size) {
161            let objectives = self.extract_objectives(&result.metrics)?;
162            self.population.push(MultiObjectiveSolution::new(
163                result.architecture.clone(),
164                objectives,
165            ));
166        }
167        while self.population.len() < self.config.population_size {
168            let arch = self.generate_random_architecture()?;
169            let objs = self.estimate_random_objectives();
170            self.population
171                .push(MultiObjectiveSolution::new(arch, objs));
172        }
173        Ok(())
174    }
175
176    pub fn evolve_generation(&mut self) -> Result<()> {
177        match self.config.algorithm {
178            MultiObjectiveAlgorithm::NSGA2 => self.nsga2_step()?,
179            MultiObjectiveAlgorithm::SPEA2 => self.spea2_step()?,
180            MultiObjectiveAlgorithm::MOEAD => self.moead_step()?,
181            MultiObjectiveAlgorithm::HYPERE => self.hypere_step()?,
182            MultiObjectiveAlgorithm::WeightedSum => self.weighted_sum_step()?,
183            MultiObjectiveAlgorithm::ConstraintHandling => self.constraint_handling_step()?,
184        }
185        self.generation += 1;
186        self.update_pareto_front()?;
187        let hv = self.compute_hypervolume()?;
188        self.hypervolume_history.push(hv);
189        Ok(())
190    }
191
192    fn nsga2_step(&mut self) -> Result<()> {
193        let offspring = self.create_offspring()?;
194        let mut combined = self.population.clone();
195        combined.extend(offspring);
196        self.non_dominated_sort(&mut combined)?;
197        self.population = self.environmental_selection(combined)?;
198        Ok(())
199    }
200
201    fn spea2_step(&mut self) -> Result<()> {
202        let offspring = self.create_offspring()?;
203        let mut combined = self.population.clone();
204        combined.extend(offspring);
205        self.calculate_spea2_fitness_for_population(&mut combined)?;
206        self.population = self.spea2_environmental_selection(combined)?;
207        Ok(())
208    }
209
210    fn moead_step(&mut self) -> Result<()> {
211        let weight_vectors = self.generate_weight_vectors()?;
212        for (i, weights) in weight_vectors
213            .iter()
214            .enumerate()
215            .take(weight_vectors.len().min(self.population.len()))
216        {
217            let weights = weights.clone();
218            let new_solution = self.update_subproblem(i, &weights)?;
219            self.update_neighbors(i, &new_solution)?;
220        }
221        Ok(())
222    }
223
224    fn hypere_step(&mut self) -> Result<()> {
225        let parent_count = 10.min(self.population.len());
226        let mut offspring = Vec::new();
227        for idx in 0..parent_count {
228            let child_arch_box = self.population[idx].architecture.mutate(0.1)?;
229            let child_arch: std::sync::Arc<
230                dyn crate::nas::architecture_encoding::ArchitectureEncoding,
231            > = std::sync::Arc::from(child_arch_box);
232            let objectives = self.estimate_objectives(&child_arch)?;
233            offspring.push(MultiObjectiveSolution::new(child_arch, objectives));
234        }
235        let mut combined = self.population.clone();
236        combined.extend(offspring);
237        self.population = self.hypervolume_environmental_selection(combined)?;
238        Ok(())
239    }
240
241    fn weighted_sum_step(&mut self) -> Result<()> {
242        for solution in &mut self.population {
243            let ws: f64 = solution
244                .objectives
245                .iter()
246                .zip(self.config.objectives.iter())
247                .map(|(v, o)| v * o.weight)
248                .sum();
249            solution.objectives = vec![ws];
250        }
251        self.population.sort_by(|a, b| {
252            let ao = a.objectives.first().copied().unwrap_or(0.0);
253            let bo = b.objectives.first().copied().unwrap_or(0.0);
254            ao.partial_cmp(&bo).unwrap_or(std::cmp::Ordering::Equal)
255        });
256        let offspring = self.create_offspring()?;
257        self.population.extend(offspring);
258        self.population.truncate(self.config.population_size);
259        Ok(())
260    }
261
262    fn constraint_handling_step(&mut self) -> Result<()> {
263        let violations: Vec<Vec<f64>> = self
264            .population
265            .iter()
266            .map(|s| self.evaluate_constraints(s))
267            .collect::<Result<Vec<_>>>()?;
268        for (sol, viols) in self.population.iter_mut().zip(violations) {
269            sol.constraint_violations = viols;
270        }
271        self.population.sort_by(|a, b| {
272            let av: f64 = a.constraint_violations.iter().sum();
273            let bv: f64 = b.constraint_violations.iter().sum();
274            if (av - bv).abs() > 1e-12 {
275                av.partial_cmp(&bv).unwrap_or(std::cmp::Ordering::Equal)
276            } else {
277                a.objectives
278                    .first()
279                    .copied()
280                    .unwrap_or(0.0)
281                    .partial_cmp(&b.objectives.first().copied().unwrap_or(0.0))
282                    .unwrap_or(std::cmp::Ordering::Equal)
283            }
284        });
285        let offspring = self.create_offspring()?;
286        self.population = self.constraint_environmental_selection(offspring)?;
287        Ok(())
288    }
289
290    fn non_dominated_sort(&self, population: &mut [MultiObjectiveSolution]) -> Result<()> {
291        let n = population.len();
292        let mut dominated_by: Vec<Vec<usize>> = vec![Vec::new(); n];
293        let mut dom_counts: Vec<usize> = vec![0; n];
294
295        for i in 0..n {
296            for j in 0..n {
297                if i == j {
298                    continue;
299                }
300                if self.dominates_by_values(&population[i].objectives, &population[j].objectives) {
301                    dominated_by[i].push(j);
302                } else if self
303                    .dominates_by_values(&population[j].objectives, &population[i].objectives)
304                {
305                    dom_counts[i] += 1;
306                }
307            }
308        }
309
310        let mut first_front = Vec::new();
311        for i in 0..n {
312            population[i].dominated_solutions = dominated_by[i].clone();
313            population[i].dominance_count = dom_counts[i];
314            if dom_counts[i] == 0 {
315                population[i].rank = 0;
316                first_front.push(i);
317            }
318        }
319
320        let mut fronts = vec![first_front];
321        let mut fi = 0;
322        while fi < fronts.len() && !fronts[fi].is_empty() {
323            let mut next_front = Vec::new();
324            let current = fronts[fi].clone();
325            for &i in &current {
326                let doms = population[i].dominated_solutions.clone();
327                for &j in &doms {
328                    if population[j].dominance_count > 0 {
329                        population[j].dominance_count -= 1;
330                        if population[j].dominance_count == 0 {
331                            population[j].rank = fi + 1;
332                            next_front.push(j);
333                        }
334                    }
335                }
336            }
337            fi += 1;
338            fronts.push(next_front);
339        }
340        Ok(())
341    }
342
343    fn dominates_by_values(&self, a: &[f64], b: &[f64]) -> bool {
344        let mut better = false;
345        for (k, obj_cfg) in self.config.objectives.iter().enumerate() {
346            let oa = a.get(k).copied().unwrap_or(0.0);
347            let ob = b.get(k).copied().unwrap_or(0.0);
348            if obj_cfg.minimize {
349                if oa > ob {
350                    return false;
351                } else if oa < ob {
352                    better = true;
353                }
354            } else {
355                if oa < ob {
356                    return false;
357                } else if oa > ob {
358                    better = true;
359                }
360            }
361        }
362        better
363    }
364
365    fn calculate_crowding_distance(
366        &self,
367        front: &[usize],
368        population: &mut [MultiObjectiveSolution],
369    ) -> Result<()> {
370        if front.len() <= 2 {
371            for &i in front {
372                population[i].crowding_distance = f64::INFINITY;
373            }
374            return Ok(());
375        }
376        for &i in front {
377            population[i].crowding_distance = 0.0;
378        }
379        for obj_idx in 0..self.config.objectives.len() {
380            let mut sorted = front.to_vec();
381            sorted.sort_by(|&a, &b| {
382                let oa = population[a]
383                    .objectives
384                    .get(obj_idx)
385                    .copied()
386                    .unwrap_or(0.0);
387                let ob = population[b]
388                    .objectives
389                    .get(obj_idx)
390                    .copied()
391                    .unwrap_or(0.0);
392                oa.partial_cmp(&ob).unwrap_or(std::cmp::Ordering::Equal)
393            });
394            let first = sorted[0];
395            let last = sorted[sorted.len() - 1];
396            population[first].crowding_distance = f64::INFINITY;
397            population[last].crowding_distance = f64::INFINITY;
398            let obj_min = population[first]
399                .objectives
400                .get(obj_idx)
401                .copied()
402                .unwrap_or(0.0);
403            let obj_max = population[last]
404                .objectives
405                .get(obj_idx)
406                .copied()
407                .unwrap_or(0.0);
408            let range = obj_max - obj_min;
409            if range > 0.0 {
410                for i in 1..sorted.len() - 1 {
411                    let prev = population[sorted[i - 1]]
412                        .objectives
413                        .get(obj_idx)
414                        .copied()
415                        .unwrap_or(0.0);
416                    let next = population[sorted[i + 1]]
417                        .objectives
418                        .get(obj_idx)
419                        .copied()
420                        .unwrap_or(0.0);
421                    population[sorted[i]].crowding_distance += (next - prev) / range;
422                }
423            }
424        }
425        Ok(())
426    }
427
428    fn environmental_selection(
429        &mut self,
430        mut population: Vec<MultiObjectiveSolution>,
431    ) -> Result<Vec<MultiObjectiveSolution>> {
432        let mut result = Vec::new();
433        let mut fronts: HashMap<usize, Vec<usize>> = HashMap::new();
434        for (i, s) in population.iter().enumerate() {
435            fronts.entry(s.rank).or_default().push(i);
436        }
437        let mut current_front = 0;
438        while current_front < fronts.len() {
439            if let Some(front) = fronts.get(&current_front) {
440                if result.len() + front.len() <= self.config.population_size {
441                    for &i in front {
442                        result.push(population[i].clone());
443                    }
444                } else {
445                    self.calculate_crowding_distance(front, &mut population)?;
446                    let mut fd: Vec<(usize, f64)> = front
447                        .iter()
448                        .map(|&i| (i, population[i].crowding_distance))
449                        .collect();
450                    fd.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
451                    let remaining = self.config.population_size - result.len();
452                    for item in fd.into_iter().take(remaining) {
453                        result.push(population[item.0].clone());
454                    }
455                    break;
456                }
457            }
458            current_front += 1;
459        }
460        Ok(result)
461    }
462
463    fn create_offspring(&self) -> Result<Vec<MultiObjectiveSolution>> {
464        if self.population.is_empty() {
465            return Ok(Vec::new());
466        }
467        let mut offspring = Vec::new();
468        for _ in 0..self.config.population_size {
469            let p1 = self.tournament_selection()?;
470            let p2 = self.tournament_selection()?;
471            let child = p1.architecture.crossover(p2.architecture.as_ref())?;
472            let mutated_box = child.mutate(0.1)?;
473            let mutated: Arc<dyn ArchitectureEncoding> = Arc::from(mutated_box);
474            let objectives = self.estimate_objectives(&mutated)?;
475            offspring.push(MultiObjectiveSolution::new(mutated, objectives));
476        }
477        Ok(offspring)
478    }
479
480    fn tournament_selection(&self) -> Result<&MultiObjectiveSolution> {
481        use scirs2_core::random::prelude::*;
482        let mut rng_inst = thread_rng();
483        if self.population.is_empty() {
484            return Err(NeuralError::InvalidArgument(
485                "Population is empty".to_string(),
486            ));
487        }
488        let mut best = rng_inst.random_range(0..self.population.len());
489        for _ in 1..3 {
490            let candidate = rng_inst.random_range(0..self.population.len());
491            if self.is_better(&self.population[candidate], &self.population[best]) {
492                best = candidate;
493            }
494        }
495        Ok(&self.population[best])
496    }
497
498    fn is_better(&self, a: &MultiObjectiveSolution, b: &MultiObjectiveSolution) -> bool {
499        if a.rank < b.rank {
500            true
501        } else if a.rank > b.rank {
502            false
503        } else {
504            a.crowding_distance > b.crowding_distance
505        }
506    }
507
508    fn extract_objectives(&self, metrics: &EvaluationMetrics) -> Result<Vec<f64>> {
509        Ok(self
510            .config
511            .objectives
512            .iter()
513            .map(|o| metrics.get(&o.name).copied().unwrap_or(0.0))
514            .collect())
515    }
516
517    fn estimate_objectives(&self, _arch: &Arc<dyn ArchitectureEncoding>) -> Result<Vec<f64>> {
518        Ok(self
519            .config
520            .objectives
521            .iter()
522            .map(|o| match o.name.as_str() {
523                "validation_accuracy" => 0.7 + 0.2 * scirs2_core::random::random::<f64>(),
524                "model_flops" => 1e6 + 1e6 * scirs2_core::random::random::<f64>(),
525                "model_params" => 1e5 + 1e5 * scirs2_core::random::random::<f64>(),
526                "inference_latency" => 10.0 + 10.0 * scirs2_core::random::random::<f64>(),
527                _ => 0.5,
528            })
529            .collect())
530    }
531
532    fn generate_random_architecture(&self) -> Result<Arc<dyn ArchitectureEncoding>> {
533        use scirs2_core::random::prelude::*;
534        let mut rng_inst = thread_rng();
535        let enc = crate::nas::architecture_encoding::SequentialEncoding::random(&mut rng_inst)?;
536        Ok(Arc::new(enc) as Arc<dyn ArchitectureEncoding>)
537    }
538
539    fn estimate_random_objectives(&self) -> Vec<f64> {
540        self.config
541            .objectives
542            .iter()
543            .map(|o| match o.name.as_str() {
544                "validation_accuracy" => 0.3 + 0.4 * scirs2_core::random::random::<f64>(),
545                "model_flops" => 1e5 + 1e6 * scirs2_core::random::random::<f64>(),
546                "model_params" => 1e4 + 1e5 * scirs2_core::random::random::<f64>(),
547                "inference_latency" => 1.0 + 20.0 * scirs2_core::random::random::<f64>(),
548                _ => scirs2_core::random::random::<f64>(),
549            })
550            .collect()
551    }
552
553    fn update_pareto_front(&mut self) -> Result<()> {
554        let mut pareto_indices = Vec::new();
555        for i in 0..self.population.len() {
556            let mut dominated = false;
557            for j in 0..self.population.len() {
558                if i != j
559                    && self.dominates_by_values(
560                        &self.population[j].objectives.clone(),
561                        &self.population[i].objectives.clone(),
562                    )
563                {
564                    dominated = true;
565                    break;
566                }
567            }
568            if !dominated {
569                pareto_indices.push(i);
570            }
571        }
572        let mut pareto: Vec<MultiObjectiveSolution> = pareto_indices
573            .iter()
574            .map(|&i| self.population[i].clone())
575            .collect();
576        if pareto.len() > self.config.pareto_front_limit {
577            let indices: Vec<usize> = (0..pareto.len()).collect();
578            self.calculate_crowding_distance(&indices, &mut pareto)?;
579            pareto.sort_by(|a, b| {
580                b.crowding_distance
581                    .partial_cmp(&a.crowding_distance)
582                    .unwrap_or(std::cmp::Ordering::Equal)
583            });
584            pareto.truncate(self.config.pareto_front_limit);
585        }
586        self.pareto_front = pareto;
587        Ok(())
588    }
589
590    fn compute_hypervolume(&self) -> Result<f64> {
591        if self.pareto_front.is_empty() {
592            return Ok(0.0);
593        }
594        let rp = self
595            .config
596            .reference_point
597            .as_ref()
598            .cloned()
599            .unwrap_or_else(|| self.estimate_reference_point());
600        match self.config.objectives.len() {
601            2 => self.compute_hypervolume_2d(&rp),
602            3 => self.compute_hypervolume_3d(&rp),
603            _ => self.compute_hypervolume_monte_carlo(&rp),
604        }
605    }
606
607    fn estimate_reference_point(&self) -> Vec<f64> {
608        let n = self.config.objectives.len();
609        let mut rp = vec![0.0f64; n];
610        for (i, obj) in self.config.objectives.iter().enumerate() {
611            if obj.minimize {
612                let max_val = self
613                    .pareto_front
614                    .iter()
615                    .filter_map(|s| s.objectives.get(i).copied())
616                    .fold(f64::NEG_INFINITY, f64::max);
617                rp[i] = if max_val.is_finite() {
618                    max_val * 1.1
619                } else {
620                    1.0
621                };
622            } else {
623                let min_val = self
624                    .pareto_front
625                    .iter()
626                    .filter_map(|s| s.objectives.get(i).copied())
627                    .fold(f64::INFINITY, f64::min);
628                rp[i] = if min_val.is_finite() {
629                    min_val * 0.9
630                } else {
631                    0.0
632                };
633            }
634        }
635        rp
636    }
637
638    fn compute_hypervolume_2d(&self, rp: &[f64]) -> Result<f64> {
639        let min0 = self
640            .config
641            .objectives
642            .first()
643            .map(|o| o.minimize)
644            .unwrap_or(true);
645        let min1 = self
646            .config
647            .objectives
648            .get(1)
649            .map(|o| o.minimize)
650            .unwrap_or(true);
651        let rp0 = rp.first().copied().unwrap_or(0.0);
652        let rp1 = rp.get(1).copied().unwrap_or(0.0);
653        let mut points: Vec<(f64, f64)> = self
654            .pareto_front
655            .iter()
656            .map(|s| {
657                let v0 = s.objectives.first().copied().unwrap_or(0.0);
658                let v1 = s.objectives.get(1).copied().unwrap_or(0.0);
659                let x = if min0 {
660                    (rp0 - v0).max(0.0)
661                } else {
662                    (v0 - rp0).max(0.0)
663                };
664                let y = if min1 {
665                    (rp1 - v1).max(0.0)
666                } else {
667                    (v1 - rp1).max(0.0)
668                };
669                (x, y)
670            })
671            .collect();
672        points.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
673        let mut volume = 0.0f64;
674        let mut prev_y = 0.0f64;
675        for (x, y) in points {
676            if y > prev_y {
677                volume += x * (y - prev_y);
678                prev_y = y;
679            }
680        }
681        Ok(volume)
682    }
683
684    fn compute_hypervolume_3d(&self, rp: &[f64]) -> Result<f64> {
685        let points: Vec<[f64; 3]> = self
686            .pareto_front
687            .iter()
688            .map(|s| {
689                let mut arr = [0.0f64; 3];
690                for (i, cell) in arr.iter_mut().enumerate() {
691                    let r = rp.get(i).copied().unwrap_or(0.0);
692                    let v = s.objectives.get(i).copied().unwrap_or(0.0);
693                    *cell = if self
694                        .config
695                        .objectives
696                        .get(i)
697                        .map(|o| o.minimize)
698                        .unwrap_or(true)
699                    {
700                        (r - v).max(0.0)
701                    } else {
702                        (v - r).max(0.0)
703                    };
704                }
705                arr
706            })
707            .collect();
708        let n = points.len();
709        let mut volume = 0.0f64;
710        for p in &points {
711            volume += p[0] * p[1] * p[2];
712        }
713        for i in 0..n {
714            for j in (i + 1)..n {
715                volume -= points[i][0].min(points[j][0])
716                    * points[i][1].min(points[j][1])
717                    * points[i][2].min(points[j][2]);
718            }
719        }
720        Ok(volume.max(0.0))
721    }
722
723    fn compute_hypervolume_monte_carlo(&self, rp: &[f64]) -> Result<f64> {
724        use scirs2_core::random::prelude::*;
725        let mut rng_inst = thread_rng();
726        let num_samples = 10000usize;
727        let n_obj = self.config.objectives.len();
728        let mut lower_bounds = vec![f64::INFINITY; n_obj];
729        let upper_bounds = rp.to_vec();
730        for sol in &self.pareto_front {
731            for (i, &v) in sol.objectives.iter().enumerate() {
732                if i < n_obj {
733                    lower_bounds[i] = lower_bounds[i].min(v);
734                }
735            }
736        }
737        for (i, lb) in lower_bounds.iter_mut().enumerate() {
738            if !lb.is_finite() {
739                *lb = upper_bounds.get(i).copied().unwrap_or(0.0) - 1.0;
740            }
741        }
742        let mut dominated_count = 0usize;
743        for _ in 0..num_samples {
744            let sample: Vec<f64> = (0..n_obj)
745                .map(|i| {
746                    let lo = lower_bounds[i];
747                    let hi = upper_bounds.get(i).copied().unwrap_or(lo + 1.0);
748                    if hi > lo {
749                        lo + rng_inst.random::<f64>() * (hi - lo)
750                    } else {
751                        lo
752                    }
753                })
754                .collect();
755            let mut is_dominated = false;
756            'outer: for sol in &self.pareto_front {
757                let mut dom = true;
758                let mut better = false;
759                for (i, (&sv, &pv)) in sol.objectives.iter().zip(sample.iter()).enumerate() {
760                    let min = self
761                        .config
762                        .objectives
763                        .get(i)
764                        .map(|o| o.minimize)
765                        .unwrap_or(true);
766                    if min {
767                        if sv > pv {
768                            dom = false;
769                            break;
770                        } else if sv < pv {
771                            better = true;
772                        }
773                    } else {
774                        if sv < pv {
775                            dom = false;
776                            break;
777                        } else if sv > pv {
778                            better = true;
779                        }
780                    }
781                }
782                if dom && better {
783                    is_dominated = true;
784                    break 'outer;
785                }
786            }
787            if is_dominated {
788                dominated_count += 1;
789            }
790        }
791        let sampling_vol: f64 = upper_bounds
792            .iter()
793            .zip(lower_bounds.iter())
794            .map(|(u, l)| (u - l).max(0.0))
795            .product();
796        Ok(sampling_vol * (dominated_count as f64 / num_samples as f64))
797    }
798
799    pub fn get_pareto_front(&self) -> &[MultiObjectiveSolution] {
800        &self.pareto_front
801    }
802    pub fn get_hypervolume_history(&self) -> &[f64] {
803        &self.hypervolume_history
804    }
805    pub fn get_generation(&self) -> usize {
806        self.generation
807    }
808
809    fn calculate_spea2_fitness_for_population(
810        &self,
811        population: &mut [MultiObjectiveSolution],
812    ) -> Result<()> {
813        let n = population.len();
814        let mut strengths = vec![0usize; n];
815        let mut raw_fitness = vec![0.0f64; n];
816        let mut densities = vec![0.0f64; n];
817        for i in 0..n {
818            let mut count = 0;
819            for j in 0..n {
820                if i != j {
821                    let oi: f64 = population[i].objectives.iter().sum();
822                    let oj: f64 = population[j].objectives.iter().sum();
823                    if oi < oj {
824                        count += 1;
825                    }
826                }
827            }
828            strengths[i] = count;
829        }
830        for i in 0..n {
831            let mut fitness = 0.0;
832            for j in 0..n {
833                if i != j {
834                    let oi: f64 = population[i].objectives.iter().sum();
835                    let oj: f64 = population[j].objectives.iter().sum();
836                    if oj < oi {
837                        fitness += strengths[j] as f64;
838                    }
839                }
840            }
841            raw_fitness[i] = fitness;
842        }
843        let k = (n as f64).sqrt() as usize;
844        for i in 0..n {
845            let mut dists: Vec<f64> = (0..n)
846                .filter(|&j| j != i)
847                .map(|j| self.euclidean_distance(&population[i], &population[j]))
848                .collect();
849            dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
850            let kth = if k > 0 && k <= dists.len() {
851                dists[k - 1]
852            } else {
853                dists.last().copied().unwrap_or(0.0)
854            };
855            densities[i] = 1.0 / (kth + 2.0);
856        }
857        for i in 0..n {
858            population[i].crowding_distance = raw_fitness[i] + densities[i];
859        }
860        Ok(())
861    }
862
863    fn spea2_environmental_selection(
864        &self,
865        mut population: Vec<MultiObjectiveSolution>,
866    ) -> Result<Vec<MultiObjectiveSolution>> {
867        population.sort_by(|a, b| {
868            a.crowding_distance
869                .partial_cmp(&b.crowding_distance)
870                .unwrap_or(std::cmp::Ordering::Equal)
871        });
872        let mut selected = Vec::new();
873        for sol in &population {
874            if sol.crowding_distance < 1.0 && selected.len() < self.config.population_size {
875                selected.push(sol.clone());
876            }
877        }
878        if selected.len() < self.config.population_size {
879            for sol in &population {
880                if sol.crowding_distance >= 1.0 && selected.len() < self.config.population_size {
881                    selected.push(sol.clone());
882                }
883            }
884        }
885        selected.truncate(self.config.population_size);
886        Ok(selected)
887    }
888
889    fn euclidean_distance(&self, a: &MultiObjectiveSolution, b: &MultiObjectiveSolution) -> f64 {
890        a.objectives
891            .iter()
892            .zip(b.objectives.iter())
893            .map(|(x, y)| (x - y).powi(2))
894            .sum::<f64>()
895            .sqrt()
896    }
897
898    fn generate_weight_vectors(&self) -> Result<Vec<Vec<f64>>> {
899        let n_obj = self.config.objectives.len();
900        let n_weights = self.config.population_size;
901        let mut weights = Vec::new();
902        if n_obj == 2 {
903            for i in 0..n_weights {
904                let w1 = i as f64 / (n_weights - 1).max(1) as f64;
905                weights.push(vec![w1, 1.0 - w1]);
906            }
907        } else {
908            while weights.len() < n_weights {
909                let raw: Vec<f64> = (0..n_obj)
910                    .map(|_| scirs2_core::random::random::<f64>())
911                    .collect();
912                let sum: f64 = raw.iter().sum();
913                if sum > 1e-12 {
914                    weights.push(raw.iter().map(|w| w / sum).collect());
915                }
916            }
917        }
918        weights.truncate(n_weights);
919        Ok(weights)
920    }
921
922    fn update_subproblem(&self, index: usize, weights: &[f64]) -> Result<MultiObjectiveSolution> {
923        if index >= self.population.len() {
924            return Err(NeuralError::InvalidArgument(
925                "Subproblem index out of bounds".to_string(),
926            ));
927        }
928        let current = &self.population[index];
929        let neighbor = self.select_neighbor(index)?;
930        let p2 = &self.population[neighbor];
931        let child = current.architecture.crossover(p2.architecture.as_ref())?;
932        let mutated_box = child.mutate(0.1)?;
933        let mutated: Arc<dyn ArchitectureEncoding> = Arc::from(mutated_box);
934        let objectives = self.estimate_objectives(&mutated)?;
935        let mut child_sol = MultiObjectiveSolution::new(mutated, objectives);
936        let cur_fit = self.tchebycheff_fitness(&current.objectives, weights);
937        let child_fit = self.tchebycheff_fitness(&child_sol.objectives, weights);
938        if child_fit < cur_fit {
939            child_sol.crowding_distance = child_fit;
940            Ok(child_sol)
941        } else {
942            let mut cur_clone = current.clone();
943            cur_clone.crowding_distance = cur_fit;
944            Ok(cur_clone)
945        }
946    }
947
948    fn select_neighbor(&self, index: usize) -> Result<usize> {
949        use scirs2_core::random::prelude::*;
950        let mut rng_inst = thread_rng();
951        if self.population.len() <= 1 {
952            return Ok(0);
953        }
954        let nbhood = 10.min(self.population.len());
955        let start = index.saturating_sub(nbhood / 2);
956        let end = (index + nbhood / 2).min(self.population.len() - 1);
957        if end <= start {
958            return Ok(if index > 0 { index - 1 } else { 0 });
959        }
960        let ni = rng_inst.random_range(start..=end);
961        if ni == index && end > start {
962            Ok(if ni == start { end } else { start })
963        } else {
964            Ok(ni)
965        }
966    }
967
968    fn tchebycheff_fitness(&self, objectives: &[f64], weights: &[f64]) -> f64 {
969        let mut max_diff = 0.0f64;
970        for (i, (&v, &w)) in objectives.iter().zip(weights.iter()).enumerate() {
971            let ideal = if self
972                .config
973                .objectives
974                .get(i)
975                .map(|o| o.minimize)
976                .unwrap_or(true)
977            {
978                0.0
979            } else {
980                1.0
981            };
982            max_diff = max_diff.max(w * (v - ideal).abs());
983        }
984        max_diff
985    }
986
987    fn update_neighbors(&mut self, index: usize, solution: &MultiObjectiveSolution) -> Result<()> {
988        let nbhood = 10.min(self.population.len());
989        let start = index.saturating_sub(nbhood / 2);
990        let end = (index + nbhood / 2).min(self.population.len());
991        let wvecs = self.generate_weight_vectors()?;
992        for i in start..end {
993            if i != index && i < wvecs.len() {
994                let w = wvecs[i].clone();
995                let cur_fit = self.tchebycheff_fitness(&self.population[i].objectives, &w);
996                let new_fit = self.tchebycheff_fitness(&solution.objectives, &w);
997                if new_fit < cur_fit {
998                    self.population[i] = solution.clone();
999                }
1000            }
1001        }
1002        Ok(())
1003    }
1004
1005    fn hypervolume_environmental_selection(
1006        &self,
1007        mut combined: Vec<MultiObjectiveSolution>,
1008    ) -> Result<Vec<MultiObjectiveSolution>> {
1009        combined.sort_by(|a, b| {
1010            b.crowding_distance
1011                .partial_cmp(&a.crowding_distance)
1012                .unwrap_or(std::cmp::Ordering::Equal)
1013        });
1014        combined.truncate(self.config.population_size);
1015        Ok(combined)
1016    }
1017
1018    fn evaluate_constraints(&self, solution: &MultiObjectiveSolution) -> Result<Vec<f64>> {
1019        let mut violations = Vec::new();
1020        for (i, obj) in self.config.objectives.iter().enumerate() {
1021            if let (Some(target), Some(tol)) = (obj.target, obj.tolerance) {
1022                let v = solution.objectives.get(i).copied().unwrap_or(0.0);
1023                violations.push(((v - target).abs() - tol).max(0.0));
1024            }
1025        }
1026        Ok(violations)
1027    }
1028
1029    fn constraint_environmental_selection(
1030        &self,
1031        mut offspring: Vec<MultiObjectiveSolution>,
1032    ) -> Result<Vec<MultiObjectiveSolution>> {
1033        let mut combined = self.population.clone();
1034        combined.append(&mut offspring);
1035        combined.sort_by(|a, b| {
1036            let av: f64 = a.constraint_violations.iter().sum();
1037            let bv: f64 = b.constraint_violations.iter().sum();
1038            if (av - bv).abs() > 1e-12 {
1039                av.partial_cmp(&bv).unwrap_or(std::cmp::Ordering::Equal)
1040            } else {
1041                a.objectives
1042                    .first()
1043                    .copied()
1044                    .unwrap_or(0.0)
1045                    .partial_cmp(&b.objectives.first().copied().unwrap_or(0.0))
1046                    .unwrap_or(std::cmp::Ordering::Equal)
1047            }
1048        });
1049        combined.truncate(self.config.population_size);
1050        Ok(combined)
1051    }
1052}
1053
1054#[cfg(test)]
1055mod tests {
1056    use super::*;
1057
1058    #[test]
1059    fn test_multi_objective_config() {
1060        let config = MultiObjectiveConfig::default();
1061        assert_eq!(config.objectives.len(), 4);
1062        assert_eq!(config.population_size, 50);
1063    }
1064
1065    #[test]
1066    fn test_solution_dominance() {
1067        let config = MultiObjectiveConfig::default();
1068        let arch1 = Arc::new(crate::nas::architecture_encoding::SequentialEncoding::new(
1069            vec![],
1070        ));
1071        let arch2 = Arc::new(crate::nas::architecture_encoding::SequentialEncoding::new(
1072            vec![],
1073        ));
1074        let sol1 = MultiObjectiveSolution::new(arch1, vec![0.9, 1000.0, 500.0, 5.0]);
1075        let sol2 = MultiObjectiveSolution::new(arch2, vec![0.8, 500.0, 250.0, 2.5]);
1076        assert!(!sol1.dominates(&sol2, &config));
1077        assert!(!sol2.dominates(&sol1, &config));
1078    }
1079
1080    #[test]
1081    fn test_optimizer_creation() {
1082        let config = MultiObjectiveConfig::default();
1083        let optimizer = MultiObjectiveOptimizer::new(config);
1084        assert_eq!(optimizer.generation, 0);
1085        assert!(optimizer.pareto_front.is_empty());
1086    }
1087}