1use std::collections::HashMap;
11
12use scirs2_core::ndarray::{Array1, Array2, Axis};
13use scirs2_core::random::prelude::*;
14use scirs2_core::random::SliceRandomExt;
15use scirs2_core::random::Uniform;
16use serde::{Deserialize, Serialize};
17
18use crate::error::{DatasetsError, Result};
19use crate::utils::{BalancingStrategy, CrossValidationFolds, Dataset};
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct MLPipelineConfig {
24 pub random_state: Option<u64>,
26 pub test_size: f64,
28 pub cv_folds: usize,
30 pub stratify: bool,
32 pub balancing_strategy: Option<BalancingStrategy>,
34 pub scaling_method: Option<ScalingMethod>,
36}
37
38impl Default for MLPipelineConfig {
39 fn default() -> Self {
40 Self {
41 random_state: Some(42),
42 test_size: 0.2,
43 cv_folds: 5,
44 stratify: true,
45 balancing_strategy: None,
46 scaling_method: Some(ScalingMethod::StandardScaler),
47 }
48 }
49}
50
51#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
53pub enum ScalingMethod {
54 StandardScaler,
56 MinMaxScaler,
58 RobustScaler,
60 None,
62}
63
64pub struct MLPipeline {
66 config: MLPipelineConfig,
67 fitted_scalers: Option<HashMap<String, ScalerParams>>,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ScalerParams {
73 pub method: ScalingMethod,
75 pub mean: Option<f64>,
77 pub std: Option<f64>,
79 pub min: Option<f64>,
81 pub max: Option<f64>,
83 pub median: Option<f64>,
85 pub mad: Option<f64>,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct MLExperiment {
92 pub name: String,
94 pub dataset_info: DatasetInfo,
96 pub model_config: ModelConfig,
98 pub results: ExperimentResults,
100 pub cv_scores: Option<CrossValidationResults>,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct DatasetInfo {
107 pub n_samples: usize,
109 pub n_features: usize,
111 pub n_classes: Option<usize>,
113 pub class_distribution: Option<HashMap<String, usize>>,
115 pub missing_data_percentage: f64,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct ModelConfig {
122 pub model_type: String,
124 pub hyperparameters: HashMap<String, serde_json::Value>,
126 pub preprocessing_steps: Vec<String>,
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct ExperimentResults {
133 pub training_score: f64,
135 pub validation_score: f64,
137 pub test_score: Option<f64>,
139 pub training_time: f64,
141 pub inference_time: Option<f64>,
143 pub feature_importance: Option<Vec<(String, f64)>>,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct CrossValidationResults {
150 pub scores: Vec<f64>,
152 pub mean_score: f64,
154 pub std_score: f64,
156 pub fold_details: Vec<FoldResult>,
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct FoldResult {
163 pub fold_index: usize,
165 pub train_score: f64,
167 pub validation_score: f64,
169 pub training_time: f64,
171}
172
173#[derive(Debug, Clone)]
175pub struct DataSplit {
176 pub x_train: Array2<f64>,
178 pub x_test: Array2<f64>,
180 pub y_train: Array1<f64>,
182 pub y_test: Array1<f64>,
184}
185
186impl Default for MLPipeline {
187 fn default() -> Self {
188 Self::new(MLPipelineConfig::default())
189 }
190}
191
192impl MLPipeline {
193 pub fn new(config: MLPipelineConfig) -> Self {
195 Self {
196 config,
197 fitted_scalers: None,
198 }
199 }
200
201 pub fn prepare_dataset(&mut self, dataset: &Dataset) -> Result<Dataset> {
203 let mut prepared = dataset.clone();
204
205 if let Some(ref strategy) = self.config.balancing_strategy {
207 prepared = self.apply_balancing(&prepared, strategy)?;
208 }
209
210 if let Some(method) = self.config.scaling_method {
212 prepared = self.fit_and_transform_scaling(&prepared, method)?;
213 }
214
215 Ok(prepared)
216 }
217
218 pub fn train_test_split(&self, dataset: &Dataset) -> Result<DataSplit> {
220 let n_samples = dataset.n_samples();
221 let test_samples = (n_samples as f64 * self.config.test_size) as usize;
222 let train_samples = n_samples - test_samples;
223
224 let indices = self.generate_split_indices(n_samples, dataset.target.as_ref())?;
225
226 let train_indices = &indices[..train_samples];
227 let test_indices = &indices[train_samples..];
228
229 let x_train = dataset.data.select(Axis(0), train_indices);
230 let x_test = dataset.data.select(Axis(0), test_indices);
231
232 let (y_train, y_test) = if let Some(ref target) = dataset.target {
233 let y_train = target.select(Axis(0), train_indices);
234 let y_test = target.select(Axis(0), test_indices);
235 (y_train, y_test)
236 } else {
237 return Err(DatasetsError::InvalidFormat(
238 "Target variable required for train/test split".to_string(),
239 ));
240 };
241
242 Ok(DataSplit {
243 x_train,
244 x_test,
245 y_train,
246 y_test,
247 })
248 }
249
250 pub fn cross_validation_split(&self, dataset: &Dataset) -> Result<CrossValidationFolds> {
252 let target = dataset.target.as_ref().ok_or_else(|| {
253 DatasetsError::InvalidFormat(
254 "Target variable required for cross-validation".to_string(),
255 )
256 })?;
257
258 if self.config.stratify {
259 crate::utils::stratified_k_fold_split(
260 target,
261 self.config.cv_folds,
262 true,
263 self.config.random_state,
264 )
265 } else {
266 crate::utils::k_fold_split(
267 dataset.n_samples(),
268 self.config.cv_folds,
269 true,
270 self.config.random_state,
271 )
272 }
273 }
274
275 pub fn transform(&self, dataset: &Dataset) -> Result<Dataset> {
277 let scalers = self.fitted_scalers.as_ref().ok_or_else(|| {
278 DatasetsError::InvalidFormat(
279 "Pipeline not fitted. Call prepare_dataset first.".to_string(),
280 )
281 })?;
282
283 let mut transformed_data = dataset.data.clone();
284
285 for (col_idx, mut column) in transformed_data.columns_mut().into_iter().enumerate() {
286 let defaultname = format!("feature_{col_idx}");
287 let featurename = dataset
288 .featurenames
289 .as_ref()
290 .and_then(|names| names.get(col_idx))
291 .map(|s| s.as_str())
292 .unwrap_or(&defaultname);
293
294 if let Some(scaler) = scalers.get(featurename) {
295 Self::apply_scaler_to_column(&mut column, scaler)?;
296 }
297 }
298
299 Ok(Dataset {
300 data: transformed_data,
301 target: dataset.target.clone(),
302 featurenames: dataset.featurenames.clone(),
303 targetnames: dataset.targetnames.clone(),
304 feature_descriptions: dataset.feature_descriptions.clone(),
305 description: Some("Transformed dataset".to_string()),
306 metadata: dataset.metadata.clone(),
307 })
308 }
309
310 pub fn create_experiment(&self, name: &str, dataset: &Dataset) -> MLExperiment {
312 let dataset_info = self.extract_dataset_info(dataset);
313
314 MLExperiment {
315 name: name.to_string(),
316 dataset_info,
317 model_config: ModelConfig {
318 model_type: "undefined".to_string(),
319 hyperparameters: HashMap::new(),
320 preprocessing_steps: Vec::new(),
321 },
322 results: ExperimentResults {
323 training_score: 0.0,
324 validation_score: 0.0,
325 test_score: None,
326 training_time: 0.0,
327 inference_time: None,
328 feature_importance: None,
329 },
330 cv_scores: None,
331 }
332 }
333
334 pub fn evaluate_with_cv<F>(
336 &self,
337 dataset: &Dataset,
338 train_fn: F,
339 ) -> Result<CrossValidationResults>
340 where
341 F: Fn(&Array2<f64>, &Array1<f64>, &Array2<f64>, &Array1<f64>) -> Result<(f64, f64, f64)>,
342 {
343 let folds = self.cross_validation_split(dataset)?;
344 let mut scores = Vec::new();
345 let mut fold_details = Vec::new();
346
347 for (fold_idx, (train_indices, val_indices)) in folds.into_iter().enumerate() {
348 let x_train = dataset.data.select(Axis(0), &train_indices);
349 let x_val = dataset.data.select(Axis(0), &val_indices);
350
351 let target = dataset.target.as_ref().unwrap();
352 let y_train = target.select(Axis(0), &train_indices);
353 let y_val = target.select(Axis(0), &val_indices);
354
355 let (train_score, val_score, training_time) =
356 train_fn(&x_train, &y_train, &x_val, &y_val)?;
357
358 scores.push(val_score);
359 fold_details.push(FoldResult {
360 fold_index: fold_idx,
361 train_score,
362 validation_score: val_score,
363 training_time,
364 });
365 }
366
367 let mean_score = scores.iter().sum::<f64>() / scores.len() as f64;
368 let variance = scores
369 .iter()
370 .map(|score| (score - mean_score).powi(2))
371 .sum::<f64>()
372 / scores.len() as f64;
373 let std_score = variance.sqrt();
374
375 Ok(CrossValidationResults {
376 scores,
377 mean_score,
378 std_score,
379 fold_details,
380 })
381 }
382
383 fn apply_balancing(&self, dataset: &Dataset, strategy: &BalancingStrategy) -> Result<Dataset> {
386 match strategy {
389 BalancingStrategy::RandomUndersample => self.random_undersample(dataset, None),
390 BalancingStrategy::RandomOversample => self.random_oversample(dataset, None),
391 _ => Ok(dataset.clone()), }
393 }
394
395 fn random_undersample(&self, dataset: &Dataset, _randomstate: Option<u64>) -> Result<Dataset> {
396 let target = dataset.target.as_ref().ok_or_else(|| {
397 DatasetsError::InvalidFormat("Target required for balancing".to_string())
398 })?;
399
400 let mut class_counts: HashMap<i64, usize> = HashMap::new();
402 for &value in target.iter() {
403 if !value.is_nan() {
404 *class_counts.entry(value as i64).or_insert(0) += 1;
405 }
406 }
407
408 let min_count = class_counts.values().min().copied().unwrap_or(0);
409
410 let mut selected_indices = Vec::new();
412
413 for (class_, _count) in class_counts {
414 let class_indices: Vec<usize> = target
415 .iter()
416 .enumerate()
417 .filter(|(_, &val)| !val.is_nan() && val as i64 == class_)
418 .map(|(idx, _)| idx)
419 .collect();
420
421 let mut sampled_indices = class_indices;
422 if sampled_indices.len() > min_count {
423 sampled_indices.truncate(min_count);
425 }
426
427 selected_indices.extend(sampled_indices);
428 }
429
430 let balanced_data = dataset.data.select(Axis(0), &selected_indices);
431 let balanced_target = target.select(Axis(0), &selected_indices);
432
433 Ok(Dataset {
434 data: balanced_data,
435 target: Some(balanced_target),
436 featurenames: dataset.featurenames.clone(),
437 targetnames: dataset.targetnames.clone(),
438 feature_descriptions: dataset.feature_descriptions.clone(),
439 description: Some("Undersampled dataset".to_string()),
440 metadata: dataset.metadata.clone(),
441 })
442 }
443
444 fn random_oversample(&self, dataset: &Dataset, randomstate: Option<u64>) -> Result<Dataset> {
445 use scirs2_core::random::prelude::*;
446 use scirs2_core::random::{rngs::StdRng, RngCore, SeedableRng};
447 use std::collections::HashMap;
448
449 let target = dataset.target.as_ref().ok_or_else(|| {
450 DatasetsError::InvalidFormat("Random oversampling requires target labels".to_string())
451 })?;
452
453 if target.len() != dataset.data.nrows() {
454 return Err(DatasetsError::InvalidFormat(
455 "Target length must match number of samples".to_string(),
456 ));
457 }
458
459 let mut class_counts: HashMap<i32, usize> = HashMap::new();
461 let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
462
463 for (idx, &label) in target.iter().enumerate() {
464 let class = label as i32;
465 *class_counts.entry(class).or_insert(0) += 1;
466 class_indices.entry(class).or_default().push(idx);
467 }
468
469 let max_count = class_counts.values().max().copied().unwrap_or(0);
471
472 if max_count == 0 {
473 return Err(DatasetsError::InvalidFormat(
474 "No samples found in dataset".to_string(),
475 ));
476 }
477
478 let mut rng: Box<dyn RngCore> = match randomstate {
480 Some(seed) => Box::new(StdRng::seed_from_u64(seed)),
481 None => Box::new(thread_rng()),
482 };
483
484 let mut all_indices = Vec::new();
486
487 for (_class, indices) in class_indices.iter() {
488 let current_count = indices.len();
489
490 all_indices.extend(indices.iter().copied());
492
493 let samples_needed = max_count - current_count;
495
496 if samples_needed > 0 {
497 for _ in 0..samples_needed {
498 let random_idx = rng.sample(Uniform::new(0, indices.len()).unwrap());
499 all_indices.push(indices[random_idx]);
500 }
501 }
502 }
503
504 all_indices.shuffle(&mut *rng);
506
507 let oversampled_data = dataset.data.select(Axis(0), &all_indices);
509 let oversampled_target = target.select(Axis(0), &all_indices);
510
511 Ok(Dataset {
512 data: oversampled_data,
513 target: Some(oversampled_target),
514 featurenames: dataset.featurenames.clone(),
515 targetnames: dataset.targetnames.clone(),
516 feature_descriptions: dataset.feature_descriptions.clone(),
517 description: Some(format!(
518 "Random oversampled dataset (original: {} samples, oversampled: {} samples)",
519 dataset.n_samples(),
520 all_indices.len()
521 )),
522 metadata: dataset.metadata.clone(),
523 })
524 }
525
526 fn fit_and_transform_scaling(
527 &mut self,
528 dataset: &Dataset,
529 method: ScalingMethod,
530 ) -> Result<Dataset> {
531 let mut scalers = HashMap::new();
532 let mut scaled_data = dataset.data.clone();
533
534 for (col_idx, mut column) in scaled_data.columns_mut().into_iter().enumerate() {
535 let featurename = dataset
536 .featurenames
537 .as_ref()
538 .and_then(|names| names.get(col_idx))
539 .cloned()
540 .unwrap_or_else(|| format!("feature_{col_idx}"));
541
542 let column_view = column.view();
543 let scaler_params = Self::fit_scaler(&column_view, method)?;
544 Self::apply_scaler_to_column(&mut column, &scaler_params)?;
545
546 scalers.insert(featurename, scaler_params);
547 }
548
549 self.fitted_scalers = Some(scalers);
550
551 Ok(Dataset {
552 data: scaled_data,
553 target: dataset.target.clone(),
554 featurenames: dataset.featurenames.clone(),
555 targetnames: dataset.targetnames.clone(),
556 feature_descriptions: dataset.feature_descriptions.clone(),
557 description: Some("Scaled dataset".to_string()),
558 metadata: dataset.metadata.clone(),
559 })
560 }
561
562 fn fit_scaler(
563 column: &scirs2_core::ndarray::ArrayView1<f64>,
564 method: ScalingMethod,
565 ) -> Result<ScalerParams> {
566 let values: Vec<f64> = column.iter().copied().filter(|x| !x.is_nan()).collect();
567
568 if values.is_empty() {
569 return Ok(ScalerParams {
570 method,
571 mean: None,
572 std: None,
573 min: None,
574 max: None,
575 median: None,
576 mad: None,
577 });
578 }
579
580 match method {
581 ScalingMethod::StandardScaler => {
582 let mean = values.iter().sum::<f64>() / values.len() as f64;
583 let variance =
584 values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
585 let std = variance.sqrt();
586
587 Ok(ScalerParams {
588 method,
589 mean: Some(mean),
590 std: Some(std),
591 min: None,
592 max: None,
593 median: None,
594 mad: None,
595 })
596 }
597 ScalingMethod::MinMaxScaler => {
598 let min = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
599 let max = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
600
601 Ok(ScalerParams {
602 method,
603 mean: None,
604 std: None,
605 min: Some(min),
606 max: Some(max),
607 median: None,
608 mad: None,
609 })
610 }
611 ScalingMethod::RobustScaler => {
612 let mut sorted_values = values.clone();
613 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
614
615 let median = Self::percentile(&sorted_values, 0.5).unwrap_or(0.0);
616 let mad = Self::compute_mad(&sorted_values, median);
617
618 Ok(ScalerParams {
619 method,
620 mean: None,
621 std: None,
622 min: None,
623 max: None,
624 median: Some(median),
625 mad: Some(mad),
626 })
627 }
628 ScalingMethod::None => Ok(ScalerParams {
629 method,
630 mean: None,
631 std: None,
632 min: None,
633 max: None,
634 median: None,
635 mad: None,
636 }),
637 }
638 }
639
640 fn apply_scaler_to_column(
641 column: &mut scirs2_core::ndarray::ArrayViewMut1<f64>,
642 params: &ScalerParams,
643 ) -> Result<()> {
644 match params.method {
645 ScalingMethod::StandardScaler => {
646 if let (Some(mean), Some(std)) = (params.mean, params.std) {
647 if std > 1e-8 {
648 for value in column.iter_mut() {
650 if !value.is_nan() {
651 *value = (*value - mean) / std;
652 }
653 }
654 }
655 }
656 }
657 ScalingMethod::MinMaxScaler => {
658 if let (Some(min), Some(max)) = (params.min, params.max) {
659 let range = max - min;
660 if range > 1e-8 {
661 for value in column.iter_mut() {
663 if !value.is_nan() {
664 *value = (*value - min) / range;
665 }
666 }
667 }
668 }
669 }
670 ScalingMethod::RobustScaler => {
671 if let (Some(median), Some(mad)) = (params.median, params.mad) {
672 if mad > 1e-8 {
673 for value in column.iter_mut() {
675 if !value.is_nan() {
676 *value = (*value - median) / mad;
677 }
678 }
679 }
680 }
681 }
682 ScalingMethod::None => {
683 }
685 }
686
687 Ok(())
688 }
689
690 fn percentile(sorted_values: &[f64], p: f64) -> Option<f64> {
691 if sorted_values.is_empty() {
692 return None;
693 }
694
695 let index = p * (sorted_values.len() - 1) as f64;
696 let lower = index.floor() as usize;
697 let upper = index.ceil() as usize;
698
699 if lower == upper {
700 Some(sorted_values[lower])
701 } else {
702 let weight = index - lower as f64;
703 Some(sorted_values[lower] * (1.0 - weight) + sorted_values[upper] * weight)
704 }
705 }
706
707 fn compute_mad(sorted_values: &[f64], median: f64) -> f64 {
708 let deviations: Vec<f64> = sorted_values.iter().map(|&x| (x - median).abs()).collect();
709
710 let mut sorted_deviations = deviations;
711 sorted_deviations.sort_by(|a, b| a.partial_cmp(b).unwrap());
712
713 Self::percentile(&sorted_deviations, 0.5).unwrap_or(1.0)
714 }
715
716 fn generate_split_indices(
717 &self,
718 n_samples: usize,
719 target: Option<&Array1<f64>>,
720 ) -> Result<Vec<usize>> {
721 let mut indices: Vec<usize> = (0..n_samples).collect();
722
723 if self.config.stratify && target.is_some() {
725 self.stratified_shuffle(&mut indices, target.unwrap())?;
727 } else {
728 match self.config.random_state {
730 Some(seed) => {
731 let mut rng = StdRng::seed_from_u64(seed);
732 indices.shuffle(&mut rng);
733 }
734 None => {
735 let mut rng = thread_rng();
736 indices.shuffle(&mut rng);
737 }
738 }
739 }
740
741 Ok(indices)
742 }
743
744 fn stratified_shuffle(&self, indices: &mut Vec<usize>, target: &Array1<f64>) -> Result<()> {
746 let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
748
749 for &idx in indices.iter() {
750 let class = target[idx] as i32;
751 class_indices.entry(class).or_default().push(idx);
752 }
753
754 for class_group in class_indices.values_mut() {
756 match self.config.random_state {
757 Some(seed) => {
758 let mut rng = StdRng::seed_from_u64(seed);
759 class_group.shuffle(&mut rng);
760 }
761 None => {
762 let mut rng = thread_rng();
763 class_group.shuffle(&mut rng);
764 }
765 }
766 }
767
768 indices.clear();
770 let mut class_iterators: HashMap<i32, std::vec::IntoIter<usize>> = class_indices
771 .into_iter()
772 .map(|(class, group)| (class, group.into_iter()))
773 .collect();
774
775 while !class_iterators.is_empty() {
777 let mut to_remove = Vec::new();
778 for (&class, iterator) in class_iterators.iter_mut() {
779 if let Some(idx) = iterator.next() {
780 indices.push(idx);
781 } else {
782 to_remove.push(class);
783 }
784 }
785 for class in to_remove {
786 class_iterators.remove(&class);
787 }
788 }
789
790 Ok(())
791 }
792
793 fn extract_dataset_info(&self, dataset: &Dataset) -> DatasetInfo {
794 let n_samples = dataset.n_samples();
795 let n_features = dataset.n_features();
796
797 let (n_classes, class_distribution) = if let Some(ref target) = dataset.target {
798 let mut class_counts: HashMap<String, usize> = HashMap::new();
799 for &value in target.iter() {
800 if !value.is_nan() {
801 let classname = format!("{value:.0}");
802 *class_counts.entry(classname).or_insert(0) += 1;
803 }
804 }
805
806 let n_classes = class_counts.len();
807 (Some(n_classes), Some(class_counts))
808 } else {
809 (None, None)
810 };
811
812 let total_values = n_samples * n_features;
814 let missing_values = dataset.data.iter().filter(|&&x| x.is_nan()).count();
815 let missing_data_percentage = missing_values as f64 / total_values as f64 * 100.0;
816
817 DatasetInfo {
818 n_samples,
819 n_features,
820 n_classes,
821 class_distribution,
822 missing_data_percentage,
823 }
824 }
825}
826
827pub mod convenience {
829 use super::*;
830
831 pub fn train_test_split(_dataset: &Dataset, testsize: Option<f64>) -> Result<DataSplit> {
833 let mut config = MLPipelineConfig::default();
834 if let Some(_size) = testsize {
835 config.test_size = _size;
836 }
837
838 let pipeline = MLPipeline::new(config);
839 pipeline.train_test_split(_dataset)
840 }
841
842 pub fn prepare_for_ml(dataset: &Dataset, scale: bool, balance: bool) -> Result<Dataset> {
844 let mut config = MLPipelineConfig::default();
845
846 if !scale {
847 config.scaling_method = None;
848 }
849
850 if balance {
851 config.balancing_strategy = Some(BalancingStrategy::RandomUndersample);
852 }
853
854 let mut pipeline = MLPipeline::new(config);
855 pipeline.prepare_dataset(dataset)
856 }
857
858 pub fn cv_split(
860 dataset: &Dataset,
861 n_folds: Option<usize>,
862 stratify: Option<bool>,
863 ) -> Result<CrossValidationFolds> {
864 let mut config = MLPipelineConfig::default();
865
866 if let Some(_folds) = n_folds {
867 config.cv_folds = _folds;
868 }
869
870 if let Some(strat) = stratify {
871 config.stratify = strat;
872 }
873
874 let pipeline = MLPipeline::new(config);
875 pipeline.cross_validation_split(dataset)
876 }
877
878 pub fn create_experiment(name: &str, dataset: &Dataset) -> MLExperiment {
880 let pipeline = MLPipeline::default();
881 pipeline.create_experiment(name, dataset)
882 }
883}
884
885#[cfg(test)]
886mod tests {
887 use super::*;
888 use crate::generators::make_classification;
889 use scirs2_core::random::Uniform;
890
891 #[test]
892 fn test_ml_pipeline_creation() {
893 let pipeline = MLPipeline::default();
894 assert_eq!(pipeline.config.test_size, 0.2);
895 assert_eq!(pipeline.config.cv_folds, 5);
896 }
897
898 #[test]
899 fn test_train_test_split() {
900 let dataset = make_classification(100, 5, 2, 1, 1, Some(42)).unwrap();
901 let split = convenience::train_test_split(&dataset, Some(0.3)).unwrap();
902
903 assert_eq!(split.x_train.nrows() + split.x_test.nrows(), 100);
904 assert_eq!(split.y_train.len() + split.y_test.len(), 100);
905 assert_eq!(split.x_train.ncols(), 5);
906 assert_eq!(split.x_test.ncols(), 5);
907 }
908
909 #[test]
910 fn test_cross_validation_split() {
911 let dataset = make_classification(100, 3, 2, 1, 1, Some(42)).unwrap();
912 let folds = convenience::cv_split(&dataset, Some(5), Some(true)).unwrap();
913
914 assert_eq!(folds.len(), 5);
915
916 let total_samples: usize = folds
918 .iter()
919 .map(|(train, test)| train.len() + test.len())
920 .sum::<usize>()
921 / 5; assert_eq!(total_samples, 100);
924 }
925
926 #[test]
927 fn test_dataset_preparation() {
928 let dataset = make_classification(50, 4, 2, 1, 1, Some(42)).unwrap();
929 let prepared = convenience::prepare_for_ml(&dataset, true, false).unwrap();
930
931 assert_eq!(prepared.n_samples(), dataset.n_samples());
932 assert_eq!(prepared.n_features(), dataset.n_features());
933 }
934
935 #[test]
936 fn test_experiment_creation() {
937 let dataset = make_classification(100, 5, 2, 1, 1, Some(42)).unwrap();
938 let experiment = convenience::create_experiment("test_experiment", &dataset);
939
940 assert_eq!(experiment.name, "test_experiment");
941 assert_eq!(experiment.dataset_info.n_samples, 100);
942 assert_eq!(experiment.dataset_info.n_features, 5);
943 assert_eq!(experiment.dataset_info.n_classes, Some(2));
944 }
945
946 #[test]
947 fn test_scaler_fitting() {
948 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
949 let array = Array1::from_vec(data);
950 let view = array.view();
951
952 let scaler_params = MLPipeline::fit_scaler(&view, ScalingMethod::StandardScaler).unwrap();
953
954 assert!(scaler_params.mean.is_some());
955 assert!(scaler_params.std.is_some());
956 assert_eq!(scaler_params.mean.unwrap(), 3.0);
957 }
958
959 #[test]
960 fn test_min_max_scaler() {
961 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
962 let array = Array1::from_vec(data);
963 let view = array.view();
964
965 let scaler_params = MLPipeline::fit_scaler(&view, ScalingMethod::MinMaxScaler).unwrap();
966
967 assert!(scaler_params.min.is_some());
968 assert!(scaler_params.max.is_some());
969 assert_eq!(scaler_params.min.unwrap(), 1.0);
970 assert_eq!(scaler_params.max.unwrap(), 5.0);
971 }
972}