Skip to main content

tensorlogic_train/
data.rs

1//! Data loading and preprocessing utilities for training.
2//!
3//! This module provides tools for loading and preprocessing training data:
4//! - CSV and JSON data loading
5//! - Data normalization and standardization
6//! - Train/validation/test splitting
7//! - Data shuffling and sampling
8
9use crate::{TrainError, TrainResult};
10use scirs2_core::ndarray::{s, Array1, Array2};
11use std::collections::HashMap;
12use std::fs::File;
13use std::io::{BufRead, BufReader};
14use std::path::Path;
15
16/// Dataset container for training data.
17#[derive(Debug, Clone)]
18pub struct Dataset {
19    /// Feature matrix (samples x features).
20    pub features: Array2<f64>,
21    /// Target vector or matrix.
22    pub targets: Array2<f64>,
23    /// Feature names (if available).
24    pub feature_names: Option<Vec<String>>,
25    /// Target names (if available).
26    pub target_names: Option<Vec<String>>,
27}
28
29impl Dataset {
30    /// Create a new dataset.
31    pub fn new(features: Array2<f64>, targets: Array2<f64>) -> Self {
32        Self {
33            features,
34            targets,
35            feature_names: None,
36            target_names: None,
37        }
38    }
39
40    /// Set feature names.
41    pub fn with_feature_names(mut self, names: Vec<String>) -> Self {
42        self.feature_names = Some(names);
43        self
44    }
45
46    /// Set target names.
47    pub fn with_target_names(mut self, names: Vec<String>) -> Self {
48        self.target_names = Some(names);
49        self
50    }
51
52    /// Get number of samples.
53    pub fn num_samples(&self) -> usize {
54        self.features.nrows()
55    }
56
57    /// Get number of features.
58    pub fn num_features(&self) -> usize {
59        self.features.ncols()
60    }
61
62    /// Get number of targets.
63    pub fn num_targets(&self) -> usize {
64        self.targets.ncols()
65    }
66
67    /// Shuffle the dataset in place using Fisher-Yates algorithm.
68    pub fn shuffle(&mut self, seed: u64) {
69        let n = self.num_samples();
70        if n <= 1 {
71            return;
72        }
73
74        // Simple LCG for deterministic shuffling
75        let mut rng_state = seed;
76        let lcg_next = |state: &mut u64| -> usize {
77            *state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
78            (*state >> 33) as usize
79        };
80
81        for i in (1..n).rev() {
82            let j = lcg_next(&mut rng_state) % (i + 1);
83            // Swap rows
84            for col in 0..self.features.ncols() {
85                let tmp = self.features[[i, col]];
86                self.features[[i, col]] = self.features[[j, col]];
87                self.features[[j, col]] = tmp;
88            }
89            for col in 0..self.targets.ncols() {
90                let tmp = self.targets[[i, col]];
91                self.targets[[i, col]] = self.targets[[j, col]];
92                self.targets[[j, col]] = tmp;
93            }
94        }
95    }
96
97    /// Split dataset into subsets.
98    ///
99    /// # Arguments
100    /// * `ratios` - Ratios for each split (must sum to 1.0)
101    ///
102    /// # Returns
103    /// Vector of datasets corresponding to each ratio
104    pub fn split(&self, ratios: &[f64]) -> TrainResult<Vec<Dataset>> {
105        let total: f64 = ratios.iter().sum();
106        if (total - 1.0).abs() > 1e-6 {
107            return Err(TrainError::ConfigError(format!(
108                "Split ratios must sum to 1.0, got {}",
109                total
110            )));
111        }
112
113        let n = self.num_samples();
114        let mut splits = Vec::new();
115        let mut start = 0;
116
117        for (i, &ratio) in ratios.iter().enumerate() {
118            let end = if i == ratios.len() - 1 {
119                n // Last split gets remaining samples
120            } else {
121                start + (n as f64 * ratio).round() as usize
122            };
123
124            let features = self.features.slice(s![start..end, ..]).to_owned();
125            let targets = self.targets.slice(s![start..end, ..]).to_owned();
126
127            let mut dataset = Dataset::new(features, targets);
128            if let Some(ref names) = self.feature_names {
129                dataset.feature_names = Some(names.clone());
130            }
131            if let Some(ref names) = self.target_names {
132                dataset.target_names = Some(names.clone());
133            }
134
135            splits.push(dataset);
136            start = end;
137        }
138
139        Ok(splits)
140    }
141
142    /// Split into train and test sets.
143    pub fn train_test_split(&self, train_ratio: f64) -> TrainResult<(Dataset, Dataset)> {
144        let splits = self.split(&[train_ratio, 1.0 - train_ratio])?;
145        let mut iter = splits.into_iter();
146        Ok((iter.next().unwrap(), iter.next().unwrap()))
147    }
148
149    /// Split into train, validation, and test sets.
150    pub fn train_val_test_split(
151        &self,
152        train_ratio: f64,
153        val_ratio: f64,
154    ) -> TrainResult<(Dataset, Dataset, Dataset)> {
155        let test_ratio = 1.0 - train_ratio - val_ratio;
156        if test_ratio < 0.0 {
157            return Err(TrainError::ConfigError(
158                "Train and validation ratios exceed 1.0".to_string(),
159            ));
160        }
161        let splits = self.split(&[train_ratio, val_ratio, test_ratio])?;
162        let mut iter = splits.into_iter();
163        Ok((
164            iter.next().unwrap(),
165            iter.next().unwrap(),
166            iter.next().unwrap(),
167        ))
168    }
169
170    /// Get a subset of the dataset by indices.
171    pub fn subset(&self, indices: &[usize]) -> TrainResult<Dataset> {
172        let n = self.num_samples();
173        for &idx in indices {
174            if idx >= n {
175                return Err(TrainError::ConfigError(format!(
176                    "Index {} out of bounds for dataset with {} samples",
177                    idx, n
178                )));
179            }
180        }
181
182        let features = Array2::from_shape_fn((indices.len(), self.num_features()), |(i, j)| {
183            self.features[[indices[i], j]]
184        });
185        let targets = Array2::from_shape_fn((indices.len(), self.num_targets()), |(i, j)| {
186            self.targets[[indices[i], j]]
187        });
188
189        let mut dataset = Dataset::new(features, targets);
190        dataset.feature_names = self.feature_names.clone();
191        dataset.target_names = self.target_names.clone();
192
193        Ok(dataset)
194    }
195}
196
197/// CSV data loader.
198#[derive(Debug, Clone)]
199pub struct CsvLoader {
200    /// Whether the CSV has a header row.
201    pub has_header: bool,
202    /// Delimiter character.
203    pub delimiter: char,
204    /// Indices of target columns (0-based).
205    pub target_columns: Vec<usize>,
206    /// Columns to skip.
207    pub skip_columns: Vec<usize>,
208}
209
210impl Default for CsvLoader {
211    fn default() -> Self {
212        Self {
213            has_header: true,
214            delimiter: ',',
215            target_columns: vec![],
216            skip_columns: vec![],
217        }
218    }
219}
220
221impl CsvLoader {
222    /// Create a new CSV loader.
223    pub fn new() -> Self {
224        Self::default()
225    }
226
227    /// Set whether CSV has a header.
228    pub fn with_header(mut self, has_header: bool) -> Self {
229        self.has_header = has_header;
230        self
231    }
232
233    /// Set the delimiter character.
234    pub fn with_delimiter(mut self, delimiter: char) -> Self {
235        self.delimiter = delimiter;
236        self
237    }
238
239    /// Set target column indices.
240    pub fn with_target_columns(mut self, columns: Vec<usize>) -> Self {
241        self.target_columns = columns;
242        self
243    }
244
245    /// Set columns to skip.
246    pub fn with_skip_columns(mut self, columns: Vec<usize>) -> Self {
247        self.skip_columns = columns;
248        self
249    }
250
251    /// Load data from a CSV file.
252    pub fn load<P: AsRef<Path>>(&self, path: P) -> TrainResult<Dataset> {
253        let file = File::open(path.as_ref())
254            .map_err(|e| TrainError::Other(format!("Failed to open CSV file: {}", e)))?;
255        let reader = BufReader::new(file);
256        let mut lines = reader.lines();
257
258        let mut feature_names = None;
259        let mut target_names = None;
260
261        // Parse header if present
262        if self.has_header {
263            if let Some(Ok(header)) = lines.next() {
264                let names: Vec<String> = header
265                    .split(self.delimiter)
266                    .map(|s| s.trim().to_string())
267                    .collect();
268
269                let mut feat_names = Vec::new();
270                let mut targ_names = Vec::new();
271
272                for (i, name) in names.into_iter().enumerate() {
273                    if self.skip_columns.contains(&i) {
274                        continue;
275                    }
276                    if self.target_columns.contains(&i) {
277                        targ_names.push(name);
278                    } else {
279                        feat_names.push(name);
280                    }
281                }
282
283                feature_names = Some(feat_names);
284                target_names = Some(targ_names);
285            }
286        }
287
288        // Parse data rows
289        let mut features_data: Vec<Vec<f64>> = Vec::new();
290        let mut targets_data: Vec<Vec<f64>> = Vec::new();
291
292        for line_result in lines {
293            let line = line_result
294                .map_err(|e| TrainError::Other(format!("Failed to read CSV line: {}", e)))?;
295
296            if line.trim().is_empty() {
297                continue;
298            }
299
300            let values: Vec<&str> = line.split(self.delimiter).collect();
301            let mut row_features = Vec::new();
302            let mut row_targets = Vec::new();
303
304            for (i, value) in values.iter().enumerate() {
305                if self.skip_columns.contains(&i) {
306                    continue;
307                }
308
309                let parsed: f64 = value.trim().parse().map_err(|e| {
310                    TrainError::Other(format!("Failed to parse value '{}': {}", value, e))
311                })?;
312
313                if self.target_columns.contains(&i) {
314                    row_targets.push(parsed);
315                } else {
316                    row_features.push(parsed);
317                }
318            }
319
320            features_data.push(row_features);
321            targets_data.push(row_targets);
322        }
323
324        if features_data.is_empty() {
325            return Err(TrainError::Other("CSV file is empty".to_string()));
326        }
327
328        let n_samples = features_data.len();
329        let n_features = features_data[0].len();
330        let n_targets = if targets_data[0].is_empty() {
331            0
332        } else {
333            targets_data[0].len()
334        };
335
336        // Convert to arrays
337        let features = Array2::from_shape_fn((n_samples, n_features), |(i, j)| features_data[i][j]);
338
339        let targets = if n_targets > 0 {
340            Array2::from_shape_fn((n_samples, n_targets), |(i, j)| targets_data[i][j])
341        } else {
342            Array2::zeros((n_samples, 1))
343        };
344
345        let mut dataset = Dataset::new(features, targets);
346        dataset.feature_names = feature_names;
347        dataset.target_names = target_names;
348
349        Ok(dataset)
350    }
351}
352
353/// Data preprocessor for normalization and standardization.
354#[derive(Debug, Clone)]
355pub struct DataPreprocessor {
356    /// Preprocessing method.
357    method: PreprocessingMethod,
358    /// Fitted parameters (mean, std, min, max).
359    params: Option<PreprocessingParams>,
360}
361
362/// Preprocessing method.
363#[derive(Debug, Clone, Copy, PartialEq, Eq)]
364pub enum PreprocessingMethod {
365    /// Standardization (zero mean, unit variance).
366    Standardize,
367    /// Min-max normalization to [0, 1].
368    MinMaxNormalize,
369    /// Min-max scaling to custom range.
370    MinMaxScale { min: i32, max: i32 },
371    /// No preprocessing.
372    None,
373}
374
375/// Fitted preprocessing parameters.
376#[derive(Debug, Clone)]
377struct PreprocessingParams {
378    means: Array1<f64>,
379    stds: Array1<f64>,
380    mins: Array1<f64>,
381    maxs: Array1<f64>,
382}
383
384impl DataPreprocessor {
385    /// Create a new preprocessor with standardization.
386    pub fn standardize() -> Self {
387        Self {
388            method: PreprocessingMethod::Standardize,
389            params: None,
390        }
391    }
392
393    /// Create a new preprocessor with min-max normalization.
394    pub fn min_max_normalize() -> Self {
395        Self {
396            method: PreprocessingMethod::MinMaxNormalize,
397            params: None,
398        }
399    }
400
401    /// Create a new preprocessor with custom min-max scaling.
402    pub fn min_max_scale(min: i32, max: i32) -> Self {
403        Self {
404            method: PreprocessingMethod::MinMaxScale { min, max },
405            params: None,
406        }
407    }
408
409    /// Create a preprocessor that does nothing.
410    pub fn none() -> Self {
411        Self {
412            method: PreprocessingMethod::None,
413            params: None,
414        }
415    }
416
417    /// Fit the preprocessor to data.
418    pub fn fit(&mut self, data: &Array2<f64>) -> &mut Self {
419        let n_features = data.ncols();
420
421        let mut means = Array1::zeros(n_features);
422        let mut stds = Array1::zeros(n_features);
423        let mut mins = Array1::from_elem(n_features, f64::INFINITY);
424        let mut maxs = Array1::from_elem(n_features, f64::NEG_INFINITY);
425
426        for j in 0..n_features {
427            let col = data.column(j);
428            let n = col.len() as f64;
429
430            // Compute mean
431            let mean: f64 = col.iter().sum::<f64>() / n;
432            means[j] = mean;
433
434            // Compute std
435            let variance: f64 = col.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
436            stds[j] = variance.sqrt().max(1e-8); // Avoid division by zero
437
438            // Compute min/max
439            for &x in col.iter() {
440                if x < mins[j] {
441                    mins[j] = x;
442                }
443                if x > maxs[j] {
444                    maxs[j] = x;
445                }
446            }
447        }
448
449        self.params = Some(PreprocessingParams {
450            means,
451            stds,
452            mins,
453            maxs,
454        });
455
456        self
457    }
458
459    /// Transform data using fitted parameters.
460    pub fn transform(&self, data: &Array2<f64>) -> TrainResult<Array2<f64>> {
461        let params = self.params.as_ref().ok_or_else(|| {
462            TrainError::Other("Preprocessor not fitted. Call fit() first.".to_string())
463        })?;
464
465        let mut result = data.clone();
466
467        match self.method {
468            PreprocessingMethod::Standardize => {
469                for j in 0..data.ncols() {
470                    for i in 0..data.nrows() {
471                        result[[i, j]] = (data[[i, j]] - params.means[j]) / params.stds[j];
472                    }
473                }
474            }
475            PreprocessingMethod::MinMaxNormalize => {
476                for j in 0..data.ncols() {
477                    let range = (params.maxs[j] - params.mins[j]).max(1e-8);
478                    for i in 0..data.nrows() {
479                        result[[i, j]] = (data[[i, j]] - params.mins[j]) / range;
480                    }
481                }
482            }
483            PreprocessingMethod::MinMaxScale { min, max } => {
484                let target_range = (max - min) as f64;
485                for j in 0..data.ncols() {
486                    let range = (params.maxs[j] - params.mins[j]).max(1e-8);
487                    for i in 0..data.nrows() {
488                        let normalized = (data[[i, j]] - params.mins[j]) / range;
489                        result[[i, j]] = normalized * target_range + min as f64;
490                    }
491                }
492            }
493            PreprocessingMethod::None => {}
494        }
495
496        Ok(result)
497    }
498
499    /// Fit and transform in one step.
500    pub fn fit_transform(&mut self, data: &Array2<f64>) -> TrainResult<Array2<f64>> {
501        self.fit(data);
502        self.transform(data)
503    }
504
505    /// Inverse transform to original scale.
506    pub fn inverse_transform(&self, data: &Array2<f64>) -> TrainResult<Array2<f64>> {
507        let params = self.params.as_ref().ok_or_else(|| {
508            TrainError::Other("Preprocessor not fitted. Call fit() first.".to_string())
509        })?;
510
511        let mut result = data.clone();
512
513        match self.method {
514            PreprocessingMethod::Standardize => {
515                for j in 0..data.ncols() {
516                    for i in 0..data.nrows() {
517                        result[[i, j]] = data[[i, j]] * params.stds[j] + params.means[j];
518                    }
519                }
520            }
521            PreprocessingMethod::MinMaxNormalize => {
522                for j in 0..data.ncols() {
523                    let range = params.maxs[j] - params.mins[j];
524                    for i in 0..data.nrows() {
525                        result[[i, j]] = data[[i, j]] * range + params.mins[j];
526                    }
527                }
528            }
529            PreprocessingMethod::MinMaxScale { min, max } => {
530                let target_range = (max - min) as f64;
531                for j in 0..data.ncols() {
532                    let range = params.maxs[j] - params.mins[j];
533                    for i in 0..data.nrows() {
534                        let normalized = (data[[i, j]] - min as f64) / target_range;
535                        result[[i, j]] = normalized * range + params.mins[j];
536                    }
537                }
538            }
539            PreprocessingMethod::None => {}
540        }
541
542        Ok(result)
543    }
544
545    /// Check if the preprocessor is fitted.
546    pub fn is_fitted(&self) -> bool {
547        self.params.is_some()
548    }
549
550    /// Get the preprocessing method.
551    pub fn method(&self) -> PreprocessingMethod {
552        self.method
553    }
554}
555
556/// One-hot encoder for categorical data.
557#[derive(Debug, Clone)]
558pub struct OneHotEncoder {
559    /// Mapping from category to index for each column.
560    categories: HashMap<usize, HashMap<String, usize>>,
561    /// Number of categories per column.
562    n_categories: HashMap<usize, usize>,
563}
564
565impl OneHotEncoder {
566    /// Create a new one-hot encoder.
567    pub fn new() -> Self {
568        Self {
569            categories: HashMap::new(),
570            n_categories: HashMap::new(),
571        }
572    }
573
574    /// Fit the encoder to categorical data.
575    ///
576    /// # Arguments
577    /// * `data` - Vector of (column_index, values) pairs
578    pub fn fit(&mut self, data: &[(usize, Vec<String>)]) -> &mut Self {
579        for (col_idx, values) in data {
580            let mut categories = HashMap::new();
581            let mut unique_values: Vec<&String> = values.iter().collect();
582            unique_values.sort();
583            unique_values.dedup();
584
585            for (i, value) in unique_values.into_iter().enumerate() {
586                categories.insert(value.clone(), i);
587            }
588
589            self.n_categories.insert(*col_idx, categories.len());
590            self.categories.insert(*col_idx, categories);
591        }
592
593        self
594    }
595
596    /// Transform categorical column to one-hot encoded array.
597    pub fn transform(&self, col_idx: usize, values: &[String]) -> TrainResult<Array2<f64>> {
598        let categories = self
599            .categories
600            .get(&col_idx)
601            .ok_or_else(|| TrainError::Other(format!("Column {} not fitted", col_idx)))?;
602
603        let n_samples = values.len();
604        let n_cats = *self.n_categories.get(&col_idx).unwrap();
605
606        let mut result = Array2::zeros((n_samples, n_cats));
607
608        for (i, value) in values.iter().enumerate() {
609            if let Some(&idx) = categories.get(value) {
610                result[[i, idx]] = 1.0;
611            } else {
612                return Err(TrainError::Other(format!(
613                    "Unknown category '{}' for column {}",
614                    value, col_idx
615                )));
616            }
617        }
618
619        Ok(result)
620    }
621
622    /// Get number of categories for a column.
623    pub fn num_categories(&self, col_idx: usize) -> Option<usize> {
624        self.n_categories.get(&col_idx).copied()
625    }
626}
627
628impl Default for OneHotEncoder {
629    fn default() -> Self {
630        Self::new()
631    }
632}
633
634/// Label encoder for converting string labels to integers.
635#[derive(Debug, Clone)]
636pub struct LabelEncoder {
637    /// Mapping from label to integer.
638    label_to_int: HashMap<String, usize>,
639    /// Mapping from integer to label.
640    int_to_label: Vec<String>,
641}
642
643impl LabelEncoder {
644    /// Create a new label encoder.
645    pub fn new() -> Self {
646        Self {
647            label_to_int: HashMap::new(),
648            int_to_label: Vec::new(),
649        }
650    }
651
652    /// Fit the encoder to labels.
653    pub fn fit(&mut self, labels: &[String]) -> &mut Self {
654        let mut unique: Vec<&String> = labels.iter().collect();
655        unique.sort();
656        unique.dedup();
657
658        self.label_to_int.clear();
659        self.int_to_label.clear();
660
661        for (i, label) in unique.into_iter().enumerate() {
662            self.label_to_int.insert(label.clone(), i);
663            self.int_to_label.push(label.clone());
664        }
665
666        self
667    }
668
669    /// Transform labels to integers.
670    pub fn transform(&self, labels: &[String]) -> TrainResult<Array1<usize>> {
671        let mut result = Array1::zeros(labels.len());
672
673        for (i, label) in labels.iter().enumerate() {
674            result[i] = *self
675                .label_to_int
676                .get(label)
677                .ok_or_else(|| TrainError::Other(format!("Unknown label: {}", label)))?;
678        }
679
680        Ok(result)
681    }
682
683    /// Inverse transform integers to labels.
684    pub fn inverse_transform(&self, indices: &Array1<usize>) -> TrainResult<Vec<String>> {
685        let mut result = Vec::with_capacity(indices.len());
686
687        for &idx in indices.iter() {
688            if idx >= self.int_to_label.len() {
689                return Err(TrainError::Other(format!(
690                    "Index {} out of bounds for {} classes",
691                    idx,
692                    self.int_to_label.len()
693                )));
694            }
695            result.push(self.int_to_label[idx].clone());
696        }
697
698        Ok(result)
699    }
700
701    /// Fit and transform in one step.
702    pub fn fit_transform(&mut self, labels: &[String]) -> TrainResult<Array1<usize>> {
703        self.fit(labels);
704        self.transform(labels)
705    }
706
707    /// Get number of classes.
708    pub fn num_classes(&self) -> usize {
709        self.int_to_label.len()
710    }
711
712    /// Get class labels.
713    pub fn classes(&self) -> &[String] {
714        &self.int_to_label
715    }
716}
717
718impl Default for LabelEncoder {
719    fn default() -> Self {
720        Self::new()
721    }
722}
723
724#[cfg(test)]
725mod tests {
726    use super::*;
727
728    #[test]
729    fn test_dataset_creation() {
730        let features = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
731        let targets = Array2::from_shape_vec((3, 1), vec![0.0, 1.0, 0.0]).unwrap();
732
733        let dataset = Dataset::new(features, targets);
734
735        assert_eq!(dataset.num_samples(), 3);
736        assert_eq!(dataset.num_features(), 2);
737        assert_eq!(dataset.num_targets(), 1);
738    }
739
740    #[test]
741    fn test_dataset_split() {
742        let features = Array2::from_shape_fn((10, 2), |(i, j)| (i * 2 + j) as f64);
743        let targets = Array2::from_shape_fn((10, 1), |(i, _)| i as f64);
744
745        let dataset = Dataset::new(features, targets);
746        let splits = dataset.split(&[0.6, 0.2, 0.2]).unwrap();
747
748        assert_eq!(splits.len(), 3);
749        assert_eq!(splits[0].num_samples(), 6);
750        assert_eq!(splits[1].num_samples(), 2);
751        assert_eq!(splits[2].num_samples(), 2);
752    }
753
754    #[test]
755    fn test_train_test_split() {
756        let features = Array2::from_shape_fn((100, 4), |(i, j)| (i * 4 + j) as f64);
757        let targets = Array2::from_shape_fn((100, 1), |(i, _)| (i % 2) as f64);
758
759        let dataset = Dataset::new(features, targets);
760        let (train, test) = dataset.train_test_split(0.8).unwrap();
761
762        assert_eq!(train.num_samples(), 80);
763        assert_eq!(test.num_samples(), 20);
764    }
765
766    #[test]
767    fn test_dataset_shuffle() {
768        let features = Array2::from_shape_fn((10, 2), |(i, j)| (i * 2 + j) as f64);
769        let targets = Array2::from_shape_fn((10, 1), |(i, _)| i as f64);
770
771        let mut dataset = Dataset::new(features.clone(), targets);
772        dataset.shuffle(42);
773
774        // After shuffle, data should be different
775        let mut different = false;
776        for i in 0..10 {
777            if dataset.features[[i, 0]] != features[[i, 0]] {
778                different = true;
779                break;
780            }
781        }
782        assert!(different);
783    }
784
785    #[test]
786    fn test_dataset_subset() {
787        let features = Array2::from_shape_fn((10, 2), |(i, j)| (i * 2 + j) as f64);
788        let targets = Array2::from_shape_fn((10, 1), |(i, _)| i as f64);
789
790        let dataset = Dataset::new(features, targets);
791        let subset = dataset.subset(&[0, 2, 4]).unwrap();
792
793        assert_eq!(subset.num_samples(), 3);
794        assert_eq!(subset.features[[0, 0]], 0.0);
795        assert_eq!(subset.features[[1, 0]], 4.0);
796        assert_eq!(subset.features[[2, 0]], 8.0);
797    }
798
799    #[test]
800    fn test_preprocessor_standardize() {
801        let data =
802            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
803
804        let mut preprocessor = DataPreprocessor::standardize();
805        let transformed = preprocessor.fit_transform(&data).unwrap();
806
807        // Check that mean is approximately 0
808        let col0_mean: f64 = transformed.column(0).iter().sum::<f64>() / 4.0;
809        let col1_mean: f64 = transformed.column(1).iter().sum::<f64>() / 4.0;
810
811        assert!(col0_mean.abs() < 1e-10);
812        assert!(col1_mean.abs() < 1e-10);
813
814        // Check inverse transform
815        let recovered = preprocessor.inverse_transform(&transformed).unwrap();
816        for i in 0..4 {
817            for j in 0..2 {
818                assert!((recovered[[i, j]] - data[[i, j]]).abs() < 1e-10);
819            }
820        }
821    }
822
823    #[test]
824    fn test_preprocessor_min_max() {
825        let data =
826            Array2::from_shape_vec((4, 2), vec![0.0, 10.0, 5.0, 20.0, 10.0, 30.0, 15.0, 40.0])
827                .unwrap();
828
829        let mut preprocessor = DataPreprocessor::min_max_normalize();
830        let transformed = preprocessor.fit_transform(&data).unwrap();
831
832        // Check that values are in [0, 1]
833        for &val in transformed.iter() {
834            assert!((0.0..=1.0).contains(&val));
835        }
836
837        // Check specific values
838        assert!((transformed[[0, 0]] - 0.0).abs() < 1e-10); // min
839        assert!((transformed[[3, 0]] - 1.0).abs() < 1e-10); // max
840    }
841
842    #[test]
843    fn test_label_encoder() {
844        let labels = vec![
845            "cat".to_string(),
846            "dog".to_string(),
847            "cat".to_string(),
848            "bird".to_string(),
849        ];
850
851        let mut encoder = LabelEncoder::new();
852        let encoded = encoder.fit_transform(&labels).unwrap();
853
854        assert_eq!(encoder.num_classes(), 3);
855        assert_eq!(encoded.len(), 4);
856
857        // Same labels should have same encoding
858        assert_eq!(encoded[0], encoded[2]);
859
860        // Test inverse transform
861        let decoded = encoder.inverse_transform(&encoded).unwrap();
862        assert_eq!(decoded, labels);
863    }
864
865    #[test]
866    fn test_one_hot_encoder() {
867        let values = vec![
868            "red".to_string(),
869            "green".to_string(),
870            "blue".to_string(),
871            "red".to_string(),
872        ];
873
874        let mut encoder = OneHotEncoder::new();
875        encoder.fit(&[(0, values.clone())]);
876
877        let encoded = encoder.transform(0, &values).unwrap();
878
879        assert_eq!(encoded.nrows(), 4);
880        assert_eq!(encoded.ncols(), 3);
881
882        // Each row should sum to 1
883        for i in 0..4 {
884            let row_sum: f64 = encoded.row(i).iter().sum();
885            assert!((row_sum - 1.0).abs() < 1e-10);
886        }
887    }
888
889    #[test]
890    fn test_csv_loader_builder() {
891        let loader = CsvLoader::new()
892            .with_header(true)
893            .with_delimiter(',')
894            .with_target_columns(vec![3]);
895
896        assert!(loader.has_header);
897        assert_eq!(loader.delimiter, ',');
898        assert_eq!(loader.target_columns, vec![3]);
899    }
900
901    #[test]
902    fn test_invalid_split_ratios() {
903        let features = Array2::zeros((10, 2));
904        let targets = Array2::zeros((10, 1));
905        let dataset = Dataset::new(features, targets);
906
907        // Ratios don't sum to 1
908        let result = dataset.split(&[0.5, 0.3]);
909        assert!(result.is_err());
910    }
911
912    #[test]
913    fn test_preprocessor_not_fitted() {
914        let data = Array2::zeros((4, 2));
915        let preprocessor = DataPreprocessor::standardize();
916
917        let result = preprocessor.transform(&data);
918        assert!(result.is_err());
919    }
920}