Skip to main content

scry_learn/preprocess/
column_transformer.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Column-based transformer composition.
3//!
4//! [`ColumnTransformer`] applies different transformers to different subsets
5//! of feature columns and concatenates the results.
6//!
7//! # Example
8//!
9//! ```ignore
10//! use scry_learn::preprocess::{ColumnTransformer, StandardScaler, MinMaxScaler};
11//!
12//! let ct = ColumnTransformer::new()
13//!     .add(&[0, 1], StandardScaler::new())
14//!     .add(&[2, 3], MinMaxScaler::new());
15//! ```
16
17use crate::dataset::Dataset;
18use crate::error::{Result, ScryLearnError};
19use crate::preprocess::Transformer;
20
21/// Internal trait-object wrapper so we can store heterogeneous transformers.
22trait BoxedTransformer {
23    fn fit(&mut self, data: &Dataset) -> Result<()>;
24    fn transform(&self, data: &mut Dataset) -> Result<()>;
25}
26
27impl<T: Transformer> BoxedTransformer for T {
28    fn fit(&mut self, data: &Dataset) -> Result<()> {
29        Transformer::fit(self, data)
30    }
31    fn transform(&self, data: &mut Dataset) -> Result<()> {
32        Transformer::transform(self, data)
33    }
34}
35
36/// A step within the column transformer: column indices + transformer.
37struct TransformerStep {
38    columns: Vec<usize>,
39    transformer: Box<dyn BoxedTransformer>,
40}
41
42/// Apply different transformers to different column subsets, then
43/// concatenate all transformed outputs.
44///
45/// # Builder API
46///
47/// ```ignore
48/// let ct = ColumnTransformer::new()
49///     .add(&[0, 1], StandardScaler::new())
50///     .add(&[2, 3], MinMaxScaler::new());
51/// ct.fit_transform(&mut ds)?;
52/// ```
53#[non_exhaustive]
54pub struct ColumnTransformer {
55    steps: Vec<TransformerStep>,
56    fitted: bool,
57}
58
59impl ColumnTransformer {
60    /// Create an empty column transformer.
61    pub fn new() -> Self {
62        Self {
63            steps: Vec::new(),
64            fitted: false,
65        }
66    }
67
68    /// Add a transformer to be applied to the given column indices.
69    pub fn add<T: Transformer + 'static>(mut self, columns: &[usize], transformer: T) -> Self {
70        self.steps.push(TransformerStep {
71            columns: columns.to_vec(),
72            transformer: Box::new(transformer),
73        });
74        self
75    }
76}
77
78impl Default for ColumnTransformer {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84/// Extract a sub-dataset containing only the specified feature columns.
85fn extract_columns(data: &Dataset, cols: &[usize]) -> Dataset {
86    let features: Vec<Vec<f64>> = cols.iter().map(|&c| data.features[c].clone()).collect();
87    let names: Vec<String> = cols
88        .iter()
89        .map(|&c| data.feature_names[c].clone())
90        .collect();
91    Dataset::new(features, data.target.clone(), names, &data.target_name)
92}
93
94impl Transformer for ColumnTransformer {
95    fn fit(&mut self, data: &Dataset) -> Result<()> {
96        if data.n_samples() == 0 {
97            return Err(ScryLearnError::EmptyDataset);
98        }
99        for step in &mut self.steps {
100            // Validate column indices.
101            for &c in &step.columns {
102                if c >= data.n_features() {
103                    return Err(ScryLearnError::InvalidColumn(format!(
104                        "column index {c} out of range (dataset has {} features)",
105                        data.n_features()
106                    )));
107                }
108            }
109            let sub = extract_columns(data, &step.columns);
110            step.transformer.fit(&sub)?;
111        }
112        self.fitted = true;
113        Ok(())
114    }
115
116    fn transform(&self, data: &mut Dataset) -> Result<()> {
117        if !self.fitted {
118            return Err(ScryLearnError::NotFitted);
119        }
120
121        // Transform each column subset independently, collect results.
122        let mut result_cols: Vec<Vec<f64>> = Vec::new();
123        let mut result_names: Vec<String> = Vec::new();
124
125        for step in &self.steps {
126            let mut sub = extract_columns(data, &step.columns);
127            step.transformer.transform(&mut sub)?;
128            for (col, name) in sub.features.into_iter().zip(sub.feature_names) {
129                result_cols.push(col);
130                result_names.push(name);
131            }
132        }
133
134        // Replace the dataset's features with the concatenated result.
135        data.features = result_cols;
136        data.feature_names = result_names;
137        data.sync_matrix();
138
139        Ok(())
140    }
141
142    fn inverse_transform(&self, _data: &mut Dataset) -> Result<()> {
143        Err(ScryLearnError::InvalidParameter(
144            "ColumnTransformer is not invertible".into(),
145        ))
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use crate::preprocess::{MinMaxScaler, StandardScaler};
153
154    #[test]
155    fn test_column_transformer_basic() {
156        // 4 features, apply StandardScaler to [0,1], MinMaxScaler to [2,3]
157        let mut ds = Dataset::new(
158            vec![
159                vec![1.0, 2.0, 3.0, 4.0, 5.0],           // col 0
160                vec![10.0, 20.0, 30.0, 40.0, 50.0],      // col 1
161                vec![100.0, 200.0, 300.0, 400.0, 500.0], // col 2
162                vec![5.0, 10.0, 15.0, 20.0, 25.0],       // col 3
163            ],
164            vec![0.0; 5],
165            vec!["a".into(), "b".into(), "c".into(), "d".into()],
166            "y",
167        );
168
169        let mut ct = ColumnTransformer::new()
170            .add(&[0, 1], StandardScaler::new())
171            .add(&[2, 3], MinMaxScaler::new());
172
173        ct.fit_transform(&mut ds).unwrap();
174
175        assert_eq!(ds.n_features(), 4);
176
177        // StandardScaler'd columns: mean ≈ 0
178        let mean_a: f64 = ds.features[0].iter().sum::<f64>() / 5.0;
179        assert!(
180            mean_a.abs() < 1e-10,
181            "col 0 should be zero-mean, got {mean_a}"
182        );
183
184        let mean_b: f64 = ds.features[1].iter().sum::<f64>() / 5.0;
185        assert!(
186            mean_b.abs() < 1e-10,
187            "col 1 should be zero-mean, got {mean_b}"
188        );
189
190        // MinMaxScaler'd columns: min=0, max=1
191        assert!(ds.features[2][0].abs() < 1e-10, "col 2 min should be 0");
192        assert!(
193            (ds.features[2][4] - 1.0).abs() < 1e-10,
194            "col 2 max should be 1"
195        );
196        assert!(ds.features[3][0].abs() < 1e-10, "col 3 min should be 0");
197        assert!(
198            (ds.features[3][4] - 1.0).abs() < 1e-10,
199            "col 3 max should be 1"
200        );
201    }
202
203    #[test]
204    fn test_column_transformer_not_fitted() {
205        let ct = ColumnTransformer::new().add(&[0], StandardScaler::new());
206        let mut ds = Dataset::new(vec![vec![1.0, 2.0]], vec![0.0; 2], vec!["x".into()], "y");
207        assert!(Transformer::transform(&ct, &mut ds).is_err());
208    }
209
210    #[test]
211    fn test_column_transformer_invalid_column() {
212        let mut ct = ColumnTransformer::new().add(&[99], StandardScaler::new());
213        let ds = Dataset::new(vec![vec![1.0, 2.0]], vec![0.0; 2], vec!["x".into()], "y");
214        assert!(Transformer::fit(&mut ct, &ds).is_err());
215    }
216}