1use std::collections::HashMap;
11
12use ndarray::{Array1, Array2, Axis};
13use rand::rngs::StdRng;
14use rand::seq::SliceRandom;
15use rand::{rng, SeedableRng};
16use rand_distr::Uniform;
17use serde::{Deserialize, Serialize};
18
19use crate::error::{DatasetsError, Result};
20use crate::utils::{BalancingStrategy, CrossValidationFolds, Dataset};
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct MLPipelineConfig {
25 pub random_state: Option<u64>,
27 pub test_size: f64,
29 pub cv_folds: usize,
31 pub stratify: bool,
33 pub balancing_strategy: Option<BalancingStrategy>,
35 pub scaling_method: Option<ScalingMethod>,
37}
38
39impl Default for MLPipelineConfig {
40 fn default() -> Self {
41 Self {
42 random_state: Some(42),
43 test_size: 0.2,
44 cv_folds: 5,
45 stratify: true,
46 balancing_strategy: None,
47 scaling_method: Some(ScalingMethod::StandardScaler),
48 }
49 }
50}
51
52#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
54pub enum ScalingMethod {
55 StandardScaler,
57 MinMaxScaler,
59 RobustScaler,
61 None,
63}
64
65pub struct MLPipeline {
67 config: MLPipelineConfig,
68 fitted_scalers: Option<HashMap<String, ScalerParams>>,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct ScalerParams {
74 pub method: ScalingMethod,
76 pub mean: Option<f64>,
78 pub std: Option<f64>,
80 pub min: Option<f64>,
82 pub max: Option<f64>,
84 pub median: Option<f64>,
86 pub mad: Option<f64>,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct MLExperiment {
93 pub name: String,
95 pub dataset_info: DatasetInfo,
97 pub model_config: ModelConfig,
99 pub results: ExperimentResults,
101 pub cv_scores: Option<CrossValidationResults>,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct DatasetInfo {
108 pub n_samples: usize,
110 pub n_features: usize,
112 pub n_classes: Option<usize>,
114 pub class_distribution: Option<HashMap<String, usize>>,
116 pub missing_data_percentage: f64,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct ModelConfig {
123 pub model_type: String,
125 pub hyperparameters: HashMap<String, serde_json::Value>,
127 pub preprocessing_steps: Vec<String>,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct ExperimentResults {
134 pub training_score: f64,
136 pub validation_score: f64,
138 pub test_score: Option<f64>,
140 pub training_time: f64,
142 pub inference_time: Option<f64>,
144 pub feature_importance: Option<Vec<(String, f64)>>,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct CrossValidationResults {
151 pub scores: Vec<f64>,
153 pub mean_score: f64,
155 pub std_score: f64,
157 pub fold_details: Vec<FoldResult>,
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct FoldResult {
164 pub fold_index: usize,
166 pub train_score: f64,
168 pub validation_score: f64,
170 pub training_time: f64,
172}
173
174#[derive(Debug, Clone)]
176pub struct DataSplit {
177 pub x_train: Array2<f64>,
179 pub x_test: Array2<f64>,
181 pub y_train: Array1<f64>,
183 pub y_test: Array1<f64>,
185}
186
187impl Default for MLPipeline {
188 fn default() -> Self {
189 Self::new(MLPipelineConfig::default())
190 }
191}
192
193impl MLPipeline {
194 pub fn new(config: MLPipelineConfig) -> Self {
196 Self {
197 config,
198 fitted_scalers: None,
199 }
200 }
201
202 pub fn prepare_dataset(&mut self, dataset: &Dataset) -> Result<Dataset> {
204 let mut prepared = dataset.clone();
205
206 if let Some(ref strategy) = self.config.balancing_strategy {
208 prepared = self.apply_balancing(&prepared, strategy)?;
209 }
210
211 if let Some(method) = self.config.scaling_method {
213 prepared = self.fit_and_transform_scaling(&prepared, method)?;
214 }
215
216 Ok(prepared)
217 }
218
219 pub fn train_test_split(&self, dataset: &Dataset) -> Result<DataSplit> {
221 let n_samples = dataset.n_samples();
222 let test_samples = (n_samples as f64 * self.config.test_size) as usize;
223 let train_samples = n_samples - test_samples;
224
225 let indices = self.generate_split_indices(n_samples, dataset.target.as_ref())?;
226
227 let train_indices = &indices[..train_samples];
228 let test_indices = &indices[train_samples..];
229
230 let x_train = dataset.data.select(Axis(0), train_indices);
231 let x_test = dataset.data.select(Axis(0), test_indices);
232
233 let (y_train, y_test) = if let Some(ref target) = dataset.target {
234 let y_train = target.select(Axis(0), train_indices);
235 let y_test = target.select(Axis(0), test_indices);
236 (y_train, y_test)
237 } else {
238 return Err(DatasetsError::InvalidFormat(
239 "Target variable required for train/test split".to_string(),
240 ));
241 };
242
243 Ok(DataSplit {
244 x_train,
245 x_test,
246 y_train,
247 y_test,
248 })
249 }
250
251 pub fn cross_validation_split(&self, dataset: &Dataset) -> Result<CrossValidationFolds> {
253 let target = dataset.target.as_ref().ok_or_else(|| {
254 DatasetsError::InvalidFormat(
255 "Target variable required for cross-validation".to_string(),
256 )
257 })?;
258
259 if self.config.stratify {
260 crate::utils::stratified_k_fold_split(
261 target,
262 self.config.cv_folds,
263 true,
264 self.config.random_state,
265 )
266 } else {
267 crate::utils::k_fold_split(
268 dataset.n_samples(),
269 self.config.cv_folds,
270 true,
271 self.config.random_state,
272 )
273 }
274 }
275
276 pub fn transform(&self, dataset: &Dataset) -> Result<Dataset> {
278 let scalers = self.fitted_scalers.as_ref().ok_or_else(|| {
279 DatasetsError::InvalidFormat(
280 "Pipeline not fitted. Call prepare_dataset first.".to_string(),
281 )
282 })?;
283
284 let mut transformed_data = dataset.data.clone();
285
286 for (col_idx, mut column) in transformed_data.columns_mut().into_iter().enumerate() {
287 let defaultname = format!("feature_{col_idx}");
288 let featurename = dataset
289 .featurenames
290 .as_ref()
291 .and_then(|names| names.get(col_idx))
292 .map(|s| s.as_str())
293 .unwrap_or(&defaultname);
294
295 if let Some(scaler) = scalers.get(featurename) {
296 Self::apply_scaler_to_column(&mut column, scaler)?;
297 }
298 }
299
300 Ok(Dataset {
301 data: transformed_data,
302 target: dataset.target.clone(),
303 featurenames: dataset.featurenames.clone(),
304 targetnames: dataset.targetnames.clone(),
305 feature_descriptions: dataset.feature_descriptions.clone(),
306 description: Some("Transformed dataset".to_string()),
307 metadata: dataset.metadata.clone(),
308 })
309 }
310
311 pub fn create_experiment(&self, name: &str, dataset: &Dataset) -> MLExperiment {
313 let dataset_info = self.extract_dataset_info(dataset);
314
315 MLExperiment {
316 name: name.to_string(),
317 dataset_info,
318 model_config: ModelConfig {
319 model_type: "undefined".to_string(),
320 hyperparameters: HashMap::new(),
321 preprocessing_steps: Vec::new(),
322 },
323 results: ExperimentResults {
324 training_score: 0.0,
325 validation_score: 0.0,
326 test_score: None,
327 training_time: 0.0,
328 inference_time: None,
329 feature_importance: None,
330 },
331 cv_scores: None,
332 }
333 }
334
335 pub fn evaluate_with_cv<F>(
337 &self,
338 dataset: &Dataset,
339 train_fn: F,
340 ) -> Result<CrossValidationResults>
341 where
342 F: Fn(&Array2<f64>, &Array1<f64>, &Array2<f64>, &Array1<f64>) -> Result<(f64, f64, f64)>,
343 {
344 let folds = self.cross_validation_split(dataset)?;
345 let mut scores = Vec::new();
346 let mut fold_details = Vec::new();
347
348 for (fold_idx, (train_indices, val_indices)) in folds.into_iter().enumerate() {
349 let x_train = dataset.data.select(Axis(0), &train_indices);
350 let x_val = dataset.data.select(Axis(0), &val_indices);
351
352 let target = dataset.target.as_ref().unwrap();
353 let y_train = target.select(Axis(0), &train_indices);
354 let y_val = target.select(Axis(0), &val_indices);
355
356 let (train_score, val_score, training_time) =
357 train_fn(&x_train, &y_train, &x_val, &y_val)?;
358
359 scores.push(val_score);
360 fold_details.push(FoldResult {
361 fold_index: fold_idx,
362 train_score,
363 validation_score: val_score,
364 training_time,
365 });
366 }
367
368 let mean_score = scores.iter().sum::<f64>() / scores.len() as f64;
369 let variance = scores
370 .iter()
371 .map(|score| (score - mean_score).powi(2))
372 .sum::<f64>()
373 / scores.len() as f64;
374 let std_score = variance.sqrt();
375
376 Ok(CrossValidationResults {
377 scores,
378 mean_score,
379 std_score,
380 fold_details,
381 })
382 }
383
384 fn apply_balancing(&self, dataset: &Dataset, strategy: &BalancingStrategy) -> Result<Dataset> {
387 match strategy {
390 BalancingStrategy::RandomUndersample => self.random_undersample(dataset, None),
391 BalancingStrategy::RandomOversample => self.random_oversample(dataset, None),
392 _ => Ok(dataset.clone()), }
394 }
395
396 fn random_undersample(&self, dataset: &Dataset, _randomstate: Option<u64>) -> Result<Dataset> {
397 let target = dataset.target.as_ref().ok_or_else(|| {
398 DatasetsError::InvalidFormat("Target required for balancing".to_string())
399 })?;
400
401 let mut class_counts: HashMap<i64, usize> = HashMap::new();
403 for &value in target.iter() {
404 if !value.is_nan() {
405 *class_counts.entry(value as i64).or_insert(0) += 1;
406 }
407 }
408
409 let min_count = class_counts.values().min().copied().unwrap_or(0);
410
411 let mut selected_indices = Vec::new();
413
414 for (class_, _count) in class_counts {
415 let class_indices: Vec<usize> = target
416 .iter()
417 .enumerate()
418 .filter(|(_, &val)| !val.is_nan() && val as i64 == class_)
419 .map(|(idx, _)| idx)
420 .collect();
421
422 let mut sampled_indices = class_indices;
423 if sampled_indices.len() > min_count {
424 sampled_indices.truncate(min_count);
426 }
427
428 selected_indices.extend(sampled_indices);
429 }
430
431 let balanced_data = dataset.data.select(Axis(0), &selected_indices);
432 let balanced_target = target.select(Axis(0), &selected_indices);
433
434 Ok(Dataset {
435 data: balanced_data,
436 target: Some(balanced_target),
437 featurenames: dataset.featurenames.clone(),
438 targetnames: dataset.targetnames.clone(),
439 feature_descriptions: dataset.feature_descriptions.clone(),
440 description: Some("Undersampled dataset".to_string()),
441 metadata: dataset.metadata.clone(),
442 })
443 }
444
445 fn random_oversample(&self, dataset: &Dataset, randomstate: Option<u64>) -> Result<Dataset> {
446 use rand::prelude::*;
447 use rand::{rngs::StdRng, RngCore, SeedableRng};
448 use std::collections::HashMap;
449
450 let target = dataset.target.as_ref().ok_or_else(|| {
451 DatasetsError::InvalidFormat("Random oversampling requires target labels".to_string())
452 })?;
453
454 if target.len() != dataset.data.nrows() {
455 return Err(DatasetsError::InvalidFormat(
456 "Target length must match number of samples".to_string(),
457 ));
458 }
459
460 let mut class_counts: HashMap<i32, usize> = HashMap::new();
462 let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
463
464 for (idx, &label) in target.iter().enumerate() {
465 let class = label as i32;
466 *class_counts.entry(class).or_insert(0) += 1;
467 class_indices.entry(class).or_default().push(idx);
468 }
469
470 let max_count = class_counts.values().max().copied().unwrap_or(0);
472
473 if max_count == 0 {
474 return Err(DatasetsError::InvalidFormat(
475 "No samples found in dataset".to_string(),
476 ));
477 }
478
479 let mut rng: Box<dyn RngCore> = match randomstate {
481 Some(seed) => Box::new(StdRng::seed_from_u64(seed)),
482 None => Box::new(rng()),
483 };
484
485 let mut all_indices = Vec::new();
487
488 for (_class, indices) in class_indices.iter() {
489 let current_count = indices.len();
490
491 all_indices.extend(indices.iter().copied());
493
494 let samples_needed = max_count - current_count;
496
497 if samples_needed > 0 {
498 for _ in 0..samples_needed {
499 let random_idx = rng.sample(Uniform::new(0, indices.len()).unwrap());
500 all_indices.push(indices[random_idx]);
501 }
502 }
503 }
504
505 all_indices.shuffle(&mut *rng);
507
508 let oversampled_data = dataset.data.select(Axis(0), &all_indices);
510 let oversampled_target = target.select(Axis(0), &all_indices);
511
512 Ok(Dataset {
513 data: oversampled_data,
514 target: Some(oversampled_target),
515 featurenames: dataset.featurenames.clone(),
516 targetnames: dataset.targetnames.clone(),
517 feature_descriptions: dataset.feature_descriptions.clone(),
518 description: Some(format!(
519 "Random oversampled dataset (original: {} samples, oversampled: {} samples)",
520 dataset.n_samples(),
521 all_indices.len()
522 )),
523 metadata: dataset.metadata.clone(),
524 })
525 }
526
527 fn fit_and_transform_scaling(
528 &mut self,
529 dataset: &Dataset,
530 method: ScalingMethod,
531 ) -> Result<Dataset> {
532 let mut scalers = HashMap::new();
533 let mut scaled_data = dataset.data.clone();
534
535 for (col_idx, mut column) in scaled_data.columns_mut().into_iter().enumerate() {
536 let featurename = dataset
537 .featurenames
538 .as_ref()
539 .and_then(|names| names.get(col_idx))
540 .cloned()
541 .unwrap_or_else(|| format!("feature_{col_idx}"));
542
543 let column_view = column.view();
544 let scaler_params = Self::fit_scaler(&column_view, method)?;
545 Self::apply_scaler_to_column(&mut column, &scaler_params)?;
546
547 scalers.insert(featurename, scaler_params);
548 }
549
550 self.fitted_scalers = Some(scalers);
551
552 Ok(Dataset {
553 data: scaled_data,
554 target: dataset.target.clone(),
555 featurenames: dataset.featurenames.clone(),
556 targetnames: dataset.targetnames.clone(),
557 feature_descriptions: dataset.feature_descriptions.clone(),
558 description: Some("Scaled dataset".to_string()),
559 metadata: dataset.metadata.clone(),
560 })
561 }
562
563 fn fit_scaler(
564 column: &ndarray::ArrayView1<f64>,
565 method: ScalingMethod,
566 ) -> Result<ScalerParams> {
567 let values: Vec<f64> = column.iter().copied().filter(|x| !x.is_nan()).collect();
568
569 if values.is_empty() {
570 return Ok(ScalerParams {
571 method,
572 mean: None,
573 std: None,
574 min: None,
575 max: None,
576 median: None,
577 mad: None,
578 });
579 }
580
581 match method {
582 ScalingMethod::StandardScaler => {
583 let mean = values.iter().sum::<f64>() / values.len() as f64;
584 let variance =
585 values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
586 let std = variance.sqrt();
587
588 Ok(ScalerParams {
589 method,
590 mean: Some(mean),
591 std: Some(std),
592 min: None,
593 max: None,
594 median: None,
595 mad: None,
596 })
597 }
598 ScalingMethod::MinMaxScaler => {
599 let min = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
600 let max = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
601
602 Ok(ScalerParams {
603 method,
604 mean: None,
605 std: None,
606 min: Some(min),
607 max: Some(max),
608 median: None,
609 mad: None,
610 })
611 }
612 ScalingMethod::RobustScaler => {
613 let mut sorted_values = values.clone();
614 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
615
616 let median = Self::percentile(&sorted_values, 0.5).unwrap_or(0.0);
617 let mad = Self::compute_mad(&sorted_values, median);
618
619 Ok(ScalerParams {
620 method,
621 mean: None,
622 std: None,
623 min: None,
624 max: None,
625 median: Some(median),
626 mad: Some(mad),
627 })
628 }
629 ScalingMethod::None => Ok(ScalerParams {
630 method,
631 mean: None,
632 std: None,
633 min: None,
634 max: None,
635 median: None,
636 mad: None,
637 }),
638 }
639 }
640
641 fn apply_scaler_to_column(
642 column: &mut ndarray::ArrayViewMut1<f64>,
643 params: &ScalerParams,
644 ) -> Result<()> {
645 match params.method {
646 ScalingMethod::StandardScaler => {
647 if let (Some(mean), Some(std)) = (params.mean, params.std) {
648 if std > 1e-8 {
649 for value in column.iter_mut() {
651 if !value.is_nan() {
652 *value = (*value - mean) / std;
653 }
654 }
655 }
656 }
657 }
658 ScalingMethod::MinMaxScaler => {
659 if let (Some(min), Some(max)) = (params.min, params.max) {
660 let range = max - min;
661 if range > 1e-8 {
662 for value in column.iter_mut() {
664 if !value.is_nan() {
665 *value = (*value - min) / range;
666 }
667 }
668 }
669 }
670 }
671 ScalingMethod::RobustScaler => {
672 if let (Some(median), Some(mad)) = (params.median, params.mad) {
673 if mad > 1e-8 {
674 for value in column.iter_mut() {
676 if !value.is_nan() {
677 *value = (*value - median) / mad;
678 }
679 }
680 }
681 }
682 }
683 ScalingMethod::None => {
684 }
686 }
687
688 Ok(())
689 }
690
691 fn percentile(sorted_values: &[f64], p: f64) -> Option<f64> {
692 if sorted_values.is_empty() {
693 return None;
694 }
695
696 let index = p * (sorted_values.len() - 1) as f64;
697 let lower = index.floor() as usize;
698 let upper = index.ceil() as usize;
699
700 if lower == upper {
701 Some(sorted_values[lower])
702 } else {
703 let weight = index - lower as f64;
704 Some(sorted_values[lower] * (1.0 - weight) + sorted_values[upper] * weight)
705 }
706 }
707
708 fn compute_mad(sorted_values: &[f64], median: f64) -> f64 {
709 let deviations: Vec<f64> = sorted_values.iter().map(|&x| (x - median).abs()).collect();
710
711 let mut sorted_deviations = deviations;
712 sorted_deviations.sort_by(|a, b| a.partial_cmp(b).unwrap());
713
714 Self::percentile(&sorted_deviations, 0.5).unwrap_or(1.0)
715 }
716
717 fn generate_split_indices(
718 &self,
719 n_samples: usize,
720 target: Option<&Array1<f64>>,
721 ) -> Result<Vec<usize>> {
722 let mut indices: Vec<usize> = (0..n_samples).collect();
723
724 if self.config.stratify && target.is_some() {
726 self.stratified_shuffle(&mut indices, target.unwrap())?;
728 } else {
729 match self.config.random_state {
731 Some(seed) => {
732 let mut rng = StdRng::seed_from_u64(seed);
733 indices.shuffle(&mut rng);
734 }
735 None => {
736 let mut rng = rng();
737 indices.shuffle(&mut rng);
738 }
739 }
740 }
741
742 Ok(indices)
743 }
744
745 fn stratified_shuffle(&self, indices: &mut Vec<usize>, target: &Array1<f64>) -> Result<()> {
747 let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
749
750 for &idx in indices.iter() {
751 let class = target[idx] as i32;
752 class_indices.entry(class).or_default().push(idx);
753 }
754
755 for class_group in class_indices.values_mut() {
757 match self.config.random_state {
758 Some(seed) => {
759 let mut rng = StdRng::seed_from_u64(seed);
760 class_group.shuffle(&mut rng);
761 }
762 None => {
763 let mut rng = rng();
764 class_group.shuffle(&mut rng);
765 }
766 }
767 }
768
769 indices.clear();
771 let mut class_iterators: HashMap<i32, std::vec::IntoIter<usize>> = class_indices
772 .into_iter()
773 .map(|(class, group)| (class, group.into_iter()))
774 .collect();
775
776 while !class_iterators.is_empty() {
778 let mut to_remove = Vec::new();
779 for (&class, iterator) in class_iterators.iter_mut() {
780 if let Some(idx) = iterator.next() {
781 indices.push(idx);
782 } else {
783 to_remove.push(class);
784 }
785 }
786 for class in to_remove {
787 class_iterators.remove(&class);
788 }
789 }
790
791 Ok(())
792 }
793
794 fn extract_dataset_info(&self, dataset: &Dataset) -> DatasetInfo {
795 let n_samples = dataset.n_samples();
796 let n_features = dataset.n_features();
797
798 let (n_classes, class_distribution) = if let Some(ref target) = dataset.target {
799 let mut class_counts: HashMap<String, usize> = HashMap::new();
800 for &value in target.iter() {
801 if !value.is_nan() {
802 let classname = format!("{value:.0}");
803 *class_counts.entry(classname).or_insert(0) += 1;
804 }
805 }
806
807 let n_classes = class_counts.len();
808 (Some(n_classes), Some(class_counts))
809 } else {
810 (None, None)
811 };
812
813 let total_values = n_samples * n_features;
815 let missing_values = dataset.data.iter().filter(|&&x| x.is_nan()).count();
816 let missing_data_percentage = missing_values as f64 / total_values as f64 * 100.0;
817
818 DatasetInfo {
819 n_samples,
820 n_features,
821 n_classes,
822 class_distribution,
823 missing_data_percentage,
824 }
825 }
826}
827
828pub mod convenience {
830 use super::*;
831
832 pub fn train_test_split(_dataset: &Dataset, testsize: Option<f64>) -> Result<DataSplit> {
834 let mut config = MLPipelineConfig::default();
835 if let Some(_size) = testsize {
836 config.test_size = _size;
837 }
838
839 let pipeline = MLPipeline::new(config);
840 pipeline.train_test_split(_dataset)
841 }
842
843 pub fn prepare_for_ml(dataset: &Dataset, scale: bool, balance: bool) -> Result<Dataset> {
845 let mut config = MLPipelineConfig::default();
846
847 if !scale {
848 config.scaling_method = None;
849 }
850
851 if balance {
852 config.balancing_strategy = Some(BalancingStrategy::RandomUndersample);
853 }
854
855 let mut pipeline = MLPipeline::new(config);
856 pipeline.prepare_dataset(dataset)
857 }
858
859 pub fn cv_split(
861 dataset: &Dataset,
862 n_folds: Option<usize>,
863 stratify: Option<bool>,
864 ) -> Result<CrossValidationFolds> {
865 let mut config = MLPipelineConfig::default();
866
867 if let Some(_folds) = n_folds {
868 config.cv_folds = _folds;
869 }
870
871 if let Some(strat) = stratify {
872 config.stratify = strat;
873 }
874
875 let pipeline = MLPipeline::new(config);
876 pipeline.cross_validation_split(dataset)
877 }
878
879 pub fn create_experiment(name: &str, dataset: &Dataset) -> MLExperiment {
881 let pipeline = MLPipeline::default();
882 pipeline.create_experiment(name, dataset)
883 }
884}
885
886#[cfg(test)]
887mod tests {
888 use super::*;
889 use crate::generators::make_classification;
890 use rand_distr::Uniform;
891
892 #[test]
893 fn test_ml_pipeline_creation() {
894 let pipeline = MLPipeline::default();
895 assert_eq!(pipeline.config.test_size, 0.2);
896 assert_eq!(pipeline.config.cv_folds, 5);
897 }
898
899 #[test]
900 fn test_train_test_split() {
901 let dataset = make_classification(100, 5, 2, 1, 1, Some(42)).unwrap();
902 let split = convenience::train_test_split(&dataset, Some(0.3)).unwrap();
903
904 assert_eq!(split.x_train.nrows() + split.x_test.nrows(), 100);
905 assert_eq!(split.y_train.len() + split.y_test.len(), 100);
906 assert_eq!(split.x_train.ncols(), 5);
907 assert_eq!(split.x_test.ncols(), 5);
908 }
909
910 #[test]
911 fn test_cross_validation_split() {
912 let dataset = make_classification(100, 3, 2, 1, 1, Some(42)).unwrap();
913 let folds = convenience::cv_split(&dataset, Some(5), Some(true)).unwrap();
914
915 assert_eq!(folds.len(), 5);
916
917 let total_samples: usize = folds
919 .iter()
920 .map(|(train, test)| train.len() + test.len())
921 .sum::<usize>()
922 / 5; assert_eq!(total_samples, 100);
925 }
926
927 #[test]
928 fn test_dataset_preparation() {
929 let dataset = make_classification(50, 4, 2, 1, 1, Some(42)).unwrap();
930 let prepared = convenience::prepare_for_ml(&dataset, true, false).unwrap();
931
932 assert_eq!(prepared.n_samples(), dataset.n_samples());
933 assert_eq!(prepared.n_features(), dataset.n_features());
934 }
935
936 #[test]
937 fn test_experiment_creation() {
938 let dataset = make_classification(100, 5, 2, 1, 1, Some(42)).unwrap();
939 let experiment = convenience::create_experiment("test_experiment", &dataset);
940
941 assert_eq!(experiment.name, "test_experiment");
942 assert_eq!(experiment.dataset_info.n_samples, 100);
943 assert_eq!(experiment.dataset_info.n_features, 5);
944 assert_eq!(experiment.dataset_info.n_classes, Some(2));
945 }
946
947 #[test]
948 fn test_scaler_fitting() {
949 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
950 let array = Array1::from_vec(data);
951 let view = array.view();
952
953 let scaler_params = MLPipeline::fit_scaler(&view, ScalingMethod::StandardScaler).unwrap();
954
955 assert!(scaler_params.mean.is_some());
956 assert!(scaler_params.std.is_some());
957 assert_eq!(scaler_params.mean.unwrap(), 3.0);
958 }
959
960 #[test]
961 fn test_min_max_scaler() {
962 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
963 let array = Array1::from_vec(data);
964 let view = array.view();
965
966 let scaler_params = MLPipeline::fit_scaler(&view, ScalingMethod::MinMaxScaler).unwrap();
967
968 assert!(scaler_params.min.is_some());
969 assert!(scaler_params.max.is_some());
970 assert_eq!(scaler_params.min.unwrap(), 1.0);
971 assert_eq!(scaler_params.max.unwrap(), 5.0);
972 }
973}