1use crate::{TrainError, TrainResult};
10use scirs2_core::ndarray::{s, Array1, Array2};
11use std::collections::HashMap;
12use std::fs::File;
13use std::io::{BufRead, BufReader};
14use std::path::Path;
15
16#[derive(Debug, Clone)]
18pub struct Dataset {
19 pub features: Array2<f64>,
21 pub targets: Array2<f64>,
23 pub feature_names: Option<Vec<String>>,
25 pub target_names: Option<Vec<String>>,
27}
28
29impl Dataset {
30 pub fn new(features: Array2<f64>, targets: Array2<f64>) -> Self {
32 Self {
33 features,
34 targets,
35 feature_names: None,
36 target_names: None,
37 }
38 }
39
40 pub fn with_feature_names(mut self, names: Vec<String>) -> Self {
42 self.feature_names = Some(names);
43 self
44 }
45
46 pub fn with_target_names(mut self, names: Vec<String>) -> Self {
48 self.target_names = Some(names);
49 self
50 }
51
52 pub fn num_samples(&self) -> usize {
54 self.features.nrows()
55 }
56
57 pub fn num_features(&self) -> usize {
59 self.features.ncols()
60 }
61
62 pub fn num_targets(&self) -> usize {
64 self.targets.ncols()
65 }
66
67 pub fn shuffle(&mut self, seed: u64) {
69 let n = self.num_samples();
70 if n <= 1 {
71 return;
72 }
73
74 let mut rng_state = seed;
76 let lcg_next = |state: &mut u64| -> usize {
77 *state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
78 (*state >> 33) as usize
79 };
80
81 for i in (1..n).rev() {
82 let j = lcg_next(&mut rng_state) % (i + 1);
83 for col in 0..self.features.ncols() {
85 let tmp = self.features[[i, col]];
86 self.features[[i, col]] = self.features[[j, col]];
87 self.features[[j, col]] = tmp;
88 }
89 for col in 0..self.targets.ncols() {
90 let tmp = self.targets[[i, col]];
91 self.targets[[i, col]] = self.targets[[j, col]];
92 self.targets[[j, col]] = tmp;
93 }
94 }
95 }
96
97 pub fn split(&self, ratios: &[f64]) -> TrainResult<Vec<Dataset>> {
105 let total: f64 = ratios.iter().sum();
106 if (total - 1.0).abs() > 1e-6 {
107 return Err(TrainError::ConfigError(format!(
108 "Split ratios must sum to 1.0, got {}",
109 total
110 )));
111 }
112
113 let n = self.num_samples();
114 let mut splits = Vec::new();
115 let mut start = 0;
116
117 for (i, &ratio) in ratios.iter().enumerate() {
118 let end = if i == ratios.len() - 1 {
119 n } else {
121 start + (n as f64 * ratio).round() as usize
122 };
123
124 let features = self.features.slice(s![start..end, ..]).to_owned();
125 let targets = self.targets.slice(s![start..end, ..]).to_owned();
126
127 let mut dataset = Dataset::new(features, targets);
128 if let Some(ref names) = self.feature_names {
129 dataset.feature_names = Some(names.clone());
130 }
131 if let Some(ref names) = self.target_names {
132 dataset.target_names = Some(names.clone());
133 }
134
135 splits.push(dataset);
136 start = end;
137 }
138
139 Ok(splits)
140 }
141
142 pub fn train_test_split(&self, train_ratio: f64) -> TrainResult<(Dataset, Dataset)> {
144 let splits = self.split(&[train_ratio, 1.0 - train_ratio])?;
145 let mut iter = splits.into_iter();
146 Ok((
147 iter.next().expect("split returns exactly 2 parts"),
148 iter.next().expect("split returns exactly 2 parts"),
149 ))
150 }
151
152 pub fn train_val_test_split(
154 &self,
155 train_ratio: f64,
156 val_ratio: f64,
157 ) -> TrainResult<(Dataset, Dataset, Dataset)> {
158 let test_ratio = 1.0 - train_ratio - val_ratio;
159 if test_ratio < 0.0 {
160 return Err(TrainError::ConfigError(
161 "Train and validation ratios exceed 1.0".to_string(),
162 ));
163 }
164 let splits = self.split(&[train_ratio, val_ratio, test_ratio])?;
165 let mut iter = splits.into_iter();
166 Ok((
167 iter.next().expect("split returns exactly 3 parts"),
168 iter.next().expect("split returns exactly 3 parts"),
169 iter.next().expect("split returns exactly 3 parts"),
170 ))
171 }
172
173 pub fn subset(&self, indices: &[usize]) -> TrainResult<Dataset> {
175 let n = self.num_samples();
176 for &idx in indices {
177 if idx >= n {
178 return Err(TrainError::ConfigError(format!(
179 "Index {} out of bounds for dataset with {} samples",
180 idx, n
181 )));
182 }
183 }
184
185 let features = Array2::from_shape_fn((indices.len(), self.num_features()), |(i, j)| {
186 self.features[[indices[i], j]]
187 });
188 let targets = Array2::from_shape_fn((indices.len(), self.num_targets()), |(i, j)| {
189 self.targets[[indices[i], j]]
190 });
191
192 let mut dataset = Dataset::new(features, targets);
193 dataset.feature_names = self.feature_names.clone();
194 dataset.target_names = self.target_names.clone();
195
196 Ok(dataset)
197 }
198}
199
200#[derive(Debug, Clone)]
202pub struct CsvLoader {
203 pub has_header: bool,
205 pub delimiter: char,
207 pub target_columns: Vec<usize>,
209 pub skip_columns: Vec<usize>,
211}
212
213impl Default for CsvLoader {
214 fn default() -> Self {
215 Self {
216 has_header: true,
217 delimiter: ',',
218 target_columns: vec![],
219 skip_columns: vec![],
220 }
221 }
222}
223
224impl CsvLoader {
225 pub fn new() -> Self {
227 Self::default()
228 }
229
230 pub fn with_header(mut self, has_header: bool) -> Self {
232 self.has_header = has_header;
233 self
234 }
235
236 pub fn with_delimiter(mut self, delimiter: char) -> Self {
238 self.delimiter = delimiter;
239 self
240 }
241
242 pub fn with_target_columns(mut self, columns: Vec<usize>) -> Self {
244 self.target_columns = columns;
245 self
246 }
247
248 pub fn with_skip_columns(mut self, columns: Vec<usize>) -> Self {
250 self.skip_columns = columns;
251 self
252 }
253
254 pub fn load<P: AsRef<Path>>(&self, path: P) -> TrainResult<Dataset> {
256 let file = File::open(path.as_ref())
257 .map_err(|e| TrainError::Other(format!("Failed to open CSV file: {}", e)))?;
258 let reader = BufReader::new(file);
259 let mut lines = reader.lines();
260
261 let mut feature_names = None;
262 let mut target_names = None;
263
264 if self.has_header {
266 if let Some(Ok(header)) = lines.next() {
267 let names: Vec<String> = header
268 .split(self.delimiter)
269 .map(|s| s.trim().to_string())
270 .collect();
271
272 let mut feat_names = Vec::new();
273 let mut targ_names = Vec::new();
274
275 for (i, name) in names.into_iter().enumerate() {
276 if self.skip_columns.contains(&i) {
277 continue;
278 }
279 if self.target_columns.contains(&i) {
280 targ_names.push(name);
281 } else {
282 feat_names.push(name);
283 }
284 }
285
286 feature_names = Some(feat_names);
287 target_names = Some(targ_names);
288 }
289 }
290
291 let mut features_data: Vec<Vec<f64>> = Vec::new();
293 let mut targets_data: Vec<Vec<f64>> = Vec::new();
294
295 for line_result in lines {
296 let line = line_result
297 .map_err(|e| TrainError::Other(format!("Failed to read CSV line: {}", e)))?;
298
299 if line.trim().is_empty() {
300 continue;
301 }
302
303 let values: Vec<&str> = line.split(self.delimiter).collect();
304 let mut row_features = Vec::new();
305 let mut row_targets = Vec::new();
306
307 for (i, value) in values.iter().enumerate() {
308 if self.skip_columns.contains(&i) {
309 continue;
310 }
311
312 let parsed: f64 = value.trim().parse().map_err(|e| {
313 TrainError::Other(format!("Failed to parse value '{}': {}", value, e))
314 })?;
315
316 if self.target_columns.contains(&i) {
317 row_targets.push(parsed);
318 } else {
319 row_features.push(parsed);
320 }
321 }
322
323 features_data.push(row_features);
324 targets_data.push(row_targets);
325 }
326
327 if features_data.is_empty() {
328 return Err(TrainError::Other("CSV file is empty".to_string()));
329 }
330
331 let n_samples = features_data.len();
332 let n_features = features_data[0].len();
333 let n_targets = if targets_data[0].is_empty() {
334 0
335 } else {
336 targets_data[0].len()
337 };
338
339 let features = Array2::from_shape_fn((n_samples, n_features), |(i, j)| features_data[i][j]);
341
342 let targets = if n_targets > 0 {
343 Array2::from_shape_fn((n_samples, n_targets), |(i, j)| targets_data[i][j])
344 } else {
345 Array2::zeros((n_samples, 1))
346 };
347
348 let mut dataset = Dataset::new(features, targets);
349 dataset.feature_names = feature_names;
350 dataset.target_names = target_names;
351
352 Ok(dataset)
353 }
354}
355
356#[derive(Debug, Clone)]
358pub struct DataPreprocessor {
359 method: PreprocessingMethod,
361 params: Option<PreprocessingParams>,
363}
364
365#[derive(Debug, Clone, Copy, PartialEq, Eq)]
367pub enum PreprocessingMethod {
368 Standardize,
370 MinMaxNormalize,
372 MinMaxScale { min: i32, max: i32 },
374 None,
376}
377
378#[derive(Debug, Clone)]
380struct PreprocessingParams {
381 means: Array1<f64>,
382 stds: Array1<f64>,
383 mins: Array1<f64>,
384 maxs: Array1<f64>,
385}
386
387impl DataPreprocessor {
388 pub fn standardize() -> Self {
390 Self {
391 method: PreprocessingMethod::Standardize,
392 params: None,
393 }
394 }
395
396 pub fn min_max_normalize() -> Self {
398 Self {
399 method: PreprocessingMethod::MinMaxNormalize,
400 params: None,
401 }
402 }
403
404 pub fn min_max_scale(min: i32, max: i32) -> Self {
406 Self {
407 method: PreprocessingMethod::MinMaxScale { min, max },
408 params: None,
409 }
410 }
411
412 pub fn none() -> Self {
414 Self {
415 method: PreprocessingMethod::None,
416 params: None,
417 }
418 }
419
420 pub fn fit(&mut self, data: &Array2<f64>) -> &mut Self {
422 let n_features = data.ncols();
423
424 let mut means = Array1::zeros(n_features);
425 let mut stds = Array1::zeros(n_features);
426 let mut mins = Array1::from_elem(n_features, f64::INFINITY);
427 let mut maxs = Array1::from_elem(n_features, f64::NEG_INFINITY);
428
429 for j in 0..n_features {
430 let col = data.column(j);
431 let n = col.len() as f64;
432
433 let mean: f64 = col.iter().sum::<f64>() / n;
435 means[j] = mean;
436
437 let variance: f64 = col.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
439 stds[j] = variance.sqrt().max(1e-8); for &x in col.iter() {
443 if x < mins[j] {
444 mins[j] = x;
445 }
446 if x > maxs[j] {
447 maxs[j] = x;
448 }
449 }
450 }
451
452 self.params = Some(PreprocessingParams {
453 means,
454 stds,
455 mins,
456 maxs,
457 });
458
459 self
460 }
461
462 pub fn transform(&self, data: &Array2<f64>) -> TrainResult<Array2<f64>> {
464 let params = self.params.as_ref().ok_or_else(|| {
465 TrainError::Other("Preprocessor not fitted. Call fit() first.".to_string())
466 })?;
467
468 let mut result = data.clone();
469
470 match self.method {
471 PreprocessingMethod::Standardize => {
472 for j in 0..data.ncols() {
473 for i in 0..data.nrows() {
474 result[[i, j]] = (data[[i, j]] - params.means[j]) / params.stds[j];
475 }
476 }
477 }
478 PreprocessingMethod::MinMaxNormalize => {
479 for j in 0..data.ncols() {
480 let range = (params.maxs[j] - params.mins[j]).max(1e-8);
481 for i in 0..data.nrows() {
482 result[[i, j]] = (data[[i, j]] - params.mins[j]) / range;
483 }
484 }
485 }
486 PreprocessingMethod::MinMaxScale { min, max } => {
487 let target_range = (max - min) as f64;
488 for j in 0..data.ncols() {
489 let range = (params.maxs[j] - params.mins[j]).max(1e-8);
490 for i in 0..data.nrows() {
491 let normalized = (data[[i, j]] - params.mins[j]) / range;
492 result[[i, j]] = normalized * target_range + min as f64;
493 }
494 }
495 }
496 PreprocessingMethod::None => {}
497 }
498
499 Ok(result)
500 }
501
502 pub fn fit_transform(&mut self, data: &Array2<f64>) -> TrainResult<Array2<f64>> {
504 self.fit(data);
505 self.transform(data)
506 }
507
508 pub fn inverse_transform(&self, data: &Array2<f64>) -> TrainResult<Array2<f64>> {
510 let params = self.params.as_ref().ok_or_else(|| {
511 TrainError::Other("Preprocessor not fitted. Call fit() first.".to_string())
512 })?;
513
514 let mut result = data.clone();
515
516 match self.method {
517 PreprocessingMethod::Standardize => {
518 for j in 0..data.ncols() {
519 for i in 0..data.nrows() {
520 result[[i, j]] = data[[i, j]] * params.stds[j] + params.means[j];
521 }
522 }
523 }
524 PreprocessingMethod::MinMaxNormalize => {
525 for j in 0..data.ncols() {
526 let range = params.maxs[j] - params.mins[j];
527 for i in 0..data.nrows() {
528 result[[i, j]] = data[[i, j]] * range + params.mins[j];
529 }
530 }
531 }
532 PreprocessingMethod::MinMaxScale { min, max } => {
533 let target_range = (max - min) as f64;
534 for j in 0..data.ncols() {
535 let range = params.maxs[j] - params.mins[j];
536 for i in 0..data.nrows() {
537 let normalized = (data[[i, j]] - min as f64) / target_range;
538 result[[i, j]] = normalized * range + params.mins[j];
539 }
540 }
541 }
542 PreprocessingMethod::None => {}
543 }
544
545 Ok(result)
546 }
547
548 pub fn is_fitted(&self) -> bool {
550 self.params.is_some()
551 }
552
553 pub fn method(&self) -> PreprocessingMethod {
555 self.method
556 }
557}
558
559#[derive(Debug, Clone)]
561pub struct OneHotEncoder {
562 categories: HashMap<usize, HashMap<String, usize>>,
564 n_categories: HashMap<usize, usize>,
566}
567
568impl OneHotEncoder {
569 pub fn new() -> Self {
571 Self {
572 categories: HashMap::new(),
573 n_categories: HashMap::new(),
574 }
575 }
576
577 pub fn fit(&mut self, data: &[(usize, Vec<String>)]) -> &mut Self {
582 for (col_idx, values) in data {
583 let mut categories = HashMap::new();
584 let mut unique_values: Vec<&String> = values.iter().collect();
585 unique_values.sort();
586 unique_values.dedup();
587
588 for (i, value) in unique_values.into_iter().enumerate() {
589 categories.insert(value.clone(), i);
590 }
591
592 self.n_categories.insert(*col_idx, categories.len());
593 self.categories.insert(*col_idx, categories);
594 }
595
596 self
597 }
598
599 pub fn transform(&self, col_idx: usize, values: &[String]) -> TrainResult<Array2<f64>> {
601 let categories = self
602 .categories
603 .get(&col_idx)
604 .ok_or_else(|| TrainError::Other(format!("Column {} not fitted", col_idx)))?;
605
606 let n_samples = values.len();
607 let n_cats = *self
608 .n_categories
609 .get(&col_idx)
610 .expect("n_categories populated during fit for all fitted columns");
611
612 let mut result = Array2::zeros((n_samples, n_cats));
613
614 for (i, value) in values.iter().enumerate() {
615 if let Some(&idx) = categories.get(value) {
616 result[[i, idx]] = 1.0;
617 } else {
618 return Err(TrainError::Other(format!(
619 "Unknown category '{}' for column {}",
620 value, col_idx
621 )));
622 }
623 }
624
625 Ok(result)
626 }
627
628 pub fn num_categories(&self, col_idx: usize) -> Option<usize> {
630 self.n_categories.get(&col_idx).copied()
631 }
632}
633
634impl Default for OneHotEncoder {
635 fn default() -> Self {
636 Self::new()
637 }
638}
639
640#[derive(Debug, Clone)]
642pub struct LabelEncoder {
643 label_to_int: HashMap<String, usize>,
645 int_to_label: Vec<String>,
647}
648
649impl LabelEncoder {
650 pub fn new() -> Self {
652 Self {
653 label_to_int: HashMap::new(),
654 int_to_label: Vec::new(),
655 }
656 }
657
658 pub fn fit(&mut self, labels: &[String]) -> &mut Self {
660 let mut unique: Vec<&String> = labels.iter().collect();
661 unique.sort();
662 unique.dedup();
663
664 self.label_to_int.clear();
665 self.int_to_label.clear();
666
667 for (i, label) in unique.into_iter().enumerate() {
668 self.label_to_int.insert(label.clone(), i);
669 self.int_to_label.push(label.clone());
670 }
671
672 self
673 }
674
675 pub fn transform(&self, labels: &[String]) -> TrainResult<Array1<usize>> {
677 let mut result = Array1::zeros(labels.len());
678
679 for (i, label) in labels.iter().enumerate() {
680 result[i] = *self
681 .label_to_int
682 .get(label)
683 .ok_or_else(|| TrainError::Other(format!("Unknown label: {}", label)))?;
684 }
685
686 Ok(result)
687 }
688
689 pub fn inverse_transform(&self, indices: &Array1<usize>) -> TrainResult<Vec<String>> {
691 let mut result = Vec::with_capacity(indices.len());
692
693 for &idx in indices.iter() {
694 if idx >= self.int_to_label.len() {
695 return Err(TrainError::Other(format!(
696 "Index {} out of bounds for {} classes",
697 idx,
698 self.int_to_label.len()
699 )));
700 }
701 result.push(self.int_to_label[idx].clone());
702 }
703
704 Ok(result)
705 }
706
707 pub fn fit_transform(&mut self, labels: &[String]) -> TrainResult<Array1<usize>> {
709 self.fit(labels);
710 self.transform(labels)
711 }
712
713 pub fn num_classes(&self) -> usize {
715 self.int_to_label.len()
716 }
717
718 pub fn classes(&self) -> &[String] {
720 &self.int_to_label
721 }
722}
723
724impl Default for LabelEncoder {
725 fn default() -> Self {
726 Self::new()
727 }
728}
729
730#[cfg(test)]
731mod tests {
732 use super::*;
733
734 #[test]
735 fn test_dataset_creation() {
736 let features =
737 Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("unwrap");
738 let targets = Array2::from_shape_vec((3, 1), vec![0.0, 1.0, 0.0]).expect("unwrap");
739
740 let dataset = Dataset::new(features, targets);
741
742 assert_eq!(dataset.num_samples(), 3);
743 assert_eq!(dataset.num_features(), 2);
744 assert_eq!(dataset.num_targets(), 1);
745 }
746
747 #[test]
748 fn test_dataset_split() {
749 let features = Array2::from_shape_fn((10, 2), |(i, j)| (i * 2 + j) as f64);
750 let targets = Array2::from_shape_fn((10, 1), |(i, _)| i as f64);
751
752 let dataset = Dataset::new(features, targets);
753 let splits = dataset.split(&[0.6, 0.2, 0.2]).expect("unwrap");
754
755 assert_eq!(splits.len(), 3);
756 assert_eq!(splits[0].num_samples(), 6);
757 assert_eq!(splits[1].num_samples(), 2);
758 assert_eq!(splits[2].num_samples(), 2);
759 }
760
761 #[test]
762 fn test_train_test_split() {
763 let features = Array2::from_shape_fn((100, 4), |(i, j)| (i * 4 + j) as f64);
764 let targets = Array2::from_shape_fn((100, 1), |(i, _)| (i % 2) as f64);
765
766 let dataset = Dataset::new(features, targets);
767 let (train, test) = dataset.train_test_split(0.8).expect("unwrap");
768
769 assert_eq!(train.num_samples(), 80);
770 assert_eq!(test.num_samples(), 20);
771 }
772
773 #[test]
774 fn test_dataset_shuffle() {
775 let features = Array2::from_shape_fn((10, 2), |(i, j)| (i * 2 + j) as f64);
776 let targets = Array2::from_shape_fn((10, 1), |(i, _)| i as f64);
777
778 let mut dataset = Dataset::new(features.clone(), targets);
779 dataset.shuffle(42);
780
781 let mut different = false;
783 for i in 0..10 {
784 if dataset.features[[i, 0]] != features[[i, 0]] {
785 different = true;
786 break;
787 }
788 }
789 assert!(different);
790 }
791
792 #[test]
793 fn test_dataset_subset() {
794 let features = Array2::from_shape_fn((10, 2), |(i, j)| (i * 2 + j) as f64);
795 let targets = Array2::from_shape_fn((10, 1), |(i, _)| i as f64);
796
797 let dataset = Dataset::new(features, targets);
798 let subset = dataset.subset(&[0, 2, 4]).expect("unwrap");
799
800 assert_eq!(subset.num_samples(), 3);
801 assert_eq!(subset.features[[0, 0]], 0.0);
802 assert_eq!(subset.features[[1, 0]], 4.0);
803 assert_eq!(subset.features[[2, 0]], 8.0);
804 }
805
806 #[test]
807 fn test_preprocessor_standardize() {
808 let data = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
809 .expect("unwrap");
810
811 let mut preprocessor = DataPreprocessor::standardize();
812 let transformed = preprocessor.fit_transform(&data).expect("unwrap");
813
814 let col0_mean: f64 = transformed.column(0).iter().sum::<f64>() / 4.0;
816 let col1_mean: f64 = transformed.column(1).iter().sum::<f64>() / 4.0;
817
818 assert!(col0_mean.abs() < 1e-10);
819 assert!(col1_mean.abs() < 1e-10);
820
821 let recovered = preprocessor
823 .inverse_transform(&transformed)
824 .expect("unwrap");
825 for i in 0..4 {
826 for j in 0..2 {
827 assert!((recovered[[i, j]] - data[[i, j]]).abs() < 1e-10);
828 }
829 }
830 }
831
832 #[test]
833 fn test_preprocessor_min_max() {
834 let data =
835 Array2::from_shape_vec((4, 2), vec![0.0, 10.0, 5.0, 20.0, 10.0, 30.0, 15.0, 40.0])
836 .expect("unwrap");
837
838 let mut preprocessor = DataPreprocessor::min_max_normalize();
839 let transformed = preprocessor.fit_transform(&data).expect("unwrap");
840
841 for &val in transformed.iter() {
843 assert!((0.0..=1.0).contains(&val));
844 }
845
846 assert!((transformed[[0, 0]] - 0.0).abs() < 1e-10); assert!((transformed[[3, 0]] - 1.0).abs() < 1e-10); }
850
851 #[test]
852 fn test_label_encoder() {
853 let labels = vec![
854 "cat".to_string(),
855 "dog".to_string(),
856 "cat".to_string(),
857 "bird".to_string(),
858 ];
859
860 let mut encoder = LabelEncoder::new();
861 let encoded = encoder.fit_transform(&labels).expect("unwrap");
862
863 assert_eq!(encoder.num_classes(), 3);
864 assert_eq!(encoded.len(), 4);
865
866 assert_eq!(encoded[0], encoded[2]);
868
869 let decoded = encoder.inverse_transform(&encoded).expect("unwrap");
871 assert_eq!(decoded, labels);
872 }
873
874 #[test]
875 fn test_one_hot_encoder() {
876 let values = vec![
877 "red".to_string(),
878 "green".to_string(),
879 "blue".to_string(),
880 "red".to_string(),
881 ];
882
883 let mut encoder = OneHotEncoder::new();
884 encoder.fit(&[(0, values.clone())]);
885
886 let encoded = encoder.transform(0, &values).expect("unwrap");
887
888 assert_eq!(encoded.nrows(), 4);
889 assert_eq!(encoded.ncols(), 3);
890
891 for i in 0..4 {
893 let row_sum: f64 = encoded.row(i).iter().sum();
894 assert!((row_sum - 1.0).abs() < 1e-10);
895 }
896 }
897
898 #[test]
899 fn test_csv_loader_builder() {
900 let loader = CsvLoader::new()
901 .with_header(true)
902 .with_delimiter(',')
903 .with_target_columns(vec![3]);
904
905 assert!(loader.has_header);
906 assert_eq!(loader.delimiter, ',');
907 assert_eq!(loader.target_columns, vec![3]);
908 }
909
910 #[test]
911 fn test_invalid_split_ratios() {
912 let features = Array2::zeros((10, 2));
913 let targets = Array2::zeros((10, 1));
914 let dataset = Dataset::new(features, targets);
915
916 let result = dataset.split(&[0.5, 0.3]);
918 assert!(result.is_err());
919 }
920
921 #[test]
922 fn test_preprocessor_not_fitted() {
923 let data = Array2::zeros((4, 2));
924 let preprocessor = DataPreprocessor::standardize();
925
926 let result = preprocessor.transform(&data);
927 assert!(result.is_err());
928 }
929}