scirs2_optimize/multi_objective/
mod.rs1pub mod advanced;
18pub mod algorithms;
19pub mod crossover;
20pub mod indicators;
21pub mod mutation;
22pub mod pareto;
23pub mod selection;
24pub mod solutions;
25
26pub use algorithms::{
28 MultiObjectiveConfig, MultiObjectiveOptimizer, MultiObjectiveOptimizerWrapper,
29 OptimizerFactory, MOEAD, NSGAII, NSGAIII, SPEA2,
30};
31pub use solutions::{
32 MultiObjectiveResult, MultiObjectiveSolution, OptimizationMetrics, Population,
33};
34
35use crate::error::OptimizeError;
36use scirs2_core::ndarray::{s, Array1, ArrayView1};
37
38pub fn nsga2<F>(
40 objective_function: F,
41 n_objectives: usize,
42 n_variables: usize,
43 config: Option<MultiObjectiveConfig>,
44) -> Result<MultiObjectiveResult, OptimizeError>
45where
46 F: Fn(&ArrayView1<f64>) -> Array1<f64> + Send + Sync,
47{
48 let config = config.unwrap_or_default();
49 let mut optimizer =
50 algorithms::OptimizerFactory::create_nsga2(config, n_objectives, n_variables)?;
51 optimizer.optimize(objective_function)
52}
53
54pub fn nsga3<F>(
56 objective_function: F,
57 n_objectives: usize,
58 n_variables: usize,
59 config: Option<MultiObjectiveConfig>,
60 reference_points: Option<Vec<Array1<f64>>>,
61) -> Result<MultiObjectiveResult, OptimizeError>
62where
63 F: Fn(&ArrayView1<f64>) -> Array1<f64> + Send + Sync,
64{
65 let config = config.unwrap_or_default();
66 let mut optimizer = algorithms::OptimizerFactory::create_nsga3(
67 config,
68 n_objectives,
69 n_variables,
70 reference_points,
71 )?;
72 optimizer.optimize(objective_function)
73}
74
75pub fn optimize<F>(
77 algorithm: &str,
78 objective_function: F,
79 n_objectives: usize,
80 n_variables: usize,
81 config: Option<MultiObjectiveConfig>,
82) -> Result<MultiObjectiveResult, OptimizeError>
83where
84 F: Fn(&ArrayView1<f64>) -> Array1<f64> + Send + Sync,
85{
86 let config = config.unwrap_or_default();
87 let mut optimizer =
88 algorithms::OptimizerFactory::create_by_name(algorithm, config, n_objectives, n_variables)?;
89
90 let adapted_fn = |x: &ArrayView1<f64>| objective_function(x);
92
93 optimizer.optimize(adapted_fn)
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99 use scirs2_core::ndarray::array;
100
101 fn zdt1(x: &ArrayView1<f64>) -> Array1<f64> {
103 let f1 = x[0];
104 let g = 1.0 + 9.0 * x.slice(s![1..]).sum() / (x.len() - 1) as f64;
105 let f2 = g * (1.0 - (f1 / g).sqrt());
106 array![f1, f2]
107 }
108
109 #[test]
110 fn test_nsga2_convenience_function() {
111 let mut config = MultiObjectiveConfig::default();
112 config.max_generations = 5;
113 config.population_size = 10;
114 config.bounds = Some((Array1::zeros(2), Array1::ones(2)));
115
116 let result = nsga2(zdt1, 2, 2, Some(config));
117 assert!(result.is_ok());
118
119 let result = result.expect("Operation failed");
120 assert!(result.success);
121 assert!(!result.pareto_front.is_empty());
122 }
123
124 #[test]
125 fn test_optimize_by_name() {
126 let mut config = MultiObjectiveConfig::default();
127 config.max_generations = 5;
128 config.population_size = 10;
129 config.bounds = Some((Array1::zeros(2), Array1::ones(2)));
130
131 let result = optimize("nsga2", zdt1, 2, 2, Some(config.clone()));
132 assert!(result.is_ok());
133
134 let result = optimize("unknown", zdt1, 2, 2, Some(config));
135 assert!(result.is_err());
136 }
137
138 #[test]
139 fn test_default_config() {
140 let result = nsga2(zdt1, 2, 2, None);
141 assert!(result.is_ok() || matches!(result, Err(OptimizeError::MaxEvaluationsReached)));
143 }
144}