1use super::{utils, MultiObjectiveConfig, MultiObjectiveOptimizer};
20use crate::error::OptimizeError;
21use crate::multi_objective::crossover::{CrossoverOperator, SimulatedBinaryCrossover};
22use crate::multi_objective::mutation::{MutationOperator, PolynomialMutation};
23use crate::multi_objective::solutions::{MultiObjectiveResult, MultiObjectiveSolution, Population};
24use scirs2_core::ndarray::{Array1, ArrayView1};
25use scirs2_core::random::rngs::StdRng;
26use scirs2_core::random::{Rng, RngExt, SeedableRng};
27
28pub struct SPEA2 {
30 config: MultiObjectiveConfig,
31 archive_size: usize,
32 n_objectives: usize,
33 n_variables: usize,
34 archive: Vec<MultiObjectiveSolution>,
36 population: Population,
37 generation: usize,
38 n_evaluations: usize,
39 rng: StdRng,
40 crossover: SimulatedBinaryCrossover,
41 mutation: PolynomialMutation,
42 convergence_history: Vec<f64>,
43}
44
45impl SPEA2 {
46 pub fn new(population_size: usize, n_objectives: usize, n_variables: usize) -> Self {
48 let config = MultiObjectiveConfig {
49 population_size,
50 ..Default::default()
51 };
52 Self::with_config(config, n_objectives, n_variables)
53 }
54
55 pub fn with_config(
57 config: MultiObjectiveConfig,
58 n_objectives: usize,
59 n_variables: usize,
60 ) -> Self {
61 let archive_size = config.archive_size.unwrap_or(config.population_size);
62
63 let seed = config.random_seed.unwrap_or_else(|| {
64 use std::time::{SystemTime, UNIX_EPOCH};
65 SystemTime::now()
66 .duration_since(UNIX_EPOCH)
67 .map(|d| d.as_secs())
68 .unwrap_or(42)
69 });
70
71 let rng = StdRng::seed_from_u64(seed);
72
73 let crossover =
74 SimulatedBinaryCrossover::new(config.crossover_eta, config.crossover_probability);
75 let mutation = PolynomialMutation::new(config.mutation_probability, config.mutation_eta);
76
77 Self {
78 config,
79 archive_size,
80 n_objectives,
81 n_variables,
82 archive: Vec::new(),
83 population: Population::new(),
84 generation: 0,
85 n_evaluations: 0,
86 rng,
87 crossover,
88 mutation,
89 convergence_history: Vec::new(),
90 }
91 }
92
93 fn evaluate_individual<F>(
95 &mut self,
96 variables: &Array1<f64>,
97 objective_function: &F,
98 ) -> Result<Array1<f64>, OptimizeError>
99 where
100 F: Fn(&ArrayView1<f64>) -> Array1<f64> + Send + Sync,
101 {
102 self.n_evaluations += 1;
103
104 if let Some(max_evals) = self.config.max_evaluations {
105 if self.n_evaluations > max_evals {
106 return Err(OptimizeError::MaxEvaluationsReached);
107 }
108 }
109
110 let objectives = objective_function(&variables.view());
111 if objectives.len() != self.n_objectives {
112 return Err(OptimizeError::InvalidInput(format!(
113 "Expected {} objectives, got {}",
114 self.n_objectives,
115 objectives.len()
116 )));
117 }
118
119 Ok(objectives)
120 }
121
122 fn dominates(a: &MultiObjectiveSolution, b: &MultiObjectiveSolution) -> bool {
124 let mut at_least_one_better = false;
125 for i in 0..a.objectives.len() {
126 if a.objectives[i] > b.objectives[i] {
127 return false;
128 }
129 if a.objectives[i] < b.objectives[i] {
130 at_least_one_better = true;
131 }
132 }
133 at_least_one_better
134 }
135
136 fn calculate_strengths(combined: &[MultiObjectiveSolution]) -> Vec<usize> {
139 let n = combined.len();
140 let mut strengths = vec![0usize; n];
141
142 for i in 0..n {
143 for j in 0..n {
144 if i != j && Self::dominates(&combined[i], &combined[j]) {
145 strengths[i] += 1;
146 }
147 }
148 }
149
150 strengths
151 }
152
153 fn calculate_raw_fitness(combined: &[MultiObjectiveSolution], strengths: &[usize]) -> Vec<f64> {
155 let n = combined.len();
156 let mut raw_fitness = vec![0.0f64; n];
157
158 for i in 0..n {
159 for j in 0..n {
160 if i != j && Self::dominates(&combined[j], &combined[i]) {
161 raw_fitness[i] += strengths[j] as f64;
162 }
163 }
164 }
165
166 raw_fitness
167 }
168
169 fn calculate_density(combined: &[MultiObjectiveSolution]) -> Vec<f64> {
172 let n = combined.len();
173 if n <= 1 {
174 return vec![0.0; n];
175 }
176
177 let k = (n as f64).sqrt().floor() as usize;
179 let k = k.max(1).min(n - 1);
180
181 let mut densities = vec![0.0f64; n];
182
183 for i in 0..n {
184 let mut distances: Vec<f64> = Vec::with_capacity(n - 1);
186 for j in 0..n {
187 if i != j {
188 let dist = euclidean_distance_objectives(&combined[i], &combined[j]);
189 distances.push(dist);
190 }
191 }
192
193 distances.sort_by(|a, b| a.total_cmp(b));
195
196 let sigma_k = if k <= distances.len() {
197 distances[k - 1]
198 } else {
199 distances.last().copied().unwrap_or(0.0)
200 };
201
202 densities[i] = 1.0 / (sigma_k + 2.0);
203 }
204
205 densities
206 }
207
208 fn calculate_fitness(combined: &[MultiObjectiveSolution]) -> Vec<f64> {
210 let strengths = Self::calculate_strengths(combined);
211 let raw_fitness = Self::calculate_raw_fitness(combined, &strengths);
212 let densities = Self::calculate_density(combined);
213
214 raw_fitness
215 .iter()
216 .zip(densities.iter())
217 .map(|(r, d)| r + d)
218 .collect()
219 }
220
221 fn truncate_archive(&self, candidates: &mut Vec<(MultiObjectiveSolution, f64)>) {
225 while candidates.len() > self.archive_size {
226 let n = candidates.len();
227
228 let mut dist_matrix = vec![vec![0.0f64; n]; n];
230 for i in 0..n {
231 for j in (i + 1)..n {
232 let d = euclidean_distance_objectives(&candidates[i].0, &candidates[j].0);
233 dist_matrix[i][j] = d;
234 dist_matrix[j][i] = d;
235 }
236 }
237
238 let mut sorted_distances: Vec<Vec<(f64, usize)>> = Vec::with_capacity(n);
240 for i in 0..n {
241 let mut dists: Vec<(f64, usize)> = (0..n)
242 .filter(|&j| j != i)
243 .map(|j| (dist_matrix[i][j], j))
244 .collect();
245 dists.sort_by(|a, b| a.0.total_cmp(&b.0));
246 sorted_distances.push(dists);
247 }
248
249 let mut remove_idx = 0;
253 for candidate_idx in 1..n {
254 let mut is_smaller = false;
255 for k in 0..sorted_distances[0]
256 .len()
257 .min(sorted_distances[candidate_idx].len())
258 {
259 let d_current = sorted_distances[remove_idx]
260 .get(k)
261 .map(|(d, _)| *d)
262 .unwrap_or(f64::INFINITY);
263 let d_candidate = sorted_distances[candidate_idx]
264 .get(k)
265 .map(|(d, _)| *d)
266 .unwrap_or(f64::INFINITY);
267
268 if d_candidate < d_current - 1e-15 {
269 is_smaller = true;
270 break;
271 } else if d_candidate > d_current + 1e-15 {
272 break;
273 }
274 }
276 if is_smaller {
277 remove_idx = candidate_idx;
278 }
279 }
280
281 candidates.remove(remove_idx);
282 }
283 }
284
285 fn environmental_selection(
287 &self,
288 combined: &[MultiObjectiveSolution],
289 fitness: &[f64],
290 ) -> Vec<MultiObjectiveSolution> {
291 let mut next_archive: Vec<(MultiObjectiveSolution, f64)> = combined
293 .iter()
294 .zip(fitness.iter())
295 .filter(|(_, &f)| f < 1.0)
296 .map(|(sol, &f)| (sol.clone(), f))
297 .collect();
298
299 if next_archive.len() < self.archive_size {
300 let mut dominated: Vec<(MultiObjectiveSolution, f64)> = combined
302 .iter()
303 .zip(fitness.iter())
304 .filter(|(_, &f)| f >= 1.0)
305 .map(|(sol, &f)| (sol.clone(), f))
306 .collect();
307
308 dominated.sort_by(|a, b| a.1.total_cmp(&b.1));
310
311 let remaining = self.archive_size - next_archive.len();
312 next_archive.extend(dominated.into_iter().take(remaining));
313 } else if next_archive.len() > self.archive_size {
314 self.truncate_archive(&mut next_archive);
316 }
317
318 next_archive.into_iter().map(|(sol, _)| sol).collect()
319 }
320
321 fn binary_tournament_selection(
323 &mut self,
324 archive: &[MultiObjectiveSolution],
325 fitness: &[f64],
326 n_select: usize,
327 ) -> Vec<MultiObjectiveSolution> {
328 let n = archive.len();
329 if n == 0 {
330 return vec![];
331 }
332
333 let mut selected = Vec::with_capacity(n_select);
334
335 for _ in 0..n_select {
336 let idx1 = self.rng.random_range(0..n);
337 let idx2 = self.rng.random_range(0..n);
338
339 let winner = if fitness[idx1] <= fitness[idx2] {
340 idx1
341 } else {
342 idx2
343 };
344 selected.push(archive[winner].clone());
345 }
346
347 selected
348 }
349
350 fn create_offspring<F>(
352 &mut self,
353 archive: &[MultiObjectiveSolution],
354 archive_fitness: &[f64],
355 objective_function: &F,
356 ) -> Result<Vec<MultiObjectiveSolution>, OptimizeError>
357 where
358 F: Fn(&ArrayView1<f64>) -> Array1<f64> + Send + Sync,
359 {
360 let mut offspring = Vec::new();
361
362 while offspring.len() < self.config.population_size {
363 let parents = self.binary_tournament_selection(archive, archive_fitness, 2);
364 if parents.len() < 2 {
365 break;
366 }
367
368 let p1_vars = parents[0].variables.as_slice().unwrap_or(&[]);
369 let p2_vars = parents[1].variables.as_slice().unwrap_or(&[]);
370
371 let (mut c1_vars, mut c2_vars) = self.crossover.crossover(p1_vars, p2_vars);
372
373 let bounds: Vec<(f64, f64)> = if let Some((lower, upper)) = &self.config.bounds {
374 lower
375 .iter()
376 .zip(upper.iter())
377 .map(|(&l, &u)| (l, u))
378 .collect()
379 } else {
380 vec![(-1.0, 1.0); self.n_variables]
381 };
382
383 self.mutation.mutate(&mut c1_vars, &bounds);
384 self.mutation.mutate(&mut c2_vars, &bounds);
385
386 let c1_arr = Array1::from_vec(c1_vars);
387 let c1_obj = self.evaluate_individual(&c1_arr, objective_function)?;
388 offspring.push(MultiObjectiveSolution::new(c1_arr, c1_obj));
389
390 if offspring.len() < self.config.population_size {
391 let c2_arr = Array1::from_vec(c2_vars);
392 let c2_obj = self.evaluate_individual(&c2_arr, objective_function)?;
393 offspring.push(MultiObjectiveSolution::new(c2_arr, c2_obj));
394 }
395 }
396
397 Ok(offspring)
398 }
399
400 fn calculate_metrics(&mut self) {
402 if let Some(ref_point) = &self.config.reference_point {
403 let pareto_front = extract_pareto_front_from_slice(&self.archive);
404 let hv = utils::calculate_hypervolume(&pareto_front, ref_point);
405 self.convergence_history.push(hv);
406 }
407 }
408}
409
410impl MultiObjectiveOptimizer for SPEA2 {
411 fn optimize<F>(&mut self, objective_function: F) -> Result<MultiObjectiveResult, OptimizeError>
412 where
413 F: Fn(&ArrayView1<f64>) -> Array1<f64> + Send + Sync,
414 {
415 self.initialize_population()?;
416
417 let initial_vars = utils::generate_random_population(
419 self.config.population_size,
420 self.n_variables,
421 &self.config.bounds,
422 );
423
424 let mut initial_solutions = Vec::new();
425 for vars in initial_vars {
426 let objs = self.evaluate_individual(&vars, &objective_function)?;
427 initial_solutions.push(MultiObjectiveSolution::new(vars, objs));
428 }
429
430 self.population = Population::from_solutions(initial_solutions);
431 self.archive.clear();
432
433 while self.generation < self.config.max_generations {
435 if self.check_convergence() {
436 break;
437 }
438 self.evolve_generation(&objective_function)?;
439 }
440
441 let pareto_front = extract_pareto_front_from_slice(&self.archive);
443 let hypervolume = self
444 .config
445 .reference_point
446 .as_ref()
447 .map(|rp| utils::calculate_hypervolume(&pareto_front, rp));
448
449 let mut result = MultiObjectiveResult::new(
450 pareto_front,
451 self.archive.clone(),
452 self.n_evaluations,
453 self.generation,
454 );
455 result.hypervolume = hypervolume;
456 result.metrics.convergence_history = self.convergence_history.clone();
457
458 Ok(result)
459 }
460
461 fn evolve_generation<F>(&mut self, objective_function: &F) -> Result<(), OptimizeError>
462 where
463 F: Fn(&ArrayView1<f64>) -> Array1<f64> + Send + Sync,
464 {
465 let mut combined: Vec<MultiObjectiveSolution> = self.population.solutions().to_vec();
467 combined.extend(self.archive.clone());
468
469 let fitness = Self::calculate_fitness(&combined);
471
472 self.archive = self.environmental_selection(&combined, &fitness);
474
475 let archive_clone = self.archive.clone();
477 let archive_fitness = Self::calculate_fitness(&archive_clone);
478
479 let offspring =
481 self.create_offspring(&archive_clone, &archive_fitness, objective_function)?;
482
483 self.population = Population::from_solutions(offspring);
484 self.generation += 1;
485 self.calculate_metrics();
486
487 Ok(())
488 }
489
490 fn initialize_population(&mut self) -> Result<(), OptimizeError> {
491 self.population.clear();
492 self.archive.clear();
493 self.generation = 0;
494 self.n_evaluations = 0;
495 self.convergence_history.clear();
496 Ok(())
497 }
498
499 fn check_convergence(&self) -> bool {
500 if let Some(max_evals) = self.config.max_evaluations {
501 if self.n_evaluations >= max_evals {
502 return true;
503 }
504 }
505
506 if self.convergence_history.len() >= 10 {
507 let recent = &self.convergence_history[self.convergence_history.len() - 10..];
508 let max_hv = recent.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
509 let min_hv = recent.iter().fold(f64::INFINITY, |a, &b| a.min(b));
510 if (max_hv - min_hv) < self.config.tolerance {
511 return true;
512 }
513 }
514
515 false
516 }
517
518 fn get_population(&self) -> &Population {
519 &self.population
520 }
521
522 fn get_generation(&self) -> usize {
523 self.generation
524 }
525
526 fn get_evaluations(&self) -> usize {
527 self.n_evaluations
528 }
529
530 fn name(&self) -> &str {
531 "SPEA2"
532 }
533}
534
535fn euclidean_distance_objectives(a: &MultiObjectiveSolution, b: &MultiObjectiveSolution) -> f64 {
537 a.objectives
538 .iter()
539 .zip(b.objectives.iter())
540 .map(|(x, y)| (x - y).powi(2))
541 .sum::<f64>()
542 .sqrt()
543}
544
545fn extract_pareto_front_from_slice(
547 solutions: &[MultiObjectiveSolution],
548) -> Vec<MultiObjectiveSolution> {
549 let mut pareto_front: Vec<MultiObjectiveSolution> = Vec::new();
550
551 for candidate in solutions {
552 let mut is_dominated = false;
553 for existing in &pareto_front {
554 if SPEA2::dominates(existing, candidate) {
555 is_dominated = true;
556 break;
557 }
558 }
559 if !is_dominated {
560 pareto_front.retain(|existing| !SPEA2::dominates(candidate, existing));
561 pareto_front.push(candidate.clone());
562 }
563 }
564
565 pareto_front
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
571 use scirs2_core::ndarray::{array, s};
572
573 fn zdt1(x: &ArrayView1<f64>) -> Array1<f64> {
574 let f1 = x[0];
575 let g = 1.0 + 9.0 * x.slice(s![1..]).sum() / (x.len() - 1) as f64;
576 let f2 = g * (1.0 - (f1 / g).sqrt());
577 array![f1, f2]
578 }
579
580 #[test]
581 fn test_spea2_creation() {
582 let spea2 = SPEA2::new(100, 2, 3);
583 assert_eq!(spea2.n_objectives, 2);
584 assert_eq!(spea2.n_variables, 3);
585 assert_eq!(spea2.archive_size, 100);
586 assert_eq!(spea2.generation, 0);
587 }
588
589 #[test]
590 fn test_spea2_with_config() {
591 let config = MultiObjectiveConfig {
592 population_size: 50,
593 archive_size: Some(30),
594 max_generations: 10,
595 random_seed: Some(42),
596 ..Default::default()
597 };
598
599 let spea2 = SPEA2::with_config(config, 2, 3);
600 assert_eq!(spea2.archive_size, 30);
601 }
602
603 #[test]
604 fn test_spea2_optimize_zdt1() {
605 let config = MultiObjectiveConfig {
606 max_generations: 10,
607 population_size: 20,
608 bounds: Some((Array1::zeros(3), Array1::ones(3))),
609 random_seed: Some(42),
610 ..Default::default()
611 };
612
613 let mut spea2 = SPEA2::with_config(config, 2, 3);
614 let result = spea2.optimize(zdt1);
615
616 assert!(result.is_ok());
617 let res = result.expect("should succeed");
618 assert!(res.success);
619 assert!(!res.pareto_front.is_empty());
620 assert!(res.n_evaluations > 0);
621 }
622
623 #[test]
624 fn test_spea2_strength_calculation() {
625 let solutions = vec![
626 MultiObjectiveSolution::new(array![0.0], array![1.0, 3.0]),
627 MultiObjectiveSolution::new(array![1.0], array![2.0, 2.0]),
628 MultiObjectiveSolution::new(array![2.0], array![3.0, 1.0]),
629 MultiObjectiveSolution::new(array![3.0], array![4.0, 4.0]), ];
631
632 let strengths = SPEA2::calculate_strengths(&solutions);
633 assert!(strengths[0] >= 1);
636 assert!(strengths[1] >= 1);
637 assert!(strengths[2] >= 1);
638 }
639
640 #[test]
641 fn test_spea2_fitness_calculation() {
642 let solutions = vec![
643 MultiObjectiveSolution::new(array![0.0], array![1.0, 3.0]),
644 MultiObjectiveSolution::new(array![1.0], array![2.0, 2.0]),
645 MultiObjectiveSolution::new(array![2.0], array![4.0, 4.0]), ];
647
648 let fitness = SPEA2::calculate_fitness(&solutions);
649 assert!(fitness[0] < fitness[2] || fitness[1] < fitness[2]);
651 }
652
653 #[test]
654 fn test_spea2_dominance() {
655 let a = MultiObjectiveSolution::new(array![0.0], array![1.0, 2.0]);
656 let b = MultiObjectiveSolution::new(array![0.0], array![2.0, 3.0]);
657 let c = MultiObjectiveSolution::new(array![0.0], array![0.5, 3.5]);
658
659 assert!(SPEA2::dominates(&a, &b)); assert!(!SPEA2::dominates(&b, &a));
661 assert!(!SPEA2::dominates(&a, &c)); assert!(!SPEA2::dominates(&c, &a));
663 }
664
665 #[test]
666 fn test_spea2_max_evaluations() {
667 let config = MultiObjectiveConfig {
668 max_generations: 1000,
669 max_evaluations: Some(50),
670 population_size: 10,
671 bounds: Some((Array1::zeros(3), Array1::ones(3))),
672 random_seed: Some(42),
673 ..Default::default()
674 };
675
676 let mut spea2 = SPEA2::with_config(config, 2, 3);
677 let result = spea2.optimize(zdt1);
678 assert!(result.is_ok());
679 let res = result.expect("should succeed");
680 assert!(res.n_evaluations <= 60);
681 }
682
683 #[test]
684 fn test_spea2_name() {
685 let spea2 = SPEA2::new(50, 2, 3);
686 assert_eq!(spea2.name(), "SPEA2");
687 }
688
689 #[test]
690 fn test_euclidean_distance() {
691 let a = MultiObjectiveSolution::new(array![0.0], array![0.0, 0.0]);
692 let b = MultiObjectiveSolution::new(array![0.0], array![3.0, 4.0]);
693 let dist = euclidean_distance_objectives(&a, &b);
694 assert!((dist - 5.0).abs() < 1e-10);
695 }
696}