scirs2_datasets/
ml_integration.rs

1//! Machine learning pipeline integration
2//!
3//! This module provides integration utilities for common ML frameworks and pipelines:
4//! - Model training data preparation
5//! - Cross-validation utilities
6//! - Feature engineering pipelines
7//! - Model evaluation and metrics
8//! - Integration with popular ML libraries
9
10use std::collections::HashMap;
11
12use scirs2_core::ndarray::{Array1, Array2, Axis};
13use scirs2_core::random::prelude::*;
14use scirs2_core::random::SliceRandomExt;
15use scirs2_core::random::Uniform;
16use serde::{Deserialize, Serialize};
17
18use crate::error::{DatasetsError, Result};
19use crate::utils::{BalancingStrategy, CrossValidationFolds, Dataset};
20
21/// Configuration for ML pipeline integration
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct MLPipelineConfig {
24    /// Random seed for reproducibility
25    pub random_state: Option<u64>,
26    /// Default test size for train/test splits
27    pub test_size: f64,
28    /// Number of folds for cross-validation
29    pub cv_folds: usize,
30    /// Whether to stratify splits for classification
31    pub stratify: bool,
32    /// Data balancing strategy
33    pub balancing_strategy: Option<BalancingStrategy>,
34    /// Feature scaling method
35    pub scaling_method: Option<ScalingMethod>,
36}
37
38impl Default for MLPipelineConfig {
39    fn default() -> Self {
40        Self {
41            random_state: Some(42),
42            test_size: 0.2,
43            cv_folds: 5,
44            stratify: true,
45            balancing_strategy: None,
46            scaling_method: Some(ScalingMethod::StandardScaler),
47        }
48    }
49}
50
51/// Feature scaling methods
52#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
53pub enum ScalingMethod {
54    /// Z-score normalization
55    StandardScaler,
56    /// Min-max scaling to [0, 1]
57    MinMaxScaler,
58    /// Robust scaling using median and MAD
59    RobustScaler,
60    /// No scaling
61    None,
62}
63
64/// ML pipeline for data preprocessing and preparation
65pub struct MLPipeline {
66    config: MLPipelineConfig,
67    fitted_scalers: Option<HashMap<String, ScalerParams>>,
68}
69
70/// Parameters for fitted scalers
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ScalerParams {
73    /// Scaling method used
74    pub method: ScalingMethod,
75    /// Mean value (for StandardScaler)
76    pub mean: Option<f64>,
77    /// Standard deviation (for StandardScaler)
78    pub std: Option<f64>,
79    /// Minimum value (for MinMaxScaler)
80    pub min: Option<f64>,
81    /// Maximum value (for MinMaxScaler)
82    pub max: Option<f64>,
83    /// Median value (for RobustScaler)
84    pub median: Option<f64>,
85    /// Median absolute deviation (for RobustScaler)
86    pub mad: Option<f64>,
87}
88
89/// ML experiment tracking
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct MLExperiment {
92    /// Experiment name
93    pub name: String,
94    /// Dataset information
95    pub dataset_info: DatasetInfo,
96    /// Model configuration
97    pub model_config: ModelConfig,
98    /// Training results
99    pub results: ExperimentResults,
100    /// Cross-validation scores
101    pub cv_scores: Option<CrossValidationResults>,
102}
103
104/// Dataset information for experiments
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct DatasetInfo {
107    /// Number of samples in the dataset
108    pub n_samples: usize,
109    /// Number of features in the dataset
110    pub n_features: usize,
111    /// Number of classes (for classification tasks)
112    pub n_classes: Option<usize>,
113    /// Distribution of classes in the dataset
114    pub class_distribution: Option<HashMap<String, usize>>,
115    /// Percentage of missing data
116    pub missing_data_percentage: f64,
117}
118
119/// Model configuration
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct ModelConfig {
122    /// Type of ML model used
123    pub model_type: String,
124    /// Hyperparameter settings
125    pub hyperparameters: HashMap<String, serde_json::Value>,
126    /// List of preprocessing steps applied
127    pub preprocessing_steps: Vec<String>,
128}
129
130/// Experiment results
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct ExperimentResults {
133    /// Score on training data
134    pub training_score: f64,
135    /// Score on validation data
136    pub validation_score: f64,
137    /// Score on test data (if available)
138    pub test_score: Option<f64>,
139    /// Time taken for training (in seconds)
140    pub training_time: f64,
141    /// Average inference time per sample (in milliseconds)
142    pub inference_time: Option<f64>,
143    /// Feature importance scores
144    pub feature_importance: Option<Vec<(String, f64)>>,
145}
146
147/// Cross-validation results
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct CrossValidationResults {
150    /// Individual scores for each fold
151    pub scores: Vec<f64>,
152    /// Mean score across all folds
153    pub mean_score: f64,
154    /// Standard deviation of scores
155    pub std_score: f64,
156    /// Detailed results for each fold
157    pub fold_details: Vec<FoldResult>,
158}
159
160/// Result for a single fold
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct FoldResult {
163    /// Index of the fold
164    pub fold_index: usize,
165    /// Training score for this fold
166    pub train_score: f64,
167    /// Validation score for this fold
168    pub validation_score: f64,
169    /// Training time in seconds for this fold
170    pub training_time: f64,
171}
172
173/// Data split for ML training
174#[derive(Debug, Clone)]
175pub struct DataSplit {
176    /// Training features
177    pub x_train: Array2<f64>,
178    /// Testing features
179    pub x_test: Array2<f64>,
180    /// Training targets
181    pub y_train: Array1<f64>,
182    /// Testing targets
183    pub y_test: Array1<f64>,
184}
185
186impl Default for MLPipeline {
187    fn default() -> Self {
188        Self::new(MLPipelineConfig::default())
189    }
190}
191
192impl MLPipeline {
193    /// Create a new ML pipeline
194    pub fn new(config: MLPipelineConfig) -> Self {
195        Self {
196            config,
197            fitted_scalers: None,
198        }
199    }
200
201    /// Prepare dataset for ML training
202    pub fn prepare_dataset(&mut self, dataset: &Dataset) -> Result<Dataset> {
203        let mut prepared = dataset.clone();
204
205        // Apply balancing if specified
206        if let Some(ref strategy) = self.config.balancing_strategy {
207            prepared = self.apply_balancing(&prepared, strategy)?;
208        }
209
210        // Apply scaling if specified
211        if let Some(method) = self.config.scaling_method {
212            prepared = self.fit_and_transform_scaling(&prepared, method)?;
213        }
214
215        Ok(prepared)
216    }
217
218    /// Split dataset into train/test sets
219    pub fn train_test_split(&self, dataset: &Dataset) -> Result<DataSplit> {
220        let n_samples = dataset.n_samples();
221        let test_samples = (n_samples as f64 * self.config.test_size) as usize;
222        let train_samples = n_samples - test_samples;
223
224        let indices = self.generate_split_indices(n_samples, dataset.target.as_ref())?;
225
226        let train_indices = &indices[..train_samples];
227        let test_indices = &indices[train_samples..];
228
229        let x_train = dataset.data.select(Axis(0), train_indices);
230        let x_test = dataset.data.select(Axis(0), test_indices);
231
232        let (y_train, y_test) = if let Some(ref target) = dataset.target {
233            let y_train = target.select(Axis(0), train_indices);
234            let y_test = target.select(Axis(0), test_indices);
235            (y_train, y_test)
236        } else {
237            return Err(DatasetsError::InvalidFormat(
238                "Target variable required for train/test split".to_string(),
239            ));
240        };
241
242        Ok(DataSplit {
243            x_train,
244            x_test,
245            y_train,
246            y_test,
247        })
248    }
249
250    /// Generate cross-validation folds
251    pub fn cross_validation_split(&self, dataset: &Dataset) -> Result<CrossValidationFolds> {
252        let target = dataset.target.as_ref().ok_or_else(|| {
253            DatasetsError::InvalidFormat(
254                "Target variable required for cross-validation".to_string(),
255            )
256        })?;
257
258        if self.config.stratify {
259            crate::utils::stratified_k_fold_split(
260                target,
261                self.config.cv_folds,
262                true,
263                self.config.random_state,
264            )
265        } else {
266            crate::utils::k_fold_split(
267                dataset.n_samples(),
268                self.config.cv_folds,
269                true,
270                self.config.random_state,
271            )
272        }
273    }
274
275    /// Transform new data using fitted scalers
276    pub fn transform(&self, dataset: &Dataset) -> Result<Dataset> {
277        let scalers = self.fitted_scalers.as_ref().ok_or_else(|| {
278            DatasetsError::InvalidFormat(
279                "Pipeline not fitted. Call prepare_dataset first.".to_string(),
280            )
281        })?;
282
283        let mut transformed_data = dataset.data.clone();
284
285        for (col_idx, mut column) in transformed_data.columns_mut().into_iter().enumerate() {
286            let defaultname = format!("feature_{col_idx}");
287            let featurename = dataset
288                .featurenames
289                .as_ref()
290                .and_then(|names| names.get(col_idx))
291                .map(|s| s.as_str())
292                .unwrap_or(&defaultname);
293
294            if let Some(scaler) = scalers.get(featurename) {
295                Self::apply_scaler_to_column(&mut column, scaler)?;
296            }
297        }
298
299        Ok(Dataset {
300            data: transformed_data,
301            target: dataset.target.clone(),
302            featurenames: dataset.featurenames.clone(),
303            targetnames: dataset.targetnames.clone(),
304            feature_descriptions: dataset.feature_descriptions.clone(),
305            description: Some("Transformed dataset".to_string()),
306            metadata: dataset.metadata.clone(),
307        })
308    }
309
310    /// Create an ML experiment tracker
311    pub fn create_experiment(&self, name: &str, dataset: &Dataset) -> MLExperiment {
312        let dataset_info = self.extract_dataset_info(dataset);
313
314        MLExperiment {
315            name: name.to_string(),
316            dataset_info,
317            model_config: ModelConfig {
318                model_type: "undefined".to_string(),
319                hyperparameters: HashMap::new(),
320                preprocessing_steps: Vec::new(),
321            },
322            results: ExperimentResults {
323                training_score: 0.0,
324                validation_score: 0.0,
325                test_score: None,
326                training_time: 0.0,
327                inference_time: None,
328                feature_importance: None,
329            },
330            cv_scores: None,
331        }
332    }
333
334    /// Evaluate model performance with cross-validation
335    pub fn evaluate_with_cv<F>(
336        &self,
337        dataset: &Dataset,
338        train_fn: F,
339    ) -> Result<CrossValidationResults>
340    where
341        F: Fn(&Array2<f64>, &Array1<f64>, &Array2<f64>, &Array1<f64>) -> Result<(f64, f64, f64)>,
342    {
343        let folds = self.cross_validation_split(dataset)?;
344        let mut scores = Vec::new();
345        let mut fold_details = Vec::new();
346
347        for (fold_idx, (train_indices, val_indices)) in folds.into_iter().enumerate() {
348            let x_train = dataset.data.select(Axis(0), &train_indices);
349            let x_val = dataset.data.select(Axis(0), &val_indices);
350
351            let target = dataset.target.as_ref().unwrap();
352            let y_train = target.select(Axis(0), &train_indices);
353            let y_val = target.select(Axis(0), &val_indices);
354
355            let (train_score, val_score, training_time) =
356                train_fn(&x_train, &y_train, &x_val, &y_val)?;
357
358            scores.push(val_score);
359            fold_details.push(FoldResult {
360                fold_index: fold_idx,
361                train_score,
362                validation_score: val_score,
363                training_time,
364            });
365        }
366
367        let mean_score = scores.iter().sum::<f64>() / scores.len() as f64;
368        let variance = scores
369            .iter()
370            .map(|score| (score - mean_score).powi(2))
371            .sum::<f64>()
372            / scores.len() as f64;
373        let std_score = variance.sqrt();
374
375        Ok(CrossValidationResults {
376            scores,
377            mean_score,
378            std_score,
379            fold_details,
380        })
381    }
382
383    // Private helper methods
384
385    fn apply_balancing(&self, dataset: &Dataset, strategy: &BalancingStrategy) -> Result<Dataset> {
386        // Simplified balancing implementation
387        // In a full implementation, you'd use the actual balancing utilities
388        match strategy {
389            BalancingStrategy::RandomUndersample => self.random_undersample(dataset, None),
390            BalancingStrategy::RandomOversample => self.random_oversample(dataset, None),
391            _ => Ok(dataset.clone()), // Placeholder for other strategies
392        }
393    }
394
395    fn random_undersample(&self, dataset: &Dataset, _randomstate: Option<u64>) -> Result<Dataset> {
396        let target = dataset.target.as_ref().ok_or_else(|| {
397            DatasetsError::InvalidFormat("Target required for balancing".to_string())
398        })?;
399
400        // Find minority class size
401        let mut class_counts: HashMap<i64, usize> = HashMap::new();
402        for &value in target.iter() {
403            if !value.is_nan() {
404                *class_counts.entry(value as i64).or_insert(0) += 1;
405            }
406        }
407
408        let min_count = class_counts.values().min().copied().unwrap_or(0);
409
410        // Sample min_count samples from each class
411        let mut selected_indices = Vec::new();
412
413        for (class_, _count) in class_counts {
414            let class_indices: Vec<usize> = target
415                .iter()
416                .enumerate()
417                .filter(|(_, &val)| !val.is_nan() && val as i64 == class_)
418                .map(|(idx, _)| idx)
419                .collect();
420
421            let mut sampled_indices = class_indices;
422            if sampled_indices.len() > min_count {
423                // Simple random sampling (in a real implementation, use proper random sampling)
424                sampled_indices.truncate(min_count);
425            }
426
427            selected_indices.extend(sampled_indices);
428        }
429
430        let balanced_data = dataset.data.select(Axis(0), &selected_indices);
431        let balanced_target = target.select(Axis(0), &selected_indices);
432
433        Ok(Dataset {
434            data: balanced_data,
435            target: Some(balanced_target),
436            featurenames: dataset.featurenames.clone(),
437            targetnames: dataset.targetnames.clone(),
438            feature_descriptions: dataset.feature_descriptions.clone(),
439            description: Some("Undersampled dataset".to_string()),
440            metadata: dataset.metadata.clone(),
441        })
442    }
443
444    fn random_oversample(&self, dataset: &Dataset, randomstate: Option<u64>) -> Result<Dataset> {
445        use scirs2_core::random::prelude::*;
446        use scirs2_core::random::{rngs::StdRng, RngCore, SeedableRng};
447        use std::collections::HashMap;
448
449        let target = dataset.target.as_ref().ok_or_else(|| {
450            DatasetsError::InvalidFormat("Random oversampling requires target labels".to_string())
451        })?;
452
453        if target.len() != dataset.data.nrows() {
454            return Err(DatasetsError::InvalidFormat(
455                "Target length must match number of samples".to_string(),
456            ));
457        }
458
459        // Count samples per class
460        let mut class_counts: HashMap<i32, usize> = HashMap::new();
461        let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
462
463        for (idx, &label) in target.iter().enumerate() {
464            let class = label as i32;
465            *class_counts.entry(class).or_insert(0) += 1;
466            class_indices.entry(class).or_default().push(idx);
467        }
468
469        // Find the majority class size (the maximum count)
470        let max_count = class_counts.values().max().copied().unwrap_or(0);
471
472        if max_count == 0 {
473            return Err(DatasetsError::InvalidFormat(
474                "No samples found in dataset".to_string(),
475            ));
476        }
477
478        // Create RNG
479        let mut rng: Box<dyn RngCore> = match randomstate {
480            Some(seed) => Box::new(StdRng::seed_from_u64(seed)),
481            None => Box::new(thread_rng()),
482        };
483
484        // Collect all indices for the oversampled dataset
485        let mut all_indices = Vec::new();
486
487        for (_class, indices) in class_indices.iter() {
488            let current_count = indices.len();
489
490            // Add all original samples
491            all_indices.extend(indices.iter().copied());
492
493            // Add additional samples by random oversampling with replacement
494            let samples_needed = max_count - current_count;
495
496            if samples_needed > 0 {
497                for _ in 0..samples_needed {
498                    let random_idx = rng.sample(Uniform::new(0, indices.len()).unwrap());
499                    all_indices.push(indices[random_idx]);
500                }
501            }
502        }
503
504        // Shuffle the final indices to mix classes
505        all_indices.shuffle(&mut *rng);
506
507        // Create the oversampled dataset
508        let oversampled_data = dataset.data.select(Axis(0), &all_indices);
509        let oversampled_target = target.select(Axis(0), &all_indices);
510
511        Ok(Dataset {
512            data: oversampled_data,
513            target: Some(oversampled_target),
514            featurenames: dataset.featurenames.clone(),
515            targetnames: dataset.targetnames.clone(),
516            feature_descriptions: dataset.feature_descriptions.clone(),
517            description: Some(format!(
518                "Random oversampled dataset (original: {} samples, oversampled: {} samples)",
519                dataset.n_samples(),
520                all_indices.len()
521            )),
522            metadata: dataset.metadata.clone(),
523        })
524    }
525
526    fn fit_and_transform_scaling(
527        &mut self,
528        dataset: &Dataset,
529        method: ScalingMethod,
530    ) -> Result<Dataset> {
531        let mut scalers = HashMap::new();
532        let mut scaled_data = dataset.data.clone();
533
534        for (col_idx, mut column) in scaled_data.columns_mut().into_iter().enumerate() {
535            let featurename = dataset
536                .featurenames
537                .as_ref()
538                .and_then(|names| names.get(col_idx))
539                .cloned()
540                .unwrap_or_else(|| format!("feature_{col_idx}"));
541
542            let column_view = column.view();
543            let scaler_params = Self::fit_scaler(&column_view, method)?;
544            Self::apply_scaler_to_column(&mut column, &scaler_params)?;
545
546            scalers.insert(featurename, scaler_params);
547        }
548
549        self.fitted_scalers = Some(scalers);
550
551        Ok(Dataset {
552            data: scaled_data,
553            target: dataset.target.clone(),
554            featurenames: dataset.featurenames.clone(),
555            targetnames: dataset.targetnames.clone(),
556            feature_descriptions: dataset.feature_descriptions.clone(),
557            description: Some("Scaled dataset".to_string()),
558            metadata: dataset.metadata.clone(),
559        })
560    }
561
562    fn fit_scaler(
563        column: &scirs2_core::ndarray::ArrayView1<f64>,
564        method: ScalingMethod,
565    ) -> Result<ScalerParams> {
566        let values: Vec<f64> = column.iter().copied().filter(|x| !x.is_nan()).collect();
567
568        if values.is_empty() {
569            return Ok(ScalerParams {
570                method,
571                mean: None,
572                std: None,
573                min: None,
574                max: None,
575                median: None,
576                mad: None,
577            });
578        }
579
580        match method {
581            ScalingMethod::StandardScaler => {
582                let mean = values.iter().sum::<f64>() / values.len() as f64;
583                let variance =
584                    values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
585                let std = variance.sqrt();
586
587                Ok(ScalerParams {
588                    method,
589                    mean: Some(mean),
590                    std: Some(std),
591                    min: None,
592                    max: None,
593                    median: None,
594                    mad: None,
595                })
596            }
597            ScalingMethod::MinMaxScaler => {
598                let min = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
599                let max = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
600
601                Ok(ScalerParams {
602                    method,
603                    mean: None,
604                    std: None,
605                    min: Some(min),
606                    max: Some(max),
607                    median: None,
608                    mad: None,
609                })
610            }
611            ScalingMethod::RobustScaler => {
612                let mut sorted_values = values.clone();
613                sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
614
615                let median = Self::percentile(&sorted_values, 0.5).unwrap_or(0.0);
616                let mad = Self::compute_mad(&sorted_values, median);
617
618                Ok(ScalerParams {
619                    method,
620                    mean: None,
621                    std: None,
622                    min: None,
623                    max: None,
624                    median: Some(median),
625                    mad: Some(mad),
626                })
627            }
628            ScalingMethod::None => Ok(ScalerParams {
629                method,
630                mean: None,
631                std: None,
632                min: None,
633                max: None,
634                median: None,
635                mad: None,
636            }),
637        }
638    }
639
640    fn apply_scaler_to_column(
641        column: &mut scirs2_core::ndarray::ArrayViewMut1<f64>,
642        params: &ScalerParams,
643    ) -> Result<()> {
644        match params.method {
645            ScalingMethod::StandardScaler => {
646                if let (Some(mean), Some(std)) = (params.mean, params.std) {
647                    if std > 1e-8 {
648                        // Avoid division by zero
649                        for value in column.iter_mut() {
650                            if !value.is_nan() {
651                                *value = (*value - mean) / std;
652                            }
653                        }
654                    }
655                }
656            }
657            ScalingMethod::MinMaxScaler => {
658                if let (Some(min), Some(max)) = (params.min, params.max) {
659                    let range = max - min;
660                    if range > 1e-8 {
661                        // Avoid division by zero
662                        for value in column.iter_mut() {
663                            if !value.is_nan() {
664                                *value = (*value - min) / range;
665                            }
666                        }
667                    }
668                }
669            }
670            ScalingMethod::RobustScaler => {
671                if let (Some(median), Some(mad)) = (params.median, params.mad) {
672                    if mad > 1e-8 {
673                        // Avoid division by zero
674                        for value in column.iter_mut() {
675                            if !value.is_nan() {
676                                *value = (*value - median) / mad;
677                            }
678                        }
679                    }
680                }
681            }
682            ScalingMethod::None => {
683                // No scaling applied
684            }
685        }
686
687        Ok(())
688    }
689
690    fn percentile(sorted_values: &[f64], p: f64) -> Option<f64> {
691        if sorted_values.is_empty() {
692            return None;
693        }
694
695        let index = p * (sorted_values.len() - 1) as f64;
696        let lower = index.floor() as usize;
697        let upper = index.ceil() as usize;
698
699        if lower == upper {
700            Some(sorted_values[lower])
701        } else {
702            let weight = index - lower as f64;
703            Some(sorted_values[lower] * (1.0 - weight) + sorted_values[upper] * weight)
704        }
705    }
706
707    fn compute_mad(sorted_values: &[f64], median: f64) -> f64 {
708        let deviations: Vec<f64> = sorted_values.iter().map(|&x| (x - median).abs()).collect();
709
710        let mut sorted_deviations = deviations;
711        sorted_deviations.sort_by(|a, b| a.partial_cmp(b).unwrap());
712
713        Self::percentile(&sorted_deviations, 0.5).unwrap_or(1.0)
714    }
715
716    fn generate_split_indices(
717        &self,
718        n_samples: usize,
719        target: Option<&Array1<f64>>,
720    ) -> Result<Vec<usize>> {
721        let mut indices: Vec<usize> = (0..n_samples).collect();
722
723        // Use proper random shuffling based on configuration
724        if self.config.stratify && target.is_some() {
725            // Implement stratified shuffling
726            self.stratified_shuffle(&mut indices, target.unwrap())?;
727        } else {
728            // Regular shuffling with optional random state
729            match self.config.random_state {
730                Some(seed) => {
731                    let mut rng = StdRng::seed_from_u64(seed);
732                    indices.shuffle(&mut rng);
733                }
734                None => {
735                    let mut rng = thread_rng();
736                    indices.shuffle(&mut rng);
737                }
738            }
739        }
740
741        Ok(indices)
742    }
743
744    /// Perform stratified shuffling to maintain class proportions
745    fn stratified_shuffle(&self, indices: &mut Vec<usize>, target: &Array1<f64>) -> Result<()> {
746        // Group indices by class
747        let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
748
749        for &idx in indices.iter() {
750            let class = target[idx] as i32;
751            class_indices.entry(class).or_default().push(idx);
752        }
753
754        // Shuffle each class group separately
755        for class_group in class_indices.values_mut() {
756            match self.config.random_state {
757                Some(seed) => {
758                    let mut rng = StdRng::seed_from_u64(seed);
759                    class_group.shuffle(&mut rng);
760                }
761                None => {
762                    let mut rng = thread_rng();
763                    class_group.shuffle(&mut rng);
764                }
765            }
766        }
767
768        // Recombine shuffled class groups while maintaining order
769        indices.clear();
770        let mut class_iterators: HashMap<i32, std::vec::IntoIter<usize>> = class_indices
771            .into_iter()
772            .map(|(class, group)| (class, group.into_iter()))
773            .collect();
774
775        // Interleave samples from different classes to maintain distribution
776        while !class_iterators.is_empty() {
777            let mut to_remove = Vec::new();
778            for (&class, iterator) in class_iterators.iter_mut() {
779                if let Some(idx) = iterator.next() {
780                    indices.push(idx);
781                } else {
782                    to_remove.push(class);
783                }
784            }
785            for class in to_remove {
786                class_iterators.remove(&class);
787            }
788        }
789
790        Ok(())
791    }
792
793    fn extract_dataset_info(&self, dataset: &Dataset) -> DatasetInfo {
794        let n_samples = dataset.n_samples();
795        let n_features = dataset.n_features();
796
797        let (n_classes, class_distribution) = if let Some(ref target) = dataset.target {
798            let mut class_counts: HashMap<String, usize> = HashMap::new();
799            for &value in target.iter() {
800                if !value.is_nan() {
801                    let classname = format!("{value:.0}");
802                    *class_counts.entry(classname).or_insert(0) += 1;
803                }
804            }
805
806            let n_classes = class_counts.len();
807            (Some(n_classes), Some(class_counts))
808        } else {
809            (None, None)
810        };
811
812        // Calculate missing data percentage
813        let total_values = n_samples * n_features;
814        let missing_values = dataset.data.iter().filter(|&&x| x.is_nan()).count();
815        let missing_data_percentage = missing_values as f64 / total_values as f64 * 100.0;
816
817        DatasetInfo {
818            n_samples,
819            n_features,
820            n_classes,
821            class_distribution,
822            missing_data_percentage,
823        }
824    }
825}
826
827/// Convenience functions for ML pipeline integration
828pub mod convenience {
829    use super::*;
830
831    /// Quick train/test split with default configuration
832    pub fn train_test_split(_dataset: &Dataset, testsize: Option<f64>) -> Result<DataSplit> {
833        let mut config = MLPipelineConfig::default();
834        if let Some(_size) = testsize {
835            config.test_size = _size;
836        }
837
838        let pipeline = MLPipeline::new(config);
839        pipeline.train_test_split(_dataset)
840    }
841
842    /// Prepare dataset for ML with standard preprocessing
843    pub fn prepare_for_ml(dataset: &Dataset, scale: bool, balance: bool) -> Result<Dataset> {
844        let mut config = MLPipelineConfig::default();
845
846        if !scale {
847            config.scaling_method = None;
848        }
849
850        if balance {
851            config.balancing_strategy = Some(BalancingStrategy::RandomUndersample);
852        }
853
854        let mut pipeline = MLPipeline::new(config);
855        pipeline.prepare_dataset(dataset)
856    }
857
858    /// Generate cross-validation folds
859    pub fn cv_split(
860        dataset: &Dataset,
861        n_folds: Option<usize>,
862        stratify: Option<bool>,
863    ) -> Result<CrossValidationFolds> {
864        let mut config = MLPipelineConfig::default();
865
866        if let Some(_folds) = n_folds {
867            config.cv_folds = _folds;
868        }
869
870        if let Some(strat) = stratify {
871            config.stratify = strat;
872        }
873
874        let pipeline = MLPipeline::new(config);
875        pipeline.cross_validation_split(dataset)
876    }
877
878    /// Create a simple ML experiment
879    pub fn create_experiment(name: &str, dataset: &Dataset) -> MLExperiment {
880        let pipeline = MLPipeline::default();
881        pipeline.create_experiment(name, dataset)
882    }
883}
884
885#[cfg(test)]
886mod tests {
887    use super::*;
888    use crate::generators::make_classification;
889    use scirs2_core::random::Uniform;
890
891    #[test]
892    fn test_ml_pipeline_creation() {
893        let pipeline = MLPipeline::default();
894        assert_eq!(pipeline.config.test_size, 0.2);
895        assert_eq!(pipeline.config.cv_folds, 5);
896    }
897
898    #[test]
899    fn test_train_test_split() {
900        let dataset = make_classification(100, 5, 2, 1, 1, Some(42)).unwrap();
901        let split = convenience::train_test_split(&dataset, Some(0.3)).unwrap();
902
903        assert_eq!(split.x_train.nrows() + split.x_test.nrows(), 100);
904        assert_eq!(split.y_train.len() + split.y_test.len(), 100);
905        assert_eq!(split.x_train.ncols(), 5);
906        assert_eq!(split.x_test.ncols(), 5);
907    }
908
909    #[test]
910    fn test_cross_validation_split() {
911        let dataset = make_classification(100, 3, 2, 1, 1, Some(42)).unwrap();
912        let folds = convenience::cv_split(&dataset, Some(5), Some(true)).unwrap();
913
914        assert_eq!(folds.len(), 5);
915
916        // Check that all samples are used
917        let total_samples: usize = folds
918            .iter()
919            .map(|(train, test)| train.len() + test.len())
920            .sum::<usize>()
921            / 5; // Each sample appears in exactly one test set
922
923        assert_eq!(total_samples, 100);
924    }
925
926    #[test]
927    fn test_dataset_preparation() {
928        let dataset = make_classification(50, 4, 2, 1, 1, Some(42)).unwrap();
929        let prepared = convenience::prepare_for_ml(&dataset, true, false).unwrap();
930
931        assert_eq!(prepared.n_samples(), dataset.n_samples());
932        assert_eq!(prepared.n_features(), dataset.n_features());
933    }
934
935    #[test]
936    fn test_experiment_creation() {
937        let dataset = make_classification(100, 5, 2, 1, 1, Some(42)).unwrap();
938        let experiment = convenience::create_experiment("test_experiment", &dataset);
939
940        assert_eq!(experiment.name, "test_experiment");
941        assert_eq!(experiment.dataset_info.n_samples, 100);
942        assert_eq!(experiment.dataset_info.n_features, 5);
943        assert_eq!(experiment.dataset_info.n_classes, Some(2));
944    }
945
946    #[test]
947    fn test_scaler_fitting() {
948        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
949        let array = Array1::from_vec(data);
950        let view = array.view();
951
952        let scaler_params = MLPipeline::fit_scaler(&view, ScalingMethod::StandardScaler).unwrap();
953
954        assert!(scaler_params.mean.is_some());
955        assert!(scaler_params.std.is_some());
956        assert_eq!(scaler_params.mean.unwrap(), 3.0);
957    }
958
959    #[test]
960    fn test_min_max_scaler() {
961        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
962        let array = Array1::from_vec(data);
963        let view = array.view();
964
965        let scaler_params = MLPipeline::fit_scaler(&view, ScalingMethod::MinMaxScaler).unwrap();
966
967        assert!(scaler_params.min.is_some());
968        assert!(scaler_params.max.is_some());
969        assert_eq!(scaler_params.min.unwrap(), 1.0);
970        assert_eq!(scaler_params.max.unwrap(), 5.0);
971    }
972}