scirs2_datasets/
ml_integration.rs

1//! Machine learning pipeline integration
2//!
3//! This module provides integration utilities for common ML frameworks and pipelines:
4//! - Model training data preparation
5//! - Cross-validation utilities
6//! - Feature engineering pipelines
7//! - Model evaluation and metrics
8//! - Integration with popular ML libraries
9
10use std::collections::HashMap;
11
12use scirs2_core::ndarray::{Array1, Array2, Axis};
13use scirs2_core::random::prelude::*;
14use scirs2_core::random::SliceRandomExt;
15use scirs2_core::random::Uniform;
16use serde::{Deserialize, Serialize};
17
18use crate::error::{DatasetsError, Result};
19use crate::utils::{BalancingStrategy, CrossValidationFolds, Dataset};
20
21/// Configuration for ML pipeline integration
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct MLPipelineConfig {
24    /// Random seed for reproducibility
25    pub random_state: Option<u64>,
26    /// Default test size for train/test splits
27    pub test_size: f64,
28    /// Number of folds for cross-validation
29    pub cv_folds: usize,
30    /// Whether to stratify splits for classification
31    pub stratify: bool,
32    /// Data balancing strategy
33    pub balancing_strategy: Option<BalancingStrategy>,
34    /// Feature scaling method
35    pub scaling_method: Option<ScalingMethod>,
36}
37
38impl Default for MLPipelineConfig {
39    fn default() -> Self {
40        Self {
41            random_state: Some(42),
42            test_size: 0.2,
43            cv_folds: 5,
44            stratify: true,
45            balancing_strategy: None,
46            scaling_method: Some(ScalingMethod::StandardScaler),
47        }
48    }
49}
50
51/// Feature scaling methods
52#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
53pub enum ScalingMethod {
54    /// Z-score normalization
55    StandardScaler,
56    /// Min-max scaling to [0, 1]
57    MinMaxScaler,
58    /// Robust scaling using median and MAD
59    RobustScaler,
60    /// No scaling
61    None,
62}
63
64/// ML pipeline for data preprocessing and preparation
65pub struct MLPipeline {
66    config: MLPipelineConfig,
67    fitted_scalers: Option<HashMap<String, ScalerParams>>,
68}
69
70/// Parameters for fitted scalers
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ScalerParams {
73    /// Scaling method used
74    pub method: ScalingMethod,
75    /// Mean value (for StandardScaler)
76    pub mean: Option<f64>,
77    /// Standard deviation (for StandardScaler)
78    pub std: Option<f64>,
79    /// Minimum value (for MinMaxScaler)
80    pub min: Option<f64>,
81    /// Maximum value (for MinMaxScaler)
82    pub max: Option<f64>,
83    /// Median value (for RobustScaler)
84    pub median: Option<f64>,
85    /// Median absolute deviation (for RobustScaler)
86    pub mad: Option<f64>,
87}
88
89/// ML experiment tracking
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct MLExperiment {
92    /// Experiment name
93    pub name: String,
94    /// Dataset information
95    pub dataset_info: DatasetInfo,
96    /// Model configuration
97    pub model_config: ModelConfig,
98    /// Training results
99    pub results: ExperimentResults,
100    /// Cross-validation scores
101    pub cv_scores: Option<CrossValidationResults>,
102}
103
104/// Dataset information for experiments
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct DatasetInfo {
107    /// Number of samples in the dataset
108    pub n_samples: usize,
109    /// Number of features in the dataset
110    pub n_features: usize,
111    /// Number of classes (for classification tasks)
112    pub n_classes: Option<usize>,
113    /// Distribution of classes in the dataset
114    pub class_distribution: Option<HashMap<String, usize>>,
115    /// Percentage of missing data
116    pub missing_data_percentage: f64,
117}
118
119/// Model configuration
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct ModelConfig {
122    /// Type of ML model used
123    pub model_type: String,
124    /// Hyperparameter settings
125    pub hyperparameters: HashMap<String, serde_json::Value>,
126    /// List of preprocessing steps applied
127    pub preprocessing_steps: Vec<String>,
128}
129
130/// Experiment results
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct ExperimentResults {
133    /// Score on training data
134    pub training_score: f64,
135    /// Score on validation data
136    pub validation_score: f64,
137    /// Score on test data (if available)
138    pub test_score: Option<f64>,
139    /// Time taken for training (in seconds)
140    pub training_time: f64,
141    /// Average inference time per sample (in milliseconds)
142    pub inference_time: Option<f64>,
143    /// Feature importance scores
144    pub feature_importance: Option<Vec<(String, f64)>>,
145}
146
147/// Cross-validation results
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct CrossValidationResults {
150    /// Individual scores for each fold
151    pub scores: Vec<f64>,
152    /// Mean score across all folds
153    pub mean_score: f64,
154    /// Standard deviation of scores
155    pub std_score: f64,
156    /// Detailed results for each fold
157    pub fold_details: Vec<FoldResult>,
158}
159
160/// Result for a single fold
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct FoldResult {
163    /// Index of the fold
164    pub fold_index: usize,
165    /// Training score for this fold
166    pub train_score: f64,
167    /// Validation score for this fold
168    pub validation_score: f64,
169    /// Training time in seconds for this fold
170    pub training_time: f64,
171}
172
173/// Data split for ML training
174#[derive(Debug, Clone)]
175pub struct DataSplit {
176    /// Training features
177    pub x_train: Array2<f64>,
178    /// Testing features
179    pub x_test: Array2<f64>,
180    /// Training targets
181    pub y_train: Array1<f64>,
182    /// Testing targets
183    pub y_test: Array1<f64>,
184}
185
186impl Default for MLPipeline {
187    fn default() -> Self {
188        Self::new(MLPipelineConfig::default())
189    }
190}
191
192impl MLPipeline {
193    /// Create a new ML pipeline
194    pub fn new(config: MLPipelineConfig) -> Self {
195        Self {
196            config,
197            fitted_scalers: None,
198        }
199    }
200
201    /// Prepare dataset for ML training
202    pub fn prepare_dataset(&mut self, dataset: &Dataset) -> Result<Dataset> {
203        let mut prepared = dataset.clone();
204
205        // Apply balancing if specified
206        if let Some(ref strategy) = self.config.balancing_strategy {
207            prepared = self.apply_balancing(&prepared, strategy)?;
208        }
209
210        // Apply scaling if specified
211        if let Some(method) = self.config.scaling_method {
212            prepared = self.fit_and_transform_scaling(&prepared, method)?;
213        }
214
215        Ok(prepared)
216    }
217
218    /// Split dataset into train/test sets
219    pub fn train_test_split(&self, dataset: &Dataset) -> Result<DataSplit> {
220        let n_samples = dataset.n_samples();
221        let test_samples = (n_samples as f64 * self.config.test_size) as usize;
222        let train_samples = n_samples - test_samples;
223
224        let indices = self.generate_split_indices(n_samples, dataset.target.as_ref())?;
225
226        let train_indices = &indices[..train_samples];
227        let test_indices = &indices[train_samples..];
228
229        let x_train = dataset.data.select(Axis(0), train_indices);
230        let x_test = dataset.data.select(Axis(0), test_indices);
231
232        let (y_train, y_test) = if let Some(ref target) = dataset.target {
233            let y_train = target.select(Axis(0), train_indices);
234            let y_test = target.select(Axis(0), test_indices);
235            (y_train, y_test)
236        } else {
237            return Err(DatasetsError::InvalidFormat(
238                "Target variable required for train/test split".to_string(),
239            ));
240        };
241
242        Ok(DataSplit {
243            x_train,
244            x_test,
245            y_train,
246            y_test,
247        })
248    }
249
250    /// Generate cross-validation folds
251    pub fn cross_validation_split(&self, dataset: &Dataset) -> Result<CrossValidationFolds> {
252        let target = dataset.target.as_ref().ok_or_else(|| {
253            DatasetsError::InvalidFormat(
254                "Target variable required for cross-validation".to_string(),
255            )
256        })?;
257
258        if self.config.stratify {
259            crate::utils::stratified_k_fold_split(
260                target,
261                self.config.cv_folds,
262                true,
263                self.config.random_state,
264            )
265        } else {
266            crate::utils::k_fold_split(
267                dataset.n_samples(),
268                self.config.cv_folds,
269                true,
270                self.config.random_state,
271            )
272        }
273    }
274
275    /// Transform new data using fitted scalers
276    pub fn transform(&self, dataset: &Dataset) -> Result<Dataset> {
277        let scalers = self.fitted_scalers.as_ref().ok_or_else(|| {
278            DatasetsError::InvalidFormat(
279                "Pipeline not fitted. Call prepare_dataset first.".to_string(),
280            )
281        })?;
282
283        let mut transformed_data = dataset.data.clone();
284
285        for (col_idx, mut column) in transformed_data.columns_mut().into_iter().enumerate() {
286            let defaultname = format!("feature_{col_idx}");
287            let featurename = dataset
288                .featurenames
289                .as_ref()
290                .and_then(|names| names.get(col_idx))
291                .map(|s| s.as_str())
292                .unwrap_or(&defaultname);
293
294            if let Some(scaler) = scalers.get(featurename) {
295                Self::apply_scaler_to_column(&mut column, scaler)?;
296            }
297        }
298
299        Ok(Dataset {
300            data: transformed_data,
301            target: dataset.target.clone(),
302            featurenames: dataset.featurenames.clone(),
303            targetnames: dataset.targetnames.clone(),
304            feature_descriptions: dataset.feature_descriptions.clone(),
305            description: Some("Transformed dataset".to_string()),
306            metadata: dataset.metadata.clone(),
307        })
308    }
309
310    /// Create an ML experiment tracker
311    pub fn create_experiment(&self, name: &str, dataset: &Dataset) -> MLExperiment {
312        let dataset_info = self.extract_dataset_info(dataset);
313
314        MLExperiment {
315            name: name.to_string(),
316            dataset_info,
317            model_config: ModelConfig {
318                model_type: "undefined".to_string(),
319                hyperparameters: HashMap::new(),
320                preprocessing_steps: Vec::new(),
321            },
322            results: ExperimentResults {
323                training_score: 0.0,
324                validation_score: 0.0,
325                test_score: None,
326                training_time: 0.0,
327                inference_time: None,
328                feature_importance: None,
329            },
330            cv_scores: None,
331        }
332    }
333
334    /// Evaluate model performance with cross-validation
335    pub fn evaluate_with_cv<F>(
336        &self,
337        dataset: &Dataset,
338        train_fn: F,
339    ) -> Result<CrossValidationResults>
340    where
341        F: Fn(&Array2<f64>, &Array1<f64>, &Array2<f64>, &Array1<f64>) -> Result<(f64, f64, f64)>,
342    {
343        let folds = self.cross_validation_split(dataset)?;
344        let mut scores = Vec::new();
345        let mut fold_details = Vec::new();
346
347        for (fold_idx, (train_indices, val_indices)) in folds.into_iter().enumerate() {
348            let x_train = dataset.data.select(Axis(0), &train_indices);
349            let x_val = dataset.data.select(Axis(0), &val_indices);
350
351            let target = dataset.target.as_ref().expect("Operation failed");
352            let y_train = target.select(Axis(0), &train_indices);
353            let y_val = target.select(Axis(0), &val_indices);
354
355            let (train_score, val_score, training_time) =
356                train_fn(&x_train, &y_train, &x_val, &y_val)?;
357
358            scores.push(val_score);
359            fold_details.push(FoldResult {
360                fold_index: fold_idx,
361                train_score,
362                validation_score: val_score,
363                training_time,
364            });
365        }
366
367        let mean_score = scores.iter().sum::<f64>() / scores.len() as f64;
368        let variance = scores
369            .iter()
370            .map(|score| (score - mean_score).powi(2))
371            .sum::<f64>()
372            / scores.len() as f64;
373        let std_score = variance.sqrt();
374
375        Ok(CrossValidationResults {
376            scores,
377            mean_score,
378            std_score,
379            fold_details,
380        })
381    }
382
383    // Private helper methods
384
385    fn apply_balancing(&self, dataset: &Dataset, strategy: &BalancingStrategy) -> Result<Dataset> {
386        // Simplified balancing implementation
387        // In a full implementation, you'd use the actual balancing utilities
388        match strategy {
389            BalancingStrategy::RandomUndersample => self.random_undersample(dataset, None),
390            BalancingStrategy::RandomOversample => self.random_oversample(dataset, None),
391            _ => Ok(dataset.clone()), // Placeholder for other strategies
392        }
393    }
394
395    fn random_undersample(&self, dataset: &Dataset, _randomstate: Option<u64>) -> Result<Dataset> {
396        let target = dataset.target.as_ref().ok_or_else(|| {
397            DatasetsError::InvalidFormat("Target required for balancing".to_string())
398        })?;
399
400        // Find minority class size
401        let mut class_counts: HashMap<i64, usize> = HashMap::new();
402        for &value in target.iter() {
403            if !value.is_nan() {
404                *class_counts.entry(value as i64).or_insert(0) += 1;
405            }
406        }
407
408        let min_count = class_counts.values().min().copied().unwrap_or(0);
409
410        // Sample min_count samples from each class
411        let mut selected_indices = Vec::new();
412
413        for (class_, _count) in class_counts {
414            let class_indices: Vec<usize> = target
415                .iter()
416                .enumerate()
417                .filter(|(_, &val)| !val.is_nan() && val as i64 == class_)
418                .map(|(idx, _)| idx)
419                .collect();
420
421            let mut sampled_indices = class_indices;
422            if sampled_indices.len() > min_count {
423                // Simple random sampling (in a real implementation, use proper random sampling)
424                sampled_indices.truncate(min_count);
425            }
426
427            selected_indices.extend(sampled_indices);
428        }
429
430        let balanced_data = dataset.data.select(Axis(0), &selected_indices);
431        let balanced_target = target.select(Axis(0), &selected_indices);
432
433        Ok(Dataset {
434            data: balanced_data,
435            target: Some(balanced_target),
436            featurenames: dataset.featurenames.clone(),
437            targetnames: dataset.targetnames.clone(),
438            feature_descriptions: dataset.feature_descriptions.clone(),
439            description: Some("Undersampled dataset".to_string()),
440            metadata: dataset.metadata.clone(),
441        })
442    }
443
444    fn random_oversample(&self, dataset: &Dataset, randomstate: Option<u64>) -> Result<Dataset> {
445        use scirs2_core::random::prelude::*;
446        use scirs2_core::random::{rngs::StdRng, RngCore, SeedableRng};
447        use std::collections::HashMap;
448
449        let target = dataset.target.as_ref().ok_or_else(|| {
450            DatasetsError::InvalidFormat("Random oversampling requires target labels".to_string())
451        })?;
452
453        if target.len() != dataset.data.nrows() {
454            return Err(DatasetsError::InvalidFormat(
455                "Target length must match number of samples".to_string(),
456            ));
457        }
458
459        // Count samples per class
460        let mut class_counts: HashMap<i32, usize> = HashMap::new();
461        let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
462
463        for (idx, &label) in target.iter().enumerate() {
464            let class = label as i32;
465            *class_counts.entry(class).or_insert(0) += 1;
466            class_indices.entry(class).or_default().push(idx);
467        }
468
469        // Find the majority class size (the maximum count)
470        let max_count = class_counts.values().max().copied().unwrap_or(0);
471
472        if max_count == 0 {
473            return Err(DatasetsError::InvalidFormat(
474                "No samples found in dataset".to_string(),
475            ));
476        }
477
478        // Create RNG
479        let mut rng: Box<dyn RngCore> = match randomstate {
480            Some(seed) => Box::new(StdRng::seed_from_u64(seed)),
481            None => Box::new(thread_rng()),
482        };
483
484        // Collect all indices for the oversampled dataset
485        let mut all_indices = Vec::new();
486
487        for (_class, indices) in class_indices.iter() {
488            let current_count = indices.len();
489
490            // Add all original samples
491            all_indices.extend(indices.iter().copied());
492
493            // Add additional samples by random oversampling with replacement
494            let samples_needed = max_count - current_count;
495
496            if samples_needed > 0 {
497                for _ in 0..samples_needed {
498                    let random_idx =
499                        rng.sample(Uniform::new(0, indices.len()).expect("Operation failed"));
500                    all_indices.push(indices[random_idx]);
501                }
502            }
503        }
504
505        // Shuffle the final indices to mix classes
506        all_indices.shuffle(&mut *rng);
507
508        // Create the oversampled dataset
509        let oversampled_data = dataset.data.select(Axis(0), &all_indices);
510        let oversampled_target = target.select(Axis(0), &all_indices);
511
512        Ok(Dataset {
513            data: oversampled_data,
514            target: Some(oversampled_target),
515            featurenames: dataset.featurenames.clone(),
516            targetnames: dataset.targetnames.clone(),
517            feature_descriptions: dataset.feature_descriptions.clone(),
518            description: Some(format!(
519                "Random oversampled dataset (original: {} samples, oversampled: {} samples)",
520                dataset.n_samples(),
521                all_indices.len()
522            )),
523            metadata: dataset.metadata.clone(),
524        })
525    }
526
527    fn fit_and_transform_scaling(
528        &mut self,
529        dataset: &Dataset,
530        method: ScalingMethod,
531    ) -> Result<Dataset> {
532        let mut scalers = HashMap::new();
533        let mut scaled_data = dataset.data.clone();
534
535        for (col_idx, mut column) in scaled_data.columns_mut().into_iter().enumerate() {
536            let featurename = dataset
537                .featurenames
538                .as_ref()
539                .and_then(|names| names.get(col_idx))
540                .cloned()
541                .unwrap_or_else(|| format!("feature_{col_idx}"));
542
543            let column_view = column.view();
544            let scaler_params = Self::fit_scaler(&column_view, method)?;
545            Self::apply_scaler_to_column(&mut column, &scaler_params)?;
546
547            scalers.insert(featurename, scaler_params);
548        }
549
550        self.fitted_scalers = Some(scalers);
551
552        Ok(Dataset {
553            data: scaled_data,
554            target: dataset.target.clone(),
555            featurenames: dataset.featurenames.clone(),
556            targetnames: dataset.targetnames.clone(),
557            feature_descriptions: dataset.feature_descriptions.clone(),
558            description: Some("Scaled dataset".to_string()),
559            metadata: dataset.metadata.clone(),
560        })
561    }
562
563    fn fit_scaler(
564        column: &scirs2_core::ndarray::ArrayView1<f64>,
565        method: ScalingMethod,
566    ) -> Result<ScalerParams> {
567        let values: Vec<f64> = column.iter().copied().filter(|x| !x.is_nan()).collect();
568
569        if values.is_empty() {
570            return Ok(ScalerParams {
571                method,
572                mean: None,
573                std: None,
574                min: None,
575                max: None,
576                median: None,
577                mad: None,
578            });
579        }
580
581        match method {
582            ScalingMethod::StandardScaler => {
583                let mean = values.iter().sum::<f64>() / values.len() as f64;
584                let variance =
585                    values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
586                let std = variance.sqrt();
587
588                Ok(ScalerParams {
589                    method,
590                    mean: Some(mean),
591                    std: Some(std),
592                    min: None,
593                    max: None,
594                    median: None,
595                    mad: None,
596                })
597            }
598            ScalingMethod::MinMaxScaler => {
599                let min = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
600                let max = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
601
602                Ok(ScalerParams {
603                    method,
604                    mean: None,
605                    std: None,
606                    min: Some(min),
607                    max: Some(max),
608                    median: None,
609                    mad: None,
610                })
611            }
612            ScalingMethod::RobustScaler => {
613                let mut sorted_values = values.clone();
614                sorted_values.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
615
616                let median = Self::percentile(&sorted_values, 0.5).unwrap_or(0.0);
617                let mad = Self::compute_mad(&sorted_values, median);
618
619                Ok(ScalerParams {
620                    method,
621                    mean: None,
622                    std: None,
623                    min: None,
624                    max: None,
625                    median: Some(median),
626                    mad: Some(mad),
627                })
628            }
629            ScalingMethod::None => Ok(ScalerParams {
630                method,
631                mean: None,
632                std: None,
633                min: None,
634                max: None,
635                median: None,
636                mad: None,
637            }),
638        }
639    }
640
641    fn apply_scaler_to_column(
642        column: &mut scirs2_core::ndarray::ArrayViewMut1<f64>,
643        params: &ScalerParams,
644    ) -> Result<()> {
645        match params.method {
646            ScalingMethod::StandardScaler => {
647                if let (Some(mean), Some(std)) = (params.mean, params.std) {
648                    if std > 1e-8 {
649                        // Avoid division by zero
650                        for value in column.iter_mut() {
651                            if !value.is_nan() {
652                                *value = (*value - mean) / std;
653                            }
654                        }
655                    }
656                }
657            }
658            ScalingMethod::MinMaxScaler => {
659                if let (Some(min), Some(max)) = (params.min, params.max) {
660                    let range = max - min;
661                    if range > 1e-8 {
662                        // Avoid division by zero
663                        for value in column.iter_mut() {
664                            if !value.is_nan() {
665                                *value = (*value - min) / range;
666                            }
667                        }
668                    }
669                }
670            }
671            ScalingMethod::RobustScaler => {
672                if let (Some(median), Some(mad)) = (params.median, params.mad) {
673                    if mad > 1e-8 {
674                        // Avoid division by zero
675                        for value in column.iter_mut() {
676                            if !value.is_nan() {
677                                *value = (*value - median) / mad;
678                            }
679                        }
680                    }
681                }
682            }
683            ScalingMethod::None => {
684                // No scaling applied
685            }
686        }
687
688        Ok(())
689    }
690
691    fn percentile(sorted_values: &[f64], p: f64) -> Option<f64> {
692        if sorted_values.is_empty() {
693            return None;
694        }
695
696        let index = p * (sorted_values.len() - 1) as f64;
697        let lower = index.floor() as usize;
698        let upper = index.ceil() as usize;
699
700        if lower == upper {
701            Some(sorted_values[lower])
702        } else {
703            let weight = index - lower as f64;
704            Some(sorted_values[lower] * (1.0 - weight) + sorted_values[upper] * weight)
705        }
706    }
707
708    fn compute_mad(sorted_values: &[f64], median: f64) -> f64 {
709        let deviations: Vec<f64> = sorted_values.iter().map(|&x| (x - median).abs()).collect();
710
711        let mut sorted_deviations = deviations;
712        sorted_deviations.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
713
714        Self::percentile(&sorted_deviations, 0.5).unwrap_or(1.0)
715    }
716
717    fn generate_split_indices(
718        &self,
719        n_samples: usize,
720        target: Option<&Array1<f64>>,
721    ) -> Result<Vec<usize>> {
722        let mut indices: Vec<usize> = (0..n_samples).collect();
723
724        // Use proper random shuffling based on configuration
725        if let (true, Some(t)) = (self.config.stratify, target) {
726            // Implement stratified shuffling
727            self.stratified_shuffle(&mut indices, t)?;
728        } else {
729            // Regular shuffling with optional random state
730            match self.config.random_state {
731                Some(seed) => {
732                    let mut rng = StdRng::seed_from_u64(seed);
733                    indices.shuffle(&mut rng);
734                }
735                None => {
736                    let mut rng = thread_rng();
737                    indices.shuffle(&mut rng);
738                }
739            }
740        }
741
742        Ok(indices)
743    }
744
745    /// Perform stratified shuffling to maintain class proportions
746    fn stratified_shuffle(&self, indices: &mut Vec<usize>, target: &Array1<f64>) -> Result<()> {
747        // Group indices by class
748        let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
749
750        for &idx in indices.iter() {
751            let class = target[idx] as i32;
752            class_indices.entry(class).or_default().push(idx);
753        }
754
755        // Shuffle each class group separately
756        for class_group in class_indices.values_mut() {
757            match self.config.random_state {
758                Some(seed) => {
759                    let mut rng = StdRng::seed_from_u64(seed);
760                    class_group.shuffle(&mut rng);
761                }
762                None => {
763                    let mut rng = thread_rng();
764                    class_group.shuffle(&mut rng);
765                }
766            }
767        }
768
769        // Recombine shuffled class groups while maintaining order
770        indices.clear();
771        let mut class_iterators: HashMap<i32, std::vec::IntoIter<usize>> = class_indices
772            .into_iter()
773            .map(|(class, group)| (class, group.into_iter()))
774            .collect();
775
776        // Interleave samples from different classes to maintain distribution
777        while !class_iterators.is_empty() {
778            let mut to_remove = Vec::new();
779            for (&class, iterator) in class_iterators.iter_mut() {
780                if let Some(idx) = iterator.next() {
781                    indices.push(idx);
782                } else {
783                    to_remove.push(class);
784                }
785            }
786            for class in to_remove {
787                class_iterators.remove(&class);
788            }
789        }
790
791        Ok(())
792    }
793
794    fn extract_dataset_info(&self, dataset: &Dataset) -> DatasetInfo {
795        let n_samples = dataset.n_samples();
796        let n_features = dataset.n_features();
797
798        let (n_classes, class_distribution) = if let Some(ref target) = dataset.target {
799            let mut class_counts: HashMap<String, usize> = HashMap::new();
800            for &value in target.iter() {
801                if !value.is_nan() {
802                    let classname = format!("{value:.0}");
803                    *class_counts.entry(classname).or_insert(0) += 1;
804                }
805            }
806
807            let n_classes = class_counts.len();
808            (Some(n_classes), Some(class_counts))
809        } else {
810            (None, None)
811        };
812
813        // Calculate missing data percentage
814        let total_values = n_samples * n_features;
815        let missing_values = dataset.data.iter().filter(|&&x| x.is_nan()).count();
816        let missing_data_percentage = missing_values as f64 / total_values as f64 * 100.0;
817
818        DatasetInfo {
819            n_samples,
820            n_features,
821            n_classes,
822            class_distribution,
823            missing_data_percentage,
824        }
825    }
826}
827
828/// Convenience functions for ML pipeline integration
829pub mod convenience {
830    use super::*;
831
832    /// Quick train/test split with default configuration
833    pub fn train_test_split(_dataset: &Dataset, testsize: Option<f64>) -> Result<DataSplit> {
834        let mut config = MLPipelineConfig::default();
835        if let Some(_size) = testsize {
836            config.test_size = _size;
837        }
838
839        let pipeline = MLPipeline::new(config);
840        pipeline.train_test_split(_dataset)
841    }
842
843    /// Prepare dataset for ML with standard preprocessing
844    pub fn prepare_for_ml(dataset: &Dataset, scale: bool, balance: bool) -> Result<Dataset> {
845        let mut config = MLPipelineConfig::default();
846
847        if !scale {
848            config.scaling_method = None;
849        }
850
851        if balance {
852            config.balancing_strategy = Some(BalancingStrategy::RandomUndersample);
853        }
854
855        let mut pipeline = MLPipeline::new(config);
856        pipeline.prepare_dataset(dataset)
857    }
858
859    /// Generate cross-validation folds
860    pub fn cv_split(
861        dataset: &Dataset,
862        n_folds: Option<usize>,
863        stratify: Option<bool>,
864    ) -> Result<CrossValidationFolds> {
865        let mut config = MLPipelineConfig::default();
866
867        if let Some(_folds) = n_folds {
868            config.cv_folds = _folds;
869        }
870
871        if let Some(strat) = stratify {
872            config.stratify = strat;
873        }
874
875        let pipeline = MLPipeline::new(config);
876        pipeline.cross_validation_split(dataset)
877    }
878
879    /// Create a simple ML experiment
880    pub fn create_experiment(name: &str, dataset: &Dataset) -> MLExperiment {
881        let pipeline = MLPipeline::default();
882        pipeline.create_experiment(name, dataset)
883    }
884}
885
886#[cfg(test)]
887mod tests {
888    use super::*;
889    use crate::generators::make_classification;
890    use scirs2_core::random::Uniform;
891
892    #[test]
893    fn test_ml_pipeline_creation() {
894        let pipeline = MLPipeline::default();
895        assert_eq!(pipeline.config.test_size, 0.2);
896        assert_eq!(pipeline.config.cv_folds, 5);
897    }
898
899    #[test]
900    fn test_train_test_split() {
901        let dataset = make_classification(100, 5, 2, 1, 1, Some(42)).expect("Operation failed");
902        let split = convenience::train_test_split(&dataset, Some(0.3)).expect("Operation failed");
903
904        assert_eq!(split.x_train.nrows() + split.x_test.nrows(), 100);
905        assert_eq!(split.y_train.len() + split.y_test.len(), 100);
906        assert_eq!(split.x_train.ncols(), 5);
907        assert_eq!(split.x_test.ncols(), 5);
908    }
909
910    #[test]
911    fn test_cross_validation_split() {
912        let dataset = make_classification(100, 3, 2, 1, 1, Some(42)).expect("Operation failed");
913        let folds = convenience::cv_split(&dataset, Some(5), Some(true)).expect("Operation failed");
914
915        assert_eq!(folds.len(), 5);
916
917        // Check that all samples are used
918        let total_samples: usize = folds
919            .iter()
920            .map(|(train, test)| train.len() + test.len())
921            .sum::<usize>()
922            / 5; // Each sample appears in exactly one test set
923
924        assert_eq!(total_samples, 100);
925    }
926
927    #[test]
928    fn test_dataset_preparation() {
929        let dataset = make_classification(50, 4, 2, 1, 1, Some(42)).expect("Operation failed");
930        let prepared =
931            convenience::prepare_for_ml(&dataset, true, false).expect("Operation failed");
932
933        assert_eq!(prepared.n_samples(), dataset.n_samples());
934        assert_eq!(prepared.n_features(), dataset.n_features());
935    }
936
937    #[test]
938    fn test_experiment_creation() {
939        let dataset = make_classification(100, 5, 2, 1, 1, Some(42)).expect("Operation failed");
940        let experiment = convenience::create_experiment("test_experiment", &dataset);
941
942        assert_eq!(experiment.name, "test_experiment");
943        assert_eq!(experiment.dataset_info.n_samples, 100);
944        assert_eq!(experiment.dataset_info.n_features, 5);
945        assert_eq!(experiment.dataset_info.n_classes, Some(2));
946    }
947
948    #[test]
949    fn test_scaler_fitting() {
950        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
951        let array = Array1::from_vec(data);
952        let view = array.view();
953
954        let scaler_params =
955            MLPipeline::fit_scaler(&view, ScalingMethod::StandardScaler).expect("Operation failed");
956
957        assert!(scaler_params.mean.is_some());
958        assert!(scaler_params.std.is_some());
959        assert_eq!(scaler_params.mean.expect("Test: mean missing"), 3.0);
960    }
961
962    #[test]
963    fn test_min_max_scaler() {
964        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
965        let array = Array1::from_vec(data);
966        let view = array.view();
967
968        let scaler_params =
969            MLPipeline::fit_scaler(&view, ScalingMethod::MinMaxScaler).expect("Operation failed");
970
971        assert!(scaler_params.min.is_some());
972        assert!(scaler_params.max.is_some());
973        assert_eq!(scaler_params.min.expect("Test: min missing"), 1.0);
974        assert_eq!(scaler_params.max.expect("Test: max missing"), 5.0);
975    }
976}