Skip to main content

tensorlogic_train/
data.rs

1//! Data loading and preprocessing utilities for training.
2//!
3//! This module provides tools for loading and preprocessing training data:
4//! - CSV and JSON data loading
5//! - Data normalization and standardization
6//! - Train/validation/test splitting
7//! - Data shuffling and sampling
8
9use crate::{TrainError, TrainResult};
10use scirs2_core::ndarray::{s, Array1, Array2};
11use std::collections::HashMap;
12use std::fs::File;
13use std::io::{BufRead, BufReader};
14use std::path::Path;
15
16/// Dataset container for training data.
17#[derive(Debug, Clone)]
18pub struct Dataset {
19    /// Feature matrix (samples x features).
20    pub features: Array2<f64>,
21    /// Target vector or matrix.
22    pub targets: Array2<f64>,
23    /// Feature names (if available).
24    pub feature_names: Option<Vec<String>>,
25    /// Target names (if available).
26    pub target_names: Option<Vec<String>>,
27}
28
29impl Dataset {
30    /// Create a new dataset.
31    pub fn new(features: Array2<f64>, targets: Array2<f64>) -> Self {
32        Self {
33            features,
34            targets,
35            feature_names: None,
36            target_names: None,
37        }
38    }
39
40    /// Set feature names.
41    pub fn with_feature_names(mut self, names: Vec<String>) -> Self {
42        self.feature_names = Some(names);
43        self
44    }
45
46    /// Set target names.
47    pub fn with_target_names(mut self, names: Vec<String>) -> Self {
48        self.target_names = Some(names);
49        self
50    }
51
52    /// Get number of samples.
53    pub fn num_samples(&self) -> usize {
54        self.features.nrows()
55    }
56
57    /// Get number of features.
58    pub fn num_features(&self) -> usize {
59        self.features.ncols()
60    }
61
62    /// Get number of targets.
63    pub fn num_targets(&self) -> usize {
64        self.targets.ncols()
65    }
66
67    /// Shuffle the dataset in place using Fisher-Yates algorithm.
68    pub fn shuffle(&mut self, seed: u64) {
69        let n = self.num_samples();
70        if n <= 1 {
71            return;
72        }
73
74        // Simple LCG for deterministic shuffling
75        let mut rng_state = seed;
76        let lcg_next = |state: &mut u64| -> usize {
77            *state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
78            (*state >> 33) as usize
79        };
80
81        for i in (1..n).rev() {
82            let j = lcg_next(&mut rng_state) % (i + 1);
83            // Swap rows
84            for col in 0..self.features.ncols() {
85                let tmp = self.features[[i, col]];
86                self.features[[i, col]] = self.features[[j, col]];
87                self.features[[j, col]] = tmp;
88            }
89            for col in 0..self.targets.ncols() {
90                let tmp = self.targets[[i, col]];
91                self.targets[[i, col]] = self.targets[[j, col]];
92                self.targets[[j, col]] = tmp;
93            }
94        }
95    }
96
97    /// Split dataset into subsets.
98    ///
99    /// # Arguments
100    /// * `ratios` - Ratios for each split (must sum to 1.0)
101    ///
102    /// # Returns
103    /// Vector of datasets corresponding to each ratio
104    pub fn split(&self, ratios: &[f64]) -> TrainResult<Vec<Dataset>> {
105        let total: f64 = ratios.iter().sum();
106        if (total - 1.0).abs() > 1e-6 {
107            return Err(TrainError::ConfigError(format!(
108                "Split ratios must sum to 1.0, got {}",
109                total
110            )));
111        }
112
113        let n = self.num_samples();
114        let mut splits = Vec::new();
115        let mut start = 0;
116
117        for (i, &ratio) in ratios.iter().enumerate() {
118            let end = if i == ratios.len() - 1 {
119                n // Last split gets remaining samples
120            } else {
121                start + (n as f64 * ratio).round() as usize
122            };
123
124            let features = self.features.slice(s![start..end, ..]).to_owned();
125            let targets = self.targets.slice(s![start..end, ..]).to_owned();
126
127            let mut dataset = Dataset::new(features, targets);
128            if let Some(ref names) = self.feature_names {
129                dataset.feature_names = Some(names.clone());
130            }
131            if let Some(ref names) = self.target_names {
132                dataset.target_names = Some(names.clone());
133            }
134
135            splits.push(dataset);
136            start = end;
137        }
138
139        Ok(splits)
140    }
141
142    /// Split into train and test sets.
143    pub fn train_test_split(&self, train_ratio: f64) -> TrainResult<(Dataset, Dataset)> {
144        let splits = self.split(&[train_ratio, 1.0 - train_ratio])?;
145        let mut iter = splits.into_iter();
146        Ok((
147            iter.next().expect("split returns exactly 2 parts"),
148            iter.next().expect("split returns exactly 2 parts"),
149        ))
150    }
151
152    /// Split into train, validation, and test sets.
153    pub fn train_val_test_split(
154        &self,
155        train_ratio: f64,
156        val_ratio: f64,
157    ) -> TrainResult<(Dataset, Dataset, Dataset)> {
158        let test_ratio = 1.0 - train_ratio - val_ratio;
159        if test_ratio < 0.0 {
160            return Err(TrainError::ConfigError(
161                "Train and validation ratios exceed 1.0".to_string(),
162            ));
163        }
164        let splits = self.split(&[train_ratio, val_ratio, test_ratio])?;
165        let mut iter = splits.into_iter();
166        Ok((
167            iter.next().expect("split returns exactly 3 parts"),
168            iter.next().expect("split returns exactly 3 parts"),
169            iter.next().expect("split returns exactly 3 parts"),
170        ))
171    }
172
173    /// Get a subset of the dataset by indices.
174    pub fn subset(&self, indices: &[usize]) -> TrainResult<Dataset> {
175        let n = self.num_samples();
176        for &idx in indices {
177            if idx >= n {
178                return Err(TrainError::ConfigError(format!(
179                    "Index {} out of bounds for dataset with {} samples",
180                    idx, n
181                )));
182            }
183        }
184
185        let features = Array2::from_shape_fn((indices.len(), self.num_features()), |(i, j)| {
186            self.features[[indices[i], j]]
187        });
188        let targets = Array2::from_shape_fn((indices.len(), self.num_targets()), |(i, j)| {
189            self.targets[[indices[i], j]]
190        });
191
192        let mut dataset = Dataset::new(features, targets);
193        dataset.feature_names = self.feature_names.clone();
194        dataset.target_names = self.target_names.clone();
195
196        Ok(dataset)
197    }
198}
199
200/// CSV data loader.
201#[derive(Debug, Clone)]
202pub struct CsvLoader {
203    /// Whether the CSV has a header row.
204    pub has_header: bool,
205    /// Delimiter character.
206    pub delimiter: char,
207    /// Indices of target columns (0-based).
208    pub target_columns: Vec<usize>,
209    /// Columns to skip.
210    pub skip_columns: Vec<usize>,
211}
212
213impl Default for CsvLoader {
214    fn default() -> Self {
215        Self {
216            has_header: true,
217            delimiter: ',',
218            target_columns: vec![],
219            skip_columns: vec![],
220        }
221    }
222}
223
224impl CsvLoader {
225    /// Create a new CSV loader.
226    pub fn new() -> Self {
227        Self::default()
228    }
229
230    /// Set whether CSV has a header.
231    pub fn with_header(mut self, has_header: bool) -> Self {
232        self.has_header = has_header;
233        self
234    }
235
236    /// Set the delimiter character.
237    pub fn with_delimiter(mut self, delimiter: char) -> Self {
238        self.delimiter = delimiter;
239        self
240    }
241
242    /// Set target column indices.
243    pub fn with_target_columns(mut self, columns: Vec<usize>) -> Self {
244        self.target_columns = columns;
245        self
246    }
247
248    /// Set columns to skip.
249    pub fn with_skip_columns(mut self, columns: Vec<usize>) -> Self {
250        self.skip_columns = columns;
251        self
252    }
253
254    /// Load data from a CSV file.
255    pub fn load<P: AsRef<Path>>(&self, path: P) -> TrainResult<Dataset> {
256        let file = File::open(path.as_ref())
257            .map_err(|e| TrainError::Other(format!("Failed to open CSV file: {}", e)))?;
258        let reader = BufReader::new(file);
259        let mut lines = reader.lines();
260
261        let mut feature_names = None;
262        let mut target_names = None;
263
264        // Parse header if present
265        if self.has_header {
266            if let Some(Ok(header)) = lines.next() {
267                let names: Vec<String> = header
268                    .split(self.delimiter)
269                    .map(|s| s.trim().to_string())
270                    .collect();
271
272                let mut feat_names = Vec::new();
273                let mut targ_names = Vec::new();
274
275                for (i, name) in names.into_iter().enumerate() {
276                    if self.skip_columns.contains(&i) {
277                        continue;
278                    }
279                    if self.target_columns.contains(&i) {
280                        targ_names.push(name);
281                    } else {
282                        feat_names.push(name);
283                    }
284                }
285
286                feature_names = Some(feat_names);
287                target_names = Some(targ_names);
288            }
289        }
290
291        // Parse data rows
292        let mut features_data: Vec<Vec<f64>> = Vec::new();
293        let mut targets_data: Vec<Vec<f64>> = Vec::new();
294
295        for line_result in lines {
296            let line = line_result
297                .map_err(|e| TrainError::Other(format!("Failed to read CSV line: {}", e)))?;
298
299            if line.trim().is_empty() {
300                continue;
301            }
302
303            let values: Vec<&str> = line.split(self.delimiter).collect();
304            let mut row_features = Vec::new();
305            let mut row_targets = Vec::new();
306
307            for (i, value) in values.iter().enumerate() {
308                if self.skip_columns.contains(&i) {
309                    continue;
310                }
311
312                let parsed: f64 = value.trim().parse().map_err(|e| {
313                    TrainError::Other(format!("Failed to parse value '{}': {}", value, e))
314                })?;
315
316                if self.target_columns.contains(&i) {
317                    row_targets.push(parsed);
318                } else {
319                    row_features.push(parsed);
320                }
321            }
322
323            features_data.push(row_features);
324            targets_data.push(row_targets);
325        }
326
327        if features_data.is_empty() {
328            return Err(TrainError::Other("CSV file is empty".to_string()));
329        }
330
331        let n_samples = features_data.len();
332        let n_features = features_data[0].len();
333        let n_targets = if targets_data[0].is_empty() {
334            0
335        } else {
336            targets_data[0].len()
337        };
338
339        // Convert to arrays
340        let features = Array2::from_shape_fn((n_samples, n_features), |(i, j)| features_data[i][j]);
341
342        let targets = if n_targets > 0 {
343            Array2::from_shape_fn((n_samples, n_targets), |(i, j)| targets_data[i][j])
344        } else {
345            Array2::zeros((n_samples, 1))
346        };
347
348        let mut dataset = Dataset::new(features, targets);
349        dataset.feature_names = feature_names;
350        dataset.target_names = target_names;
351
352        Ok(dataset)
353    }
354}
355
356/// Data preprocessor for normalization and standardization.
357#[derive(Debug, Clone)]
358pub struct DataPreprocessor {
359    /// Preprocessing method.
360    method: PreprocessingMethod,
361    /// Fitted parameters (mean, std, min, max).
362    params: Option<PreprocessingParams>,
363}
364
365/// Preprocessing method.
366#[derive(Debug, Clone, Copy, PartialEq, Eq)]
367pub enum PreprocessingMethod {
368    /// Standardization (zero mean, unit variance).
369    Standardize,
370    /// Min-max normalization to [0, 1].
371    MinMaxNormalize,
372    /// Min-max scaling to custom range.
373    MinMaxScale { min: i32, max: i32 },
374    /// No preprocessing.
375    None,
376}
377
378/// Fitted preprocessing parameters.
379#[derive(Debug, Clone)]
380struct PreprocessingParams {
381    means: Array1<f64>,
382    stds: Array1<f64>,
383    mins: Array1<f64>,
384    maxs: Array1<f64>,
385}
386
387impl DataPreprocessor {
388    /// Create a new preprocessor with standardization.
389    pub fn standardize() -> Self {
390        Self {
391            method: PreprocessingMethod::Standardize,
392            params: None,
393        }
394    }
395
396    /// Create a new preprocessor with min-max normalization.
397    pub fn min_max_normalize() -> Self {
398        Self {
399            method: PreprocessingMethod::MinMaxNormalize,
400            params: None,
401        }
402    }
403
404    /// Create a new preprocessor with custom min-max scaling.
405    pub fn min_max_scale(min: i32, max: i32) -> Self {
406        Self {
407            method: PreprocessingMethod::MinMaxScale { min, max },
408            params: None,
409        }
410    }
411
412    /// Create a preprocessor that does nothing.
413    pub fn none() -> Self {
414        Self {
415            method: PreprocessingMethod::None,
416            params: None,
417        }
418    }
419
420    /// Fit the preprocessor to data.
421    pub fn fit(&mut self, data: &Array2<f64>) -> &mut Self {
422        let n_features = data.ncols();
423
424        let mut means = Array1::zeros(n_features);
425        let mut stds = Array1::zeros(n_features);
426        let mut mins = Array1::from_elem(n_features, f64::INFINITY);
427        let mut maxs = Array1::from_elem(n_features, f64::NEG_INFINITY);
428
429        for j in 0..n_features {
430            let col = data.column(j);
431            let n = col.len() as f64;
432
433            // Compute mean
434            let mean: f64 = col.iter().sum::<f64>() / n;
435            means[j] = mean;
436
437            // Compute std
438            let variance: f64 = col.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
439            stds[j] = variance.sqrt().max(1e-8); // Avoid division by zero
440
441            // Compute min/max
442            for &x in col.iter() {
443                if x < mins[j] {
444                    mins[j] = x;
445                }
446                if x > maxs[j] {
447                    maxs[j] = x;
448                }
449            }
450        }
451
452        self.params = Some(PreprocessingParams {
453            means,
454            stds,
455            mins,
456            maxs,
457        });
458
459        self
460    }
461
462    /// Transform data using fitted parameters.
463    pub fn transform(&self, data: &Array2<f64>) -> TrainResult<Array2<f64>> {
464        let params = self.params.as_ref().ok_or_else(|| {
465            TrainError::Other("Preprocessor not fitted. Call fit() first.".to_string())
466        })?;
467
468        let mut result = data.clone();
469
470        match self.method {
471            PreprocessingMethod::Standardize => {
472                for j in 0..data.ncols() {
473                    for i in 0..data.nrows() {
474                        result[[i, j]] = (data[[i, j]] - params.means[j]) / params.stds[j];
475                    }
476                }
477            }
478            PreprocessingMethod::MinMaxNormalize => {
479                for j in 0..data.ncols() {
480                    let range = (params.maxs[j] - params.mins[j]).max(1e-8);
481                    for i in 0..data.nrows() {
482                        result[[i, j]] = (data[[i, j]] - params.mins[j]) / range;
483                    }
484                }
485            }
486            PreprocessingMethod::MinMaxScale { min, max } => {
487                let target_range = (max - min) as f64;
488                for j in 0..data.ncols() {
489                    let range = (params.maxs[j] - params.mins[j]).max(1e-8);
490                    for i in 0..data.nrows() {
491                        let normalized = (data[[i, j]] - params.mins[j]) / range;
492                        result[[i, j]] = normalized * target_range + min as f64;
493                    }
494                }
495            }
496            PreprocessingMethod::None => {}
497        }
498
499        Ok(result)
500    }
501
502    /// Fit and transform in one step.
503    pub fn fit_transform(&mut self, data: &Array2<f64>) -> TrainResult<Array2<f64>> {
504        self.fit(data);
505        self.transform(data)
506    }
507
508    /// Inverse transform to original scale.
509    pub fn inverse_transform(&self, data: &Array2<f64>) -> TrainResult<Array2<f64>> {
510        let params = self.params.as_ref().ok_or_else(|| {
511            TrainError::Other("Preprocessor not fitted. Call fit() first.".to_string())
512        })?;
513
514        let mut result = data.clone();
515
516        match self.method {
517            PreprocessingMethod::Standardize => {
518                for j in 0..data.ncols() {
519                    for i in 0..data.nrows() {
520                        result[[i, j]] = data[[i, j]] * params.stds[j] + params.means[j];
521                    }
522                }
523            }
524            PreprocessingMethod::MinMaxNormalize => {
525                for j in 0..data.ncols() {
526                    let range = params.maxs[j] - params.mins[j];
527                    for i in 0..data.nrows() {
528                        result[[i, j]] = data[[i, j]] * range + params.mins[j];
529                    }
530                }
531            }
532            PreprocessingMethod::MinMaxScale { min, max } => {
533                let target_range = (max - min) as f64;
534                for j in 0..data.ncols() {
535                    let range = params.maxs[j] - params.mins[j];
536                    for i in 0..data.nrows() {
537                        let normalized = (data[[i, j]] - min as f64) / target_range;
538                        result[[i, j]] = normalized * range + params.mins[j];
539                    }
540                }
541            }
542            PreprocessingMethod::None => {}
543        }
544
545        Ok(result)
546    }
547
548    /// Check if the preprocessor is fitted.
549    pub fn is_fitted(&self) -> bool {
550        self.params.is_some()
551    }
552
553    /// Get the preprocessing method.
554    pub fn method(&self) -> PreprocessingMethod {
555        self.method
556    }
557}
558
559/// One-hot encoder for categorical data.
560#[derive(Debug, Clone)]
561pub struct OneHotEncoder {
562    /// Mapping from category to index for each column.
563    categories: HashMap<usize, HashMap<String, usize>>,
564    /// Number of categories per column.
565    n_categories: HashMap<usize, usize>,
566}
567
568impl OneHotEncoder {
569    /// Create a new one-hot encoder.
570    pub fn new() -> Self {
571        Self {
572            categories: HashMap::new(),
573            n_categories: HashMap::new(),
574        }
575    }
576
577    /// Fit the encoder to categorical data.
578    ///
579    /// # Arguments
580    /// * `data` - Vector of (column_index, values) pairs
581    pub fn fit(&mut self, data: &[(usize, Vec<String>)]) -> &mut Self {
582        for (col_idx, values) in data {
583            let mut categories = HashMap::new();
584            let mut unique_values: Vec<&String> = values.iter().collect();
585            unique_values.sort();
586            unique_values.dedup();
587
588            for (i, value) in unique_values.into_iter().enumerate() {
589                categories.insert(value.clone(), i);
590            }
591
592            self.n_categories.insert(*col_idx, categories.len());
593            self.categories.insert(*col_idx, categories);
594        }
595
596        self
597    }
598
599    /// Transform categorical column to one-hot encoded array.
600    pub fn transform(&self, col_idx: usize, values: &[String]) -> TrainResult<Array2<f64>> {
601        let categories = self
602            .categories
603            .get(&col_idx)
604            .ok_or_else(|| TrainError::Other(format!("Column {} not fitted", col_idx)))?;
605
606        let n_samples = values.len();
607        let n_cats = *self
608            .n_categories
609            .get(&col_idx)
610            .expect("n_categories populated during fit for all fitted columns");
611
612        let mut result = Array2::zeros((n_samples, n_cats));
613
614        for (i, value) in values.iter().enumerate() {
615            if let Some(&idx) = categories.get(value) {
616                result[[i, idx]] = 1.0;
617            } else {
618                return Err(TrainError::Other(format!(
619                    "Unknown category '{}' for column {}",
620                    value, col_idx
621                )));
622            }
623        }
624
625        Ok(result)
626    }
627
628    /// Get number of categories for a column.
629    pub fn num_categories(&self, col_idx: usize) -> Option<usize> {
630        self.n_categories.get(&col_idx).copied()
631    }
632}
633
634impl Default for OneHotEncoder {
635    fn default() -> Self {
636        Self::new()
637    }
638}
639
640/// Label encoder for converting string labels to integers.
641#[derive(Debug, Clone)]
642pub struct LabelEncoder {
643    /// Mapping from label to integer.
644    label_to_int: HashMap<String, usize>,
645    /// Mapping from integer to label.
646    int_to_label: Vec<String>,
647}
648
649impl LabelEncoder {
650    /// Create a new label encoder.
651    pub fn new() -> Self {
652        Self {
653            label_to_int: HashMap::new(),
654            int_to_label: Vec::new(),
655        }
656    }
657
658    /// Fit the encoder to labels.
659    pub fn fit(&mut self, labels: &[String]) -> &mut Self {
660        let mut unique: Vec<&String> = labels.iter().collect();
661        unique.sort();
662        unique.dedup();
663
664        self.label_to_int.clear();
665        self.int_to_label.clear();
666
667        for (i, label) in unique.into_iter().enumerate() {
668            self.label_to_int.insert(label.clone(), i);
669            self.int_to_label.push(label.clone());
670        }
671
672        self
673    }
674
675    /// Transform labels to integers.
676    pub fn transform(&self, labels: &[String]) -> TrainResult<Array1<usize>> {
677        let mut result = Array1::zeros(labels.len());
678
679        for (i, label) in labels.iter().enumerate() {
680            result[i] = *self
681                .label_to_int
682                .get(label)
683                .ok_or_else(|| TrainError::Other(format!("Unknown label: {}", label)))?;
684        }
685
686        Ok(result)
687    }
688
689    /// Inverse transform integers to labels.
690    pub fn inverse_transform(&self, indices: &Array1<usize>) -> TrainResult<Vec<String>> {
691        let mut result = Vec::with_capacity(indices.len());
692
693        for &idx in indices.iter() {
694            if idx >= self.int_to_label.len() {
695                return Err(TrainError::Other(format!(
696                    "Index {} out of bounds for {} classes",
697                    idx,
698                    self.int_to_label.len()
699                )));
700            }
701            result.push(self.int_to_label[idx].clone());
702        }
703
704        Ok(result)
705    }
706
707    /// Fit and transform in one step.
708    pub fn fit_transform(&mut self, labels: &[String]) -> TrainResult<Array1<usize>> {
709        self.fit(labels);
710        self.transform(labels)
711    }
712
713    /// Get number of classes.
714    pub fn num_classes(&self) -> usize {
715        self.int_to_label.len()
716    }
717
718    /// Get class labels.
719    pub fn classes(&self) -> &[String] {
720        &self.int_to_label
721    }
722}
723
724impl Default for LabelEncoder {
725    fn default() -> Self {
726        Self::new()
727    }
728}
729
730#[cfg(test)]
731mod tests {
732    use super::*;
733
734    #[test]
735    fn test_dataset_creation() {
736        let features =
737            Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("unwrap");
738        let targets = Array2::from_shape_vec((3, 1), vec![0.0, 1.0, 0.0]).expect("unwrap");
739
740        let dataset = Dataset::new(features, targets);
741
742        assert_eq!(dataset.num_samples(), 3);
743        assert_eq!(dataset.num_features(), 2);
744        assert_eq!(dataset.num_targets(), 1);
745    }
746
747    #[test]
748    fn test_dataset_split() {
749        let features = Array2::from_shape_fn((10, 2), |(i, j)| (i * 2 + j) as f64);
750        let targets = Array2::from_shape_fn((10, 1), |(i, _)| i as f64);
751
752        let dataset = Dataset::new(features, targets);
753        let splits = dataset.split(&[0.6, 0.2, 0.2]).expect("unwrap");
754
755        assert_eq!(splits.len(), 3);
756        assert_eq!(splits[0].num_samples(), 6);
757        assert_eq!(splits[1].num_samples(), 2);
758        assert_eq!(splits[2].num_samples(), 2);
759    }
760
761    #[test]
762    fn test_train_test_split() {
763        let features = Array2::from_shape_fn((100, 4), |(i, j)| (i * 4 + j) as f64);
764        let targets = Array2::from_shape_fn((100, 1), |(i, _)| (i % 2) as f64);
765
766        let dataset = Dataset::new(features, targets);
767        let (train, test) = dataset.train_test_split(0.8).expect("unwrap");
768
769        assert_eq!(train.num_samples(), 80);
770        assert_eq!(test.num_samples(), 20);
771    }
772
773    #[test]
774    fn test_dataset_shuffle() {
775        let features = Array2::from_shape_fn((10, 2), |(i, j)| (i * 2 + j) as f64);
776        let targets = Array2::from_shape_fn((10, 1), |(i, _)| i as f64);
777
778        let mut dataset = Dataset::new(features.clone(), targets);
779        dataset.shuffle(42);
780
781        // After shuffle, data should be different
782        let mut different = false;
783        for i in 0..10 {
784            if dataset.features[[i, 0]] != features[[i, 0]] {
785                different = true;
786                break;
787            }
788        }
789        assert!(different);
790    }
791
792    #[test]
793    fn test_dataset_subset() {
794        let features = Array2::from_shape_fn((10, 2), |(i, j)| (i * 2 + j) as f64);
795        let targets = Array2::from_shape_fn((10, 1), |(i, _)| i as f64);
796
797        let dataset = Dataset::new(features, targets);
798        let subset = dataset.subset(&[0, 2, 4]).expect("unwrap");
799
800        assert_eq!(subset.num_samples(), 3);
801        assert_eq!(subset.features[[0, 0]], 0.0);
802        assert_eq!(subset.features[[1, 0]], 4.0);
803        assert_eq!(subset.features[[2, 0]], 8.0);
804    }
805
806    #[test]
807    fn test_preprocessor_standardize() {
808        let data = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
809            .expect("unwrap");
810
811        let mut preprocessor = DataPreprocessor::standardize();
812        let transformed = preprocessor.fit_transform(&data).expect("unwrap");
813
814        // Check that mean is approximately 0
815        let col0_mean: f64 = transformed.column(0).iter().sum::<f64>() / 4.0;
816        let col1_mean: f64 = transformed.column(1).iter().sum::<f64>() / 4.0;
817
818        assert!(col0_mean.abs() < 1e-10);
819        assert!(col1_mean.abs() < 1e-10);
820
821        // Check inverse transform
822        let recovered = preprocessor
823            .inverse_transform(&transformed)
824            .expect("unwrap");
825        for i in 0..4 {
826            for j in 0..2 {
827                assert!((recovered[[i, j]] - data[[i, j]]).abs() < 1e-10);
828            }
829        }
830    }
831
832    #[test]
833    fn test_preprocessor_min_max() {
834        let data =
835            Array2::from_shape_vec((4, 2), vec![0.0, 10.0, 5.0, 20.0, 10.0, 30.0, 15.0, 40.0])
836                .expect("unwrap");
837
838        let mut preprocessor = DataPreprocessor::min_max_normalize();
839        let transformed = preprocessor.fit_transform(&data).expect("unwrap");
840
841        // Check that values are in [0, 1]
842        for &val in transformed.iter() {
843            assert!((0.0..=1.0).contains(&val));
844        }
845
846        // Check specific values
847        assert!((transformed[[0, 0]] - 0.0).abs() < 1e-10); // min
848        assert!((transformed[[3, 0]] - 1.0).abs() < 1e-10); // max
849    }
850
851    #[test]
852    fn test_label_encoder() {
853        let labels = vec![
854            "cat".to_string(),
855            "dog".to_string(),
856            "cat".to_string(),
857            "bird".to_string(),
858        ];
859
860        let mut encoder = LabelEncoder::new();
861        let encoded = encoder.fit_transform(&labels).expect("unwrap");
862
863        assert_eq!(encoder.num_classes(), 3);
864        assert_eq!(encoded.len(), 4);
865
866        // Same labels should have same encoding
867        assert_eq!(encoded[0], encoded[2]);
868
869        // Test inverse transform
870        let decoded = encoder.inverse_transform(&encoded).expect("unwrap");
871        assert_eq!(decoded, labels);
872    }
873
874    #[test]
875    fn test_one_hot_encoder() {
876        let values = vec![
877            "red".to_string(),
878            "green".to_string(),
879            "blue".to_string(),
880            "red".to_string(),
881        ];
882
883        let mut encoder = OneHotEncoder::new();
884        encoder.fit(&[(0, values.clone())]);
885
886        let encoded = encoder.transform(0, &values).expect("unwrap");
887
888        assert_eq!(encoded.nrows(), 4);
889        assert_eq!(encoded.ncols(), 3);
890
891        // Each row should sum to 1
892        for i in 0..4 {
893            let row_sum: f64 = encoded.row(i).iter().sum();
894            assert!((row_sum - 1.0).abs() < 1e-10);
895        }
896    }
897
898    #[test]
899    fn test_csv_loader_builder() {
900        let loader = CsvLoader::new()
901            .with_header(true)
902            .with_delimiter(',')
903            .with_target_columns(vec![3]);
904
905        assert!(loader.has_header);
906        assert_eq!(loader.delimiter, ',');
907        assert_eq!(loader.target_columns, vec![3]);
908    }
909
910    #[test]
911    fn test_invalid_split_ratios() {
912        let features = Array2::zeros((10, 2));
913        let targets = Array2::zeros((10, 1));
914        let dataset = Dataset::new(features, targets);
915
916        // Ratios don't sum to 1
917        let result = dataset.split(&[0.5, 0.3]);
918        assert!(result.is_err());
919    }
920
921    #[test]
922    fn test_preprocessor_not_fitted() {
923        let data = Array2::zeros((4, 2));
924        let preprocessor = DataPreprocessor::standardize();
925
926        let result = preprocessor.transform(&data);
927        assert!(result.is_err());
928    }
929}