1use crate::{TrainError, TrainResult};
10use scirs2_core::ndarray::{s, Array1, Array2};
11use std::collections::HashMap;
12use std::fs::File;
13use std::io::{BufRead, BufReader};
14use std::path::Path;
15
16#[derive(Debug, Clone)]
18pub struct Dataset {
19 pub features: Array2<f64>,
21 pub targets: Array2<f64>,
23 pub feature_names: Option<Vec<String>>,
25 pub target_names: Option<Vec<String>>,
27}
28
29impl Dataset {
30 pub fn new(features: Array2<f64>, targets: Array2<f64>) -> Self {
32 Self {
33 features,
34 targets,
35 feature_names: None,
36 target_names: None,
37 }
38 }
39
40 pub fn with_feature_names(mut self, names: Vec<String>) -> Self {
42 self.feature_names = Some(names);
43 self
44 }
45
46 pub fn with_target_names(mut self, names: Vec<String>) -> Self {
48 self.target_names = Some(names);
49 self
50 }
51
52 pub fn num_samples(&self) -> usize {
54 self.features.nrows()
55 }
56
57 pub fn num_features(&self) -> usize {
59 self.features.ncols()
60 }
61
62 pub fn num_targets(&self) -> usize {
64 self.targets.ncols()
65 }
66
67 pub fn shuffle(&mut self, seed: u64) {
69 let n = self.num_samples();
70 if n <= 1 {
71 return;
72 }
73
74 let mut rng_state = seed;
76 let lcg_next = |state: &mut u64| -> usize {
77 *state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
78 (*state >> 33) as usize
79 };
80
81 for i in (1..n).rev() {
82 let j = lcg_next(&mut rng_state) % (i + 1);
83 for col in 0..self.features.ncols() {
85 let tmp = self.features[[i, col]];
86 self.features[[i, col]] = self.features[[j, col]];
87 self.features[[j, col]] = tmp;
88 }
89 for col in 0..self.targets.ncols() {
90 let tmp = self.targets[[i, col]];
91 self.targets[[i, col]] = self.targets[[j, col]];
92 self.targets[[j, col]] = tmp;
93 }
94 }
95 }
96
97 pub fn split(&self, ratios: &[f64]) -> TrainResult<Vec<Dataset>> {
105 let total: f64 = ratios.iter().sum();
106 if (total - 1.0).abs() > 1e-6 {
107 return Err(TrainError::ConfigError(format!(
108 "Split ratios must sum to 1.0, got {}",
109 total
110 )));
111 }
112
113 let n = self.num_samples();
114 let mut splits = Vec::new();
115 let mut start = 0;
116
117 for (i, &ratio) in ratios.iter().enumerate() {
118 let end = if i == ratios.len() - 1 {
119 n } else {
121 start + (n as f64 * ratio).round() as usize
122 };
123
124 let features = self.features.slice(s![start..end, ..]).to_owned();
125 let targets = self.targets.slice(s![start..end, ..]).to_owned();
126
127 let mut dataset = Dataset::new(features, targets);
128 if let Some(ref names) = self.feature_names {
129 dataset.feature_names = Some(names.clone());
130 }
131 if let Some(ref names) = self.target_names {
132 dataset.target_names = Some(names.clone());
133 }
134
135 splits.push(dataset);
136 start = end;
137 }
138
139 Ok(splits)
140 }
141
142 pub fn train_test_split(&self, train_ratio: f64) -> TrainResult<(Dataset, Dataset)> {
144 let splits = self.split(&[train_ratio, 1.0 - train_ratio])?;
145 let mut iter = splits.into_iter();
146 Ok((iter.next().unwrap(), iter.next().unwrap()))
147 }
148
149 pub fn train_val_test_split(
151 &self,
152 train_ratio: f64,
153 val_ratio: f64,
154 ) -> TrainResult<(Dataset, Dataset, Dataset)> {
155 let test_ratio = 1.0 - train_ratio - val_ratio;
156 if test_ratio < 0.0 {
157 return Err(TrainError::ConfigError(
158 "Train and validation ratios exceed 1.0".to_string(),
159 ));
160 }
161 let splits = self.split(&[train_ratio, val_ratio, test_ratio])?;
162 let mut iter = splits.into_iter();
163 Ok((
164 iter.next().unwrap(),
165 iter.next().unwrap(),
166 iter.next().unwrap(),
167 ))
168 }
169
170 pub fn subset(&self, indices: &[usize]) -> TrainResult<Dataset> {
172 let n = self.num_samples();
173 for &idx in indices {
174 if idx >= n {
175 return Err(TrainError::ConfigError(format!(
176 "Index {} out of bounds for dataset with {} samples",
177 idx, n
178 )));
179 }
180 }
181
182 let features = Array2::from_shape_fn((indices.len(), self.num_features()), |(i, j)| {
183 self.features[[indices[i], j]]
184 });
185 let targets = Array2::from_shape_fn((indices.len(), self.num_targets()), |(i, j)| {
186 self.targets[[indices[i], j]]
187 });
188
189 let mut dataset = Dataset::new(features, targets);
190 dataset.feature_names = self.feature_names.clone();
191 dataset.target_names = self.target_names.clone();
192
193 Ok(dataset)
194 }
195}
196
197#[derive(Debug, Clone)]
199pub struct CsvLoader {
200 pub has_header: bool,
202 pub delimiter: char,
204 pub target_columns: Vec<usize>,
206 pub skip_columns: Vec<usize>,
208}
209
210impl Default for CsvLoader {
211 fn default() -> Self {
212 Self {
213 has_header: true,
214 delimiter: ',',
215 target_columns: vec![],
216 skip_columns: vec![],
217 }
218 }
219}
220
221impl CsvLoader {
222 pub fn new() -> Self {
224 Self::default()
225 }
226
227 pub fn with_header(mut self, has_header: bool) -> Self {
229 self.has_header = has_header;
230 self
231 }
232
233 pub fn with_delimiter(mut self, delimiter: char) -> Self {
235 self.delimiter = delimiter;
236 self
237 }
238
239 pub fn with_target_columns(mut self, columns: Vec<usize>) -> Self {
241 self.target_columns = columns;
242 self
243 }
244
245 pub fn with_skip_columns(mut self, columns: Vec<usize>) -> Self {
247 self.skip_columns = columns;
248 self
249 }
250
251 pub fn load<P: AsRef<Path>>(&self, path: P) -> TrainResult<Dataset> {
253 let file = File::open(path.as_ref())
254 .map_err(|e| TrainError::Other(format!("Failed to open CSV file: {}", e)))?;
255 let reader = BufReader::new(file);
256 let mut lines = reader.lines();
257
258 let mut feature_names = None;
259 let mut target_names = None;
260
261 if self.has_header {
263 if let Some(Ok(header)) = lines.next() {
264 let names: Vec<String> = header
265 .split(self.delimiter)
266 .map(|s| s.trim().to_string())
267 .collect();
268
269 let mut feat_names = Vec::new();
270 let mut targ_names = Vec::new();
271
272 for (i, name) in names.into_iter().enumerate() {
273 if self.skip_columns.contains(&i) {
274 continue;
275 }
276 if self.target_columns.contains(&i) {
277 targ_names.push(name);
278 } else {
279 feat_names.push(name);
280 }
281 }
282
283 feature_names = Some(feat_names);
284 target_names = Some(targ_names);
285 }
286 }
287
288 let mut features_data: Vec<Vec<f64>> = Vec::new();
290 let mut targets_data: Vec<Vec<f64>> = Vec::new();
291
292 for line_result in lines {
293 let line = line_result
294 .map_err(|e| TrainError::Other(format!("Failed to read CSV line: {}", e)))?;
295
296 if line.trim().is_empty() {
297 continue;
298 }
299
300 let values: Vec<&str> = line.split(self.delimiter).collect();
301 let mut row_features = Vec::new();
302 let mut row_targets = Vec::new();
303
304 for (i, value) in values.iter().enumerate() {
305 if self.skip_columns.contains(&i) {
306 continue;
307 }
308
309 let parsed: f64 = value.trim().parse().map_err(|e| {
310 TrainError::Other(format!("Failed to parse value '{}': {}", value, e))
311 })?;
312
313 if self.target_columns.contains(&i) {
314 row_targets.push(parsed);
315 } else {
316 row_features.push(parsed);
317 }
318 }
319
320 features_data.push(row_features);
321 targets_data.push(row_targets);
322 }
323
324 if features_data.is_empty() {
325 return Err(TrainError::Other("CSV file is empty".to_string()));
326 }
327
328 let n_samples = features_data.len();
329 let n_features = features_data[0].len();
330 let n_targets = if targets_data[0].is_empty() {
331 0
332 } else {
333 targets_data[0].len()
334 };
335
336 let features = Array2::from_shape_fn((n_samples, n_features), |(i, j)| features_data[i][j]);
338
339 let targets = if n_targets > 0 {
340 Array2::from_shape_fn((n_samples, n_targets), |(i, j)| targets_data[i][j])
341 } else {
342 Array2::zeros((n_samples, 1))
343 };
344
345 let mut dataset = Dataset::new(features, targets);
346 dataset.feature_names = feature_names;
347 dataset.target_names = target_names;
348
349 Ok(dataset)
350 }
351}
352
353#[derive(Debug, Clone)]
355pub struct DataPreprocessor {
356 method: PreprocessingMethod,
358 params: Option<PreprocessingParams>,
360}
361
362#[derive(Debug, Clone, Copy, PartialEq, Eq)]
364pub enum PreprocessingMethod {
365 Standardize,
367 MinMaxNormalize,
369 MinMaxScale { min: i32, max: i32 },
371 None,
373}
374
375#[derive(Debug, Clone)]
377struct PreprocessingParams {
378 means: Array1<f64>,
379 stds: Array1<f64>,
380 mins: Array1<f64>,
381 maxs: Array1<f64>,
382}
383
384impl DataPreprocessor {
385 pub fn standardize() -> Self {
387 Self {
388 method: PreprocessingMethod::Standardize,
389 params: None,
390 }
391 }
392
393 pub fn min_max_normalize() -> Self {
395 Self {
396 method: PreprocessingMethod::MinMaxNormalize,
397 params: None,
398 }
399 }
400
401 pub fn min_max_scale(min: i32, max: i32) -> Self {
403 Self {
404 method: PreprocessingMethod::MinMaxScale { min, max },
405 params: None,
406 }
407 }
408
409 pub fn none() -> Self {
411 Self {
412 method: PreprocessingMethod::None,
413 params: None,
414 }
415 }
416
417 pub fn fit(&mut self, data: &Array2<f64>) -> &mut Self {
419 let n_features = data.ncols();
420
421 let mut means = Array1::zeros(n_features);
422 let mut stds = Array1::zeros(n_features);
423 let mut mins = Array1::from_elem(n_features, f64::INFINITY);
424 let mut maxs = Array1::from_elem(n_features, f64::NEG_INFINITY);
425
426 for j in 0..n_features {
427 let col = data.column(j);
428 let n = col.len() as f64;
429
430 let mean: f64 = col.iter().sum::<f64>() / n;
432 means[j] = mean;
433
434 let variance: f64 = col.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
436 stds[j] = variance.sqrt().max(1e-8); for &x in col.iter() {
440 if x < mins[j] {
441 mins[j] = x;
442 }
443 if x > maxs[j] {
444 maxs[j] = x;
445 }
446 }
447 }
448
449 self.params = Some(PreprocessingParams {
450 means,
451 stds,
452 mins,
453 maxs,
454 });
455
456 self
457 }
458
459 pub fn transform(&self, data: &Array2<f64>) -> TrainResult<Array2<f64>> {
461 let params = self.params.as_ref().ok_or_else(|| {
462 TrainError::Other("Preprocessor not fitted. Call fit() first.".to_string())
463 })?;
464
465 let mut result = data.clone();
466
467 match self.method {
468 PreprocessingMethod::Standardize => {
469 for j in 0..data.ncols() {
470 for i in 0..data.nrows() {
471 result[[i, j]] = (data[[i, j]] - params.means[j]) / params.stds[j];
472 }
473 }
474 }
475 PreprocessingMethod::MinMaxNormalize => {
476 for j in 0..data.ncols() {
477 let range = (params.maxs[j] - params.mins[j]).max(1e-8);
478 for i in 0..data.nrows() {
479 result[[i, j]] = (data[[i, j]] - params.mins[j]) / range;
480 }
481 }
482 }
483 PreprocessingMethod::MinMaxScale { min, max } => {
484 let target_range = (max - min) as f64;
485 for j in 0..data.ncols() {
486 let range = (params.maxs[j] - params.mins[j]).max(1e-8);
487 for i in 0..data.nrows() {
488 let normalized = (data[[i, j]] - params.mins[j]) / range;
489 result[[i, j]] = normalized * target_range + min as f64;
490 }
491 }
492 }
493 PreprocessingMethod::None => {}
494 }
495
496 Ok(result)
497 }
498
499 pub fn fit_transform(&mut self, data: &Array2<f64>) -> TrainResult<Array2<f64>> {
501 self.fit(data);
502 self.transform(data)
503 }
504
505 pub fn inverse_transform(&self, data: &Array2<f64>) -> TrainResult<Array2<f64>> {
507 let params = self.params.as_ref().ok_or_else(|| {
508 TrainError::Other("Preprocessor not fitted. Call fit() first.".to_string())
509 })?;
510
511 let mut result = data.clone();
512
513 match self.method {
514 PreprocessingMethod::Standardize => {
515 for j in 0..data.ncols() {
516 for i in 0..data.nrows() {
517 result[[i, j]] = data[[i, j]] * params.stds[j] + params.means[j];
518 }
519 }
520 }
521 PreprocessingMethod::MinMaxNormalize => {
522 for j in 0..data.ncols() {
523 let range = params.maxs[j] - params.mins[j];
524 for i in 0..data.nrows() {
525 result[[i, j]] = data[[i, j]] * range + params.mins[j];
526 }
527 }
528 }
529 PreprocessingMethod::MinMaxScale { min, max } => {
530 let target_range = (max - min) as f64;
531 for j in 0..data.ncols() {
532 let range = params.maxs[j] - params.mins[j];
533 for i in 0..data.nrows() {
534 let normalized = (data[[i, j]] - min as f64) / target_range;
535 result[[i, j]] = normalized * range + params.mins[j];
536 }
537 }
538 }
539 PreprocessingMethod::None => {}
540 }
541
542 Ok(result)
543 }
544
545 pub fn is_fitted(&self) -> bool {
547 self.params.is_some()
548 }
549
550 pub fn method(&self) -> PreprocessingMethod {
552 self.method
553 }
554}
555
556#[derive(Debug, Clone)]
558pub struct OneHotEncoder {
559 categories: HashMap<usize, HashMap<String, usize>>,
561 n_categories: HashMap<usize, usize>,
563}
564
565impl OneHotEncoder {
566 pub fn new() -> Self {
568 Self {
569 categories: HashMap::new(),
570 n_categories: HashMap::new(),
571 }
572 }
573
574 pub fn fit(&mut self, data: &[(usize, Vec<String>)]) -> &mut Self {
579 for (col_idx, values) in data {
580 let mut categories = HashMap::new();
581 let mut unique_values: Vec<&String> = values.iter().collect();
582 unique_values.sort();
583 unique_values.dedup();
584
585 for (i, value) in unique_values.into_iter().enumerate() {
586 categories.insert(value.clone(), i);
587 }
588
589 self.n_categories.insert(*col_idx, categories.len());
590 self.categories.insert(*col_idx, categories);
591 }
592
593 self
594 }
595
596 pub fn transform(&self, col_idx: usize, values: &[String]) -> TrainResult<Array2<f64>> {
598 let categories = self
599 .categories
600 .get(&col_idx)
601 .ok_or_else(|| TrainError::Other(format!("Column {} not fitted", col_idx)))?;
602
603 let n_samples = values.len();
604 let n_cats = *self.n_categories.get(&col_idx).unwrap();
605
606 let mut result = Array2::zeros((n_samples, n_cats));
607
608 for (i, value) in values.iter().enumerate() {
609 if let Some(&idx) = categories.get(value) {
610 result[[i, idx]] = 1.0;
611 } else {
612 return Err(TrainError::Other(format!(
613 "Unknown category '{}' for column {}",
614 value, col_idx
615 )));
616 }
617 }
618
619 Ok(result)
620 }
621
622 pub fn num_categories(&self, col_idx: usize) -> Option<usize> {
624 self.n_categories.get(&col_idx).copied()
625 }
626}
627
628impl Default for OneHotEncoder {
629 fn default() -> Self {
630 Self::new()
631 }
632}
633
634#[derive(Debug, Clone)]
636pub struct LabelEncoder {
637 label_to_int: HashMap<String, usize>,
639 int_to_label: Vec<String>,
641}
642
643impl LabelEncoder {
644 pub fn new() -> Self {
646 Self {
647 label_to_int: HashMap::new(),
648 int_to_label: Vec::new(),
649 }
650 }
651
652 pub fn fit(&mut self, labels: &[String]) -> &mut Self {
654 let mut unique: Vec<&String> = labels.iter().collect();
655 unique.sort();
656 unique.dedup();
657
658 self.label_to_int.clear();
659 self.int_to_label.clear();
660
661 for (i, label) in unique.into_iter().enumerate() {
662 self.label_to_int.insert(label.clone(), i);
663 self.int_to_label.push(label.clone());
664 }
665
666 self
667 }
668
669 pub fn transform(&self, labels: &[String]) -> TrainResult<Array1<usize>> {
671 let mut result = Array1::zeros(labels.len());
672
673 for (i, label) in labels.iter().enumerate() {
674 result[i] = *self
675 .label_to_int
676 .get(label)
677 .ok_or_else(|| TrainError::Other(format!("Unknown label: {}", label)))?;
678 }
679
680 Ok(result)
681 }
682
683 pub fn inverse_transform(&self, indices: &Array1<usize>) -> TrainResult<Vec<String>> {
685 let mut result = Vec::with_capacity(indices.len());
686
687 for &idx in indices.iter() {
688 if idx >= self.int_to_label.len() {
689 return Err(TrainError::Other(format!(
690 "Index {} out of bounds for {} classes",
691 idx,
692 self.int_to_label.len()
693 )));
694 }
695 result.push(self.int_to_label[idx].clone());
696 }
697
698 Ok(result)
699 }
700
701 pub fn fit_transform(&mut self, labels: &[String]) -> TrainResult<Array1<usize>> {
703 self.fit(labels);
704 self.transform(labels)
705 }
706
707 pub fn num_classes(&self) -> usize {
709 self.int_to_label.len()
710 }
711
712 pub fn classes(&self) -> &[String] {
714 &self.int_to_label
715 }
716}
717
718impl Default for LabelEncoder {
719 fn default() -> Self {
720 Self::new()
721 }
722}
723
724#[cfg(test)]
725mod tests {
726 use super::*;
727
728 #[test]
729 fn test_dataset_creation() {
730 let features = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
731 let targets = Array2::from_shape_vec((3, 1), vec![0.0, 1.0, 0.0]).unwrap();
732
733 let dataset = Dataset::new(features, targets);
734
735 assert_eq!(dataset.num_samples(), 3);
736 assert_eq!(dataset.num_features(), 2);
737 assert_eq!(dataset.num_targets(), 1);
738 }
739
740 #[test]
741 fn test_dataset_split() {
742 let features = Array2::from_shape_fn((10, 2), |(i, j)| (i * 2 + j) as f64);
743 let targets = Array2::from_shape_fn((10, 1), |(i, _)| i as f64);
744
745 let dataset = Dataset::new(features, targets);
746 let splits = dataset.split(&[0.6, 0.2, 0.2]).unwrap();
747
748 assert_eq!(splits.len(), 3);
749 assert_eq!(splits[0].num_samples(), 6);
750 assert_eq!(splits[1].num_samples(), 2);
751 assert_eq!(splits[2].num_samples(), 2);
752 }
753
754 #[test]
755 fn test_train_test_split() {
756 let features = Array2::from_shape_fn((100, 4), |(i, j)| (i * 4 + j) as f64);
757 let targets = Array2::from_shape_fn((100, 1), |(i, _)| (i % 2) as f64);
758
759 let dataset = Dataset::new(features, targets);
760 let (train, test) = dataset.train_test_split(0.8).unwrap();
761
762 assert_eq!(train.num_samples(), 80);
763 assert_eq!(test.num_samples(), 20);
764 }
765
766 #[test]
767 fn test_dataset_shuffle() {
768 let features = Array2::from_shape_fn((10, 2), |(i, j)| (i * 2 + j) as f64);
769 let targets = Array2::from_shape_fn((10, 1), |(i, _)| i as f64);
770
771 let mut dataset = Dataset::new(features.clone(), targets);
772 dataset.shuffle(42);
773
774 let mut different = false;
776 for i in 0..10 {
777 if dataset.features[[i, 0]] != features[[i, 0]] {
778 different = true;
779 break;
780 }
781 }
782 assert!(different);
783 }
784
785 #[test]
786 fn test_dataset_subset() {
787 let features = Array2::from_shape_fn((10, 2), |(i, j)| (i * 2 + j) as f64);
788 let targets = Array2::from_shape_fn((10, 1), |(i, _)| i as f64);
789
790 let dataset = Dataset::new(features, targets);
791 let subset = dataset.subset(&[0, 2, 4]).unwrap();
792
793 assert_eq!(subset.num_samples(), 3);
794 assert_eq!(subset.features[[0, 0]], 0.0);
795 assert_eq!(subset.features[[1, 0]], 4.0);
796 assert_eq!(subset.features[[2, 0]], 8.0);
797 }
798
799 #[test]
800 fn test_preprocessor_standardize() {
801 let data =
802 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
803
804 let mut preprocessor = DataPreprocessor::standardize();
805 let transformed = preprocessor.fit_transform(&data).unwrap();
806
807 let col0_mean: f64 = transformed.column(0).iter().sum::<f64>() / 4.0;
809 let col1_mean: f64 = transformed.column(1).iter().sum::<f64>() / 4.0;
810
811 assert!(col0_mean.abs() < 1e-10);
812 assert!(col1_mean.abs() < 1e-10);
813
814 let recovered = preprocessor.inverse_transform(&transformed).unwrap();
816 for i in 0..4 {
817 for j in 0..2 {
818 assert!((recovered[[i, j]] - data[[i, j]]).abs() < 1e-10);
819 }
820 }
821 }
822
823 #[test]
824 fn test_preprocessor_min_max() {
825 let data =
826 Array2::from_shape_vec((4, 2), vec![0.0, 10.0, 5.0, 20.0, 10.0, 30.0, 15.0, 40.0])
827 .unwrap();
828
829 let mut preprocessor = DataPreprocessor::min_max_normalize();
830 let transformed = preprocessor.fit_transform(&data).unwrap();
831
832 for &val in transformed.iter() {
834 assert!((0.0..=1.0).contains(&val));
835 }
836
837 assert!((transformed[[0, 0]] - 0.0).abs() < 1e-10); assert!((transformed[[3, 0]] - 1.0).abs() < 1e-10); }
841
842 #[test]
843 fn test_label_encoder() {
844 let labels = vec![
845 "cat".to_string(),
846 "dog".to_string(),
847 "cat".to_string(),
848 "bird".to_string(),
849 ];
850
851 let mut encoder = LabelEncoder::new();
852 let encoded = encoder.fit_transform(&labels).unwrap();
853
854 assert_eq!(encoder.num_classes(), 3);
855 assert_eq!(encoded.len(), 4);
856
857 assert_eq!(encoded[0], encoded[2]);
859
860 let decoded = encoder.inverse_transform(&encoded).unwrap();
862 assert_eq!(decoded, labels);
863 }
864
865 #[test]
866 fn test_one_hot_encoder() {
867 let values = vec![
868 "red".to_string(),
869 "green".to_string(),
870 "blue".to_string(),
871 "red".to_string(),
872 ];
873
874 let mut encoder = OneHotEncoder::new();
875 encoder.fit(&[(0, values.clone())]);
876
877 let encoded = encoder.transform(0, &values).unwrap();
878
879 assert_eq!(encoded.nrows(), 4);
880 assert_eq!(encoded.ncols(), 3);
881
882 for i in 0..4 {
884 let row_sum: f64 = encoded.row(i).iter().sum();
885 assert!((row_sum - 1.0).abs() < 1e-10);
886 }
887 }
888
889 #[test]
890 fn test_csv_loader_builder() {
891 let loader = CsvLoader::new()
892 .with_header(true)
893 .with_delimiter(',')
894 .with_target_columns(vec![3]);
895
896 assert!(loader.has_header);
897 assert_eq!(loader.delimiter, ',');
898 assert_eq!(loader.target_columns, vec![3]);
899 }
900
901 #[test]
902 fn test_invalid_split_ratios() {
903 let features = Array2::zeros((10, 2));
904 let targets = Array2::zeros((10, 1));
905 let dataset = Dataset::new(features, targets);
906
907 let result = dataset.split(&[0.5, 0.3]);
909 assert!(result.is_err());
910 }
911
912 #[test]
913 fn test_preprocessor_not_fitted() {
914 let data = Array2::zeros((4, 2));
915 let preprocessor = DataPreprocessor::standardize();
916
917 let result = preprocessor.transform(&data);
918 assert!(result.is_err());
919 }
920}