scirs2_optimize/multi_objective/
mod.rs1pub mod algorithms;
18pub mod crossover;
19pub mod indicators;
20pub mod mutation;
21pub mod pareto;
22pub mod selection;
23pub mod solutions;
24
25pub use algorithms::{
27 MultiObjectiveConfig, MultiObjectiveOptimizer, MultiObjectiveOptimizerWrapper,
28 OptimizerFactory, MOEAD, NSGAII, NSGAIII, SPEA2,
29};
30pub use solutions::{
31 MultiObjectiveResult, MultiObjectiveSolution, OptimizationMetrics, Population,
32};
33
34use crate::error::OptimizeError;
35use scirs2_core::ndarray::{s, Array1, ArrayView1};
36
37pub fn nsga2<F>(
39 objective_function: F,
40 n_objectives: usize,
41 n_variables: usize,
42 config: Option<MultiObjectiveConfig>,
43) -> Result<MultiObjectiveResult, OptimizeError>
44where
45 F: Fn(&ArrayView1<f64>) -> Array1<f64> + Send + Sync,
46{
47 let config = config.unwrap_or_default();
48 let mut optimizer =
49 algorithms::OptimizerFactory::create_nsga2(config, n_objectives, n_variables)?;
50 optimizer.optimize(objective_function)
51}
52
53pub fn nsga3<F>(
55 objective_function: F,
56 n_objectives: usize,
57 n_variables: usize,
58 config: Option<MultiObjectiveConfig>,
59 reference_points: Option<Vec<Array1<f64>>>,
60) -> Result<MultiObjectiveResult, OptimizeError>
61where
62 F: Fn(&ArrayView1<f64>) -> Array1<f64> + Send + Sync,
63{
64 let config = config.unwrap_or_default();
65 let mut optimizer = algorithms::OptimizerFactory::create_nsga3(
66 config,
67 n_objectives,
68 n_variables,
69 reference_points,
70 )?;
71 optimizer.optimize(objective_function)
72}
73
74pub fn optimize<F>(
76 algorithm: &str,
77 objective_function: F,
78 n_objectives: usize,
79 n_variables: usize,
80 config: Option<MultiObjectiveConfig>,
81) -> Result<MultiObjectiveResult, OptimizeError>
82where
83 F: Fn(&ArrayView1<f64>) -> Array1<f64> + Send + Sync,
84{
85 let config = config.unwrap_or_default();
86 let mut optimizer =
87 algorithms::OptimizerFactory::create_by_name(algorithm, config, n_objectives, n_variables)?;
88
89 let adapted_fn = |x: &ArrayView1<f64>| objective_function(x);
91
92 optimizer.optimize(adapted_fn)
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use scirs2_core::ndarray::array;
99
100 fn zdt1(x: &ArrayView1<f64>) -> Array1<f64> {
102 let f1 = x[0];
103 let g = 1.0 + 9.0 * x.slice(s![1..]).sum() / (x.len() - 1) as f64;
104 let f2 = g * (1.0 - (f1 / g).sqrt());
105 array![f1, f2]
106 }
107
108 #[test]
109 fn test_nsga2_convenience_function() {
110 let mut config = MultiObjectiveConfig::default();
111 config.max_generations = 5;
112 config.population_size = 10;
113 config.bounds = Some((Array1::zeros(2), Array1::ones(2)));
114
115 let result = nsga2(zdt1, 2, 2, Some(config));
116 assert!(result.is_ok());
117
118 let result = result.expect("Operation failed");
119 assert!(result.success);
120 assert!(!result.pareto_front.is_empty());
121 }
122
123 #[test]
124 fn test_optimize_by_name() {
125 let mut config = MultiObjectiveConfig::default();
126 config.max_generations = 5;
127 config.population_size = 10;
128 config.bounds = Some((Array1::zeros(2), Array1::ones(2)));
129
130 let result = optimize("nsga2", zdt1, 2, 2, Some(config.clone()));
131 assert!(result.is_ok());
132
133 let result = optimize("unknown", zdt1, 2, 2, Some(config));
134 assert!(result.is_err());
135 }
136
137 #[test]
138 #[ignore] fn test_default_config() {
140 let result = nsga2(zdt1, 2, 2, None);
141 assert!(result.is_ok() || matches!(result, Err(OptimizeError::MaxEvaluationsReached)));
143 }
144}