sklears_multioutput/optimization/
multi_objective_optimization.rs1use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
18use scirs2_core::random::thread_rng;
19use scirs2_core::random::{RandNormal, Rng};
20use sklears_core::{
21 error::{Result as SklResult, SklearsError},
22 traits::{Estimator, Fit, Predict, Untrained},
23 types::Float,
24};
25
26#[derive(Debug, Clone)]
28pub struct MultiObjectiveOptimizer<S = Untrained> {
29 state: S,
30 config: MultiObjectiveConfig,
31}
32
33#[derive(Debug, Clone)]
35pub struct MultiObjectiveConfig {
36 pub population_size: usize,
38 pub generations: usize,
40 pub mutation_rate: Float,
42 pub crossover_rate: Float,
44 pub selection_pressure: Float,
46 pub objectives: Vec<String>,
48 pub random_state: Option<u64>,
50}
51
52impl Default for MultiObjectiveConfig {
53 fn default() -> Self {
54 Self {
55 population_size: 100,
56 generations: 100,
57 mutation_rate: 0.1,
58 crossover_rate: 0.8,
59 selection_pressure: 2.0,
60 objectives: vec!["accuracy".to_string(), "complexity".to_string()],
61 random_state: None,
62 }
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct ParetoSolution {
69 pub parameters: Array1<Float>,
71 pub objectives: Array1<Float>,
73 pub rank: usize,
75 pub crowding_distance: Float,
77}
78
79#[derive(Debug, Clone)]
81pub struct MultiObjectiveOptimizerTrained {
82 pub pareto_solutions: Vec<ParetoSolution>,
84 pub best_solution: ParetoSolution,
86 pub convergence_history: Vec<Float>,
88 pub config: MultiObjectiveConfig,
90 pub n_outputs: usize,
92}
93
94impl MultiObjectiveOptimizer<Untrained> {
95 pub fn new() -> Self {
97 Self {
98 state: Untrained,
99 config: MultiObjectiveConfig::default(),
100 }
101 }
102
103 pub fn config(mut self, config: MultiObjectiveConfig) -> Self {
105 self.config = config;
106 self
107 }
108
109 pub fn population_size(mut self, population_size: usize) -> Self {
111 self.config.population_size = population_size;
112 self
113 }
114
115 pub fn generations(mut self, generations: usize) -> Self {
117 self.config.generations = generations;
118 self
119 }
120
121 pub fn mutation_rate(mut self, mutation_rate: Float) -> Self {
123 self.config.mutation_rate = mutation_rate;
124 self
125 }
126
127 pub fn crossover_rate(mut self, crossover_rate: Float) -> Self {
129 self.config.crossover_rate = crossover_rate;
130 self
131 }
132
133 pub fn selection_pressure(mut self, selection_pressure: Float) -> Self {
135 self.config.selection_pressure = selection_pressure;
136 self
137 }
138
139 pub fn objectives(mut self, objectives: Vec<String>) -> Self {
141 self.config.objectives = objectives;
142 self
143 }
144
145 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
147 self.config.random_state = random_state;
148 self
149 }
150}
151
152impl Default for MultiObjectiveOptimizer<Untrained> {
153 fn default() -> Self {
154 Self::new()
155 }
156}
157
158impl Estimator for MultiObjectiveOptimizer<Untrained> {
159 type Config = MultiObjectiveConfig;
160 type Error = SklearsError;
161 type Float = Float;
162
163 fn config(&self) -> &Self::Config {
164 &self.config
165 }
166}
167
168impl Fit<ArrayView2<'_, Float>, ArrayView2<'_, Float>> for MultiObjectiveOptimizer<Untrained> {
169 type Fitted = MultiObjectiveOptimizer<MultiObjectiveOptimizerTrained>;
170
171 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView2<'_, Float>) -> SklResult<Self::Fitted> {
172 let (n_samples, n_features) = X.dim();
173 let (y_samples, n_outputs) = y.dim();
174
175 if n_samples != y_samples {
176 return Err(SklearsError::InvalidInput(
177 "X and y must have the same number of samples".to_string(),
178 ));
179 }
180
181 let mut rng = thread_rng();
182
183 let mut population = self.initialize_population(n_features, n_outputs, &mut rng)?;
185 let mut convergence_history = Vec::new();
186
187 for generation in 0..self.config.generations {
188 self.evaluate_population(&mut population, X, y)?;
190
191 self.non_dominated_sort(&mut population)?;
193
194 self.calculate_crowding_distance(&mut population)?;
196
197 population = self.evolve_population(population, &mut rng)?;
199
200 let hypervolume = self.calculate_hypervolume(&population)?;
202 convergence_history.push(hypervolume);
203 }
204
205 self.evaluate_population(&mut population, X, y)?;
207 self.non_dominated_sort(&mut population)?;
208
209 let pareto_solutions: Vec<ParetoSolution> =
211 population.into_iter().filter(|sol| sol.rank == 0).collect();
212
213 let best_solution = self.find_best_compromise(&pareto_solutions)?;
215
216 Ok(MultiObjectiveOptimizer {
217 state: MultiObjectiveOptimizerTrained {
218 pareto_solutions,
219 best_solution,
220 convergence_history,
221 config: self.config.clone(),
222 n_outputs,
223 },
224 config: self.config,
225 })
226 }
227}
228
229impl MultiObjectiveOptimizer<Untrained> {
230 fn initialize_population(
232 &self,
233 n_features: usize,
234 n_outputs: usize,
235 rng: &mut scirs2_core::random::CoreRandom,
236 ) -> SklResult<Vec<ParetoSolution>> {
237 let mut population = Vec::new();
238
239 for _ in 0..self.config.population_size {
240 let param_size = n_features * n_outputs + n_outputs;
242 let normal_dist = RandNormal::new(0.0, 1.0).unwrap();
243 let mut parameters = Array1::<Float>::zeros(param_size);
244 for i in 0..param_size {
245 parameters[i] = rng.sample(normal_dist);
246 }
247
248 let solution = ParetoSolution {
249 parameters,
250 objectives: Array1::<Float>::zeros(self.config.objectives.len()),
251 rank: 0,
252 crowding_distance: 0.0,
253 };
254
255 population.push(solution);
256 }
257
258 Ok(population)
259 }
260
261 fn evaluate_population(
263 &self,
264 population: &mut [ParetoSolution],
265 X: &ArrayView2<'_, Float>,
266 y: &ArrayView2<'_, Float>,
267 ) -> SklResult<()> {
268 let (n_samples, n_features) = X.dim();
269 let n_outputs = y.ncols();
270
271 for solution in population.iter_mut() {
272 let weights_size = n_features * n_outputs;
274 let weights = solution
275 .parameters
276 .slice(s![..weights_size])
277 .to_owned()
278 .into_shape((n_features, n_outputs))
279 .unwrap();
280 let bias = solution.parameters.slice(s![weights_size..]).to_owned();
281
282 let predictions = X.dot(&weights) + &bias;
284
285 let mut objectives = Array1::<Float>::zeros(self.config.objectives.len());
287
288 for (i, objective) in self.config.objectives.iter().enumerate() {
289 let objective_value = match objective.as_str() {
290 "accuracy" => self.calculate_accuracy(&predictions, y)?,
291 "complexity" => self.calculate_complexity(&weights, &bias)?,
292 "mse" => self.calculate_mse(&predictions, y)?,
293 "mae" => self.calculate_mae(&predictions, y)?,
294 _ => {
295 return Err(SklearsError::InvalidInput(format!(
296 "Unknown objective: {}",
297 objective
298 )))
299 }
300 };
301 objectives[i] = objective_value;
302 }
303
304 solution.objectives = objectives;
305 }
306
307 Ok(())
308 }
309
310 fn calculate_accuracy(
312 &self,
313 predictions: &Array2<Float>,
314 y: &ArrayView2<'_, Float>,
315 ) -> SklResult<Float> {
316 let mse = predictions
317 .iter()
318 .zip(y.iter())
319 .map(|(pred, true_val)| (pred - true_val).powi(2))
320 .sum::<Float>()
321 / (predictions.len() as Float);
322 Ok(-mse) }
324
325 fn calculate_complexity(
327 &self,
328 weights: &Array2<Float>,
329 bias: &Array1<Float>,
330 ) -> SklResult<Float> {
331 let weight_complexity = weights.mapv(|x| x.abs()).sum();
332 let bias_complexity = bias.mapv(|x| x.abs()).sum();
333 Ok(weight_complexity + bias_complexity)
334 }
335
336 fn calculate_mse(
338 &self,
339 predictions: &Array2<Float>,
340 y: &ArrayView2<'_, Float>,
341 ) -> SklResult<Float> {
342 let mse = predictions
343 .iter()
344 .zip(y.iter())
345 .map(|(pred, true_val)| (pred - true_val).powi(2))
346 .sum::<Float>()
347 / (predictions.len() as Float);
348 Ok(mse)
349 }
350
351 fn calculate_mae(
353 &self,
354 predictions: &Array2<Float>,
355 y: &ArrayView2<'_, Float>,
356 ) -> SklResult<Float> {
357 let mae = predictions
358 .iter()
359 .zip(y.iter())
360 .map(|(pred, true_val)| (pred - true_val).abs())
361 .sum::<Float>()
362 / (predictions.len() as Float);
363 Ok(mae)
364 }
365
366 fn non_dominated_sort(&self, population: &mut [ParetoSolution]) -> SklResult<()> {
368 let n = population.len();
369 let mut domination_count = vec![0; n];
370 let mut dominated_solutions = vec![Vec::new(); n];
371
372 for i in 0..n {
374 for j in 0..n {
375 if i != j {
376 if self.dominates(&population[i], &population[j]) {
377 dominated_solutions[i].push(j);
378 } else if self.dominates(&population[j], &population[i]) {
379 domination_count[i] += 1;
380 }
381 }
382 }
383 }
384
385 let mut current_rank = 0;
387 let mut current_front: Vec<usize> = (0..n).filter(|&i| domination_count[i] == 0).collect();
388
389 while !current_front.is_empty() {
390 let mut next_front = Vec::new();
391
392 for &i in ¤t_front {
393 population[i].rank = current_rank;
394
395 for &j in &dominated_solutions[i] {
396 domination_count[j] -= 1;
397 if domination_count[j] == 0 {
398 next_front.push(j);
399 }
400 }
401 }
402
403 current_front = next_front;
404 current_rank += 1;
405 }
406
407 Ok(())
408 }
409
410 fn dominates(&self, a: &ParetoSolution, b: &ParetoSolution) -> bool {
412 let mut at_least_one_better = false;
413
414 for i in 0..a.objectives.len() {
415 if a.objectives[i] < b.objectives[i] {
416 return false; } else if a.objectives[i] > b.objectives[i] {
418 at_least_one_better = true;
419 }
420 }
421
422 at_least_one_better
423 }
424
425 fn calculate_crowding_distance(&self, population: &mut [ParetoSolution]) -> SklResult<()> {
427 let n = population.len();
428 let n_objectives = self.config.objectives.len();
429
430 for solution in population.iter_mut() {
432 solution.crowding_distance = 0.0;
433 }
434
435 for obj_idx in 0..n_objectives {
437 let mut indices: Vec<usize> = (0..n).collect();
439 indices.sort_by(|&i, &j| {
440 population[i].objectives[obj_idx]
441 .partial_cmp(&population[j].objectives[obj_idx])
442 .unwrap()
443 });
444
445 population[indices[0]].crowding_distance = Float::INFINITY;
447 population[indices[n - 1]].crowding_distance = Float::INFINITY;
448
449 let obj_range = population[indices[n - 1]].objectives[obj_idx]
451 - population[indices[0]].objectives[obj_idx];
452
453 if obj_range > 0.0 {
454 for i in 1..n - 1 {
455 let distance = (population[indices[i + 1]].objectives[obj_idx]
456 - population[indices[i - 1]].objectives[obj_idx])
457 / obj_range;
458 population[indices[i]].crowding_distance += distance;
459 }
460 }
461 }
462
463 Ok(())
464 }
465
466 fn evolve_population(
468 &self,
469 population: Vec<ParetoSolution>,
470 rng: &mut scirs2_core::random::CoreRandom,
471 ) -> SklResult<Vec<ParetoSolution>> {
472 let mut new_population = Vec::new();
473
474 while new_population.len() < self.config.population_size {
475 let parent1 = self.tournament_selection(&population, rng)?;
477 let parent2 = self.tournament_selection(&population, rng)?;
478
479 let (mut child1, mut child2) = self.crossover(&parent1, &parent2, rng)?;
481
482 self.mutate(&mut child1, rng)?;
484 self.mutate(&mut child2, rng)?;
485
486 new_population.push(child1);
487 if new_population.len() < self.config.population_size {
488 new_population.push(child2);
489 }
490 }
491
492 Ok(new_population)
493 }
494
495 fn tournament_selection(
497 &self,
498 population: &[ParetoSolution],
499 rng: &mut scirs2_core::random::CoreRandom,
500 ) -> SklResult<ParetoSolution> {
501 let tournament_size = 3;
502 let mut best_solution = None;
503
504 for _ in 0..tournament_size {
505 let idx = rng.gen_range(0..population.len());
506 let candidate = &population[idx];
507
508 if let Some(ref current_best) = best_solution {
509 if self.is_better_solution(candidate, current_best) {
510 best_solution = Some(candidate.clone());
511 }
512 } else {
513 best_solution = Some(candidate.clone());
514 }
515 }
516
517 best_solution
518 .ok_or_else(|| SklearsError::InvalidInput("Tournament selection failed".to_string()))
519 }
520
521 fn is_better_solution(&self, a: &ParetoSolution, b: &ParetoSolution) -> bool {
523 if a.rank < b.rank {
524 true
525 } else if a.rank == b.rank {
526 a.crowding_distance > b.crowding_distance
527 } else {
528 false
529 }
530 }
531
532 fn crossover(
534 &self,
535 parent1: &ParetoSolution,
536 parent2: &ParetoSolution,
537 rng: &mut scirs2_core::random::CoreRandom,
538 ) -> SklResult<(ParetoSolution, ParetoSolution)> {
539 let mut child1 = parent1.clone();
540 let mut child2 = parent2.clone();
541
542 if rng.gen::<Float>() < self.config.crossover_rate {
543 for i in 0..parent1.parameters.len() {
545 if rng.gen::<Float>() < 0.5 {
546 child1.parameters[i] = parent2.parameters[i];
547 child2.parameters[i] = parent1.parameters[i];
548 }
549 }
550 }
551
552 Ok((child1, child2))
553 }
554
555 fn mutate(
557 &self,
558 solution: &mut ParetoSolution,
559 rng: &mut scirs2_core::random::CoreRandom,
560 ) -> SklResult<()> {
561 for param in solution.parameters.iter_mut() {
562 if rng.gen::<Float>() < self.config.mutation_rate {
563 let mutation = rng.gen_range(-0.1..0.1);
564 *param += mutation;
565 }
566 }
567 Ok(())
568 }
569
570 fn calculate_hypervolume(&self, population: &[ParetoSolution]) -> SklResult<Float> {
572 let pareto_front: Vec<&ParetoSolution> =
574 population.iter().filter(|sol| sol.rank == 0).collect();
575
576 if pareto_front.is_empty() {
577 return Ok(0.0);
578 }
579
580 let hypervolume = pareto_front
582 .iter()
583 .map(|sol| sol.objectives.sum())
584 .sum::<Float>()
585 / pareto_front.len() as Float;
586
587 Ok(hypervolume)
588 }
589
590 fn find_best_compromise(
592 &self,
593 pareto_solutions: &[ParetoSolution],
594 ) -> SklResult<ParetoSolution> {
595 if pareto_solutions.is_empty() {
596 return Err(SklearsError::InvalidInput(
597 "No Pareto solutions available".to_string(),
598 ));
599 }
600
601 let mut best_solution = pareto_solutions[0].clone();
603 let mut best_distance = Float::INFINITY;
604
605 for solution in pareto_solutions {
606 let distance = solution.objectives.mapv(|x| x * x).sum().sqrt();
607 if distance < best_distance {
608 best_distance = distance;
609 best_solution = solution.clone();
610 }
611 }
612
613 Ok(best_solution)
614 }
615}
616
617impl Predict<ArrayView2<'_, Float>, Array2<Float>>
618 for MultiObjectiveOptimizer<MultiObjectiveOptimizerTrained>
619{
620 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
621 let (n_samples, n_features) = X.dim();
622 let best_solution = &self.state.best_solution;
623
624 let n_outputs = self.state.n_outputs;
626 let weights_size = n_features * n_outputs;
627 let weights = best_solution
628 .parameters
629 .slice(s![..weights_size])
630 .to_owned()
631 .into_shape((n_features, n_outputs))
632 .unwrap();
633 let bias = best_solution
634 .parameters
635 .slice(s![weights_size..weights_size + n_outputs])
636 .to_owned();
637
638 let predictions = X.dot(&weights) + &bias;
639 Ok(predictions)
640 }
641}
642
643impl Estimator for MultiObjectiveOptimizer<MultiObjectiveOptimizerTrained> {
644 type Config = MultiObjectiveConfig;
645 type Error = SklearsError;
646 type Float = Float;
647
648 fn config(&self) -> &Self::Config {
649 &self.state.config
650 }
651}
652
653impl MultiObjectiveOptimizer<MultiObjectiveOptimizerTrained> {
654 pub fn pareto_solutions(&self) -> &[ParetoSolution] {
656 &self.state.pareto_solutions
657 }
658
659 pub fn best_solution(&self) -> &ParetoSolution {
661 &self.state.best_solution
662 }
663
664 pub fn convergence_history(&self) -> &[Float] {
666 &self.state.convergence_history
667 }
668}