scirs2_optimize/multi_objective/
mod.rs1pub mod algorithms;
18pub mod crossover;
19pub mod indicators;
20pub mod mutation;
21pub mod selection;
22pub mod solutions;
23
24pub use algorithms::{
26 MultiObjectiveConfig, MultiObjectiveOptimizer, MultiObjectiveOptimizerWrapper,
27 OptimizerFactory, NSGAII, NSGAIII,
28};
29pub use solutions::{
30 MultiObjectiveResult, MultiObjectiveSolution, OptimizationMetrics, Population,
31};
32
33use crate::error::OptimizeError;
34use ndarray::{s, Array1, ArrayView1};
35
36pub fn nsga2<F>(
38 objective_function: F,
39 n_objectives: usize,
40 n_variables: usize,
41 config: Option<MultiObjectiveConfig>,
42) -> Result<MultiObjectiveResult, OptimizeError>
43where
44 F: Fn(&ArrayView1<f64>) -> Array1<f64> + Send + Sync,
45{
46 let config = config.unwrap_or_default();
47 let mut optimizer =
48 algorithms::OptimizerFactory::create_nsga2(config, n_objectives, n_variables)?;
49 optimizer.optimize(objective_function)
50}
51
52pub fn nsga3<F>(
54 objective_function: F,
55 n_objectives: usize,
56 n_variables: usize,
57 config: Option<MultiObjectiveConfig>,
58 reference_points: Option<Vec<Array1<f64>>>,
59) -> Result<MultiObjectiveResult, OptimizeError>
60where
61 F: Fn(&ArrayView1<f64>) -> Array1<f64> + Send + Sync,
62{
63 let config = config.unwrap_or_default();
64 let mut optimizer = algorithms::OptimizerFactory::create_nsga3(
65 config,
66 n_objectives,
67 n_variables,
68 reference_points,
69 )?;
70 optimizer.optimize(objective_function)
71}
72
73pub fn optimize<F>(
75 algorithm: &str,
76 objective_function: F,
77 n_objectives: usize,
78 n_variables: usize,
79 config: Option<MultiObjectiveConfig>,
80) -> Result<MultiObjectiveResult, OptimizeError>
81where
82 F: Fn(&ArrayView1<f64>) -> Array1<f64> + Send + Sync,
83{
84 let config = config.unwrap_or_default();
85 let mut optimizer =
86 algorithms::OptimizerFactory::create_by_name(algorithm, config, n_objectives, n_variables)?;
87
88 let adapted_fn = |x: &ArrayView1<f64>| objective_function(x);
90
91 optimizer.optimize(adapted_fn)
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97 use ndarray::array;
98
99 fn zdt1(x: &ArrayView1<f64>) -> Array1<f64> {
101 let f1 = x[0];
102 let g = 1.0 + 9.0 * x.slice(s![1..]).sum() / (x.len() - 1) as f64;
103 let f2 = g * (1.0 - (f1 / g).sqrt());
104 array![f1, f2]
105 }
106
107 #[test]
108 fn test_nsga2_convenience_function() {
109 let mut config = MultiObjectiveConfig::default();
110 config.max_generations = 5;
111 config.population_size = 10;
112 config.bounds = Some((Array1::zeros(2), Array1::ones(2)));
113
114 let result = nsga2(zdt1, 2, 2, Some(config));
115 assert!(result.is_ok());
116
117 let result = result.unwrap();
118 assert!(result.success);
119 assert!(!result.pareto_front.is_empty());
120 }
121
122 #[test]
123 fn test_optimize_by_name() {
124 let mut config = MultiObjectiveConfig::default();
125 config.max_generations = 5;
126 config.population_size = 10;
127 config.bounds = Some((Array1::zeros(2), Array1::ones(2)));
128
129 let result = optimize("nsga2", zdt1, 2, 2, Some(config.clone()));
130 assert!(result.is_ok());
131
132 let result = optimize("unknown", zdt1, 2, 2, Some(config));
133 assert!(result.is_err());
134 }
135
136 #[test]
137 fn test_default_config() {
138 let result = nsga2(zdt1, 2, 2, None);
139 assert!(result.is_ok() || matches!(result, Err(OptimizeError::MaxEvaluationsReached)));
141 }
142}