1use crate::error::{NeuralError, Result};
8use crate::nas::{architecture_encoding::ArchitectureEncoding, EvaluationMetrics, SearchResult};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12#[derive(Debug, Clone)]
14pub struct Objective {
15 pub name: String,
16 pub minimize: bool,
17 pub weight: f64,
18 pub target: Option<f64>,
19 pub tolerance: Option<f64>,
20}
21
22impl Objective {
23 pub fn new(name: &str, minimize: bool, weight: f64) -> Self {
24 Self {
25 name: name.to_string(),
26 minimize,
27 weight,
28 target: None,
29 tolerance: None,
30 }
31 }
32
33 pub fn with_constraint(mut self, target: f64, tolerance: f64) -> Self {
34 self.target = Some(target);
35 self.tolerance = Some(tolerance);
36 self
37 }
38}
39
40pub struct MultiObjectiveConfig {
41 pub objectives: Vec<Objective>,
42 pub algorithm: MultiObjectiveAlgorithm,
43 pub population_size: usize,
44 pub max_generations: usize,
45 pub pareto_front_limit: usize,
46 pub reference_point: Option<Vec<f64>>,
47}
48
49impl Default for MultiObjectiveConfig {
50 fn default() -> Self {
51 Self {
52 objectives: vec![
53 Objective::new("validation_accuracy", false, 0.4),
54 Objective::new("model_flops", true, 0.3),
55 Objective::new("model_params", true, 0.2),
56 Objective::new("inference_latency", true, 0.1),
57 ],
58 algorithm: MultiObjectiveAlgorithm::NSGA2,
59 population_size: 50,
60 max_generations: 100,
61 pareto_front_limit: 20,
62 reference_point: None,
63 }
64 }
65}
66
67pub enum MultiObjectiveAlgorithm {
68 NSGA2,
69 SPEA2,
70 MOEAD,
71 HYPERE,
72 WeightedSum,
73 ConstraintHandling,
74}
75
76pub struct MultiObjectiveSolution {
77 pub architecture: Arc<dyn ArchitectureEncoding>,
78 pub objectives: Vec<f64>,
79 pub constraint_violations: Vec<f64>,
80 pub rank: usize,
81 pub crowding_distance: f64,
82 pub dominance_count: usize,
83 pub dominated_solutions: Vec<usize>,
84}
85
86impl Clone for MultiObjectiveSolution {
87 fn clone(&self) -> Self {
88 Self {
89 architecture: self.architecture.clone(),
90 objectives: self.objectives.clone(),
91 constraint_violations: self.constraint_violations.clone(),
92 rank: self.rank,
93 crowding_distance: self.crowding_distance,
94 dominance_count: self.dominance_count,
95 dominated_solutions: self.dominated_solutions.clone(),
96 }
97 }
98}
99
100impl MultiObjectiveSolution {
101 pub fn new(architecture: Arc<dyn ArchitectureEncoding>, objectives: Vec<f64>) -> Self {
102 Self {
103 architecture,
104 objectives,
105 constraint_violations: Vec::new(),
106 rank: 0,
107 crowding_distance: 0.0,
108 dominance_count: 0,
109 dominated_solutions: Vec::new(),
110 }
111 }
112
113 pub fn dominates(&self, other: &Self, config: &MultiObjectiveConfig) -> bool {
114 let mut better = false;
115 for (i, obj) in config.objectives.iter().enumerate() {
116 if i >= self.objectives.len() || i >= other.objectives.len() {
117 continue;
118 }
119 let sv = self.objectives[i];
120 let ov = other.objectives[i];
121 if obj.minimize {
122 if sv > ov {
123 return false;
124 } else if sv < ov {
125 better = true;
126 }
127 } else {
128 if sv < ov {
129 return false;
130 } else if sv > ov {
131 better = true;
132 }
133 }
134 }
135 better
136 }
137}
138
139pub struct MultiObjectiveOptimizer {
140 config: MultiObjectiveConfig,
141 population: Vec<MultiObjectiveSolution>,
142 pareto_front: Vec<MultiObjectiveSolution>,
143 generation: usize,
144 hypervolume_history: Vec<f64>,
145}
146
147impl MultiObjectiveOptimizer {
148 pub fn new(config: MultiObjectiveConfig) -> Self {
149 Self {
150 config,
151 population: Vec::new(),
152 pareto_front: Vec::new(),
153 generation: 0,
154 hypervolume_history: Vec::new(),
155 }
156 }
157
158 pub fn initialize_population(&mut self, results: &[SearchResult]) -> Result<()> {
159 self.population.clear();
160 for result in results.iter().take(self.config.population_size) {
161 let objectives = self.extract_objectives(&result.metrics)?;
162 self.population.push(MultiObjectiveSolution::new(
163 result.architecture.clone(),
164 objectives,
165 ));
166 }
167 while self.population.len() < self.config.population_size {
168 let arch = self.generate_random_architecture()?;
169 let objs = self.estimate_random_objectives();
170 self.population
171 .push(MultiObjectiveSolution::new(arch, objs));
172 }
173 Ok(())
174 }
175
176 pub fn evolve_generation(&mut self) -> Result<()> {
177 match self.config.algorithm {
178 MultiObjectiveAlgorithm::NSGA2 => self.nsga2_step()?,
179 MultiObjectiveAlgorithm::SPEA2 => self.spea2_step()?,
180 MultiObjectiveAlgorithm::MOEAD => self.moead_step()?,
181 MultiObjectiveAlgorithm::HYPERE => self.hypere_step()?,
182 MultiObjectiveAlgorithm::WeightedSum => self.weighted_sum_step()?,
183 MultiObjectiveAlgorithm::ConstraintHandling => self.constraint_handling_step()?,
184 }
185 self.generation += 1;
186 self.update_pareto_front()?;
187 let hv = self.compute_hypervolume()?;
188 self.hypervolume_history.push(hv);
189 Ok(())
190 }
191
192 fn nsga2_step(&mut self) -> Result<()> {
193 let offspring = self.create_offspring()?;
194 let mut combined = self.population.clone();
195 combined.extend(offspring);
196 self.non_dominated_sort(&mut combined)?;
197 self.population = self.environmental_selection(combined)?;
198 Ok(())
199 }
200
201 fn spea2_step(&mut self) -> Result<()> {
202 let offspring = self.create_offspring()?;
203 let mut combined = self.population.clone();
204 combined.extend(offspring);
205 self.calculate_spea2_fitness_for_population(&mut combined)?;
206 self.population = self.spea2_environmental_selection(combined)?;
207 Ok(())
208 }
209
210 fn moead_step(&mut self) -> Result<()> {
211 let weight_vectors = self.generate_weight_vectors()?;
212 for (i, weights) in weight_vectors
213 .iter()
214 .enumerate()
215 .take(weight_vectors.len().min(self.population.len()))
216 {
217 let weights = weights.clone();
218 let new_solution = self.update_subproblem(i, &weights)?;
219 self.update_neighbors(i, &new_solution)?;
220 }
221 Ok(())
222 }
223
224 fn hypere_step(&mut self) -> Result<()> {
225 let parent_count = 10.min(self.population.len());
226 let mut offspring = Vec::new();
227 for idx in 0..parent_count {
228 let child_arch_box = self.population[idx].architecture.mutate(0.1)?;
229 let child_arch: std::sync::Arc<
230 dyn crate::nas::architecture_encoding::ArchitectureEncoding,
231 > = std::sync::Arc::from(child_arch_box);
232 let objectives = self.estimate_objectives(&child_arch)?;
233 offspring.push(MultiObjectiveSolution::new(child_arch, objectives));
234 }
235 let mut combined = self.population.clone();
236 combined.extend(offspring);
237 self.population = self.hypervolume_environmental_selection(combined)?;
238 Ok(())
239 }
240
241 fn weighted_sum_step(&mut self) -> Result<()> {
242 for solution in &mut self.population {
243 let ws: f64 = solution
244 .objectives
245 .iter()
246 .zip(self.config.objectives.iter())
247 .map(|(v, o)| v * o.weight)
248 .sum();
249 solution.objectives = vec![ws];
250 }
251 self.population.sort_by(|a, b| {
252 let ao = a.objectives.first().copied().unwrap_or(0.0);
253 let bo = b.objectives.first().copied().unwrap_or(0.0);
254 ao.partial_cmp(&bo).unwrap_or(std::cmp::Ordering::Equal)
255 });
256 let offspring = self.create_offspring()?;
257 self.population.extend(offspring);
258 self.population.truncate(self.config.population_size);
259 Ok(())
260 }
261
262 fn constraint_handling_step(&mut self) -> Result<()> {
263 let violations: Vec<Vec<f64>> = self
264 .population
265 .iter()
266 .map(|s| self.evaluate_constraints(s))
267 .collect::<Result<Vec<_>>>()?;
268 for (sol, viols) in self.population.iter_mut().zip(violations) {
269 sol.constraint_violations = viols;
270 }
271 self.population.sort_by(|a, b| {
272 let av: f64 = a.constraint_violations.iter().sum();
273 let bv: f64 = b.constraint_violations.iter().sum();
274 if (av - bv).abs() > 1e-12 {
275 av.partial_cmp(&bv).unwrap_or(std::cmp::Ordering::Equal)
276 } else {
277 a.objectives
278 .first()
279 .copied()
280 .unwrap_or(0.0)
281 .partial_cmp(&b.objectives.first().copied().unwrap_or(0.0))
282 .unwrap_or(std::cmp::Ordering::Equal)
283 }
284 });
285 let offspring = self.create_offspring()?;
286 self.population = self.constraint_environmental_selection(offspring)?;
287 Ok(())
288 }
289
290 fn non_dominated_sort(&self, population: &mut [MultiObjectiveSolution]) -> Result<()> {
291 let n = population.len();
292 let mut dominated_by: Vec<Vec<usize>> = vec![Vec::new(); n];
293 let mut dom_counts: Vec<usize> = vec![0; n];
294
295 for i in 0..n {
296 for j in 0..n {
297 if i == j {
298 continue;
299 }
300 if self.dominates_by_values(&population[i].objectives, &population[j].objectives) {
301 dominated_by[i].push(j);
302 } else if self
303 .dominates_by_values(&population[j].objectives, &population[i].objectives)
304 {
305 dom_counts[i] += 1;
306 }
307 }
308 }
309
310 let mut first_front = Vec::new();
311 for i in 0..n {
312 population[i].dominated_solutions = dominated_by[i].clone();
313 population[i].dominance_count = dom_counts[i];
314 if dom_counts[i] == 0 {
315 population[i].rank = 0;
316 first_front.push(i);
317 }
318 }
319
320 let mut fronts = vec![first_front];
321 let mut fi = 0;
322 while fi < fronts.len() && !fronts[fi].is_empty() {
323 let mut next_front = Vec::new();
324 let current = fronts[fi].clone();
325 for &i in ¤t {
326 let doms = population[i].dominated_solutions.clone();
327 for &j in &doms {
328 if population[j].dominance_count > 0 {
329 population[j].dominance_count -= 1;
330 if population[j].dominance_count == 0 {
331 population[j].rank = fi + 1;
332 next_front.push(j);
333 }
334 }
335 }
336 }
337 fi += 1;
338 fronts.push(next_front);
339 }
340 Ok(())
341 }
342
343 fn dominates_by_values(&self, a: &[f64], b: &[f64]) -> bool {
344 let mut better = false;
345 for (k, obj_cfg) in self.config.objectives.iter().enumerate() {
346 let oa = a.get(k).copied().unwrap_or(0.0);
347 let ob = b.get(k).copied().unwrap_or(0.0);
348 if obj_cfg.minimize {
349 if oa > ob {
350 return false;
351 } else if oa < ob {
352 better = true;
353 }
354 } else {
355 if oa < ob {
356 return false;
357 } else if oa > ob {
358 better = true;
359 }
360 }
361 }
362 better
363 }
364
365 fn calculate_crowding_distance(
366 &self,
367 front: &[usize],
368 population: &mut [MultiObjectiveSolution],
369 ) -> Result<()> {
370 if front.len() <= 2 {
371 for &i in front {
372 population[i].crowding_distance = f64::INFINITY;
373 }
374 return Ok(());
375 }
376 for &i in front {
377 population[i].crowding_distance = 0.0;
378 }
379 for obj_idx in 0..self.config.objectives.len() {
380 let mut sorted = front.to_vec();
381 sorted.sort_by(|&a, &b| {
382 let oa = population[a]
383 .objectives
384 .get(obj_idx)
385 .copied()
386 .unwrap_or(0.0);
387 let ob = population[b]
388 .objectives
389 .get(obj_idx)
390 .copied()
391 .unwrap_or(0.0);
392 oa.partial_cmp(&ob).unwrap_or(std::cmp::Ordering::Equal)
393 });
394 let first = sorted[0];
395 let last = sorted[sorted.len() - 1];
396 population[first].crowding_distance = f64::INFINITY;
397 population[last].crowding_distance = f64::INFINITY;
398 let obj_min = population[first]
399 .objectives
400 .get(obj_idx)
401 .copied()
402 .unwrap_or(0.0);
403 let obj_max = population[last]
404 .objectives
405 .get(obj_idx)
406 .copied()
407 .unwrap_or(0.0);
408 let range = obj_max - obj_min;
409 if range > 0.0 {
410 for i in 1..sorted.len() - 1 {
411 let prev = population[sorted[i - 1]]
412 .objectives
413 .get(obj_idx)
414 .copied()
415 .unwrap_or(0.0);
416 let next = population[sorted[i + 1]]
417 .objectives
418 .get(obj_idx)
419 .copied()
420 .unwrap_or(0.0);
421 population[sorted[i]].crowding_distance += (next - prev) / range;
422 }
423 }
424 }
425 Ok(())
426 }
427
428 fn environmental_selection(
429 &mut self,
430 mut population: Vec<MultiObjectiveSolution>,
431 ) -> Result<Vec<MultiObjectiveSolution>> {
432 let mut result = Vec::new();
433 let mut fronts: HashMap<usize, Vec<usize>> = HashMap::new();
434 for (i, s) in population.iter().enumerate() {
435 fronts.entry(s.rank).or_default().push(i);
436 }
437 let mut current_front = 0;
438 while current_front < fronts.len() {
439 if let Some(front) = fronts.get(¤t_front) {
440 if result.len() + front.len() <= self.config.population_size {
441 for &i in front {
442 result.push(population[i].clone());
443 }
444 } else {
445 self.calculate_crowding_distance(front, &mut population)?;
446 let mut fd: Vec<(usize, f64)> = front
447 .iter()
448 .map(|&i| (i, population[i].crowding_distance))
449 .collect();
450 fd.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
451 let remaining = self.config.population_size - result.len();
452 for item in fd.into_iter().take(remaining) {
453 result.push(population[item.0].clone());
454 }
455 break;
456 }
457 }
458 current_front += 1;
459 }
460 Ok(result)
461 }
462
463 fn create_offspring(&self) -> Result<Vec<MultiObjectiveSolution>> {
464 if self.population.is_empty() {
465 return Ok(Vec::new());
466 }
467 let mut offspring = Vec::new();
468 for _ in 0..self.config.population_size {
469 let p1 = self.tournament_selection()?;
470 let p2 = self.tournament_selection()?;
471 let child = p1.architecture.crossover(p2.architecture.as_ref())?;
472 let mutated_box = child.mutate(0.1)?;
473 let mutated: Arc<dyn ArchitectureEncoding> = Arc::from(mutated_box);
474 let objectives = self.estimate_objectives(&mutated)?;
475 offspring.push(MultiObjectiveSolution::new(mutated, objectives));
476 }
477 Ok(offspring)
478 }
479
480 fn tournament_selection(&self) -> Result<&MultiObjectiveSolution> {
481 use scirs2_core::random::prelude::*;
482 let mut rng_inst = thread_rng();
483 if self.population.is_empty() {
484 return Err(NeuralError::InvalidArgument(
485 "Population is empty".to_string(),
486 ));
487 }
488 let mut best = rng_inst.random_range(0..self.population.len());
489 for _ in 1..3 {
490 let candidate = rng_inst.random_range(0..self.population.len());
491 if self.is_better(&self.population[candidate], &self.population[best]) {
492 best = candidate;
493 }
494 }
495 Ok(&self.population[best])
496 }
497
498 fn is_better(&self, a: &MultiObjectiveSolution, b: &MultiObjectiveSolution) -> bool {
499 if a.rank < b.rank {
500 true
501 } else if a.rank > b.rank {
502 false
503 } else {
504 a.crowding_distance > b.crowding_distance
505 }
506 }
507
508 fn extract_objectives(&self, metrics: &EvaluationMetrics) -> Result<Vec<f64>> {
509 Ok(self
510 .config
511 .objectives
512 .iter()
513 .map(|o| metrics.get(&o.name).copied().unwrap_or(0.0))
514 .collect())
515 }
516
517 fn estimate_objectives(&self, _arch: &Arc<dyn ArchitectureEncoding>) -> Result<Vec<f64>> {
518 Ok(self
519 .config
520 .objectives
521 .iter()
522 .map(|o| match o.name.as_str() {
523 "validation_accuracy" => 0.7 + 0.2 * scirs2_core::random::random::<f64>(),
524 "model_flops" => 1e6 + 1e6 * scirs2_core::random::random::<f64>(),
525 "model_params" => 1e5 + 1e5 * scirs2_core::random::random::<f64>(),
526 "inference_latency" => 10.0 + 10.0 * scirs2_core::random::random::<f64>(),
527 _ => 0.5,
528 })
529 .collect())
530 }
531
532 fn generate_random_architecture(&self) -> Result<Arc<dyn ArchitectureEncoding>> {
533 use scirs2_core::random::prelude::*;
534 let mut rng_inst = thread_rng();
535 let enc = crate::nas::architecture_encoding::SequentialEncoding::random(&mut rng_inst)?;
536 Ok(Arc::new(enc) as Arc<dyn ArchitectureEncoding>)
537 }
538
539 fn estimate_random_objectives(&self) -> Vec<f64> {
540 self.config
541 .objectives
542 .iter()
543 .map(|o| match o.name.as_str() {
544 "validation_accuracy" => 0.3 + 0.4 * scirs2_core::random::random::<f64>(),
545 "model_flops" => 1e5 + 1e6 * scirs2_core::random::random::<f64>(),
546 "model_params" => 1e4 + 1e5 * scirs2_core::random::random::<f64>(),
547 "inference_latency" => 1.0 + 20.0 * scirs2_core::random::random::<f64>(),
548 _ => scirs2_core::random::random::<f64>(),
549 })
550 .collect()
551 }
552
553 fn update_pareto_front(&mut self) -> Result<()> {
554 let mut pareto_indices = Vec::new();
555 for i in 0..self.population.len() {
556 let mut dominated = false;
557 for j in 0..self.population.len() {
558 if i != j
559 && self.dominates_by_values(
560 &self.population[j].objectives.clone(),
561 &self.population[i].objectives.clone(),
562 )
563 {
564 dominated = true;
565 break;
566 }
567 }
568 if !dominated {
569 pareto_indices.push(i);
570 }
571 }
572 let mut pareto: Vec<MultiObjectiveSolution> = pareto_indices
573 .iter()
574 .map(|&i| self.population[i].clone())
575 .collect();
576 if pareto.len() > self.config.pareto_front_limit {
577 let indices: Vec<usize> = (0..pareto.len()).collect();
578 self.calculate_crowding_distance(&indices, &mut pareto)?;
579 pareto.sort_by(|a, b| {
580 b.crowding_distance
581 .partial_cmp(&a.crowding_distance)
582 .unwrap_or(std::cmp::Ordering::Equal)
583 });
584 pareto.truncate(self.config.pareto_front_limit);
585 }
586 self.pareto_front = pareto;
587 Ok(())
588 }
589
590 fn compute_hypervolume(&self) -> Result<f64> {
591 if self.pareto_front.is_empty() {
592 return Ok(0.0);
593 }
594 let rp = self
595 .config
596 .reference_point
597 .as_ref()
598 .cloned()
599 .unwrap_or_else(|| self.estimate_reference_point());
600 match self.config.objectives.len() {
601 2 => self.compute_hypervolume_2d(&rp),
602 3 => self.compute_hypervolume_3d(&rp),
603 _ => self.compute_hypervolume_monte_carlo(&rp),
604 }
605 }
606
607 fn estimate_reference_point(&self) -> Vec<f64> {
608 let n = self.config.objectives.len();
609 let mut rp = vec![0.0f64; n];
610 for (i, obj) in self.config.objectives.iter().enumerate() {
611 if obj.minimize {
612 let max_val = self
613 .pareto_front
614 .iter()
615 .filter_map(|s| s.objectives.get(i).copied())
616 .fold(f64::NEG_INFINITY, f64::max);
617 rp[i] = if max_val.is_finite() {
618 max_val * 1.1
619 } else {
620 1.0
621 };
622 } else {
623 let min_val = self
624 .pareto_front
625 .iter()
626 .filter_map(|s| s.objectives.get(i).copied())
627 .fold(f64::INFINITY, f64::min);
628 rp[i] = if min_val.is_finite() {
629 min_val * 0.9
630 } else {
631 0.0
632 };
633 }
634 }
635 rp
636 }
637
638 fn compute_hypervolume_2d(&self, rp: &[f64]) -> Result<f64> {
639 let min0 = self
640 .config
641 .objectives
642 .first()
643 .map(|o| o.minimize)
644 .unwrap_or(true);
645 let min1 = self
646 .config
647 .objectives
648 .get(1)
649 .map(|o| o.minimize)
650 .unwrap_or(true);
651 let rp0 = rp.first().copied().unwrap_or(0.0);
652 let rp1 = rp.get(1).copied().unwrap_or(0.0);
653 let mut points: Vec<(f64, f64)> = self
654 .pareto_front
655 .iter()
656 .map(|s| {
657 let v0 = s.objectives.first().copied().unwrap_or(0.0);
658 let v1 = s.objectives.get(1).copied().unwrap_or(0.0);
659 let x = if min0 {
660 (rp0 - v0).max(0.0)
661 } else {
662 (v0 - rp0).max(0.0)
663 };
664 let y = if min1 {
665 (rp1 - v1).max(0.0)
666 } else {
667 (v1 - rp1).max(0.0)
668 };
669 (x, y)
670 })
671 .collect();
672 points.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
673 let mut volume = 0.0f64;
674 let mut prev_y = 0.0f64;
675 for (x, y) in points {
676 if y > prev_y {
677 volume += x * (y - prev_y);
678 prev_y = y;
679 }
680 }
681 Ok(volume)
682 }
683
684 fn compute_hypervolume_3d(&self, rp: &[f64]) -> Result<f64> {
685 let points: Vec<[f64; 3]> = self
686 .pareto_front
687 .iter()
688 .map(|s| {
689 let mut arr = [0.0f64; 3];
690 for (i, cell) in arr.iter_mut().enumerate() {
691 let r = rp.get(i).copied().unwrap_or(0.0);
692 let v = s.objectives.get(i).copied().unwrap_or(0.0);
693 *cell = if self
694 .config
695 .objectives
696 .get(i)
697 .map(|o| o.minimize)
698 .unwrap_or(true)
699 {
700 (r - v).max(0.0)
701 } else {
702 (v - r).max(0.0)
703 };
704 }
705 arr
706 })
707 .collect();
708 let n = points.len();
709 let mut volume = 0.0f64;
710 for p in &points {
711 volume += p[0] * p[1] * p[2];
712 }
713 for i in 0..n {
714 for j in (i + 1)..n {
715 volume -= points[i][0].min(points[j][0])
716 * points[i][1].min(points[j][1])
717 * points[i][2].min(points[j][2]);
718 }
719 }
720 Ok(volume.max(0.0))
721 }
722
723 fn compute_hypervolume_monte_carlo(&self, rp: &[f64]) -> Result<f64> {
724 use scirs2_core::random::prelude::*;
725 let mut rng_inst = thread_rng();
726 let num_samples = 10000usize;
727 let n_obj = self.config.objectives.len();
728 let mut lower_bounds = vec![f64::INFINITY; n_obj];
729 let upper_bounds = rp.to_vec();
730 for sol in &self.pareto_front {
731 for (i, &v) in sol.objectives.iter().enumerate() {
732 if i < n_obj {
733 lower_bounds[i] = lower_bounds[i].min(v);
734 }
735 }
736 }
737 for (i, lb) in lower_bounds.iter_mut().enumerate() {
738 if !lb.is_finite() {
739 *lb = upper_bounds.get(i).copied().unwrap_or(0.0) - 1.0;
740 }
741 }
742 let mut dominated_count = 0usize;
743 for _ in 0..num_samples {
744 let sample: Vec<f64> = (0..n_obj)
745 .map(|i| {
746 let lo = lower_bounds[i];
747 let hi = upper_bounds.get(i).copied().unwrap_or(lo + 1.0);
748 if hi > lo {
749 lo + rng_inst.random::<f64>() * (hi - lo)
750 } else {
751 lo
752 }
753 })
754 .collect();
755 let mut is_dominated = false;
756 'outer: for sol in &self.pareto_front {
757 let mut dom = true;
758 let mut better = false;
759 for (i, (&sv, &pv)) in sol.objectives.iter().zip(sample.iter()).enumerate() {
760 let min = self
761 .config
762 .objectives
763 .get(i)
764 .map(|o| o.minimize)
765 .unwrap_or(true);
766 if min {
767 if sv > pv {
768 dom = false;
769 break;
770 } else if sv < pv {
771 better = true;
772 }
773 } else {
774 if sv < pv {
775 dom = false;
776 break;
777 } else if sv > pv {
778 better = true;
779 }
780 }
781 }
782 if dom && better {
783 is_dominated = true;
784 break 'outer;
785 }
786 }
787 if is_dominated {
788 dominated_count += 1;
789 }
790 }
791 let sampling_vol: f64 = upper_bounds
792 .iter()
793 .zip(lower_bounds.iter())
794 .map(|(u, l)| (u - l).max(0.0))
795 .product();
796 Ok(sampling_vol * (dominated_count as f64 / num_samples as f64))
797 }
798
799 pub fn get_pareto_front(&self) -> &[MultiObjectiveSolution] {
800 &self.pareto_front
801 }
802 pub fn get_hypervolume_history(&self) -> &[f64] {
803 &self.hypervolume_history
804 }
805 pub fn get_generation(&self) -> usize {
806 self.generation
807 }
808
809 fn calculate_spea2_fitness_for_population(
810 &self,
811 population: &mut [MultiObjectiveSolution],
812 ) -> Result<()> {
813 let n = population.len();
814 let mut strengths = vec![0usize; n];
815 let mut raw_fitness = vec![0.0f64; n];
816 let mut densities = vec![0.0f64; n];
817 for i in 0..n {
818 let mut count = 0;
819 for j in 0..n {
820 if i != j {
821 let oi: f64 = population[i].objectives.iter().sum();
822 let oj: f64 = population[j].objectives.iter().sum();
823 if oi < oj {
824 count += 1;
825 }
826 }
827 }
828 strengths[i] = count;
829 }
830 for i in 0..n {
831 let mut fitness = 0.0;
832 for j in 0..n {
833 if i != j {
834 let oi: f64 = population[i].objectives.iter().sum();
835 let oj: f64 = population[j].objectives.iter().sum();
836 if oj < oi {
837 fitness += strengths[j] as f64;
838 }
839 }
840 }
841 raw_fitness[i] = fitness;
842 }
843 let k = (n as f64).sqrt() as usize;
844 for i in 0..n {
845 let mut dists: Vec<f64> = (0..n)
846 .filter(|&j| j != i)
847 .map(|j| self.euclidean_distance(&population[i], &population[j]))
848 .collect();
849 dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
850 let kth = if k > 0 && k <= dists.len() {
851 dists[k - 1]
852 } else {
853 dists.last().copied().unwrap_or(0.0)
854 };
855 densities[i] = 1.0 / (kth + 2.0);
856 }
857 for i in 0..n {
858 population[i].crowding_distance = raw_fitness[i] + densities[i];
859 }
860 Ok(())
861 }
862
863 fn spea2_environmental_selection(
864 &self,
865 mut population: Vec<MultiObjectiveSolution>,
866 ) -> Result<Vec<MultiObjectiveSolution>> {
867 population.sort_by(|a, b| {
868 a.crowding_distance
869 .partial_cmp(&b.crowding_distance)
870 .unwrap_or(std::cmp::Ordering::Equal)
871 });
872 let mut selected = Vec::new();
873 for sol in &population {
874 if sol.crowding_distance < 1.0 && selected.len() < self.config.population_size {
875 selected.push(sol.clone());
876 }
877 }
878 if selected.len() < self.config.population_size {
879 for sol in &population {
880 if sol.crowding_distance >= 1.0 && selected.len() < self.config.population_size {
881 selected.push(sol.clone());
882 }
883 }
884 }
885 selected.truncate(self.config.population_size);
886 Ok(selected)
887 }
888
889 fn euclidean_distance(&self, a: &MultiObjectiveSolution, b: &MultiObjectiveSolution) -> f64 {
890 a.objectives
891 .iter()
892 .zip(b.objectives.iter())
893 .map(|(x, y)| (x - y).powi(2))
894 .sum::<f64>()
895 .sqrt()
896 }
897
898 fn generate_weight_vectors(&self) -> Result<Vec<Vec<f64>>> {
899 let n_obj = self.config.objectives.len();
900 let n_weights = self.config.population_size;
901 let mut weights = Vec::new();
902 if n_obj == 2 {
903 for i in 0..n_weights {
904 let w1 = i as f64 / (n_weights - 1).max(1) as f64;
905 weights.push(vec![w1, 1.0 - w1]);
906 }
907 } else {
908 while weights.len() < n_weights {
909 let raw: Vec<f64> = (0..n_obj)
910 .map(|_| scirs2_core::random::random::<f64>())
911 .collect();
912 let sum: f64 = raw.iter().sum();
913 if sum > 1e-12 {
914 weights.push(raw.iter().map(|w| w / sum).collect());
915 }
916 }
917 }
918 weights.truncate(n_weights);
919 Ok(weights)
920 }
921
922 fn update_subproblem(&self, index: usize, weights: &[f64]) -> Result<MultiObjectiveSolution> {
923 if index >= self.population.len() {
924 return Err(NeuralError::InvalidArgument(
925 "Subproblem index out of bounds".to_string(),
926 ));
927 }
928 let current = &self.population[index];
929 let neighbor = self.select_neighbor(index)?;
930 let p2 = &self.population[neighbor];
931 let child = current.architecture.crossover(p2.architecture.as_ref())?;
932 let mutated_box = child.mutate(0.1)?;
933 let mutated: Arc<dyn ArchitectureEncoding> = Arc::from(mutated_box);
934 let objectives = self.estimate_objectives(&mutated)?;
935 let mut child_sol = MultiObjectiveSolution::new(mutated, objectives);
936 let cur_fit = self.tchebycheff_fitness(¤t.objectives, weights);
937 let child_fit = self.tchebycheff_fitness(&child_sol.objectives, weights);
938 if child_fit < cur_fit {
939 child_sol.crowding_distance = child_fit;
940 Ok(child_sol)
941 } else {
942 let mut cur_clone = current.clone();
943 cur_clone.crowding_distance = cur_fit;
944 Ok(cur_clone)
945 }
946 }
947
948 fn select_neighbor(&self, index: usize) -> Result<usize> {
949 use scirs2_core::random::prelude::*;
950 let mut rng_inst = thread_rng();
951 if self.population.len() <= 1 {
952 return Ok(0);
953 }
954 let nbhood = 10.min(self.population.len());
955 let start = index.saturating_sub(nbhood / 2);
956 let end = (index + nbhood / 2).min(self.population.len() - 1);
957 if end <= start {
958 return Ok(if index > 0 { index - 1 } else { 0 });
959 }
960 let ni = rng_inst.random_range(start..=end);
961 if ni == index && end > start {
962 Ok(if ni == start { end } else { start })
963 } else {
964 Ok(ni)
965 }
966 }
967
968 fn tchebycheff_fitness(&self, objectives: &[f64], weights: &[f64]) -> f64 {
969 let mut max_diff = 0.0f64;
970 for (i, (&v, &w)) in objectives.iter().zip(weights.iter()).enumerate() {
971 let ideal = if self
972 .config
973 .objectives
974 .get(i)
975 .map(|o| o.minimize)
976 .unwrap_or(true)
977 {
978 0.0
979 } else {
980 1.0
981 };
982 max_diff = max_diff.max(w * (v - ideal).abs());
983 }
984 max_diff
985 }
986
987 fn update_neighbors(&mut self, index: usize, solution: &MultiObjectiveSolution) -> Result<()> {
988 let nbhood = 10.min(self.population.len());
989 let start = index.saturating_sub(nbhood / 2);
990 let end = (index + nbhood / 2).min(self.population.len());
991 let wvecs = self.generate_weight_vectors()?;
992 for i in start..end {
993 if i != index && i < wvecs.len() {
994 let w = wvecs[i].clone();
995 let cur_fit = self.tchebycheff_fitness(&self.population[i].objectives, &w);
996 let new_fit = self.tchebycheff_fitness(&solution.objectives, &w);
997 if new_fit < cur_fit {
998 self.population[i] = solution.clone();
999 }
1000 }
1001 }
1002 Ok(())
1003 }
1004
1005 fn hypervolume_environmental_selection(
1006 &self,
1007 mut combined: Vec<MultiObjectiveSolution>,
1008 ) -> Result<Vec<MultiObjectiveSolution>> {
1009 combined.sort_by(|a, b| {
1010 b.crowding_distance
1011 .partial_cmp(&a.crowding_distance)
1012 .unwrap_or(std::cmp::Ordering::Equal)
1013 });
1014 combined.truncate(self.config.population_size);
1015 Ok(combined)
1016 }
1017
1018 fn evaluate_constraints(&self, solution: &MultiObjectiveSolution) -> Result<Vec<f64>> {
1019 let mut violations = Vec::new();
1020 for (i, obj) in self.config.objectives.iter().enumerate() {
1021 if let (Some(target), Some(tol)) = (obj.target, obj.tolerance) {
1022 let v = solution.objectives.get(i).copied().unwrap_or(0.0);
1023 violations.push(((v - target).abs() - tol).max(0.0));
1024 }
1025 }
1026 Ok(violations)
1027 }
1028
1029 fn constraint_environmental_selection(
1030 &self,
1031 mut offspring: Vec<MultiObjectiveSolution>,
1032 ) -> Result<Vec<MultiObjectiveSolution>> {
1033 let mut combined = self.population.clone();
1034 combined.append(&mut offspring);
1035 combined.sort_by(|a, b| {
1036 let av: f64 = a.constraint_violations.iter().sum();
1037 let bv: f64 = b.constraint_violations.iter().sum();
1038 if (av - bv).abs() > 1e-12 {
1039 av.partial_cmp(&bv).unwrap_or(std::cmp::Ordering::Equal)
1040 } else {
1041 a.objectives
1042 .first()
1043 .copied()
1044 .unwrap_or(0.0)
1045 .partial_cmp(&b.objectives.first().copied().unwrap_or(0.0))
1046 .unwrap_or(std::cmp::Ordering::Equal)
1047 }
1048 });
1049 combined.truncate(self.config.population_size);
1050 Ok(combined)
1051 }
1052}
1053
1054#[cfg(test)]
1055mod tests {
1056 use super::*;
1057
1058 #[test]
1059 fn test_multi_objective_config() {
1060 let config = MultiObjectiveConfig::default();
1061 assert_eq!(config.objectives.len(), 4);
1062 assert_eq!(config.population_size, 50);
1063 }
1064
1065 #[test]
1066 fn test_solution_dominance() {
1067 let config = MultiObjectiveConfig::default();
1068 let arch1 = Arc::new(crate::nas::architecture_encoding::SequentialEncoding::new(
1069 vec![],
1070 ));
1071 let arch2 = Arc::new(crate::nas::architecture_encoding::SequentialEncoding::new(
1072 vec![],
1073 ));
1074 let sol1 = MultiObjectiveSolution::new(arch1, vec![0.9, 1000.0, 500.0, 5.0]);
1075 let sol2 = MultiObjectiveSolution::new(arch2, vec![0.8, 500.0, 250.0, 2.5]);
1076 assert!(!sol1.dominates(&sol2, &config));
1077 assert!(!sol2.dominates(&sol1, &config));
1078 }
1079
1080 #[test]
1081 fn test_optimizer_creation() {
1082 let config = MultiObjectiveConfig::default();
1083 let optimizer = MultiObjectiveOptimizer::new(config);
1084 assert_eq!(optimizer.generation, 0);
1085 assert!(optimizer.pareto_front.is_empty());
1086 }
1087}