Skip to main content

scry_learn/
dataset.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Tabular dataset container for ML workflows.
3//!
4//! [`Dataset`] provides a lightweight, column-major representation of
5//! features + target, with CSV loading and basic column access.
6
7use std::sync::OnceLock;
8
9use crate::error::{Result, ScryLearnError};
10
11use crate::matrix::DenseMatrix;
12use crate::sparse::CscMatrix;
13
14/// Internal feature storage format.
15#[derive(Clone, Debug, Default)]
16pub(crate) enum Storage {
17    /// Dense column-major features (current default).
18    #[default]
19    Dense,
20    /// Sparse CSC matrix (column-oriented for fit).
21    Sparse(CscMatrix),
22}
23
24/// Descriptive statistics for a single column.
25#[derive(Clone, Debug)]
26pub struct ColumnStats {
27    /// Column name.
28    pub name: String,
29    /// Number of finite (non-NaN) values.
30    pub count: usize,
31    /// Arithmetic mean.
32    pub mean: f64,
33    /// Sample standard deviation (ddof=1).
34    pub std: f64,
35    /// Minimum value.
36    pub min: f64,
37    /// 25th percentile.
38    pub q25: f64,
39    /// Median (50th percentile).
40    pub median: f64,
41    /// 75th percentile.
42    pub q75: f64,
43    /// Maximum value.
44    pub max: f64,
45}
46
47/// A tabular dataset with features and a target column.
48///
49/// Features are stored column-major (`features[feature_idx][sample_idx]`)
50/// for cache-friendly access during tree split evaluation.
51#[derive(Clone, Debug)]
52#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
53#[non_exhaustive]
54pub struct Dataset {
55    /// Feature columns: `features[feature_idx][sample_idx]`.
56    pub features: Vec<Vec<f64>>,
57    /// Target values: `target[sample_idx]`.
58    pub target: Vec<f64>,
59    /// Feature column names.
60    pub feature_names: Vec<String>,
61    /// Target column name.
62    pub target_name: String,
63    /// Class label mapping (index → label string) for classification tasks.
64    pub class_labels: Option<Vec<String>>,
65    /// Lazily-computed contiguous column-major feature matrix.
66    ///
67    /// Built on first access from `features` via [`OnceCell::get_or_init`],
68    /// avoiding the upfront clone in [`Dataset::new`].
69    #[cfg_attr(feature = "serde", serde(skip))]
70    matrix: OnceLock<DenseMatrix>,
71    /// Lazily-computed contiguous row-major feature buffer.
72    ///
73    /// Layout: `[sample_0_feat_0, sample_0_feat_1, ..., sample_n_feat_m]`.
74    /// Populated on first call to [`flat_feature_matrix`].
75    #[cfg_attr(feature = "serde", serde(skip))]
76    row_major_cache: Option<Vec<f64>>,
77    /// Storage format — dense (default) or sparse CSC.
78    #[cfg_attr(feature = "serde", serde(skip))]
79    storage: Storage,
80}
81
82impl Dataset {
83    /// Create a dataset from pre-computed features and target.
84    ///
85    /// # Panics
86    ///
87    /// Panics if feature columns have mismatched lengths, or if
88    /// `feature_names.len() != features.len()`.
89    pub fn new(
90        features: Vec<Vec<f64>>,
91        target: Vec<f64>,
92        feature_names: Vec<String>,
93        target_name: impl Into<String>,
94    ) -> Self {
95        assert!(
96            feature_names.len() == features.len(),
97            "feature_names.len()={} but features.len()={}",
98            feature_names.len(),
99            features.len(),
100        );
101        if let Some(first) = features.first() {
102            for (i, col) in features.iter().enumerate().skip(1) {
103                assert!(
104                    col.len() == first.len(),
105                    "feature column {i} has {} rows but column 0 has {}",
106                    col.len(),
107                    first.len(),
108                );
109            }
110        }
111        Self {
112            features,
113            target,
114            feature_names,
115            target_name: target_name.into(),
116            class_labels: None,
117            matrix: OnceLock::new(),
118            row_major_cache: None,
119            storage: Storage::Dense,
120        }
121    }
122
123    /// Create a dataset from a [`DenseMatrix`], target, and column names.
124    ///
125    /// The `features` field is populated from the matrix for backward compat.
126    pub fn from_matrix(
127        matrix: DenseMatrix,
128        target: Vec<f64>,
129        feature_names: Vec<String>,
130        target_name: impl Into<String>,
131    ) -> Self {
132        let features = matrix.to_col_vecs();
133        let cell = OnceLock::new();
134        let _ = cell.set(matrix);
135        Self {
136            features,
137            target,
138            feature_names,
139            target_name: target_name.into(),
140            class_labels: None,
141            matrix: cell,
142            row_major_cache: None,
143            storage: Storage::Dense,
144        }
145    }
146
147    /// The contiguous column-major feature matrix.
148    ///
149    /// Lazily built from `features` on first access. Subsequent calls
150    /// return the cached matrix without recomputation.
151    #[inline]
152    pub fn matrix(&self) -> &DenseMatrix {
153        self.matrix.get_or_init(|| {
154            DenseMatrix::from_col_major_ref(&self.features)
155                .expect("DenseMatrix build from features failed")
156        })
157    }
158
159    /// Load a dataset from a CSV file.
160    ///
161    /// The `target_column` is extracted as the target; all other numeric
162    /// columns become features. String columns are label-encoded automatically.
163    ///
164    /// Requires the `csv` feature.
165    #[cfg(feature = "csv")]
166    pub fn from_csv(path: &str, target_column: &str) -> Result<Self> {
167        let file = std::fs::File::open(path).map_err(ScryLearnError::Io)?;
168        Self::from_csv_reader(file, target_column)
169    }
170
171    /// Load a dataset from any reader producing CSV data.
172    ///
173    /// Requires the `csv` feature.
174    #[cfg(feature = "csv")]
175    pub fn from_csv_reader(rdr: impl std::io::Read, target_column: &str) -> Result<Self> {
176        let mut csv_rdr = csv::ReaderBuilder::new()
177            .has_headers(true)
178            .flexible(true)
179            .from_reader(rdr);
180
181        let headers: Vec<String> = csv_rdr
182            .headers()
183            .map_err(|e| ScryLearnError::Csv(e.to_string()))?
184            .iter()
185            .map(std::string::ToString::to_string)
186            .collect();
187
188        let target_idx = headers
189            .iter()
190            .position(|h| h.eq_ignore_ascii_case(target_column))
191            .ok_or_else(|| ScryLearnError::InvalidColumn(target_column.to_string()))?;
192
193        // Collect all rows as string records.
194        let mut rows: Vec<Vec<String>> = Vec::new();
195        for result in csv_rdr.records() {
196            let record = result.map_err(|e| ScryLearnError::Csv(e.to_string()))?;
197            rows.push(
198                record
199                    .iter()
200                    .map(std::string::ToString::to_string)
201                    .collect(),
202            );
203        }
204
205        if rows.is_empty() {
206            return Err(ScryLearnError::EmptyDataset);
207        }
208
209        // Determine which columns are features (all except target).
210        let feature_indices: Vec<usize> = (0..headers.len()).filter(|&i| i != target_idx).collect();
211
212        let n_samples = rows.len();
213        let n_features = feature_indices.len();
214
215        // Parse target — try numeric first, fall back to label encoding.
216        let (target, class_labels) = parse_target_column(&rows, target_idx);
217
218        // Parse feature columns — try numeric, label-encode strings.
219        let mut features = vec![vec![0.0; n_samples]; n_features];
220        let mut feature_names = Vec::with_capacity(n_features);
221
222        for (feat_col, &col_idx) in feature_indices.iter().enumerate() {
223            feature_names.push(headers[col_idx].clone());
224            for (row_idx, row) in rows.iter().enumerate() {
225                let val = row.get(col_idx).map_or("", std::string::String::as_str);
226                features[feat_col][row_idx] = val.parse::<f64>().unwrap_or(f64::NAN);
227            }
228        }
229
230        Ok(Self {
231            features,
232            target,
233            feature_names,
234            target_name: headers[target_idx].clone(),
235            class_labels,
236            matrix: OnceLock::new(),
237            row_major_cache: None,
238            storage: Storage::Dense,
239        })
240    }
241
242    /// Number of samples (rows).
243    #[inline]
244    pub fn n_samples(&self) -> usize {
245        self.target.len()
246    }
247
248    /// Number of features (columns).
249    #[inline]
250    pub fn n_features(&self) -> usize {
251        match &self.storage {
252            Storage::Sparse(csc) => csc.n_cols(),
253            Storage::Dense => self.features.len(),
254        }
255    }
256
257    /// Number of unique classes in the target (for classification).
258    pub fn n_classes(&self) -> usize {
259        self.class_labels.as_ref().map_or_else(
260            || {
261                let mut vals: Vec<i64> = self.target.iter().map(|&v| v as i64).collect();
262                vals.sort_unstable();
263                vals.dedup();
264                vals.len()
265            },
266            Vec::len,
267        )
268    }
269
270    /// Get a single feature column by index.
271    pub fn feature(&self, idx: usize) -> &[f64] {
272        &self.features[idx]
273    }
274
275    /// Get a single sample (row) as a vector of feature values.
276    pub fn sample(&self, idx: usize) -> Vec<f64> {
277        self.features.iter().map(|col| col[idx]).collect()
278    }
279
280    /// Get the feature matrix as row-major `[n_samples][n_features]`.
281    pub fn feature_matrix(&self) -> Vec<Vec<f64>> {
282        let n = self.n_samples();
283        let m = self.n_features();
284        let mut matrix = vec![vec![0.0; m]; n];
285        for (j, feat_col) in self.features.iter().enumerate() {
286            for (i, &val) in feat_col.iter().enumerate() {
287                matrix[i][j] = val;
288            }
289        }
290        matrix
291    }
292
293    /// Get a contiguous row-major feature buffer, computing on first call.
294    ///
295    /// Layout: `[sample_0_feat_0, sample_0_feat_1, ..., sample_n_feat_m]`.
296    /// Subsequent calls return the cached slice without recomputation.
297    pub fn flat_feature_matrix(&mut self) -> &[f64] {
298        if self.row_major_cache.is_none() {
299            let n = self.n_samples();
300            let m = self.n_features();
301            let mut buf = vec![0.0; n * m];
302            if let Some(mat) = self.matrix.get() {
303                let src = mat.as_slice();
304                for j in 0..m {
305                    let col_off = j * n;
306                    for i in 0..n {
307                        buf[i * m + j] = src[col_off + i];
308                    }
309                }
310            } else {
311                for j in 0..m {
312                    for i in 0..n {
313                        buf[i * m + j] = self.features[j][i];
314                    }
315                }
316            }
317            self.row_major_cache = Some(buf);
318        }
319        self.row_major_cache
320            .as_ref()
321            .expect("row_major_cache populated above")
322    }
323
324    /// Get a zero-copy row slice from a pre-computed flat feature buffer.
325    ///
326    /// `cache` should be the result of [`Dataset::flat_feature_matrix`].
327    #[inline]
328    pub fn sample_row<'a>(&self, cache: &'a [f64], idx: usize) -> &'a [f64] {
329        let m = self.n_features();
330        &cache[idx * m..(idx + 1) * m]
331    }
332
333    /// Create a subset of this dataset with the given sample indices.
334    pub fn subset(&self, indices: &[usize]) -> Self {
335        let target: Vec<f64> = indices.iter().map(|&i| self.target[i]).collect();
336
337        if let Storage::Sparse(csc) = &self.storage {
338            let new_csc = subset_csc(csc, indices);
339            return Self {
340                features: Vec::new(),
341                target,
342                feature_names: self.feature_names.clone(),
343                target_name: self.target_name.clone(),
344                class_labels: self.class_labels.clone(),
345                matrix: OnceLock::new(),
346                row_major_cache: None,
347                storage: Storage::Sparse(new_csc),
348            };
349        }
350
351        let features: Vec<Vec<f64>> = self
352            .features
353            .iter()
354            .map(|col| indices.iter().map(|&i| col[i]).collect())
355            .collect();
356        Self {
357            features,
358            target,
359            feature_names: self.feature_names.clone(),
360            target_name: self.target_name.clone(),
361            class_labels: self.class_labels.clone(),
362            matrix: OnceLock::new(),
363            row_major_cache: None,
364            storage: Storage::Dense,
365        }
366    }
367
368    /// Clear the cached matrix so it will be lazily rebuilt from `features`
369    /// on the next call to [`matrix()`](Self::matrix).
370    ///
371    /// Call this after mutating `features` in place (e.g. after a
372    /// transformer's `transform()` step).
373    pub fn sync_matrix(&mut self) {
374        self.matrix = OnceLock::new();
375        self.row_major_cache = None;
376    }
377
378    /// Mark the matrix cache as stale after in-place feature mutations.
379    ///
380    /// The matrix will be lazily rebuilt from `features` on next access.
381    #[inline]
382    pub fn invalidate_matrix(&mut self) {
383        self.matrix = OnceLock::new();
384        self.row_major_cache = None;
385    }
386
387    /// Returns `Err(InvalidData)` if any feature or target value is NaN or ±Inf.
388    pub fn validate_finite(&self) -> Result<()> {
389        // Check sparse storage values if present.
390        if let Storage::Sparse(csc) = &self.storage {
391            for j in 0..csc.n_cols() {
392                for (i, v) in csc.col(j).iter() {
393                    if !v.is_finite() {
394                        let name = self
395                            .feature_names
396                            .get(j)
397                            .map_or_else(|| format!("feature[{j}]"), std::clone::Clone::clone);
398                        return Err(ScryLearnError::InvalidData(format!(
399                            "non-finite value ({v}) in {name} at sample {i}"
400                        )));
401                    }
402                }
403            }
404        } else {
405            for (j, col) in self.features.iter().enumerate() {
406                for (i, &v) in col.iter().enumerate() {
407                    if !v.is_finite() {
408                        let name = self
409                            .feature_names
410                            .get(j)
411                            .map_or_else(|| format!("feature[{j}]"), std::clone::Clone::clone);
412                        return Err(ScryLearnError::InvalidData(format!(
413                            "non-finite value ({v}) in {name} at sample {i}"
414                        )));
415                    }
416                }
417            }
418        }
419        for (i, &v) in self.target.iter().enumerate() {
420            if !v.is_finite() {
421                return Err(ScryLearnError::InvalidData(format!(
422                    "non-finite value ({v}) in target at sample {i}"
423                )));
424            }
425        }
426        Ok(())
427    }
428
429    /// Returns `Err(InvalidData)` if any feature or target value is ±Inf.
430    ///
431    /// Unlike [`validate_finite`](Self::validate_finite), this allows NaN
432    /// values (useful for imputers that intentionally handle NaN).
433    pub fn validate_no_inf(&self) -> Result<()> {
434        if let Storage::Sparse(csc) = &self.storage {
435            for j in 0..csc.n_cols() {
436                for (i, v) in csc.col(j).iter() {
437                    if v.is_infinite() {
438                        let name = self
439                            .feature_names
440                            .get(j)
441                            .map_or_else(|| format!("feature[{j}]"), std::clone::Clone::clone);
442                        return Err(ScryLearnError::InvalidData(format!(
443                            "infinite value ({v}) in {name} at sample {i}"
444                        )));
445                    }
446                }
447            }
448        } else {
449            for (j, col) in self.features.iter().enumerate() {
450                for (i, &v) in col.iter().enumerate() {
451                    if v.is_infinite() {
452                        let name = self
453                            .feature_names
454                            .get(j)
455                            .map_or_else(|| format!("feature[{j}]"), std::clone::Clone::clone);
456                        return Err(ScryLearnError::InvalidData(format!(
457                            "infinite value ({v}) in {name} at sample {i}"
458                        )));
459                    }
460                }
461            }
462        }
463        for (i, &v) in self.target.iter().enumerate() {
464            if v.is_infinite() {
465                return Err(ScryLearnError::InvalidData(format!(
466                    "infinite value ({v}) in target at sample {i}"
467                )));
468            }
469        }
470        Ok(())
471    }
472
473    /// Attach class labels for classification.
474    pub fn with_class_labels(mut self, labels: Vec<String>) -> Self {
475        self.class_labels = Some(labels);
476        self
477    }
478
479    /// Create a dataset from a sparse CSC matrix.
480    ///
481    /// The `features` field is left empty. Call [`ensure_dense`](Self::ensure_dense)
482    /// before accessing `features` directly on a sparse dataset.
483    pub fn from_sparse(
484        csc: CscMatrix,
485        target: Vec<f64>,
486        feature_names: Vec<String>,
487        target_name: impl Into<String>,
488    ) -> Self {
489        Self {
490            features: Vec::new(),
491            target,
492            feature_names,
493            target_name: target_name.into(),
494            class_labels: None,
495            matrix: OnceLock::new(),
496            row_major_cache: None,
497            storage: Storage::Sparse(csc),
498        }
499    }
500
501    /// Whether this dataset uses sparse storage.
502    #[inline]
503    pub fn is_sparse(&self) -> bool {
504        matches!(self.storage, Storage::Sparse(_))
505    }
506
507    /// Get the sparse CSC matrix if available.
508    pub fn sparse_csc(&self) -> Option<&CscMatrix> {
509        match &self.storage {
510            Storage::Sparse(m) => Some(m),
511            Storage::Dense => None,
512        }
513    }
514
515    /// Get the sparse CSR matrix (converted from CSC on demand).
516    pub fn sparse_csr(&self) -> Option<crate::sparse::CsrMatrix> {
517        self.sparse_csc().map(CscMatrix::to_csr)
518    }
519
520    /// Compute descriptive statistics for every feature column and the target.
521    ///
522    /// Returns one [`ColumnStats`] per feature (in order) followed by one for
523    /// the target column. NaN values are excluded from all computations.
524    /// Standard deviation uses `ddof=1` (sample std) to match pandas.
525    pub fn summary(&self) -> Vec<ColumnStats> {
526        let n_feat = self.n_features();
527        let mut stats = Vec::with_capacity(n_feat + 1);
528
529        for j in 0..n_feat {
530            let name = self
531                .feature_names
532                .get(j)
533                .cloned()
534                .unwrap_or_else(|| format!("feature[{j}]"));
535
536            let col: Vec<f64> = if let Some(csc) = self.sparse_csc() {
537                // Reconstruct the full column (zeros + stored values).
538                let n_rows = csc.n_rows();
539                let mut dense = vec![0.0_f64; n_rows];
540                for (i, v) in csc.col(j).iter() {
541                    dense[i] = v;
542                }
543                dense
544            } else {
545                self.features[j].clone()
546            };
547
548            stats.push(compute_column_stats(&name, &col));
549        }
550
551        stats.push(compute_column_stats(&self.target_name, &self.target));
552        stats
553    }
554
555    /// Print a pandas-style descriptive statistics table to stdout.
556    ///
557    /// Internally calls [`summary()`](Self::summary).
558    pub fn describe(&self) {
559        let stats = self.summary();
560        if stats.is_empty() {
561            return;
562        }
563
564        let labels = ["count", "mean", "std", "min", "25%", "50%", "75%", "max"];
565        let label_width = labels.iter().map(|l| l.len()).max().unwrap_or(0);
566
567        let col_widths: Vec<usize> = stats.iter().map(|s| s.name.len().max(12)).collect();
568
569        // Header row.
570        print!("{:>width$}", "", width = label_width);
571        for (i, s) in stats.iter().enumerate() {
572            print!("  {:>width$}", s.name, width = col_widths[i]);
573        }
574        println!();
575
576        // Data rows.
577        for (row_idx, label) in labels.iter().enumerate() {
578            print!("{:>width$}", label, width = label_width);
579            for (i, s) in stats.iter().enumerate() {
580                let val = match row_idx {
581                    0 => s.count as f64,
582                    1 => s.mean,
583                    2 => s.std,
584                    3 => s.min,
585                    4 => s.q25,
586                    5 => s.median,
587                    6 => s.q75,
588                    7 => s.max,
589                    _ => unreachable!(),
590                };
591                print!("  {:>width$.6}", val, width = col_widths[i]);
592            }
593            println!();
594        }
595    }
596
597    /// Populate the `features` field from sparse storage.
598    ///
599    /// No-op if the dataset is already dense. After calling this,
600    /// `features[j][i]` is available as usual.
601    pub fn ensure_dense(&mut self) {
602        if let Storage::Sparse(csc) = &self.storage {
603            let n_cols = csc.n_cols();
604            let n_rows = csc.n_rows();
605            let mut features = vec![vec![0.0; n_rows]; n_cols];
606            for (j, feat_col) in features.iter_mut().enumerate() {
607                for (i, v) in csc.col(j).iter() {
608                    feat_col[i] = v;
609                }
610            }
611            self.features = features;
612            self.matrix = OnceLock::new();
613        }
614    }
615}
616
617/// Subset a CSC matrix by selecting specific row indices.
618///
619/// Returns a new CSC matrix with `indices.len()` rows, where row `k` in the
620/// output corresponds to row `indices[k]` in the input.
621///
622/// Uses `CscMatrix::from_dense` (column-major) to avoid a pre-existing
623/// dedup bug in `CscMatrix::from_triplets`.
624fn subset_csc(csc: &CscMatrix, indices: &[usize]) -> CscMatrix {
625    let n_new_rows = indices.len();
626    let n_cols = csc.n_cols();
627
628    // Build old→new row mapping.
629    let mut row_map = std::collections::HashMap::with_capacity(n_new_rows);
630    for (new_idx, &old_idx) in indices.iter().enumerate() {
631        row_map.insert(old_idx, new_idx);
632    }
633
634    // Build column-major dense vectors for the subset.
635    let mut cols: Vec<Vec<f64>> = vec![vec![0.0; n_new_rows]; n_cols];
636    for (j, col) in cols.iter_mut().enumerate() {
637        for (old_row, val) in csc.col(j).iter() {
638            if let Some(&new_row) = row_map.get(&old_row) {
639                col[new_row] = val;
640            }
641        }
642    }
643
644    CscMatrix::from_dense(&cols)
645}
646
647/// Compute descriptive statistics for a single column, filtering NaN values.
648fn compute_column_stats(name: &str, values: &[f64]) -> ColumnStats {
649    let mut sorted: Vec<f64> = values.iter().copied().filter(|v| v.is_finite()).collect();
650    sorted.sort_unstable_by(|a, b| a.total_cmp(b));
651
652    let count = sorted.len();
653    if count == 0 {
654        return ColumnStats {
655            name: name.to_string(),
656            count: 0,
657            mean: f64::NAN,
658            std: f64::NAN,
659            min: f64::NAN,
660            q25: f64::NAN,
661            median: f64::NAN,
662            q75: f64::NAN,
663            max: f64::NAN,
664        };
665    }
666
667    let sum: f64 = sorted.iter().sum();
668    let mean = sum / count as f64;
669
670    let std = if count <= 1 {
671        0.0
672    } else {
673        let var = sorted.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / (count - 1) as f64;
674        var.sqrt()
675    };
676
677    let min = sorted[0];
678    let max = sorted[count - 1];
679    let q25 = percentile(&sorted, 0.25);
680    let median = percentile(&sorted, 0.50);
681    let q75 = percentile(&sorted, 0.75);
682
683    ColumnStats {
684        name: name.to_string(),
685        count,
686        mean,
687        std,
688        min,
689        q25,
690        median,
691        q75,
692        max,
693    }
694}
695
696/// Linear-interpolation percentile on a pre-sorted slice.
697fn percentile(sorted: &[f64], p: f64) -> f64 {
698    let n = sorted.len();
699    if n == 1 {
700        return sorted[0];
701    }
702    let idx = p * (n - 1) as f64;
703    let lo = idx.floor() as usize;
704    let hi = lo + 1;
705    let frac = idx - lo as f64;
706    if hi >= n {
707        sorted[lo]
708    } else {
709        sorted[lo] * (1.0 - frac) + sorted[hi] * frac
710    }
711}
712
713#[cfg(feature = "csv")]
714/// Parse a target column: try numeric, fall back to label encoding.
715///
716/// Returns `(encoded_values, Option<class_labels>)`.
717fn parse_target_column(rows: &[Vec<String>], col_idx: usize) -> (Vec<f64>, Option<Vec<String>>) {
718    // Try parsing all as numeric first.
719    let numeric: Vec<Option<f64>> = rows
720        .iter()
721        .map(|row| row.get(col_idx).and_then(|s| s.parse::<f64>().ok()))
722        .collect();
723
724    let all_numeric = numeric.iter().all(std::option::Option::is_some);
725    if all_numeric {
726        return (numeric.into_iter().flatten().collect(), None);
727    }
728
729    // Label-encode string values.
730    let mut labels: Vec<String> = Vec::new();
731    let mut encoded = Vec::with_capacity(rows.len());
732
733    for row in rows {
734        let val = row.get(col_idx).map_or("", std::string::String::as_str);
735        let idx = labels.iter().position(|l| l == val).unwrap_or_else(|| {
736            labels.push(val.to_string());
737            labels.len() - 1
738        });
739        encoded.push(idx as f64);
740    }
741
742    (encoded, Some(labels))
743}
744
745#[cfg(test)]
746#[allow(clippy::float_cmp)]
747mod tests {
748    use super::*;
749
750    #[test]
751    fn test_dataset_new() {
752        let features = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
753        let target = vec![0.0, 1.0, 0.0];
754        let ds = Dataset::new(features, target, vec!["f1".into(), "f2".into()], "label");
755        assert_eq!(ds.n_samples(), 3);
756        assert_eq!(ds.n_features(), 2);
757        assert_eq!(ds.feature(0), &[1.0, 2.0, 3.0]);
758        assert_eq!(ds.sample(1), vec![2.0, 5.0]);
759    }
760
761    #[cfg(feature = "csv")]
762    #[test]
763    fn test_dataset_from_csv_reader() {
764        let csv = "f1,f2,target\n1.0,4.0,a\n2.0,5.0,b\n3.0,6.0,a\n";
765        let ds = Dataset::from_csv_reader(csv.as_bytes(), "target").unwrap();
766        assert_eq!(ds.n_samples(), 3);
767        assert_eq!(ds.n_features(), 2);
768        assert_eq!(ds.target, vec![0.0, 1.0, 0.0]);
769        assert_eq!(
770            ds.class_labels,
771            Some(vec!["a".to_string(), "b".to_string()])
772        );
773    }
774
775    #[test]
776    fn test_dataset_subset() {
777        let features = vec![vec![1.0, 2.0, 3.0, 4.0], vec![10.0, 20.0, 30.0, 40.0]];
778        let target = vec![0.0, 1.0, 0.0, 1.0];
779        let ds = Dataset::new(features, target, vec!["a".into(), "b".into()], "t");
780        let sub = ds.subset(&[0, 2]);
781        assert_eq!(sub.n_samples(), 2);
782        assert_eq!(sub.feature(0), &[1.0, 3.0]);
783        assert_eq!(sub.target, vec![0.0, 0.0]);
784    }
785
786    #[cfg(feature = "csv")]
787    #[test]
788    fn test_empty_csv() {
789        let csv = "f1,target\n";
790        let err = Dataset::from_csv_reader(csv.as_bytes(), "target");
791        assert!(err.is_err());
792    }
793
794    #[test]
795    fn test_n_classes() {
796        let ds = Dataset::new(
797            vec![vec![1.0, 2.0, 3.0]],
798            vec![0.0, 1.0, 2.0],
799            vec!["f".into()],
800            "t",
801        );
802        assert_eq!(ds.n_classes(), 3);
803    }
804
805    #[test]
806    fn test_matrix_accessor() {
807        let features = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
808        let ds = Dataset::new(features, vec![0.0, 1.0], vec!["a".into(), "b".into()], "t");
809        let mat = ds.matrix();
810        assert_eq!(mat.n_rows(), 2);
811        assert_eq!(mat.n_cols(), 2);
812        assert_eq!(mat.col(0), &[1.0, 2.0]);
813        assert_eq!(mat.col(1), &[3.0, 4.0]);
814    }
815
816    #[test]
817    fn test_from_matrix() {
818        let mat = DenseMatrix::from_col_major(vec![vec![1.0, 2.0], vec![3.0, 4.0]]).unwrap();
819        let ds = Dataset::from_matrix(mat, vec![0.0, 1.0], vec!["a".into(), "b".into()], "t");
820        assert_eq!(ds.n_samples(), 2);
821        assert_eq!(ds.n_features(), 2);
822        assert_eq!(ds.feature(0), &[1.0, 2.0]);
823        assert_eq!(ds.matrix().col(1), &[3.0, 4.0]);
824    }
825
826    // -------------------------------------------------------------------
827    // Sparse dataset tests
828    // -------------------------------------------------------------------
829
830    fn sample_csc() -> CscMatrix {
831        // 3 samples × 2 features (column-major):
832        //   col 0: [1.0, 0.0, 3.0]
833        //   col 1: [0.0, 2.0, 0.0]
834        CscMatrix::from_dense(&[vec![1.0, 0.0, 3.0], vec![0.0, 2.0, 0.0]])
835    }
836
837    #[test]
838    fn test_from_sparse_basic() {
839        let csc = sample_csc();
840        let ds = Dataset::from_sparse(csc, vec![0.0, 1.0, 0.0], vec!["a".into(), "b".into()], "t");
841        assert!(ds.is_sparse());
842        assert_eq!(ds.n_samples(), 3);
843        assert_eq!(ds.n_features(), 2);
844    }
845
846    #[test]
847    fn test_sparse_csc_accessor() {
848        let csc = sample_csc();
849        let ds = Dataset::from_sparse(csc, vec![0.0, 1.0, 0.0], vec!["a".into(), "b".into()], "t");
850        let csc_ref = ds.sparse_csc().expect("should have CSC");
851        assert_eq!(csc_ref.n_rows(), 3);
852        assert_eq!(csc_ref.n_cols(), 2);
853        assert_eq!(csc_ref.get(0, 0), 1.0);
854        assert_eq!(csc_ref.get(1, 1), 2.0);
855        assert_eq!(csc_ref.get(1, 0), 0.0);
856    }
857
858    #[test]
859    fn test_sparse_csr_conversion() {
860        let csc = sample_csc();
861        let ds = Dataset::from_sparse(csc, vec![0.0, 1.0, 0.0], vec!["a".into(), "b".into()], "t");
862        let csr = ds.sparse_csr().expect("should convert to CSR");
863        assert_eq!(csr.n_rows(), 3);
864        assert_eq!(csr.n_cols(), 2);
865        assert_eq!(csr.get(0, 0), 1.0);
866        assert_eq!(csr.get(2, 0), 3.0);
867        assert_eq!(csr.get(1, 1), 2.0);
868    }
869
870    #[test]
871    fn test_sparse_subset() {
872        let csc = sample_csc();
873        let ds = Dataset::from_sparse(csc, vec![0.0, 1.0, 2.0], vec!["a".into(), "b".into()], "t");
874        let sub = ds.subset(&[0, 2]);
875        assert!(sub.is_sparse());
876        assert_eq!(sub.n_samples(), 2);
877        assert_eq!(sub.n_features(), 2);
878        assert_eq!(sub.target, vec![0.0, 2.0]);
879        let csc_ref = sub.sparse_csc().unwrap();
880        assert_eq!(csc_ref.get(0, 0), 1.0); // row 0 of subset = original row 0
881        assert_eq!(csc_ref.get(1, 0), 3.0); // row 1 of subset = original row 2
882    }
883
884    #[test]
885    fn test_sparse_with_class_labels() {
886        let csc = sample_csc();
887        let ds = Dataset::from_sparse(csc, vec![0.0, 1.0, 0.0], vec!["a".into(), "b".into()], "t")
888            .with_class_labels(vec!["cat".into(), "dog".into()]);
889        assert!(ds.is_sparse());
890        assert_eq!(
891            ds.class_labels,
892            Some(vec!["cat".to_string(), "dog".to_string()])
893        );
894    }
895
896    #[test]
897    fn test_n_features_consistency() {
898        // Dense and sparse datasets with same data should report same n_features.
899        let dense_ds = Dataset::new(
900            vec![vec![1.0, 0.0, 3.0], vec![0.0, 2.0, 0.0]],
901            vec![0.0, 1.0, 0.0],
902            vec!["a".into(), "b".into()],
903            "t",
904        );
905        let csc = sample_csc();
906        let sparse_ds =
907            Dataset::from_sparse(csc, vec![0.0, 1.0, 0.0], vec!["a".into(), "b".into()], "t");
908        assert_eq!(dense_ds.n_features(), sparse_ds.n_features());
909    }
910
911    #[test]
912    fn test_ensure_dense() {
913        let csc = sample_csc();
914        let mut ds =
915            Dataset::from_sparse(csc, vec![0.0, 1.0, 0.0], vec!["a".into(), "b".into()], "t");
916        assert!(ds.features.is_empty());
917        ds.ensure_dense();
918        assert_eq!(ds.features.len(), 2);
919        assert_eq!(ds.features[0], vec![1.0, 0.0, 3.0]);
920        assert_eq!(ds.features[1], vec![0.0, 2.0, 0.0]);
921    }
922
923    #[test]
924    fn test_dense_not_sparse() {
925        let ds = Dataset::new(vec![vec![1.0, 2.0]], vec![0.0, 1.0], vec!["x".into()], "y");
926        assert!(!ds.is_sparse());
927        assert!(ds.sparse_csc().is_none());
928        assert!(ds.sparse_csr().is_none());
929    }
930
931    #[test]
932    fn test_matrix_lazy_rebuild_after_invalidate() {
933        let features = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
934        let mut ds = Dataset::new(features, vec![0.0, 1.0], vec!["a".into(), "b".into()], "t");
935
936        // Matrix is available after construction.
937        assert_eq!(ds.matrix().col(0), &[1.0, 2.0]);
938
939        // Invalidate.
940        ds.invalidate_matrix();
941
942        // matrix() should lazily rebuild — no panic.
943        assert_eq!(ds.matrix().col(0), &[1.0, 2.0]);
944        assert_eq!(ds.matrix().col(1), &[3.0, 4.0]);
945    }
946
947    #[test]
948    fn test_describe_summary() {
949        let features = vec![vec![1.0, 2.0, 3.0, 4.0], vec![10.0, 20.0, 30.0, 40.0]];
950        let target = vec![0.0, 1.0, 0.0, 1.0];
951        let ds = Dataset::new(features, target, vec!["a".into(), "b".into()], "t");
952
953        let stats = ds.summary();
954        assert_eq!(stats.len(), 3); // 2 features + 1 target
955
956        // Feature "a": [1, 2, 3, 4]
957        assert_eq!(stats[0].name, "a");
958        assert_eq!(stats[0].count, 4);
959        assert!((stats[0].mean - 2.5).abs() < 1e-10);
960        assert!((stats[0].min - 1.0).abs() < 1e-10);
961        assert!((stats[0].max - 4.0).abs() < 1e-10);
962
963        // Feature "b": [10, 20, 30, 40]
964        assert_eq!(stats[1].name, "b");
965        assert_eq!(stats[1].count, 4);
966        assert!((stats[1].mean - 25.0).abs() < 1e-10);
967        assert!((stats[1].min - 10.0).abs() < 1e-10);
968        assert!((stats[1].max - 40.0).abs() < 1e-10);
969
970        // Target "t": [0, 1, 0, 1]
971        assert_eq!(stats[2].name, "t");
972        assert_eq!(stats[2].count, 4);
973        assert!((stats[2].mean - 0.5).abs() < 1e-10);
974
975        // describe() should not panic.
976        ds.describe();
977    }
978
979    #[test]
980    fn test_matrix_lazy_rebuild_reflects_feature_mutation() {
981        let features = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
982        let mut ds = Dataset::new(features, vec![0.0, 1.0], vec!["a".into(), "b".into()], "t");
983
984        // Mutate features and invalidate.
985        ds.features[0][0] = 99.0;
986        ds.invalidate_matrix();
987
988        // Lazy rebuild should reflect the mutation.
989        assert_eq!(ds.matrix().col(0), &[99.0, 2.0]);
990    }
991}