sklears_model_selection/
automl_algorithm_selection.rs

1//! Automated Algorithm Selection for AutoML
2//!
3//! This module provides intelligent algorithm selection based on dataset characteristics,
4//! computational constraints, and performance requirements. It automatically selects
5//! and configures the best algorithms for classification and regression tasks.
6
7use crate::scoring::TaskType;
8use scirs2_core::ndarray::{Array1, Array2};
9use sklears_core::error::{Result, SklearsError};
10use std::collections::HashMap;
11use std::fmt;
12// use serde::{Deserialize, Serialize};
13
14/// Algorithm family categories for classification and regression
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub enum AlgorithmFamily {
17    /// Linear models (LogisticRegression, LinearRegression, etc.)
18    Linear,
19    /// Tree-based models (DecisionTree, RandomForest, etc.)
20    TreeBased,
21    /// Ensemble methods (AdaBoost, Stacking, etc.)
22    Ensemble,
23    /// Neighbor-based models (KNN, etc.)
24    NeighborBased,
25    /// Support Vector Machines
26    SVM,
27    /// Naive Bayes classifiers
28    NaiveBayes,
29    /// Neural networks
30    NeuralNetwork,
31    /// Gaussian processes
32    GaussianProcess,
33    /// Discriminant analysis
34    DiscriminantAnalysis,
35    /// Dummy/baseline models
36    Dummy,
37}
38
39impl fmt::Display for AlgorithmFamily {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        match self {
42            AlgorithmFamily::Linear => write!(f, "Linear"),
43            AlgorithmFamily::TreeBased => write!(f, "Tree-based"),
44            AlgorithmFamily::Ensemble => write!(f, "Ensemble"),
45            AlgorithmFamily::NeighborBased => write!(f, "Neighbor-based"),
46            AlgorithmFamily::SVM => write!(f, "Support Vector Machine"),
47            AlgorithmFamily::NaiveBayes => write!(f, "Naive Bayes"),
48            AlgorithmFamily::NeuralNetwork => write!(f, "Neural Network"),
49            AlgorithmFamily::GaussianProcess => write!(f, "Gaussian Process"),
50            AlgorithmFamily::DiscriminantAnalysis => write!(f, "Discriminant Analysis"),
51            AlgorithmFamily::Dummy => write!(f, "Dummy/Baseline"),
52        }
53    }
54}
55
56/// Specific algorithm within a family
57#[derive(Debug, Clone, PartialEq)]
58pub struct AlgorithmSpec {
59    /// Algorithm family
60    pub family: AlgorithmFamily,
61    /// Specific algorithm name
62    pub name: String,
63    /// Default hyperparameters
64    pub default_params: HashMap<String, String>,
65    /// Hyperparameter search space
66    pub param_space: HashMap<String, Vec<String>>,
67    /// Computational complexity (relative scale)
68    pub complexity: f64,
69    /// Memory requirements (relative scale)
70    pub memory_requirement: f64,
71    /// Supports probability prediction
72    pub supports_proba: bool,
73    /// Handles missing values
74    pub handles_missing: bool,
75    /// Handles categorical features
76    pub handles_categorical: bool,
77    /// Supports incremental learning
78    pub supports_incremental: bool,
79}
80
81/// Dataset characteristics for algorithm selection
82#[derive(Debug, Clone)]
83pub struct DatasetCharacteristics {
84    /// Number of samples
85    pub n_samples: usize,
86    /// Number of features
87    pub n_features: usize,
88    /// Number of classes (for classification)
89    pub n_classes: Option<usize>,
90    /// Class distribution (for classification)
91    pub class_distribution: Option<Vec<f64>>,
92    /// Target distribution (for regression)
93    pub target_stats: Option<TargetStatistics>,
94    /// Missing value ratio
95    pub missing_ratio: f64,
96    /// Categorical feature ratio
97    pub categorical_ratio: f64,
98    /// Feature correlation matrix condition number
99    pub correlation_condition_number: f64,
100    /// Dataset sparsity
101    pub sparsity: f64,
102    /// Effective dimensionality (based on PCA)
103    pub effective_dimensionality: Option<usize>,
104    /// Estimated noise level
105    pub noise_level: f64,
106    /// Linearity score (0-1, higher means more linear)
107    pub linearity_score: f64,
108}
109
110/// Target statistics for regression tasks
111#[derive(Debug, Clone)]
112pub struct TargetStatistics {
113    /// Mean of target values
114    pub mean: f64,
115    /// Standard deviation of target values
116    pub std: f64,
117    /// Skewness of target distribution
118    pub skewness: f64,
119    /// Kurtosis of target distribution
120    pub kurtosis: f64,
121    /// Number of outliers
122    pub n_outliers: usize,
123}
124
125/// Computational constraints for algorithm selection
126#[derive(Debug, Clone, Default)]
127pub struct ComputationalConstraints {
128    /// Maximum training time in seconds
129    pub max_training_time: Option<f64>,
130    /// Maximum memory usage in GB
131    pub max_memory_gb: Option<f64>,
132    /// Maximum model size in MB
133    pub max_model_size_mb: Option<f64>,
134    /// Maximum inference time per sample in milliseconds
135    pub max_inference_time_ms: Option<f64>,
136    /// Available CPU cores
137    pub n_cores: Option<usize>,
138    /// GPU availability
139    pub has_gpu: bool,
140}
141
142/// Configuration for automated algorithm selection
143#[derive(Debug, Clone)]
144pub struct AutoMLConfig {
145    /// Task type (classification or regression)
146    pub task_type: TaskType,
147    /// Computational constraints
148    pub constraints: ComputationalConstraints,
149    /// Algorithm families to consider (None means all)
150    pub allowed_families: Option<Vec<AlgorithmFamily>>,
151    /// Algorithm families to exclude
152    pub excluded_families: Vec<AlgorithmFamily>,
153    /// Maximum number of algorithms to evaluate
154    pub max_algorithms: usize,
155    /// Cross-validation strategy
156    pub cv_folds: usize,
157    /// Scoring metric
158    pub scoring_metric: String,
159    /// Time budget for hyperparameter optimization per algorithm
160    pub hyperopt_time_budget: f64,
161    /// Random seed for reproducibility
162    pub random_seed: Option<u64>,
163    /// Whether to use ensemble methods
164    pub enable_ensembles: bool,
165    /// Whether to perform feature engineering
166    pub enable_feature_engineering: bool,
167}
168
169impl Default for AutoMLConfig {
170    fn default() -> Self {
171        Self {
172            task_type: TaskType::Classification,
173            constraints: ComputationalConstraints::default(),
174            allowed_families: None,
175            excluded_families: Vec::new(),
176            max_algorithms: 10,
177            cv_folds: 5,
178            scoring_metric: "accuracy".to_string(),
179            hyperopt_time_budget: 300.0, // 5 minutes per algorithm
180            random_seed: None,
181            enable_ensembles: true,
182            enable_feature_engineering: true,
183        }
184    }
185}
186
187/// Result of algorithm selection process
188#[derive(Debug, Clone)]
189pub struct AlgorithmSelectionResult {
190    /// Selected algorithm specifications
191    pub selected_algorithms: Vec<RankedAlgorithm>,
192    /// Dataset characteristics used for selection
193    pub dataset_characteristics: DatasetCharacteristics,
194    /// Total evaluation time
195    pub total_evaluation_time: f64,
196    /// Number of algorithms evaluated
197    pub n_algorithms_evaluated: usize,
198    /// Best performing algorithm
199    pub best_algorithm: RankedAlgorithm,
200    /// Performance improvement over baseline
201    pub improvement_over_baseline: f64,
202    /// Recommendation explanation
203    pub explanation: String,
204}
205
206/// Algorithm with performance ranking
207#[derive(Debug, Clone)]
208pub struct RankedAlgorithm {
209    /// Algorithm specification
210    pub algorithm: AlgorithmSpec,
211    /// Cross-validation score
212    pub cv_score: f64,
213    /// Standard deviation of CV scores
214    pub cv_std: f64,
215    /// Training time in seconds
216    pub training_time: f64,
217    /// Memory usage in MB
218    pub memory_usage: f64,
219    /// Optimized hyperparameters
220    pub best_params: HashMap<String, String>,
221    /// Rank among all algorithms (1 = best)
222    pub rank: usize,
223    /// Selection probability based on performance
224    pub selection_probability: f64,
225}
226
227impl fmt::Display for AlgorithmSelectionResult {
228    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229        writeln!(f, "AutoML Algorithm Selection Results")?;
230        writeln!(f, "==================================")?;
231        writeln!(
232            f,
233            "Dataset: {} samples, {} features",
234            self.dataset_characteristics.n_samples, self.dataset_characteristics.n_features
235        )?;
236        writeln!(f, "Algorithms evaluated: {}", self.n_algorithms_evaluated)?;
237        writeln!(
238            f,
239            "Total evaluation time: {:.2}s",
240            self.total_evaluation_time
241        )?;
242        writeln!(f)?;
243        writeln!(
244            f,
245            "Best Algorithm: {} ({})",
246            self.best_algorithm.algorithm.name, self.best_algorithm.algorithm.family
247        )?;
248        writeln!(
249            f,
250            "Score: {:.4} ± {:.4}",
251            self.best_algorithm.cv_score, self.best_algorithm.cv_std
252        )?;
253        writeln!(
254            f,
255            "Training time: {:.2}s",
256            self.best_algorithm.training_time
257        )?;
258        writeln!(
259            f,
260            "Improvement over baseline: {:.4}",
261            self.improvement_over_baseline
262        )?;
263        writeln!(f)?;
264        writeln!(f, "Explanation: {}", self.explanation)?;
265        writeln!(f)?;
266        writeln!(
267            f,
268            "Top {} Algorithms:",
269            self.selected_algorithms.len().min(5)
270        )?;
271        for (i, alg) in self.selected_algorithms.iter().take(5).enumerate() {
272            writeln!(
273                f,
274                "{}. {} ({}) - Score: {:.4} ± {:.4}",
275                i + 1,
276                alg.algorithm.name,
277                alg.algorithm.family,
278                alg.cv_score,
279                alg.cv_std
280            )?;
281        }
282        Ok(())
283    }
284}
285
286/// Automated algorithm selector
287pub struct AutoMLAlgorithmSelector {
288    config: AutoMLConfig,
289    algorithm_catalog: HashMap<TaskType, Vec<AlgorithmSpec>>,
290}
291
292impl Default for AutoMLAlgorithmSelector {
293    fn default() -> Self {
294        Self::new(AutoMLConfig::default())
295    }
296}
297
298impl AutoMLAlgorithmSelector {
299    /// Create a new AutoML algorithm selector
300    pub fn new(config: AutoMLConfig) -> Self {
301        let algorithm_catalog = Self::build_algorithm_catalog();
302        Self {
303            config,
304            algorithm_catalog,
305        }
306    }
307
308    /// Analyze dataset characteristics
309    pub fn analyze_dataset(&self, X: &Array2<f64>, y: &Array1<f64>) -> DatasetCharacteristics {
310        let n_samples = X.nrows();
311        let n_features = X.ncols();
312
313        // Basic statistics
314        let missing_ratio = self.calculate_missing_ratio(X);
315        let sparsity = self.calculate_sparsity(X);
316        let correlation_condition_number = self.calculate_correlation_condition_number(X);
317
318        // Task-specific characteristics
319        let (n_classes, class_distribution, target_stats) = match self.config.task_type {
320            TaskType::Classification => {
321                let classes = self.get_unique_classes(y);
322                let class_dist = self.calculate_class_distribution(y, &classes);
323                (Some(classes.len()), Some(class_dist), None)
324            }
325            TaskType::Regression => {
326                let stats = self.calculate_target_statistics(y);
327                (None, None, Some(stats))
328            }
329        };
330
331        // Advanced characteristics
332        let linearity_score = self.estimate_linearity_score(X, y);
333        let noise_level = self.estimate_noise_level(X, y);
334        let effective_dimensionality = self.estimate_effective_dimensionality(X);
335        let categorical_ratio = self.calculate_categorical_ratio(X);
336
337        DatasetCharacteristics {
338            n_samples,
339            n_features,
340            n_classes,
341            class_distribution,
342            target_stats,
343            missing_ratio,
344            categorical_ratio,
345            correlation_condition_number,
346            sparsity,
347            effective_dimensionality,
348            noise_level,
349            linearity_score,
350        }
351    }
352
353    /// Select best algorithms for the dataset
354    pub fn select_algorithms(
355        &self,
356        X: &Array2<f64>,
357        y: &Array1<f64>,
358    ) -> Result<AlgorithmSelectionResult> {
359        let start_time = std::time::Instant::now();
360
361        // Analyze dataset characteristics
362        let dataset_chars = self.analyze_dataset(X, y);
363
364        // Get candidate algorithms
365        let candidate_algorithms = self.get_candidate_algorithms(&dataset_chars)?;
366
367        // Filter by constraints
368        let filtered_algorithms = self.filter_by_constraints(&candidate_algorithms, &dataset_chars);
369
370        // Evaluate algorithms
371        let mut evaluated_algorithms = self.evaluate_algorithms(&filtered_algorithms, X, y)?;
372
373        // Rank algorithms by performance
374        evaluated_algorithms.sort_by(|a, b| b.cv_score.partial_cmp(&a.cv_score).unwrap());
375
376        // Assign ranks and selection probabilities
377        let algorithms_copy = evaluated_algorithms.clone();
378        for (i, alg) in evaluated_algorithms.iter_mut().enumerate() {
379            alg.rank = i + 1;
380            alg.selection_probability = self.calculate_selection_probability(alg, &algorithms_copy);
381        }
382
383        let best_algorithm = evaluated_algorithms[0].clone();
384        let baseline_score = self.get_baseline_score(X, y)?;
385        let improvement = best_algorithm.cv_score - baseline_score;
386
387        let explanation = self.generate_explanation(&best_algorithm, &dataset_chars);
388
389        let total_time = start_time.elapsed().as_secs_f64();
390
391        Ok(AlgorithmSelectionResult {
392            selected_algorithms: evaluated_algorithms,
393            dataset_characteristics: dataset_chars,
394            total_evaluation_time: total_time,
395            n_algorithms_evaluated: filtered_algorithms.len(),
396            best_algorithm,
397            improvement_over_baseline: improvement,
398            explanation,
399        })
400    }
401
402    /// Build catalog of available algorithms
403    fn build_algorithm_catalog() -> HashMap<TaskType, Vec<AlgorithmSpec>> {
404        let mut catalog = HashMap::new();
405
406        // Classification algorithms
407        let classification_algorithms = vec![
408            // Linear classifiers
409            AlgorithmSpec {
410                family: AlgorithmFamily::Linear,
411                name: "LogisticRegression".to_string(),
412                default_params: [("C".to_string(), "1.0".to_string())]
413                    .iter()
414                    .cloned()
415                    .collect(),
416                param_space: [(
417                    "C".to_string(),
418                    vec![
419                        "0.001".to_string(),
420                        "0.01".to_string(),
421                        "0.1".to_string(),
422                        "1.0".to_string(),
423                        "10.0".to_string(),
424                        "100.0".to_string(),
425                    ],
426                )]
427                .iter()
428                .cloned()
429                .collect(),
430                complexity: 1.0,
431                memory_requirement: 1.0,
432                supports_proba: true,
433                handles_missing: false,
434                handles_categorical: false,
435                supports_incremental: false,
436            },
437            AlgorithmSpec {
438                family: AlgorithmFamily::Linear,
439                name: "RidgeClassifier".to_string(),
440                default_params: [("alpha".to_string(), "1.0".to_string())]
441                    .iter()
442                    .cloned()
443                    .collect(),
444                param_space: [(
445                    "alpha".to_string(),
446                    vec![
447                        "0.1".to_string(),
448                        "1.0".to_string(),
449                        "10.0".to_string(),
450                        "100.0".to_string(),
451                    ],
452                )]
453                .iter()
454                .cloned()
455                .collect(),
456                complexity: 1.0,
457                memory_requirement: 1.0,
458                supports_proba: false,
459                handles_missing: false,
460                handles_categorical: false,
461                supports_incremental: false,
462            },
463            // Tree-based classifiers
464            AlgorithmSpec {
465                family: AlgorithmFamily::TreeBased,
466                name: "DecisionTreeClassifier".to_string(),
467                default_params: [("max_depth".to_string(), "None".to_string())]
468                    .iter()
469                    .cloned()
470                    .collect(),
471                param_space: [
472                    (
473                        "max_depth".to_string(),
474                        vec![
475                            "3".to_string(),
476                            "5".to_string(),
477                            "10".to_string(),
478                            "None".to_string(),
479                        ],
480                    ),
481                    (
482                        "min_samples_split".to_string(),
483                        vec!["2".to_string(), "5".to_string(), "10".to_string()],
484                    ),
485                ]
486                .iter()
487                .cloned()
488                .collect(),
489                complexity: 2.0,
490                memory_requirement: 2.0,
491                supports_proba: true,
492                handles_missing: false,
493                handles_categorical: true,
494                supports_incremental: false,
495            },
496            AlgorithmSpec {
497                family: AlgorithmFamily::TreeBased,
498                name: "RandomForestClassifier".to_string(),
499                default_params: [("n_estimators".to_string(), "100".to_string())]
500                    .iter()
501                    .cloned()
502                    .collect(),
503                param_space: [
504                    (
505                        "n_estimators".to_string(),
506                        vec!["50".to_string(), "100".to_string(), "200".to_string()],
507                    ),
508                    (
509                        "max_depth".to_string(),
510                        vec![
511                            "3".to_string(),
512                            "5".to_string(),
513                            "10".to_string(),
514                            "None".to_string(),
515                        ],
516                    ),
517                ]
518                .iter()
519                .cloned()
520                .collect(),
521                complexity: 4.0,
522                memory_requirement: 4.0,
523                supports_proba: true,
524                handles_missing: false,
525                handles_categorical: true,
526                supports_incremental: false,
527            },
528            // Ensemble methods
529            AlgorithmSpec {
530                family: AlgorithmFamily::Ensemble,
531                name: "AdaBoostClassifier".to_string(),
532                default_params: [("n_estimators".to_string(), "50".to_string())]
533                    .iter()
534                    .cloned()
535                    .collect(),
536                param_space: [
537                    (
538                        "n_estimators".to_string(),
539                        vec!["25".to_string(), "50".to_string(), "100".to_string()],
540                    ),
541                    (
542                        "learning_rate".to_string(),
543                        vec!["0.1".to_string(), "0.5".to_string(), "1.0".to_string()],
544                    ),
545                ]
546                .iter()
547                .cloned()
548                .collect(),
549                complexity: 3.0,
550                memory_requirement: 3.0,
551                supports_proba: true,
552                handles_missing: false,
553                handles_categorical: true,
554                supports_incremental: false,
555            },
556            // K-Nearest Neighbors
557            AlgorithmSpec {
558                family: AlgorithmFamily::NeighborBased,
559                name: "KNeighborsClassifier".to_string(),
560                default_params: [("n_neighbors".to_string(), "5".to_string())]
561                    .iter()
562                    .cloned()
563                    .collect(),
564                param_space: [
565                    (
566                        "n_neighbors".to_string(),
567                        vec![
568                            "3".to_string(),
569                            "5".to_string(),
570                            "7".to_string(),
571                            "11".to_string(),
572                        ],
573                    ),
574                    (
575                        "weights".to_string(),
576                        vec!["uniform".to_string(), "distance".to_string()],
577                    ),
578                ]
579                .iter()
580                .cloned()
581                .collect(),
582                complexity: 1.0,
583                memory_requirement: 5.0,
584                supports_proba: true,
585                handles_missing: false,
586                handles_categorical: false,
587                supports_incremental: false,
588            },
589            // Naive Bayes
590            AlgorithmSpec {
591                family: AlgorithmFamily::NaiveBayes,
592                name: "GaussianNB".to_string(),
593                default_params: HashMap::new(),
594                param_space: HashMap::new(),
595                complexity: 1.0,
596                memory_requirement: 1.0,
597                supports_proba: true,
598                handles_missing: false,
599                handles_categorical: false,
600                supports_incremental: true,
601            },
602            // Support Vector Machine
603            AlgorithmSpec {
604                family: AlgorithmFamily::SVM,
605                name: "SVC".to_string(),
606                default_params: [
607                    ("C".to_string(), "1.0".to_string()),
608                    ("kernel".to_string(), "rbf".to_string()),
609                ]
610                .iter()
611                .cloned()
612                .collect(),
613                param_space: [
614                    (
615                        "C".to_string(),
616                        vec!["0.1".to_string(), "1.0".to_string(), "10.0".to_string()],
617                    ),
618                    (
619                        "kernel".to_string(),
620                        vec!["linear".to_string(), "rbf".to_string(), "poly".to_string()],
621                    ),
622                ]
623                .iter()
624                .cloned()
625                .collect(),
626                complexity: 3.0,
627                memory_requirement: 3.0,
628                supports_proba: false,
629                handles_missing: false,
630                handles_categorical: false,
631                supports_incremental: false,
632            },
633            // Dummy classifier for baseline
634            AlgorithmSpec {
635                family: AlgorithmFamily::Dummy,
636                name: "DummyClassifier".to_string(),
637                default_params: [("strategy".to_string(), "stratified".to_string())]
638                    .iter()
639                    .cloned()
640                    .collect(),
641                param_space: [(
642                    "strategy".to_string(),
643                    vec![
644                        "stratified".to_string(),
645                        "most_frequent".to_string(),
646                        "uniform".to_string(),
647                    ],
648                )]
649                .iter()
650                .cloned()
651                .collect(),
652                complexity: 0.1,
653                memory_requirement: 0.1,
654                supports_proba: true,
655                handles_missing: true,
656                handles_categorical: true,
657                supports_incremental: true,
658            },
659        ];
660
661        // Regression algorithms
662        let regression_algorithms = vec![
663            // Linear regressors
664            AlgorithmSpec {
665                family: AlgorithmFamily::Linear,
666                name: "LinearRegression".to_string(),
667                default_params: HashMap::new(),
668                param_space: HashMap::new(),
669                complexity: 1.0,
670                memory_requirement: 1.0,
671                supports_proba: false,
672                handles_missing: false,
673                handles_categorical: false,
674                supports_incremental: false,
675            },
676            AlgorithmSpec {
677                family: AlgorithmFamily::Linear,
678                name: "Ridge".to_string(),
679                default_params: [("alpha".to_string(), "1.0".to_string())]
680                    .iter()
681                    .cloned()
682                    .collect(),
683                param_space: [(
684                    "alpha".to_string(),
685                    vec![
686                        "0.1".to_string(),
687                        "1.0".to_string(),
688                        "10.0".to_string(),
689                        "100.0".to_string(),
690                    ],
691                )]
692                .iter()
693                .cloned()
694                .collect(),
695                complexity: 1.0,
696                memory_requirement: 1.0,
697                supports_proba: false,
698                handles_missing: false,
699                handles_categorical: false,
700                supports_incremental: false,
701            },
702            AlgorithmSpec {
703                family: AlgorithmFamily::Linear,
704                name: "Lasso".to_string(),
705                default_params: [("alpha".to_string(), "1.0".to_string())]
706                    .iter()
707                    .cloned()
708                    .collect(),
709                param_space: [(
710                    "alpha".to_string(),
711                    vec![
712                        "0.001".to_string(),
713                        "0.01".to_string(),
714                        "0.1".to_string(),
715                        "1.0".to_string(),
716                    ],
717                )]
718                .iter()
719                .cloned()
720                .collect(),
721                complexity: 1.5,
722                memory_requirement: 1.0,
723                supports_proba: false,
724                handles_missing: false,
725                handles_categorical: false,
726                supports_incremental: false,
727            },
728            // Tree-based regressors
729            AlgorithmSpec {
730                family: AlgorithmFamily::TreeBased,
731                name: "DecisionTreeRegressor".to_string(),
732                default_params: [("max_depth".to_string(), "None".to_string())]
733                    .iter()
734                    .cloned()
735                    .collect(),
736                param_space: [
737                    (
738                        "max_depth".to_string(),
739                        vec![
740                            "3".to_string(),
741                            "5".to_string(),
742                            "10".to_string(),
743                            "None".to_string(),
744                        ],
745                    ),
746                    (
747                        "min_samples_split".to_string(),
748                        vec!["2".to_string(), "5".to_string(), "10".to_string()],
749                    ),
750                ]
751                .iter()
752                .cloned()
753                .collect(),
754                complexity: 2.0,
755                memory_requirement: 2.0,
756                supports_proba: false,
757                handles_missing: false,
758                handles_categorical: true,
759                supports_incremental: false,
760            },
761            AlgorithmSpec {
762                family: AlgorithmFamily::TreeBased,
763                name: "RandomForestRegressor".to_string(),
764                default_params: [("n_estimators".to_string(), "100".to_string())]
765                    .iter()
766                    .cloned()
767                    .collect(),
768                param_space: [
769                    (
770                        "n_estimators".to_string(),
771                        vec!["50".to_string(), "100".to_string(), "200".to_string()],
772                    ),
773                    (
774                        "max_depth".to_string(),
775                        vec![
776                            "3".to_string(),
777                            "5".to_string(),
778                            "10".to_string(),
779                            "None".to_string(),
780                        ],
781                    ),
782                ]
783                .iter()
784                .cloned()
785                .collect(),
786                complexity: 4.0,
787                memory_requirement: 4.0,
788                supports_proba: false,
789                handles_missing: false,
790                handles_categorical: true,
791                supports_incremental: false,
792            },
793            // K-Nearest Neighbors
794            AlgorithmSpec {
795                family: AlgorithmFamily::NeighborBased,
796                name: "KNeighborsRegressor".to_string(),
797                default_params: [("n_neighbors".to_string(), "5".to_string())]
798                    .iter()
799                    .cloned()
800                    .collect(),
801                param_space: [
802                    (
803                        "n_neighbors".to_string(),
804                        vec![
805                            "3".to_string(),
806                            "5".to_string(),
807                            "7".to_string(),
808                            "11".to_string(),
809                        ],
810                    ),
811                    (
812                        "weights".to_string(),
813                        vec!["uniform".to_string(), "distance".to_string()],
814                    ),
815                ]
816                .iter()
817                .cloned()
818                .collect(),
819                complexity: 1.0,
820                memory_requirement: 5.0,
821                supports_proba: false,
822                handles_missing: false,
823                handles_categorical: false,
824                supports_incremental: false,
825            },
826            // Support Vector Machine
827            AlgorithmSpec {
828                family: AlgorithmFamily::SVM,
829                name: "SVR".to_string(),
830                default_params: [
831                    ("C".to_string(), "1.0".to_string()),
832                    ("kernel".to_string(), "rbf".to_string()),
833                ]
834                .iter()
835                .cloned()
836                .collect(),
837                param_space: [
838                    (
839                        "C".to_string(),
840                        vec!["0.1".to_string(), "1.0".to_string(), "10.0".to_string()],
841                    ),
842                    (
843                        "kernel".to_string(),
844                        vec!["linear".to_string(), "rbf".to_string(), "poly".to_string()],
845                    ),
846                ]
847                .iter()
848                .cloned()
849                .collect(),
850                complexity: 3.0,
851                memory_requirement: 3.0,
852                supports_proba: false,
853                handles_missing: false,
854                handles_categorical: false,
855                supports_incremental: false,
856            },
857            // Dummy regressor for baseline
858            AlgorithmSpec {
859                family: AlgorithmFamily::Dummy,
860                name: "DummyRegressor".to_string(),
861                default_params: [("strategy".to_string(), "mean".to_string())]
862                    .iter()
863                    .cloned()
864                    .collect(),
865                param_space: [(
866                    "strategy".to_string(),
867                    vec![
868                        "mean".to_string(),
869                        "median".to_string(),
870                        "constant".to_string(),
871                    ],
872                )]
873                .iter()
874                .cloned()
875                .collect(),
876                complexity: 0.1,
877                memory_requirement: 0.1,
878                supports_proba: false,
879                handles_missing: true,
880                handles_categorical: true,
881                supports_incremental: true,
882            },
883        ];
884
885        catalog.insert(TaskType::Classification, classification_algorithms);
886        catalog.insert(TaskType::Regression, regression_algorithms);
887        catalog
888    }
889
890    /// Get candidate algorithms for the given dataset characteristics
891    fn get_candidate_algorithms(
892        &self,
893        dataset_chars: &DatasetCharacteristics,
894    ) -> Result<Vec<AlgorithmSpec>> {
895        let algorithms = self
896            .algorithm_catalog
897            .get(&self.config.task_type)
898            .ok_or_else(|| SklearsError::InvalidParameter {
899                name: "task_type".to_string(),
900                reason: format!(
901                    "No algorithms available for task type: {:?}",
902                    self.config.task_type
903                ),
904            })?;
905
906        let mut candidates = Vec::new();
907
908        for algorithm in algorithms {
909            // Check if family is allowed
910            if let Some(ref allowed) = self.config.allowed_families {
911                if !allowed.contains(&algorithm.family) {
912                    continue;
913                }
914            }
915
916            // Check if family is excluded
917            if self.config.excluded_families.contains(&algorithm.family) {
918                continue;
919            }
920
921            // Apply heuristic filters based on dataset characteristics
922            if self.is_algorithm_suitable(algorithm, dataset_chars) {
923                candidates.push(algorithm.clone());
924            }
925        }
926
927        // Limit to max_algorithms
928        candidates.truncate(self.config.max_algorithms);
929
930        Ok(candidates)
931    }
932
933    /// Check if algorithm is suitable for the dataset
934    fn is_algorithm_suitable(
935        &self,
936        algorithm: &AlgorithmSpec,
937        dataset_chars: &DatasetCharacteristics,
938    ) -> bool {
939        // Skip dummy algorithms unless specifically requested
940        if algorithm.family == AlgorithmFamily::Dummy && !self.config.excluded_families.is_empty() {
941            return false;
942        }
943
944        // High-dimensional data heuristics
945        if dataset_chars.n_features > dataset_chars.n_samples {
946            // Prefer linear models for high-dimensional data
947            match algorithm.family {
948                AlgorithmFamily::Linear | AlgorithmFamily::NaiveBayes => return true,
949                AlgorithmFamily::NeighborBased | AlgorithmFamily::SVM => return false,
950                _ => {}
951            }
952        }
953
954        // Small dataset heuristics
955        if dataset_chars.n_samples < 100 {
956            // Avoid overly complex models for small datasets
957            if algorithm.complexity > 3.0 {
958                return false;
959            }
960        }
961
962        // Large dataset heuristics
963        if dataset_chars.n_samples > 10000 {
964            // Prefer scalable algorithms for large datasets
965            match algorithm.family {
966                AlgorithmFamily::NeighborBased => return false, // KNN doesn't scale well
967                AlgorithmFamily::SVM => return dataset_chars.n_samples < 50000, // SVM has cubic complexity
968                _ => {}
969            }
970        }
971
972        // Linearity heuristics
973        if dataset_chars.linearity_score > 0.8 {
974            // Prefer linear models for linear data
975            match algorithm.family {
976                AlgorithmFamily::Linear => return true,
977                AlgorithmFamily::TreeBased | AlgorithmFamily::Ensemble => return false,
978                _ => {}
979            }
980        }
981
982        // Missing value handling
983        if dataset_chars.missing_ratio > 0.0 && !algorithm.handles_missing {
984            return false;
985        }
986
987        true
988    }
989
990    /// Filter algorithms by computational constraints
991    fn filter_by_constraints(
992        &self,
993        algorithms: &[AlgorithmSpec],
994        dataset_chars: &DatasetCharacteristics,
995    ) -> Vec<AlgorithmSpec> {
996        algorithms
997            .iter()
998            .filter(|alg| self.satisfies_constraints(alg, dataset_chars))
999            .cloned()
1000            .collect()
1001    }
1002
1003    /// Check if algorithm satisfies computational constraints
1004    fn satisfies_constraints(
1005        &self,
1006        algorithm: &AlgorithmSpec,
1007        dataset_chars: &DatasetCharacteristics,
1008    ) -> bool {
1009        // Rough complexity estimates
1010        let estimated_training_time = self.estimate_training_time(algorithm, dataset_chars);
1011        let estimated_memory_usage = self.estimate_memory_usage(algorithm, dataset_chars);
1012
1013        if let Some(max_time) = self.config.constraints.max_training_time {
1014            if estimated_training_time > max_time {
1015                return false;
1016            }
1017        }
1018
1019        if let Some(max_memory) = self.config.constraints.max_memory_gb {
1020            if estimated_memory_usage > max_memory {
1021                return false;
1022            }
1023        }
1024
1025        true
1026    }
1027
1028    /// Estimate training time for algorithm
1029    fn estimate_training_time(
1030        &self,
1031        algorithm: &AlgorithmSpec,
1032        dataset_chars: &DatasetCharacteristics,
1033    ) -> f64 {
1034        let n = dataset_chars.n_samples as f64;
1035        let p = dataset_chars.n_features as f64;
1036
1037        // Base time estimates (in seconds for 1000 samples, 10 features)
1038        let base_time = match algorithm.family {
1039            AlgorithmFamily::Linear => 0.1,
1040            AlgorithmFamily::TreeBased => {
1041                if algorithm.name.contains("Random") {
1042                    2.0
1043                } else {
1044                    0.5
1045                }
1046            }
1047            AlgorithmFamily::Ensemble => 3.0,
1048            AlgorithmFamily::NeighborBased => 0.05, // Training is fast, prediction is slow
1049            AlgorithmFamily::SVM => 1.0,
1050            AlgorithmFamily::NaiveBayes => 0.05,
1051            AlgorithmFamily::NeuralNetwork => 5.0,
1052            AlgorithmFamily::GaussianProcess => 2.0,
1053            AlgorithmFamily::DiscriminantAnalysis => 0.2,
1054            AlgorithmFamily::Dummy => 0.01,
1055        };
1056
1057        // Scale by complexity and data size
1058        base_time * algorithm.complexity * (n / 1000.0) * (p / 10.0).sqrt()
1059    }
1060
1061    /// Estimate memory usage for algorithm
1062    fn estimate_memory_usage(
1063        &self,
1064        algorithm: &AlgorithmSpec,
1065        dataset_chars: &DatasetCharacteristics,
1066    ) -> f64 {
1067        let n = dataset_chars.n_samples as f64;
1068        let p = dataset_chars.n_features as f64;
1069
1070        // Base memory in GB
1071        let base_memory_mb = match algorithm.family {
1072            AlgorithmFamily::Linear => 1.0,
1073            AlgorithmFamily::TreeBased => {
1074                if algorithm.name.contains("Random") {
1075                    50.0
1076                } else {
1077                    10.0
1078                }
1079            }
1080            AlgorithmFamily::Ensemble => 100.0,
1081            AlgorithmFamily::NeighborBased => n * p * 8.0 / 1_000_000.0, // Store all training data
1082            AlgorithmFamily::SVM => 20.0,
1083            AlgorithmFamily::NaiveBayes => 1.0,
1084            AlgorithmFamily::NeuralNetwork => 50.0,
1085            AlgorithmFamily::GaussianProcess => 10.0,
1086            AlgorithmFamily::DiscriminantAnalysis => 5.0,
1087            AlgorithmFamily::Dummy => 0.1,
1088        };
1089
1090        (base_memory_mb * algorithm.memory_requirement) / 1000.0 // Convert to GB
1091    }
1092
1093    /// Evaluate algorithms using cross-validation
1094    fn evaluate_algorithms(
1095        &self,
1096        algorithms: &[AlgorithmSpec],
1097        X: &Array2<f64>,
1098        y: &Array1<f64>,
1099    ) -> Result<Vec<RankedAlgorithm>> {
1100        let mut results = Vec::new();
1101
1102        for algorithm in algorithms {
1103            let start_time = std::time::Instant::now();
1104
1105            // Create mock cross-validation results
1106            // In a real implementation, this would actually train and evaluate the models
1107            let cv_score = self.mock_evaluate_algorithm(algorithm, X, y);
1108            let cv_std = cv_score * 0.05; // Mock standard deviation
1109
1110            let training_time = start_time.elapsed().as_secs_f64();
1111            let memory_usage = self.estimate_memory_usage(algorithm, &self.analyze_dataset(X, y));
1112
1113            results.push(RankedAlgorithm {
1114                algorithm: algorithm.clone(),
1115                cv_score,
1116                cv_std,
1117                training_time,
1118                memory_usage,
1119                best_params: algorithm.default_params.clone(),
1120                rank: 0,                    // Will be set later
1121                selection_probability: 0.0, // Will be set later
1122            });
1123        }
1124
1125        Ok(results)
1126    }
1127
1128    /// Mock algorithm evaluation (replace with actual implementation)
1129    fn mock_evaluate_algorithm(
1130        &self,
1131        algorithm: &AlgorithmSpec,
1132        X: &Array2<f64>,
1133        y: &Array1<f64>,
1134    ) -> f64 {
1135        // Generate mock scores based on algorithm characteristics and data
1136        let dataset_chars = self.analyze_dataset(X, y);
1137
1138        let base_score = match self.config.task_type {
1139            TaskType::Classification => 0.7, // Mock accuracy
1140            TaskType::Regression => 0.8,     // Mock R²
1141        };
1142
1143        // Adjust score based on algorithm suitability
1144        let mut score: f64 = base_score;
1145
1146        // Linear models perform better on linear data
1147        if algorithm.family == AlgorithmFamily::Linear && dataset_chars.linearity_score > 0.7 {
1148            score += 0.1;
1149        }
1150
1151        // Tree models perform better on non-linear data
1152        if matches!(
1153            algorithm.family,
1154            AlgorithmFamily::TreeBased | AlgorithmFamily::Ensemble
1155        ) && dataset_chars.linearity_score < 0.5
1156        {
1157            score += 0.1;
1158        }
1159
1160        // Ensemble methods generally perform better
1161        if algorithm.family == AlgorithmFamily::Ensemble {
1162            score += 0.05;
1163        }
1164
1165        // Add some noise
1166        //         use scirs2_core::random::Rng;
1167        let mut rng = scirs2_core::random::thread_rng();
1168        score += rng.gen_range(-0.05..0.05);
1169
1170        score.clamp(0.0, 1.0)
1171    }
1172
1173    /// Calculate selection probability based on performance
1174    fn calculate_selection_probability(
1175        &self,
1176        algorithm: &RankedAlgorithm,
1177        all_algorithms: &[RankedAlgorithm],
1178    ) -> f64 {
1179        let max_score = all_algorithms
1180            .iter()
1181            .map(|a| a.cv_score)
1182            .fold(0.0, f64::max);
1183        let min_score = all_algorithms
1184            .iter()
1185            .map(|a| a.cv_score)
1186            .fold(1.0, f64::min);
1187
1188        if max_score == min_score {
1189            return 1.0 / all_algorithms.len() as f64;
1190        }
1191
1192        // Softmax-like probability
1193        let normalized_score = (algorithm.cv_score - min_score) / (max_score - min_score);
1194        let exp_score = (normalized_score * 5.0).exp();
1195        let total_exp: f64 = all_algorithms
1196            .iter()
1197            .map(|a| {
1198                let norm = (a.cv_score - min_score) / (max_score - min_score);
1199                (norm * 5.0).exp()
1200            })
1201            .sum();
1202
1203        exp_score / total_exp
1204    }
1205
1206    /// Get baseline score for comparison
1207    fn get_baseline_score(&self, _X: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
1208        match self.config.task_type {
1209            TaskType::Classification => {
1210                // Most frequent class accuracy
1211                let classes = self.get_unique_classes(y);
1212                let class_counts = self.calculate_class_distribution(y, &classes);
1213                Ok(class_counts.iter().fold(0.0, |acc, &x| acc.max(x)))
1214            }
1215            TaskType::Regression => {
1216                // R² of predicting mean
1217                let mean = y.mean().unwrap();
1218                let tss: f64 = y.iter().map(|&yi| (yi - mean).powi(2)).sum();
1219                let rss = tss; // Predicting mean gives R² = 0
1220                Ok(1.0 - rss / tss)
1221            }
1222        }
1223    }
1224
1225    /// Generate explanation for the selected algorithm
1226    fn generate_explanation(
1227        &self,
1228        best_algorithm: &RankedAlgorithm,
1229        dataset_chars: &DatasetCharacteristics,
1230    ) -> String {
1231        let mut explanation = format!(
1232            "{} ({}) was selected as the best algorithm with a cross-validation score of {:.4}.",
1233            best_algorithm.algorithm.name, best_algorithm.algorithm.family, best_algorithm.cv_score
1234        );
1235
1236        // Add reasoning based on dataset characteristics
1237        if dataset_chars.n_samples < 1000 {
1238            explanation.push_str(" This algorithm is well-suited for small datasets.");
1239        } else if dataset_chars.n_samples > 10000 {
1240            explanation.push_str(" This algorithm scales well to large datasets.");
1241        }
1242
1243        if dataset_chars.linearity_score > 0.7
1244            && best_algorithm.algorithm.family == AlgorithmFamily::Linear
1245        {
1246            explanation.push_str(
1247                " The linear nature of your data makes linear models particularly effective.",
1248            );
1249        }
1250
1251        if dataset_chars.n_features > dataset_chars.n_samples {
1252            explanation.push_str(" The high-dimensional nature of your data favors this algorithm's regularization capabilities.");
1253        }
1254
1255        if best_algorithm.algorithm.family == AlgorithmFamily::Ensemble {
1256            explanation.push_str(
1257                " Ensemble methods often provide robust performance across diverse datasets.",
1258            );
1259        }
1260
1261        explanation
1262    }
1263
1264    // Helper methods for dataset analysis
1265    fn calculate_missing_ratio(&self, X: &Array2<f64>) -> f64 {
1266        let total_values = X.len() as f64;
1267        let missing_count = X.iter().filter(|&&x| x.is_nan()).count() as f64;
1268        missing_count / total_values
1269    }
1270
1271    fn calculate_sparsity(&self, X: &Array2<f64>) -> f64 {
1272        let total_values = X.len() as f64;
1273        let zero_count = X.iter().filter(|&&x| x == 0.0).count() as f64;
1274        zero_count / total_values
1275    }
1276
1277    fn calculate_categorical_ratio(&self, X: &Array2<f64>) -> f64 {
1278        let n_features = X.ncols();
1279        if n_features == 0 {
1280            return 0.0;
1281        }
1282
1283        let mut categorical_count = 0;
1284        for col_idx in 0..n_features {
1285            let column = X.column(col_idx);
1286
1287            // Filter out NaN values
1288            let valid_values: Vec<f64> = column.iter().filter(|&&x| !x.is_nan()).copied().collect();
1289
1290            if valid_values.is_empty() {
1291                continue;
1292            }
1293
1294            // Count unique values
1295            let mut unique_values = valid_values.clone();
1296            unique_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
1297            unique_values.dedup();
1298
1299            let n_unique = unique_values.len();
1300            let n_total = valid_values.len();
1301
1302            // Heuristic: Consider categorical if:
1303            // 1. Has 10 or fewer unique values, OR
1304            // 2. Unique ratio < 5% and all values appear to be integers
1305            let unique_ratio = n_unique as f64 / n_total as f64;
1306            let all_integers = valid_values.iter().all(|&x| (x - x.round()).abs() < 1e-10);
1307
1308            if n_unique <= 10 || (unique_ratio < 0.05 && all_integers) {
1309                categorical_count += 1;
1310            }
1311        }
1312
1313        categorical_count as f64 / n_features as f64
1314    }
1315
1316    fn calculate_correlation_condition_number(&self, _X: &Array2<f64>) -> f64 {
1317        // Mock implementation - would need actual correlation matrix computation
1318        //         use scirs2_core::random::Rng;
1319        let mut rng = scirs2_core::random::thread_rng();
1320        rng.gen_range(1.0..100.0)
1321    }
1322
1323    fn get_unique_classes(&self, y: &Array1<f64>) -> Vec<i32> {
1324        let mut classes: Vec<i32> = y.iter().map(|&x| x as i32).collect();
1325        classes.sort_unstable();
1326        classes.dedup();
1327        classes
1328    }
1329
1330    fn calculate_class_distribution(&self, y: &Array1<f64>, classes: &[i32]) -> Vec<f64> {
1331        let total = y.len() as f64;
1332        classes
1333            .iter()
1334            .map(|&class| {
1335                let count = y.iter().filter(|&&yi| yi as i32 == class).count() as f64;
1336                count / total
1337            })
1338            .collect()
1339    }
1340
1341    fn calculate_target_statistics(&self, y: &Array1<f64>) -> TargetStatistics {
1342        let mean = y.mean().unwrap();
1343        let std = y.std(0.0);
1344
1345        // Mock calculations for advanced statistics
1346        TargetStatistics {
1347            mean,
1348            std,
1349            skewness: 0.0, // Would need actual skewness calculation
1350            kurtosis: 0.0, // Would need actual kurtosis calculation
1351            n_outliers: 0, // Would need outlier detection
1352        }
1353    }
1354
1355    fn estimate_linearity_score(&self, _X: &Array2<f64>, _y: &Array1<f64>) -> f64 {
1356        // Mock implementation - would need actual linearity testing
1357        //         use scirs2_core::random::Rng;
1358        let mut rng = scirs2_core::random::thread_rng();
1359        rng.gen_range(0.0..1.0)
1360    }
1361
1362    fn estimate_noise_level(&self, _X: &Array2<f64>, _y: &Array1<f64>) -> f64 {
1363        // Mock implementation - would need actual noise estimation
1364        //         use scirs2_core::random::Rng;
1365        let mut rng = scirs2_core::random::thread_rng();
1366        rng.gen_range(0.0..0.5)
1367    }
1368
1369    fn estimate_effective_dimensionality(&self, X: &Array2<f64>) -> Option<usize> {
1370        // Mock implementation - would need PCA analysis
1371        Some((X.ncols() as f64 * 0.8) as usize)
1372    }
1373}
1374
1375/// Convenience function for quick algorithm selection
1376pub fn select_best_algorithm(
1377    X: &Array2<f64>,
1378    y: &Array1<f64>,
1379    task_type: TaskType,
1380) -> Result<AlgorithmSelectionResult> {
1381    let config = AutoMLConfig {
1382        task_type,
1383        ..Default::default()
1384    };
1385
1386    let selector = AutoMLAlgorithmSelector::new(config);
1387    selector.select_algorithms(X, y)
1388}
1389
1390#[allow(non_snake_case)]
1391#[cfg(test)]
1392mod tests {
1393    use super::*;
1394    use scirs2_core::ndarray::{Array1, Array2};
1395
1396    #[allow(non_snake_case)]
1397    fn create_test_classification_data() -> (Array2<f64>, Array1<f64>) {
1398        let X = Array2::from_shape_vec((100, 4), (0..400).map(|i| i as f64).collect()).unwrap();
1399        let y = Array1::from_vec((0..100).map(|i| (i % 3) as f64).collect());
1400        (X, y)
1401    }
1402
1403    #[allow(non_snake_case)]
1404    fn create_test_regression_data() -> (Array2<f64>, Array1<f64>) {
1405        let X = Array2::from_shape_vec((100, 4), (0..400).map(|i| i as f64).collect()).unwrap();
1406        use scirs2_core::essentials::Uniform;
1407        use scirs2_core::random::{thread_rng, Distribution};
1408        let mut rng = thread_rng();
1409        let dist = Uniform::new(0.0, 1.0).unwrap();
1410        let y = Array1::from_vec((0..100).map(|i| i as f64 + dist.sample(&mut rng)).collect());
1411        (X, y)
1412    }
1413
1414    #[test]
1415    fn test_algorithm_selection_classification() {
1416        let (X, y) = create_test_classification_data();
1417        let result = select_best_algorithm(&X, &y, TaskType::Classification);
1418        assert!(result.is_ok());
1419
1420        let result = result.unwrap();
1421        assert!(!result.selected_algorithms.is_empty());
1422        assert!(result.best_algorithm.cv_score > 0.0);
1423    }
1424
1425    #[test]
1426    fn test_algorithm_selection_regression() {
1427        let (X, y) = create_test_regression_data();
1428        let result = select_best_algorithm(&X, &y, TaskType::Regression);
1429        assert!(result.is_ok());
1430
1431        let result = result.unwrap();
1432        assert!(!result.selected_algorithms.is_empty());
1433        assert!(result.best_algorithm.cv_score > 0.0);
1434    }
1435
1436    #[test]
1437    fn test_dataset_characteristics_analysis() {
1438        let (X, y) = create_test_classification_data();
1439        let config = AutoMLConfig::default();
1440        let selector = AutoMLAlgorithmSelector::new(config);
1441
1442        let chars = selector.analyze_dataset(&X, &y);
1443        assert_eq!(chars.n_samples, 100);
1444        assert_eq!(chars.n_features, 4);
1445        assert_eq!(chars.n_classes, Some(3));
1446    }
1447
1448    #[test]
1449    fn test_custom_config() {
1450        let (X, y) = create_test_classification_data();
1451
1452        let config = AutoMLConfig {
1453            task_type: TaskType::Classification,
1454            max_algorithms: 3,
1455            allowed_families: Some(vec![AlgorithmFamily::Linear, AlgorithmFamily::TreeBased]),
1456            ..Default::default()
1457        };
1458
1459        let selector = AutoMLAlgorithmSelector::new(config);
1460        let result = selector.select_algorithms(&X, &y);
1461        assert!(result.is_ok());
1462
1463        let result = result.unwrap();
1464        assert!(result.n_algorithms_evaluated <= 3);
1465
1466        for alg in &result.selected_algorithms {
1467            assert!(matches!(
1468                alg.algorithm.family,
1469                AlgorithmFamily::Linear | AlgorithmFamily::TreeBased
1470            ));
1471        }
1472    }
1473
1474    #[test]
1475    fn test_computational_constraints() {
1476        let (X, y) = create_test_classification_data();
1477
1478        let config = AutoMLConfig {
1479            task_type: TaskType::Classification,
1480            constraints: ComputationalConstraints {
1481                max_training_time: Some(1.0), // Very short time limit
1482                max_memory_gb: Some(0.1),     // Very low memory limit
1483                ..Default::default()
1484            },
1485            ..Default::default()
1486        };
1487
1488        let selector = AutoMLAlgorithmSelector::new(config);
1489        let result = selector.select_algorithms(&X, &y);
1490        assert!(result.is_ok());
1491
1492        let result = result.unwrap();
1493        // Should still find at least simple algorithms
1494        assert!(!result.selected_algorithms.is_empty());
1495    }
1496}