sklears_multioutput/optimization/
multi_objective_optimization.rs

1//! Multi-Objective Optimization for Multi-Output Learning
2//!
3//! This module provides advanced multi-objective optimization techniques for multi-output
4//! learning problems, where multiple conflicting objectives need to be optimized simultaneously.
5//! It implements genetic algorithm-based approaches to find Pareto-optimal solutions.
6//!
7//! ## Key Features
8//!
9//! - **Genetic Algorithm**: Population-based evolutionary optimization
10//! - **Pareto Optimization**: Find trade-off solutions between conflicting objectives
11//! - **Non-dominated Sorting**: Efficient ranking of solutions using NSGA-II principles
12//! - **Crowding Distance**: Maintain diversity in the Pareto front
13//! - **Multiple Objectives**: Support for accuracy, complexity, MSE, MAE, and custom objectives
14//! - **Tournament Selection**: Efficient parent selection for reproduction
15
16// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
17use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
18use scirs2_core::random::thread_rng;
19use scirs2_core::random::{RandNormal, Rng};
20use sklears_core::{
21    error::{Result as SklResult, SklearsError},
22    traits::{Estimator, Fit, Predict, Untrained},
23    types::Float,
24};
25
26/// Multi-Objective Optimization for multi-output learning
27#[derive(Debug, Clone)]
28pub struct MultiObjectiveOptimizer<S = Untrained> {
29    state: S,
30    config: MultiObjectiveConfig,
31}
32
33/// Configuration for Multi-Objective Optimization
34#[derive(Debug, Clone)]
35pub struct MultiObjectiveConfig {
36    /// Population size for genetic algorithm
37    pub population_size: usize,
38    /// Number of generations
39    pub generations: usize,
40    /// Mutation rate
41    pub mutation_rate: Float,
42    /// Crossover rate
43    pub crossover_rate: Float,
44    /// Selection pressure
45    pub selection_pressure: Float,
46    /// Objective functions
47    pub objectives: Vec<String>,
48    /// Random state for reproducibility
49    pub random_state: Option<u64>,
50}
51
52impl Default for MultiObjectiveConfig {
53    fn default() -> Self {
54        Self {
55            population_size: 100,
56            generations: 100,
57            mutation_rate: 0.1,
58            crossover_rate: 0.8,
59            selection_pressure: 2.0,
60            objectives: vec!["accuracy".to_string(), "complexity".to_string()],
61            random_state: None,
62        }
63    }
64}
65
66/// Pareto-optimal solution
67#[derive(Debug, Clone)]
68pub struct ParetoSolution {
69    /// Solution parameters
70    pub parameters: Array1<Float>,
71    /// Objective values
72    pub objectives: Array1<Float>,
73    /// Dominance rank
74    pub rank: usize,
75    /// Crowding distance
76    pub crowding_distance: Float,
77}
78
79/// Trained state for Multi-Objective Optimizer
80#[derive(Debug, Clone)]
81pub struct MultiObjectiveOptimizerTrained {
82    /// Pareto-optimal solutions
83    pub pareto_solutions: Vec<ParetoSolution>,
84    /// Best compromise solution
85    pub best_solution: ParetoSolution,
86    /// Convergence history
87    pub convergence_history: Vec<Float>,
88    /// Configuration used for optimization
89    pub config: MultiObjectiveConfig,
90    /// Number of outputs
91    pub n_outputs: usize,
92}
93
94impl MultiObjectiveOptimizer<Untrained> {
95    /// Create a new Multi-Objective Optimizer
96    pub fn new() -> Self {
97        Self {
98            state: Untrained,
99            config: MultiObjectiveConfig::default(),
100        }
101    }
102
103    /// Set the configuration
104    pub fn config(mut self, config: MultiObjectiveConfig) -> Self {
105        self.config = config;
106        self
107    }
108
109    /// Set the population size
110    pub fn population_size(mut self, population_size: usize) -> Self {
111        self.config.population_size = population_size;
112        self
113    }
114
115    /// Set the number of generations
116    pub fn generations(mut self, generations: usize) -> Self {
117        self.config.generations = generations;
118        self
119    }
120
121    /// Set the mutation rate
122    pub fn mutation_rate(mut self, mutation_rate: Float) -> Self {
123        self.config.mutation_rate = mutation_rate;
124        self
125    }
126
127    /// Set the crossover rate
128    pub fn crossover_rate(mut self, crossover_rate: Float) -> Self {
129        self.config.crossover_rate = crossover_rate;
130        self
131    }
132
133    /// Set the selection pressure
134    pub fn selection_pressure(mut self, selection_pressure: Float) -> Self {
135        self.config.selection_pressure = selection_pressure;
136        self
137    }
138
139    /// Set the objective functions
140    pub fn objectives(mut self, objectives: Vec<String>) -> Self {
141        self.config.objectives = objectives;
142        self
143    }
144
145    /// Set the random state
146    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
147        self.config.random_state = random_state;
148        self
149    }
150}
151
152impl Default for MultiObjectiveOptimizer<Untrained> {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158impl Estimator for MultiObjectiveOptimizer<Untrained> {
159    type Config = MultiObjectiveConfig;
160    type Error = SklearsError;
161    type Float = Float;
162
163    fn config(&self) -> &Self::Config {
164        &self.config
165    }
166}
167
168impl Fit<ArrayView2<'_, Float>, ArrayView2<'_, Float>> for MultiObjectiveOptimizer<Untrained> {
169    type Fitted = MultiObjectiveOptimizer<MultiObjectiveOptimizerTrained>;
170
171    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView2<'_, Float>) -> SklResult<Self::Fitted> {
172        let (n_samples, n_features) = X.dim();
173        let (y_samples, n_outputs) = y.dim();
174
175        if n_samples != y_samples {
176            return Err(SklearsError::InvalidInput(
177                "X and y must have the same number of samples".to_string(),
178            ));
179        }
180
181        let mut rng = thread_rng();
182
183        // Initialize population
184        let mut population = self.initialize_population(n_features, n_outputs, &mut rng)?;
185        let mut convergence_history = Vec::new();
186
187        for generation in 0..self.config.generations {
188            // Evaluate objectives for all solutions
189            self.evaluate_population(&mut population, X, y)?;
190
191            // Non-dominated sorting
192            self.non_dominated_sort(&mut population)?;
193
194            // Calculate crowding distance
195            self.calculate_crowding_distance(&mut population)?;
196
197            // Selection, crossover, and mutation
198            population = self.evolve_population(population, &mut rng)?;
199
200            // Track convergence
201            let hypervolume = self.calculate_hypervolume(&population)?;
202            convergence_history.push(hypervolume);
203        }
204
205        // Final evaluation and sorting
206        self.evaluate_population(&mut population, X, y)?;
207        self.non_dominated_sort(&mut population)?;
208
209        // Extract Pareto-optimal solutions
210        let pareto_solutions: Vec<ParetoSolution> =
211            population.into_iter().filter(|sol| sol.rank == 0).collect();
212
213        // Find best compromise solution (closest to ideal point)
214        let best_solution = self.find_best_compromise(&pareto_solutions)?;
215
216        Ok(MultiObjectiveOptimizer {
217            state: MultiObjectiveOptimizerTrained {
218                pareto_solutions,
219                best_solution,
220                convergence_history,
221                config: self.config.clone(),
222                n_outputs,
223            },
224            config: self.config,
225        })
226    }
227}
228
229impl MultiObjectiveOptimizer<Untrained> {
230    /// Initialize population with random solutions
231    fn initialize_population(
232        &self,
233        n_features: usize,
234        n_outputs: usize,
235        rng: &mut scirs2_core::random::CoreRandom,
236    ) -> SklResult<Vec<ParetoSolution>> {
237        let mut population = Vec::new();
238
239        for _ in 0..self.config.population_size {
240            // Random parameters (weights and bias)
241            let param_size = n_features * n_outputs + n_outputs;
242            let normal_dist = RandNormal::new(0.0, 1.0).unwrap();
243            let mut parameters = Array1::<Float>::zeros(param_size);
244            for i in 0..param_size {
245                parameters[i] = rng.sample(normal_dist);
246            }
247
248            let solution = ParetoSolution {
249                parameters,
250                objectives: Array1::<Float>::zeros(self.config.objectives.len()),
251                rank: 0,
252                crowding_distance: 0.0,
253            };
254
255            population.push(solution);
256        }
257
258        Ok(population)
259    }
260
261    /// Evaluate objectives for all solutions in the population
262    fn evaluate_population(
263        &self,
264        population: &mut [ParetoSolution],
265        X: &ArrayView2<'_, Float>,
266        y: &ArrayView2<'_, Float>,
267    ) -> SklResult<()> {
268        let (n_samples, n_features) = X.dim();
269        let n_outputs = y.ncols();
270
271        for solution in population.iter_mut() {
272            // Extract weights and bias from parameters
273            let weights_size = n_features * n_outputs;
274            let weights = solution
275                .parameters
276                .slice(s![..weights_size])
277                .to_owned()
278                .into_shape((n_features, n_outputs))
279                .unwrap();
280            let bias = solution.parameters.slice(s![weights_size..]).to_owned();
281
282            // Make predictions
283            let predictions = X.dot(&weights) + &bias;
284
285            // Calculate objectives
286            let mut objectives = Array1::<Float>::zeros(self.config.objectives.len());
287
288            for (i, objective) in self.config.objectives.iter().enumerate() {
289                let objective_value = match objective.as_str() {
290                    "accuracy" => self.calculate_accuracy(&predictions, y)?,
291                    "complexity" => self.calculate_complexity(&weights, &bias)?,
292                    "mse" => self.calculate_mse(&predictions, y)?,
293                    "mae" => self.calculate_mae(&predictions, y)?,
294                    _ => {
295                        return Err(SklearsError::InvalidInput(format!(
296                            "Unknown objective: {}",
297                            objective
298                        )))
299                    }
300                };
301                objectives[i] = objective_value;
302            }
303
304            solution.objectives = objectives;
305        }
306
307        Ok(())
308    }
309
310    /// Calculate accuracy objective
311    fn calculate_accuracy(
312        &self,
313        predictions: &Array2<Float>,
314        y: &ArrayView2<'_, Float>,
315    ) -> SklResult<Float> {
316        let mse = predictions
317            .iter()
318            .zip(y.iter())
319            .map(|(pred, true_val)| (pred - true_val).powi(2))
320            .sum::<Float>()
321            / (predictions.len() as Float);
322        Ok(-mse) // Negative because we want to minimize MSE (maximize accuracy)
323    }
324
325    /// Calculate complexity objective
326    fn calculate_complexity(
327        &self,
328        weights: &Array2<Float>,
329        bias: &Array1<Float>,
330    ) -> SklResult<Float> {
331        let weight_complexity = weights.mapv(|x| x.abs()).sum();
332        let bias_complexity = bias.mapv(|x| x.abs()).sum();
333        Ok(weight_complexity + bias_complexity)
334    }
335
336    /// Calculate MSE objective
337    fn calculate_mse(
338        &self,
339        predictions: &Array2<Float>,
340        y: &ArrayView2<'_, Float>,
341    ) -> SklResult<Float> {
342        let mse = predictions
343            .iter()
344            .zip(y.iter())
345            .map(|(pred, true_val)| (pred - true_val).powi(2))
346            .sum::<Float>()
347            / (predictions.len() as Float);
348        Ok(mse)
349    }
350
351    /// Calculate MAE objective
352    fn calculate_mae(
353        &self,
354        predictions: &Array2<Float>,
355        y: &ArrayView2<'_, Float>,
356    ) -> SklResult<Float> {
357        let mae = predictions
358            .iter()
359            .zip(y.iter())
360            .map(|(pred, true_val)| (pred - true_val).abs())
361            .sum::<Float>()
362            / (predictions.len() as Float);
363        Ok(mae)
364    }
365
366    /// Non-dominated sorting
367    fn non_dominated_sort(&self, population: &mut [ParetoSolution]) -> SklResult<()> {
368        let n = population.len();
369        let mut domination_count = vec![0; n];
370        let mut dominated_solutions = vec![Vec::new(); n];
371
372        // Calculate domination relationships
373        for i in 0..n {
374            for j in 0..n {
375                if i != j {
376                    if self.dominates(&population[i], &population[j]) {
377                        dominated_solutions[i].push(j);
378                    } else if self.dominates(&population[j], &population[i]) {
379                        domination_count[i] += 1;
380                    }
381                }
382            }
383        }
384
385        // Assign ranks
386        let mut current_rank = 0;
387        let mut current_front: Vec<usize> = (0..n).filter(|&i| domination_count[i] == 0).collect();
388
389        while !current_front.is_empty() {
390            let mut next_front = Vec::new();
391
392            for &i in &current_front {
393                population[i].rank = current_rank;
394
395                for &j in &dominated_solutions[i] {
396                    domination_count[j] -= 1;
397                    if domination_count[j] == 0 {
398                        next_front.push(j);
399                    }
400                }
401            }
402
403            current_front = next_front;
404            current_rank += 1;
405        }
406
407        Ok(())
408    }
409
410    /// Check if solution a dominates solution b
411    fn dominates(&self, a: &ParetoSolution, b: &ParetoSolution) -> bool {
412        let mut at_least_one_better = false;
413
414        for i in 0..a.objectives.len() {
415            if a.objectives[i] < b.objectives[i] {
416                return false; // a is worse in at least one objective
417            } else if a.objectives[i] > b.objectives[i] {
418                at_least_one_better = true;
419            }
420        }
421
422        at_least_one_better
423    }
424
425    /// Calculate crowding distance
426    fn calculate_crowding_distance(&self, population: &mut [ParetoSolution]) -> SklResult<()> {
427        let n = population.len();
428        let n_objectives = self.config.objectives.len();
429
430        // Initialize crowding distances
431        for solution in population.iter_mut() {
432            solution.crowding_distance = 0.0;
433        }
434
435        // Calculate crowding distance for each objective
436        for obj_idx in 0..n_objectives {
437            // Sort by objective value
438            let mut indices: Vec<usize> = (0..n).collect();
439            indices.sort_by(|&i, &j| {
440                population[i].objectives[obj_idx]
441                    .partial_cmp(&population[j].objectives[obj_idx])
442                    .unwrap()
443            });
444
445            // Set boundary points to infinite distance
446            population[indices[0]].crowding_distance = Float::INFINITY;
447            population[indices[n - 1]].crowding_distance = Float::INFINITY;
448
449            // Calculate crowding distance for middle points
450            let obj_range = population[indices[n - 1]].objectives[obj_idx]
451                - population[indices[0]].objectives[obj_idx];
452
453            if obj_range > 0.0 {
454                for i in 1..n - 1 {
455                    let distance = (population[indices[i + 1]].objectives[obj_idx]
456                        - population[indices[i - 1]].objectives[obj_idx])
457                        / obj_range;
458                    population[indices[i]].crowding_distance += distance;
459                }
460            }
461        }
462
463        Ok(())
464    }
465
466    /// Evolve population through selection, crossover, and mutation
467    fn evolve_population(
468        &self,
469        population: Vec<ParetoSolution>,
470        rng: &mut scirs2_core::random::CoreRandom,
471    ) -> SklResult<Vec<ParetoSolution>> {
472        let mut new_population = Vec::new();
473
474        while new_population.len() < self.config.population_size {
475            // Tournament selection
476            let parent1 = self.tournament_selection(&population, rng)?;
477            let parent2 = self.tournament_selection(&population, rng)?;
478
479            // Crossover
480            let (mut child1, mut child2) = self.crossover(&parent1, &parent2, rng)?;
481
482            // Mutation
483            self.mutate(&mut child1, rng)?;
484            self.mutate(&mut child2, rng)?;
485
486            new_population.push(child1);
487            if new_population.len() < self.config.population_size {
488                new_population.push(child2);
489            }
490        }
491
492        Ok(new_population)
493    }
494
495    /// Tournament selection
496    fn tournament_selection(
497        &self,
498        population: &[ParetoSolution],
499        rng: &mut scirs2_core::random::CoreRandom,
500    ) -> SklResult<ParetoSolution> {
501        let tournament_size = 3;
502        let mut best_solution = None;
503
504        for _ in 0..tournament_size {
505            let idx = rng.gen_range(0..population.len());
506            let candidate = &population[idx];
507
508            if let Some(ref current_best) = best_solution {
509                if self.is_better_solution(candidate, current_best) {
510                    best_solution = Some(candidate.clone());
511                }
512            } else {
513                best_solution = Some(candidate.clone());
514            }
515        }
516
517        best_solution
518            .ok_or_else(|| SklearsError::InvalidInput("Tournament selection failed".to_string()))
519    }
520
521    /// Check if solution a is better than solution b
522    fn is_better_solution(&self, a: &ParetoSolution, b: &ParetoSolution) -> bool {
523        if a.rank < b.rank {
524            true
525        } else if a.rank == b.rank {
526            a.crowding_distance > b.crowding_distance
527        } else {
528            false
529        }
530    }
531
532    /// Crossover operation
533    fn crossover(
534        &self,
535        parent1: &ParetoSolution,
536        parent2: &ParetoSolution,
537        rng: &mut scirs2_core::random::CoreRandom,
538    ) -> SklResult<(ParetoSolution, ParetoSolution)> {
539        let mut child1 = parent1.clone();
540        let mut child2 = parent2.clone();
541
542        if rng.gen::<Float>() < self.config.crossover_rate {
543            // Uniform crossover
544            for i in 0..parent1.parameters.len() {
545                if rng.gen::<Float>() < 0.5 {
546                    child1.parameters[i] = parent2.parameters[i];
547                    child2.parameters[i] = parent1.parameters[i];
548                }
549            }
550        }
551
552        Ok((child1, child2))
553    }
554
555    /// Mutation operation
556    fn mutate(
557        &self,
558        solution: &mut ParetoSolution,
559        rng: &mut scirs2_core::random::CoreRandom,
560    ) -> SklResult<()> {
561        for param in solution.parameters.iter_mut() {
562            if rng.gen::<Float>() < self.config.mutation_rate {
563                let mutation = rng.gen_range(-0.1..0.1);
564                *param += mutation;
565            }
566        }
567        Ok(())
568    }
569
570    /// Calculate hypervolume (convergence metric)
571    fn calculate_hypervolume(&self, population: &[ParetoSolution]) -> SklResult<Float> {
572        // Simplified hypervolume calculation
573        let pareto_front: Vec<&ParetoSolution> =
574            population.iter().filter(|sol| sol.rank == 0).collect();
575
576        if pareto_front.is_empty() {
577            return Ok(0.0);
578        }
579
580        // Use the sum of objective values as a proxy for hypervolume
581        let hypervolume = pareto_front
582            .iter()
583            .map(|sol| sol.objectives.sum())
584            .sum::<Float>()
585            / pareto_front.len() as Float;
586
587        Ok(hypervolume)
588    }
589
590    /// Find best compromise solution
591    fn find_best_compromise(
592        &self,
593        pareto_solutions: &[ParetoSolution],
594    ) -> SklResult<ParetoSolution> {
595        if pareto_solutions.is_empty() {
596            return Err(SklearsError::InvalidInput(
597                "No Pareto solutions available".to_string(),
598            ));
599        }
600
601        // Find the solution closest to the ideal point (origin)
602        let mut best_solution = pareto_solutions[0].clone();
603        let mut best_distance = Float::INFINITY;
604
605        for solution in pareto_solutions {
606            let distance = solution.objectives.mapv(|x| x * x).sum().sqrt();
607            if distance < best_distance {
608                best_distance = distance;
609                best_solution = solution.clone();
610            }
611        }
612
613        Ok(best_solution)
614    }
615}
616
617impl Predict<ArrayView2<'_, Float>, Array2<Float>>
618    for MultiObjectiveOptimizer<MultiObjectiveOptimizerTrained>
619{
620    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
621        let (n_samples, n_features) = X.dim();
622        let best_solution = &self.state.best_solution;
623
624        // Extract weights and bias from best solution parameters
625        let n_outputs = self.state.n_outputs;
626        let weights_size = n_features * n_outputs;
627        let weights = best_solution
628            .parameters
629            .slice(s![..weights_size])
630            .to_owned()
631            .into_shape((n_features, n_outputs))
632            .unwrap();
633        let bias = best_solution
634            .parameters
635            .slice(s![weights_size..weights_size + n_outputs])
636            .to_owned();
637
638        let predictions = X.dot(&weights) + &bias;
639        Ok(predictions)
640    }
641}
642
643impl Estimator for MultiObjectiveOptimizer<MultiObjectiveOptimizerTrained> {
644    type Config = MultiObjectiveConfig;
645    type Error = SklearsError;
646    type Float = Float;
647
648    fn config(&self) -> &Self::Config {
649        &self.state.config
650    }
651}
652
653impl MultiObjectiveOptimizer<MultiObjectiveOptimizerTrained> {
654    /// Get the Pareto-optimal solutions
655    pub fn pareto_solutions(&self) -> &[ParetoSolution] {
656        &self.state.pareto_solutions
657    }
658
659    /// Get the best compromise solution
660    pub fn best_solution(&self) -> &ParetoSolution {
661        &self.state.best_solution
662    }
663
664    /// Get the convergence history
665    pub fn convergence_history(&self) -> &[Float] {
666        &self.state.convergence_history
667    }
668}