1use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
14#[cfg(feature = "serde")]
15use serde::{Deserialize, Serialize};
16use sklears_core::error::SklearsError;
17use std::collections::HashMap;
18
19use crate::{
20 dummy_classifier::Strategy as ClassifierStrategy,
21 dummy_regressor::Strategy as RegressorStrategy, CausalDiscoveryStrategy, ContextAwareStrategy,
22 EnsembleStrategy, FairnessStrategy, FewShotStrategy, RobustStrategy,
23};
24
25#[derive(Debug, Clone)]
27#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28pub struct DataCharacteristics {
29 pub n_samples: usize,
31 pub n_features: usize,
33 pub n_classes: Option<usize>,
35 pub class_balance: Option<f64>,
37 pub feature_sparsity: f64,
39 pub missing_data_ratio: f64,
41 pub outlier_ratio: f64,
43 pub noise_level: f64,
45 pub correlation_strength: f64,
47 pub temporal_dependency: bool,
49 pub categorical_features_ratio: f64,
51 pub high_dimensional: bool,
53 pub imbalanced: bool,
55 pub has_protected_attributes: bool,
57}
58
59#[derive(Debug, Clone)]
61#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
62pub struct BaselineRecommendation {
63 pub primary_strategy: BaselineType,
65 pub fallback_strategies: Vec<BaselineType>,
67 pub ensemble_recommended: bool,
69 pub preprocessing_needed: bool,
71 pub robustness_needed: bool,
73 pub fairness_considerations: bool,
75 pub confidence_score: f64,
77 pub reasoning: String,
79}
80
81#[derive(Debug, Clone)]
83#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
84pub enum BaselineType {
85 DummyClassifier(ClassifierStrategy),
87 DummyRegressor(RegressorStrategy),
89 EnsembleBaseline(EnsembleStrategy),
91 RobustBaseline(RobustStrategy),
93 ContextAwareBaseline(ContextAwareStrategy),
95 FairnessBaseline(FairnessStrategy),
97 FewShotBaseline(FewShotStrategy),
99 CausalBaseline(CausalDiscoveryStrategy),
101}
102
103#[derive(Debug, Clone)]
105#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
106pub struct PipelineConfig {
107 pub preprocessing_steps: Vec<PreprocessingStep>,
109 pub baseline_config: BaselineType,
111 pub evaluation_metrics: Vec<String>,
113 pub validation_strategy: ValidationStrategy,
115 pub output_format: OutputFormat,
117}
118
119#[derive(Debug, Clone)]
121#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
122pub enum PreprocessingStep {
123 FeatureScaling { method: String },
125 MissingDataImputation { strategy: String },
127 OutlierHandling { method: String, threshold: f64 },
129 FeatureSelection { method: String, n_features: usize },
131 DimensionalityReduction { method: String, n_components: usize },
133}
134
135#[derive(Debug, Clone)]
137#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
138pub enum ValidationStrategy {
139 HoldOut { test_size: f64 },
141 KFold { k: usize },
143 TimeSeriesSplit { n_splits: usize },
145 Stratified { n_splits: usize },
147 Bootstrap { n_samples: usize },
149}
150
151#[derive(Debug, Clone)]
153#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
154pub enum OutputFormat {
155 Predictions,
157 PredictionsWithConfidence,
159 PerformanceReport,
161 ComparativeAnalysis,
163}
164
165#[derive(Debug, Clone)]
167pub struct AutoBaselineGenerator {
168 recommendation_engine: BaselineRecommendationEngine,
169 configuration_helper: ConfigurationHelper,
170 random_state: Option<u64>,
171}
172
173#[derive(Debug)]
175pub struct BaselinePipeline {
176 config: PipelineConfig,
177 fitted_baseline: Option<Box<dyn BaselineEstimator>>,
178 preprocessing_fitted: bool,
179 random_state: Option<u64>,
180}
181
182#[derive(Debug, Clone)]
184pub struct SmartDefaultSelector {
185 selection_criteria: Vec<SelectionCriterion>,
186 fallback_strategy: BaselineType,
187 random_state: Option<u64>,
188}
189
190#[derive(Debug, Clone)]
192pub struct ConfigurationHelper {
193 parameter_defaults: HashMap<String, ParameterDefault>,
194 optimization_hints: Vec<OptimizationHint>,
195}
196
197#[derive(Debug, Clone)]
199pub struct BaselineRecommendationEngine {
200 recommendation_rules: Vec<RecommendationRule>,
201 performance_history: HashMap<String, PerformanceMetrics>,
202 adaptation_enabled: bool,
203}
204
205#[derive(Debug, Clone)]
207#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
208pub enum SelectionCriterion {
209 DataSize {
211 min_samples: usize,
212 max_samples: usize,
213 },
214 FeatureDimensionality {
216 min_features: usize,
217 max_features: usize,
218 },
219 TaskType {
221 classification: bool,
222 regression: bool,
223 },
224 PerformanceRequirement { min_accuracy: f64, max_time: f64 },
226 RobustnessRequirement { outlier_tolerance: f64 },
228}
229
230#[derive(Debug, Clone)]
232#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
233pub struct ParameterDefault {
234 pub parameter_name: String,
236 pub default_value: f64,
238 pub valid_range: (f64, f64),
240 pub description: String,
242}
243
244#[derive(Debug, Clone)]
246#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
247pub struct OptimizationHint {
248 pub context: String,
250 pub suggestion: String,
252 pub impact: String,
254 pub priority: u8,
256}
257
258#[derive(Debug, Clone)]
260#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
261pub struct RecommendationRule {
262 pub condition: String,
264 pub recommended_baseline: BaselineType,
266 pub confidence: f64,
268 pub reasoning: String,
270}
271
272#[derive(Debug, Clone)]
274#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
275pub struct PerformanceMetrics {
276 pub accuracy: f64,
278 pub precision: f64,
280 pub recall: f64,
282 pub f1_score: f64,
284 pub execution_time: f64,
286 pub memory_usage: f64,
288}
289
290pub trait BaselineEstimator: std::fmt::Debug {
292 fn fit_baseline(&mut self, x: &Array2<f64>, y: &Array1<i32>) -> Result<(), SklearsError>;
293 fn predict_baseline(&self, x: &Array2<f64>) -> Result<Array1<i32>, SklearsError>;
294 fn get_type(&self) -> BaselineType;
295}
296
297impl AutoBaselineGenerator {
298 pub fn new() -> Self {
300 Self {
301 recommendation_engine: BaselineRecommendationEngine::new(),
302 configuration_helper: ConfigurationHelper::new(),
303 random_state: None,
304 }
305 }
306
307 pub fn with_random_state(mut self, seed: u64) -> Self {
309 self.random_state = Some(seed);
310 self
311 }
312
313 pub fn analyze_and_recommend(
315 &self,
316 x: &Array2<f64>,
317 y: &Array1<i32>,
318 ) -> Result<BaselineRecommendation, SklearsError> {
319 let characteristics = self.analyze_data_characteristics(x, y);
320 let recommendation = self
321 .recommendation_engine
322 .recommend_baseline(&characteristics);
323 Ok(recommendation)
324 }
325
326 pub fn generate_baseline(
328 &self,
329 x: &Array2<f64>,
330 y: &Array1<i32>,
331 ) -> Result<BaselineType, SklearsError> {
332 let recommendation = self.analyze_and_recommend(x, y)?;
333 Ok(recommendation.primary_strategy)
334 }
335
336 fn analyze_data_characteristics(
337 &self,
338 x: &Array2<f64>,
339 y: &Array1<i32>,
340 ) -> DataCharacteristics {
341 let n_samples = x.nrows();
342 let n_features = x.ncols();
343
344 let mut class_counts = HashMap::new();
346 for &class in y.iter() {
347 *class_counts.entry(class).or_insert(0) += 1;
348 }
349 let n_classes = Some(class_counts.len());
350
351 let class_balance = if class_counts.len() > 1 {
353 let min_count = *class_counts.values().min().unwrap() as f64;
354 let max_count = *class_counts.values().max().unwrap() as f64;
355 Some(min_count / max_count)
356 } else {
357 None
358 };
359
360 let total_elements = (n_samples * n_features) as f64;
362 let zero_elements = x.iter().filter(|&&val| val.abs() < 1e-10).count() as f64;
363 let feature_sparsity = zero_elements / total_elements;
364
365 let mut outlier_count = 0;
367 for col in 0..n_features {
368 let column = x.column(col);
369 let mut sorted_col = column.to_vec();
370 sorted_col.sort_by(|a, b| a.partial_cmp(b).unwrap());
371
372 let q1_idx = sorted_col.len() / 4;
373 let q3_idx = 3 * sorted_col.len() / 4;
374
375 if q1_idx < sorted_col.len() && q3_idx < sorted_col.len() {
376 let q1 = sorted_col[q1_idx];
377 let q3 = sorted_col[q3_idx];
378 let iqr = q3 - q1;
379 let lower_bound = q1 - 1.5 * iqr;
380 let upper_bound = q3 + 1.5 * iqr;
381
382 outlier_count += column
383 .iter()
384 .filter(|&&val| val < lower_bound || val > upper_bound)
385 .count();
386 }
387 }
388 let outlier_ratio = outlier_count as f64 / total_elements;
389
390 let feature_stds: Vec<f64> = (0..n_features).map(|col| x.column(col).std(0.0)).collect();
392 let noise_level = feature_stds.iter().sum::<f64>() / feature_stds.len() as f64;
393
394 let mut correlation_sum = 0.0;
396 let mut correlation_count = 0;
397 for i in 0..n_features {
398 for j in i + 1..n_features {
399 let col_i = x.column(i);
400 let col_j = x.column(j);
401 let correlation = self.compute_correlation(&col_i, &col_j);
402 correlation_sum += correlation.abs();
403 correlation_count += 1;
404 }
405 }
406 let correlation_strength = if correlation_count > 0 {
407 correlation_sum / correlation_count as f64
408 } else {
409 0.0
410 };
411
412 DataCharacteristics {
413 n_samples,
414 n_features,
415 n_classes,
416 class_balance,
417 feature_sparsity,
418 missing_data_ratio: 0.0, outlier_ratio,
420 noise_level,
421 correlation_strength,
422 temporal_dependency: false, categorical_features_ratio: 0.0, high_dimensional: n_features > 100,
425 imbalanced: class_balance.is_some_and(|balance| balance < 0.1),
426 has_protected_attributes: false, }
428 }
429
430 fn compute_correlation(&self, x: &ArrayView1<f64>, y: &ArrayView1<f64>) -> f64 {
431 let n = x.len() as f64;
432 let mean_x = x.mean().unwrap();
433 let mean_y = y.mean().unwrap();
434
435 let numerator: f64 = x
436 .iter()
437 .zip(y.iter())
438 .map(|(xi, yi)| (xi - mean_x) * (yi - mean_y))
439 .sum();
440
441 let var_x: f64 = x.iter().map(|xi| (xi - mean_x).powi(2)).sum();
442 let var_y: f64 = y.iter().map(|yi| (yi - mean_y).powi(2)).sum();
443
444 if var_x == 0.0 || var_y == 0.0 {
445 0.0
446 } else {
447 numerator / (var_x * var_y).sqrt()
448 }
449 }
450}
451
452impl Default for AutoBaselineGenerator {
453 fn default() -> Self {
454 Self::new()
455 }
456}
457
458impl BaselineRecommendationEngine {
459 pub fn new() -> Self {
461 let mut recommendation_rules = Vec::new();
462
463 recommendation_rules.push(RecommendationRule {
465 condition: "high_dimensional".to_string(),
466 recommended_baseline: BaselineType::DummyClassifier(ClassifierStrategy::MostFrequent),
467 confidence: 0.8,
468 reasoning: "High-dimensional data benefits from simple baselines".to_string(),
469 });
470
471 recommendation_rules.push(RecommendationRule {
472 condition: "imbalanced".to_string(),
473 recommended_baseline: BaselineType::DummyClassifier(ClassifierStrategy::Stratified),
474 confidence: 0.7,
475 reasoning: "Imbalanced data requires stratified sampling".to_string(),
476 });
477
478 recommendation_rules.push(RecommendationRule {
479 condition: "high_outlier_ratio".to_string(),
480 recommended_baseline: BaselineType::RobustBaseline(RobustStrategy::TrimmedMean {
481 trim_proportion: 0.1,
482 }),
483 confidence: 0.9,
484 reasoning: "High outlier ratio requires robust methods".to_string(),
485 });
486
487 Self {
488 recommendation_rules,
489 performance_history: HashMap::new(),
490 adaptation_enabled: true,
491 }
492 }
493
494 pub fn recommend_baseline(
496 &self,
497 characteristics: &DataCharacteristics,
498 ) -> BaselineRecommendation {
499 let mut candidate_recommendations = Vec::new();
500
501 for rule in &self.recommendation_rules {
503 let matches = match rule.condition.as_str() {
504 "high_dimensional" => characteristics.high_dimensional,
505 "imbalanced" => characteristics.imbalanced,
506 "high_outlier_ratio" => characteristics.outlier_ratio > 0.1,
507 "has_protected_attributes" => characteristics.has_protected_attributes,
508 "small_dataset" => characteristics.n_samples < 1000,
509 "large_dataset" => characteristics.n_samples > 10000,
510 "high_correlation" => characteristics.correlation_strength > 0.7,
511 "sparse_features" => characteristics.feature_sparsity > 0.5,
512 _ => false,
513 };
514
515 if matches {
516 candidate_recommendations.push((rule.clone(), rule.confidence));
517 }
518 }
519
520 let (primary_rule, confidence_score) = candidate_recommendations
522 .into_iter()
523 .max_by(|(_, conf_a), (_, conf_b)| conf_a.partial_cmp(conf_b).unwrap())
524 .unwrap_or((
525 RecommendationRule {
526 condition: "default".to_string(),
527 recommended_baseline: BaselineType::DummyClassifier(
528 ClassifierStrategy::MostFrequent,
529 ),
530 confidence: 0.5,
531 reasoning: "Default baseline when no specific conditions are met".to_string(),
532 },
533 0.5,
534 ));
535
536 let fallback_strategies = vec![
538 BaselineType::DummyClassifier(ClassifierStrategy::Uniform),
539 BaselineType::EnsembleBaseline(EnsembleStrategy::Average),
540 ];
541
542 BaselineRecommendation {
543 primary_strategy: primary_rule.recommended_baseline,
544 fallback_strategies,
545 ensemble_recommended: characteristics.n_samples > 1000,
546 preprocessing_needed: characteristics.outlier_ratio > 0.05
547 || characteristics.feature_sparsity > 0.3,
548 robustness_needed: characteristics.outlier_ratio > 0.1,
549 fairness_considerations: characteristics.has_protected_attributes,
550 confidence_score,
551 reasoning: primary_rule.reasoning,
552 }
553 }
554}
555
556impl Default for BaselineRecommendationEngine {
557 fn default() -> Self {
558 Self::new()
559 }
560}
561
562impl ConfigurationHelper {
563 pub fn new() -> Self {
565 let mut parameter_defaults = HashMap::new();
566
567 parameter_defaults.insert(
569 "trim_proportion".to_string(),
570 ParameterDefault {
571 parameter_name: "trim_proportion".to_string(),
572 default_value: 0.1,
573 valid_range: (0.0, 0.5),
574 description: "Proportion of extreme values to trim for robust estimation"
575 .to_string(),
576 },
577 );
578
579 parameter_defaults.insert(
580 "ensemble_size".to_string(),
581 ParameterDefault {
582 parameter_name: "ensemble_size".to_string(),
583 default_value: 5.0,
584 valid_range: (3.0, 50.0),
585 description: "Number of base estimators in ensemble".to_string(),
586 },
587 );
588
589 let optimization_hints = vec![
590 OptimizationHint {
591 context: "high_dimensional".to_string(),
592 suggestion: "Use feature selection or dimensionality reduction".to_string(),
593 impact: "Reduces overfitting and improves computational efficiency".to_string(),
594 priority: 8,
595 },
596 OptimizationHint {
597 context: "imbalanced".to_string(),
598 suggestion: "Use stratified sampling or class weighting".to_string(),
599 impact: "Improves performance on minority classes".to_string(),
600 priority: 9,
601 },
602 ];
603
604 Self {
605 parameter_defaults,
606 optimization_hints,
607 }
608 }
609
610 pub fn get_default_config(&self, baseline_type: &BaselineType) -> HashMap<String, f64> {
612 let mut config = HashMap::new();
613
614 match baseline_type {
615 BaselineType::RobustBaseline(_) => {
616 if let Some(default) = self.parameter_defaults.get("trim_proportion") {
617 config.insert("trim_proportion".to_string(), default.default_value);
618 }
619 }
620 BaselineType::EnsembleBaseline(_) => {
621 if let Some(default) = self.parameter_defaults.get("ensemble_size") {
622 config.insert("ensemble_size".to_string(), default.default_value);
623 }
624 }
625 _ => {}
626 }
627
628 config
629 }
630
631 pub fn get_optimization_hints(
633 &self,
634 characteristics: &DataCharacteristics,
635 ) -> Vec<OptimizationHint> {
636 let mut relevant_hints = Vec::new();
637
638 for hint in &self.optimization_hints {
639 let relevant = match hint.context.as_str() {
640 "high_dimensional" => characteristics.high_dimensional,
641 "imbalanced" => characteristics.imbalanced,
642 "sparse" => characteristics.feature_sparsity > 0.5,
643 "noisy" => characteristics.noise_level > 1.0,
644 _ => false,
645 };
646
647 if relevant {
648 relevant_hints.push(hint.clone());
649 }
650 }
651
652 relevant_hints.sort_by(|a, b| b.priority.cmp(&a.priority));
654
655 relevant_hints
656 }
657}
658
659impl Default for ConfigurationHelper {
660 fn default() -> Self {
661 Self::new()
662 }
663}
664
665#[allow(non_snake_case)]
666#[cfg(test)]
667mod tests {
668 use super::*;
669 use scirs2_core::ndarray::array;
670
671 #[test]
672 fn test_auto_baseline_generator() {
673 let x = Array2::from_shape_vec((100, 5), (0..500).map(|i| i as f64).collect()).unwrap();
674 let y = Array1::from_vec((0..100).map(|i| i % 3).collect());
675
676 let generator = AutoBaselineGenerator::new();
677 let recommendation = generator.analyze_and_recommend(&x, &y).unwrap();
678
679 assert!(recommendation.confidence_score > 0.0);
680 assert!(!recommendation.reasoning.is_empty());
681 }
682
683 #[test]
684 fn test_data_characteristics_analysis() {
685 let x = Array2::from_shape_vec(
686 (50, 3),
687 vec![
688 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
689 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0,
690 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0,
691 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0,
692 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0,
693 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 80.0, 81.0, 82.0, 83.0, 84.0, 85.0,
694 86.0, 87.0, 88.0, 89.0, 90.0, 91.0, 92.0, 93.0, 94.0, 95.0, 96.0, 97.0, 98.0, 99.0,
695 100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0,
696 112.0, 113.0, 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0, 123.0,
697 124.0, 125.0, 126.0, 127.0, 128.0, 129.0, 130.0, 131.0, 132.0, 133.0, 134.0, 135.0,
698 136.0, 137.0, 138.0, 139.0, 140.0, 141.0, 142.0, 143.0, 144.0, 145.0, 146.0, 147.0,
699 148.0, 149.0, 150.0,
700 ],
701 )
702 .unwrap();
703 let y = Array1::from_vec((0..50).map(|i| i % 2).collect());
704
705 let generator = AutoBaselineGenerator::new();
706 let characteristics = generator.analyze_data_characteristics(&x, &y);
707
708 assert_eq!(characteristics.n_samples, 50);
709 assert_eq!(characteristics.n_features, 3);
710 assert_eq!(characteristics.n_classes, Some(2));
711 assert!(characteristics.class_balance.is_some());
712 }
713
714 #[test]
715 fn test_recommendation_engine() {
716 let characteristics = DataCharacteristics {
717 n_samples: 1000,
718 n_features: 50,
719 n_classes: Some(3),
720 class_balance: Some(0.8),
721 feature_sparsity: 0.1,
722 missing_data_ratio: 0.0,
723 outlier_ratio: 0.15,
724 noise_level: 0.5,
725 correlation_strength: 0.3,
726 temporal_dependency: false,
727 categorical_features_ratio: 0.0,
728 high_dimensional: false,
729 imbalanced: false,
730 has_protected_attributes: false,
731 };
732
733 let engine = BaselineRecommendationEngine::new();
734 let recommendation = engine.recommend_baseline(&characteristics);
735
736 assert!(recommendation.confidence_score > 0.0);
737 assert!(recommendation.robustness_needed); }
739
740 #[test]
741 fn test_configuration_helper() {
742 let helper = ConfigurationHelper::new();
743 let baseline_type = BaselineType::RobustBaseline(RobustStrategy::TrimmedMean {
744 trim_proportion: 0.1,
745 });
746
747 let config = helper.get_default_config(&baseline_type);
748 assert!(config.contains_key("trim_proportion"));
749
750 let characteristics = DataCharacteristics {
751 n_samples: 1000,
752 n_features: 200, n_classes: Some(2),
754 class_balance: Some(0.1), feature_sparsity: 0.0,
756 missing_data_ratio: 0.0,
757 outlier_ratio: 0.05,
758 noise_level: 0.5,
759 correlation_strength: 0.3,
760 temporal_dependency: false,
761 categorical_features_ratio: 0.0,
762 high_dimensional: true,
763 imbalanced: true,
764 has_protected_attributes: false,
765 };
766
767 let hints = helper.get_optimization_hints(&characteristics);
768 assert!(!hints.is_empty());
769 assert!(hints.iter().any(|hint| hint.context == "high_dimensional"));
770 assert!(hints.iter().any(|hint| hint.context == "imbalanced"));
771 }
772}