Skip to main content

scry_learn/preprocess/
one_hot.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! One-hot encoding for categorical features.
3//!
4//! Expands integer-encoded categorical columns into binary indicator
5//! columns.  Supports `DropStrategy` (avoid multicollinearity) and
6//! `UnknownStrategy` (handle unseen categories at transform time).
7//!
8//! # Example
9//!
10//! ```ignore
11//! use scry_learn::prelude::*;
12//!
13//! // Assume feature 0 is label-encoded colour: 0=red, 1=green, 2=blue
14//! let mut enc = OneHotEncoder::new(vec![0])
15//!     .drop(DropStrategy::First)
16//!     .handle_unknown(UnknownStrategy::Ignore);
17//! enc.fit_transform(&mut dataset)?;
18//! ```
19
20use crate::dataset::Dataset;
21use crate::error::{Result, ScryLearnError};
22use crate::preprocess::Transformer;
23
24// ── Public enums ──────────────────────────────────────────────────
25
26/// Strategy for dropping one-hot columns to avoid multicollinearity.
27#[derive(Clone, Debug, Default, PartialEq, Eq)]
28#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
29#[non_exhaustive]
30pub enum DropStrategy {
31    /// Keep all categories (default).
32    #[default]
33    None,
34    /// Drop the first category from each feature.
35    First,
36    /// Drop the first category only for binary (2-category) features.
37    IfBinary,
38}
39
40/// Strategy for handling categories seen at transform time but not at fit time.
41#[derive(Clone, Debug, Default, PartialEq, Eq)]
42#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
43#[non_exhaustive]
44pub enum UnknownStrategy {
45    /// Raise an error (default).
46    #[default]
47    Error,
48    /// Encode as all-zeros row.
49    Ignore,
50}
51
52// ── Struct ────────────────────────────────────────────────────────
53
54/// One-hot encoder for integer-encoded categorical features.
55///
56/// Replaces each selected column with `n_categories` binary columns
57/// (minus any dropped by the [`DropStrategy`]).  Non-selected columns
58/// pass through untouched.
59#[derive(Clone, Debug)]
60#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
61#[non_exhaustive]
62pub struct OneHotEncoder {
63    feature_indices: Vec<usize>,
64    drop_strategy: DropStrategy,
65    unknown_strategy: UnknownStrategy,
66    // — fitted state —
67    /// `categories[i]` = sorted unique values of `feature_indices[i]`.
68    categories: Vec<Vec<f64>>,
69    /// Original feature names captured at fit time.
70    orig_feature_names: Vec<String>,
71    fitted: bool,
72}
73
74// ── Builder ───────────────────────────────────────────────────────
75
76impl OneHotEncoder {
77    /// Create a new encoder for the given feature column indices.
78    pub fn new(feature_indices: Vec<usize>) -> Self {
79        Self {
80            feature_indices,
81            drop_strategy: DropStrategy::None,
82            unknown_strategy: UnknownStrategy::Error,
83            categories: Vec::new(),
84            orig_feature_names: Vec::new(),
85            fitted: false,
86        }
87    }
88
89    /// Set the drop strategy.
90    pub fn drop(mut self, strategy: DropStrategy) -> Self {
91        self.drop_strategy = strategy;
92        self
93    }
94
95    /// Set the unknown-category strategy.
96    pub fn handle_unknown(mut self, strategy: UnknownStrategy) -> Self {
97        self.unknown_strategy = strategy;
98        self
99    }
100
101    // ── Accessors ─────────────────────────────────────────────────
102
103    /// Learned categories for each encoded feature.
104    pub fn categories(&self) -> &[Vec<f64>] {
105        &self.categories
106    }
107
108    /// Compute the output feature names that `transform` would produce.
109    pub fn get_feature_names(&self) -> Vec<String> {
110        if !self.fitted || self.orig_feature_names.is_empty() {
111            return Vec::new();
112        }
113        let encoded_set: std::collections::HashSet<usize> =
114            self.feature_indices.iter().copied().collect();
115        let mut names = Vec::new();
116        for (j, orig_name) in self.orig_feature_names.iter().enumerate() {
117            if encoded_set.contains(&j) {
118                let cat_idx = self
119                    .feature_indices
120                    .iter()
121                    .position(|&fi| fi == j)
122                    .expect("encoded_set built from feature_indices");
123                let cats = &self.categories[cat_idx];
124                let skip = self.n_drop(cat_idx);
125                for (ci, &cat_val) in cats.iter().enumerate() {
126                    if ci < skip {
127                        continue;
128                    }
129                    names.push(format!("{}_{}", orig_name, cat_val as i64));
130                }
131            } else {
132                names.push(orig_name.clone());
133            }
134        }
135        names
136    }
137}
138
139// ── Helpers ───────────────────────────────────────────────────────
140
141impl OneHotEncoder {
142    /// Should we drop a column for feature `idx` (index into `categories`)?
143    /// Returns the number of categories to *skip* from the front (0 or 1).
144    fn n_drop(&self, cat_idx: usize) -> usize {
145        match self.drop_strategy {
146            DropStrategy::None => 0,
147            DropStrategy::First => 1,
148            DropStrategy::IfBinary => usize::from(self.categories[cat_idx].len() == 2),
149        }
150    }
151}
152
153// ── Transformer impl ─────────────────────────────────────────────
154
155impl Transformer for OneHotEncoder {
156    fn fit(&mut self, data: &Dataset) -> Result<()> {
157        if data.n_samples() == 0 {
158            return Err(ScryLearnError::EmptyDataset);
159        }
160        for &idx in &self.feature_indices {
161            if idx >= data.n_features() {
162                return Err(ScryLearnError::InvalidParameter(format!(
163                    "feature index {idx} out of range (dataset has {} features)",
164                    data.n_features()
165                )));
166            }
167        }
168
169        self.categories.clear();
170        self.orig_feature_names.clone_from(&data.feature_names);
171        for &idx in &self.feature_indices {
172            let mut unique: Vec<f64> = data.features[idx].clone();
173            unique.sort_by(|a, b| a.total_cmp(b));
174            unique.dedup();
175            self.categories.push(unique);
176        }
177        self.fitted = true;
178        Ok(())
179    }
180
181    fn transform(&self, data: &mut Dataset) -> Result<()> {
182        if !self.fitted {
183            return Err(ScryLearnError::NotFitted);
184        }
185        let n = data.n_samples();
186
187        // Build the set of encoded column indices for fast lookup.
188        let encoded_set: std::collections::HashSet<usize> =
189            self.feature_indices.iter().copied().collect();
190
191        let mut new_features: Vec<Vec<f64>> = Vec::new();
192        let mut new_names: Vec<String> = Vec::new();
193
194        for j in 0..data.n_features() {
195            if encoded_set.contains(&j) {
196                // Find which cat_idx this corresponds to.
197                let cat_idx = self
198                    .feature_indices
199                    .iter()
200                    .position(|&fi| fi == j)
201                    .ok_or(ScryLearnError::InvalidFeatureIndex(j))?;
202                let cats = &self.categories[cat_idx];
203                let skip = self.n_drop(cat_idx);
204                let orig_name = &data.feature_names[j];
205
206                for (ci, &cat_val) in cats.iter().enumerate() {
207                    if ci < skip {
208                        continue;
209                    }
210                    let mut col = Vec::with_capacity(n);
211                    for s in 0..n {
212                        let val = data.features[j][s];
213                        if (val - cat_val).abs() < 1e-10 {
214                            col.push(1.0);
215                        } else if cats.iter().any(|&c| (val - c).abs() < 1e-10) {
216                            col.push(0.0);
217                        } else {
218                            // Unknown category.
219                            match self.unknown_strategy {
220                                UnknownStrategy::Error => {
221                                    return Err(ScryLearnError::InvalidParameter(format!(
222                                        "unknown category {val} in feature '{orig_name}'"
223                                    )));
224                                }
225                                UnknownStrategy::Ignore => {
226                                    col.push(0.0);
227                                }
228                            }
229                        }
230                    }
231                    new_features.push(col);
232                    new_names.push(format!("{}_{}", orig_name, cat_val as i64));
233                }
234            } else {
235                // Pass through.
236                new_features.push(data.features[j].clone());
237                new_names.push(data.feature_names[j].clone());
238            }
239        }
240
241        data.features = new_features;
242        data.feature_names = new_names;
243        data.sync_matrix();
244        Ok(())
245    }
246
247    fn inverse_transform(&self, data: &mut Dataset) -> Result<()> {
248        if !self.fitted {
249            return Err(ScryLearnError::NotFitted);
250        }
251        let n = data.n_samples();
252
253        // We need to identify the one-hot column groups and collapse them back.
254        // Walk through features and feature_names to reconstruct.
255        let mut new_features: Vec<Vec<f64>> = Vec::new();
256        let mut new_names: Vec<String> = Vec::new();
257
258        let mut j = 0;
259        let mut cat_idx = 0;
260
261        // Build a plan: for each original feature, was it encoded?
262        // We reconstruct by scanning through the current features.
263        // One-hot columns are named "<orig>_<val>".
264        // We detect groups by checking consecutive columns whose names
265        // share a prefix matching a known encoded feature.
266
267        // For simplicity, use the category counts directly:
268        while j < data.n_features() {
269            if cat_idx < self.feature_indices.len() {
270                let cats = &self.categories[cat_idx];
271                let skip = self.n_drop(cat_idx);
272                let n_cols = cats.len() - skip;
273
274                // Check if the current block of columns looks like one-hot.
275                if j + n_cols <= data.n_features() {
276                    // Try to extract the original feature name from the first column name.
277                    let first_name = &data.feature_names[j];
278                    let prefix = first_name
279                        .rfind('_')
280                        .map_or(first_name.as_str(), |pos| &first_name[..pos]);
281
282                    // Collapse: for each sample, find which column is 1.
283                    let mut col = Vec::with_capacity(n);
284                    for s in 0..n {
285                        let mut found = false;
286                        for (ci, &cat_val) in cats.iter().enumerate().skip(skip) {
287                            let col_idx = j + ci - skip;
288                            if data.features[col_idx][s] > 0.5 {
289                                col.push(cat_val);
290                                found = true;
291                                break;
292                            }
293                        }
294                        if !found {
295                            // If drop was used, the dropped category is the one
296                            // where all columns are zero.
297                            if skip > 0 {
298                                col.push(cats[0]);
299                            } else {
300                                col.push(f64::NAN);
301                            }
302                        }
303                    }
304                    new_features.push(col);
305                    new_names.push(prefix.to_string());
306                    j += n_cols;
307                    cat_idx += 1;
308                    continue;
309                }
310            }
311
312            // Pass through.
313            new_features.push(data.features[j].clone());
314            new_names.push(data.feature_names[j].clone());
315            j += 1;
316        }
317
318        data.features = new_features;
319        data.feature_names = new_names;
320        data.sync_matrix();
321        Ok(())
322    }
323}
324
325// ── Tests ─────────────────────────────────────────────────────────
326
327#[cfg(test)]
328#[allow(clippy::float_cmp)]
329mod tests {
330    use super::*;
331
332    fn color_dataset() -> Dataset {
333        // Feature 0: color (0=red, 1=green, 2=blue), Feature 1: numeric.
334        Dataset::new(
335            vec![
336                vec![0.0, 1.0, 2.0, 0.0, 1.0, 2.0],
337                vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
338            ],
339            vec![0.0, 0.0, 1.0, 1.0, 0.0, 1.0],
340            vec!["color".into(), "value".into()],
341            "target",
342        )
343    }
344
345    #[test]
346    fn onehot_basic_encoding() {
347        let mut ds = color_dataset();
348        let mut enc = OneHotEncoder::new(vec![0]);
349        enc.fit_transform(&mut ds).unwrap();
350
351        // 3 one-hot columns + 1 passthrough = 4 features.
352        assert_eq!(ds.n_features(), 4);
353        assert_eq!(ds.feature_names[0], "color_0");
354        assert_eq!(ds.feature_names[1], "color_1");
355        assert_eq!(ds.feature_names[2], "color_2");
356        assert_eq!(ds.feature_names[3], "value");
357
358        // Sample 0: color=0 → [1, 0, 0]
359        assert_eq!(ds.features[0][0], 1.0);
360        assert_eq!(ds.features[1][0], 0.0);
361        assert_eq!(ds.features[2][0], 0.0);
362
363        // Sample 2: color=2 → [0, 0, 1]
364        assert_eq!(ds.features[0][2], 0.0);
365        assert_eq!(ds.features[1][2], 0.0);
366        assert_eq!(ds.features[2][2], 1.0);
367    }
368
369    #[test]
370    fn onehot_drop_first() {
371        let mut ds = color_dataset();
372        let mut enc = OneHotEncoder::new(vec![0]).drop(DropStrategy::First);
373        enc.fit_transform(&mut ds).unwrap();
374
375        // 2 one-hot columns (dropped first) + 1 passthrough = 3 features.
376        assert_eq!(ds.n_features(), 3);
377        assert_eq!(ds.feature_names[0], "color_1");
378        assert_eq!(ds.feature_names[1], "color_2");
379    }
380
381    #[test]
382    fn onehot_drop_if_binary() {
383        // Binary feature: only 2 categories.
384        let mut ds = Dataset::new(
385            vec![vec![0.0, 1.0, 0.0, 1.0], vec![10.0, 20.0, 30.0, 40.0]],
386            vec![0.0; 4],
387            vec!["binary".into(), "num".into()],
388            "y",
389        );
390        let mut enc = OneHotEncoder::new(vec![0]).drop(DropStrategy::IfBinary);
391        enc.fit_transform(&mut ds).unwrap();
392
393        // Binary → drop first → 1 column + 1 passthrough = 2.
394        assert_eq!(ds.n_features(), 2);
395        assert_eq!(ds.feature_names[0], "binary_1");
396
397        // Non-binary (3 cats) should keep all.
398        let mut ds3 = color_dataset();
399        let mut enc3 = OneHotEncoder::new(vec![0]).drop(DropStrategy::IfBinary);
400        enc3.fit_transform(&mut ds3).unwrap();
401        assert_eq!(ds3.n_features(), 4); // 3 one-hot + 1 passthrough
402    }
403
404    #[test]
405    fn onehot_unknown_error() {
406        let mut ds = color_dataset();
407        let mut enc = OneHotEncoder::new(vec![0]);
408        enc.fit(&ds).unwrap();
409
410        // Inject an unknown category.
411        ds.features[0][0] = 99.0;
412        assert!(enc.transform(&mut ds).is_err());
413    }
414
415    #[test]
416    fn onehot_unknown_ignore() {
417        let mut ds = color_dataset();
418        let mut enc = OneHotEncoder::new(vec![0]).handle_unknown(UnknownStrategy::Ignore);
419        enc.fit(&ds).unwrap();
420
421        // Inject unknown.
422        ds.features[0][0] = 99.0;
423        enc.transform(&mut ds).unwrap();
424
425        // Unknown → all zeros.
426        assert_eq!(ds.features[0][0], 0.0); // color_0
427        assert_eq!(ds.features[1][0], 0.0); // color_1
428        assert_eq!(ds.features[2][0], 0.0); // color_2
429    }
430
431    #[test]
432    fn onehot_roundtrip_inverse() {
433        let original = color_dataset();
434        let mut ds = original.clone();
435        let mut enc = OneHotEncoder::new(vec![0]);
436        enc.fit_transform(&mut ds).unwrap();
437        enc.inverse_transform(&mut ds).unwrap();
438
439        assert_eq!(ds.n_features(), 2);
440        for i in 0..original.n_samples() {
441            assert!(
442                (ds.features[0][i] - original.features[0][i]).abs() < 1e-10,
443                "roundtrip mismatch at sample {i}"
444            );
445        }
446    }
447
448    #[test]
449    fn onehot_feature_names() {
450        let mut ds = color_dataset();
451        let mut enc = OneHotEncoder::new(vec![0]);
452        enc.fit_transform(&mut ds).unwrap();
453
454        let names = enc.get_feature_names();
455        assert_eq!(names, &["color_0", "color_1", "color_2", "value"]);
456    }
457
458    #[test]
459    fn onehot_not_fitted_error() {
460        let enc = OneHotEncoder::new(vec![0]);
461        let mut ds = color_dataset();
462        assert!(enc.transform(&mut ds).is_err());
463    }
464
465    #[test]
466    fn onehot_multiple_features() {
467        // Encode two features simultaneously.
468        let mut ds = Dataset::new(
469            vec![
470                vec![0.0, 1.0, 0.0, 1.0], // binary
471                vec![0.0, 1.0, 2.0, 0.0], // 3-cat
472                vec![5.0, 6.0, 7.0, 8.0], // numeric (passthrough)
473            ],
474            vec![0.0; 4],
475            vec!["a".into(), "b".into(), "num".into()],
476            "y",
477        );
478        let mut enc = OneHotEncoder::new(vec![0, 1]);
479        enc.fit_transform(&mut ds).unwrap();
480
481        // 2 one-hot from "a" + 3 one-hot from "b" + 1 passthrough = 6.
482        assert_eq!(ds.n_features(), 6);
483        assert_eq!(ds.feature_names[0], "a_0");
484        assert_eq!(ds.feature_names[1], "a_1");
485        assert_eq!(ds.feature_names[2], "b_0");
486        assert_eq!(ds.feature_names[3], "b_1");
487        assert_eq!(ds.feature_names[4], "b_2");
488        assert_eq!(ds.feature_names[5], "num");
489    }
490}