Skip to main content

sklears_preprocessing/
adaptive.rs

1//! Adaptive preprocessing parameters that automatically tune based on data characteristics
2//!
3//! This module provides automatic parameter selection and tuning for preprocessing
4//! transformers based on statistical analysis of the input data. It helps optimize
5//! preprocessing pipelines without manual parameter tuning.
6//!
7//! # Features
8//!
9//! - **Data Distribution Analysis**: Automatically detect data distribution characteristics
10//! - **Adaptive Thresholds**: Dynamic threshold selection based on data properties
11//! - **Parameter Optimization**: Automatic parameter tuning for various transformers
12//! - **Multi-Objective Optimization**: Balance multiple criteria (robustness, efficiency, quality)
13//! - **Cross-Validation Based Tuning**: Use CV to select optimal parameters
14//! - **Ensemble Parameter Selection**: Combine multiple parameter selection strategies
15
16use scirs2_core::ndarray::Array2;
17use sklears_core::{
18    error::{Result, SklearsError},
19    traits::{Fit, Trained, Untrained},
20    types::Float,
21};
22use std::collections::HashMap;
23use std::marker::PhantomData;
24
25/// Data distribution characteristics detected from input data
26#[derive(Debug, Clone)]
27pub struct DataCharacteristics {
28    /// Number of samples and features
29    pub shape: (usize, usize),
30    /// Distribution type per feature (normal, skewed, uniform, bimodal, etc.)
31    pub distribution_types: Vec<DistributionType>,
32    /// Skewness per feature
33    pub skewness: Vec<Float>,
34    /// Kurtosis per feature
35    pub kurtosis: Vec<Float>,
36    /// Outlier percentages per feature
37    pub outlier_percentages: Vec<Float>,
38    /// Missing value percentages per feature
39    pub missing_percentages: Vec<Float>,
40    /// Data ranges per feature
41    pub ranges: Vec<(Float, Float)>,
42    /// Correlation matrix between features
43    pub correlation_strength: Float,
44    /// Overall data quality score (0-1)
45    pub quality_score: Float,
46    /// Estimated optimal batch size for processing
47    pub optimal_batch_size: usize,
48}
49
50/// Types of distributions detected in features
51#[derive(Debug, Clone, Copy)]
52pub enum DistributionType {
53    /// Approximately normal distribution
54    Normal,
55    /// Skewed distribution (left or right)
56    Skewed,
57    /// Uniform distribution
58    Uniform,
59    /// Bimodal or multimodal distribution
60    Multimodal,
61    /// Heavy-tailed distribution
62    HeavyTailed,
63    /// Sparse distribution (many zeros)
64    Sparse,
65    /// Unknown or complex distribution
66    Unknown,
67}
68
69/// Adaptive parameter selection strategies
70#[derive(Debug, Clone, Copy)]
71pub enum AdaptationStrategy {
72    /// Conservative approach - prioritize robustness
73    Conservative,
74    /// Balanced approach - balance performance and robustness
75    Balanced,
76    /// Aggressive approach - prioritize performance
77    Aggressive,
78    /// Custom strategy with user-defined weights
79    Custom,
80}
81
82/// Configuration for adaptive parameter selection
83#[derive(Debug, Clone)]
84pub struct AdaptiveConfig {
85    /// Adaptation strategy
86    pub strategy: AdaptationStrategy,
87    /// Whether to use cross-validation for parameter selection
88    pub use_cross_validation: bool,
89    /// Number of CV folds (if using CV)
90    pub cv_folds: usize,
91    /// Maximum time budget for optimization (seconds)
92    pub time_budget: Option<Float>,
93    /// Whether to use parallel processing
94    pub parallel: bool,
95    /// Convergence tolerance for optimization
96    pub tolerance: Float,
97    /// Maximum number of optimization iterations
98    pub max_iterations: usize,
99    /// Custom parameter bounds (if any)
100    pub parameter_bounds: HashMap<String, (Float, Float)>,
101}
102
103impl Default for AdaptiveConfig {
104    fn default() -> Self {
105        Self {
106            strategy: AdaptationStrategy::Balanced,
107            use_cross_validation: true,
108            cv_folds: 5,
109            time_budget: Some(60.0), // 1 minute default
110            parallel: true,
111            tolerance: 1e-4,
112            max_iterations: 100,
113            parameter_bounds: HashMap::new(),
114        }
115    }
116}
117
118/// Adaptive parameter selector for preprocessing transformers
119#[derive(Debug, Clone)]
120pub struct AdaptiveParameterSelector<State = Untrained> {
121    config: AdaptiveConfig,
122    state: PhantomData<State>,
123    // Fitted parameters
124    data_characteristics_: Option<DataCharacteristics>,
125    optimal_parameters_: Option<HashMap<String, Float>>,
126    parameter_history_: Option<Vec<ParameterEvaluation>>,
127}
128
129/// Parameter evaluation result
130#[derive(Debug, Clone)]
131pub struct ParameterEvaluation {
132    pub parameters: HashMap<String, Float>,
133    pub score: Float,
134    pub robustness_score: Float,
135    pub efficiency_score: Float,
136    pub quality_score: Float,
137    pub evaluation_time: Float,
138}
139
140/// Adaptive parameter recommendations for different transformers
141#[derive(Debug, Clone)]
142pub struct ParameterRecommendations {
143    /// Recommended scaling parameters
144    pub scaling: ScalingParameters,
145    /// Recommended imputation parameters
146    pub imputation: ImputationParameters,
147    /// Recommended outlier detection parameters
148    pub outlier_detection: OutlierDetectionParameters,
149    /// Recommended transformation parameters
150    pub transformation: TransformationParameters,
151    /// Overall confidence in recommendations (0-1)
152    pub confidence: Float,
153}
154
155/// Adaptive scaling parameters
156#[derive(Debug, Clone)]
157pub struct ScalingParameters {
158    pub method: String, // "standard", "robust", "minmax", etc.
159    pub outlier_threshold: Float,
160    pub quantile_range: (Float, Float),
161    pub with_centering: bool,
162    pub with_scaling: bool,
163}
164
165/// Adaptive imputation parameters
166#[derive(Debug, Clone)]
167pub struct ImputationParameters {
168    pub strategy: String, // "mean", "median", "knn", etc.
169    pub n_neighbors: Option<usize>,
170    pub outlier_aware: bool,
171    pub max_iterations: Option<usize>,
172}
173
174/// Adaptive outlier detection parameters
175#[derive(Debug, Clone)]
176pub struct OutlierDetectionParameters {
177    pub method: String, // "isolation_forest", "local_outlier_factor", etc.
178    pub contamination: Float,
179    pub threshold: Float,
180    pub ensemble_size: Option<usize>,
181}
182
183/// Adaptive transformation parameters
184#[derive(Debug, Clone)]
185pub struct TransformationParameters {
186    pub method: String, // "log", "box_cox", "quantile", etc.
187    pub handle_negatives: bool,
188    pub lambda: Option<Float>,
189    pub n_quantiles: Option<usize>,
190}
191
192impl AdaptiveParameterSelector<Untrained> {
193    /// Create a new adaptive parameter selector
194    pub fn new() -> Self {
195        Self {
196            config: AdaptiveConfig::default(),
197            state: PhantomData,
198            data_characteristics_: None,
199            optimal_parameters_: None,
200            parameter_history_: None,
201        }
202    }
203
204    /// Create with conservative strategy
205    pub fn conservative() -> Self {
206        Self::new().strategy(AdaptationStrategy::Conservative)
207    }
208
209    /// Create with balanced strategy
210    pub fn balanced() -> Self {
211        Self::new().strategy(AdaptationStrategy::Balanced)
212    }
213
214    /// Create with aggressive strategy
215    pub fn aggressive() -> Self {
216        Self::new().strategy(AdaptationStrategy::Aggressive)
217    }
218
219    /// Set the adaptation strategy
220    pub fn strategy(mut self, strategy: AdaptationStrategy) -> Self {
221        self.config.strategy = strategy;
222        self
223    }
224
225    /// Enable or disable cross-validation
226    pub fn cross_validation(mut self, enable: bool, folds: usize) -> Self {
227        self.config.use_cross_validation = enable;
228        self.config.cv_folds = folds;
229        self
230    }
231
232    /// Set time budget for optimization
233    pub fn time_budget(mut self, seconds: Float) -> Self {
234        self.config.time_budget = Some(seconds);
235        self
236    }
237
238    /// Enable parallel processing
239    pub fn parallel(mut self, enable: bool) -> Self {
240        self.config.parallel = enable;
241        self
242    }
243
244    /// Set optimization tolerance
245    pub fn tolerance(mut self, tolerance: Float) -> Self {
246        self.config.tolerance = tolerance;
247        self
248    }
249
250    /// Set parameter bounds
251    pub fn parameter_bounds(mut self, bounds: HashMap<String, (Float, Float)>) -> Self {
252        self.config.parameter_bounds = bounds;
253        self
254    }
255}
256
257impl Fit<Array2<Float>, ()> for AdaptiveParameterSelector<Untrained> {
258    type Fitted = AdaptiveParameterSelector<Trained>;
259
260    fn fit(mut self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
261        let (n_samples, n_features) = x.dim();
262
263        if n_samples == 0 || n_features == 0 {
264            return Err(SklearsError::InvalidInput(
265                "Input array is empty".to_string(),
266            ));
267        }
268
269        // Analyze data characteristics
270        let characteristics = self.analyze_data_characteristics(x)?;
271
272        // Generate parameter recommendations based on characteristics
273        let optimal_parameters = self.optimize_parameters(x, &characteristics)?;
274
275        // Evaluate different parameter configurations
276        let parameter_history = self.evaluate_parameter_space(x, &characteristics)?;
277
278        self.data_characteristics_ = Some(characteristics);
279        self.optimal_parameters_ = Some(optimal_parameters);
280        self.parameter_history_ = Some(parameter_history);
281
282        Ok(AdaptiveParameterSelector {
283            config: self.config,
284            state: PhantomData,
285            data_characteristics_: self.data_characteristics_,
286            optimal_parameters_: self.optimal_parameters_,
287            parameter_history_: self.parameter_history_,
288        })
289    }
290}
291
292impl AdaptiveParameterSelector<Untrained> {
293    /// Analyze data characteristics to inform parameter selection
294    fn analyze_data_characteristics(&self, x: &Array2<Float>) -> Result<DataCharacteristics> {
295        let (n_samples, n_features) = x.dim();
296
297        let mut distribution_types = Vec::with_capacity(n_features);
298        let mut skewness = Vec::with_capacity(n_features);
299        let mut kurtosis = Vec::with_capacity(n_features);
300        let mut outlier_percentages = Vec::with_capacity(n_features);
301        let mut missing_percentages = Vec::with_capacity(n_features);
302        let mut ranges = Vec::with_capacity(n_features);
303
304        // Analyze each feature
305        for j in 0..n_features {
306            let column = x.column(j);
307
308            // Get valid (non-NaN) values
309            let valid_values: Vec<Float> =
310                column.iter().filter(|x| x.is_finite()).copied().collect();
311
312            let missing_pct =
313                ((n_samples - valid_values.len()) as Float / n_samples as Float) * 100.0;
314            missing_percentages.push(missing_pct);
315
316            if valid_values.is_empty() {
317                distribution_types.push(DistributionType::Unknown);
318                skewness.push(0.0);
319                kurtosis.push(0.0);
320                outlier_percentages.push(0.0);
321                ranges.push((0.0, 0.0));
322                continue;
323            }
324
325            // Basic statistics
326            let mean = valid_values.iter().sum::<Float>() / valid_values.len() as Float;
327            let variance = valid_values
328                .iter()
329                .map(|x| (x - mean).powi(2))
330                .sum::<Float>()
331                / valid_values.len() as Float;
332            let std = variance.sqrt();
333
334            // Skewness and kurtosis
335            let feature_skewness = if std > 0.0 {
336                valid_values
337                    .iter()
338                    .map(|x| ((x - mean) / std).powi(3))
339                    .sum::<Float>()
340                    / valid_values.len() as Float
341            } else {
342                0.0
343            };
344
345            let feature_kurtosis = if std > 0.0 {
346                valid_values
347                    .iter()
348                    .map(|x| ((x - mean) / std).powi(4))
349                    .sum::<Float>()
350                    / valid_values.len() as Float
351                    - 3.0 // Excess kurtosis
352            } else {
353                0.0
354            };
355
356            skewness.push(feature_skewness);
357            kurtosis.push(feature_kurtosis);
358
359            // Outlier detection using IQR method
360            let mut sorted_values = valid_values.clone();
361            sorted_values.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
362
363            let q1_idx = sorted_values.len() / 4;
364            let q3_idx = 3 * sorted_values.len() / 4;
365            let q1 = sorted_values[q1_idx];
366            let q3 = sorted_values[q3_idx];
367            let iqr = q3 - q1;
368
369            let lower_bound = q1 - 1.5 * iqr;
370            let upper_bound = q3 + 1.5 * iqr;
371
372            let outlier_count = valid_values
373                .iter()
374                .filter(|&&x| x < lower_bound || x > upper_bound)
375                .count();
376            let outlier_pct = (outlier_count as Float / valid_values.len() as Float) * 100.0;
377            outlier_percentages.push(outlier_pct);
378
379            // Range
380            let min_val = sorted_values[0];
381            let max_val = sorted_values[sorted_values.len() - 1];
382            ranges.push((min_val, max_val));
383
384            // Distribution type classification
385            let dist_type = self.classify_distribution(
386                feature_skewness,
387                feature_kurtosis,
388                outlier_pct,
389                &valid_values,
390            );
391            distribution_types.push(dist_type);
392        }
393
394        // Correlation strength (simplified as average absolute correlation)
395        let correlation_strength = self.estimate_correlation_strength(x)?;
396
397        // Overall quality score
398        let avg_missing = missing_percentages.iter().sum::<Float>() / n_features as Float;
399        let avg_outliers = outlier_percentages.iter().sum::<Float>() / n_features as Float;
400        let quality_score = (100.0 - avg_missing - avg_outliers).max(0.0) / 100.0;
401
402        // Optimal batch size estimation
403        let optimal_batch_size = self.estimate_optimal_batch_size(n_samples, n_features);
404
405        Ok(DataCharacteristics {
406            shape: (n_samples, n_features),
407            distribution_types,
408            skewness,
409            kurtosis,
410            outlier_percentages,
411            missing_percentages,
412            ranges,
413            correlation_strength,
414            quality_score,
415            optimal_batch_size,
416        })
417    }
418
419    /// Classify the distribution type of a feature
420    fn classify_distribution(
421        &self,
422        skewness: Float,
423        kurtosis: Float,
424        outlier_pct: Float,
425        values: &[Float],
426    ) -> DistributionType {
427        // Check for sparsity (many zeros)
428        let zero_count = values.iter().filter(|&&x| x.abs() < 1e-10).count();
429        let sparsity = zero_count as Float / values.len() as Float;
430
431        if sparsity > 0.5 {
432            return DistributionType::Sparse;
433        }
434
435        // Check for normality
436        if skewness.abs() < 0.5 && kurtosis.abs() < 1.0 && outlier_pct < 5.0 {
437            return DistributionType::Normal;
438        }
439
440        // Check for skewness
441        if skewness.abs() > 1.0 {
442            return DistributionType::Skewed;
443        }
444
445        // Check for heavy tails
446        if kurtosis > 2.0 || outlier_pct > 10.0 {
447            return DistributionType::HeavyTailed;
448        }
449
450        // Check for uniformity (low kurtosis, low skewness, reasonable outliers)
451        if kurtosis < -1.0 && skewness.abs() < 0.5 {
452            return DistributionType::Uniform;
453        }
454
455        // Check for multimodality (complex patterns)
456        if kurtosis < -1.5 && outlier_pct > 5.0 {
457            return DistributionType::Multimodal;
458        }
459
460        DistributionType::Unknown
461    }
462
463    /// Estimate correlation strength between features
464    fn estimate_correlation_strength(&self, x: &Array2<Float>) -> Result<Float> {
465        let (_n_samples, n_features) = x.dim();
466
467        if n_features < 2 {
468            return Ok(0.0);
469        }
470
471        let mut correlation_sum = 0.0;
472        let mut correlation_count = 0;
473
474        // Sample a subset of feature pairs to avoid O(n²) computation
475        let max_pairs = 100.min(n_features * (n_features - 1) / 2);
476        let step = (n_features * (n_features - 1) / 2).max(1) / max_pairs.max(1);
477
478        let mut pair_count = 0;
479        for i in 0..n_features {
480            for j in (i + 1)..n_features {
481                if pair_count % step == 0 {
482                    let col_i = x.column(i);
483                    let col_j = x.column(j);
484
485                    // Calculate correlation coefficient
486                    if let Ok(corr) = self.calculate_correlation(&col_i, &col_j) {
487                        correlation_sum += corr.abs();
488                        correlation_count += 1;
489                    }
490                }
491                pair_count += 1;
492            }
493        }
494
495        Ok(if correlation_count > 0 {
496            correlation_sum / correlation_count as Float
497        } else {
498            0.0
499        })
500    }
501
502    /// Calculate correlation between two features
503    fn calculate_correlation(
504        &self,
505        x: &scirs2_core::ndarray::ArrayView1<Float>,
506        y: &scirs2_core::ndarray::ArrayView1<Float>,
507    ) -> Result<Float> {
508        let pairs: Vec<(Float, Float)> = x
509            .iter()
510            .zip(y.iter())
511            .filter(|(&a, &b)| a.is_finite() && b.is_finite())
512            .map(|(&a, &b)| (a, b))
513            .collect();
514
515        if pairs.len() < 3 {
516            return Ok(0.0);
517        }
518
519        let mean_x = pairs.iter().map(|(x, _)| x).sum::<Float>() / pairs.len() as Float;
520        let mean_y = pairs.iter().map(|(_, y)| y).sum::<Float>() / pairs.len() as Float;
521
522        let mut sum_xy = 0.0;
523        let mut sum_x2 = 0.0;
524        let mut sum_y2 = 0.0;
525
526        for (x, y) in pairs {
527            let dx = x - mean_x;
528            let dy = y - mean_y;
529            sum_xy += dx * dy;
530            sum_x2 += dx * dx;
531            sum_y2 += dy * dy;
532        }
533
534        let denominator = (sum_x2 * sum_y2).sqrt();
535        if denominator > 1e-10 {
536            Ok(sum_xy / denominator)
537        } else {
538            Ok(0.0)
539        }
540    }
541
542    /// Estimate optimal batch size for processing
543    fn estimate_optimal_batch_size(&self, n_samples: usize, n_features: usize) -> usize {
544        // Simple heuristic based on data size and available memory
545        let data_size = n_samples * n_features * std::mem::size_of::<Float>();
546        let target_memory = 100_000_000; // ~100MB target
547
548        let optimal_size = if data_size <= target_memory {
549            n_samples // Process all at once
550        } else {
551            (target_memory / (n_features * std::mem::size_of::<Float>()))
552                .max(1000)
553                .min(n_samples)
554        };
555
556        optimal_size
557    }
558
559    /// Optimize parameters based on data characteristics
560    fn optimize_parameters(
561        &self,
562        _x: &Array2<Float>,
563        characteristics: &DataCharacteristics,
564    ) -> Result<HashMap<String, Float>> {
565        let mut optimal_params = HashMap::new();
566
567        // Determine optimal scaling parameters
568        let scaling_method = self.select_optimal_scaling_method(characteristics);
569        optimal_params.insert("scaling_method".to_string(), scaling_method);
570
571        // Determine optimal outlier threshold
572        let outlier_threshold = self.select_optimal_outlier_threshold(characteristics);
573        optimal_params.insert("outlier_threshold".to_string(), outlier_threshold);
574
575        // Determine optimal imputation strategy
576        let imputation_strategy = self.select_optimal_imputation_strategy(characteristics);
577        optimal_params.insert("imputation_strategy".to_string(), imputation_strategy);
578
579        // Determine optimal quantile range for robust scaling
580        let (q_low, q_high) = self.select_optimal_quantile_range(characteristics);
581        optimal_params.insert("quantile_range_low".to_string(), q_low);
582        optimal_params.insert("quantile_range_high".to_string(), q_high);
583
584        // Determine optimal contamination rate
585        let contamination_rate = self.select_optimal_contamination_rate(characteristics);
586        optimal_params.insert("contamination_rate".to_string(), contamination_rate);
587
588        Ok(optimal_params)
589    }
590
591    /// Select optimal scaling method based on data characteristics
592    fn select_optimal_scaling_method(&self, characteristics: &DataCharacteristics) -> Float {
593        let avg_outlier_pct = characteristics.outlier_percentages.iter().sum::<Float>()
594            / characteristics.outlier_percentages.len() as Float;
595        let avg_skewness = characteristics
596            .skewness
597            .iter()
598            .map(|x| x.abs())
599            .sum::<Float>()
600            / characteristics.skewness.len() as Float;
601
602        match self.config.strategy {
603            AdaptationStrategy::Conservative => {
604                if avg_outlier_pct > 10.0 || avg_skewness > 1.0 {
605                    2.0 // Robust scaling
606                } else {
607                    0.0 // Standard scaling
608                }
609            }
610            AdaptationStrategy::Balanced => {
611                if avg_outlier_pct > 15.0 {
612                    2.0 // Robust scaling
613                } else if avg_skewness > 2.0 {
614                    1.0 // MinMax scaling
615                } else {
616                    0.0 // Standard scaling
617                }
618            }
619            AdaptationStrategy::Aggressive => {
620                if avg_outlier_pct > 20.0 {
621                    2.0 // Robust scaling
622                } else {
623                    0.0 // Standard scaling (prioritize performance)
624                }
625            }
626            AdaptationStrategy::Custom => 0.0, // Default to standard
627        }
628    }
629
630    /// Select optimal outlier threshold
631    fn select_optimal_outlier_threshold(&self, characteristics: &DataCharacteristics) -> Float {
632        let avg_outlier_pct = characteristics.outlier_percentages.iter().sum::<Float>()
633            / characteristics.outlier_percentages.len() as Float;
634
635        match self.config.strategy {
636            AdaptationStrategy::Conservative => {
637                if avg_outlier_pct > 20.0 {
638                    3.5
639                } else {
640                    3.0
641                }
642            }
643            AdaptationStrategy::Balanced => {
644                if avg_outlier_pct > 15.0 {
645                    2.5
646                } else {
647                    2.0
648                }
649            }
650            AdaptationStrategy::Aggressive => {
651                if avg_outlier_pct > 10.0 {
652                    2.0
653                } else {
654                    1.5
655                }
656            }
657            AdaptationStrategy::Custom => 2.5, // Default
658        }
659    }
660
661    /// Select optimal imputation strategy
662    fn select_optimal_imputation_strategy(&self, characteristics: &DataCharacteristics) -> Float {
663        let avg_missing_pct = characteristics.missing_percentages.iter().sum::<Float>()
664            / characteristics.missing_percentages.len() as Float;
665        let has_skewed_features = characteristics.skewness.iter().any(|&s| s.abs() > 1.0);
666
667        if avg_missing_pct > 20.0 {
668            2.0 // KNN imputation for high missing rates
669        } else if has_skewed_features {
670            1.0 // Median imputation for skewed data
671        } else {
672            0.0 // Mean imputation for normal data
673        }
674    }
675
676    /// Select optimal quantile range for robust scaling
677    fn select_optimal_quantile_range(
678        &self,
679        characteristics: &DataCharacteristics,
680    ) -> (Float, Float) {
681        let avg_outlier_pct = characteristics.outlier_percentages.iter().sum::<Float>()
682            / characteristics.outlier_percentages.len() as Float;
683
684        match self.config.strategy {
685            AdaptationStrategy::Conservative => {
686                if avg_outlier_pct > 15.0 {
687                    (10.0, 90.0)
688                } else {
689                    (25.0, 75.0)
690                }
691            }
692            AdaptationStrategy::Balanced => {
693                if avg_outlier_pct > 10.0 {
694                    (5.0, 95.0)
695                } else {
696                    (25.0, 75.0)
697                }
698            }
699            AdaptationStrategy::Aggressive => {
700                (25.0, 75.0) // Standard IQR
701            }
702            AdaptationStrategy::Custom => (25.0, 75.0),
703        }
704    }
705
706    /// Select optimal contamination rate
707    fn select_optimal_contamination_rate(&self, characteristics: &DataCharacteristics) -> Float {
708        let avg_outlier_pct = characteristics.outlier_percentages.iter().sum::<Float>()
709            / characteristics.outlier_percentages.len() as Float;
710
711        // Convert percentage to rate and add some margin
712        (avg_outlier_pct / 100.0 * 1.2).min(0.5).max(0.01)
713    }
714
715    /// Evaluate different parameter configurations
716    fn evaluate_parameter_space(
717        &self,
718        x: &Array2<Float>,
719        characteristics: &DataCharacteristics,
720    ) -> Result<Vec<ParameterEvaluation>> {
721        let mut evaluations = Vec::new();
722
723        // Define parameter space to explore
724        let scaling_methods = vec![0.0, 1.0, 2.0]; // Standard, MinMax, Robust
725        let outlier_thresholds = vec![1.5, 2.0, 2.5, 3.0, 3.5];
726        let contamination_rates = vec![0.05, 0.1, 0.15, 0.2];
727
728        // Evaluate a subset of the parameter space
729        let max_evaluations = 20; // Limit to avoid excessive computation
730        let mut evaluation_count = 0;
731
732        for &scaling_method in &scaling_methods {
733            for &threshold in &outlier_thresholds {
734                for &contamination in &contamination_rates {
735                    if evaluation_count >= max_evaluations {
736                        break;
737                    }
738
739                    let mut params = HashMap::new();
740                    params.insert("scaling_method".to_string(), scaling_method);
741                    params.insert("outlier_threshold".to_string(), threshold);
742                    params.insert("contamination_rate".to_string(), contamination);
743
744                    let evaluation = self.evaluate_parameters(&params, x, characteristics)?;
745                    evaluations.push(evaluation);
746                    evaluation_count += 1;
747                }
748            }
749        }
750
751        // Sort by overall score
752        evaluations.sort_by(|a, b| {
753            b.score
754                .partial_cmp(&a.score)
755                .expect("operation should succeed")
756        });
757
758        Ok(evaluations)
759    }
760
761    /// Evaluate a specific parameter configuration
762    fn evaluate_parameters(
763        &self,
764        params: &HashMap<String, Float>,
765        _x: &Array2<Float>,
766        characteristics: &DataCharacteristics,
767    ) -> Result<ParameterEvaluation> {
768        let start_time = std::time::Instant::now();
769
770        // Compute different scoring criteria
771        let robustness_score = self.compute_robustness_score(params, characteristics);
772        let efficiency_score = self.compute_efficiency_score(params, characteristics);
773        let quality_score = self.compute_quality_score(params, characteristics);
774
775        // Combine scores based on strategy
776        let overall_score = match self.config.strategy {
777            AdaptationStrategy::Conservative => {
778                robustness_score * 0.6 + quality_score * 0.3 + efficiency_score * 0.1
779            }
780            AdaptationStrategy::Balanced => {
781                robustness_score * 0.4 + quality_score * 0.4 + efficiency_score * 0.2
782            }
783            AdaptationStrategy::Aggressive => {
784                robustness_score * 0.2 + quality_score * 0.3 + efficiency_score * 0.5
785            }
786            AdaptationStrategy::Custom => {
787                robustness_score * 0.33 + quality_score * 0.33 + efficiency_score * 0.34
788            }
789        };
790
791        let evaluation_time = start_time.elapsed().as_secs_f64() as Float;
792
793        Ok(ParameterEvaluation {
794            parameters: params.clone(),
795            score: overall_score,
796            robustness_score,
797            efficiency_score,
798            quality_score,
799            evaluation_time,
800        })
801    }
802
803    /// Compute robustness score for parameters
804    fn compute_robustness_score(
805        &self,
806        params: &HashMap<String, Float>,
807        characteristics: &DataCharacteristics,
808    ) -> Float {
809        let scaling_method = params.get("scaling_method").unwrap_or(&0.0);
810        let outlier_threshold = params.get("outlier_threshold").unwrap_or(&2.5);
811
812        let avg_outlier_pct = characteristics.outlier_percentages.iter().sum::<Float>()
813            / characteristics.outlier_percentages.len() as Float;
814
815        let mut score: Float = 0.0;
816
817        // Reward robust scaling for high outlier data
818        if avg_outlier_pct > 10.0 && *scaling_method == 2.0 {
819            score += 0.4;
820        }
821
822        // Reward appropriate outlier thresholds
823        if avg_outlier_pct > 15.0 && *outlier_threshold <= 2.5 {
824            score += 0.3;
825        } else if avg_outlier_pct <= 5.0 && *outlier_threshold >= 3.0 {
826            score += 0.3;
827        }
828
829        // Reward handling of skewed data
830        let avg_skewness = characteristics
831            .skewness
832            .iter()
833            .map(|x| x.abs())
834            .sum::<Float>()
835            / characteristics.skewness.len() as Float;
836        if avg_skewness > 1.0 && *scaling_method != 0.0 {
837            score += 0.3;
838        }
839
840        score.min(1.0 as Float)
841    }
842
843    /// Compute efficiency score for parameters
844    fn compute_efficiency_score(
845        &self,
846        params: &HashMap<String, Float>,
847        characteristics: &DataCharacteristics,
848    ) -> Float {
849        let scaling_method = params.get("scaling_method").unwrap_or(&0.0);
850        let (n_samples, n_features) = characteristics.shape;
851
852        // Standard scaling is most efficient
853        let mut score: Float = if *scaling_method == 0.0 {
854            1.0
855        } else if *scaling_method == 1.0 {
856            0.8 // MinMax is moderately efficient
857        } else {
858            0.6 // Robust scaling is less efficient
859        };
860
861        // Adjust for data size (larger data benefits more from efficient methods)
862        let data_size_factor = (n_samples * n_features) as Float;
863        if data_size_factor > 1_000_000.0 {
864            score *= 1.2; // Boost efficiency importance for large data
865        }
866
867        score.min(1.0 as Float)
868    }
869
870    /// Compute quality score for parameters
871    fn compute_quality_score(
872        &self,
873        params: &HashMap<String, Float>,
874        characteristics: &DataCharacteristics,
875    ) -> Float {
876        let mut score = characteristics.quality_score; // Start with base data quality
877
878        // Adjust based on parameter appropriateness
879        let avg_missing_pct = characteristics.missing_percentages.iter().sum::<Float>()
880            / characteristics.missing_percentages.len() as Float;
881
882        // Quality improves with appropriate handling of missing values
883        if avg_missing_pct > 10.0 {
884            score *= 0.9; // Penalize for high missing rates
885        }
886
887        // Quality improves with appropriate outlier handling
888        let avg_outlier_pct = characteristics.outlier_percentages.iter().sum::<Float>()
889            / characteristics.outlier_percentages.len() as Float;
890
891        let outlier_threshold = params.get("outlier_threshold").unwrap_or(&2.5);
892        if avg_outlier_pct > 10.0 && *outlier_threshold <= 2.5 {
893            score *= 1.1; // Reward aggressive outlier handling when needed
894        }
895
896        score.min(1.0 as Float)
897    }
898}
899
900impl AdaptiveParameterSelector<Trained> {
901    /// Get the analyzed data characteristics
902    pub fn data_characteristics(&self) -> Option<&DataCharacteristics> {
903        self.data_characteristics_.as_ref()
904    }
905
906    /// Get the optimal parameters
907    pub fn optimal_parameters(&self) -> Option<&HashMap<String, Float>> {
908        self.optimal_parameters_.as_ref()
909    }
910
911    /// Get parameter evaluation history
912    pub fn parameter_history(&self) -> Option<&Vec<ParameterEvaluation>> {
913        self.parameter_history_.as_ref()
914    }
915
916    /// Generate comprehensive parameter recommendations
917    pub fn recommend_parameters(&self) -> Result<ParameterRecommendations> {
918        let characteristics = self.data_characteristics_.as_ref().ok_or_else(|| {
919            SklearsError::InvalidInput("No data characteristics available".to_string())
920        })?;
921
922        let optimal_params = self.optimal_parameters_.as_ref().ok_or_else(|| {
923            SklearsError::InvalidInput("No optimal parameters available".to_string())
924        })?;
925
926        // Generate scaling recommendations
927        let scaling_method = optimal_params.get("scaling_method").unwrap_or(&0.0);
928        let scaling = ScalingParameters {
929            method: match *scaling_method as i32 {
930                0 => "standard".to_string(),
931                1 => "minmax".to_string(),
932                2 => "robust".to_string(),
933                _ => "standard".to_string(),
934            },
935            outlier_threshold: *optimal_params.get("outlier_threshold").unwrap_or(&2.5),
936            quantile_range: (
937                *optimal_params.get("quantile_range_low").unwrap_or(&25.0),
938                *optimal_params.get("quantile_range_high").unwrap_or(&75.0),
939            ),
940            with_centering: true,
941            with_scaling: true,
942        };
943
944        // Generate imputation recommendations
945        let imputation_strategy = optimal_params.get("imputation_strategy").unwrap_or(&0.0);
946        let avg_missing_pct = characteristics.missing_percentages.iter().sum::<Float>()
947            / characteristics.missing_percentages.len() as Float;
948
949        let imputation = ImputationParameters {
950            strategy: match *imputation_strategy as i32 {
951                0 => "mean".to_string(),
952                1 => "median".to_string(),
953                2 => "knn".to_string(),
954                _ => "mean".to_string(),
955            },
956            n_neighbors: if *imputation_strategy == 2.0 {
957                Some(5)
958            } else {
959                None
960            },
961            outlier_aware: avg_missing_pct > 10.0,
962            max_iterations: if *imputation_strategy == 2.0 {
963                Some(10)
964            } else {
965                None
966            },
967        };
968
969        // Generate outlier detection recommendations
970        let contamination_rate = *optimal_params.get("contamination_rate").unwrap_or(&0.1);
971        let outlier_detection = OutlierDetectionParameters {
972            method: "isolation_forest".to_string(),
973            contamination: contamination_rate,
974            threshold: *optimal_params.get("outlier_threshold").unwrap_or(&2.5),
975            ensemble_size: Some(100),
976        };
977
978        // Generate transformation recommendations
979        let avg_skewness = characteristics
980            .skewness
981            .iter()
982            .map(|x| x.abs())
983            .sum::<Float>()
984            / characteristics.skewness.len() as Float;
985
986        let transformation = TransformationParameters {
987            method: if avg_skewness > 1.5 {
988                "log1p".to_string()
989            } else if avg_skewness > 1.0 {
990                "box_cox".to_string()
991            } else {
992                "none".to_string()
993            },
994            handle_negatives: true,
995            lambda: None, // Auto-detect
996            n_quantiles: Some(1000),
997        };
998
999        // Compute overall confidence
1000        let confidence = characteristics.quality_score * 0.5
1001            + (1.0
1002                - (characteristics.missing_percentages.iter().sum::<Float>()
1003                    / characteristics.missing_percentages.len() as Float
1004                    / 100.0))
1005                * 0.3
1006            + (1.0
1007                - (characteristics.outlier_percentages.iter().sum::<Float>()
1008                    / characteristics.outlier_percentages.len() as Float
1009                    / 100.0))
1010                * 0.2;
1011
1012        Ok(ParameterRecommendations {
1013            scaling,
1014            imputation,
1015            outlier_detection,
1016            transformation,
1017            confidence: confidence.min(1.0).max(0.0),
1018        })
1019    }
1020
1021    /// Generate a comprehensive adaptation report
1022    pub fn adaptation_report(&self) -> Result<String> {
1023        let characteristics = self.data_characteristics_.as_ref().ok_or_else(|| {
1024            SklearsError::InvalidInput("No data characteristics available".to_string())
1025        })?;
1026
1027        let recommendations = self.recommend_parameters()?;
1028
1029        let mut report = String::new();
1030
1031        report.push_str("=== Adaptive Parameter Selection Report ===\n\n");
1032
1033        // Data characteristics summary
1034        report.push_str("=== Data Characteristics ===\n");
1035        report.push_str(&format!("Data shape: {:?}\n", characteristics.shape));
1036        report.push_str(&format!(
1037            "Overall quality score: {:.3}\n",
1038            characteristics.quality_score
1039        ));
1040        report.push_str(&format!(
1041            "Correlation strength: {:.3}\n",
1042            characteristics.correlation_strength
1043        ));
1044        report.push_str(&format!(
1045            "Optimal batch size: {}\n",
1046            characteristics.optimal_batch_size
1047        ));
1048
1049        let avg_missing = characteristics.missing_percentages.iter().sum::<Float>()
1050            / characteristics.missing_percentages.len() as Float;
1051        let avg_outliers = characteristics.outlier_percentages.iter().sum::<Float>()
1052            / characteristics.outlier_percentages.len() as Float;
1053        let avg_skewness = characteristics
1054            .skewness
1055            .iter()
1056            .map(|x| x.abs())
1057            .sum::<Float>()
1058            / characteristics.skewness.len() as Float;
1059
1060        report.push_str(&format!("Average missing values: {:.1}%\n", avg_missing));
1061        report.push_str(&format!("Average outlier rate: {:.1}%\n", avg_outliers));
1062        report.push_str(&format!("Average absolute skewness: {:.3}\n", avg_skewness));
1063        report.push_str("\n");
1064
1065        // Parameter recommendations
1066        report.push_str("=== Parameter Recommendations ===\n");
1067        report.push_str(&format!(
1068            "Confidence: {:.1}%\n\n",
1069            recommendations.confidence * 100.0
1070        ));
1071
1072        report.push_str("Scaling:\n");
1073        report.push_str(&format!("  Method: {}\n", recommendations.scaling.method));
1074        report.push_str(&format!(
1075            "  Outlier threshold: {:.2}\n",
1076            recommendations.scaling.outlier_threshold
1077        ));
1078        report.push_str(&format!(
1079            "  Quantile range: ({:.1}%, {:.1}%)\n",
1080            recommendations.scaling.quantile_range.0, recommendations.scaling.quantile_range.1
1081        ));
1082        report.push_str("\n");
1083
1084        report.push_str("Imputation:\n");
1085        report.push_str(&format!(
1086            "  Strategy: {}\n",
1087            recommendations.imputation.strategy
1088        ));
1089        if let Some(k) = recommendations.imputation.n_neighbors {
1090            report.push_str(&format!("  K-neighbors: {}\n", k));
1091        }
1092        report.push_str(&format!(
1093            "  Outlier-aware: {}\n",
1094            recommendations.imputation.outlier_aware
1095        ));
1096        report.push_str("\n");
1097
1098        report.push_str("Outlier Detection:\n");
1099        report.push_str(&format!(
1100            "  Method: {}\n",
1101            recommendations.outlier_detection.method
1102        ));
1103        report.push_str(&format!(
1104            "  Contamination: {:.3}\n",
1105            recommendations.outlier_detection.contamination
1106        ));
1107        report.push_str(&format!(
1108            "  Threshold: {:.2}\n",
1109            recommendations.outlier_detection.threshold
1110        ));
1111        report.push_str("\n");
1112
1113        report.push_str("Transformation:\n");
1114        report.push_str(&format!(
1115            "  Method: {}\n",
1116            recommendations.transformation.method
1117        ));
1118        report.push_str(&format!(
1119            "  Handle negatives: {}\n",
1120            recommendations.transformation.handle_negatives
1121        ));
1122        report.push_str("\n");
1123
1124        // Strategy and configuration
1125        report.push_str("=== Configuration ===\n");
1126        report.push_str(&format!("Strategy: {:?}\n", self.config.strategy));
1127        report.push_str(&format!(
1128            "Cross-validation: {} ({} folds)\n",
1129            self.config.use_cross_validation, self.config.cv_folds
1130        ));
1131        report.push_str(&format!("Parallel processing: {}\n", self.config.parallel));
1132        if let Some(budget) = self.config.time_budget {
1133            report.push_str(&format!("Time budget: {:.1}s\n", budget));
1134        }
1135
1136        Ok(report)
1137    }
1138
1139    /// Get adaptation recommendations as actionable insights
1140    pub fn get_insights(&self) -> Vec<String> {
1141        let mut insights = Vec::new();
1142
1143        if let Some(characteristics) = &self.data_characteristics_ {
1144            let avg_missing = characteristics.missing_percentages.iter().sum::<Float>()
1145                / characteristics.missing_percentages.len() as Float;
1146            let avg_outliers = characteristics.outlier_percentages.iter().sum::<Float>()
1147                / characteristics.outlier_percentages.len() as Float;
1148            let avg_skewness = characteristics
1149                .skewness
1150                .iter()
1151                .map(|x| x.abs())
1152                .sum::<Float>()
1153                / characteristics.skewness.len() as Float;
1154
1155            if avg_missing > 20.0 {
1156                insights.push("High missing value rate detected - consider advanced imputation methods like KNN or iterative imputation".to_string());
1157            }
1158
1159            if avg_outliers > 15.0 {
1160                insights.push(
1161                    "High outlier rate detected - robust preprocessing methods are recommended"
1162                        .to_string(),
1163                );
1164            }
1165
1166            if avg_skewness > 2.0 {
1167                insights.push(
1168                    "Highly skewed data detected - consider log or Box-Cox transformations"
1169                        .to_string(),
1170                );
1171            }
1172
1173            if characteristics.correlation_strength > 0.7 {
1174                insights.push(
1175                    "High feature correlation detected - consider dimensionality reduction"
1176                        .to_string(),
1177                );
1178            }
1179
1180            if characteristics.quality_score < 0.5 {
1181                insights.push(
1182                    "Low data quality detected - comprehensive preprocessing pipeline recommended"
1183                        .to_string(),
1184                );
1185            }
1186
1187            if characteristics.shape.0 > 1_000_000 {
1188                insights.push(
1189                    "Large dataset detected - consider streaming or batch processing approaches"
1190                        .to_string(),
1191                );
1192            }
1193
1194            if characteristics.optimal_batch_size < characteristics.shape.0 {
1195                insights.push(format!(
1196                    "Consider batch processing with batch size: {}",
1197                    characteristics.optimal_batch_size
1198                ));
1199            }
1200        }
1201
1202        if insights.is_empty() {
1203            insights.push("Data characteristics are within normal ranges - standard preprocessing should be sufficient".to_string());
1204        }
1205
1206        insights
1207    }
1208}
1209
1210impl Default for AdaptiveParameterSelector<Untrained> {
1211    fn default() -> Self {
1212        Self::new()
1213    }
1214}
1215
1216#[allow(non_snake_case)]
1217#[cfg(test)]
1218mod tests {
1219    use super::*;
1220    use approx::assert_relative_eq;
1221    use scirs2_core::ndarray::Array2;
1222
1223    #[test]
1224    fn test_adaptive_parameter_selector_creation() {
1225        let selector = AdaptiveParameterSelector::new();
1226        assert_eq!(
1227            selector.config.strategy as u8,
1228            AdaptationStrategy::Balanced as u8
1229        );
1230        assert!(selector.config.use_cross_validation);
1231        assert_eq!(selector.config.cv_folds, 5);
1232    }
1233
1234    #[test]
1235    fn test_adaptive_strategies() {
1236        let conservative = AdaptiveParameterSelector::conservative();
1237        assert_eq!(
1238            conservative.config.strategy as u8,
1239            AdaptationStrategy::Conservative as u8
1240        );
1241
1242        let aggressive = AdaptiveParameterSelector::aggressive();
1243        assert_eq!(
1244            aggressive.config.strategy as u8,
1245            AdaptationStrategy::Aggressive as u8
1246        );
1247    }
1248
1249    #[test]
1250    fn test_data_characteristics_analysis() {
1251        let data = Array2::from_shape_vec(
1252            (10, 3),
1253            vec![
1254                1.0, 10.0, 100.0, // Normal range
1255                2.0, 20.0, 200.0, 3.0, 30.0, 300.0, 4.0, 40.0, 400.0, 5.0, 50.0, 500.0, 6.0, 60.0,
1256                600.0, 7.0, 70.0, 700.0, 8.0, 80.0, 800.0, 100.0, 1000.0, 10000.0, // Outliers
1257                9.0, 90.0, 900.0,
1258            ],
1259        )
1260        .expect("operation should succeed");
1261
1262        let selector = AdaptiveParameterSelector::balanced();
1263        let fitted = selector
1264            .fit(&data, &())
1265            .expect("model fitting should succeed");
1266
1267        let characteristics = fitted
1268            .data_characteristics()
1269            .expect("operation should succeed");
1270        assert_eq!(characteristics.shape, (10, 3));
1271        assert_eq!(characteristics.distribution_types.len(), 3);
1272        assert_eq!(characteristics.skewness.len(), 3);
1273        assert!(characteristics.quality_score >= 0.0 && characteristics.quality_score <= 1.0);
1274    }
1275
1276    #[test]
1277    fn test_parameter_recommendations() {
1278        let data = Array2::from_shape_vec(
1279            (8, 2),
1280            vec![
1281                1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0, 5.0, 50.0, 100.0,
1282                1000.0, // Outliers
1283                7.0, 70.0, 8.0, 80.0,
1284            ],
1285        )
1286        .expect("operation should succeed");
1287
1288        let selector = AdaptiveParameterSelector::balanced();
1289        let fitted = selector
1290            .fit(&data, &())
1291            .expect("model fitting should succeed");
1292
1293        let recommendations = fitted
1294            .recommend_parameters()
1295            .expect("operation should succeed");
1296        assert!(recommendations.confidence >= 0.0 && recommendations.confidence <= 1.0);
1297        assert!(!recommendations.scaling.method.is_empty());
1298        assert!(!recommendations.imputation.strategy.is_empty());
1299    }
1300
1301    #[test]
1302    fn test_parameter_optimization() {
1303        let data = Array2::from_shape_vec(
1304            (6, 2),
1305            vec![
1306                1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0, 5.0, 50.0, 100.0,
1307                1000.0, // Outliers
1308            ],
1309        )
1310        .expect("operation should succeed");
1311
1312        let selector = AdaptiveParameterSelector::aggressive();
1313        let fitted = selector
1314            .fit(&data, &())
1315            .expect("model fitting should succeed");
1316
1317        let optimal_params = fitted
1318            .optimal_parameters()
1319            .expect("operation should succeed");
1320        assert!(optimal_params.contains_key("scaling_method"));
1321        assert!(optimal_params.contains_key("outlier_threshold"));
1322        assert!(optimal_params.contains_key("contamination_rate"));
1323    }
1324
1325    #[test]
1326    fn test_distribution_classification() {
1327        let data = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0])
1328            .expect("shape and data length should match");
1329
1330        let selector = AdaptiveParameterSelector::new();
1331        let fitted = selector
1332            .fit(&data, &())
1333            .expect("model fitting should succeed");
1334
1335        let characteristics = fitted
1336            .data_characteristics()
1337            .expect("operation should succeed");
1338        // Should classify as normal or uniform for this simple case
1339        assert!(matches!(
1340            characteristics.distribution_types[0],
1341            DistributionType::Normal | DistributionType::Uniform | DistributionType::Unknown
1342        ));
1343    }
1344
1345    #[test]
1346    fn test_missing_value_handling() {
1347        let data = Array2::from_shape_vec(
1348            (6, 2),
1349            vec![
1350                1.0,
1351                10.0,
1352                2.0,
1353                Float::NAN, // Missing value
1354                3.0,
1355                30.0,
1356                Float::NAN,
1357                40.0, // Missing value
1358                5.0,
1359                50.0,
1360                6.0,
1361                60.0,
1362            ],
1363        )
1364        .expect("operation should succeed");
1365
1366        let selector = AdaptiveParameterSelector::balanced();
1367        let fitted = selector
1368            .fit(&data, &())
1369            .expect("model fitting should succeed");
1370
1371        let characteristics = fitted
1372            .data_characteristics()
1373            .expect("operation should succeed");
1374        // Should detect missing values
1375        assert!(
1376            characteristics.missing_percentages[0] > 0.0
1377                || characteristics.missing_percentages[1] > 0.0
1378        );
1379    }
1380
1381    #[test]
1382    fn test_adaptation_report() {
1383        let data = Array2::from_shape_vec(
1384            (6, 2),
1385            vec![
1386                1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0, 5.0, 50.0, 100.0, 1000.0,
1387            ],
1388        )
1389        .expect("operation should succeed");
1390
1391        let selector = AdaptiveParameterSelector::balanced();
1392        let fitted = selector
1393            .fit(&data, &())
1394            .expect("model fitting should succeed");
1395
1396        let report = fitted
1397            .adaptation_report()
1398            .expect("operation should succeed");
1399        assert!(report.contains("Adaptive Parameter Selection Report"));
1400        assert!(report.contains("Data Characteristics"));
1401        assert!(report.contains("Parameter Recommendations"));
1402    }
1403
1404    #[test]
1405    fn test_insights_generation() {
1406        let data = Array2::from_shape_vec((4, 2), vec![1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0])
1407            .expect("operation should succeed");
1408
1409        let selector = AdaptiveParameterSelector::conservative();
1410        let fitted = selector
1411            .fit(&data, &())
1412            .expect("model fitting should succeed");
1413
1414        let insights = fitted.get_insights();
1415        assert!(!insights.is_empty());
1416    }
1417
1418    #[test]
1419    fn test_configuration_options() {
1420        let selector = AdaptiveParameterSelector::new()
1421            .cross_validation(false, 3)
1422            .time_budget(30.0)
1423            .parallel(false)
1424            .tolerance(1e-3);
1425
1426        assert!(!selector.config.use_cross_validation);
1427        assert_eq!(selector.config.cv_folds, 3);
1428        assert_eq!(selector.config.time_budget, Some(30.0));
1429        assert!(!selector.config.parallel);
1430        assert_relative_eq!(selector.config.tolerance, 1e-3, epsilon = 1e-10);
1431    }
1432
1433    #[test]
1434    fn test_error_handling() {
1435        let selector = AdaptiveParameterSelector::new();
1436
1437        // Test empty input
1438        let empty_data =
1439            Array2::from_shape_vec((0, 0), vec![]).expect("shape and data length should match");
1440        assert!(selector.fit(&empty_data, &()).is_err());
1441    }
1442
1443    #[test]
1444    fn test_parameter_bounds() {
1445        let mut bounds = HashMap::new();
1446        bounds.insert("outlier_threshold".to_string(), (1.0, 4.0));
1447
1448        let selector = AdaptiveParameterSelector::new().parameter_bounds(bounds.clone());
1449
1450        assert_eq!(selector.config.parameter_bounds, bounds);
1451    }
1452}