scirs2_transform/pipeline/
mod.rs1use ndarray::{Array2, ArrayBase, Data, Ix2};
11use num_traits::{Float, NumCast};
12use std::any::Any;
13
14use crate::error::{Result, TransformError};
15
16pub trait Transformer: Send + Sync {
18 fn fit(&mut self, x: &Array2<f64>) -> Result<()>;
20
21 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>>;
23
24 fn fit_transform(&mut self, x: &Array2<f64>) -> Result<Array2<f64>> {
26 self.fit(x)?;
27 self.transform(x)
28 }
29
30 fn clone_box(&self) -> Box<dyn Transformer>;
32
33 fn as_any(&self) -> &dyn Any;
35
36 fn as_any_mut(&mut self) -> &mut dyn Any;
38}
39
40pub struct Pipeline {
42 steps: Vec<(String, Box<dyn Transformer>)>,
44 fitted: bool,
46}
47
48impl Pipeline {
49 pub fn new() -> Self {
51 Pipeline {
52 steps: Vec::new(),
53 fitted: false,
54 }
55 }
56
57 pub fn add_step(mut self, name: impl Into<String>, transformer: Box<dyn Transformer>) -> Self {
66 self.steps.push((name.into(), transformer));
67 self
68 }
69
70 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
78 where
79 S: Data,
80 S::Elem: Float + NumCast,
81 {
82 let mut x_transformed = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
83
84 for (name, transformer) in &mut self.steps {
85 transformer.fit(&x_transformed).map_err(|e| {
86 TransformError::TransformationError(format!("Failed to fit step '{name}': {e}"))
87 })?;
88
89 x_transformed = transformer.transform(&x_transformed).map_err(|e| {
90 TransformError::TransformationError(format!(
91 "Failed to transform in step '{name}': {e}"
92 ))
93 })?;
94 }
95
96 self.fitted = true;
97 Ok(())
98 }
99
100 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
108 where
109 S: Data,
110 S::Elem: Float + NumCast,
111 {
112 if !self.fitted {
113 return Err(TransformError::TransformationError(
114 "Pipeline has not been fitted".to_string(),
115 ));
116 }
117
118 let mut x_transformed = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
119
120 for (name, transformer) in &self.steps {
121 x_transformed = transformer.transform(&x_transformed).map_err(|e| {
122 TransformError::TransformationError(format!(
123 "Failed to transform in step '{name}': {e}"
124 ))
125 })?;
126 }
127
128 Ok(x_transformed)
129 }
130
131 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
139 where
140 S: Data,
141 S::Elem: Float + NumCast,
142 {
143 self.fit(x)?;
144 self.transform(x)
145 }
146
147 pub fn len(&self) -> usize {
149 self.steps.len()
150 }
151
152 pub fn is_empty(&self) -> bool {
154 self.steps.is_empty()
155 }
156
157 pub fn get_step(&self, name: &str) -> Option<&dyn Transformer> {
159 self.steps
160 .iter()
161 .find(|(n, _)| n == name)
162 .map(|(_, t)| t.as_ref())
163 }
164
165 pub fn get_step_mut(&mut self, name: &str) -> Option<&mut Box<dyn Transformer>> {
167 self.steps
168 .iter_mut()
169 .find(|(n, _)| n == name)
170 .map(|(_, t)| t)
171 }
172}
173
174impl Default for Pipeline {
175 fn default() -> Self {
176 Self::new()
177 }
178}
179
180pub struct ColumnTransformer {
182 transformers: Vec<(String, Box<dyn Transformer>, Vec<usize>)>,
184 remainder: RemainderOption,
186 fitted: bool,
188}
189
190#[derive(Debug, Clone, Copy)]
192pub enum RemainderOption {
193 Drop,
195 Passthrough,
197}
198
199impl ColumnTransformer {
200 pub fn new(remainder: RemainderOption) -> Self {
205 ColumnTransformer {
206 transformers: Vec::new(),
207 remainder,
208 fitted: false,
209 }
210 }
211
212 pub fn add_transformer(
222 mut self,
223 name: impl Into<String>,
224 transformer: Box<dyn Transformer>,
225 columns: Vec<usize>,
226 ) -> Self {
227 self.transformers.push((name.into(), transformer, columns));
228 self
229 }
230
231 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
239 where
240 S: Data,
241 S::Elem: Float + NumCast,
242 {
243 let x_f64 = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
244 let n_features = x_f64.shape()[1];
245
246 for (name_, transformer, columns) in &self.transformers {
248 for &col in columns {
249 if col >= n_features {
250 return Err(TransformError::InvalidInput(format!(
251 "Column index {col} in transformer '{name_}' exceeds number of features {n_features}"
252 )));
253 }
254 }
255 }
256
257 for (name, transformer, columns) in &mut self.transformers {
259 let subset = extract_columns(&x_f64, columns);
261
262 transformer.fit(&subset).map_err(|e| {
263 TransformError::TransformationError(format!(
264 "Failed to fit transformer '{name}': {e}"
265 ))
266 })?;
267 }
268
269 self.fitted = true;
270 Ok(())
271 }
272
273 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
281 where
282 S: Data,
283 S::Elem: Float + NumCast,
284 {
285 if !self.fitted {
286 return Err(TransformError::TransformationError(
287 "ColumnTransformer has not been fitted".to_string(),
288 ));
289 }
290
291 let x_f64 = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
292 let n_samples = x_f64.shape()[0];
293 let n_features = x_f64.shape()[1];
294
295 let mut used_columns = vec![false; n_features];
297 let mut transformed_parts = Vec::new();
298
299 for (name, transformer, columns) in &self.transformers {
301 for &col in columns {
303 used_columns[col] = true;
304 }
305
306 let subset = extract_columns(&x_f64, columns);
308 let transformed = transformer.transform(&subset).map_err(|e| {
309 TransformError::TransformationError(format!(
310 "Failed to transform with '{name}': {e}"
311 ))
312 })?;
313
314 transformed_parts.push(transformed);
315 }
316
317 match self.remainder {
319 RemainderOption::Passthrough => {
320 let unused_columns: Vec<usize> =
322 (0..n_features).filter(|&i| !used_columns[i]).collect();
323
324 if !unused_columns.is_empty() {
325 let remainder = extract_columns(&x_f64, &unused_columns);
326 transformed_parts.push(remainder);
327 }
328 }
329 RemainderOption::Drop => {
330 }
332 }
333
334 if transformed_parts.is_empty() {
336 return Ok(Array2::zeros((n_samples, 0)));
337 }
338
339 concatenate_horizontal(&transformed_parts)
340 }
341
342 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
344 where
345 S: Data,
346 S::Elem: Float + NumCast,
347 {
348 self.fit(x)?;
349 self.transform(x)
350 }
351}
352
353#[allow(dead_code)]
355fn extract_columns(data: &Array2<f64>, columns: &[usize]) -> Array2<f64> {
356 let n_samples = data.shape()[0];
357 let n_cols = columns.len();
358
359 let mut result = Array2::zeros((n_samples, n_cols));
360
361 for (j, &col_idx) in columns.iter().enumerate() {
362 for i in 0..n_samples {
363 result[[i, j]] = data[[i, col_idx]];
364 }
365 }
366
367 result
368}
369
370#[allow(dead_code)]
372fn concatenate_horizontal(arrays: &[Array2<f64>]) -> Result<Array2<f64>> {
373 if arrays.is_empty() {
374 return Err(TransformError::InvalidInput(
375 "Cannot concatenate empty array list".to_string(),
376 ));
377 }
378
379 let n_samples = arrays[0].shape()[0];
380 let total_features: usize = arrays.iter().map(|a| a.shape()[1]).sum();
381
382 for arr in arrays {
384 if arr.shape()[0] != n_samples {
385 return Err(TransformError::InvalidInput(
386 "All _arrays must have the same number of samples".to_string(),
387 ));
388 }
389 }
390
391 let mut result = Array2::zeros((n_samples, total_features));
392 let mut col_offset = 0;
393
394 for arr in arrays {
395 let n_cols = arr.shape()[1];
396 for i in 0..n_samples {
397 for j in 0..n_cols {
398 result[[i, col_offset + j]] = arr[[i, j]];
399 }
400 }
401 col_offset += n_cols;
402 }
403
404 Ok(result)
405}
406
407#[allow(dead_code)]
409pub fn make_pipeline(steps: Vec<(&str, Box<dyn Transformer>)>) -> Pipeline {
410 let mut pipeline = Pipeline::new();
411 for (name, transformer) in steps {
412 pipeline = pipeline.add_step(name, transformer);
413 }
414 pipeline
415}
416
417#[allow(dead_code)]
419pub fn make_column_transformer(
420 transformers: Vec<(&str, Box<dyn Transformer>, Vec<usize>)>,
421 remainder: RemainderOption,
422) -> ColumnTransformer {
423 let mut ct = ColumnTransformer::new(remainder);
424 for (name, transformer, columns) in transformers {
425 ct = ct.add_transformer(name, transformer, columns);
426 }
427 ct
428}