Skip to main content

sklears_model_selection/
automl_algorithm_selection.rs

1//! Automated Algorithm Selection for AutoML
2//!
3//! This module provides intelligent algorithm selection based on dataset characteristics,
4//! computational constraints, and performance requirements. It automatically selects
5//! and configures the best algorithms for classification and regression tasks.
6
7use crate::scoring::TaskType;
8use scirs2_core::ndarray::{Array1, Array2};
9use sklears_core::error::{Result, SklearsError};
10use std::collections::HashMap;
11use std::fmt;
12// use serde::{Deserialize, Serialize};
13
14/// Algorithm family categories for classification and regression
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub enum AlgorithmFamily {
17    /// Linear models (LogisticRegression, LinearRegression, etc.)
18    Linear,
19    /// Tree-based models (DecisionTree, RandomForest, etc.)
20    TreeBased,
21    /// Ensemble methods (AdaBoost, Stacking, etc.)
22    Ensemble,
23    /// Neighbor-based models (KNN, etc.)
24    NeighborBased,
25    /// Support Vector Machines
26    SVM,
27    /// Naive Bayes classifiers
28    NaiveBayes,
29    /// Neural networks
30    NeuralNetwork,
31    /// Gaussian processes
32    GaussianProcess,
33    /// Discriminant analysis
34    DiscriminantAnalysis,
35    /// Dummy/baseline models
36    Dummy,
37}
38
39impl fmt::Display for AlgorithmFamily {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        match self {
42            AlgorithmFamily::Linear => write!(f, "Linear"),
43            AlgorithmFamily::TreeBased => write!(f, "Tree-based"),
44            AlgorithmFamily::Ensemble => write!(f, "Ensemble"),
45            AlgorithmFamily::NeighborBased => write!(f, "Neighbor-based"),
46            AlgorithmFamily::SVM => write!(f, "Support Vector Machine"),
47            AlgorithmFamily::NaiveBayes => write!(f, "Naive Bayes"),
48            AlgorithmFamily::NeuralNetwork => write!(f, "Neural Network"),
49            AlgorithmFamily::GaussianProcess => write!(f, "Gaussian Process"),
50            AlgorithmFamily::DiscriminantAnalysis => write!(f, "Discriminant Analysis"),
51            AlgorithmFamily::Dummy => write!(f, "Dummy/Baseline"),
52        }
53    }
54}
55
56/// Specific algorithm within a family
57#[derive(Debug, Clone, PartialEq)]
58pub struct AlgorithmSpec {
59    /// Algorithm family
60    pub family: AlgorithmFamily,
61    /// Specific algorithm name
62    pub name: String,
63    /// Default hyperparameters
64    pub default_params: HashMap<String, String>,
65    /// Hyperparameter search space
66    pub param_space: HashMap<String, Vec<String>>,
67    /// Computational complexity (relative scale)
68    pub complexity: f64,
69    /// Memory requirements (relative scale)
70    pub memory_requirement: f64,
71    /// Supports probability prediction
72    pub supports_proba: bool,
73    /// Handles missing values
74    pub handles_missing: bool,
75    /// Handles categorical features
76    pub handles_categorical: bool,
77    /// Supports incremental learning
78    pub supports_incremental: bool,
79}
80
81/// Dataset characteristics for algorithm selection
82#[derive(Debug, Clone)]
83pub struct DatasetCharacteristics {
84    /// Number of samples
85    pub n_samples: usize,
86    /// Number of features
87    pub n_features: usize,
88    /// Number of classes (for classification)
89    pub n_classes: Option<usize>,
90    /// Class distribution (for classification)
91    pub class_distribution: Option<Vec<f64>>,
92    /// Target distribution (for regression)
93    pub target_stats: Option<TargetStatistics>,
94    /// Missing value ratio
95    pub missing_ratio: f64,
96    /// Categorical feature ratio
97    pub categorical_ratio: f64,
98    /// Feature correlation matrix condition number
99    pub correlation_condition_number: f64,
100    /// Dataset sparsity
101    pub sparsity: f64,
102    /// Effective dimensionality (based on PCA)
103    pub effective_dimensionality: Option<usize>,
104    /// Estimated noise level
105    pub noise_level: f64,
106    /// Linearity score (0-1, higher means more linear)
107    pub linearity_score: f64,
108}
109
110/// Target statistics for regression tasks
111#[derive(Debug, Clone)]
112pub struct TargetStatistics {
113    /// Mean of target values
114    pub mean: f64,
115    /// Standard deviation of target values
116    pub std: f64,
117    /// Skewness of target distribution
118    pub skewness: f64,
119    /// Kurtosis of target distribution
120    pub kurtosis: f64,
121    /// Number of outliers
122    pub n_outliers: usize,
123}
124
125/// Computational constraints for algorithm selection
126#[derive(Debug, Clone, Default)]
127pub struct ComputationalConstraints {
128    /// Maximum training time in seconds
129    pub max_training_time: Option<f64>,
130    /// Maximum memory usage in GB
131    pub max_memory_gb: Option<f64>,
132    /// Maximum model size in MB
133    pub max_model_size_mb: Option<f64>,
134    /// Maximum inference time per sample in milliseconds
135    pub max_inference_time_ms: Option<f64>,
136    /// Available CPU cores
137    pub n_cores: Option<usize>,
138    /// GPU availability
139    pub has_gpu: bool,
140}
141
142/// Configuration for automated algorithm selection
143#[derive(Debug, Clone)]
144pub struct AutoMLConfig {
145    /// Task type (classification or regression)
146    pub task_type: TaskType,
147    /// Computational constraints
148    pub constraints: ComputationalConstraints,
149    /// Algorithm families to consider (None means all)
150    pub allowed_families: Option<Vec<AlgorithmFamily>>,
151    /// Algorithm families to exclude
152    pub excluded_families: Vec<AlgorithmFamily>,
153    /// Maximum number of algorithms to evaluate
154    pub max_algorithms: usize,
155    /// Cross-validation strategy
156    pub cv_folds: usize,
157    /// Scoring metric
158    pub scoring_metric: String,
159    /// Time budget for hyperparameter optimization per algorithm
160    pub hyperopt_time_budget: f64,
161    /// Random seed for reproducibility
162    pub random_seed: Option<u64>,
163    /// Whether to use ensemble methods
164    pub enable_ensembles: bool,
165    /// Whether to perform feature engineering
166    pub enable_feature_engineering: bool,
167}
168
169impl Default for AutoMLConfig {
170    fn default() -> Self {
171        Self {
172            task_type: TaskType::Classification,
173            constraints: ComputationalConstraints::default(),
174            allowed_families: None,
175            excluded_families: Vec::new(),
176            max_algorithms: 10,
177            cv_folds: 5,
178            scoring_metric: "accuracy".to_string(),
179            hyperopt_time_budget: 300.0, // 5 minutes per algorithm
180            random_seed: None,
181            enable_ensembles: true,
182            enable_feature_engineering: true,
183        }
184    }
185}
186
187/// Result of algorithm selection process
188#[derive(Debug, Clone)]
189pub struct AlgorithmSelectionResult {
190    /// Selected algorithm specifications
191    pub selected_algorithms: Vec<RankedAlgorithm>,
192    /// Dataset characteristics used for selection
193    pub dataset_characteristics: DatasetCharacteristics,
194    /// Total evaluation time
195    pub total_evaluation_time: f64,
196    /// Number of algorithms evaluated
197    pub n_algorithms_evaluated: usize,
198    /// Best performing algorithm
199    pub best_algorithm: RankedAlgorithm,
200    /// Performance improvement over baseline
201    pub improvement_over_baseline: f64,
202    /// Recommendation explanation
203    pub explanation: String,
204}
205
206/// Algorithm with performance ranking
207#[derive(Debug, Clone)]
208pub struct RankedAlgorithm {
209    /// Algorithm specification
210    pub algorithm: AlgorithmSpec,
211    /// Cross-validation score
212    pub cv_score: f64,
213    /// Standard deviation of CV scores
214    pub cv_std: f64,
215    /// Training time in seconds
216    pub training_time: f64,
217    /// Memory usage in MB
218    pub memory_usage: f64,
219    /// Optimized hyperparameters
220    pub best_params: HashMap<String, String>,
221    /// Rank among all algorithms (1 = best)
222    pub rank: usize,
223    /// Selection probability based on performance
224    pub selection_probability: f64,
225}
226
227impl fmt::Display for AlgorithmSelectionResult {
228    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229        writeln!(f, "AutoML Algorithm Selection Results")?;
230        writeln!(f, "==================================")?;
231        writeln!(
232            f,
233            "Dataset: {} samples, {} features",
234            self.dataset_characteristics.n_samples, self.dataset_characteristics.n_features
235        )?;
236        writeln!(f, "Algorithms evaluated: {}", self.n_algorithms_evaluated)?;
237        writeln!(
238            f,
239            "Total evaluation time: {:.2}s",
240            self.total_evaluation_time
241        )?;
242        writeln!(f)?;
243        writeln!(
244            f,
245            "Best Algorithm: {} ({})",
246            self.best_algorithm.algorithm.name, self.best_algorithm.algorithm.family
247        )?;
248        writeln!(
249            f,
250            "Score: {:.4} ± {:.4}",
251            self.best_algorithm.cv_score, self.best_algorithm.cv_std
252        )?;
253        writeln!(
254            f,
255            "Training time: {:.2}s",
256            self.best_algorithm.training_time
257        )?;
258        writeln!(
259            f,
260            "Improvement over baseline: {:.4}",
261            self.improvement_over_baseline
262        )?;
263        writeln!(f)?;
264        writeln!(f, "Explanation: {}", self.explanation)?;
265        writeln!(f)?;
266        writeln!(
267            f,
268            "Top {} Algorithms:",
269            self.selected_algorithms.len().min(5)
270        )?;
271        for (i, alg) in self.selected_algorithms.iter().take(5).enumerate() {
272            writeln!(
273                f,
274                "{}. {} ({}) - Score: {:.4} ± {:.4}",
275                i + 1,
276                alg.algorithm.name,
277                alg.algorithm.family,
278                alg.cv_score,
279                alg.cv_std
280            )?;
281        }
282        Ok(())
283    }
284}
285
286/// Automated algorithm selector
287pub struct AutoMLAlgorithmSelector {
288    config: AutoMLConfig,
289    algorithm_catalog: HashMap<TaskType, Vec<AlgorithmSpec>>,
290}
291
292impl Default for AutoMLAlgorithmSelector {
293    fn default() -> Self {
294        Self::new(AutoMLConfig::default())
295    }
296}
297
298impl AutoMLAlgorithmSelector {
299    /// Create a new AutoML algorithm selector
300    pub fn new(config: AutoMLConfig) -> Self {
301        let algorithm_catalog = Self::build_algorithm_catalog();
302        Self {
303            config,
304            algorithm_catalog,
305        }
306    }
307
308    /// Analyze dataset characteristics
309    pub fn analyze_dataset(&self, X: &Array2<f64>, y: &Array1<f64>) -> DatasetCharacteristics {
310        let n_samples = X.nrows();
311        let n_features = X.ncols();
312
313        // Basic statistics
314        let missing_ratio = self.calculate_missing_ratio(X);
315        let sparsity = self.calculate_sparsity(X);
316        let correlation_condition_number = self.calculate_correlation_condition_number(X);
317
318        // Task-specific characteristics
319        let (n_classes, class_distribution, target_stats) = match self.config.task_type {
320            TaskType::Classification => {
321                let classes = self.get_unique_classes(y);
322                let class_dist = self.calculate_class_distribution(y, &classes);
323                (Some(classes.len()), Some(class_dist), None)
324            }
325            TaskType::Regression => {
326                let stats = self.calculate_target_statistics(y);
327                (None, None, Some(stats))
328            }
329        };
330
331        // Advanced characteristics
332        let linearity_score = self.estimate_linearity_score(X, y);
333        let noise_level = self.estimate_noise_level(X, y);
334        let effective_dimensionality = self.estimate_effective_dimensionality(X);
335        let categorical_ratio = self.calculate_categorical_ratio(X);
336
337        DatasetCharacteristics {
338            n_samples,
339            n_features,
340            n_classes,
341            class_distribution,
342            target_stats,
343            missing_ratio,
344            categorical_ratio,
345            correlation_condition_number,
346            sparsity,
347            effective_dimensionality,
348            noise_level,
349            linearity_score,
350        }
351    }
352
353    /// Select best algorithms for the dataset
354    pub fn select_algorithms(
355        &self,
356        X: &Array2<f64>,
357        y: &Array1<f64>,
358    ) -> Result<AlgorithmSelectionResult> {
359        let start_time = std::time::Instant::now();
360
361        // Analyze dataset characteristics
362        let dataset_chars = self.analyze_dataset(X, y);
363
364        // Get candidate algorithms
365        let candidate_algorithms = self.get_candidate_algorithms(&dataset_chars)?;
366
367        // Filter by constraints
368        let filtered_algorithms = self.filter_by_constraints(&candidate_algorithms, &dataset_chars);
369
370        // Evaluate algorithms
371        let mut evaluated_algorithms = self.evaluate_algorithms(&filtered_algorithms, X, y)?;
372
373        // Rank algorithms by performance
374        evaluated_algorithms.sort_by(|a, b| {
375            b.cv_score
376                .partial_cmp(&a.cv_score)
377                .expect("operation should succeed")
378        });
379
380        // Assign ranks and selection probabilities
381        let algorithms_copy = evaluated_algorithms.clone();
382        for (i, alg) in evaluated_algorithms.iter_mut().enumerate() {
383            alg.rank = i + 1;
384            alg.selection_probability = self.calculate_selection_probability(alg, &algorithms_copy);
385        }
386
387        let best_algorithm = evaluated_algorithms[0].clone();
388        let baseline_score = self.get_baseline_score(X, y)?;
389        let improvement = best_algorithm.cv_score - baseline_score;
390
391        let explanation = self.generate_explanation(&best_algorithm, &dataset_chars);
392
393        let total_time = start_time.elapsed().as_secs_f64();
394
395        Ok(AlgorithmSelectionResult {
396            selected_algorithms: evaluated_algorithms,
397            dataset_characteristics: dataset_chars,
398            total_evaluation_time: total_time,
399            n_algorithms_evaluated: filtered_algorithms.len(),
400            best_algorithm,
401            improvement_over_baseline: improvement,
402            explanation,
403        })
404    }
405
406    /// Build catalog of available algorithms
407    fn build_algorithm_catalog() -> HashMap<TaskType, Vec<AlgorithmSpec>> {
408        let mut catalog = HashMap::new();
409
410        // Classification algorithms
411        let classification_algorithms = vec![
412            // Linear classifiers
413            AlgorithmSpec {
414                family: AlgorithmFamily::Linear,
415                name: "LogisticRegression".to_string(),
416                default_params: [("C".to_string(), "1.0".to_string())]
417                    .iter()
418                    .cloned()
419                    .collect(),
420                param_space: [(
421                    "C".to_string(),
422                    vec![
423                        "0.001".to_string(),
424                        "0.01".to_string(),
425                        "0.1".to_string(),
426                        "1.0".to_string(),
427                        "10.0".to_string(),
428                        "100.0".to_string(),
429                    ],
430                )]
431                .iter()
432                .cloned()
433                .collect(),
434                complexity: 1.0,
435                memory_requirement: 1.0,
436                supports_proba: true,
437                handles_missing: false,
438                handles_categorical: false,
439                supports_incremental: false,
440            },
441            AlgorithmSpec {
442                family: AlgorithmFamily::Linear,
443                name: "RidgeClassifier".to_string(),
444                default_params: [("alpha".to_string(), "1.0".to_string())]
445                    .iter()
446                    .cloned()
447                    .collect(),
448                param_space: [(
449                    "alpha".to_string(),
450                    vec![
451                        "0.1".to_string(),
452                        "1.0".to_string(),
453                        "10.0".to_string(),
454                        "100.0".to_string(),
455                    ],
456                )]
457                .iter()
458                .cloned()
459                .collect(),
460                complexity: 1.0,
461                memory_requirement: 1.0,
462                supports_proba: false,
463                handles_missing: false,
464                handles_categorical: false,
465                supports_incremental: false,
466            },
467            // Tree-based classifiers
468            AlgorithmSpec {
469                family: AlgorithmFamily::TreeBased,
470                name: "DecisionTreeClassifier".to_string(),
471                default_params: [("max_depth".to_string(), "None".to_string())]
472                    .iter()
473                    .cloned()
474                    .collect(),
475                param_space: [
476                    (
477                        "max_depth".to_string(),
478                        vec![
479                            "3".to_string(),
480                            "5".to_string(),
481                            "10".to_string(),
482                            "None".to_string(),
483                        ],
484                    ),
485                    (
486                        "min_samples_split".to_string(),
487                        vec!["2".to_string(), "5".to_string(), "10".to_string()],
488                    ),
489                ]
490                .iter()
491                .cloned()
492                .collect(),
493                complexity: 2.0,
494                memory_requirement: 2.0,
495                supports_proba: true,
496                handles_missing: false,
497                handles_categorical: true,
498                supports_incremental: false,
499            },
500            AlgorithmSpec {
501                family: AlgorithmFamily::TreeBased,
502                name: "RandomForestClassifier".to_string(),
503                default_params: [("n_estimators".to_string(), "100".to_string())]
504                    .iter()
505                    .cloned()
506                    .collect(),
507                param_space: [
508                    (
509                        "n_estimators".to_string(),
510                        vec!["50".to_string(), "100".to_string(), "200".to_string()],
511                    ),
512                    (
513                        "max_depth".to_string(),
514                        vec![
515                            "3".to_string(),
516                            "5".to_string(),
517                            "10".to_string(),
518                            "None".to_string(),
519                        ],
520                    ),
521                ]
522                .iter()
523                .cloned()
524                .collect(),
525                complexity: 4.0,
526                memory_requirement: 4.0,
527                supports_proba: true,
528                handles_missing: false,
529                handles_categorical: true,
530                supports_incremental: false,
531            },
532            // Ensemble methods
533            AlgorithmSpec {
534                family: AlgorithmFamily::Ensemble,
535                name: "AdaBoostClassifier".to_string(),
536                default_params: [("n_estimators".to_string(), "50".to_string())]
537                    .iter()
538                    .cloned()
539                    .collect(),
540                param_space: [
541                    (
542                        "n_estimators".to_string(),
543                        vec!["25".to_string(), "50".to_string(), "100".to_string()],
544                    ),
545                    (
546                        "learning_rate".to_string(),
547                        vec!["0.1".to_string(), "0.5".to_string(), "1.0".to_string()],
548                    ),
549                ]
550                .iter()
551                .cloned()
552                .collect(),
553                complexity: 3.0,
554                memory_requirement: 3.0,
555                supports_proba: true,
556                handles_missing: false,
557                handles_categorical: true,
558                supports_incremental: false,
559            },
560            // K-Nearest Neighbors
561            AlgorithmSpec {
562                family: AlgorithmFamily::NeighborBased,
563                name: "KNeighborsClassifier".to_string(),
564                default_params: [("n_neighbors".to_string(), "5".to_string())]
565                    .iter()
566                    .cloned()
567                    .collect(),
568                param_space: [
569                    (
570                        "n_neighbors".to_string(),
571                        vec![
572                            "3".to_string(),
573                            "5".to_string(),
574                            "7".to_string(),
575                            "11".to_string(),
576                        ],
577                    ),
578                    (
579                        "weights".to_string(),
580                        vec!["uniform".to_string(), "distance".to_string()],
581                    ),
582                ]
583                .iter()
584                .cloned()
585                .collect(),
586                complexity: 1.0,
587                memory_requirement: 5.0,
588                supports_proba: true,
589                handles_missing: false,
590                handles_categorical: false,
591                supports_incremental: false,
592            },
593            // Naive Bayes
594            AlgorithmSpec {
595                family: AlgorithmFamily::NaiveBayes,
596                name: "GaussianNB".to_string(),
597                default_params: HashMap::new(),
598                param_space: HashMap::new(),
599                complexity: 1.0,
600                memory_requirement: 1.0,
601                supports_proba: true,
602                handles_missing: false,
603                handles_categorical: false,
604                supports_incremental: true,
605            },
606            // Support Vector Machine
607            AlgorithmSpec {
608                family: AlgorithmFamily::SVM,
609                name: "SVC".to_string(),
610                default_params: [
611                    ("C".to_string(), "1.0".to_string()),
612                    ("kernel".to_string(), "rbf".to_string()),
613                ]
614                .iter()
615                .cloned()
616                .collect(),
617                param_space: [
618                    (
619                        "C".to_string(),
620                        vec!["0.1".to_string(), "1.0".to_string(), "10.0".to_string()],
621                    ),
622                    (
623                        "kernel".to_string(),
624                        vec!["linear".to_string(), "rbf".to_string(), "poly".to_string()],
625                    ),
626                ]
627                .iter()
628                .cloned()
629                .collect(),
630                complexity: 3.0,
631                memory_requirement: 3.0,
632                supports_proba: false,
633                handles_missing: false,
634                handles_categorical: false,
635                supports_incremental: false,
636            },
637            // Dummy classifier for baseline
638            AlgorithmSpec {
639                family: AlgorithmFamily::Dummy,
640                name: "DummyClassifier".to_string(),
641                default_params: [("strategy".to_string(), "stratified".to_string())]
642                    .iter()
643                    .cloned()
644                    .collect(),
645                param_space: [(
646                    "strategy".to_string(),
647                    vec![
648                        "stratified".to_string(),
649                        "most_frequent".to_string(),
650                        "uniform".to_string(),
651                    ],
652                )]
653                .iter()
654                .cloned()
655                .collect(),
656                complexity: 0.1,
657                memory_requirement: 0.1,
658                supports_proba: true,
659                handles_missing: true,
660                handles_categorical: true,
661                supports_incremental: true,
662            },
663        ];
664
665        // Regression algorithms
666        let regression_algorithms = vec![
667            // Linear regressors
668            AlgorithmSpec {
669                family: AlgorithmFamily::Linear,
670                name: "LinearRegression".to_string(),
671                default_params: HashMap::new(),
672                param_space: HashMap::new(),
673                complexity: 1.0,
674                memory_requirement: 1.0,
675                supports_proba: false,
676                handles_missing: false,
677                handles_categorical: false,
678                supports_incremental: false,
679            },
680            AlgorithmSpec {
681                family: AlgorithmFamily::Linear,
682                name: "Ridge".to_string(),
683                default_params: [("alpha".to_string(), "1.0".to_string())]
684                    .iter()
685                    .cloned()
686                    .collect(),
687                param_space: [(
688                    "alpha".to_string(),
689                    vec![
690                        "0.1".to_string(),
691                        "1.0".to_string(),
692                        "10.0".to_string(),
693                        "100.0".to_string(),
694                    ],
695                )]
696                .iter()
697                .cloned()
698                .collect(),
699                complexity: 1.0,
700                memory_requirement: 1.0,
701                supports_proba: false,
702                handles_missing: false,
703                handles_categorical: false,
704                supports_incremental: false,
705            },
706            AlgorithmSpec {
707                family: AlgorithmFamily::Linear,
708                name: "Lasso".to_string(),
709                default_params: [("alpha".to_string(), "1.0".to_string())]
710                    .iter()
711                    .cloned()
712                    .collect(),
713                param_space: [(
714                    "alpha".to_string(),
715                    vec![
716                        "0.001".to_string(),
717                        "0.01".to_string(),
718                        "0.1".to_string(),
719                        "1.0".to_string(),
720                    ],
721                )]
722                .iter()
723                .cloned()
724                .collect(),
725                complexity: 1.5,
726                memory_requirement: 1.0,
727                supports_proba: false,
728                handles_missing: false,
729                handles_categorical: false,
730                supports_incremental: false,
731            },
732            // Tree-based regressors
733            AlgorithmSpec {
734                family: AlgorithmFamily::TreeBased,
735                name: "DecisionTreeRegressor".to_string(),
736                default_params: [("max_depth".to_string(), "None".to_string())]
737                    .iter()
738                    .cloned()
739                    .collect(),
740                param_space: [
741                    (
742                        "max_depth".to_string(),
743                        vec![
744                            "3".to_string(),
745                            "5".to_string(),
746                            "10".to_string(),
747                            "None".to_string(),
748                        ],
749                    ),
750                    (
751                        "min_samples_split".to_string(),
752                        vec!["2".to_string(), "5".to_string(), "10".to_string()],
753                    ),
754                ]
755                .iter()
756                .cloned()
757                .collect(),
758                complexity: 2.0,
759                memory_requirement: 2.0,
760                supports_proba: false,
761                handles_missing: false,
762                handles_categorical: true,
763                supports_incremental: false,
764            },
765            AlgorithmSpec {
766                family: AlgorithmFamily::TreeBased,
767                name: "RandomForestRegressor".to_string(),
768                default_params: [("n_estimators".to_string(), "100".to_string())]
769                    .iter()
770                    .cloned()
771                    .collect(),
772                param_space: [
773                    (
774                        "n_estimators".to_string(),
775                        vec!["50".to_string(), "100".to_string(), "200".to_string()],
776                    ),
777                    (
778                        "max_depth".to_string(),
779                        vec![
780                            "3".to_string(),
781                            "5".to_string(),
782                            "10".to_string(),
783                            "None".to_string(),
784                        ],
785                    ),
786                ]
787                .iter()
788                .cloned()
789                .collect(),
790                complexity: 4.0,
791                memory_requirement: 4.0,
792                supports_proba: false,
793                handles_missing: false,
794                handles_categorical: true,
795                supports_incremental: false,
796            },
797            // K-Nearest Neighbors
798            AlgorithmSpec {
799                family: AlgorithmFamily::NeighborBased,
800                name: "KNeighborsRegressor".to_string(),
801                default_params: [("n_neighbors".to_string(), "5".to_string())]
802                    .iter()
803                    .cloned()
804                    .collect(),
805                param_space: [
806                    (
807                        "n_neighbors".to_string(),
808                        vec![
809                            "3".to_string(),
810                            "5".to_string(),
811                            "7".to_string(),
812                            "11".to_string(),
813                        ],
814                    ),
815                    (
816                        "weights".to_string(),
817                        vec!["uniform".to_string(), "distance".to_string()],
818                    ),
819                ]
820                .iter()
821                .cloned()
822                .collect(),
823                complexity: 1.0,
824                memory_requirement: 5.0,
825                supports_proba: false,
826                handles_missing: false,
827                handles_categorical: false,
828                supports_incremental: false,
829            },
830            // Support Vector Machine
831            AlgorithmSpec {
832                family: AlgorithmFamily::SVM,
833                name: "SVR".to_string(),
834                default_params: [
835                    ("C".to_string(), "1.0".to_string()),
836                    ("kernel".to_string(), "rbf".to_string()),
837                ]
838                .iter()
839                .cloned()
840                .collect(),
841                param_space: [
842                    (
843                        "C".to_string(),
844                        vec!["0.1".to_string(), "1.0".to_string(), "10.0".to_string()],
845                    ),
846                    (
847                        "kernel".to_string(),
848                        vec!["linear".to_string(), "rbf".to_string(), "poly".to_string()],
849                    ),
850                ]
851                .iter()
852                .cloned()
853                .collect(),
854                complexity: 3.0,
855                memory_requirement: 3.0,
856                supports_proba: false,
857                handles_missing: false,
858                handles_categorical: false,
859                supports_incremental: false,
860            },
861            // Dummy regressor for baseline
862            AlgorithmSpec {
863                family: AlgorithmFamily::Dummy,
864                name: "DummyRegressor".to_string(),
865                default_params: [("strategy".to_string(), "mean".to_string())]
866                    .iter()
867                    .cloned()
868                    .collect(),
869                param_space: [(
870                    "strategy".to_string(),
871                    vec![
872                        "mean".to_string(),
873                        "median".to_string(),
874                        "constant".to_string(),
875                    ],
876                )]
877                .iter()
878                .cloned()
879                .collect(),
880                complexity: 0.1,
881                memory_requirement: 0.1,
882                supports_proba: false,
883                handles_missing: true,
884                handles_categorical: true,
885                supports_incremental: true,
886            },
887        ];
888
889        catalog.insert(TaskType::Classification, classification_algorithms);
890        catalog.insert(TaskType::Regression, regression_algorithms);
891        catalog
892    }
893
894    /// Get candidate algorithms for the given dataset characteristics
895    fn get_candidate_algorithms(
896        &self,
897        dataset_chars: &DatasetCharacteristics,
898    ) -> Result<Vec<AlgorithmSpec>> {
899        let algorithms = self
900            .algorithm_catalog
901            .get(&self.config.task_type)
902            .ok_or_else(|| SklearsError::InvalidParameter {
903                name: "task_type".to_string(),
904                reason: format!(
905                    "No algorithms available for task type: {:?}",
906                    self.config.task_type
907                ),
908            })?;
909
910        let mut candidates = Vec::new();
911
912        for algorithm in algorithms {
913            // Check if family is allowed
914            if let Some(ref allowed) = self.config.allowed_families {
915                if !allowed.contains(&algorithm.family) {
916                    continue;
917                }
918            }
919
920            // Check if family is excluded
921            if self.config.excluded_families.contains(&algorithm.family) {
922                continue;
923            }
924
925            // Apply heuristic filters based on dataset characteristics
926            if self.is_algorithm_suitable(algorithm, dataset_chars) {
927                candidates.push(algorithm.clone());
928            }
929        }
930
931        // Limit to max_algorithms
932        candidates.truncate(self.config.max_algorithms);
933
934        Ok(candidates)
935    }
936
937    /// Check if algorithm is suitable for the dataset
938    fn is_algorithm_suitable(
939        &self,
940        algorithm: &AlgorithmSpec,
941        dataset_chars: &DatasetCharacteristics,
942    ) -> bool {
943        // Skip dummy algorithms unless specifically requested
944        if algorithm.family == AlgorithmFamily::Dummy && !self.config.excluded_families.is_empty() {
945            return false;
946        }
947
948        // High-dimensional data heuristics
949        if dataset_chars.n_features > dataset_chars.n_samples {
950            // Prefer linear models for high-dimensional data
951            match algorithm.family {
952                AlgorithmFamily::Linear | AlgorithmFamily::NaiveBayes => return true,
953                AlgorithmFamily::NeighborBased | AlgorithmFamily::SVM => return false,
954                _ => {}
955            }
956        }
957
958        // Small dataset heuristics
959        if dataset_chars.n_samples < 100 {
960            // Avoid overly complex models for small datasets
961            if algorithm.complexity > 3.0 {
962                return false;
963            }
964        }
965
966        // Large dataset heuristics
967        if dataset_chars.n_samples > 10000 {
968            // Prefer scalable algorithms for large datasets
969            match algorithm.family {
970                AlgorithmFamily::NeighborBased => return false, // KNN doesn't scale well
971                AlgorithmFamily::SVM => return dataset_chars.n_samples < 50000, // SVM has cubic complexity
972                _ => {}
973            }
974        }
975
976        // Linearity heuristics
977        if dataset_chars.linearity_score > 0.8 {
978            // Prefer linear models for linear data
979            match algorithm.family {
980                AlgorithmFamily::Linear => return true,
981                AlgorithmFamily::TreeBased | AlgorithmFamily::Ensemble => return false,
982                _ => {}
983            }
984        }
985
986        // Missing value handling
987        if dataset_chars.missing_ratio > 0.0 && !algorithm.handles_missing {
988            return false;
989        }
990
991        true
992    }
993
994    /// Filter algorithms by computational constraints
995    fn filter_by_constraints(
996        &self,
997        algorithms: &[AlgorithmSpec],
998        dataset_chars: &DatasetCharacteristics,
999    ) -> Vec<AlgorithmSpec> {
1000        algorithms
1001            .iter()
1002            .filter(|alg| self.satisfies_constraints(alg, dataset_chars))
1003            .cloned()
1004            .collect()
1005    }
1006
1007    /// Check if algorithm satisfies computational constraints
1008    fn satisfies_constraints(
1009        &self,
1010        algorithm: &AlgorithmSpec,
1011        dataset_chars: &DatasetCharacteristics,
1012    ) -> bool {
1013        // Rough complexity estimates
1014        let estimated_training_time = self.estimate_training_time(algorithm, dataset_chars);
1015        let estimated_memory_usage = self.estimate_memory_usage(algorithm, dataset_chars);
1016
1017        if let Some(max_time) = self.config.constraints.max_training_time {
1018            if estimated_training_time > max_time {
1019                return false;
1020            }
1021        }
1022
1023        if let Some(max_memory) = self.config.constraints.max_memory_gb {
1024            if estimated_memory_usage > max_memory {
1025                return false;
1026            }
1027        }
1028
1029        true
1030    }
1031
1032    /// Estimate training time for algorithm
1033    fn estimate_training_time(
1034        &self,
1035        algorithm: &AlgorithmSpec,
1036        dataset_chars: &DatasetCharacteristics,
1037    ) -> f64 {
1038        let n = dataset_chars.n_samples as f64;
1039        let p = dataset_chars.n_features as f64;
1040
1041        // Base time estimates (in seconds for 1000 samples, 10 features)
1042        let base_time = match algorithm.family {
1043            AlgorithmFamily::Linear => 0.1,
1044            AlgorithmFamily::TreeBased => {
1045                if algorithm.name.contains("Random") {
1046                    2.0
1047                } else {
1048                    0.5
1049                }
1050            }
1051            AlgorithmFamily::Ensemble => 3.0,
1052            AlgorithmFamily::NeighborBased => 0.05, // Training is fast, prediction is slow
1053            AlgorithmFamily::SVM => 1.0,
1054            AlgorithmFamily::NaiveBayes => 0.05,
1055            AlgorithmFamily::NeuralNetwork => 5.0,
1056            AlgorithmFamily::GaussianProcess => 2.0,
1057            AlgorithmFamily::DiscriminantAnalysis => 0.2,
1058            AlgorithmFamily::Dummy => 0.01,
1059        };
1060
1061        // Scale by complexity and data size
1062        base_time * algorithm.complexity * (n / 1000.0) * (p / 10.0).sqrt()
1063    }
1064
1065    /// Estimate memory usage for algorithm
1066    fn estimate_memory_usage(
1067        &self,
1068        algorithm: &AlgorithmSpec,
1069        dataset_chars: &DatasetCharacteristics,
1070    ) -> f64 {
1071        let n = dataset_chars.n_samples as f64;
1072        let p = dataset_chars.n_features as f64;
1073
1074        // Base memory in GB
1075        let base_memory_mb = match algorithm.family {
1076            AlgorithmFamily::Linear => 1.0,
1077            AlgorithmFamily::TreeBased => {
1078                if algorithm.name.contains("Random") {
1079                    50.0
1080                } else {
1081                    10.0
1082                }
1083            }
1084            AlgorithmFamily::Ensemble => 100.0,
1085            AlgorithmFamily::NeighborBased => n * p * 8.0 / 1_000_000.0, // Store all training data
1086            AlgorithmFamily::SVM => 20.0,
1087            AlgorithmFamily::NaiveBayes => 1.0,
1088            AlgorithmFamily::NeuralNetwork => 50.0,
1089            AlgorithmFamily::GaussianProcess => 10.0,
1090            AlgorithmFamily::DiscriminantAnalysis => 5.0,
1091            AlgorithmFamily::Dummy => 0.1,
1092        };
1093
1094        (base_memory_mb * algorithm.memory_requirement) / 1000.0 // Convert to GB
1095    }
1096
1097    /// Evaluate algorithms using cross-validation
1098    fn evaluate_algorithms(
1099        &self,
1100        algorithms: &[AlgorithmSpec],
1101        X: &Array2<f64>,
1102        y: &Array1<f64>,
1103    ) -> Result<Vec<RankedAlgorithm>> {
1104        let mut results = Vec::new();
1105
1106        for algorithm in algorithms {
1107            let start_time = std::time::Instant::now();
1108
1109            // Create mock cross-validation results
1110            // In a real implementation, this would actually train and evaluate the models
1111            let cv_score = self.mock_evaluate_algorithm(algorithm, X, y);
1112            let cv_std = cv_score * 0.05; // Mock standard deviation
1113
1114            let training_time = start_time.elapsed().as_secs_f64();
1115            let memory_usage = self.estimate_memory_usage(algorithm, &self.analyze_dataset(X, y));
1116
1117            results.push(RankedAlgorithm {
1118                algorithm: algorithm.clone(),
1119                cv_score,
1120                cv_std,
1121                training_time,
1122                memory_usage,
1123                best_params: algorithm.default_params.clone(),
1124                rank: 0,                    // Will be set later
1125                selection_probability: 0.0, // Will be set later
1126            });
1127        }
1128
1129        Ok(results)
1130    }
1131
1132    /// Mock algorithm evaluation (replace with actual implementation)
1133    fn mock_evaluate_algorithm(
1134        &self,
1135        algorithm: &AlgorithmSpec,
1136        X: &Array2<f64>,
1137        y: &Array1<f64>,
1138    ) -> f64 {
1139        // Generate mock scores based on algorithm characteristics and data
1140        let dataset_chars = self.analyze_dataset(X, y);
1141
1142        let base_score = match self.config.task_type {
1143            TaskType::Classification => 0.7, // Mock accuracy
1144            TaskType::Regression => 0.8,     // Mock R²
1145        };
1146
1147        // Adjust score based on algorithm suitability
1148        let mut score: f64 = base_score;
1149
1150        // Linear models perform better on linear data
1151        if algorithm.family == AlgorithmFamily::Linear && dataset_chars.linearity_score > 0.7 {
1152            score += 0.1;
1153        }
1154
1155        // Tree models perform better on non-linear data
1156        if matches!(
1157            algorithm.family,
1158            AlgorithmFamily::TreeBased | AlgorithmFamily::Ensemble
1159        ) && dataset_chars.linearity_score < 0.5
1160        {
1161            score += 0.1;
1162        }
1163
1164        // Ensemble methods generally perform better
1165        if algorithm.family == AlgorithmFamily::Ensemble {
1166            score += 0.05;
1167        }
1168
1169        // Add some noise
1170        //         use scirs2_core::random::Rng;
1171        let mut rng = scirs2_core::random::thread_rng();
1172        score += rng.gen_range(-0.05..0.05);
1173
1174        score.clamp(0.0, 1.0)
1175    }
1176
1177    /// Calculate selection probability based on performance
1178    fn calculate_selection_probability(
1179        &self,
1180        algorithm: &RankedAlgorithm,
1181        all_algorithms: &[RankedAlgorithm],
1182    ) -> f64 {
1183        let max_score = all_algorithms
1184            .iter()
1185            .map(|a| a.cv_score)
1186            .fold(0.0, f64::max);
1187        let min_score = all_algorithms
1188            .iter()
1189            .map(|a| a.cv_score)
1190            .fold(1.0, f64::min);
1191
1192        if max_score == min_score {
1193            return 1.0 / all_algorithms.len() as f64;
1194        }
1195
1196        // Softmax-like probability
1197        let normalized_score = (algorithm.cv_score - min_score) / (max_score - min_score);
1198        let exp_score = (normalized_score * 5.0).exp();
1199        let total_exp: f64 = all_algorithms
1200            .iter()
1201            .map(|a| {
1202                let norm = (a.cv_score - min_score) / (max_score - min_score);
1203                (norm * 5.0).exp()
1204            })
1205            .sum();
1206
1207        exp_score / total_exp
1208    }
1209
1210    /// Get baseline score for comparison
1211    fn get_baseline_score(&self, _X: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
1212        match self.config.task_type {
1213            TaskType::Classification => {
1214                // Most frequent class accuracy
1215                let classes = self.get_unique_classes(y);
1216                let class_counts = self.calculate_class_distribution(y, &classes);
1217                Ok(class_counts.iter().fold(0.0, |acc, &x| acc.max(x)))
1218            }
1219            TaskType::Regression => {
1220                // R² of predicting mean
1221                let mean = y.mean().expect("operation should succeed");
1222                let tss: f64 = y.iter().map(|&yi| (yi - mean).powi(2)).sum();
1223                let rss = tss; // Predicting mean gives R² = 0
1224                Ok(1.0 - rss / tss)
1225            }
1226        }
1227    }
1228
1229    /// Generate explanation for the selected algorithm
1230    fn generate_explanation(
1231        &self,
1232        best_algorithm: &RankedAlgorithm,
1233        dataset_chars: &DatasetCharacteristics,
1234    ) -> String {
1235        let mut explanation = format!(
1236            "{} ({}) was selected as the best algorithm with a cross-validation score of {:.4}.",
1237            best_algorithm.algorithm.name, best_algorithm.algorithm.family, best_algorithm.cv_score
1238        );
1239
1240        // Add reasoning based on dataset characteristics
1241        if dataset_chars.n_samples < 1000 {
1242            explanation.push_str(" This algorithm is well-suited for small datasets.");
1243        } else if dataset_chars.n_samples > 10000 {
1244            explanation.push_str(" This algorithm scales well to large datasets.");
1245        }
1246
1247        if dataset_chars.linearity_score > 0.7
1248            && best_algorithm.algorithm.family == AlgorithmFamily::Linear
1249        {
1250            explanation.push_str(
1251                " The linear nature of your data makes linear models particularly effective.",
1252            );
1253        }
1254
1255        if dataset_chars.n_features > dataset_chars.n_samples {
1256            explanation.push_str(" The high-dimensional nature of your data favors this algorithm's regularization capabilities.");
1257        }
1258
1259        if best_algorithm.algorithm.family == AlgorithmFamily::Ensemble {
1260            explanation.push_str(
1261                " Ensemble methods often provide robust performance across diverse datasets.",
1262            );
1263        }
1264
1265        explanation
1266    }
1267
1268    // Helper methods for dataset analysis
1269    fn calculate_missing_ratio(&self, X: &Array2<f64>) -> f64 {
1270        let total_values = X.len() as f64;
1271        let missing_count = X.iter().filter(|&&x| x.is_nan()).count() as f64;
1272        missing_count / total_values
1273    }
1274
1275    fn calculate_sparsity(&self, X: &Array2<f64>) -> f64 {
1276        let total_values = X.len() as f64;
1277        let zero_count = X.iter().filter(|&&x| x == 0.0).count() as f64;
1278        zero_count / total_values
1279    }
1280
1281    fn calculate_categorical_ratio(&self, X: &Array2<f64>) -> f64 {
1282        let n_features = X.ncols();
1283        if n_features == 0 {
1284            return 0.0;
1285        }
1286
1287        let mut categorical_count = 0;
1288        for col_idx in 0..n_features {
1289            let column = X.column(col_idx);
1290
1291            // Filter out NaN values
1292            let valid_values: Vec<f64> = column.iter().filter(|&&x| !x.is_nan()).copied().collect();
1293
1294            if valid_values.is_empty() {
1295                continue;
1296            }
1297
1298            // Count unique values
1299            let mut unique_values = valid_values.clone();
1300            unique_values.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
1301            unique_values.dedup();
1302
1303            let n_unique = unique_values.len();
1304            let n_total = valid_values.len();
1305
1306            // Heuristic: Consider categorical if:
1307            // 1. Has 10 or fewer unique values, OR
1308            // 2. Unique ratio < 5% and all values appear to be integers
1309            let unique_ratio = n_unique as f64 / n_total as f64;
1310            let all_integers = valid_values.iter().all(|&x| (x - x.round()).abs() < 1e-10);
1311
1312            if n_unique <= 10 || (unique_ratio < 0.05 && all_integers) {
1313                categorical_count += 1;
1314            }
1315        }
1316
1317        categorical_count as f64 / n_features as f64
1318    }
1319
1320    fn calculate_correlation_condition_number(&self, _X: &Array2<f64>) -> f64 {
1321        // Mock implementation - would need actual correlation matrix computation
1322        //         use scirs2_core::random::Rng;
1323        let mut rng = scirs2_core::random::thread_rng();
1324        rng.gen_range(1.0..100.0)
1325    }
1326
1327    fn get_unique_classes(&self, y: &Array1<f64>) -> Vec<i32> {
1328        let mut classes: Vec<i32> = y.iter().map(|&x| x as i32).collect();
1329        classes.sort_unstable();
1330        classes.dedup();
1331        classes
1332    }
1333
1334    fn calculate_class_distribution(&self, y: &Array1<f64>, classes: &[i32]) -> Vec<f64> {
1335        let total = y.len() as f64;
1336        classes
1337            .iter()
1338            .map(|&class| {
1339                let count = y.iter().filter(|&&yi| yi as i32 == class).count() as f64;
1340                count / total
1341            })
1342            .collect()
1343    }
1344
1345    fn calculate_target_statistics(&self, y: &Array1<f64>) -> TargetStatistics {
1346        let mean = y.mean().expect("operation should succeed");
1347        let std = y.std(0.0);
1348
1349        // Mock calculations for advanced statistics
1350        TargetStatistics {
1351            mean,
1352            std,
1353            skewness: 0.0, // Would need actual skewness calculation
1354            kurtosis: 0.0, // Would need actual kurtosis calculation
1355            n_outliers: 0, // Would need outlier detection
1356        }
1357    }
1358
1359    fn estimate_linearity_score(&self, _X: &Array2<f64>, _y: &Array1<f64>) -> f64 {
1360        // Mock implementation - would need actual linearity testing
1361        //         use scirs2_core::random::Rng;
1362        let mut rng = scirs2_core::random::thread_rng();
1363        rng.gen_range(0.0..1.0)
1364    }
1365
1366    fn estimate_noise_level(&self, _X: &Array2<f64>, _y: &Array1<f64>) -> f64 {
1367        // Mock implementation - would need actual noise estimation
1368        //         use scirs2_core::random::Rng;
1369        let mut rng = scirs2_core::random::thread_rng();
1370        rng.gen_range(0.0..0.5)
1371    }
1372
1373    fn estimate_effective_dimensionality(&self, X: &Array2<f64>) -> Option<usize> {
1374        // Mock implementation - would need PCA analysis
1375        Some((X.ncols() as f64 * 0.8) as usize)
1376    }
1377}
1378
1379/// Convenience function for quick algorithm selection
1380pub fn select_best_algorithm(
1381    X: &Array2<f64>,
1382    y: &Array1<f64>,
1383    task_type: TaskType,
1384) -> Result<AlgorithmSelectionResult> {
1385    let config = AutoMLConfig {
1386        task_type,
1387        ..Default::default()
1388    };
1389
1390    let selector = AutoMLAlgorithmSelector::new(config);
1391    selector.select_algorithms(X, y)
1392}
1393
1394#[allow(non_snake_case)]
1395#[cfg(test)]
1396mod tests {
1397    use super::*;
1398    use scirs2_core::ndarray::{Array1, Array2};
1399
1400    #[allow(non_snake_case)]
1401    fn create_test_classification_data() -> (Array2<f64>, Array1<f64>) {
1402        let X = Array2::from_shape_vec((100, 4), (0..400).map(|i| i as f64).collect())
1403            .expect("operation should succeed");
1404        let y = Array1::from_vec((0..100).map(|i| (i % 3) as f64).collect());
1405        (X, y)
1406    }
1407
1408    #[allow(non_snake_case)]
1409    fn create_test_regression_data() -> (Array2<f64>, Array1<f64>) {
1410        let X = Array2::from_shape_vec((100, 4), (0..400).map(|i| i as f64).collect())
1411            .expect("operation should succeed");
1412        use scirs2_core::essentials::Uniform;
1413        use scirs2_core::random::{thread_rng, Distribution};
1414        let mut rng = thread_rng();
1415        let dist = Uniform::new(0.0, 1.0).expect("operation should succeed");
1416        let y = Array1::from_vec((0..100).map(|i| i as f64 + dist.sample(&mut rng)).collect());
1417        (X, y)
1418    }
1419
1420    #[test]
1421    fn test_algorithm_selection_classification() {
1422        let (X, y) = create_test_classification_data();
1423        let result = select_best_algorithm(&X, &y, TaskType::Classification);
1424        assert!(result.is_ok());
1425
1426        let result = result.expect("operation should succeed");
1427        assert!(!result.selected_algorithms.is_empty());
1428        assert!(result.best_algorithm.cv_score > 0.0);
1429    }
1430
1431    #[test]
1432    fn test_algorithm_selection_regression() {
1433        let (X, y) = create_test_regression_data();
1434        let result = select_best_algorithm(&X, &y, TaskType::Regression);
1435        assert!(result.is_ok());
1436
1437        let result = result.expect("operation should succeed");
1438        assert!(!result.selected_algorithms.is_empty());
1439        assert!(result.best_algorithm.cv_score > 0.0);
1440    }
1441
1442    #[test]
1443    fn test_dataset_characteristics_analysis() {
1444        let (X, y) = create_test_classification_data();
1445        let config = AutoMLConfig::default();
1446        let selector = AutoMLAlgorithmSelector::new(config);
1447
1448        let chars = selector.analyze_dataset(&X, &y);
1449        assert_eq!(chars.n_samples, 100);
1450        assert_eq!(chars.n_features, 4);
1451        assert_eq!(chars.n_classes, Some(3));
1452    }
1453
1454    #[test]
1455    fn test_custom_config() {
1456        let (X, y) = create_test_classification_data();
1457
1458        let config = AutoMLConfig {
1459            task_type: TaskType::Classification,
1460            max_algorithms: 3,
1461            allowed_families: Some(vec![AlgorithmFamily::Linear, AlgorithmFamily::TreeBased]),
1462            ..Default::default()
1463        };
1464
1465        let selector = AutoMLAlgorithmSelector::new(config);
1466        let result = selector.select_algorithms(&X, &y);
1467        assert!(result.is_ok());
1468
1469        let result = result.expect("operation should succeed");
1470        assert!(result.n_algorithms_evaluated <= 3);
1471
1472        for alg in &result.selected_algorithms {
1473            assert!(matches!(
1474                alg.algorithm.family,
1475                AlgorithmFamily::Linear | AlgorithmFamily::TreeBased
1476            ));
1477        }
1478    }
1479
1480    #[test]
1481    fn test_computational_constraints() {
1482        let (X, y) = create_test_classification_data();
1483
1484        let config = AutoMLConfig {
1485            task_type: TaskType::Classification,
1486            constraints: ComputationalConstraints {
1487                max_training_time: Some(1.0), // Very short time limit
1488                max_memory_gb: Some(0.1),     // Very low memory limit
1489                ..Default::default()
1490            },
1491            ..Default::default()
1492        };
1493
1494        let selector = AutoMLAlgorithmSelector::new(config);
1495        let result = selector.select_algorithms(&X, &y);
1496        assert!(result.is_ok());
1497
1498        let result = result.expect("operation should succeed");
1499        // Should still find at least simple algorithms
1500        assert!(!result.selected_algorithms.is_empty());
1501    }
1502}