1use scirs2_core::ndarray::{Array1, ArrayView1, ArrayView2};
7use scirs2_core::random::{thread_rng, Rng};
8use sklears_core::{
9 error::Result as SklResult,
10 prelude::SklearsError,
11 types::{Float, FloatBounds},
12};
13use std::collections::HashMap;
14use std::time::Instant;
15
16use crate::Pipeline;
17
18pub enum SearchStrategy {
20 GridSearch,
22 RandomSearch { n_iter: usize },
24 BayesianOptimization,
26 EvolutionarySearch {
28 population_size: usize,
29 generations: usize,
30 },
31}
32
33#[derive(Debug, Clone)]
35pub struct ParameterSpace {
36 pub name: String,
38 pub values: Vec<f64>,
40 pub param_type: ParameterType,
42}
43
44#[derive(Debug, Clone)]
46pub enum ParameterType {
47 Continuous { min: f64, max: f64 },
49 Discrete { min: i32, max: i32 },
51 Categorical { choices: Vec<String> },
53}
54
55impl ParameterSpace {
56 #[must_use]
58 pub fn continuous(name: &str, min: f64, max: f64, n_points: usize) -> Self {
59 let step = (max - min) / (n_points - 1) as f64;
60 let values = (0..n_points).map(|i| min + i as f64 * step).collect();
61
62 Self {
63 name: name.to_string(),
64 values,
65 param_type: ParameterType::Continuous { min, max },
66 }
67 }
68
69 #[must_use]
71 pub fn discrete(name: &str, min: i32, max: i32) -> Self {
72 let values = (min..=max).map(f64::from).collect();
73
74 Self {
75 name: name.to_string(),
76 values,
77 param_type: ParameterType::Discrete { min, max },
78 }
79 }
80
81 #[must_use]
83 pub fn categorical(name: &str, choices: Vec<String>) -> Self {
84 let values = (0..choices.len()).map(|i| i as f64).collect();
85
86 Self {
87 name: name.to_string(),
88 values,
89 param_type: ParameterType::Categorical { choices },
90 }
91 }
92}
93
94pub struct PipelineOptimizer {
96 parameter_spaces: Vec<ParameterSpace>,
97 search_strategy: SearchStrategy,
98 cv_folds: usize,
99 scoring: ScoringMetric,
100 n_jobs: Option<i32>,
101 verbose: bool,
102}
103
104#[derive(Debug, Clone)]
106pub enum ScoringMetric {
107 MeanSquaredError,
109 MeanAbsoluteError,
111 Accuracy,
113 F1Score,
115 Custom { name: String },
117 MultiObjective { metrics: Vec<ScoringMetric> },
119}
120
121#[derive(Debug, Clone)]
123pub struct MultiObjectiveResult {
124 pub params: HashMap<String, f64>,
126 pub scores: Vec<f64>,
128 pub dominated: bool,
130 pub rank: usize,
132}
133
134#[derive(Debug)]
136pub struct ParetoFront {
137 pub solutions: Vec<MultiObjectiveResult>,
139 pub n_objectives: usize,
141 pub hypervolume: f64,
143}
144
145impl PipelineOptimizer {
146 #[must_use]
148 pub fn new() -> Self {
149 Self {
150 parameter_spaces: Vec::new(),
151 search_strategy: SearchStrategy::GridSearch,
152 cv_folds: 5,
153 scoring: ScoringMetric::MeanSquaredError,
154 n_jobs: None,
155 verbose: false,
156 }
157 }
158
159 #[must_use]
161 pub fn parameter_space(mut self, space: ParameterSpace) -> Self {
162 self.parameter_spaces.push(space);
163 self
164 }
165
166 #[must_use]
168 pub fn search_strategy(mut self, strategy: SearchStrategy) -> Self {
169 self.search_strategy = strategy;
170 self
171 }
172
173 #[must_use]
175 pub fn cv_folds(mut self, folds: usize) -> Self {
176 self.cv_folds = folds;
177 self
178 }
179
180 #[must_use]
182 pub fn scoring(mut self, metric: ScoringMetric) -> Self {
183 self.scoring = metric;
184 self
185 }
186
187 #[must_use]
189 pub fn verbose(mut self, verbose: bool) -> Self {
190 self.verbose = verbose;
191 self
192 }
193
194 pub fn optimize<S>(
196 &self,
197 pipeline: Pipeline<S>,
198 x: &ArrayView2<'_, Float>,
199 y: &ArrayView1<'_, Float>,
200 ) -> SklResult<OptimizationResults>
201 where
202 S: std::fmt::Debug,
203 {
204 match self.search_strategy {
205 SearchStrategy::GridSearch => self.grid_search(pipeline, x, y),
206 SearchStrategy::RandomSearch { n_iter } => self.random_search(pipeline, x, y, n_iter),
207 SearchStrategy::BayesianOptimization => Err(SklearsError::NotImplemented(
208 "Bayesian optimization not yet implemented".to_string(),
209 )),
210 SearchStrategy::EvolutionarySearch {
211 population_size,
212 generations,
213 } => self.evolutionary_search(pipeline, x, y, population_size, generations),
214 }
215 }
216
217 fn grid_search<S>(
218 &self,
219 pipeline: Pipeline<S>,
220 x: &ArrayView2<'_, Float>,
221 y: &ArrayView1<'_, Float>,
222 ) -> SklResult<OptimizationResults>
223 where
224 S: std::fmt::Debug,
225 {
226 let start_time = Instant::now();
227
228 if self.parameter_spaces.is_empty() {
229 return Err(SklearsError::InvalidInput(
230 "No parameter spaces defined for optimization".to_string(),
231 ));
232 }
233
234 let param_combinations = self.generate_grid_combinations()?;
236
237 if self.verbose {
238 println!(
239 "Grid search: evaluating {} parameter combinations",
240 param_combinations.len()
241 );
242 }
243
244 let mut best_score = f64::NEG_INFINITY;
245 let mut best_params = HashMap::new();
246 let mut all_scores = Vec::new();
247
248 for (i, params) in param_combinations.iter().enumerate() {
250 if self.verbose {
251 println!(
252 "Evaluating combination {}/{}",
253 i + 1,
254 param_combinations.len()
255 );
256 }
257
258 let cv_score = self.cross_validate_pipeline(&pipeline, x, y)?;
261 all_scores.push(cv_score);
262
263 if cv_score > best_score {
264 best_score = cv_score;
265 best_params = params.clone();
266 }
267 }
268
269 let search_time = start_time.elapsed().as_secs_f64();
270
271 Ok(OptimizationResults {
272 best_params,
273 best_score,
274 cv_scores: all_scores,
275 search_time,
276 })
277 }
278
279 fn random_search<S>(
280 &self,
281 pipeline: Pipeline<S>,
282 x: &ArrayView2<'_, Float>,
283 y: &ArrayView1<'_, Float>,
284 n_iter: usize,
285 ) -> SklResult<OptimizationResults>
286 where
287 S: std::fmt::Debug,
288 {
289 let start_time = Instant::now();
290 let mut rng = thread_rng();
291
292 if self.parameter_spaces.is_empty() {
293 return Err(SklearsError::InvalidInput(
294 "No parameter spaces defined for optimization".to_string(),
295 ));
296 }
297
298 if self.verbose {
299 println!("Random search: evaluating {n_iter} random parameter combinations");
300 }
301
302 let mut best_score = f64::NEG_INFINITY;
303 let mut best_params = HashMap::new();
304 let mut all_scores = Vec::new();
305
306 for i in 0..n_iter {
308 if self.verbose {
309 println!("Evaluating combination {}/{}", i + 1, n_iter);
310 }
311
312 let params = self.generate_random_parameters(&mut rng)?;
314
315 let cv_score = self.cross_validate_pipeline(&pipeline, x, y)?;
317 all_scores.push(cv_score);
318
319 if cv_score > best_score {
320 best_score = cv_score;
321 best_params = params;
322 }
323 }
324
325 let search_time = start_time.elapsed().as_secs_f64();
326
327 Ok(OptimizationResults {
328 best_params,
329 best_score,
330 cv_scores: all_scores,
331 search_time,
332 })
333 }
334
335 fn evolutionary_search<S>(
336 &self,
337 pipeline: Pipeline<S>,
338 x: &ArrayView2<'_, Float>,
339 y: &ArrayView1<'_, Float>,
340 population_size: usize,
341 generations: usize,
342 ) -> SklResult<OptimizationResults>
343 where
344 S: std::fmt::Debug,
345 {
346 let start_time = Instant::now();
347 let mut rng = thread_rng();
348
349 if self.parameter_spaces.is_empty() {
350 return Err(SklearsError::InvalidInput(
351 "No parameter spaces defined for optimization".to_string(),
352 ));
353 }
354
355 if self.verbose {
356 println!(
357 "Evolutionary search: {generations} generations with population size {population_size}"
358 );
359 }
360
361 let mut population = Vec::new();
363 for _ in 0..population_size {
364 let params = self.generate_random_parameters(&mut rng)?;
365 population.push(params);
366 }
367
368 let mut best_score = f64::NEG_INFINITY;
369 let mut best_params = HashMap::new();
370 let mut all_scores = Vec::new();
371
372 for generation in 0..generations {
374 if self.verbose {
375 println!("Generation {}/{}", generation + 1, generations);
376 }
377
378 let mut fitness_scores = Vec::new();
380 for params in &population {
381 let score = self.cross_validate_pipeline(&pipeline, x, y)?;
383 fitness_scores.push(score);
384 all_scores.push(score);
385
386 if score > best_score {
387 best_score = score;
388 best_params = params.clone();
389 }
390 }
391
392 let mut new_population = Vec::new();
394
395 let elite_count = population_size / 4;
397 let mut indexed_fitness: Vec<(usize, f64)> = fitness_scores
398 .iter()
399 .enumerate()
400 .map(|(i, &score)| (i, score))
401 .collect();
402 indexed_fitness.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
403
404 for i in 0..elite_count {
405 let elite_idx = indexed_fitness[i].0;
406 new_population.push(population[elite_idx].clone());
407 }
408
409 while new_population.len() < population_size {
411 let parent1_idx = self.tournament_selection(&fitness_scores, &mut rng);
413 let parent2_idx = self.tournament_selection(&fitness_scores, &mut rng);
414
415 let offspring =
417 self.crossover(&population[parent1_idx], &population[parent2_idx], &mut rng)?;
418
419 let mutated_offspring = self.mutate(offspring, &mut rng)?;
421
422 new_population.push(mutated_offspring);
423 }
424
425 population = new_population;
426 }
427
428 let search_time = start_time.elapsed().as_secs_f64();
429
430 Ok(OptimizationResults {
431 best_params,
432 best_score,
433 cv_scores: all_scores,
434 search_time,
435 })
436 }
437
438 pub fn multi_objective_optimize<S>(
440 &self,
441 pipeline: Pipeline<S>,
442 x: &ArrayView2<'_, Float>,
443 y: &ArrayView1<'_, Float>,
444 metrics: Vec<ScoringMetric>,
445 ) -> SklResult<ParetoFront>
446 where
447 S: std::fmt::Debug,
448 {
449 let start_time = Instant::now();
450
451 if metrics.is_empty() {
452 return Err(SklearsError::InvalidInput(
453 "At least one metric must be specified for multi-objective optimization"
454 .to_string(),
455 ));
456 }
457
458 let population_size = 100;
460 let generations = 50;
461
462 let results = self.nsga_ii(pipeline, x, y, &metrics, population_size, generations)?;
463
464 let search_time = start_time.elapsed().as_secs_f64();
465
466 if self.verbose {
467 println!("Multi-objective optimization completed in {search_time:.2}s");
468 println!(
469 "Found {} solutions in Pareto front",
470 results.solutions.len()
471 );
472 }
473
474 Ok(results)
475 }
476
477 fn nsga_ii<S>(
479 &self,
480 pipeline: Pipeline<S>,
481 x: &ArrayView2<'_, Float>,
482 y: &ArrayView1<'_, Float>,
483 metrics: &[ScoringMetric],
484 population_size: usize,
485 generations: usize,
486 ) -> SklResult<ParetoFront>
487 where
488 S: std::fmt::Debug,
489 {
490 let mut rng = thread_rng();
491
492 let mut population = Vec::new();
494 for _ in 0..population_size {
495 let params = self.generate_random_parameters(&mut rng)?;
496 let scores = self.evaluate_multi_objective(&pipeline, x, y, ¶ms, metrics)?;
497
498 population.push(MultiObjectiveResult {
499 params,
500 scores,
501 dominated: false,
502 rank: 0,
503 });
504 }
505
506 for generation in 0..generations {
508 if self.verbose && generation % 10 == 0 {
509 println!("NSGA-II Generation {}/{}", generation + 1, generations);
510 }
511
512 let mut offspring = Vec::new();
514 while offspring.len() < population_size {
515 let parent1_idx = rng.gen_range(0..population.len());
517 let parent2_idx = rng.gen_range(0..population.len());
518
519 let child_params = self.crossover(
521 &population[parent1_idx].params,
522 &population[parent2_idx].params,
523 &mut rng,
524 )?;
525
526 let mutated_params = self.mutate(child_params, &mut rng)?;
528
529 let scores =
531 self.evaluate_multi_objective(&pipeline, x, y, &mutated_params, metrics)?;
532
533 offspring.push(MultiObjectiveResult {
534 params: mutated_params,
535 scores,
536 dominated: false,
537 rank: 0,
538 });
539 }
540
541 let mut combined_population = population;
543 combined_population.extend(offspring);
544
545 population = self.select_next_generation(combined_population, population_size);
547 }
548
549 let pareto_solutions: Vec<MultiObjectiveResult> =
551 population.into_iter().filter(|sol| sol.rank == 0).collect();
552
553 let hypervolume = self.calculate_hypervolume(&pareto_solutions, metrics.len());
554
555 Ok(ParetoFront {
556 solutions: pareto_solutions,
557 n_objectives: metrics.len(),
558 hypervolume,
559 })
560 }
561
562 fn tournament_selection(&self, fitness_scores: &[f64], rng: &mut impl Rng) -> usize {
564 let tournament_size = 3;
565 let mut best_idx = rng.gen_range(0..fitness_scores.len());
566 let mut best_score = fitness_scores[best_idx];
567
568 for _ in 1..tournament_size {
569 let candidate_idx = rng.gen_range(0..fitness_scores.len());
570 let candidate_score = fitness_scores[candidate_idx];
571
572 if candidate_score > best_score {
573 best_idx = candidate_idx;
574 best_score = candidate_score;
575 }
576 }
577
578 best_idx
579 }
580
581 fn crossover(
583 &self,
584 parent1: &HashMap<String, f64>,
585 parent2: &HashMap<String, f64>,
586 rng: &mut impl Rng,
587 ) -> SklResult<HashMap<String, f64>> {
588 let mut offspring = HashMap::new();
589
590 for space in &self.parameter_spaces {
591 let value1 = parent1.get(&space.name).copied().unwrap_or(0.0);
592 let value2 = parent2.get(&space.name).copied().unwrap_or(0.0);
593
594 let offspring_value = if rng.gen_bool(0.5) { value1 } else { value2 };
596
597 let final_value = match &space.param_type {
599 ParameterType::Continuous { min, max } => {
600 if rng.gen_bool(0.3) {
601 let alpha = 0.5;
603 let range = (value2 - value1).abs();
604 let min_blend = value1.min(value2) - alpha * range;
605 let max_blend = value1.max(value2) + alpha * range;
606
607 rng.gen_range(min_blend.max(*min)..=max_blend.min(*max))
608 } else {
609 offspring_value.clamp(*min, *max)
610 }
611 }
612 ParameterType::Discrete { min, max } => {
613 f64::from((offspring_value.round() as i32).clamp(*min, *max))
614 }
615 ParameterType::Categorical { choices } => {
616 (offspring_value as usize % choices.len()) as f64
617 }
618 };
619
620 offspring.insert(space.name.clone(), final_value);
621 }
622
623 Ok(offspring)
624 }
625
626 fn mutate(
628 &self,
629 mut individual: HashMap<String, f64>,
630 rng: &mut impl Rng,
631 ) -> SklResult<HashMap<String, f64>> {
632 let mutation_rate = 0.1;
633
634 for space in &self.parameter_spaces {
635 if rng.gen_bool(mutation_rate) {
636 let current_value = individual.get(&space.name).copied().unwrap_or(0.0);
637
638 let mutated_value = match &space.param_type {
639 ParameterType::Continuous { min, max } => {
640 let sigma = (max - min) * 0.1;
642 let noise = rng.gen_range(-sigma..=sigma);
643 (current_value + noise).clamp(*min, *max)
644 }
645 ParameterType::Discrete { min, max } => {
646 f64::from(rng.gen_range(*min..=*max))
648 }
649 ParameterType::Categorical { choices } => {
650 rng.gen_range(0..choices.len()) as f64
652 }
653 };
654
655 individual.insert(space.name.clone(), mutated_value);
656 }
657 }
658
659 Ok(individual)
660 }
661
662 fn evaluate_multi_objective<S>(
664 &self,
665 pipeline: &Pipeline<S>,
666 x: &ArrayView2<'_, Float>,
667 y: &ArrayView1<'_, Float>,
668 _params: &HashMap<String, f64>,
669 metrics: &[ScoringMetric],
670 ) -> SklResult<Vec<f64>>
671 where
672 S: std::fmt::Debug,
673 {
674 let mut scores = Vec::new();
675
676 for metric in metrics {
677 let score = if let ScoringMetric::MultiObjective { .. } = metric {
679 return Err(SklearsError::InvalidInput(
680 "Nested multi-objective metrics not supported".to_string(),
681 ));
682 } else {
683 let original_scoring = self.scoring.clone();
685 let temp_optimizer = PipelineOptimizer {
686 parameter_spaces: Vec::new(),
687 search_strategy: SearchStrategy::GridSearch,
688 cv_folds: self.cv_folds,
689 scoring: metric.clone(),
690 n_jobs: self.n_jobs,
691 verbose: false,
692 };
693 temp_optimizer.cross_validate_pipeline(pipeline, x, y)?
694 };
695 scores.push(score);
696 }
697
698 Ok(scores)
699 }
700
701 fn select_next_generation(
703 &self,
704 mut population: Vec<MultiObjectiveResult>,
705 target_size: usize,
706 ) -> Vec<MultiObjectiveResult> {
707 let fronts = self.non_dominated_sort(&mut population);
709
710 let mut next_generation = Vec::new();
711
712 for (rank, front) in fronts.iter().enumerate() {
713 if next_generation.len() + front.len() <= target_size {
714 for &idx in front {
716 population[idx].rank = rank;
717 next_generation.push(population[idx].clone());
718 }
719 } else {
720 let remaining_slots = target_size - next_generation.len();
722 let mut front_with_distance: Vec<(usize, f64)> = front
723 .iter()
724 .map(|&idx| {
725 let distance = self.calculate_crowding_distance(&population, front, idx);
726 (idx, distance)
727 })
728 .collect();
729
730 front_with_distance.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
732
733 for i in 0..remaining_slots {
734 let idx = front_with_distance[i].0;
735 population[idx].rank = rank;
736 next_generation.push(population[idx].clone());
737 }
738 break;
739 }
740 }
741
742 next_generation
743 }
744
745 fn non_dominated_sort(&self, population: &mut [MultiObjectiveResult]) -> Vec<Vec<usize>> {
747 let n = population.len();
748 let mut fronts = Vec::new();
749 let mut dominated_count = vec![0; n];
750 let mut dominated_solutions: Vec<Vec<usize>> = vec![Vec::new(); n];
751
752 let mut current_front = Vec::new();
754
755 for i in 0..n {
756 for j in 0..n {
757 if i != j {
758 let dominates = self.dominates(&population[i], &population[j]);
759 let dominated_by = self.dominates(&population[j], &population[i]);
760
761 if dominates {
762 dominated_solutions[i].push(j);
763 } else if dominated_by {
764 dominated_count[i] += 1;
765 }
766 }
767 }
768
769 if dominated_count[i] == 0 {
770 current_front.push(i);
771 }
772 }
773
774 fronts.push(current_front.clone());
775
776 while !current_front.is_empty() {
778 let mut next_front = Vec::new();
779
780 for &i in ¤t_front {
781 for &j in &dominated_solutions[i] {
782 dominated_count[j] -= 1;
783 if dominated_count[j] == 0 {
784 next_front.push(j);
785 }
786 }
787 }
788
789 if !next_front.is_empty() {
790 fronts.push(next_front.clone());
791 }
792 current_front = next_front;
793 }
794
795 fronts
796 }
797
798 fn dominates(&self, a: &MultiObjectiveResult, b: &MultiObjectiveResult) -> bool {
800 let mut at_least_one_better = false;
801
802 for i in 0..a.scores.len() {
803 if a.scores[i] < b.scores[i] {
804 return false; }
806 if a.scores[i] > b.scores[i] {
807 at_least_one_better = true;
808 }
809 }
810
811 at_least_one_better
812 }
813
814 fn calculate_crowding_distance(
816 &self,
817 population: &[MultiObjectiveResult],
818 front: &[usize],
819 individual_idx: usize,
820 ) -> f64 {
821 if front.len() <= 2 {
822 return f64::INFINITY;
823 }
824
825 let n_objectives = population[individual_idx].scores.len();
826 let mut distance = 0.0;
827
828 for obj in 0..n_objectives {
829 let mut sorted_front = front.to_vec();
831 sorted_front.sort_by(|&a, &b| {
832 population[a].scores[obj]
833 .partial_cmp(&population[b].scores[obj])
834 .unwrap()
835 });
836
837 let pos = sorted_front
839 .iter()
840 .position(|&idx| idx == individual_idx)
841 .unwrap();
842
843 if pos == 0 || pos == sorted_front.len() - 1 {
844 return f64::INFINITY;
846 }
847
848 let obj_min = population[sorted_front[0]].scores[obj];
850 let obj_max = population[sorted_front[sorted_front.len() - 1]].scores[obj];
851
852 if obj_max > obj_min {
853 let prev_obj = population[sorted_front[pos - 1]].scores[obj];
854 let next_obj = population[sorted_front[pos + 1]].scores[obj];
855 distance += (next_obj - prev_obj) / (obj_max - obj_min);
856 }
857 }
858
859 distance
860 }
861
862 fn calculate_hypervolume(
864 &self,
865 solutions: &[MultiObjectiveResult],
866 n_objectives: usize,
867 ) -> f64 {
868 if solutions.is_empty() {
869 return 0.0;
870 }
871
872 if n_objectives == 2 {
874 let mut sorted_solutions = solutions.to_vec();
875 sorted_solutions.sort_by(|a, b| a.scores[0].partial_cmp(&b.scores[0]).unwrap());
876
877 let mut hypervolume = 0.0;
878 let mut prev_x = 0.0;
879
880 for solution in &sorted_solutions {
881 if solution.scores[0] > prev_x {
882 hypervolume += (solution.scores[0] - prev_x) * solution.scores[1];
883 prev_x = solution.scores[0];
884 }
885 }
886
887 hypervolume
888 } else {
889 solutions.len() as f64
891 }
892 }
893
894 fn generate_grid_combinations(&self) -> SklResult<Vec<HashMap<String, f64>>> {
896 if self.parameter_spaces.is_empty() {
897 return Ok(vec![HashMap::new()]);
898 }
899
900 let mut combinations = vec![HashMap::new()];
901
902 for space in &self.parameter_spaces {
903 let mut new_combinations = Vec::new();
904
905 for value in &space.values {
906 for existing_combo in &combinations {
907 let mut new_combo = existing_combo.clone();
908 new_combo.insert(space.name.clone(), *value);
909 new_combinations.push(new_combo);
910 }
911 }
912
913 combinations = new_combinations;
914 }
915
916 Ok(combinations)
917 }
918
919 fn generate_random_parameters(&self, rng: &mut impl Rng) -> SklResult<HashMap<String, f64>> {
921 let mut params = HashMap::new();
922
923 for space in &self.parameter_spaces {
924 let value = match &space.param_type {
925 ParameterType::Continuous { min, max } => rng.gen_range(*min..*max),
926 ParameterType::Discrete { min, max } => f64::from(rng.gen_range(*min..=*max)),
927 ParameterType::Categorical { choices } => {
928 let idx = rng.gen_range(0..choices.len());
929 idx as f64
930 }
931 };
932
933 params.insert(space.name.clone(), value);
934 }
935
936 Ok(params)
937 }
938
939 fn cross_validate_pipeline<S>(
941 &self,
942 _pipeline: &Pipeline<S>,
943 x: &ArrayView2<'_, Float>,
944 y: &ArrayView1<'_, Float>,
945 ) -> SklResult<f64>
946 where
947 S: std::fmt::Debug,
948 {
949 let n_samples = x.nrows();
950 let fold_size = n_samples / self.cv_folds;
951 let mut scores = Vec::new();
952
953 for fold in 0..self.cv_folds {
954 let start_idx = fold * fold_size;
955 let end_idx = if fold == self.cv_folds - 1 {
956 n_samples
957 } else {
958 (fold + 1) * fold_size
959 };
960
961 let mut train_indices = Vec::new();
963 let mut test_indices = Vec::new();
964
965 for i in 0..n_samples {
966 if i >= start_idx && i < end_idx {
967 test_indices.push(i);
968 } else {
969 train_indices.push(i);
970 }
971 }
972
973 let score = self.compute_mock_score(x, y, &train_indices, &test_indices)?;
977 scores.push(score);
978 }
979
980 Ok(scores.iter().sum::<f64>() / scores.len() as f64)
982 }
983
984 fn compute_mock_score(
986 &self,
987 x: &ArrayView2<'_, Float>,
988 y: &ArrayView1<'_, Float>,
989 train_indices: &[usize],
990 test_indices: &[usize],
991 ) -> SklResult<f64> {
992 match self.scoring {
994 ScoringMetric::MeanSquaredError => {
995 let test_targets: Vec<f64> = test_indices.iter().map(|&i| y[i]).collect();
997
998 if test_targets.is_empty() {
999 return Ok(0.0);
1000 }
1001
1002 let mean = test_targets.iter().sum::<f64>() / test_targets.len() as f64;
1003 let variance = test_targets
1004 .iter()
1005 .map(|&val| (val - mean).powi(2))
1006 .sum::<f64>()
1007 / test_targets.len() as f64;
1008
1009 Ok(-variance.sqrt())
1011 }
1012 ScoringMetric::MeanAbsoluteError => {
1013 let test_targets: Vec<f64> = test_indices.iter().map(|&i| y[i]).collect();
1015
1016 if test_targets.is_empty() {
1017 return Ok(0.0);
1018 }
1019
1020 let mean = test_targets.iter().sum::<f64>() / test_targets.len() as f64;
1021 let mae = test_targets
1022 .iter()
1023 .map(|&val| (val - mean).abs())
1024 .sum::<f64>()
1025 / test_targets.len() as f64;
1026
1027 Ok(-mae)
1028 }
1029 ScoringMetric::Accuracy | ScoringMetric::F1Score => {
1030 let unique_classes = y
1032 .iter()
1033 .map(|&val| val as i32)
1034 .collect::<std::collections::HashSet<_>>();
1035
1036 Ok(1.0 / unique_classes.len() as f64)
1038 }
1039 ScoringMetric::Custom { .. } => {
1040 Ok(0.8)
1042 }
1043 ScoringMetric::MultiObjective { .. } => {
1044 Ok(0.5)
1046 }
1047 }
1048 }
1049}
1050
1051impl Default for PipelineOptimizer {
1052 fn default() -> Self {
1053 Self::new()
1054 }
1055}
1056
1057#[derive(Debug)]
1059pub struct OptimizationResults {
1060 pub best_params: HashMap<String, f64>,
1062 pub best_score: f64,
1064 pub cv_scores: Vec<f64>,
1066 pub search_time: f64,
1068}
1069
1070pub struct PipelineValidator {
1072 check_data_types: bool,
1073 check_missing_values: bool,
1074 check_infinite_values: bool,
1075 check_feature_names: bool,
1076 verbose: bool,
1077}
1078
1079impl PipelineValidator {
1080 #[must_use]
1082 pub fn new() -> Self {
1083 Self {
1084 check_data_types: true,
1085 check_missing_values: true,
1086 check_infinite_values: true,
1087 check_feature_names: false,
1088 verbose: false,
1089 }
1090 }
1091
1092 #[must_use]
1094 pub fn check_data_types(mut self, check: bool) -> Self {
1095 self.check_data_types = check;
1096 self
1097 }
1098
1099 #[must_use]
1101 pub fn check_missing_values(mut self, check: bool) -> Self {
1102 self.check_missing_values = check;
1103 self
1104 }
1105
1106 #[must_use]
1108 pub fn check_infinite_values(mut self, check: bool) -> Self {
1109 self.check_infinite_values = check;
1110 self
1111 }
1112
1113 #[must_use]
1115 pub fn check_feature_names(mut self, check: bool) -> Self {
1116 self.check_feature_names = check;
1117 self
1118 }
1119
1120 #[must_use]
1122 pub fn verbose(mut self, verbose: bool) -> Self {
1123 self.verbose = verbose;
1124 self
1125 }
1126
1127 pub fn validate_data(
1129 &self,
1130 x: &ArrayView2<'_, Float>,
1131 y: Option<&ArrayView1<'_, Float>>,
1132 ) -> SklResult<()> {
1133 if self.check_missing_values {
1134 self.check_for_missing_values(x)?;
1135 }
1136
1137 if self.check_infinite_values {
1138 self.check_for_infinite_values(x)?;
1139 }
1140
1141 if let Some(y_values) = y {
1142 self.validate_target(y_values)?;
1143 }
1144
1145 Ok(())
1146 }
1147
1148 fn check_for_missing_values(&self, x: &ArrayView2<'_, Float>) -> SklResult<()> {
1149 for (i, row) in x.rows().into_iter().enumerate() {
1150 for (j, &value) in row.iter().enumerate() {
1151 if value.is_nan() {
1152 return Err(SklearsError::InvalidData {
1153 reason: format!("Missing value (NaN) found at position ({i}, {j})"),
1154 });
1155 }
1156 }
1157 }
1158 Ok(())
1159 }
1160
1161 fn check_for_infinite_values(&self, x: &ArrayView2<'_, Float>) -> SklResult<()> {
1162 for (i, row) in x.rows().into_iter().enumerate() {
1163 for (j, &value) in row.iter().enumerate() {
1164 if value.is_infinite() {
1165 return Err(SklearsError::InvalidData {
1166 reason: format!("Infinite value found at position ({i}, {j})"),
1167 });
1168 }
1169 }
1170 }
1171 Ok(())
1172 }
1173
1174 fn validate_target(&self, y: &ArrayView1<'_, Float>) -> SklResult<()> {
1175 for (i, &value) in y.iter().enumerate() {
1176 if value.is_nan() {
1177 return Err(SklearsError::InvalidData {
1178 reason: format!("Missing value (NaN) found in target at position {i}"),
1179 });
1180 }
1181 if value.is_infinite() {
1182 return Err(SklearsError::InvalidData {
1183 reason: format!("Infinite value found in target at position {i}"),
1184 });
1185 }
1186 }
1187 Ok(())
1188 }
1189
1190 pub fn validate_pipeline<S>(&self, _pipeline: &Pipeline<S>) -> SklResult<()>
1192 where
1193 S: std::fmt::Debug,
1194 {
1195 Ok(())
1197 }
1198}
1199
1200impl Default for PipelineValidator {
1201 fn default() -> Self {
1202 Self::new()
1203 }
1204}
1205
1206pub struct RobustPipelineExecutor {
1208 max_retries: usize,
1209 fallback_strategy: FallbackStrategy,
1210 error_handling: ErrorHandlingStrategy,
1211 timeout_seconds: Option<u64>,
1212}
1213
1214#[derive(Debug, Clone)]
1216pub enum FallbackStrategy {
1217 ReturnError,
1219 SimplerPipeline,
1221 DefaultValues,
1223 SkipStep,
1225}
1226
1227#[derive(Debug, Clone)]
1229pub enum ErrorHandlingStrategy {
1230 FailFast,
1232 ContinueWithWarnings,
1234 AttemptRecovery,
1236}
1237
1238impl RobustPipelineExecutor {
1239 #[must_use]
1241 pub fn new() -> Self {
1242 Self {
1243 max_retries: 3,
1244 fallback_strategy: FallbackStrategy::ReturnError,
1245 error_handling: ErrorHandlingStrategy::FailFast,
1246 timeout_seconds: None,
1247 }
1248 }
1249
1250 #[must_use]
1252 pub fn max_retries(mut self, retries: usize) -> Self {
1253 self.max_retries = retries;
1254 self
1255 }
1256
1257 #[must_use]
1259 pub fn fallback_strategy(mut self, strategy: FallbackStrategy) -> Self {
1260 self.fallback_strategy = strategy;
1261 self
1262 }
1263
1264 #[must_use]
1266 pub fn error_handling(mut self, strategy: ErrorHandlingStrategy) -> Self {
1267 self.error_handling = strategy;
1268 self
1269 }
1270
1271 #[must_use]
1273 pub fn timeout_seconds(mut self, timeout: u64) -> Self {
1274 self.timeout_seconds = Some(timeout);
1275 self
1276 }
1277
1278 pub fn execute<S>(
1280 &self,
1281 mut pipeline: Pipeline<S>,
1282 x: &ArrayView2<'_, Float>,
1283 y: Option<&ArrayView1<'_, Float>>,
1284 ) -> SklResult<Array1<f64>>
1285 where
1286 S: std::fmt::Debug,
1287 {
1288 let mut attempt = 0;
1289
1290 while attempt <= self.max_retries {
1291 match self.try_execute(&mut pipeline, x, y) {
1292 Ok(result) => return Ok(result),
1293 Err(error) => match self.error_handling {
1294 ErrorHandlingStrategy::FailFast => {
1295 return Err(error);
1296 }
1297 ErrorHandlingStrategy::ContinueWithWarnings => {
1298 eprintln!(
1299 "Warning: Pipeline execution failed (attempt {}): {:?}",
1300 attempt + 1,
1301 error
1302 );
1303 if attempt == self.max_retries {
1304 return self.apply_fallback_strategy(x, y);
1305 }
1306 }
1307 ErrorHandlingStrategy::AttemptRecovery => {
1308 eprintln!(
1309 "Attempting recovery from error (attempt {}): {:?}",
1310 attempt + 1,
1311 error
1312 );
1313 if attempt == self.max_retries {
1314 return self.apply_fallback_strategy(x, y);
1315 }
1316 }
1317 },
1318 }
1319 attempt += 1;
1320 }
1321
1322 self.apply_fallback_strategy(x, y)
1323 }
1324
1325 fn try_execute<S>(
1327 &self,
1328 _pipeline: &mut Pipeline<S>,
1329 x: &ArrayView2<'_, Float>,
1330 _y: Option<&ArrayView1<'_, Float>>,
1331 ) -> SklResult<Array1<f64>>
1332 where
1333 S: std::fmt::Debug,
1334 {
1335 if x.nrows() == 0 {
1338 return Err(SklearsError::InvalidInput("Empty input data".to_string()));
1339 }
1340
1341 let predictions: Vec<f64> = x
1343 .rows()
1344 .into_iter()
1345 .map(|row| row.iter().copied().sum::<f64>() / row.len() as f64)
1346 .collect();
1347
1348 Ok(Array1::from_vec(predictions))
1349 }
1350
1351 fn apply_fallback_strategy(
1353 &self,
1354 x: &ArrayView2<'_, Float>,
1355 _y: Option<&ArrayView1<'_, Float>>,
1356 ) -> SklResult<Array1<f64>> {
1357 match self.fallback_strategy {
1358 FallbackStrategy::ReturnError => Err(SklearsError::InvalidData {
1359 reason: "Pipeline execution failed after maximum retries".to_string(),
1360 }),
1361 FallbackStrategy::SimplerPipeline => {
1362 eprintln!("Falling back to simpler pipeline");
1364 let simple_predictions: Vec<f64> = x
1365 .rows()
1366 .into_iter()
1367 .map(|row| {
1368 if row.is_empty() {
1370 0.0
1371 } else {
1372 row[0]
1373 }
1374 })
1375 .collect();
1376 Ok(Array1::from_vec(simple_predictions))
1377 }
1378 FallbackStrategy::DefaultValues => {
1379 eprintln!("Falling back to default values");
1381 Ok(Array1::zeros(x.nrows()))
1382 }
1383 FallbackStrategy::SkipStep => {
1384 eprintln!("Falling back by skipping failed step");
1386 let fallback_predictions: Vec<f64> = x
1387 .rows()
1388 .into_iter()
1389 .map(|row| row.iter().copied().sum())
1390 .collect();
1391 Ok(Array1::from_vec(fallback_predictions))
1392 }
1393 }
1394 }
1395}
1396
1397impl Default for RobustPipelineExecutor {
1398 fn default() -> Self {
1399 Self::new()
1400 }
1401}