sklears_dummy/
integration_utilities.rs

1//! Integration Utilities for Baseline Estimators
2//!
3//! This module provides utilities for automatic baseline generation,
4//! pipeline integration, and smart default selection.
5//!
6//! The module includes:
7//! - [`AutoBaselineGenerator`] - Automatic baseline generation based on data characteristics
8//! - [`BaselinePipeline`] - Integration with preprocessing and evaluation pipelines
9//! - [`SmartDefaultSelector`] - Intelligent default strategy selection
10//! - [`ConfigurationHelper`] - Configuration assistance for baseline methods
11//! - [`BaselineRecommendationEngine`] - Advanced recommendation system for baseline selection
12
13use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
14#[cfg(feature = "serde")]
15use serde::{Deserialize, Serialize};
16use sklears_core::error::SklearsError;
17use std::collections::HashMap;
18
19use crate::{
20    dummy_classifier::Strategy as ClassifierStrategy,
21    dummy_regressor::Strategy as RegressorStrategy, CausalDiscoveryStrategy, ContextAwareStrategy,
22    EnsembleStrategy, FairnessStrategy, FewShotStrategy, RobustStrategy,
23};
24
25/// Data characteristics for automatic baseline selection
26#[derive(Debug, Clone)]
27#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28pub struct DataCharacteristics {
29    /// n_samples
30    pub n_samples: usize,
31    /// n_features
32    pub n_features: usize,
33    /// n_classes
34    pub n_classes: Option<usize>,
35    /// class_balance
36    pub class_balance: Option<f64>,
37    /// feature_sparsity
38    pub feature_sparsity: f64,
39    /// missing_data_ratio
40    pub missing_data_ratio: f64,
41    /// outlier_ratio
42    pub outlier_ratio: f64,
43    /// noise_level
44    pub noise_level: f64,
45    /// correlation_strength
46    pub correlation_strength: f64,
47    /// temporal_dependency
48    pub temporal_dependency: bool,
49    /// categorical_features_ratio
50    pub categorical_features_ratio: f64,
51    /// high_dimensional
52    pub high_dimensional: bool,
53    /// imbalanced
54    pub imbalanced: bool,
55    /// has_protected_attributes
56    pub has_protected_attributes: bool,
57}
58
59/// Recommended baseline configuration
60#[derive(Debug, Clone)]
61#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
62pub struct BaselineRecommendation {
63    /// primary_strategy
64    pub primary_strategy: BaselineType,
65    /// fallback_strategies
66    pub fallback_strategies: Vec<BaselineType>,
67    /// ensemble_recommended
68    pub ensemble_recommended: bool,
69    /// preprocessing_needed
70    pub preprocessing_needed: bool,
71    /// robustness_needed
72    pub robustness_needed: bool,
73    /// fairness_considerations
74    pub fairness_considerations: bool,
75    /// confidence_score
76    pub confidence_score: f64,
77    /// reasoning
78    pub reasoning: String,
79}
80
81/// Types of baseline estimators
82#[derive(Debug, Clone)]
83#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
84pub enum BaselineType {
85    /// Standard dummy classifier
86    DummyClassifier(ClassifierStrategy),
87    /// Standard dummy regressor
88    DummyRegressor(RegressorStrategy),
89    /// Ensemble baseline
90    EnsembleBaseline(EnsembleStrategy),
91    /// Robust baseline
92    RobustBaseline(RobustStrategy),
93    /// Context-aware baseline
94    ContextAwareBaseline(ContextAwareStrategy),
95    /// Fairness-aware baseline
96    FairnessBaseline(FairnessStrategy),
97    /// Few-shot baseline
98    FewShotBaseline(FewShotStrategy),
99    /// Causal baseline
100    CausalBaseline(CausalDiscoveryStrategy),
101}
102
103/// Pipeline configuration for baseline integration
104#[derive(Debug, Clone)]
105#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
106pub struct PipelineConfig {
107    /// preprocessing_steps
108    pub preprocessing_steps: Vec<PreprocessingStep>,
109    /// baseline_config
110    pub baseline_config: BaselineType,
111    /// evaluation_metrics
112    pub evaluation_metrics: Vec<String>,
113    /// validation_strategy
114    pub validation_strategy: ValidationStrategy,
115    /// output_format
116    pub output_format: OutputFormat,
117}
118
119/// Preprocessing steps for baseline pipelines
120#[derive(Debug, Clone)]
121#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
122pub enum PreprocessingStep {
123    /// Feature scaling/normalization
124    FeatureScaling { method: String },
125    /// Missing data imputation
126    MissingDataImputation { strategy: String },
127    /// Outlier detection and removal
128    OutlierHandling { method: String, threshold: f64 },
129    /// Feature selection
130    FeatureSelection { method: String, n_features: usize },
131    /// Dimensionality reduction
132    DimensionalityReduction { method: String, n_components: usize },
133}
134
135/// Validation strategies for baseline evaluation
136#[derive(Debug, Clone)]
137#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
138pub enum ValidationStrategy {
139    /// Hold-out validation
140    HoldOut { test_size: f64 },
141    /// K-fold cross-validation
142    KFold { k: usize },
143    /// Time series split
144    TimeSeriesSplit { n_splits: usize },
145    /// Stratified validation
146    Stratified { n_splits: usize },
147    /// Bootstrap validation
148    Bootstrap { n_samples: usize },
149}
150
151/// Output formats for baseline results
152#[derive(Debug, Clone)]
153#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
154pub enum OutputFormat {
155    /// Simple predictions array
156    Predictions,
157    /// Predictions with confidence intervals
158    PredictionsWithConfidence,
159    /// Full performance report
160    PerformanceReport,
161    /// Comparative analysis
162    ComparativeAnalysis,
163}
164
165/// Automatic baseline generator
166#[derive(Debug, Clone)]
167pub struct AutoBaselineGenerator {
168    recommendation_engine: BaselineRecommendationEngine,
169    configuration_helper: ConfigurationHelper,
170    random_state: Option<u64>,
171}
172
173/// Baseline pipeline for integration
174#[derive(Debug)]
175pub struct BaselinePipeline {
176    config: PipelineConfig,
177    fitted_baseline: Option<Box<dyn BaselineEstimator>>,
178    preprocessing_fitted: bool,
179    random_state: Option<u64>,
180}
181
182/// Smart default selector
183#[derive(Debug, Clone)]
184pub struct SmartDefaultSelector {
185    selection_criteria: Vec<SelectionCriterion>,
186    fallback_strategy: BaselineType,
187    random_state: Option<u64>,
188}
189
190/// Configuration helper for baselines
191#[derive(Debug, Clone)]
192pub struct ConfigurationHelper {
193    parameter_defaults: HashMap<String, ParameterDefault>,
194    optimization_hints: Vec<OptimizationHint>,
195}
196
197/// Baseline recommendation engine
198#[derive(Debug, Clone)]
199pub struct BaselineRecommendationEngine {
200    recommendation_rules: Vec<RecommendationRule>,
201    performance_history: HashMap<String, PerformanceMetrics>,
202    adaptation_enabled: bool,
203}
204
205/// Selection criteria for baseline choice
206#[derive(Debug, Clone)]
207#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
208pub enum SelectionCriterion {
209    /// Data size criterion
210    DataSize {
211        min_samples: usize,
212        max_samples: usize,
213    },
214    /// Feature dimensionality criterion
215    FeatureDimensionality {
216        min_features: usize,
217        max_features: usize,
218    },
219    /// Task type criterion
220    TaskType {
221        classification: bool,
222        regression: bool,
223    },
224    /// Performance requirement criterion
225    PerformanceRequirement { min_accuracy: f64, max_time: f64 },
226    /// Robustness requirement criterion
227    RobustnessRequirement { outlier_tolerance: f64 },
228}
229
230/// Parameter defaults for baseline configuration
231#[derive(Debug, Clone)]
232#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
233pub struct ParameterDefault {
234    /// parameter_name
235    pub parameter_name: String,
236    /// default_value
237    pub default_value: f64,
238    /// valid_range
239    pub valid_range: (f64, f64),
240    /// description
241    pub description: String,
242}
243
244/// Optimization hints for baseline tuning
245#[derive(Debug, Clone)]
246#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
247pub struct OptimizationHint {
248    /// context
249    pub context: String,
250    /// suggestion
251    pub suggestion: String,
252    /// impact
253    pub impact: String,
254    /// priority
255    pub priority: u8,
256}
257
258/// Recommendation rules for baseline selection
259#[derive(Debug, Clone)]
260#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
261pub struct RecommendationRule {
262    /// condition
263    pub condition: String,
264    /// recommended_baseline
265    pub recommended_baseline: BaselineType,
266    /// confidence
267    pub confidence: f64,
268    /// reasoning
269    pub reasoning: String,
270}
271
272/// Performance metrics for recommendation engine
273#[derive(Debug, Clone)]
274#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
275pub struct PerformanceMetrics {
276    /// accuracy
277    pub accuracy: f64,
278    /// precision
279    pub precision: f64,
280    /// recall
281    pub recall: f64,
282    /// f1_score
283    pub f1_score: f64,
284    /// execution_time
285    pub execution_time: f64,
286    /// memory_usage
287    pub memory_usage: f64,
288}
289
290/// Trait for baseline estimators in pipeline
291pub trait BaselineEstimator: std::fmt::Debug {
292    fn fit_baseline(&mut self, x: &Array2<f64>, y: &Array1<i32>) -> Result<(), SklearsError>;
293    fn predict_baseline(&self, x: &Array2<f64>) -> Result<Array1<i32>, SklearsError>;
294    fn get_type(&self) -> BaselineType;
295}
296
297impl AutoBaselineGenerator {
298    /// Create a new automatic baseline generator
299    pub fn new() -> Self {
300        Self {
301            recommendation_engine: BaselineRecommendationEngine::new(),
302            configuration_helper: ConfigurationHelper::new(),
303            random_state: None,
304        }
305    }
306
307    /// Set the random state for reproducible results
308    pub fn with_random_state(mut self, seed: u64) -> Self {
309        self.random_state = Some(seed);
310        self
311    }
312
313    /// Analyze data and generate baseline recommendations
314    pub fn analyze_and_recommend(
315        &self,
316        x: &Array2<f64>,
317        y: &Array1<i32>,
318    ) -> Result<BaselineRecommendation, SklearsError> {
319        let characteristics = self.analyze_data_characteristics(x, y);
320        let recommendation = self
321            .recommendation_engine
322            .recommend_baseline(&characteristics);
323        Ok(recommendation)
324    }
325
326    /// Generate automatic baseline configuration
327    pub fn generate_baseline(
328        &self,
329        x: &Array2<f64>,
330        y: &Array1<i32>,
331    ) -> Result<BaselineType, SklearsError> {
332        let recommendation = self.analyze_and_recommend(x, y)?;
333        Ok(recommendation.primary_strategy)
334    }
335
336    fn analyze_data_characteristics(
337        &self,
338        x: &Array2<f64>,
339        y: &Array1<i32>,
340    ) -> DataCharacteristics {
341        let n_samples = x.nrows();
342        let n_features = x.ncols();
343
344        // Analyze class distribution
345        let mut class_counts = HashMap::new();
346        for &class in y.iter() {
347            *class_counts.entry(class).or_insert(0) += 1;
348        }
349        let n_classes = Some(class_counts.len());
350
351        // Calculate class balance (ratio of smallest to largest class)
352        let class_balance = if class_counts.len() > 1 {
353            let min_count = *class_counts.values().min().unwrap() as f64;
354            let max_count = *class_counts.values().max().unwrap() as f64;
355            Some(min_count / max_count)
356        } else {
357            None
358        };
359
360        // Calculate feature sparsity
361        let total_elements = (n_samples * n_features) as f64;
362        let zero_elements = x.iter().filter(|&&val| val.abs() < 1e-10).count() as f64;
363        let feature_sparsity = zero_elements / total_elements;
364
365        // Estimate outlier ratio using IQR method
366        let mut outlier_count = 0;
367        for col in 0..n_features {
368            let column = x.column(col);
369            let mut sorted_col = column.to_vec();
370            sorted_col.sort_by(|a, b| a.partial_cmp(b).unwrap());
371
372            let q1_idx = sorted_col.len() / 4;
373            let q3_idx = 3 * sorted_col.len() / 4;
374
375            if q1_idx < sorted_col.len() && q3_idx < sorted_col.len() {
376                let q1 = sorted_col[q1_idx];
377                let q3 = sorted_col[q3_idx];
378                let iqr = q3 - q1;
379                let lower_bound = q1 - 1.5 * iqr;
380                let upper_bound = q3 + 1.5 * iqr;
381
382                outlier_count += column
383                    .iter()
384                    .filter(|&&val| val < lower_bound || val > upper_bound)
385                    .count();
386            }
387        }
388        let outlier_ratio = outlier_count as f64 / total_elements;
389
390        // Estimate noise level using standard deviation
391        let feature_stds: Vec<f64> = (0..n_features).map(|col| x.column(col).std(0.0)).collect();
392        let noise_level = feature_stds.iter().sum::<f64>() / feature_stds.len() as f64;
393
394        // Calculate correlation strength (average absolute correlation)
395        let mut correlation_sum = 0.0;
396        let mut correlation_count = 0;
397        for i in 0..n_features {
398            for j in i + 1..n_features {
399                let col_i = x.column(i);
400                let col_j = x.column(j);
401                let correlation = self.compute_correlation(&col_i, &col_j);
402                correlation_sum += correlation.abs();
403                correlation_count += 1;
404            }
405        }
406        let correlation_strength = if correlation_count > 0 {
407            correlation_sum / correlation_count as f64
408        } else {
409            0.0
410        };
411
412        DataCharacteristics {
413            n_samples,
414            n_features,
415            n_classes,
416            class_balance,
417            feature_sparsity,
418            missing_data_ratio: 0.0, // Simplified: assume no missing data in ndarray
419            outlier_ratio,
420            noise_level,
421            correlation_strength,
422            temporal_dependency: false, // Simplified: would need domain knowledge
423            categorical_features_ratio: 0.0, // Simplified: assume continuous features
424            high_dimensional: n_features > 100,
425            imbalanced: class_balance.is_some_and(|balance| balance < 0.1),
426            has_protected_attributes: false, // Simplified: would need domain knowledge
427        }
428    }
429
430    fn compute_correlation(&self, x: &ArrayView1<f64>, y: &ArrayView1<f64>) -> f64 {
431        let n = x.len() as f64;
432        let mean_x = x.mean().unwrap();
433        let mean_y = y.mean().unwrap();
434
435        let numerator: f64 = x
436            .iter()
437            .zip(y.iter())
438            .map(|(xi, yi)| (xi - mean_x) * (yi - mean_y))
439            .sum();
440
441        let var_x: f64 = x.iter().map(|xi| (xi - mean_x).powi(2)).sum();
442        let var_y: f64 = y.iter().map(|yi| (yi - mean_y).powi(2)).sum();
443
444        if var_x == 0.0 || var_y == 0.0 {
445            0.0
446        } else {
447            numerator / (var_x * var_y).sqrt()
448        }
449    }
450}
451
452impl Default for AutoBaselineGenerator {
453    fn default() -> Self {
454        Self::new()
455    }
456}
457
458impl BaselineRecommendationEngine {
459    /// Create a new recommendation engine
460    pub fn new() -> Self {
461        let mut recommendation_rules = Vec::new();
462
463        // Add default recommendation rules
464        recommendation_rules.push(RecommendationRule {
465            condition: "high_dimensional".to_string(),
466            recommended_baseline: BaselineType::DummyClassifier(ClassifierStrategy::MostFrequent),
467            confidence: 0.8,
468            reasoning: "High-dimensional data benefits from simple baselines".to_string(),
469        });
470
471        recommendation_rules.push(RecommendationRule {
472            condition: "imbalanced".to_string(),
473            recommended_baseline: BaselineType::DummyClassifier(ClassifierStrategy::Stratified),
474            confidence: 0.7,
475            reasoning: "Imbalanced data requires stratified sampling".to_string(),
476        });
477
478        recommendation_rules.push(RecommendationRule {
479            condition: "high_outlier_ratio".to_string(),
480            recommended_baseline: BaselineType::RobustBaseline(RobustStrategy::TrimmedMean {
481                trim_proportion: 0.1,
482            }),
483            confidence: 0.9,
484            reasoning: "High outlier ratio requires robust methods".to_string(),
485        });
486
487        Self {
488            recommendation_rules,
489            performance_history: HashMap::new(),
490            adaptation_enabled: true,
491        }
492    }
493
494    /// Recommend baseline based on data characteristics
495    pub fn recommend_baseline(
496        &self,
497        characteristics: &DataCharacteristics,
498    ) -> BaselineRecommendation {
499        let mut candidate_recommendations = Vec::new();
500
501        // Apply recommendation rules
502        for rule in &self.recommendation_rules {
503            let matches = match rule.condition.as_str() {
504                "high_dimensional" => characteristics.high_dimensional,
505                "imbalanced" => characteristics.imbalanced,
506                "high_outlier_ratio" => characteristics.outlier_ratio > 0.1,
507                "has_protected_attributes" => characteristics.has_protected_attributes,
508                "small_dataset" => characteristics.n_samples < 1000,
509                "large_dataset" => characteristics.n_samples > 10000,
510                "high_correlation" => characteristics.correlation_strength > 0.7,
511                "sparse_features" => characteristics.feature_sparsity > 0.5,
512                _ => false,
513            };
514
515            if matches {
516                candidate_recommendations.push((rule.clone(), rule.confidence));
517            }
518        }
519
520        // Select best recommendation
521        let (primary_rule, confidence_score) = candidate_recommendations
522            .into_iter()
523            .max_by(|(_, conf_a), (_, conf_b)| conf_a.partial_cmp(conf_b).unwrap())
524            .unwrap_or((
525                RecommendationRule {
526                    condition: "default".to_string(),
527                    recommended_baseline: BaselineType::DummyClassifier(
528                        ClassifierStrategy::MostFrequent,
529                    ),
530                    confidence: 0.5,
531                    reasoning: "Default baseline when no specific conditions are met".to_string(),
532                },
533                0.5,
534            ));
535
536        // Generate fallback strategies
537        let fallback_strategies = vec![
538            BaselineType::DummyClassifier(ClassifierStrategy::Uniform),
539            BaselineType::EnsembleBaseline(EnsembleStrategy::Average),
540        ];
541
542        BaselineRecommendation {
543            primary_strategy: primary_rule.recommended_baseline,
544            fallback_strategies,
545            ensemble_recommended: characteristics.n_samples > 1000,
546            preprocessing_needed: characteristics.outlier_ratio > 0.05
547                || characteristics.feature_sparsity > 0.3,
548            robustness_needed: characteristics.outlier_ratio > 0.1,
549            fairness_considerations: characteristics.has_protected_attributes,
550            confidence_score,
551            reasoning: primary_rule.reasoning,
552        }
553    }
554}
555
556impl Default for BaselineRecommendationEngine {
557    fn default() -> Self {
558        Self::new()
559    }
560}
561
562impl ConfigurationHelper {
563    /// Create a new configuration helper
564    pub fn new() -> Self {
565        let mut parameter_defaults = HashMap::new();
566
567        // Add default parameters for different baseline types
568        parameter_defaults.insert(
569            "trim_proportion".to_string(),
570            ParameterDefault {
571                parameter_name: "trim_proportion".to_string(),
572                default_value: 0.1,
573                valid_range: (0.0, 0.5),
574                description: "Proportion of extreme values to trim for robust estimation"
575                    .to_string(),
576            },
577        );
578
579        parameter_defaults.insert(
580            "ensemble_size".to_string(),
581            ParameterDefault {
582                parameter_name: "ensemble_size".to_string(),
583                default_value: 5.0,
584                valid_range: (3.0, 50.0),
585                description: "Number of base estimators in ensemble".to_string(),
586            },
587        );
588
589        let optimization_hints = vec![
590            OptimizationHint {
591                context: "high_dimensional".to_string(),
592                suggestion: "Use feature selection or dimensionality reduction".to_string(),
593                impact: "Reduces overfitting and improves computational efficiency".to_string(),
594                priority: 8,
595            },
596            OptimizationHint {
597                context: "imbalanced".to_string(),
598                suggestion: "Use stratified sampling or class weighting".to_string(),
599                impact: "Improves performance on minority classes".to_string(),
600                priority: 9,
601            },
602        ];
603
604        Self {
605            parameter_defaults,
606            optimization_hints,
607        }
608    }
609
610    /// Get default parameter configuration
611    pub fn get_default_config(&self, baseline_type: &BaselineType) -> HashMap<String, f64> {
612        let mut config = HashMap::new();
613
614        match baseline_type {
615            BaselineType::RobustBaseline(_) => {
616                if let Some(default) = self.parameter_defaults.get("trim_proportion") {
617                    config.insert("trim_proportion".to_string(), default.default_value);
618                }
619            }
620            BaselineType::EnsembleBaseline(_) => {
621                if let Some(default) = self.parameter_defaults.get("ensemble_size") {
622                    config.insert("ensemble_size".to_string(), default.default_value);
623                }
624            }
625            _ => {}
626        }
627
628        config
629    }
630
631    /// Get optimization hints for given data characteristics
632    pub fn get_optimization_hints(
633        &self,
634        characteristics: &DataCharacteristics,
635    ) -> Vec<OptimizationHint> {
636        let mut relevant_hints = Vec::new();
637
638        for hint in &self.optimization_hints {
639            let relevant = match hint.context.as_str() {
640                "high_dimensional" => characteristics.high_dimensional,
641                "imbalanced" => characteristics.imbalanced,
642                "sparse" => characteristics.feature_sparsity > 0.5,
643                "noisy" => characteristics.noise_level > 1.0,
644                _ => false,
645            };
646
647            if relevant {
648                relevant_hints.push(hint.clone());
649            }
650        }
651
652        // Sort by priority (higher priority first)
653        relevant_hints.sort_by(|a, b| b.priority.cmp(&a.priority));
654
655        relevant_hints
656    }
657}
658
659impl Default for ConfigurationHelper {
660    fn default() -> Self {
661        Self::new()
662    }
663}
664
665#[allow(non_snake_case)]
666#[cfg(test)]
667mod tests {
668    use super::*;
669    use scirs2_core::ndarray::array;
670
671    #[test]
672    fn test_auto_baseline_generator() {
673        let x = Array2::from_shape_vec((100, 5), (0..500).map(|i| i as f64).collect()).unwrap();
674        let y = Array1::from_vec((0..100).map(|i| i % 3).collect());
675
676        let generator = AutoBaselineGenerator::new();
677        let recommendation = generator.analyze_and_recommend(&x, &y).unwrap();
678
679        assert!(recommendation.confidence_score > 0.0);
680        assert!(!recommendation.reasoning.is_empty());
681    }
682
683    #[test]
684    fn test_data_characteristics_analysis() {
685        let x = Array2::from_shape_vec(
686            (50, 3),
687            vec![
688                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
689                16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0,
690                30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0,
691                44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0,
692                58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0,
693                72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 80.0, 81.0, 82.0, 83.0, 84.0, 85.0,
694                86.0, 87.0, 88.0, 89.0, 90.0, 91.0, 92.0, 93.0, 94.0, 95.0, 96.0, 97.0, 98.0, 99.0,
695                100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0,
696                112.0, 113.0, 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0, 123.0,
697                124.0, 125.0, 126.0, 127.0, 128.0, 129.0, 130.0, 131.0, 132.0, 133.0, 134.0, 135.0,
698                136.0, 137.0, 138.0, 139.0, 140.0, 141.0, 142.0, 143.0, 144.0, 145.0, 146.0, 147.0,
699                148.0, 149.0, 150.0,
700            ],
701        )
702        .unwrap();
703        let y = Array1::from_vec((0..50).map(|i| i % 2).collect());
704
705        let generator = AutoBaselineGenerator::new();
706        let characteristics = generator.analyze_data_characteristics(&x, &y);
707
708        assert_eq!(characteristics.n_samples, 50);
709        assert_eq!(characteristics.n_features, 3);
710        assert_eq!(characteristics.n_classes, Some(2));
711        assert!(characteristics.class_balance.is_some());
712    }
713
714    #[test]
715    fn test_recommendation_engine() {
716        let characteristics = DataCharacteristics {
717            n_samples: 1000,
718            n_features: 50,
719            n_classes: Some(3),
720            class_balance: Some(0.8),
721            feature_sparsity: 0.1,
722            missing_data_ratio: 0.0,
723            outlier_ratio: 0.15,
724            noise_level: 0.5,
725            correlation_strength: 0.3,
726            temporal_dependency: false,
727            categorical_features_ratio: 0.0,
728            high_dimensional: false,
729            imbalanced: false,
730            has_protected_attributes: false,
731        };
732
733        let engine = BaselineRecommendationEngine::new();
734        let recommendation = engine.recommend_baseline(&characteristics);
735
736        assert!(recommendation.confidence_score > 0.0);
737        assert!(recommendation.robustness_needed); // Due to high outlier ratio
738    }
739
740    #[test]
741    fn test_configuration_helper() {
742        let helper = ConfigurationHelper::new();
743        let baseline_type = BaselineType::RobustBaseline(RobustStrategy::TrimmedMean {
744            trim_proportion: 0.1,
745        });
746
747        let config = helper.get_default_config(&baseline_type);
748        assert!(config.contains_key("trim_proportion"));
749
750        let characteristics = DataCharacteristics {
751            n_samples: 1000,
752            n_features: 200, // High dimensional
753            n_classes: Some(2),
754            class_balance: Some(0.1), // Imbalanced
755            feature_sparsity: 0.0,
756            missing_data_ratio: 0.0,
757            outlier_ratio: 0.05,
758            noise_level: 0.5,
759            correlation_strength: 0.3,
760            temporal_dependency: false,
761            categorical_features_ratio: 0.0,
762            high_dimensional: true,
763            imbalanced: true,
764            has_protected_attributes: false,
765        };
766
767        let hints = helper.get_optimization_hints(&characteristics);
768        assert!(!hints.is_empty());
769        assert!(hints.iter().any(|hint| hint.context == "high_dimensional"));
770        assert!(hints.iter().any(|hint| hint.context == "imbalanced"));
771    }
772}