1use crate::scoring::TaskType;
8use scirs2_core::ndarray::{Array1, Array2};
9use sklears_core::error::{Result, SklearsError};
10use std::collections::HashMap;
11use std::fmt;
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub enum AlgorithmFamily {
17 Linear,
19 TreeBased,
21 Ensemble,
23 NeighborBased,
25 SVM,
27 NaiveBayes,
29 NeuralNetwork,
31 GaussianProcess,
33 DiscriminantAnalysis,
35 Dummy,
37}
38
39impl fmt::Display for AlgorithmFamily {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 match self {
42 AlgorithmFamily::Linear => write!(f, "Linear"),
43 AlgorithmFamily::TreeBased => write!(f, "Tree-based"),
44 AlgorithmFamily::Ensemble => write!(f, "Ensemble"),
45 AlgorithmFamily::NeighborBased => write!(f, "Neighbor-based"),
46 AlgorithmFamily::SVM => write!(f, "Support Vector Machine"),
47 AlgorithmFamily::NaiveBayes => write!(f, "Naive Bayes"),
48 AlgorithmFamily::NeuralNetwork => write!(f, "Neural Network"),
49 AlgorithmFamily::GaussianProcess => write!(f, "Gaussian Process"),
50 AlgorithmFamily::DiscriminantAnalysis => write!(f, "Discriminant Analysis"),
51 AlgorithmFamily::Dummy => write!(f, "Dummy/Baseline"),
52 }
53 }
54}
55
56#[derive(Debug, Clone, PartialEq)]
58pub struct AlgorithmSpec {
59 pub family: AlgorithmFamily,
61 pub name: String,
63 pub default_params: HashMap<String, String>,
65 pub param_space: HashMap<String, Vec<String>>,
67 pub complexity: f64,
69 pub memory_requirement: f64,
71 pub supports_proba: bool,
73 pub handles_missing: bool,
75 pub handles_categorical: bool,
77 pub supports_incremental: bool,
79}
80
81#[derive(Debug, Clone)]
83pub struct DatasetCharacteristics {
84 pub n_samples: usize,
86 pub n_features: usize,
88 pub n_classes: Option<usize>,
90 pub class_distribution: Option<Vec<f64>>,
92 pub target_stats: Option<TargetStatistics>,
94 pub missing_ratio: f64,
96 pub categorical_ratio: f64,
98 pub correlation_condition_number: f64,
100 pub sparsity: f64,
102 pub effective_dimensionality: Option<usize>,
104 pub noise_level: f64,
106 pub linearity_score: f64,
108}
109
110#[derive(Debug, Clone)]
112pub struct TargetStatistics {
113 pub mean: f64,
115 pub std: f64,
117 pub skewness: f64,
119 pub kurtosis: f64,
121 pub n_outliers: usize,
123}
124
125#[derive(Debug, Clone, Default)]
127pub struct ComputationalConstraints {
128 pub max_training_time: Option<f64>,
130 pub max_memory_gb: Option<f64>,
132 pub max_model_size_mb: Option<f64>,
134 pub max_inference_time_ms: Option<f64>,
136 pub n_cores: Option<usize>,
138 pub has_gpu: bool,
140}
141
142#[derive(Debug, Clone)]
144pub struct AutoMLConfig {
145 pub task_type: TaskType,
147 pub constraints: ComputationalConstraints,
149 pub allowed_families: Option<Vec<AlgorithmFamily>>,
151 pub excluded_families: Vec<AlgorithmFamily>,
153 pub max_algorithms: usize,
155 pub cv_folds: usize,
157 pub scoring_metric: String,
159 pub hyperopt_time_budget: f64,
161 pub random_seed: Option<u64>,
163 pub enable_ensembles: bool,
165 pub enable_feature_engineering: bool,
167}
168
169impl Default for AutoMLConfig {
170 fn default() -> Self {
171 Self {
172 task_type: TaskType::Classification,
173 constraints: ComputationalConstraints::default(),
174 allowed_families: None,
175 excluded_families: Vec::new(),
176 max_algorithms: 10,
177 cv_folds: 5,
178 scoring_metric: "accuracy".to_string(),
179 hyperopt_time_budget: 300.0, random_seed: None,
181 enable_ensembles: true,
182 enable_feature_engineering: true,
183 }
184 }
185}
186
187#[derive(Debug, Clone)]
189pub struct AlgorithmSelectionResult {
190 pub selected_algorithms: Vec<RankedAlgorithm>,
192 pub dataset_characteristics: DatasetCharacteristics,
194 pub total_evaluation_time: f64,
196 pub n_algorithms_evaluated: usize,
198 pub best_algorithm: RankedAlgorithm,
200 pub improvement_over_baseline: f64,
202 pub explanation: String,
204}
205
206#[derive(Debug, Clone)]
208pub struct RankedAlgorithm {
209 pub algorithm: AlgorithmSpec,
211 pub cv_score: f64,
213 pub cv_std: f64,
215 pub training_time: f64,
217 pub memory_usage: f64,
219 pub best_params: HashMap<String, String>,
221 pub rank: usize,
223 pub selection_probability: f64,
225}
226
227impl fmt::Display for AlgorithmSelectionResult {
228 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229 writeln!(f, "AutoML Algorithm Selection Results")?;
230 writeln!(f, "==================================")?;
231 writeln!(
232 f,
233 "Dataset: {} samples, {} features",
234 self.dataset_characteristics.n_samples, self.dataset_characteristics.n_features
235 )?;
236 writeln!(f, "Algorithms evaluated: {}", self.n_algorithms_evaluated)?;
237 writeln!(
238 f,
239 "Total evaluation time: {:.2}s",
240 self.total_evaluation_time
241 )?;
242 writeln!(f)?;
243 writeln!(
244 f,
245 "Best Algorithm: {} ({})",
246 self.best_algorithm.algorithm.name, self.best_algorithm.algorithm.family
247 )?;
248 writeln!(
249 f,
250 "Score: {:.4} ± {:.4}",
251 self.best_algorithm.cv_score, self.best_algorithm.cv_std
252 )?;
253 writeln!(
254 f,
255 "Training time: {:.2}s",
256 self.best_algorithm.training_time
257 )?;
258 writeln!(
259 f,
260 "Improvement over baseline: {:.4}",
261 self.improvement_over_baseline
262 )?;
263 writeln!(f)?;
264 writeln!(f, "Explanation: {}", self.explanation)?;
265 writeln!(f)?;
266 writeln!(
267 f,
268 "Top {} Algorithms:",
269 self.selected_algorithms.len().min(5)
270 )?;
271 for (i, alg) in self.selected_algorithms.iter().take(5).enumerate() {
272 writeln!(
273 f,
274 "{}. {} ({}) - Score: {:.4} ± {:.4}",
275 i + 1,
276 alg.algorithm.name,
277 alg.algorithm.family,
278 alg.cv_score,
279 alg.cv_std
280 )?;
281 }
282 Ok(())
283 }
284}
285
286pub struct AutoMLAlgorithmSelector {
288 config: AutoMLConfig,
289 algorithm_catalog: HashMap<TaskType, Vec<AlgorithmSpec>>,
290}
291
292impl Default for AutoMLAlgorithmSelector {
293 fn default() -> Self {
294 Self::new(AutoMLConfig::default())
295 }
296}
297
298impl AutoMLAlgorithmSelector {
299 pub fn new(config: AutoMLConfig) -> Self {
301 let algorithm_catalog = Self::build_algorithm_catalog();
302 Self {
303 config,
304 algorithm_catalog,
305 }
306 }
307
308 pub fn analyze_dataset(&self, X: &Array2<f64>, y: &Array1<f64>) -> DatasetCharacteristics {
310 let n_samples = X.nrows();
311 let n_features = X.ncols();
312
313 let missing_ratio = self.calculate_missing_ratio(X);
315 let sparsity = self.calculate_sparsity(X);
316 let correlation_condition_number = self.calculate_correlation_condition_number(X);
317
318 let (n_classes, class_distribution, target_stats) = match self.config.task_type {
320 TaskType::Classification => {
321 let classes = self.get_unique_classes(y);
322 let class_dist = self.calculate_class_distribution(y, &classes);
323 (Some(classes.len()), Some(class_dist), None)
324 }
325 TaskType::Regression => {
326 let stats = self.calculate_target_statistics(y);
327 (None, None, Some(stats))
328 }
329 };
330
331 let linearity_score = self.estimate_linearity_score(X, y);
333 let noise_level = self.estimate_noise_level(X, y);
334 let effective_dimensionality = self.estimate_effective_dimensionality(X);
335 let categorical_ratio = self.calculate_categorical_ratio(X);
336
337 DatasetCharacteristics {
338 n_samples,
339 n_features,
340 n_classes,
341 class_distribution,
342 target_stats,
343 missing_ratio,
344 categorical_ratio,
345 correlation_condition_number,
346 sparsity,
347 effective_dimensionality,
348 noise_level,
349 linearity_score,
350 }
351 }
352
353 pub fn select_algorithms(
355 &self,
356 X: &Array2<f64>,
357 y: &Array1<f64>,
358 ) -> Result<AlgorithmSelectionResult> {
359 let start_time = std::time::Instant::now();
360
361 let dataset_chars = self.analyze_dataset(X, y);
363
364 let candidate_algorithms = self.get_candidate_algorithms(&dataset_chars)?;
366
367 let filtered_algorithms = self.filter_by_constraints(&candidate_algorithms, &dataset_chars);
369
370 let mut evaluated_algorithms = self.evaluate_algorithms(&filtered_algorithms, X, y)?;
372
373 evaluated_algorithms.sort_by(|a, b| b.cv_score.partial_cmp(&a.cv_score).unwrap());
375
376 let algorithms_copy = evaluated_algorithms.clone();
378 for (i, alg) in evaluated_algorithms.iter_mut().enumerate() {
379 alg.rank = i + 1;
380 alg.selection_probability = self.calculate_selection_probability(alg, &algorithms_copy);
381 }
382
383 let best_algorithm = evaluated_algorithms[0].clone();
384 let baseline_score = self.get_baseline_score(X, y)?;
385 let improvement = best_algorithm.cv_score - baseline_score;
386
387 let explanation = self.generate_explanation(&best_algorithm, &dataset_chars);
388
389 let total_time = start_time.elapsed().as_secs_f64();
390
391 Ok(AlgorithmSelectionResult {
392 selected_algorithms: evaluated_algorithms,
393 dataset_characteristics: dataset_chars,
394 total_evaluation_time: total_time,
395 n_algorithms_evaluated: filtered_algorithms.len(),
396 best_algorithm,
397 improvement_over_baseline: improvement,
398 explanation,
399 })
400 }
401
402 fn build_algorithm_catalog() -> HashMap<TaskType, Vec<AlgorithmSpec>> {
404 let mut catalog = HashMap::new();
405
406 let classification_algorithms = vec![
408 AlgorithmSpec {
410 family: AlgorithmFamily::Linear,
411 name: "LogisticRegression".to_string(),
412 default_params: [("C".to_string(), "1.0".to_string())]
413 .iter()
414 .cloned()
415 .collect(),
416 param_space: [(
417 "C".to_string(),
418 vec![
419 "0.001".to_string(),
420 "0.01".to_string(),
421 "0.1".to_string(),
422 "1.0".to_string(),
423 "10.0".to_string(),
424 "100.0".to_string(),
425 ],
426 )]
427 .iter()
428 .cloned()
429 .collect(),
430 complexity: 1.0,
431 memory_requirement: 1.0,
432 supports_proba: true,
433 handles_missing: false,
434 handles_categorical: false,
435 supports_incremental: false,
436 },
437 AlgorithmSpec {
438 family: AlgorithmFamily::Linear,
439 name: "RidgeClassifier".to_string(),
440 default_params: [("alpha".to_string(), "1.0".to_string())]
441 .iter()
442 .cloned()
443 .collect(),
444 param_space: [(
445 "alpha".to_string(),
446 vec![
447 "0.1".to_string(),
448 "1.0".to_string(),
449 "10.0".to_string(),
450 "100.0".to_string(),
451 ],
452 )]
453 .iter()
454 .cloned()
455 .collect(),
456 complexity: 1.0,
457 memory_requirement: 1.0,
458 supports_proba: false,
459 handles_missing: false,
460 handles_categorical: false,
461 supports_incremental: false,
462 },
463 AlgorithmSpec {
465 family: AlgorithmFamily::TreeBased,
466 name: "DecisionTreeClassifier".to_string(),
467 default_params: [("max_depth".to_string(), "None".to_string())]
468 .iter()
469 .cloned()
470 .collect(),
471 param_space: [
472 (
473 "max_depth".to_string(),
474 vec![
475 "3".to_string(),
476 "5".to_string(),
477 "10".to_string(),
478 "None".to_string(),
479 ],
480 ),
481 (
482 "min_samples_split".to_string(),
483 vec!["2".to_string(), "5".to_string(), "10".to_string()],
484 ),
485 ]
486 .iter()
487 .cloned()
488 .collect(),
489 complexity: 2.0,
490 memory_requirement: 2.0,
491 supports_proba: true,
492 handles_missing: false,
493 handles_categorical: true,
494 supports_incremental: false,
495 },
496 AlgorithmSpec {
497 family: AlgorithmFamily::TreeBased,
498 name: "RandomForestClassifier".to_string(),
499 default_params: [("n_estimators".to_string(), "100".to_string())]
500 .iter()
501 .cloned()
502 .collect(),
503 param_space: [
504 (
505 "n_estimators".to_string(),
506 vec!["50".to_string(), "100".to_string(), "200".to_string()],
507 ),
508 (
509 "max_depth".to_string(),
510 vec![
511 "3".to_string(),
512 "5".to_string(),
513 "10".to_string(),
514 "None".to_string(),
515 ],
516 ),
517 ]
518 .iter()
519 .cloned()
520 .collect(),
521 complexity: 4.0,
522 memory_requirement: 4.0,
523 supports_proba: true,
524 handles_missing: false,
525 handles_categorical: true,
526 supports_incremental: false,
527 },
528 AlgorithmSpec {
530 family: AlgorithmFamily::Ensemble,
531 name: "AdaBoostClassifier".to_string(),
532 default_params: [("n_estimators".to_string(), "50".to_string())]
533 .iter()
534 .cloned()
535 .collect(),
536 param_space: [
537 (
538 "n_estimators".to_string(),
539 vec!["25".to_string(), "50".to_string(), "100".to_string()],
540 ),
541 (
542 "learning_rate".to_string(),
543 vec!["0.1".to_string(), "0.5".to_string(), "1.0".to_string()],
544 ),
545 ]
546 .iter()
547 .cloned()
548 .collect(),
549 complexity: 3.0,
550 memory_requirement: 3.0,
551 supports_proba: true,
552 handles_missing: false,
553 handles_categorical: true,
554 supports_incremental: false,
555 },
556 AlgorithmSpec {
558 family: AlgorithmFamily::NeighborBased,
559 name: "KNeighborsClassifier".to_string(),
560 default_params: [("n_neighbors".to_string(), "5".to_string())]
561 .iter()
562 .cloned()
563 .collect(),
564 param_space: [
565 (
566 "n_neighbors".to_string(),
567 vec![
568 "3".to_string(),
569 "5".to_string(),
570 "7".to_string(),
571 "11".to_string(),
572 ],
573 ),
574 (
575 "weights".to_string(),
576 vec!["uniform".to_string(), "distance".to_string()],
577 ),
578 ]
579 .iter()
580 .cloned()
581 .collect(),
582 complexity: 1.0,
583 memory_requirement: 5.0,
584 supports_proba: true,
585 handles_missing: false,
586 handles_categorical: false,
587 supports_incremental: false,
588 },
589 AlgorithmSpec {
591 family: AlgorithmFamily::NaiveBayes,
592 name: "GaussianNB".to_string(),
593 default_params: HashMap::new(),
594 param_space: HashMap::new(),
595 complexity: 1.0,
596 memory_requirement: 1.0,
597 supports_proba: true,
598 handles_missing: false,
599 handles_categorical: false,
600 supports_incremental: true,
601 },
602 AlgorithmSpec {
604 family: AlgorithmFamily::SVM,
605 name: "SVC".to_string(),
606 default_params: [
607 ("C".to_string(), "1.0".to_string()),
608 ("kernel".to_string(), "rbf".to_string()),
609 ]
610 .iter()
611 .cloned()
612 .collect(),
613 param_space: [
614 (
615 "C".to_string(),
616 vec!["0.1".to_string(), "1.0".to_string(), "10.0".to_string()],
617 ),
618 (
619 "kernel".to_string(),
620 vec!["linear".to_string(), "rbf".to_string(), "poly".to_string()],
621 ),
622 ]
623 .iter()
624 .cloned()
625 .collect(),
626 complexity: 3.0,
627 memory_requirement: 3.0,
628 supports_proba: false,
629 handles_missing: false,
630 handles_categorical: false,
631 supports_incremental: false,
632 },
633 AlgorithmSpec {
635 family: AlgorithmFamily::Dummy,
636 name: "DummyClassifier".to_string(),
637 default_params: [("strategy".to_string(), "stratified".to_string())]
638 .iter()
639 .cloned()
640 .collect(),
641 param_space: [(
642 "strategy".to_string(),
643 vec![
644 "stratified".to_string(),
645 "most_frequent".to_string(),
646 "uniform".to_string(),
647 ],
648 )]
649 .iter()
650 .cloned()
651 .collect(),
652 complexity: 0.1,
653 memory_requirement: 0.1,
654 supports_proba: true,
655 handles_missing: true,
656 handles_categorical: true,
657 supports_incremental: true,
658 },
659 ];
660
661 let regression_algorithms = vec![
663 AlgorithmSpec {
665 family: AlgorithmFamily::Linear,
666 name: "LinearRegression".to_string(),
667 default_params: HashMap::new(),
668 param_space: HashMap::new(),
669 complexity: 1.0,
670 memory_requirement: 1.0,
671 supports_proba: false,
672 handles_missing: false,
673 handles_categorical: false,
674 supports_incremental: false,
675 },
676 AlgorithmSpec {
677 family: AlgorithmFamily::Linear,
678 name: "Ridge".to_string(),
679 default_params: [("alpha".to_string(), "1.0".to_string())]
680 .iter()
681 .cloned()
682 .collect(),
683 param_space: [(
684 "alpha".to_string(),
685 vec![
686 "0.1".to_string(),
687 "1.0".to_string(),
688 "10.0".to_string(),
689 "100.0".to_string(),
690 ],
691 )]
692 .iter()
693 .cloned()
694 .collect(),
695 complexity: 1.0,
696 memory_requirement: 1.0,
697 supports_proba: false,
698 handles_missing: false,
699 handles_categorical: false,
700 supports_incremental: false,
701 },
702 AlgorithmSpec {
703 family: AlgorithmFamily::Linear,
704 name: "Lasso".to_string(),
705 default_params: [("alpha".to_string(), "1.0".to_string())]
706 .iter()
707 .cloned()
708 .collect(),
709 param_space: [(
710 "alpha".to_string(),
711 vec![
712 "0.001".to_string(),
713 "0.01".to_string(),
714 "0.1".to_string(),
715 "1.0".to_string(),
716 ],
717 )]
718 .iter()
719 .cloned()
720 .collect(),
721 complexity: 1.5,
722 memory_requirement: 1.0,
723 supports_proba: false,
724 handles_missing: false,
725 handles_categorical: false,
726 supports_incremental: false,
727 },
728 AlgorithmSpec {
730 family: AlgorithmFamily::TreeBased,
731 name: "DecisionTreeRegressor".to_string(),
732 default_params: [("max_depth".to_string(), "None".to_string())]
733 .iter()
734 .cloned()
735 .collect(),
736 param_space: [
737 (
738 "max_depth".to_string(),
739 vec![
740 "3".to_string(),
741 "5".to_string(),
742 "10".to_string(),
743 "None".to_string(),
744 ],
745 ),
746 (
747 "min_samples_split".to_string(),
748 vec!["2".to_string(), "5".to_string(), "10".to_string()],
749 ),
750 ]
751 .iter()
752 .cloned()
753 .collect(),
754 complexity: 2.0,
755 memory_requirement: 2.0,
756 supports_proba: false,
757 handles_missing: false,
758 handles_categorical: true,
759 supports_incremental: false,
760 },
761 AlgorithmSpec {
762 family: AlgorithmFamily::TreeBased,
763 name: "RandomForestRegressor".to_string(),
764 default_params: [("n_estimators".to_string(), "100".to_string())]
765 .iter()
766 .cloned()
767 .collect(),
768 param_space: [
769 (
770 "n_estimators".to_string(),
771 vec!["50".to_string(), "100".to_string(), "200".to_string()],
772 ),
773 (
774 "max_depth".to_string(),
775 vec![
776 "3".to_string(),
777 "5".to_string(),
778 "10".to_string(),
779 "None".to_string(),
780 ],
781 ),
782 ]
783 .iter()
784 .cloned()
785 .collect(),
786 complexity: 4.0,
787 memory_requirement: 4.0,
788 supports_proba: false,
789 handles_missing: false,
790 handles_categorical: true,
791 supports_incremental: false,
792 },
793 AlgorithmSpec {
795 family: AlgorithmFamily::NeighborBased,
796 name: "KNeighborsRegressor".to_string(),
797 default_params: [("n_neighbors".to_string(), "5".to_string())]
798 .iter()
799 .cloned()
800 .collect(),
801 param_space: [
802 (
803 "n_neighbors".to_string(),
804 vec![
805 "3".to_string(),
806 "5".to_string(),
807 "7".to_string(),
808 "11".to_string(),
809 ],
810 ),
811 (
812 "weights".to_string(),
813 vec!["uniform".to_string(), "distance".to_string()],
814 ),
815 ]
816 .iter()
817 .cloned()
818 .collect(),
819 complexity: 1.0,
820 memory_requirement: 5.0,
821 supports_proba: false,
822 handles_missing: false,
823 handles_categorical: false,
824 supports_incremental: false,
825 },
826 AlgorithmSpec {
828 family: AlgorithmFamily::SVM,
829 name: "SVR".to_string(),
830 default_params: [
831 ("C".to_string(), "1.0".to_string()),
832 ("kernel".to_string(), "rbf".to_string()),
833 ]
834 .iter()
835 .cloned()
836 .collect(),
837 param_space: [
838 (
839 "C".to_string(),
840 vec!["0.1".to_string(), "1.0".to_string(), "10.0".to_string()],
841 ),
842 (
843 "kernel".to_string(),
844 vec!["linear".to_string(), "rbf".to_string(), "poly".to_string()],
845 ),
846 ]
847 .iter()
848 .cloned()
849 .collect(),
850 complexity: 3.0,
851 memory_requirement: 3.0,
852 supports_proba: false,
853 handles_missing: false,
854 handles_categorical: false,
855 supports_incremental: false,
856 },
857 AlgorithmSpec {
859 family: AlgorithmFamily::Dummy,
860 name: "DummyRegressor".to_string(),
861 default_params: [("strategy".to_string(), "mean".to_string())]
862 .iter()
863 .cloned()
864 .collect(),
865 param_space: [(
866 "strategy".to_string(),
867 vec![
868 "mean".to_string(),
869 "median".to_string(),
870 "constant".to_string(),
871 ],
872 )]
873 .iter()
874 .cloned()
875 .collect(),
876 complexity: 0.1,
877 memory_requirement: 0.1,
878 supports_proba: false,
879 handles_missing: true,
880 handles_categorical: true,
881 supports_incremental: true,
882 },
883 ];
884
885 catalog.insert(TaskType::Classification, classification_algorithms);
886 catalog.insert(TaskType::Regression, regression_algorithms);
887 catalog
888 }
889
890 fn get_candidate_algorithms(
892 &self,
893 dataset_chars: &DatasetCharacteristics,
894 ) -> Result<Vec<AlgorithmSpec>> {
895 let algorithms = self
896 .algorithm_catalog
897 .get(&self.config.task_type)
898 .ok_or_else(|| SklearsError::InvalidParameter {
899 name: "task_type".to_string(),
900 reason: format!(
901 "No algorithms available for task type: {:?}",
902 self.config.task_type
903 ),
904 })?;
905
906 let mut candidates = Vec::new();
907
908 for algorithm in algorithms {
909 if let Some(ref allowed) = self.config.allowed_families {
911 if !allowed.contains(&algorithm.family) {
912 continue;
913 }
914 }
915
916 if self.config.excluded_families.contains(&algorithm.family) {
918 continue;
919 }
920
921 if self.is_algorithm_suitable(algorithm, dataset_chars) {
923 candidates.push(algorithm.clone());
924 }
925 }
926
927 candidates.truncate(self.config.max_algorithms);
929
930 Ok(candidates)
931 }
932
933 fn is_algorithm_suitable(
935 &self,
936 algorithm: &AlgorithmSpec,
937 dataset_chars: &DatasetCharacteristics,
938 ) -> bool {
939 if algorithm.family == AlgorithmFamily::Dummy && !self.config.excluded_families.is_empty() {
941 return false;
942 }
943
944 if dataset_chars.n_features > dataset_chars.n_samples {
946 match algorithm.family {
948 AlgorithmFamily::Linear | AlgorithmFamily::NaiveBayes => return true,
949 AlgorithmFamily::NeighborBased | AlgorithmFamily::SVM => return false,
950 _ => {}
951 }
952 }
953
954 if dataset_chars.n_samples < 100 {
956 if algorithm.complexity > 3.0 {
958 return false;
959 }
960 }
961
962 if dataset_chars.n_samples > 10000 {
964 match algorithm.family {
966 AlgorithmFamily::NeighborBased => return false, AlgorithmFamily::SVM => return dataset_chars.n_samples < 50000, _ => {}
969 }
970 }
971
972 if dataset_chars.linearity_score > 0.8 {
974 match algorithm.family {
976 AlgorithmFamily::Linear => return true,
977 AlgorithmFamily::TreeBased | AlgorithmFamily::Ensemble => return false,
978 _ => {}
979 }
980 }
981
982 if dataset_chars.missing_ratio > 0.0 && !algorithm.handles_missing {
984 return false;
985 }
986
987 true
988 }
989
990 fn filter_by_constraints(
992 &self,
993 algorithms: &[AlgorithmSpec],
994 dataset_chars: &DatasetCharacteristics,
995 ) -> Vec<AlgorithmSpec> {
996 algorithms
997 .iter()
998 .filter(|alg| self.satisfies_constraints(alg, dataset_chars))
999 .cloned()
1000 .collect()
1001 }
1002
1003 fn satisfies_constraints(
1005 &self,
1006 algorithm: &AlgorithmSpec,
1007 dataset_chars: &DatasetCharacteristics,
1008 ) -> bool {
1009 let estimated_training_time = self.estimate_training_time(algorithm, dataset_chars);
1011 let estimated_memory_usage = self.estimate_memory_usage(algorithm, dataset_chars);
1012
1013 if let Some(max_time) = self.config.constraints.max_training_time {
1014 if estimated_training_time > max_time {
1015 return false;
1016 }
1017 }
1018
1019 if let Some(max_memory) = self.config.constraints.max_memory_gb {
1020 if estimated_memory_usage > max_memory {
1021 return false;
1022 }
1023 }
1024
1025 true
1026 }
1027
1028 fn estimate_training_time(
1030 &self,
1031 algorithm: &AlgorithmSpec,
1032 dataset_chars: &DatasetCharacteristics,
1033 ) -> f64 {
1034 let n = dataset_chars.n_samples as f64;
1035 let p = dataset_chars.n_features as f64;
1036
1037 let base_time = match algorithm.family {
1039 AlgorithmFamily::Linear => 0.1,
1040 AlgorithmFamily::TreeBased => {
1041 if algorithm.name.contains("Random") {
1042 2.0
1043 } else {
1044 0.5
1045 }
1046 }
1047 AlgorithmFamily::Ensemble => 3.0,
1048 AlgorithmFamily::NeighborBased => 0.05, AlgorithmFamily::SVM => 1.0,
1050 AlgorithmFamily::NaiveBayes => 0.05,
1051 AlgorithmFamily::NeuralNetwork => 5.0,
1052 AlgorithmFamily::GaussianProcess => 2.0,
1053 AlgorithmFamily::DiscriminantAnalysis => 0.2,
1054 AlgorithmFamily::Dummy => 0.01,
1055 };
1056
1057 base_time * algorithm.complexity * (n / 1000.0) * (p / 10.0).sqrt()
1059 }
1060
1061 fn estimate_memory_usage(
1063 &self,
1064 algorithm: &AlgorithmSpec,
1065 dataset_chars: &DatasetCharacteristics,
1066 ) -> f64 {
1067 let n = dataset_chars.n_samples as f64;
1068 let p = dataset_chars.n_features as f64;
1069
1070 let base_memory_mb = match algorithm.family {
1072 AlgorithmFamily::Linear => 1.0,
1073 AlgorithmFamily::TreeBased => {
1074 if algorithm.name.contains("Random") {
1075 50.0
1076 } else {
1077 10.0
1078 }
1079 }
1080 AlgorithmFamily::Ensemble => 100.0,
1081 AlgorithmFamily::NeighborBased => n * p * 8.0 / 1_000_000.0, AlgorithmFamily::SVM => 20.0,
1083 AlgorithmFamily::NaiveBayes => 1.0,
1084 AlgorithmFamily::NeuralNetwork => 50.0,
1085 AlgorithmFamily::GaussianProcess => 10.0,
1086 AlgorithmFamily::DiscriminantAnalysis => 5.0,
1087 AlgorithmFamily::Dummy => 0.1,
1088 };
1089
1090 (base_memory_mb * algorithm.memory_requirement) / 1000.0 }
1092
1093 fn evaluate_algorithms(
1095 &self,
1096 algorithms: &[AlgorithmSpec],
1097 X: &Array2<f64>,
1098 y: &Array1<f64>,
1099 ) -> Result<Vec<RankedAlgorithm>> {
1100 let mut results = Vec::new();
1101
1102 for algorithm in algorithms {
1103 let start_time = std::time::Instant::now();
1104
1105 let cv_score = self.mock_evaluate_algorithm(algorithm, X, y);
1108 let cv_std = cv_score * 0.05; let training_time = start_time.elapsed().as_secs_f64();
1111 let memory_usage = self.estimate_memory_usage(algorithm, &self.analyze_dataset(X, y));
1112
1113 results.push(RankedAlgorithm {
1114 algorithm: algorithm.clone(),
1115 cv_score,
1116 cv_std,
1117 training_time,
1118 memory_usage,
1119 best_params: algorithm.default_params.clone(),
1120 rank: 0, selection_probability: 0.0, });
1123 }
1124
1125 Ok(results)
1126 }
1127
1128 fn mock_evaluate_algorithm(
1130 &self,
1131 algorithm: &AlgorithmSpec,
1132 X: &Array2<f64>,
1133 y: &Array1<f64>,
1134 ) -> f64 {
1135 let dataset_chars = self.analyze_dataset(X, y);
1137
1138 let base_score = match self.config.task_type {
1139 TaskType::Classification => 0.7, TaskType::Regression => 0.8, };
1142
1143 let mut score: f64 = base_score;
1145
1146 if algorithm.family == AlgorithmFamily::Linear && dataset_chars.linearity_score > 0.7 {
1148 score += 0.1;
1149 }
1150
1151 if matches!(
1153 algorithm.family,
1154 AlgorithmFamily::TreeBased | AlgorithmFamily::Ensemble
1155 ) && dataset_chars.linearity_score < 0.5
1156 {
1157 score += 0.1;
1158 }
1159
1160 if algorithm.family == AlgorithmFamily::Ensemble {
1162 score += 0.05;
1163 }
1164
1165 let mut rng = scirs2_core::random::thread_rng();
1168 score += rng.gen_range(-0.05..0.05);
1169
1170 score.clamp(0.0, 1.0)
1171 }
1172
1173 fn calculate_selection_probability(
1175 &self,
1176 algorithm: &RankedAlgorithm,
1177 all_algorithms: &[RankedAlgorithm],
1178 ) -> f64 {
1179 let max_score = all_algorithms
1180 .iter()
1181 .map(|a| a.cv_score)
1182 .fold(0.0, f64::max);
1183 let min_score = all_algorithms
1184 .iter()
1185 .map(|a| a.cv_score)
1186 .fold(1.0, f64::min);
1187
1188 if max_score == min_score {
1189 return 1.0 / all_algorithms.len() as f64;
1190 }
1191
1192 let normalized_score = (algorithm.cv_score - min_score) / (max_score - min_score);
1194 let exp_score = (normalized_score * 5.0).exp();
1195 let total_exp: f64 = all_algorithms
1196 .iter()
1197 .map(|a| {
1198 let norm = (a.cv_score - min_score) / (max_score - min_score);
1199 (norm * 5.0).exp()
1200 })
1201 .sum();
1202
1203 exp_score / total_exp
1204 }
1205
1206 fn get_baseline_score(&self, _X: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
1208 match self.config.task_type {
1209 TaskType::Classification => {
1210 let classes = self.get_unique_classes(y);
1212 let class_counts = self.calculate_class_distribution(y, &classes);
1213 Ok(class_counts.iter().fold(0.0, |acc, &x| acc.max(x)))
1214 }
1215 TaskType::Regression => {
1216 let mean = y.mean().unwrap();
1218 let tss: f64 = y.iter().map(|&yi| (yi - mean).powi(2)).sum();
1219 let rss = tss; Ok(1.0 - rss / tss)
1221 }
1222 }
1223 }
1224
1225 fn generate_explanation(
1227 &self,
1228 best_algorithm: &RankedAlgorithm,
1229 dataset_chars: &DatasetCharacteristics,
1230 ) -> String {
1231 let mut explanation = format!(
1232 "{} ({}) was selected as the best algorithm with a cross-validation score of {:.4}.",
1233 best_algorithm.algorithm.name, best_algorithm.algorithm.family, best_algorithm.cv_score
1234 );
1235
1236 if dataset_chars.n_samples < 1000 {
1238 explanation.push_str(" This algorithm is well-suited for small datasets.");
1239 } else if dataset_chars.n_samples > 10000 {
1240 explanation.push_str(" This algorithm scales well to large datasets.");
1241 }
1242
1243 if dataset_chars.linearity_score > 0.7
1244 && best_algorithm.algorithm.family == AlgorithmFamily::Linear
1245 {
1246 explanation.push_str(
1247 " The linear nature of your data makes linear models particularly effective.",
1248 );
1249 }
1250
1251 if dataset_chars.n_features > dataset_chars.n_samples {
1252 explanation.push_str(" The high-dimensional nature of your data favors this algorithm's regularization capabilities.");
1253 }
1254
1255 if best_algorithm.algorithm.family == AlgorithmFamily::Ensemble {
1256 explanation.push_str(
1257 " Ensemble methods often provide robust performance across diverse datasets.",
1258 );
1259 }
1260
1261 explanation
1262 }
1263
1264 fn calculate_missing_ratio(&self, X: &Array2<f64>) -> f64 {
1266 let total_values = X.len() as f64;
1267 let missing_count = X.iter().filter(|&&x| x.is_nan()).count() as f64;
1268 missing_count / total_values
1269 }
1270
1271 fn calculate_sparsity(&self, X: &Array2<f64>) -> f64 {
1272 let total_values = X.len() as f64;
1273 let zero_count = X.iter().filter(|&&x| x == 0.0).count() as f64;
1274 zero_count / total_values
1275 }
1276
1277 fn calculate_categorical_ratio(&self, X: &Array2<f64>) -> f64 {
1278 let n_features = X.ncols();
1279 if n_features == 0 {
1280 return 0.0;
1281 }
1282
1283 let mut categorical_count = 0;
1284 for col_idx in 0..n_features {
1285 let column = X.column(col_idx);
1286
1287 let valid_values: Vec<f64> = column.iter().filter(|&&x| !x.is_nan()).copied().collect();
1289
1290 if valid_values.is_empty() {
1291 continue;
1292 }
1293
1294 let mut unique_values = valid_values.clone();
1296 unique_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
1297 unique_values.dedup();
1298
1299 let n_unique = unique_values.len();
1300 let n_total = valid_values.len();
1301
1302 let unique_ratio = n_unique as f64 / n_total as f64;
1306 let all_integers = valid_values.iter().all(|&x| (x - x.round()).abs() < 1e-10);
1307
1308 if n_unique <= 10 || (unique_ratio < 0.05 && all_integers) {
1309 categorical_count += 1;
1310 }
1311 }
1312
1313 categorical_count as f64 / n_features as f64
1314 }
1315
1316 fn calculate_correlation_condition_number(&self, _X: &Array2<f64>) -> f64 {
1317 let mut rng = scirs2_core::random::thread_rng();
1320 rng.gen_range(1.0..100.0)
1321 }
1322
1323 fn get_unique_classes(&self, y: &Array1<f64>) -> Vec<i32> {
1324 let mut classes: Vec<i32> = y.iter().map(|&x| x as i32).collect();
1325 classes.sort_unstable();
1326 classes.dedup();
1327 classes
1328 }
1329
1330 fn calculate_class_distribution(&self, y: &Array1<f64>, classes: &[i32]) -> Vec<f64> {
1331 let total = y.len() as f64;
1332 classes
1333 .iter()
1334 .map(|&class| {
1335 let count = y.iter().filter(|&&yi| yi as i32 == class).count() as f64;
1336 count / total
1337 })
1338 .collect()
1339 }
1340
1341 fn calculate_target_statistics(&self, y: &Array1<f64>) -> TargetStatistics {
1342 let mean = y.mean().unwrap();
1343 let std = y.std(0.0);
1344
1345 TargetStatistics {
1347 mean,
1348 std,
1349 skewness: 0.0, kurtosis: 0.0, n_outliers: 0, }
1353 }
1354
1355 fn estimate_linearity_score(&self, _X: &Array2<f64>, _y: &Array1<f64>) -> f64 {
1356 let mut rng = scirs2_core::random::thread_rng();
1359 rng.gen_range(0.0..1.0)
1360 }
1361
1362 fn estimate_noise_level(&self, _X: &Array2<f64>, _y: &Array1<f64>) -> f64 {
1363 let mut rng = scirs2_core::random::thread_rng();
1366 rng.gen_range(0.0..0.5)
1367 }
1368
1369 fn estimate_effective_dimensionality(&self, X: &Array2<f64>) -> Option<usize> {
1370 Some((X.ncols() as f64 * 0.8) as usize)
1372 }
1373}
1374
1375pub fn select_best_algorithm(
1377 X: &Array2<f64>,
1378 y: &Array1<f64>,
1379 task_type: TaskType,
1380) -> Result<AlgorithmSelectionResult> {
1381 let config = AutoMLConfig {
1382 task_type,
1383 ..Default::default()
1384 };
1385
1386 let selector = AutoMLAlgorithmSelector::new(config);
1387 selector.select_algorithms(X, y)
1388}
1389
1390#[allow(non_snake_case)]
1391#[cfg(test)]
1392mod tests {
1393 use super::*;
1394 use scirs2_core::ndarray::{Array1, Array2};
1395
1396 #[allow(non_snake_case)]
1397 fn create_test_classification_data() -> (Array2<f64>, Array1<f64>) {
1398 let X = Array2::from_shape_vec((100, 4), (0..400).map(|i| i as f64).collect()).unwrap();
1399 let y = Array1::from_vec((0..100).map(|i| (i % 3) as f64).collect());
1400 (X, y)
1401 }
1402
1403 #[allow(non_snake_case)]
1404 fn create_test_regression_data() -> (Array2<f64>, Array1<f64>) {
1405 let X = Array2::from_shape_vec((100, 4), (0..400).map(|i| i as f64).collect()).unwrap();
1406 use scirs2_core::essentials::Uniform;
1407 use scirs2_core::random::{thread_rng, Distribution};
1408 let mut rng = thread_rng();
1409 let dist = Uniform::new(0.0, 1.0).unwrap();
1410 let y = Array1::from_vec((0..100).map(|i| i as f64 + dist.sample(&mut rng)).collect());
1411 (X, y)
1412 }
1413
1414 #[test]
1415 fn test_algorithm_selection_classification() {
1416 let (X, y) = create_test_classification_data();
1417 let result = select_best_algorithm(&X, &y, TaskType::Classification);
1418 assert!(result.is_ok());
1419
1420 let result = result.unwrap();
1421 assert!(!result.selected_algorithms.is_empty());
1422 assert!(result.best_algorithm.cv_score > 0.0);
1423 }
1424
1425 #[test]
1426 fn test_algorithm_selection_regression() {
1427 let (X, y) = create_test_regression_data();
1428 let result = select_best_algorithm(&X, &y, TaskType::Regression);
1429 assert!(result.is_ok());
1430
1431 let result = result.unwrap();
1432 assert!(!result.selected_algorithms.is_empty());
1433 assert!(result.best_algorithm.cv_score > 0.0);
1434 }
1435
1436 #[test]
1437 fn test_dataset_characteristics_analysis() {
1438 let (X, y) = create_test_classification_data();
1439 let config = AutoMLConfig::default();
1440 let selector = AutoMLAlgorithmSelector::new(config);
1441
1442 let chars = selector.analyze_dataset(&X, &y);
1443 assert_eq!(chars.n_samples, 100);
1444 assert_eq!(chars.n_features, 4);
1445 assert_eq!(chars.n_classes, Some(3));
1446 }
1447
1448 #[test]
1449 fn test_custom_config() {
1450 let (X, y) = create_test_classification_data();
1451
1452 let config = AutoMLConfig {
1453 task_type: TaskType::Classification,
1454 max_algorithms: 3,
1455 allowed_families: Some(vec![AlgorithmFamily::Linear, AlgorithmFamily::TreeBased]),
1456 ..Default::default()
1457 };
1458
1459 let selector = AutoMLAlgorithmSelector::new(config);
1460 let result = selector.select_algorithms(&X, &y);
1461 assert!(result.is_ok());
1462
1463 let result = result.unwrap();
1464 assert!(result.n_algorithms_evaluated <= 3);
1465
1466 for alg in &result.selected_algorithms {
1467 assert!(matches!(
1468 alg.algorithm.family,
1469 AlgorithmFamily::Linear | AlgorithmFamily::TreeBased
1470 ));
1471 }
1472 }
1473
1474 #[test]
1475 fn test_computational_constraints() {
1476 let (X, y) = create_test_classification_data();
1477
1478 let config = AutoMLConfig {
1479 task_type: TaskType::Classification,
1480 constraints: ComputationalConstraints {
1481 max_training_time: Some(1.0), max_memory_gb: Some(0.1), ..Default::default()
1484 },
1485 ..Default::default()
1486 };
1487
1488 let selector = AutoMLAlgorithmSelector::new(config);
1489 let result = selector.select_algorithms(&X, &y);
1490 assert!(result.is_ok());
1491
1492 let result = result.unwrap();
1493 assert!(!result.selected_algorithms.is_empty());
1495 }
1496}