scirs2_datasets/
ml_integration.rs

1//! Machine learning pipeline integration
2//!
3//! This module provides integration utilities for common ML frameworks and pipelines:
4//! - Model training data preparation
5//! - Cross-validation utilities
6//! - Feature engineering pipelines
7//! - Model evaluation and metrics
8//! - Integration with popular ML libraries
9
10use std::collections::HashMap;
11
12use ndarray::{Array1, Array2, Axis};
13use rand::rngs::StdRng;
14use rand::seq::SliceRandom;
15use rand::{rng, SeedableRng};
16use rand_distr::Uniform;
17use serde::{Deserialize, Serialize};
18
19use crate::error::{DatasetsError, Result};
20use crate::utils::{BalancingStrategy, CrossValidationFolds, Dataset};
21
22/// Configuration for ML pipeline integration
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct MLPipelineConfig {
25    /// Random seed for reproducibility
26    pub random_state: Option<u64>,
27    /// Default test size for train/test splits
28    pub test_size: f64,
29    /// Number of folds for cross-validation
30    pub cv_folds: usize,
31    /// Whether to stratify splits for classification
32    pub stratify: bool,
33    /// Data balancing strategy
34    pub balancing_strategy: Option<BalancingStrategy>,
35    /// Feature scaling method
36    pub scaling_method: Option<ScalingMethod>,
37}
38
39impl Default for MLPipelineConfig {
40    fn default() -> Self {
41        Self {
42            random_state: Some(42),
43            test_size: 0.2,
44            cv_folds: 5,
45            stratify: true,
46            balancing_strategy: None,
47            scaling_method: Some(ScalingMethod::StandardScaler),
48        }
49    }
50}
51
52/// Feature scaling methods
53#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
54pub enum ScalingMethod {
55    /// Z-score normalization
56    StandardScaler,
57    /// Min-max scaling to [0, 1]
58    MinMaxScaler,
59    /// Robust scaling using median and MAD
60    RobustScaler,
61    /// No scaling
62    None,
63}
64
65/// ML pipeline for data preprocessing and preparation
66pub struct MLPipeline {
67    config: MLPipelineConfig,
68    fitted_scalers: Option<HashMap<String, ScalerParams>>,
69}
70
71/// Parameters for fitted scalers
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct ScalerParams {
74    /// Scaling method used
75    pub method: ScalingMethod,
76    /// Mean value (for StandardScaler)
77    pub mean: Option<f64>,
78    /// Standard deviation (for StandardScaler)
79    pub std: Option<f64>,
80    /// Minimum value (for MinMaxScaler)
81    pub min: Option<f64>,
82    /// Maximum value (for MinMaxScaler)
83    pub max: Option<f64>,
84    /// Median value (for RobustScaler)
85    pub median: Option<f64>,
86    /// Median absolute deviation (for RobustScaler)
87    pub mad: Option<f64>,
88}
89
90/// ML experiment tracking
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct MLExperiment {
93    /// Experiment name
94    pub name: String,
95    /// Dataset information
96    pub dataset_info: DatasetInfo,
97    /// Model configuration
98    pub model_config: ModelConfig,
99    /// Training results
100    pub results: ExperimentResults,
101    /// Cross-validation scores
102    pub cv_scores: Option<CrossValidationResults>,
103}
104
105/// Dataset information for experiments
106#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct DatasetInfo {
108    /// Number of samples in the dataset
109    pub n_samples: usize,
110    /// Number of features in the dataset
111    pub n_features: usize,
112    /// Number of classes (for classification tasks)
113    pub n_classes: Option<usize>,
114    /// Distribution of classes in the dataset
115    pub class_distribution: Option<HashMap<String, usize>>,
116    /// Percentage of missing data
117    pub missing_data_percentage: f64,
118}
119
120/// Model configuration
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct ModelConfig {
123    /// Type of ML model used
124    pub model_type: String,
125    /// Hyperparameter settings
126    pub hyperparameters: HashMap<String, serde_json::Value>,
127    /// List of preprocessing steps applied
128    pub preprocessing_steps: Vec<String>,
129}
130
131/// Experiment results
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct ExperimentResults {
134    /// Score on training data
135    pub training_score: f64,
136    /// Score on validation data
137    pub validation_score: f64,
138    /// Score on test data (if available)
139    pub test_score: Option<f64>,
140    /// Time taken for training (in seconds)
141    pub training_time: f64,
142    /// Average inference time per sample (in milliseconds)
143    pub inference_time: Option<f64>,
144    /// Feature importance scores
145    pub feature_importance: Option<Vec<(String, f64)>>,
146}
147
148/// Cross-validation results
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct CrossValidationResults {
151    /// Individual scores for each fold
152    pub scores: Vec<f64>,
153    /// Mean score across all folds
154    pub mean_score: f64,
155    /// Standard deviation of scores
156    pub std_score: f64,
157    /// Detailed results for each fold
158    pub fold_details: Vec<FoldResult>,
159}
160
161/// Result for a single fold
162#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct FoldResult {
164    /// Index of the fold
165    pub fold_index: usize,
166    /// Training score for this fold
167    pub train_score: f64,
168    /// Validation score for this fold
169    pub validation_score: f64,
170    /// Training time in seconds for this fold
171    pub training_time: f64,
172}
173
174/// Data split for ML training
175#[derive(Debug, Clone)]
176pub struct DataSplit {
177    /// Training features
178    pub x_train: Array2<f64>,
179    /// Testing features
180    pub x_test: Array2<f64>,
181    /// Training targets
182    pub y_train: Array1<f64>,
183    /// Testing targets
184    pub y_test: Array1<f64>,
185}
186
187impl Default for MLPipeline {
188    fn default() -> Self {
189        Self::new(MLPipelineConfig::default())
190    }
191}
192
193impl MLPipeline {
194    /// Create a new ML pipeline
195    pub fn new(config: MLPipelineConfig) -> Self {
196        Self {
197            config,
198            fitted_scalers: None,
199        }
200    }
201
202    /// Prepare dataset for ML training
203    pub fn prepare_dataset(&mut self, dataset: &Dataset) -> Result<Dataset> {
204        let mut prepared = dataset.clone();
205
206        // Apply balancing if specified
207        if let Some(ref strategy) = self.config.balancing_strategy {
208            prepared = self.apply_balancing(&prepared, strategy)?;
209        }
210
211        // Apply scaling if specified
212        if let Some(method) = self.config.scaling_method {
213            prepared = self.fit_and_transform_scaling(&prepared, method)?;
214        }
215
216        Ok(prepared)
217    }
218
219    /// Split dataset into train/test sets
220    pub fn train_test_split(&self, dataset: &Dataset) -> Result<DataSplit> {
221        let n_samples = dataset.n_samples();
222        let test_samples = (n_samples as f64 * self.config.test_size) as usize;
223        let train_samples = n_samples - test_samples;
224
225        let indices = self.generate_split_indices(n_samples, dataset.target.as_ref())?;
226
227        let train_indices = &indices[..train_samples];
228        let test_indices = &indices[train_samples..];
229
230        let x_train = dataset.data.select(Axis(0), train_indices);
231        let x_test = dataset.data.select(Axis(0), test_indices);
232
233        let (y_train, y_test) = if let Some(ref target) = dataset.target {
234            let y_train = target.select(Axis(0), train_indices);
235            let y_test = target.select(Axis(0), test_indices);
236            (y_train, y_test)
237        } else {
238            return Err(DatasetsError::InvalidFormat(
239                "Target variable required for train/test split".to_string(),
240            ));
241        };
242
243        Ok(DataSplit {
244            x_train,
245            x_test,
246            y_train,
247            y_test,
248        })
249    }
250
251    /// Generate cross-validation folds
252    pub fn cross_validation_split(&self, dataset: &Dataset) -> Result<CrossValidationFolds> {
253        let target = dataset.target.as_ref().ok_or_else(|| {
254            DatasetsError::InvalidFormat(
255                "Target variable required for cross-validation".to_string(),
256            )
257        })?;
258
259        if self.config.stratify {
260            crate::utils::stratified_k_fold_split(
261                target,
262                self.config.cv_folds,
263                true,
264                self.config.random_state,
265            )
266        } else {
267            crate::utils::k_fold_split(
268                dataset.n_samples(),
269                self.config.cv_folds,
270                true,
271                self.config.random_state,
272            )
273        }
274    }
275
276    /// Transform new data using fitted scalers
277    pub fn transform(&self, dataset: &Dataset) -> Result<Dataset> {
278        let scalers = self.fitted_scalers.as_ref().ok_or_else(|| {
279            DatasetsError::InvalidFormat(
280                "Pipeline not fitted. Call prepare_dataset first.".to_string(),
281            )
282        })?;
283
284        let mut transformed_data = dataset.data.clone();
285
286        for (col_idx, mut column) in transformed_data.columns_mut().into_iter().enumerate() {
287            let defaultname = format!("feature_{col_idx}");
288            let featurename = dataset
289                .featurenames
290                .as_ref()
291                .and_then(|names| names.get(col_idx))
292                .map(|s| s.as_str())
293                .unwrap_or(&defaultname);
294
295            if let Some(scaler) = scalers.get(featurename) {
296                Self::apply_scaler_to_column(&mut column, scaler)?;
297            }
298        }
299
300        Ok(Dataset {
301            data: transformed_data,
302            target: dataset.target.clone(),
303            featurenames: dataset.featurenames.clone(),
304            targetnames: dataset.targetnames.clone(),
305            feature_descriptions: dataset.feature_descriptions.clone(),
306            description: Some("Transformed dataset".to_string()),
307            metadata: dataset.metadata.clone(),
308        })
309    }
310
311    /// Create an ML experiment tracker
312    pub fn create_experiment(&self, name: &str, dataset: &Dataset) -> MLExperiment {
313        let dataset_info = self.extract_dataset_info(dataset);
314
315        MLExperiment {
316            name: name.to_string(),
317            dataset_info,
318            model_config: ModelConfig {
319                model_type: "undefined".to_string(),
320                hyperparameters: HashMap::new(),
321                preprocessing_steps: Vec::new(),
322            },
323            results: ExperimentResults {
324                training_score: 0.0,
325                validation_score: 0.0,
326                test_score: None,
327                training_time: 0.0,
328                inference_time: None,
329                feature_importance: None,
330            },
331            cv_scores: None,
332        }
333    }
334
335    /// Evaluate model performance with cross-validation
336    pub fn evaluate_with_cv<F>(
337        &self,
338        dataset: &Dataset,
339        train_fn: F,
340    ) -> Result<CrossValidationResults>
341    where
342        F: Fn(&Array2<f64>, &Array1<f64>, &Array2<f64>, &Array1<f64>) -> Result<(f64, f64, f64)>,
343    {
344        let folds = self.cross_validation_split(dataset)?;
345        let mut scores = Vec::new();
346        let mut fold_details = Vec::new();
347
348        for (fold_idx, (train_indices, val_indices)) in folds.into_iter().enumerate() {
349            let x_train = dataset.data.select(Axis(0), &train_indices);
350            let x_val = dataset.data.select(Axis(0), &val_indices);
351
352            let target = dataset.target.as_ref().unwrap();
353            let y_train = target.select(Axis(0), &train_indices);
354            let y_val = target.select(Axis(0), &val_indices);
355
356            let (train_score, val_score, training_time) =
357                train_fn(&x_train, &y_train, &x_val, &y_val)?;
358
359            scores.push(val_score);
360            fold_details.push(FoldResult {
361                fold_index: fold_idx,
362                train_score,
363                validation_score: val_score,
364                training_time,
365            });
366        }
367
368        let mean_score = scores.iter().sum::<f64>() / scores.len() as f64;
369        let variance = scores
370            .iter()
371            .map(|score| (score - mean_score).powi(2))
372            .sum::<f64>()
373            / scores.len() as f64;
374        let std_score = variance.sqrt();
375
376        Ok(CrossValidationResults {
377            scores,
378            mean_score,
379            std_score,
380            fold_details,
381        })
382    }
383
384    // Private helper methods
385
386    fn apply_balancing(&self, dataset: &Dataset, strategy: &BalancingStrategy) -> Result<Dataset> {
387        // Simplified balancing implementation
388        // In a full implementation, you'd use the actual balancing utilities
389        match strategy {
390            BalancingStrategy::RandomUndersample => self.random_undersample(dataset, None),
391            BalancingStrategy::RandomOversample => self.random_oversample(dataset, None),
392            _ => Ok(dataset.clone()), // Placeholder for other strategies
393        }
394    }
395
396    fn random_undersample(&self, dataset: &Dataset, _randomstate: Option<u64>) -> Result<Dataset> {
397        let target = dataset.target.as_ref().ok_or_else(|| {
398            DatasetsError::InvalidFormat("Target required for balancing".to_string())
399        })?;
400
401        // Find minority class size
402        let mut class_counts: HashMap<i64, usize> = HashMap::new();
403        for &value in target.iter() {
404            if !value.is_nan() {
405                *class_counts.entry(value as i64).or_insert(0) += 1;
406            }
407        }
408
409        let min_count = class_counts.values().min().copied().unwrap_or(0);
410
411        // Sample min_count samples from each class
412        let mut selected_indices = Vec::new();
413
414        for (class_, _count) in class_counts {
415            let class_indices: Vec<usize> = target
416                .iter()
417                .enumerate()
418                .filter(|(_, &val)| !val.is_nan() && val as i64 == class_)
419                .map(|(idx, _)| idx)
420                .collect();
421
422            let mut sampled_indices = class_indices;
423            if sampled_indices.len() > min_count {
424                // Simple random sampling (in a real implementation, use proper random sampling)
425                sampled_indices.truncate(min_count);
426            }
427
428            selected_indices.extend(sampled_indices);
429        }
430
431        let balanced_data = dataset.data.select(Axis(0), &selected_indices);
432        let balanced_target = target.select(Axis(0), &selected_indices);
433
434        Ok(Dataset {
435            data: balanced_data,
436            target: Some(balanced_target),
437            featurenames: dataset.featurenames.clone(),
438            targetnames: dataset.targetnames.clone(),
439            feature_descriptions: dataset.feature_descriptions.clone(),
440            description: Some("Undersampled dataset".to_string()),
441            metadata: dataset.metadata.clone(),
442        })
443    }
444
445    fn random_oversample(&self, dataset: &Dataset, randomstate: Option<u64>) -> Result<Dataset> {
446        use rand::prelude::*;
447        use rand::{rngs::StdRng, RngCore, SeedableRng};
448        use std::collections::HashMap;
449
450        let target = dataset.target.as_ref().ok_or_else(|| {
451            DatasetsError::InvalidFormat("Random oversampling requires target labels".to_string())
452        })?;
453
454        if target.len() != dataset.data.nrows() {
455            return Err(DatasetsError::InvalidFormat(
456                "Target length must match number of samples".to_string(),
457            ));
458        }
459
460        // Count samples per class
461        let mut class_counts: HashMap<i32, usize> = HashMap::new();
462        let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
463
464        for (idx, &label) in target.iter().enumerate() {
465            let class = label as i32;
466            *class_counts.entry(class).or_insert(0) += 1;
467            class_indices.entry(class).or_default().push(idx);
468        }
469
470        // Find the majority class size (the maximum count)
471        let max_count = class_counts.values().max().copied().unwrap_or(0);
472
473        if max_count == 0 {
474            return Err(DatasetsError::InvalidFormat(
475                "No samples found in dataset".to_string(),
476            ));
477        }
478
479        // Create RNG
480        let mut rng: Box<dyn RngCore> = match randomstate {
481            Some(seed) => Box::new(StdRng::seed_from_u64(seed)),
482            None => Box::new(rng()),
483        };
484
485        // Collect all indices for the oversampled dataset
486        let mut all_indices = Vec::new();
487
488        for (_class, indices) in class_indices.iter() {
489            let current_count = indices.len();
490
491            // Add all original samples
492            all_indices.extend(indices.iter().copied());
493
494            // Add additional samples by random oversampling with replacement
495            let samples_needed = max_count - current_count;
496
497            if samples_needed > 0 {
498                for _ in 0..samples_needed {
499                    let random_idx = rng.sample(Uniform::new(0, indices.len()).unwrap());
500                    all_indices.push(indices[random_idx]);
501                }
502            }
503        }
504
505        // Shuffle the final indices to mix classes
506        all_indices.shuffle(&mut *rng);
507
508        // Create the oversampled dataset
509        let oversampled_data = dataset.data.select(Axis(0), &all_indices);
510        let oversampled_target = target.select(Axis(0), &all_indices);
511
512        Ok(Dataset {
513            data: oversampled_data,
514            target: Some(oversampled_target),
515            featurenames: dataset.featurenames.clone(),
516            targetnames: dataset.targetnames.clone(),
517            feature_descriptions: dataset.feature_descriptions.clone(),
518            description: Some(format!(
519                "Random oversampled dataset (original: {} samples, oversampled: {} samples)",
520                dataset.n_samples(),
521                all_indices.len()
522            )),
523            metadata: dataset.metadata.clone(),
524        })
525    }
526
527    fn fit_and_transform_scaling(
528        &mut self,
529        dataset: &Dataset,
530        method: ScalingMethod,
531    ) -> Result<Dataset> {
532        let mut scalers = HashMap::new();
533        let mut scaled_data = dataset.data.clone();
534
535        for (col_idx, mut column) in scaled_data.columns_mut().into_iter().enumerate() {
536            let featurename = dataset
537                .featurenames
538                .as_ref()
539                .and_then(|names| names.get(col_idx))
540                .cloned()
541                .unwrap_or_else(|| format!("feature_{col_idx}"));
542
543            let column_view = column.view();
544            let scaler_params = Self::fit_scaler(&column_view, method)?;
545            Self::apply_scaler_to_column(&mut column, &scaler_params)?;
546
547            scalers.insert(featurename, scaler_params);
548        }
549
550        self.fitted_scalers = Some(scalers);
551
552        Ok(Dataset {
553            data: scaled_data,
554            target: dataset.target.clone(),
555            featurenames: dataset.featurenames.clone(),
556            targetnames: dataset.targetnames.clone(),
557            feature_descriptions: dataset.feature_descriptions.clone(),
558            description: Some("Scaled dataset".to_string()),
559            metadata: dataset.metadata.clone(),
560        })
561    }
562
563    fn fit_scaler(
564        column: &ndarray::ArrayView1<f64>,
565        method: ScalingMethod,
566    ) -> Result<ScalerParams> {
567        let values: Vec<f64> = column.iter().copied().filter(|x| !x.is_nan()).collect();
568
569        if values.is_empty() {
570            return Ok(ScalerParams {
571                method,
572                mean: None,
573                std: None,
574                min: None,
575                max: None,
576                median: None,
577                mad: None,
578            });
579        }
580
581        match method {
582            ScalingMethod::StandardScaler => {
583                let mean = values.iter().sum::<f64>() / values.len() as f64;
584                let variance =
585                    values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
586                let std = variance.sqrt();
587
588                Ok(ScalerParams {
589                    method,
590                    mean: Some(mean),
591                    std: Some(std),
592                    min: None,
593                    max: None,
594                    median: None,
595                    mad: None,
596                })
597            }
598            ScalingMethod::MinMaxScaler => {
599                let min = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
600                let max = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
601
602                Ok(ScalerParams {
603                    method,
604                    mean: None,
605                    std: None,
606                    min: Some(min),
607                    max: Some(max),
608                    median: None,
609                    mad: None,
610                })
611            }
612            ScalingMethod::RobustScaler => {
613                let mut sorted_values = values.clone();
614                sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
615
616                let median = Self::percentile(&sorted_values, 0.5).unwrap_or(0.0);
617                let mad = Self::compute_mad(&sorted_values, median);
618
619                Ok(ScalerParams {
620                    method,
621                    mean: None,
622                    std: None,
623                    min: None,
624                    max: None,
625                    median: Some(median),
626                    mad: Some(mad),
627                })
628            }
629            ScalingMethod::None => Ok(ScalerParams {
630                method,
631                mean: None,
632                std: None,
633                min: None,
634                max: None,
635                median: None,
636                mad: None,
637            }),
638        }
639    }
640
641    fn apply_scaler_to_column(
642        column: &mut ndarray::ArrayViewMut1<f64>,
643        params: &ScalerParams,
644    ) -> Result<()> {
645        match params.method {
646            ScalingMethod::StandardScaler => {
647                if let (Some(mean), Some(std)) = (params.mean, params.std) {
648                    if std > 1e-8 {
649                        // Avoid division by zero
650                        for value in column.iter_mut() {
651                            if !value.is_nan() {
652                                *value = (*value - mean) / std;
653                            }
654                        }
655                    }
656                }
657            }
658            ScalingMethod::MinMaxScaler => {
659                if let (Some(min), Some(max)) = (params.min, params.max) {
660                    let range = max - min;
661                    if range > 1e-8 {
662                        // Avoid division by zero
663                        for value in column.iter_mut() {
664                            if !value.is_nan() {
665                                *value = (*value - min) / range;
666                            }
667                        }
668                    }
669                }
670            }
671            ScalingMethod::RobustScaler => {
672                if let (Some(median), Some(mad)) = (params.median, params.mad) {
673                    if mad > 1e-8 {
674                        // Avoid division by zero
675                        for value in column.iter_mut() {
676                            if !value.is_nan() {
677                                *value = (*value - median) / mad;
678                            }
679                        }
680                    }
681                }
682            }
683            ScalingMethod::None => {
684                // No scaling applied
685            }
686        }
687
688        Ok(())
689    }
690
691    fn percentile(sorted_values: &[f64], p: f64) -> Option<f64> {
692        if sorted_values.is_empty() {
693            return None;
694        }
695
696        let index = p * (sorted_values.len() - 1) as f64;
697        let lower = index.floor() as usize;
698        let upper = index.ceil() as usize;
699
700        if lower == upper {
701            Some(sorted_values[lower])
702        } else {
703            let weight = index - lower as f64;
704            Some(sorted_values[lower] * (1.0 - weight) + sorted_values[upper] * weight)
705        }
706    }
707
708    fn compute_mad(sorted_values: &[f64], median: f64) -> f64 {
709        let deviations: Vec<f64> = sorted_values.iter().map(|&x| (x - median).abs()).collect();
710
711        let mut sorted_deviations = deviations;
712        sorted_deviations.sort_by(|a, b| a.partial_cmp(b).unwrap());
713
714        Self::percentile(&sorted_deviations, 0.5).unwrap_or(1.0)
715    }
716
717    fn generate_split_indices(
718        &self,
719        n_samples: usize,
720        target: Option<&Array1<f64>>,
721    ) -> Result<Vec<usize>> {
722        let mut indices: Vec<usize> = (0..n_samples).collect();
723
724        // Use proper random shuffling based on configuration
725        if self.config.stratify && target.is_some() {
726            // Implement stratified shuffling
727            self.stratified_shuffle(&mut indices, target.unwrap())?;
728        } else {
729            // Regular shuffling with optional random state
730            match self.config.random_state {
731                Some(seed) => {
732                    let mut rng = StdRng::seed_from_u64(seed);
733                    indices.shuffle(&mut rng);
734                }
735                None => {
736                    let mut rng = rng();
737                    indices.shuffle(&mut rng);
738                }
739            }
740        }
741
742        Ok(indices)
743    }
744
745    /// Perform stratified shuffling to maintain class proportions
746    fn stratified_shuffle(&self, indices: &mut Vec<usize>, target: &Array1<f64>) -> Result<()> {
747        // Group indices by class
748        let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
749
750        for &idx in indices.iter() {
751            let class = target[idx] as i32;
752            class_indices.entry(class).or_default().push(idx);
753        }
754
755        // Shuffle each class group separately
756        for class_group in class_indices.values_mut() {
757            match self.config.random_state {
758                Some(seed) => {
759                    let mut rng = StdRng::seed_from_u64(seed);
760                    class_group.shuffle(&mut rng);
761                }
762                None => {
763                    let mut rng = rng();
764                    class_group.shuffle(&mut rng);
765                }
766            }
767        }
768
769        // Recombine shuffled class groups while maintaining order
770        indices.clear();
771        let mut class_iterators: HashMap<i32, std::vec::IntoIter<usize>> = class_indices
772            .into_iter()
773            .map(|(class, group)| (class, group.into_iter()))
774            .collect();
775
776        // Interleave samples from different classes to maintain distribution
777        while !class_iterators.is_empty() {
778            let mut to_remove = Vec::new();
779            for (&class, iterator) in class_iterators.iter_mut() {
780                if let Some(idx) = iterator.next() {
781                    indices.push(idx);
782                } else {
783                    to_remove.push(class);
784                }
785            }
786            for class in to_remove {
787                class_iterators.remove(&class);
788            }
789        }
790
791        Ok(())
792    }
793
794    fn extract_dataset_info(&self, dataset: &Dataset) -> DatasetInfo {
795        let n_samples = dataset.n_samples();
796        let n_features = dataset.n_features();
797
798        let (n_classes, class_distribution) = if let Some(ref target) = dataset.target {
799            let mut class_counts: HashMap<String, usize> = HashMap::new();
800            for &value in target.iter() {
801                if !value.is_nan() {
802                    let classname = format!("{value:.0}");
803                    *class_counts.entry(classname).or_insert(0) += 1;
804                }
805            }
806
807            let n_classes = class_counts.len();
808            (Some(n_classes), Some(class_counts))
809        } else {
810            (None, None)
811        };
812
813        // Calculate missing data percentage
814        let total_values = n_samples * n_features;
815        let missing_values = dataset.data.iter().filter(|&&x| x.is_nan()).count();
816        let missing_data_percentage = missing_values as f64 / total_values as f64 * 100.0;
817
818        DatasetInfo {
819            n_samples,
820            n_features,
821            n_classes,
822            class_distribution,
823            missing_data_percentage,
824        }
825    }
826}
827
828/// Convenience functions for ML pipeline integration
829pub mod convenience {
830    use super::*;
831
832    /// Quick train/test split with default configuration
833    pub fn train_test_split(_dataset: &Dataset, testsize: Option<f64>) -> Result<DataSplit> {
834        let mut config = MLPipelineConfig::default();
835        if let Some(_size) = testsize {
836            config.test_size = _size;
837        }
838
839        let pipeline = MLPipeline::new(config);
840        pipeline.train_test_split(_dataset)
841    }
842
843    /// Prepare dataset for ML with standard preprocessing
844    pub fn prepare_for_ml(dataset: &Dataset, scale: bool, balance: bool) -> Result<Dataset> {
845        let mut config = MLPipelineConfig::default();
846
847        if !scale {
848            config.scaling_method = None;
849        }
850
851        if balance {
852            config.balancing_strategy = Some(BalancingStrategy::RandomUndersample);
853        }
854
855        let mut pipeline = MLPipeline::new(config);
856        pipeline.prepare_dataset(dataset)
857    }
858
859    /// Generate cross-validation folds
860    pub fn cv_split(
861        dataset: &Dataset,
862        n_folds: Option<usize>,
863        stratify: Option<bool>,
864    ) -> Result<CrossValidationFolds> {
865        let mut config = MLPipelineConfig::default();
866
867        if let Some(_folds) = n_folds {
868            config.cv_folds = _folds;
869        }
870
871        if let Some(strat) = stratify {
872            config.stratify = strat;
873        }
874
875        let pipeline = MLPipeline::new(config);
876        pipeline.cross_validation_split(dataset)
877    }
878
879    /// Create a simple ML experiment
880    pub fn create_experiment(name: &str, dataset: &Dataset) -> MLExperiment {
881        let pipeline = MLPipeline::default();
882        pipeline.create_experiment(name, dataset)
883    }
884}
885
886#[cfg(test)]
887mod tests {
888    use super::*;
889    use crate::generators::make_classification;
890    use rand_distr::Uniform;
891
892    #[test]
893    fn test_ml_pipeline_creation() {
894        let pipeline = MLPipeline::default();
895        assert_eq!(pipeline.config.test_size, 0.2);
896        assert_eq!(pipeline.config.cv_folds, 5);
897    }
898
899    #[test]
900    fn test_train_test_split() {
901        let dataset = make_classification(100, 5, 2, 1, 1, Some(42)).unwrap();
902        let split = convenience::train_test_split(&dataset, Some(0.3)).unwrap();
903
904        assert_eq!(split.x_train.nrows() + split.x_test.nrows(), 100);
905        assert_eq!(split.y_train.len() + split.y_test.len(), 100);
906        assert_eq!(split.x_train.ncols(), 5);
907        assert_eq!(split.x_test.ncols(), 5);
908    }
909
910    #[test]
911    fn test_cross_validation_split() {
912        let dataset = make_classification(100, 3, 2, 1, 1, Some(42)).unwrap();
913        let folds = convenience::cv_split(&dataset, Some(5), Some(true)).unwrap();
914
915        assert_eq!(folds.len(), 5);
916
917        // Check that all samples are used
918        let total_samples: usize = folds
919            .iter()
920            .map(|(train, test)| train.len() + test.len())
921            .sum::<usize>()
922            / 5; // Each sample appears in exactly one test set
923
924        assert_eq!(total_samples, 100);
925    }
926
927    #[test]
928    fn test_dataset_preparation() {
929        let dataset = make_classification(50, 4, 2, 1, 1, Some(42)).unwrap();
930        let prepared = convenience::prepare_for_ml(&dataset, true, false).unwrap();
931
932        assert_eq!(prepared.n_samples(), dataset.n_samples());
933        assert_eq!(prepared.n_features(), dataset.n_features());
934    }
935
936    #[test]
937    fn test_experiment_creation() {
938        let dataset = make_classification(100, 5, 2, 1, 1, Some(42)).unwrap();
939        let experiment = convenience::create_experiment("test_experiment", &dataset);
940
941        assert_eq!(experiment.name, "test_experiment");
942        assert_eq!(experiment.dataset_info.n_samples, 100);
943        assert_eq!(experiment.dataset_info.n_features, 5);
944        assert_eq!(experiment.dataset_info.n_classes, Some(2));
945    }
946
947    #[test]
948    fn test_scaler_fitting() {
949        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
950        let array = Array1::from_vec(data);
951        let view = array.view();
952
953        let scaler_params = MLPipeline::fit_scaler(&view, ScalingMethod::StandardScaler).unwrap();
954
955        assert!(scaler_params.mean.is_some());
956        assert!(scaler_params.std.is_some());
957        assert_eq!(scaler_params.mean.unwrap(), 3.0);
958    }
959
960    #[test]
961    fn test_min_max_scaler() {
962        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
963        let array = Array1::from_vec(data);
964        let view = array.view();
965
966        let scaler_params = MLPipeline::fit_scaler(&view, ScalingMethod::MinMaxScaler).unwrap();
967
968        assert!(scaler_params.min.is_some());
969        assert!(scaler_params.max.is_some());
970        assert_eq!(scaler_params.min.unwrap(), 1.0);
971        assert_eq!(scaler_params.max.unwrap(), 5.0);
972    }
973}