scry_learn/preprocess/
column_transformer.rs1use crate::dataset::Dataset;
18use crate::error::{Result, ScryLearnError};
19use crate::preprocess::Transformer;
20
21trait BoxedTransformer {
23 fn fit(&mut self, data: &Dataset) -> Result<()>;
24 fn transform(&self, data: &mut Dataset) -> Result<()>;
25}
26
27impl<T: Transformer> BoxedTransformer for T {
28 fn fit(&mut self, data: &Dataset) -> Result<()> {
29 Transformer::fit(self, data)
30 }
31 fn transform(&self, data: &mut Dataset) -> Result<()> {
32 Transformer::transform(self, data)
33 }
34}
35
36struct TransformerStep {
38 columns: Vec<usize>,
39 transformer: Box<dyn BoxedTransformer>,
40}
41
42#[non_exhaustive]
54pub struct ColumnTransformer {
55 steps: Vec<TransformerStep>,
56 fitted: bool,
57}
58
59impl ColumnTransformer {
60 pub fn new() -> Self {
62 Self {
63 steps: Vec::new(),
64 fitted: false,
65 }
66 }
67
68 pub fn add<T: Transformer + 'static>(mut self, columns: &[usize], transformer: T) -> Self {
70 self.steps.push(TransformerStep {
71 columns: columns.to_vec(),
72 transformer: Box::new(transformer),
73 });
74 self
75 }
76}
77
78impl Default for ColumnTransformer {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84fn extract_columns(data: &Dataset, cols: &[usize]) -> Dataset {
86 let features: Vec<Vec<f64>> = cols.iter().map(|&c| data.features[c].clone()).collect();
87 let names: Vec<String> = cols
88 .iter()
89 .map(|&c| data.feature_names[c].clone())
90 .collect();
91 Dataset::new(features, data.target.clone(), names, &data.target_name)
92}
93
94impl Transformer for ColumnTransformer {
95 fn fit(&mut self, data: &Dataset) -> Result<()> {
96 if data.n_samples() == 0 {
97 return Err(ScryLearnError::EmptyDataset);
98 }
99 for step in &mut self.steps {
100 for &c in &step.columns {
102 if c >= data.n_features() {
103 return Err(ScryLearnError::InvalidColumn(format!(
104 "column index {c} out of range (dataset has {} features)",
105 data.n_features()
106 )));
107 }
108 }
109 let sub = extract_columns(data, &step.columns);
110 step.transformer.fit(&sub)?;
111 }
112 self.fitted = true;
113 Ok(())
114 }
115
116 fn transform(&self, data: &mut Dataset) -> Result<()> {
117 if !self.fitted {
118 return Err(ScryLearnError::NotFitted);
119 }
120
121 let mut result_cols: Vec<Vec<f64>> = Vec::new();
123 let mut result_names: Vec<String> = Vec::new();
124
125 for step in &self.steps {
126 let mut sub = extract_columns(data, &step.columns);
127 step.transformer.transform(&mut sub)?;
128 for (col, name) in sub.features.into_iter().zip(sub.feature_names) {
129 result_cols.push(col);
130 result_names.push(name);
131 }
132 }
133
134 data.features = result_cols;
136 data.feature_names = result_names;
137 data.sync_matrix();
138
139 Ok(())
140 }
141
142 fn inverse_transform(&self, _data: &mut Dataset) -> Result<()> {
143 Err(ScryLearnError::InvalidParameter(
144 "ColumnTransformer is not invertible".into(),
145 ))
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152 use crate::preprocess::{MinMaxScaler, StandardScaler};
153
154 #[test]
155 fn test_column_transformer_basic() {
156 let mut ds = Dataset::new(
158 vec![
159 vec![1.0, 2.0, 3.0, 4.0, 5.0], vec![10.0, 20.0, 30.0, 40.0, 50.0], vec![100.0, 200.0, 300.0, 400.0, 500.0], vec![5.0, 10.0, 15.0, 20.0, 25.0], ],
164 vec![0.0; 5],
165 vec!["a".into(), "b".into(), "c".into(), "d".into()],
166 "y",
167 );
168
169 let mut ct = ColumnTransformer::new()
170 .add(&[0, 1], StandardScaler::new())
171 .add(&[2, 3], MinMaxScaler::new());
172
173 ct.fit_transform(&mut ds).unwrap();
174
175 assert_eq!(ds.n_features(), 4);
176
177 let mean_a: f64 = ds.features[0].iter().sum::<f64>() / 5.0;
179 assert!(
180 mean_a.abs() < 1e-10,
181 "col 0 should be zero-mean, got {mean_a}"
182 );
183
184 let mean_b: f64 = ds.features[1].iter().sum::<f64>() / 5.0;
185 assert!(
186 mean_b.abs() < 1e-10,
187 "col 1 should be zero-mean, got {mean_b}"
188 );
189
190 assert!(ds.features[2][0].abs() < 1e-10, "col 2 min should be 0");
192 assert!(
193 (ds.features[2][4] - 1.0).abs() < 1e-10,
194 "col 2 max should be 1"
195 );
196 assert!(ds.features[3][0].abs() < 1e-10, "col 3 min should be 0");
197 assert!(
198 (ds.features[3][4] - 1.0).abs() < 1e-10,
199 "col 3 max should be 1"
200 );
201 }
202
203 #[test]
204 fn test_column_transformer_not_fitted() {
205 let ct = ColumnTransformer::new().add(&[0], StandardScaler::new());
206 let mut ds = Dataset::new(vec![vec![1.0, 2.0]], vec![0.0; 2], vec!["x".into()], "y");
207 assert!(Transformer::transform(&ct, &mut ds).is_err());
208 }
209
210 #[test]
211 fn test_column_transformer_invalid_column() {
212 let mut ct = ColumnTransformer::new().add(&[99], StandardScaler::new());
213 let ds = Dataset::new(vec![vec![1.0, 2.0]], vec![0.0; 2], vec!["x".into()], "y");
214 assert!(Transformer::fit(&mut ct, &ds).is_err());
215 }
216}