1use std::collections::HashMap;
11
12use scirs2_core::ndarray::{Array1, Array2, Axis};
13use scirs2_core::random::prelude::*;
14use scirs2_core::random::SliceRandomExt;
15use scirs2_core::random::Uniform;
16use serde::{Deserialize, Serialize};
17
18use crate::error::{DatasetsError, Result};
19use crate::utils::{BalancingStrategy, CrossValidationFolds, Dataset};
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct MLPipelineConfig {
24 pub random_state: Option<u64>,
26 pub test_size: f64,
28 pub cv_folds: usize,
30 pub stratify: bool,
32 pub balancing_strategy: Option<BalancingStrategy>,
34 pub scaling_method: Option<ScalingMethod>,
36}
37
38impl Default for MLPipelineConfig {
39 fn default() -> Self {
40 Self {
41 random_state: Some(42),
42 test_size: 0.2,
43 cv_folds: 5,
44 stratify: true,
45 balancing_strategy: None,
46 scaling_method: Some(ScalingMethod::StandardScaler),
47 }
48 }
49}
50
51#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
53pub enum ScalingMethod {
54 StandardScaler,
56 MinMaxScaler,
58 RobustScaler,
60 None,
62}
63
64pub struct MLPipeline {
66 config: MLPipelineConfig,
67 fitted_scalers: Option<HashMap<String, ScalerParams>>,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ScalerParams {
73 pub method: ScalingMethod,
75 pub mean: Option<f64>,
77 pub std: Option<f64>,
79 pub min: Option<f64>,
81 pub max: Option<f64>,
83 pub median: Option<f64>,
85 pub mad: Option<f64>,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct MLExperiment {
92 pub name: String,
94 pub dataset_info: DatasetInfo,
96 pub model_config: ModelConfig,
98 pub results: ExperimentResults,
100 pub cv_scores: Option<CrossValidationResults>,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct DatasetInfo {
107 pub n_samples: usize,
109 pub n_features: usize,
111 pub n_classes: Option<usize>,
113 pub class_distribution: Option<HashMap<String, usize>>,
115 pub missing_data_percentage: f64,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct ModelConfig {
122 pub model_type: String,
124 pub hyperparameters: HashMap<String, serde_json::Value>,
126 pub preprocessing_steps: Vec<String>,
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct ExperimentResults {
133 pub training_score: f64,
135 pub validation_score: f64,
137 pub test_score: Option<f64>,
139 pub training_time: f64,
141 pub inference_time: Option<f64>,
143 pub feature_importance: Option<Vec<(String, f64)>>,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct CrossValidationResults {
150 pub scores: Vec<f64>,
152 pub mean_score: f64,
154 pub std_score: f64,
156 pub fold_details: Vec<FoldResult>,
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct FoldResult {
163 pub fold_index: usize,
165 pub train_score: f64,
167 pub validation_score: f64,
169 pub training_time: f64,
171}
172
173#[derive(Debug, Clone)]
175pub struct DataSplit {
176 pub x_train: Array2<f64>,
178 pub x_test: Array2<f64>,
180 pub y_train: Array1<f64>,
182 pub y_test: Array1<f64>,
184}
185
186impl Default for MLPipeline {
187 fn default() -> Self {
188 Self::new(MLPipelineConfig::default())
189 }
190}
191
192impl MLPipeline {
193 pub fn new(config: MLPipelineConfig) -> Self {
195 Self {
196 config,
197 fitted_scalers: None,
198 }
199 }
200
201 pub fn prepare_dataset(&mut self, dataset: &Dataset) -> Result<Dataset> {
203 let mut prepared = dataset.clone();
204
205 if let Some(ref strategy) = self.config.balancing_strategy {
207 prepared = self.apply_balancing(&prepared, strategy)?;
208 }
209
210 if let Some(method) = self.config.scaling_method {
212 prepared = self.fit_and_transform_scaling(&prepared, method)?;
213 }
214
215 Ok(prepared)
216 }
217
218 pub fn train_test_split(&self, dataset: &Dataset) -> Result<DataSplit> {
220 let n_samples = dataset.n_samples();
221 let test_samples = (n_samples as f64 * self.config.test_size) as usize;
222 let train_samples = n_samples - test_samples;
223
224 let indices = self.generate_split_indices(n_samples, dataset.target.as_ref())?;
225
226 let train_indices = &indices[..train_samples];
227 let test_indices = &indices[train_samples..];
228
229 let x_train = dataset.data.select(Axis(0), train_indices);
230 let x_test = dataset.data.select(Axis(0), test_indices);
231
232 let (y_train, y_test) = if let Some(ref target) = dataset.target {
233 let y_train = target.select(Axis(0), train_indices);
234 let y_test = target.select(Axis(0), test_indices);
235 (y_train, y_test)
236 } else {
237 return Err(DatasetsError::InvalidFormat(
238 "Target variable required for train/test split".to_string(),
239 ));
240 };
241
242 Ok(DataSplit {
243 x_train,
244 x_test,
245 y_train,
246 y_test,
247 })
248 }
249
250 pub fn cross_validation_split(&self, dataset: &Dataset) -> Result<CrossValidationFolds> {
252 let target = dataset.target.as_ref().ok_or_else(|| {
253 DatasetsError::InvalidFormat(
254 "Target variable required for cross-validation".to_string(),
255 )
256 })?;
257
258 if self.config.stratify {
259 crate::utils::stratified_k_fold_split(
260 target,
261 self.config.cv_folds,
262 true,
263 self.config.random_state,
264 )
265 } else {
266 crate::utils::k_fold_split(
267 dataset.n_samples(),
268 self.config.cv_folds,
269 true,
270 self.config.random_state,
271 )
272 }
273 }
274
275 pub fn transform(&self, dataset: &Dataset) -> Result<Dataset> {
277 let scalers = self.fitted_scalers.as_ref().ok_or_else(|| {
278 DatasetsError::InvalidFormat(
279 "Pipeline not fitted. Call prepare_dataset first.".to_string(),
280 )
281 })?;
282
283 let mut transformed_data = dataset.data.clone();
284
285 for (col_idx, mut column) in transformed_data.columns_mut().into_iter().enumerate() {
286 let defaultname = format!("feature_{col_idx}");
287 let featurename = dataset
288 .featurenames
289 .as_ref()
290 .and_then(|names| names.get(col_idx))
291 .map(|s| s.as_str())
292 .unwrap_or(&defaultname);
293
294 if let Some(scaler) = scalers.get(featurename) {
295 Self::apply_scaler_to_column(&mut column, scaler)?;
296 }
297 }
298
299 Ok(Dataset {
300 data: transformed_data,
301 target: dataset.target.clone(),
302 featurenames: dataset.featurenames.clone(),
303 targetnames: dataset.targetnames.clone(),
304 feature_descriptions: dataset.feature_descriptions.clone(),
305 description: Some("Transformed dataset".to_string()),
306 metadata: dataset.metadata.clone(),
307 })
308 }
309
310 pub fn create_experiment(&self, name: &str, dataset: &Dataset) -> MLExperiment {
312 let dataset_info = self.extract_dataset_info(dataset);
313
314 MLExperiment {
315 name: name.to_string(),
316 dataset_info,
317 model_config: ModelConfig {
318 model_type: "undefined".to_string(),
319 hyperparameters: HashMap::new(),
320 preprocessing_steps: Vec::new(),
321 },
322 results: ExperimentResults {
323 training_score: 0.0,
324 validation_score: 0.0,
325 test_score: None,
326 training_time: 0.0,
327 inference_time: None,
328 feature_importance: None,
329 },
330 cv_scores: None,
331 }
332 }
333
334 pub fn evaluate_with_cv<F>(
336 &self,
337 dataset: &Dataset,
338 train_fn: F,
339 ) -> Result<CrossValidationResults>
340 where
341 F: Fn(&Array2<f64>, &Array1<f64>, &Array2<f64>, &Array1<f64>) -> Result<(f64, f64, f64)>,
342 {
343 let folds = self.cross_validation_split(dataset)?;
344 let mut scores = Vec::new();
345 let mut fold_details = Vec::new();
346
347 for (fold_idx, (train_indices, val_indices)) in folds.into_iter().enumerate() {
348 let x_train = dataset.data.select(Axis(0), &train_indices);
349 let x_val = dataset.data.select(Axis(0), &val_indices);
350
351 let target = dataset.target.as_ref().expect("Operation failed");
352 let y_train = target.select(Axis(0), &train_indices);
353 let y_val = target.select(Axis(0), &val_indices);
354
355 let (train_score, val_score, training_time) =
356 train_fn(&x_train, &y_train, &x_val, &y_val)?;
357
358 scores.push(val_score);
359 fold_details.push(FoldResult {
360 fold_index: fold_idx,
361 train_score,
362 validation_score: val_score,
363 training_time,
364 });
365 }
366
367 let mean_score = scores.iter().sum::<f64>() / scores.len() as f64;
368 let variance = scores
369 .iter()
370 .map(|score| (score - mean_score).powi(2))
371 .sum::<f64>()
372 / scores.len() as f64;
373 let std_score = variance.sqrt();
374
375 Ok(CrossValidationResults {
376 scores,
377 mean_score,
378 std_score,
379 fold_details,
380 })
381 }
382
383 fn apply_balancing(&self, dataset: &Dataset, strategy: &BalancingStrategy) -> Result<Dataset> {
386 match strategy {
389 BalancingStrategy::RandomUndersample => self.random_undersample(dataset, None),
390 BalancingStrategy::RandomOversample => self.random_oversample(dataset, None),
391 _ => Ok(dataset.clone()), }
393 }
394
395 fn random_undersample(&self, dataset: &Dataset, _randomstate: Option<u64>) -> Result<Dataset> {
396 let target = dataset.target.as_ref().ok_or_else(|| {
397 DatasetsError::InvalidFormat("Target required for balancing".to_string())
398 })?;
399
400 let mut class_counts: HashMap<i64, usize> = HashMap::new();
402 for &value in target.iter() {
403 if !value.is_nan() {
404 *class_counts.entry(value as i64).or_insert(0) += 1;
405 }
406 }
407
408 let min_count = class_counts.values().min().copied().unwrap_or(0);
409
410 let mut selected_indices = Vec::new();
412
413 for (class_, _count) in class_counts {
414 let class_indices: Vec<usize> = target
415 .iter()
416 .enumerate()
417 .filter(|(_, &val)| !val.is_nan() && val as i64 == class_)
418 .map(|(idx, _)| idx)
419 .collect();
420
421 let mut sampled_indices = class_indices;
422 if sampled_indices.len() > min_count {
423 sampled_indices.truncate(min_count);
425 }
426
427 selected_indices.extend(sampled_indices);
428 }
429
430 let balanced_data = dataset.data.select(Axis(0), &selected_indices);
431 let balanced_target = target.select(Axis(0), &selected_indices);
432
433 Ok(Dataset {
434 data: balanced_data,
435 target: Some(balanced_target),
436 featurenames: dataset.featurenames.clone(),
437 targetnames: dataset.targetnames.clone(),
438 feature_descriptions: dataset.feature_descriptions.clone(),
439 description: Some("Undersampled dataset".to_string()),
440 metadata: dataset.metadata.clone(),
441 })
442 }
443
444 fn random_oversample(&self, dataset: &Dataset, randomstate: Option<u64>) -> Result<Dataset> {
445 use scirs2_core::random::prelude::*;
446 use scirs2_core::random::{rngs::StdRng, RngCore, SeedableRng};
447 use std::collections::HashMap;
448
449 let target = dataset.target.as_ref().ok_or_else(|| {
450 DatasetsError::InvalidFormat("Random oversampling requires target labels".to_string())
451 })?;
452
453 if target.len() != dataset.data.nrows() {
454 return Err(DatasetsError::InvalidFormat(
455 "Target length must match number of samples".to_string(),
456 ));
457 }
458
459 let mut class_counts: HashMap<i32, usize> = HashMap::new();
461 let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
462
463 for (idx, &label) in target.iter().enumerate() {
464 let class = label as i32;
465 *class_counts.entry(class).or_insert(0) += 1;
466 class_indices.entry(class).or_default().push(idx);
467 }
468
469 let max_count = class_counts.values().max().copied().unwrap_or(0);
471
472 if max_count == 0 {
473 return Err(DatasetsError::InvalidFormat(
474 "No samples found in dataset".to_string(),
475 ));
476 }
477
478 let mut rng: Box<dyn RngCore> = match randomstate {
480 Some(seed) => Box::new(StdRng::seed_from_u64(seed)),
481 None => Box::new(thread_rng()),
482 };
483
484 let mut all_indices = Vec::new();
486
487 for (_class, indices) in class_indices.iter() {
488 let current_count = indices.len();
489
490 all_indices.extend(indices.iter().copied());
492
493 let samples_needed = max_count - current_count;
495
496 if samples_needed > 0 {
497 for _ in 0..samples_needed {
498 let random_idx =
499 rng.sample(Uniform::new(0, indices.len()).expect("Operation failed"));
500 all_indices.push(indices[random_idx]);
501 }
502 }
503 }
504
505 all_indices.shuffle(&mut *rng);
507
508 let oversampled_data = dataset.data.select(Axis(0), &all_indices);
510 let oversampled_target = target.select(Axis(0), &all_indices);
511
512 Ok(Dataset {
513 data: oversampled_data,
514 target: Some(oversampled_target),
515 featurenames: dataset.featurenames.clone(),
516 targetnames: dataset.targetnames.clone(),
517 feature_descriptions: dataset.feature_descriptions.clone(),
518 description: Some(format!(
519 "Random oversampled dataset (original: {} samples, oversampled: {} samples)",
520 dataset.n_samples(),
521 all_indices.len()
522 )),
523 metadata: dataset.metadata.clone(),
524 })
525 }
526
527 fn fit_and_transform_scaling(
528 &mut self,
529 dataset: &Dataset,
530 method: ScalingMethod,
531 ) -> Result<Dataset> {
532 let mut scalers = HashMap::new();
533 let mut scaled_data = dataset.data.clone();
534
535 for (col_idx, mut column) in scaled_data.columns_mut().into_iter().enumerate() {
536 let featurename = dataset
537 .featurenames
538 .as_ref()
539 .and_then(|names| names.get(col_idx))
540 .cloned()
541 .unwrap_or_else(|| format!("feature_{col_idx}"));
542
543 let column_view = column.view();
544 let scaler_params = Self::fit_scaler(&column_view, method)?;
545 Self::apply_scaler_to_column(&mut column, &scaler_params)?;
546
547 scalers.insert(featurename, scaler_params);
548 }
549
550 self.fitted_scalers = Some(scalers);
551
552 Ok(Dataset {
553 data: scaled_data,
554 target: dataset.target.clone(),
555 featurenames: dataset.featurenames.clone(),
556 targetnames: dataset.targetnames.clone(),
557 feature_descriptions: dataset.feature_descriptions.clone(),
558 description: Some("Scaled dataset".to_string()),
559 metadata: dataset.metadata.clone(),
560 })
561 }
562
563 fn fit_scaler(
564 column: &scirs2_core::ndarray::ArrayView1<f64>,
565 method: ScalingMethod,
566 ) -> Result<ScalerParams> {
567 let values: Vec<f64> = column.iter().copied().filter(|x| !x.is_nan()).collect();
568
569 if values.is_empty() {
570 return Ok(ScalerParams {
571 method,
572 mean: None,
573 std: None,
574 min: None,
575 max: None,
576 median: None,
577 mad: None,
578 });
579 }
580
581 match method {
582 ScalingMethod::StandardScaler => {
583 let mean = values.iter().sum::<f64>() / values.len() as f64;
584 let variance =
585 values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
586 let std = variance.sqrt();
587
588 Ok(ScalerParams {
589 method,
590 mean: Some(mean),
591 std: Some(std),
592 min: None,
593 max: None,
594 median: None,
595 mad: None,
596 })
597 }
598 ScalingMethod::MinMaxScaler => {
599 let min = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
600 let max = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
601
602 Ok(ScalerParams {
603 method,
604 mean: None,
605 std: None,
606 min: Some(min),
607 max: Some(max),
608 median: None,
609 mad: None,
610 })
611 }
612 ScalingMethod::RobustScaler => {
613 let mut sorted_values = values.clone();
614 sorted_values.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
615
616 let median = Self::percentile(&sorted_values, 0.5).unwrap_or(0.0);
617 let mad = Self::compute_mad(&sorted_values, median);
618
619 Ok(ScalerParams {
620 method,
621 mean: None,
622 std: None,
623 min: None,
624 max: None,
625 median: Some(median),
626 mad: Some(mad),
627 })
628 }
629 ScalingMethod::None => Ok(ScalerParams {
630 method,
631 mean: None,
632 std: None,
633 min: None,
634 max: None,
635 median: None,
636 mad: None,
637 }),
638 }
639 }
640
641 fn apply_scaler_to_column(
642 column: &mut scirs2_core::ndarray::ArrayViewMut1<f64>,
643 params: &ScalerParams,
644 ) -> Result<()> {
645 match params.method {
646 ScalingMethod::StandardScaler => {
647 if let (Some(mean), Some(std)) = (params.mean, params.std) {
648 if std > 1e-8 {
649 for value in column.iter_mut() {
651 if !value.is_nan() {
652 *value = (*value - mean) / std;
653 }
654 }
655 }
656 }
657 }
658 ScalingMethod::MinMaxScaler => {
659 if let (Some(min), Some(max)) = (params.min, params.max) {
660 let range = max - min;
661 if range > 1e-8 {
662 for value in column.iter_mut() {
664 if !value.is_nan() {
665 *value = (*value - min) / range;
666 }
667 }
668 }
669 }
670 }
671 ScalingMethod::RobustScaler => {
672 if let (Some(median), Some(mad)) = (params.median, params.mad) {
673 if mad > 1e-8 {
674 for value in column.iter_mut() {
676 if !value.is_nan() {
677 *value = (*value - median) / mad;
678 }
679 }
680 }
681 }
682 }
683 ScalingMethod::None => {
684 }
686 }
687
688 Ok(())
689 }
690
691 fn percentile(sorted_values: &[f64], p: f64) -> Option<f64> {
692 if sorted_values.is_empty() {
693 return None;
694 }
695
696 let index = p * (sorted_values.len() - 1) as f64;
697 let lower = index.floor() as usize;
698 let upper = index.ceil() as usize;
699
700 if lower == upper {
701 Some(sorted_values[lower])
702 } else {
703 let weight = index - lower as f64;
704 Some(sorted_values[lower] * (1.0 - weight) + sorted_values[upper] * weight)
705 }
706 }
707
708 fn compute_mad(sorted_values: &[f64], median: f64) -> f64 {
709 let deviations: Vec<f64> = sorted_values.iter().map(|&x| (x - median).abs()).collect();
710
711 let mut sorted_deviations = deviations;
712 sorted_deviations.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
713
714 Self::percentile(&sorted_deviations, 0.5).unwrap_or(1.0)
715 }
716
717 fn generate_split_indices(
718 &self,
719 n_samples: usize,
720 target: Option<&Array1<f64>>,
721 ) -> Result<Vec<usize>> {
722 let mut indices: Vec<usize> = (0..n_samples).collect();
723
724 if self.config.stratify && target.is_some() {
726 self.stratified_shuffle(
728 &mut indices,
729 target.expect("Target required for stratified split"),
730 )?;
731 } else {
732 match self.config.random_state {
734 Some(seed) => {
735 let mut rng = StdRng::seed_from_u64(seed);
736 indices.shuffle(&mut rng);
737 }
738 None => {
739 let mut rng = thread_rng();
740 indices.shuffle(&mut rng);
741 }
742 }
743 }
744
745 Ok(indices)
746 }
747
748 fn stratified_shuffle(&self, indices: &mut Vec<usize>, target: &Array1<f64>) -> Result<()> {
750 let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
752
753 for &idx in indices.iter() {
754 let class = target[idx] as i32;
755 class_indices.entry(class).or_default().push(idx);
756 }
757
758 for class_group in class_indices.values_mut() {
760 match self.config.random_state {
761 Some(seed) => {
762 let mut rng = StdRng::seed_from_u64(seed);
763 class_group.shuffle(&mut rng);
764 }
765 None => {
766 let mut rng = thread_rng();
767 class_group.shuffle(&mut rng);
768 }
769 }
770 }
771
772 indices.clear();
774 let mut class_iterators: HashMap<i32, std::vec::IntoIter<usize>> = class_indices
775 .into_iter()
776 .map(|(class, group)| (class, group.into_iter()))
777 .collect();
778
779 while !class_iterators.is_empty() {
781 let mut to_remove = Vec::new();
782 for (&class, iterator) in class_iterators.iter_mut() {
783 if let Some(idx) = iterator.next() {
784 indices.push(idx);
785 } else {
786 to_remove.push(class);
787 }
788 }
789 for class in to_remove {
790 class_iterators.remove(&class);
791 }
792 }
793
794 Ok(())
795 }
796
797 fn extract_dataset_info(&self, dataset: &Dataset) -> DatasetInfo {
798 let n_samples = dataset.n_samples();
799 let n_features = dataset.n_features();
800
801 let (n_classes, class_distribution) = if let Some(ref target) = dataset.target {
802 let mut class_counts: HashMap<String, usize> = HashMap::new();
803 for &value in target.iter() {
804 if !value.is_nan() {
805 let classname = format!("{value:.0}");
806 *class_counts.entry(classname).or_insert(0) += 1;
807 }
808 }
809
810 let n_classes = class_counts.len();
811 (Some(n_classes), Some(class_counts))
812 } else {
813 (None, None)
814 };
815
816 let total_values = n_samples * n_features;
818 let missing_values = dataset.data.iter().filter(|&&x| x.is_nan()).count();
819 let missing_data_percentage = missing_values as f64 / total_values as f64 * 100.0;
820
821 DatasetInfo {
822 n_samples,
823 n_features,
824 n_classes,
825 class_distribution,
826 missing_data_percentage,
827 }
828 }
829}
830
831pub mod convenience {
833 use super::*;
834
835 pub fn train_test_split(_dataset: &Dataset, testsize: Option<f64>) -> Result<DataSplit> {
837 let mut config = MLPipelineConfig::default();
838 if let Some(_size) = testsize {
839 config.test_size = _size;
840 }
841
842 let pipeline = MLPipeline::new(config);
843 pipeline.train_test_split(_dataset)
844 }
845
846 pub fn prepare_for_ml(dataset: &Dataset, scale: bool, balance: bool) -> Result<Dataset> {
848 let mut config = MLPipelineConfig::default();
849
850 if !scale {
851 config.scaling_method = None;
852 }
853
854 if balance {
855 config.balancing_strategy = Some(BalancingStrategy::RandomUndersample);
856 }
857
858 let mut pipeline = MLPipeline::new(config);
859 pipeline.prepare_dataset(dataset)
860 }
861
862 pub fn cv_split(
864 dataset: &Dataset,
865 n_folds: Option<usize>,
866 stratify: Option<bool>,
867 ) -> Result<CrossValidationFolds> {
868 let mut config = MLPipelineConfig::default();
869
870 if let Some(_folds) = n_folds {
871 config.cv_folds = _folds;
872 }
873
874 if let Some(strat) = stratify {
875 config.stratify = strat;
876 }
877
878 let pipeline = MLPipeline::new(config);
879 pipeline.cross_validation_split(dataset)
880 }
881
882 pub fn create_experiment(name: &str, dataset: &Dataset) -> MLExperiment {
884 let pipeline = MLPipeline::default();
885 pipeline.create_experiment(name, dataset)
886 }
887}
888
889#[cfg(test)]
890mod tests {
891 use super::*;
892 use crate::generators::make_classification;
893 use scirs2_core::random::Uniform;
894
895 #[test]
896 fn test_ml_pipeline_creation() {
897 let pipeline = MLPipeline::default();
898 assert_eq!(pipeline.config.test_size, 0.2);
899 assert_eq!(pipeline.config.cv_folds, 5);
900 }
901
902 #[test]
903 fn test_train_test_split() {
904 let dataset = make_classification(100, 5, 2, 1, 1, Some(42)).expect("Operation failed");
905 let split = convenience::train_test_split(&dataset, Some(0.3)).expect("Operation failed");
906
907 assert_eq!(split.x_train.nrows() + split.x_test.nrows(), 100);
908 assert_eq!(split.y_train.len() + split.y_test.len(), 100);
909 assert_eq!(split.x_train.ncols(), 5);
910 assert_eq!(split.x_test.ncols(), 5);
911 }
912
913 #[test]
914 fn test_cross_validation_split() {
915 let dataset = make_classification(100, 3, 2, 1, 1, Some(42)).expect("Operation failed");
916 let folds = convenience::cv_split(&dataset, Some(5), Some(true)).expect("Operation failed");
917
918 assert_eq!(folds.len(), 5);
919
920 let total_samples: usize = folds
922 .iter()
923 .map(|(train, test)| train.len() + test.len())
924 .sum::<usize>()
925 / 5; assert_eq!(total_samples, 100);
928 }
929
930 #[test]
931 fn test_dataset_preparation() {
932 let dataset = make_classification(50, 4, 2, 1, 1, Some(42)).expect("Operation failed");
933 let prepared =
934 convenience::prepare_for_ml(&dataset, true, false).expect("Operation failed");
935
936 assert_eq!(prepared.n_samples(), dataset.n_samples());
937 assert_eq!(prepared.n_features(), dataset.n_features());
938 }
939
940 #[test]
941 fn test_experiment_creation() {
942 let dataset = make_classification(100, 5, 2, 1, 1, Some(42)).expect("Operation failed");
943 let experiment = convenience::create_experiment("test_experiment", &dataset);
944
945 assert_eq!(experiment.name, "test_experiment");
946 assert_eq!(experiment.dataset_info.n_samples, 100);
947 assert_eq!(experiment.dataset_info.n_features, 5);
948 assert_eq!(experiment.dataset_info.n_classes, Some(2));
949 }
950
951 #[test]
952 fn test_scaler_fitting() {
953 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
954 let array = Array1::from_vec(data);
955 let view = array.view();
956
957 let scaler_params =
958 MLPipeline::fit_scaler(&view, ScalingMethod::StandardScaler).expect("Operation failed");
959
960 assert!(scaler_params.mean.is_some());
961 assert!(scaler_params.std.is_some());
962 assert_eq!(scaler_params.mean.expect("Test: mean missing"), 3.0);
963 }
964
965 #[test]
966 fn test_min_max_scaler() {
967 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
968 let array = Array1::from_vec(data);
969 let view = array.view();
970
971 let scaler_params =
972 MLPipeline::fit_scaler(&view, ScalingMethod::MinMaxScaler).expect("Operation failed");
973
974 assert!(scaler_params.min.is_some());
975 assert!(scaler_params.max.is_some());
976 assert_eq!(scaler_params.min.expect("Test: min missing"), 1.0);
977 assert_eq!(scaler_params.max.expect("Test: max missing"), 5.0);
978 }
979}