1use crate::scoring::TaskType;
8use scirs2_core::ndarray::{Array1, Array2};
9use sklears_core::error::{Result, SklearsError};
10use std::collections::HashMap;
11use std::fmt;
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub enum AlgorithmFamily {
17 Linear,
19 TreeBased,
21 Ensemble,
23 NeighborBased,
25 SVM,
27 NaiveBayes,
29 NeuralNetwork,
31 GaussianProcess,
33 DiscriminantAnalysis,
35 Dummy,
37}
38
39impl fmt::Display for AlgorithmFamily {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 match self {
42 AlgorithmFamily::Linear => write!(f, "Linear"),
43 AlgorithmFamily::TreeBased => write!(f, "Tree-based"),
44 AlgorithmFamily::Ensemble => write!(f, "Ensemble"),
45 AlgorithmFamily::NeighborBased => write!(f, "Neighbor-based"),
46 AlgorithmFamily::SVM => write!(f, "Support Vector Machine"),
47 AlgorithmFamily::NaiveBayes => write!(f, "Naive Bayes"),
48 AlgorithmFamily::NeuralNetwork => write!(f, "Neural Network"),
49 AlgorithmFamily::GaussianProcess => write!(f, "Gaussian Process"),
50 AlgorithmFamily::DiscriminantAnalysis => write!(f, "Discriminant Analysis"),
51 AlgorithmFamily::Dummy => write!(f, "Dummy/Baseline"),
52 }
53 }
54}
55
56#[derive(Debug, Clone, PartialEq)]
58pub struct AlgorithmSpec {
59 pub family: AlgorithmFamily,
61 pub name: String,
63 pub default_params: HashMap<String, String>,
65 pub param_space: HashMap<String, Vec<String>>,
67 pub complexity: f64,
69 pub memory_requirement: f64,
71 pub supports_proba: bool,
73 pub handles_missing: bool,
75 pub handles_categorical: bool,
77 pub supports_incremental: bool,
79}
80
81#[derive(Debug, Clone)]
83pub struct DatasetCharacteristics {
84 pub n_samples: usize,
86 pub n_features: usize,
88 pub n_classes: Option<usize>,
90 pub class_distribution: Option<Vec<f64>>,
92 pub target_stats: Option<TargetStatistics>,
94 pub missing_ratio: f64,
96 pub categorical_ratio: f64,
98 pub correlation_condition_number: f64,
100 pub sparsity: f64,
102 pub effective_dimensionality: Option<usize>,
104 pub noise_level: f64,
106 pub linearity_score: f64,
108}
109
110#[derive(Debug, Clone)]
112pub struct TargetStatistics {
113 pub mean: f64,
115 pub std: f64,
117 pub skewness: f64,
119 pub kurtosis: f64,
121 pub n_outliers: usize,
123}
124
125#[derive(Debug, Clone, Default)]
127pub struct ComputationalConstraints {
128 pub max_training_time: Option<f64>,
130 pub max_memory_gb: Option<f64>,
132 pub max_model_size_mb: Option<f64>,
134 pub max_inference_time_ms: Option<f64>,
136 pub n_cores: Option<usize>,
138 pub has_gpu: bool,
140}
141
142#[derive(Debug, Clone)]
144pub struct AutoMLConfig {
145 pub task_type: TaskType,
147 pub constraints: ComputationalConstraints,
149 pub allowed_families: Option<Vec<AlgorithmFamily>>,
151 pub excluded_families: Vec<AlgorithmFamily>,
153 pub max_algorithms: usize,
155 pub cv_folds: usize,
157 pub scoring_metric: String,
159 pub hyperopt_time_budget: f64,
161 pub random_seed: Option<u64>,
163 pub enable_ensembles: bool,
165 pub enable_feature_engineering: bool,
167}
168
169impl Default for AutoMLConfig {
170 fn default() -> Self {
171 Self {
172 task_type: TaskType::Classification,
173 constraints: ComputationalConstraints::default(),
174 allowed_families: None,
175 excluded_families: Vec::new(),
176 max_algorithms: 10,
177 cv_folds: 5,
178 scoring_metric: "accuracy".to_string(),
179 hyperopt_time_budget: 300.0, random_seed: None,
181 enable_ensembles: true,
182 enable_feature_engineering: true,
183 }
184 }
185}
186
187#[derive(Debug, Clone)]
189pub struct AlgorithmSelectionResult {
190 pub selected_algorithms: Vec<RankedAlgorithm>,
192 pub dataset_characteristics: DatasetCharacteristics,
194 pub total_evaluation_time: f64,
196 pub n_algorithms_evaluated: usize,
198 pub best_algorithm: RankedAlgorithm,
200 pub improvement_over_baseline: f64,
202 pub explanation: String,
204}
205
206#[derive(Debug, Clone)]
208pub struct RankedAlgorithm {
209 pub algorithm: AlgorithmSpec,
211 pub cv_score: f64,
213 pub cv_std: f64,
215 pub training_time: f64,
217 pub memory_usage: f64,
219 pub best_params: HashMap<String, String>,
221 pub rank: usize,
223 pub selection_probability: f64,
225}
226
227impl fmt::Display for AlgorithmSelectionResult {
228 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229 writeln!(f, "AutoML Algorithm Selection Results")?;
230 writeln!(f, "==================================")?;
231 writeln!(
232 f,
233 "Dataset: {} samples, {} features",
234 self.dataset_characteristics.n_samples, self.dataset_characteristics.n_features
235 )?;
236 writeln!(f, "Algorithms evaluated: {}", self.n_algorithms_evaluated)?;
237 writeln!(
238 f,
239 "Total evaluation time: {:.2}s",
240 self.total_evaluation_time
241 )?;
242 writeln!(f)?;
243 writeln!(
244 f,
245 "Best Algorithm: {} ({})",
246 self.best_algorithm.algorithm.name, self.best_algorithm.algorithm.family
247 )?;
248 writeln!(
249 f,
250 "Score: {:.4} ± {:.4}",
251 self.best_algorithm.cv_score, self.best_algorithm.cv_std
252 )?;
253 writeln!(
254 f,
255 "Training time: {:.2}s",
256 self.best_algorithm.training_time
257 )?;
258 writeln!(
259 f,
260 "Improvement over baseline: {:.4}",
261 self.improvement_over_baseline
262 )?;
263 writeln!(f)?;
264 writeln!(f, "Explanation: {}", self.explanation)?;
265 writeln!(f)?;
266 writeln!(
267 f,
268 "Top {} Algorithms:",
269 self.selected_algorithms.len().min(5)
270 )?;
271 for (i, alg) in self.selected_algorithms.iter().take(5).enumerate() {
272 writeln!(
273 f,
274 "{}. {} ({}) - Score: {:.4} ± {:.4}",
275 i + 1,
276 alg.algorithm.name,
277 alg.algorithm.family,
278 alg.cv_score,
279 alg.cv_std
280 )?;
281 }
282 Ok(())
283 }
284}
285
286pub struct AutoMLAlgorithmSelector {
288 config: AutoMLConfig,
289 algorithm_catalog: HashMap<TaskType, Vec<AlgorithmSpec>>,
290}
291
292impl Default for AutoMLAlgorithmSelector {
293 fn default() -> Self {
294 Self::new(AutoMLConfig::default())
295 }
296}
297
298impl AutoMLAlgorithmSelector {
299 pub fn new(config: AutoMLConfig) -> Self {
301 let algorithm_catalog = Self::build_algorithm_catalog();
302 Self {
303 config,
304 algorithm_catalog,
305 }
306 }
307
308 pub fn analyze_dataset(&self, X: &Array2<f64>, y: &Array1<f64>) -> DatasetCharacteristics {
310 let n_samples = X.nrows();
311 let n_features = X.ncols();
312
313 let missing_ratio = self.calculate_missing_ratio(X);
315 let sparsity = self.calculate_sparsity(X);
316 let correlation_condition_number = self.calculate_correlation_condition_number(X);
317
318 let (n_classes, class_distribution, target_stats) = match self.config.task_type {
320 TaskType::Classification => {
321 let classes = self.get_unique_classes(y);
322 let class_dist = self.calculate_class_distribution(y, &classes);
323 (Some(classes.len()), Some(class_dist), None)
324 }
325 TaskType::Regression => {
326 let stats = self.calculate_target_statistics(y);
327 (None, None, Some(stats))
328 }
329 };
330
331 let linearity_score = self.estimate_linearity_score(X, y);
333 let noise_level = self.estimate_noise_level(X, y);
334 let effective_dimensionality = self.estimate_effective_dimensionality(X);
335 let categorical_ratio = self.calculate_categorical_ratio(X);
336
337 DatasetCharacteristics {
338 n_samples,
339 n_features,
340 n_classes,
341 class_distribution,
342 target_stats,
343 missing_ratio,
344 categorical_ratio,
345 correlation_condition_number,
346 sparsity,
347 effective_dimensionality,
348 noise_level,
349 linearity_score,
350 }
351 }
352
353 pub fn select_algorithms(
355 &self,
356 X: &Array2<f64>,
357 y: &Array1<f64>,
358 ) -> Result<AlgorithmSelectionResult> {
359 let start_time = std::time::Instant::now();
360
361 let dataset_chars = self.analyze_dataset(X, y);
363
364 let candidate_algorithms = self.get_candidate_algorithms(&dataset_chars)?;
366
367 let filtered_algorithms = self.filter_by_constraints(&candidate_algorithms, &dataset_chars);
369
370 let mut evaluated_algorithms = self.evaluate_algorithms(&filtered_algorithms, X, y)?;
372
373 evaluated_algorithms.sort_by(|a, b| {
375 b.cv_score
376 .partial_cmp(&a.cv_score)
377 .expect("operation should succeed")
378 });
379
380 let algorithms_copy = evaluated_algorithms.clone();
382 for (i, alg) in evaluated_algorithms.iter_mut().enumerate() {
383 alg.rank = i + 1;
384 alg.selection_probability = self.calculate_selection_probability(alg, &algorithms_copy);
385 }
386
387 let best_algorithm = evaluated_algorithms[0].clone();
388 let baseline_score = self.get_baseline_score(X, y)?;
389 let improvement = best_algorithm.cv_score - baseline_score;
390
391 let explanation = self.generate_explanation(&best_algorithm, &dataset_chars);
392
393 let total_time = start_time.elapsed().as_secs_f64();
394
395 Ok(AlgorithmSelectionResult {
396 selected_algorithms: evaluated_algorithms,
397 dataset_characteristics: dataset_chars,
398 total_evaluation_time: total_time,
399 n_algorithms_evaluated: filtered_algorithms.len(),
400 best_algorithm,
401 improvement_over_baseline: improvement,
402 explanation,
403 })
404 }
405
406 fn build_algorithm_catalog() -> HashMap<TaskType, Vec<AlgorithmSpec>> {
408 let mut catalog = HashMap::new();
409
410 let classification_algorithms = vec![
412 AlgorithmSpec {
414 family: AlgorithmFamily::Linear,
415 name: "LogisticRegression".to_string(),
416 default_params: [("C".to_string(), "1.0".to_string())]
417 .iter()
418 .cloned()
419 .collect(),
420 param_space: [(
421 "C".to_string(),
422 vec![
423 "0.001".to_string(),
424 "0.01".to_string(),
425 "0.1".to_string(),
426 "1.0".to_string(),
427 "10.0".to_string(),
428 "100.0".to_string(),
429 ],
430 )]
431 .iter()
432 .cloned()
433 .collect(),
434 complexity: 1.0,
435 memory_requirement: 1.0,
436 supports_proba: true,
437 handles_missing: false,
438 handles_categorical: false,
439 supports_incremental: false,
440 },
441 AlgorithmSpec {
442 family: AlgorithmFamily::Linear,
443 name: "RidgeClassifier".to_string(),
444 default_params: [("alpha".to_string(), "1.0".to_string())]
445 .iter()
446 .cloned()
447 .collect(),
448 param_space: [(
449 "alpha".to_string(),
450 vec![
451 "0.1".to_string(),
452 "1.0".to_string(),
453 "10.0".to_string(),
454 "100.0".to_string(),
455 ],
456 )]
457 .iter()
458 .cloned()
459 .collect(),
460 complexity: 1.0,
461 memory_requirement: 1.0,
462 supports_proba: false,
463 handles_missing: false,
464 handles_categorical: false,
465 supports_incremental: false,
466 },
467 AlgorithmSpec {
469 family: AlgorithmFamily::TreeBased,
470 name: "DecisionTreeClassifier".to_string(),
471 default_params: [("max_depth".to_string(), "None".to_string())]
472 .iter()
473 .cloned()
474 .collect(),
475 param_space: [
476 (
477 "max_depth".to_string(),
478 vec![
479 "3".to_string(),
480 "5".to_string(),
481 "10".to_string(),
482 "None".to_string(),
483 ],
484 ),
485 (
486 "min_samples_split".to_string(),
487 vec!["2".to_string(), "5".to_string(), "10".to_string()],
488 ),
489 ]
490 .iter()
491 .cloned()
492 .collect(),
493 complexity: 2.0,
494 memory_requirement: 2.0,
495 supports_proba: true,
496 handles_missing: false,
497 handles_categorical: true,
498 supports_incremental: false,
499 },
500 AlgorithmSpec {
501 family: AlgorithmFamily::TreeBased,
502 name: "RandomForestClassifier".to_string(),
503 default_params: [("n_estimators".to_string(), "100".to_string())]
504 .iter()
505 .cloned()
506 .collect(),
507 param_space: [
508 (
509 "n_estimators".to_string(),
510 vec!["50".to_string(), "100".to_string(), "200".to_string()],
511 ),
512 (
513 "max_depth".to_string(),
514 vec![
515 "3".to_string(),
516 "5".to_string(),
517 "10".to_string(),
518 "None".to_string(),
519 ],
520 ),
521 ]
522 .iter()
523 .cloned()
524 .collect(),
525 complexity: 4.0,
526 memory_requirement: 4.0,
527 supports_proba: true,
528 handles_missing: false,
529 handles_categorical: true,
530 supports_incremental: false,
531 },
532 AlgorithmSpec {
534 family: AlgorithmFamily::Ensemble,
535 name: "AdaBoostClassifier".to_string(),
536 default_params: [("n_estimators".to_string(), "50".to_string())]
537 .iter()
538 .cloned()
539 .collect(),
540 param_space: [
541 (
542 "n_estimators".to_string(),
543 vec!["25".to_string(), "50".to_string(), "100".to_string()],
544 ),
545 (
546 "learning_rate".to_string(),
547 vec!["0.1".to_string(), "0.5".to_string(), "1.0".to_string()],
548 ),
549 ]
550 .iter()
551 .cloned()
552 .collect(),
553 complexity: 3.0,
554 memory_requirement: 3.0,
555 supports_proba: true,
556 handles_missing: false,
557 handles_categorical: true,
558 supports_incremental: false,
559 },
560 AlgorithmSpec {
562 family: AlgorithmFamily::NeighborBased,
563 name: "KNeighborsClassifier".to_string(),
564 default_params: [("n_neighbors".to_string(), "5".to_string())]
565 .iter()
566 .cloned()
567 .collect(),
568 param_space: [
569 (
570 "n_neighbors".to_string(),
571 vec![
572 "3".to_string(),
573 "5".to_string(),
574 "7".to_string(),
575 "11".to_string(),
576 ],
577 ),
578 (
579 "weights".to_string(),
580 vec!["uniform".to_string(), "distance".to_string()],
581 ),
582 ]
583 .iter()
584 .cloned()
585 .collect(),
586 complexity: 1.0,
587 memory_requirement: 5.0,
588 supports_proba: true,
589 handles_missing: false,
590 handles_categorical: false,
591 supports_incremental: false,
592 },
593 AlgorithmSpec {
595 family: AlgorithmFamily::NaiveBayes,
596 name: "GaussianNB".to_string(),
597 default_params: HashMap::new(),
598 param_space: HashMap::new(),
599 complexity: 1.0,
600 memory_requirement: 1.0,
601 supports_proba: true,
602 handles_missing: false,
603 handles_categorical: false,
604 supports_incremental: true,
605 },
606 AlgorithmSpec {
608 family: AlgorithmFamily::SVM,
609 name: "SVC".to_string(),
610 default_params: [
611 ("C".to_string(), "1.0".to_string()),
612 ("kernel".to_string(), "rbf".to_string()),
613 ]
614 .iter()
615 .cloned()
616 .collect(),
617 param_space: [
618 (
619 "C".to_string(),
620 vec!["0.1".to_string(), "1.0".to_string(), "10.0".to_string()],
621 ),
622 (
623 "kernel".to_string(),
624 vec!["linear".to_string(), "rbf".to_string(), "poly".to_string()],
625 ),
626 ]
627 .iter()
628 .cloned()
629 .collect(),
630 complexity: 3.0,
631 memory_requirement: 3.0,
632 supports_proba: false,
633 handles_missing: false,
634 handles_categorical: false,
635 supports_incremental: false,
636 },
637 AlgorithmSpec {
639 family: AlgorithmFamily::Dummy,
640 name: "DummyClassifier".to_string(),
641 default_params: [("strategy".to_string(), "stratified".to_string())]
642 .iter()
643 .cloned()
644 .collect(),
645 param_space: [(
646 "strategy".to_string(),
647 vec![
648 "stratified".to_string(),
649 "most_frequent".to_string(),
650 "uniform".to_string(),
651 ],
652 )]
653 .iter()
654 .cloned()
655 .collect(),
656 complexity: 0.1,
657 memory_requirement: 0.1,
658 supports_proba: true,
659 handles_missing: true,
660 handles_categorical: true,
661 supports_incremental: true,
662 },
663 ];
664
665 let regression_algorithms = vec![
667 AlgorithmSpec {
669 family: AlgorithmFamily::Linear,
670 name: "LinearRegression".to_string(),
671 default_params: HashMap::new(),
672 param_space: HashMap::new(),
673 complexity: 1.0,
674 memory_requirement: 1.0,
675 supports_proba: false,
676 handles_missing: false,
677 handles_categorical: false,
678 supports_incremental: false,
679 },
680 AlgorithmSpec {
681 family: AlgorithmFamily::Linear,
682 name: "Ridge".to_string(),
683 default_params: [("alpha".to_string(), "1.0".to_string())]
684 .iter()
685 .cloned()
686 .collect(),
687 param_space: [(
688 "alpha".to_string(),
689 vec![
690 "0.1".to_string(),
691 "1.0".to_string(),
692 "10.0".to_string(),
693 "100.0".to_string(),
694 ],
695 )]
696 .iter()
697 .cloned()
698 .collect(),
699 complexity: 1.0,
700 memory_requirement: 1.0,
701 supports_proba: false,
702 handles_missing: false,
703 handles_categorical: false,
704 supports_incremental: false,
705 },
706 AlgorithmSpec {
707 family: AlgorithmFamily::Linear,
708 name: "Lasso".to_string(),
709 default_params: [("alpha".to_string(), "1.0".to_string())]
710 .iter()
711 .cloned()
712 .collect(),
713 param_space: [(
714 "alpha".to_string(),
715 vec![
716 "0.001".to_string(),
717 "0.01".to_string(),
718 "0.1".to_string(),
719 "1.0".to_string(),
720 ],
721 )]
722 .iter()
723 .cloned()
724 .collect(),
725 complexity: 1.5,
726 memory_requirement: 1.0,
727 supports_proba: false,
728 handles_missing: false,
729 handles_categorical: false,
730 supports_incremental: false,
731 },
732 AlgorithmSpec {
734 family: AlgorithmFamily::TreeBased,
735 name: "DecisionTreeRegressor".to_string(),
736 default_params: [("max_depth".to_string(), "None".to_string())]
737 .iter()
738 .cloned()
739 .collect(),
740 param_space: [
741 (
742 "max_depth".to_string(),
743 vec![
744 "3".to_string(),
745 "5".to_string(),
746 "10".to_string(),
747 "None".to_string(),
748 ],
749 ),
750 (
751 "min_samples_split".to_string(),
752 vec!["2".to_string(), "5".to_string(), "10".to_string()],
753 ),
754 ]
755 .iter()
756 .cloned()
757 .collect(),
758 complexity: 2.0,
759 memory_requirement: 2.0,
760 supports_proba: false,
761 handles_missing: false,
762 handles_categorical: true,
763 supports_incremental: false,
764 },
765 AlgorithmSpec {
766 family: AlgorithmFamily::TreeBased,
767 name: "RandomForestRegressor".to_string(),
768 default_params: [("n_estimators".to_string(), "100".to_string())]
769 .iter()
770 .cloned()
771 .collect(),
772 param_space: [
773 (
774 "n_estimators".to_string(),
775 vec!["50".to_string(), "100".to_string(), "200".to_string()],
776 ),
777 (
778 "max_depth".to_string(),
779 vec![
780 "3".to_string(),
781 "5".to_string(),
782 "10".to_string(),
783 "None".to_string(),
784 ],
785 ),
786 ]
787 .iter()
788 .cloned()
789 .collect(),
790 complexity: 4.0,
791 memory_requirement: 4.0,
792 supports_proba: false,
793 handles_missing: false,
794 handles_categorical: true,
795 supports_incremental: false,
796 },
797 AlgorithmSpec {
799 family: AlgorithmFamily::NeighborBased,
800 name: "KNeighborsRegressor".to_string(),
801 default_params: [("n_neighbors".to_string(), "5".to_string())]
802 .iter()
803 .cloned()
804 .collect(),
805 param_space: [
806 (
807 "n_neighbors".to_string(),
808 vec![
809 "3".to_string(),
810 "5".to_string(),
811 "7".to_string(),
812 "11".to_string(),
813 ],
814 ),
815 (
816 "weights".to_string(),
817 vec!["uniform".to_string(), "distance".to_string()],
818 ),
819 ]
820 .iter()
821 .cloned()
822 .collect(),
823 complexity: 1.0,
824 memory_requirement: 5.0,
825 supports_proba: false,
826 handles_missing: false,
827 handles_categorical: false,
828 supports_incremental: false,
829 },
830 AlgorithmSpec {
832 family: AlgorithmFamily::SVM,
833 name: "SVR".to_string(),
834 default_params: [
835 ("C".to_string(), "1.0".to_string()),
836 ("kernel".to_string(), "rbf".to_string()),
837 ]
838 .iter()
839 .cloned()
840 .collect(),
841 param_space: [
842 (
843 "C".to_string(),
844 vec!["0.1".to_string(), "1.0".to_string(), "10.0".to_string()],
845 ),
846 (
847 "kernel".to_string(),
848 vec!["linear".to_string(), "rbf".to_string(), "poly".to_string()],
849 ),
850 ]
851 .iter()
852 .cloned()
853 .collect(),
854 complexity: 3.0,
855 memory_requirement: 3.0,
856 supports_proba: false,
857 handles_missing: false,
858 handles_categorical: false,
859 supports_incremental: false,
860 },
861 AlgorithmSpec {
863 family: AlgorithmFamily::Dummy,
864 name: "DummyRegressor".to_string(),
865 default_params: [("strategy".to_string(), "mean".to_string())]
866 .iter()
867 .cloned()
868 .collect(),
869 param_space: [(
870 "strategy".to_string(),
871 vec![
872 "mean".to_string(),
873 "median".to_string(),
874 "constant".to_string(),
875 ],
876 )]
877 .iter()
878 .cloned()
879 .collect(),
880 complexity: 0.1,
881 memory_requirement: 0.1,
882 supports_proba: false,
883 handles_missing: true,
884 handles_categorical: true,
885 supports_incremental: true,
886 },
887 ];
888
889 catalog.insert(TaskType::Classification, classification_algorithms);
890 catalog.insert(TaskType::Regression, regression_algorithms);
891 catalog
892 }
893
894 fn get_candidate_algorithms(
896 &self,
897 dataset_chars: &DatasetCharacteristics,
898 ) -> Result<Vec<AlgorithmSpec>> {
899 let algorithms = self
900 .algorithm_catalog
901 .get(&self.config.task_type)
902 .ok_or_else(|| SklearsError::InvalidParameter {
903 name: "task_type".to_string(),
904 reason: format!(
905 "No algorithms available for task type: {:?}",
906 self.config.task_type
907 ),
908 })?;
909
910 let mut candidates = Vec::new();
911
912 for algorithm in algorithms {
913 if let Some(ref allowed) = self.config.allowed_families {
915 if !allowed.contains(&algorithm.family) {
916 continue;
917 }
918 }
919
920 if self.config.excluded_families.contains(&algorithm.family) {
922 continue;
923 }
924
925 if self.is_algorithm_suitable(algorithm, dataset_chars) {
927 candidates.push(algorithm.clone());
928 }
929 }
930
931 candidates.truncate(self.config.max_algorithms);
933
934 Ok(candidates)
935 }
936
937 fn is_algorithm_suitable(
939 &self,
940 algorithm: &AlgorithmSpec,
941 dataset_chars: &DatasetCharacteristics,
942 ) -> bool {
943 if algorithm.family == AlgorithmFamily::Dummy && !self.config.excluded_families.is_empty() {
945 return false;
946 }
947
948 if dataset_chars.n_features > dataset_chars.n_samples {
950 match algorithm.family {
952 AlgorithmFamily::Linear | AlgorithmFamily::NaiveBayes => return true,
953 AlgorithmFamily::NeighborBased | AlgorithmFamily::SVM => return false,
954 _ => {}
955 }
956 }
957
958 if dataset_chars.n_samples < 100 {
960 if algorithm.complexity > 3.0 {
962 return false;
963 }
964 }
965
966 if dataset_chars.n_samples > 10000 {
968 match algorithm.family {
970 AlgorithmFamily::NeighborBased => return false, AlgorithmFamily::SVM => return dataset_chars.n_samples < 50000, _ => {}
973 }
974 }
975
976 if dataset_chars.linearity_score > 0.8 {
978 match algorithm.family {
980 AlgorithmFamily::Linear => return true,
981 AlgorithmFamily::TreeBased | AlgorithmFamily::Ensemble => return false,
982 _ => {}
983 }
984 }
985
986 if dataset_chars.missing_ratio > 0.0 && !algorithm.handles_missing {
988 return false;
989 }
990
991 true
992 }
993
994 fn filter_by_constraints(
996 &self,
997 algorithms: &[AlgorithmSpec],
998 dataset_chars: &DatasetCharacteristics,
999 ) -> Vec<AlgorithmSpec> {
1000 algorithms
1001 .iter()
1002 .filter(|alg| self.satisfies_constraints(alg, dataset_chars))
1003 .cloned()
1004 .collect()
1005 }
1006
1007 fn satisfies_constraints(
1009 &self,
1010 algorithm: &AlgorithmSpec,
1011 dataset_chars: &DatasetCharacteristics,
1012 ) -> bool {
1013 let estimated_training_time = self.estimate_training_time(algorithm, dataset_chars);
1015 let estimated_memory_usage = self.estimate_memory_usage(algorithm, dataset_chars);
1016
1017 if let Some(max_time) = self.config.constraints.max_training_time {
1018 if estimated_training_time > max_time {
1019 return false;
1020 }
1021 }
1022
1023 if let Some(max_memory) = self.config.constraints.max_memory_gb {
1024 if estimated_memory_usage > max_memory {
1025 return false;
1026 }
1027 }
1028
1029 true
1030 }
1031
1032 fn estimate_training_time(
1034 &self,
1035 algorithm: &AlgorithmSpec,
1036 dataset_chars: &DatasetCharacteristics,
1037 ) -> f64 {
1038 let n = dataset_chars.n_samples as f64;
1039 let p = dataset_chars.n_features as f64;
1040
1041 let base_time = match algorithm.family {
1043 AlgorithmFamily::Linear => 0.1,
1044 AlgorithmFamily::TreeBased => {
1045 if algorithm.name.contains("Random") {
1046 2.0
1047 } else {
1048 0.5
1049 }
1050 }
1051 AlgorithmFamily::Ensemble => 3.0,
1052 AlgorithmFamily::NeighborBased => 0.05, AlgorithmFamily::SVM => 1.0,
1054 AlgorithmFamily::NaiveBayes => 0.05,
1055 AlgorithmFamily::NeuralNetwork => 5.0,
1056 AlgorithmFamily::GaussianProcess => 2.0,
1057 AlgorithmFamily::DiscriminantAnalysis => 0.2,
1058 AlgorithmFamily::Dummy => 0.01,
1059 };
1060
1061 base_time * algorithm.complexity * (n / 1000.0) * (p / 10.0).sqrt()
1063 }
1064
1065 fn estimate_memory_usage(
1067 &self,
1068 algorithm: &AlgorithmSpec,
1069 dataset_chars: &DatasetCharacteristics,
1070 ) -> f64 {
1071 let n = dataset_chars.n_samples as f64;
1072 let p = dataset_chars.n_features as f64;
1073
1074 let base_memory_mb = match algorithm.family {
1076 AlgorithmFamily::Linear => 1.0,
1077 AlgorithmFamily::TreeBased => {
1078 if algorithm.name.contains("Random") {
1079 50.0
1080 } else {
1081 10.0
1082 }
1083 }
1084 AlgorithmFamily::Ensemble => 100.0,
1085 AlgorithmFamily::NeighborBased => n * p * 8.0 / 1_000_000.0, AlgorithmFamily::SVM => 20.0,
1087 AlgorithmFamily::NaiveBayes => 1.0,
1088 AlgorithmFamily::NeuralNetwork => 50.0,
1089 AlgorithmFamily::GaussianProcess => 10.0,
1090 AlgorithmFamily::DiscriminantAnalysis => 5.0,
1091 AlgorithmFamily::Dummy => 0.1,
1092 };
1093
1094 (base_memory_mb * algorithm.memory_requirement) / 1000.0 }
1096
1097 fn evaluate_algorithms(
1099 &self,
1100 algorithms: &[AlgorithmSpec],
1101 X: &Array2<f64>,
1102 y: &Array1<f64>,
1103 ) -> Result<Vec<RankedAlgorithm>> {
1104 let mut results = Vec::new();
1105
1106 for algorithm in algorithms {
1107 let start_time = std::time::Instant::now();
1108
1109 let cv_score = self.mock_evaluate_algorithm(algorithm, X, y);
1112 let cv_std = cv_score * 0.05; let training_time = start_time.elapsed().as_secs_f64();
1115 let memory_usage = self.estimate_memory_usage(algorithm, &self.analyze_dataset(X, y));
1116
1117 results.push(RankedAlgorithm {
1118 algorithm: algorithm.clone(),
1119 cv_score,
1120 cv_std,
1121 training_time,
1122 memory_usage,
1123 best_params: algorithm.default_params.clone(),
1124 rank: 0, selection_probability: 0.0, });
1127 }
1128
1129 Ok(results)
1130 }
1131
1132 fn mock_evaluate_algorithm(
1134 &self,
1135 algorithm: &AlgorithmSpec,
1136 X: &Array2<f64>,
1137 y: &Array1<f64>,
1138 ) -> f64 {
1139 let dataset_chars = self.analyze_dataset(X, y);
1141
1142 let base_score = match self.config.task_type {
1143 TaskType::Classification => 0.7, TaskType::Regression => 0.8, };
1146
1147 let mut score: f64 = base_score;
1149
1150 if algorithm.family == AlgorithmFamily::Linear && dataset_chars.linearity_score > 0.7 {
1152 score += 0.1;
1153 }
1154
1155 if matches!(
1157 algorithm.family,
1158 AlgorithmFamily::TreeBased | AlgorithmFamily::Ensemble
1159 ) && dataset_chars.linearity_score < 0.5
1160 {
1161 score += 0.1;
1162 }
1163
1164 if algorithm.family == AlgorithmFamily::Ensemble {
1166 score += 0.05;
1167 }
1168
1169 let mut rng = scirs2_core::random::thread_rng();
1172 score += rng.gen_range(-0.05..0.05);
1173
1174 score.clamp(0.0, 1.0)
1175 }
1176
1177 fn calculate_selection_probability(
1179 &self,
1180 algorithm: &RankedAlgorithm,
1181 all_algorithms: &[RankedAlgorithm],
1182 ) -> f64 {
1183 let max_score = all_algorithms
1184 .iter()
1185 .map(|a| a.cv_score)
1186 .fold(0.0, f64::max);
1187 let min_score = all_algorithms
1188 .iter()
1189 .map(|a| a.cv_score)
1190 .fold(1.0, f64::min);
1191
1192 if max_score == min_score {
1193 return 1.0 / all_algorithms.len() as f64;
1194 }
1195
1196 let normalized_score = (algorithm.cv_score - min_score) / (max_score - min_score);
1198 let exp_score = (normalized_score * 5.0).exp();
1199 let total_exp: f64 = all_algorithms
1200 .iter()
1201 .map(|a| {
1202 let norm = (a.cv_score - min_score) / (max_score - min_score);
1203 (norm * 5.0).exp()
1204 })
1205 .sum();
1206
1207 exp_score / total_exp
1208 }
1209
1210 fn get_baseline_score(&self, _X: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
1212 match self.config.task_type {
1213 TaskType::Classification => {
1214 let classes = self.get_unique_classes(y);
1216 let class_counts = self.calculate_class_distribution(y, &classes);
1217 Ok(class_counts.iter().fold(0.0, |acc, &x| acc.max(x)))
1218 }
1219 TaskType::Regression => {
1220 let mean = y.mean().expect("operation should succeed");
1222 let tss: f64 = y.iter().map(|&yi| (yi - mean).powi(2)).sum();
1223 let rss = tss; Ok(1.0 - rss / tss)
1225 }
1226 }
1227 }
1228
1229 fn generate_explanation(
1231 &self,
1232 best_algorithm: &RankedAlgorithm,
1233 dataset_chars: &DatasetCharacteristics,
1234 ) -> String {
1235 let mut explanation = format!(
1236 "{} ({}) was selected as the best algorithm with a cross-validation score of {:.4}.",
1237 best_algorithm.algorithm.name, best_algorithm.algorithm.family, best_algorithm.cv_score
1238 );
1239
1240 if dataset_chars.n_samples < 1000 {
1242 explanation.push_str(" This algorithm is well-suited for small datasets.");
1243 } else if dataset_chars.n_samples > 10000 {
1244 explanation.push_str(" This algorithm scales well to large datasets.");
1245 }
1246
1247 if dataset_chars.linearity_score > 0.7
1248 && best_algorithm.algorithm.family == AlgorithmFamily::Linear
1249 {
1250 explanation.push_str(
1251 " The linear nature of your data makes linear models particularly effective.",
1252 );
1253 }
1254
1255 if dataset_chars.n_features > dataset_chars.n_samples {
1256 explanation.push_str(" The high-dimensional nature of your data favors this algorithm's regularization capabilities.");
1257 }
1258
1259 if best_algorithm.algorithm.family == AlgorithmFamily::Ensemble {
1260 explanation.push_str(
1261 " Ensemble methods often provide robust performance across diverse datasets.",
1262 );
1263 }
1264
1265 explanation
1266 }
1267
1268 fn calculate_missing_ratio(&self, X: &Array2<f64>) -> f64 {
1270 let total_values = X.len() as f64;
1271 let missing_count = X.iter().filter(|&&x| x.is_nan()).count() as f64;
1272 missing_count / total_values
1273 }
1274
1275 fn calculate_sparsity(&self, X: &Array2<f64>) -> f64 {
1276 let total_values = X.len() as f64;
1277 let zero_count = X.iter().filter(|&&x| x == 0.0).count() as f64;
1278 zero_count / total_values
1279 }
1280
1281 fn calculate_categorical_ratio(&self, X: &Array2<f64>) -> f64 {
1282 let n_features = X.ncols();
1283 if n_features == 0 {
1284 return 0.0;
1285 }
1286
1287 let mut categorical_count = 0;
1288 for col_idx in 0..n_features {
1289 let column = X.column(col_idx);
1290
1291 let valid_values: Vec<f64> = column.iter().filter(|&&x| !x.is_nan()).copied().collect();
1293
1294 if valid_values.is_empty() {
1295 continue;
1296 }
1297
1298 let mut unique_values = valid_values.clone();
1300 unique_values.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
1301 unique_values.dedup();
1302
1303 let n_unique = unique_values.len();
1304 let n_total = valid_values.len();
1305
1306 let unique_ratio = n_unique as f64 / n_total as f64;
1310 let all_integers = valid_values.iter().all(|&x| (x - x.round()).abs() < 1e-10);
1311
1312 if n_unique <= 10 || (unique_ratio < 0.05 && all_integers) {
1313 categorical_count += 1;
1314 }
1315 }
1316
1317 categorical_count as f64 / n_features as f64
1318 }
1319
1320 fn calculate_correlation_condition_number(&self, _X: &Array2<f64>) -> f64 {
1321 let mut rng = scirs2_core::random::thread_rng();
1324 rng.gen_range(1.0..100.0)
1325 }
1326
1327 fn get_unique_classes(&self, y: &Array1<f64>) -> Vec<i32> {
1328 let mut classes: Vec<i32> = y.iter().map(|&x| x as i32).collect();
1329 classes.sort_unstable();
1330 classes.dedup();
1331 classes
1332 }
1333
1334 fn calculate_class_distribution(&self, y: &Array1<f64>, classes: &[i32]) -> Vec<f64> {
1335 let total = y.len() as f64;
1336 classes
1337 .iter()
1338 .map(|&class| {
1339 let count = y.iter().filter(|&&yi| yi as i32 == class).count() as f64;
1340 count / total
1341 })
1342 .collect()
1343 }
1344
1345 fn calculate_target_statistics(&self, y: &Array1<f64>) -> TargetStatistics {
1346 let mean = y.mean().expect("operation should succeed");
1347 let std = y.std(0.0);
1348
1349 TargetStatistics {
1351 mean,
1352 std,
1353 skewness: 0.0, kurtosis: 0.0, n_outliers: 0, }
1357 }
1358
1359 fn estimate_linearity_score(&self, _X: &Array2<f64>, _y: &Array1<f64>) -> f64 {
1360 let mut rng = scirs2_core::random::thread_rng();
1363 rng.gen_range(0.0..1.0)
1364 }
1365
1366 fn estimate_noise_level(&self, _X: &Array2<f64>, _y: &Array1<f64>) -> f64 {
1367 let mut rng = scirs2_core::random::thread_rng();
1370 rng.gen_range(0.0..0.5)
1371 }
1372
1373 fn estimate_effective_dimensionality(&self, X: &Array2<f64>) -> Option<usize> {
1374 Some((X.ncols() as f64 * 0.8) as usize)
1376 }
1377}
1378
1379pub fn select_best_algorithm(
1381 X: &Array2<f64>,
1382 y: &Array1<f64>,
1383 task_type: TaskType,
1384) -> Result<AlgorithmSelectionResult> {
1385 let config = AutoMLConfig {
1386 task_type,
1387 ..Default::default()
1388 };
1389
1390 let selector = AutoMLAlgorithmSelector::new(config);
1391 selector.select_algorithms(X, y)
1392}
1393
1394#[allow(non_snake_case)]
1395#[cfg(test)]
1396mod tests {
1397 use super::*;
1398 use scirs2_core::ndarray::{Array1, Array2};
1399
1400 #[allow(non_snake_case)]
1401 fn create_test_classification_data() -> (Array2<f64>, Array1<f64>) {
1402 let X = Array2::from_shape_vec((100, 4), (0..400).map(|i| i as f64).collect())
1403 .expect("operation should succeed");
1404 let y = Array1::from_vec((0..100).map(|i| (i % 3) as f64).collect());
1405 (X, y)
1406 }
1407
1408 #[allow(non_snake_case)]
1409 fn create_test_regression_data() -> (Array2<f64>, Array1<f64>) {
1410 let X = Array2::from_shape_vec((100, 4), (0..400).map(|i| i as f64).collect())
1411 .expect("operation should succeed");
1412 use scirs2_core::essentials::Uniform;
1413 use scirs2_core::random::{thread_rng, Distribution};
1414 let mut rng = thread_rng();
1415 let dist = Uniform::new(0.0, 1.0).expect("operation should succeed");
1416 let y = Array1::from_vec((0..100).map(|i| i as f64 + dist.sample(&mut rng)).collect());
1417 (X, y)
1418 }
1419
1420 #[test]
1421 fn test_algorithm_selection_classification() {
1422 let (X, y) = create_test_classification_data();
1423 let result = select_best_algorithm(&X, &y, TaskType::Classification);
1424 assert!(result.is_ok());
1425
1426 let result = result.expect("operation should succeed");
1427 assert!(!result.selected_algorithms.is_empty());
1428 assert!(result.best_algorithm.cv_score > 0.0);
1429 }
1430
1431 #[test]
1432 fn test_algorithm_selection_regression() {
1433 let (X, y) = create_test_regression_data();
1434 let result = select_best_algorithm(&X, &y, TaskType::Regression);
1435 assert!(result.is_ok());
1436
1437 let result = result.expect("operation should succeed");
1438 assert!(!result.selected_algorithms.is_empty());
1439 assert!(result.best_algorithm.cv_score > 0.0);
1440 }
1441
1442 #[test]
1443 fn test_dataset_characteristics_analysis() {
1444 let (X, y) = create_test_classification_data();
1445 let config = AutoMLConfig::default();
1446 let selector = AutoMLAlgorithmSelector::new(config);
1447
1448 let chars = selector.analyze_dataset(&X, &y);
1449 assert_eq!(chars.n_samples, 100);
1450 assert_eq!(chars.n_features, 4);
1451 assert_eq!(chars.n_classes, Some(3));
1452 }
1453
1454 #[test]
1455 fn test_custom_config() {
1456 let (X, y) = create_test_classification_data();
1457
1458 let config = AutoMLConfig {
1459 task_type: TaskType::Classification,
1460 max_algorithms: 3,
1461 allowed_families: Some(vec![AlgorithmFamily::Linear, AlgorithmFamily::TreeBased]),
1462 ..Default::default()
1463 };
1464
1465 let selector = AutoMLAlgorithmSelector::new(config);
1466 let result = selector.select_algorithms(&X, &y);
1467 assert!(result.is_ok());
1468
1469 let result = result.expect("operation should succeed");
1470 assert!(result.n_algorithms_evaluated <= 3);
1471
1472 for alg in &result.selected_algorithms {
1473 assert!(matches!(
1474 alg.algorithm.family,
1475 AlgorithmFamily::Linear | AlgorithmFamily::TreeBased
1476 ));
1477 }
1478 }
1479
1480 #[test]
1481 fn test_computational_constraints() {
1482 let (X, y) = create_test_classification_data();
1483
1484 let config = AutoMLConfig {
1485 task_type: TaskType::Classification,
1486 constraints: ComputationalConstraints {
1487 max_training_time: Some(1.0), max_memory_gb: Some(0.1), ..Default::default()
1490 },
1491 ..Default::default()
1492 };
1493
1494 let selector = AutoMLAlgorithmSelector::new(config);
1495 let result = selector.select_algorithms(&X, &y);
1496 assert!(result.is_ok());
1497
1498 let result = result.expect("operation should succeed");
1499 assert!(!result.selected_algorithms.is_empty());
1501 }
1502}