scirs2_transform/pipeline/
mod.rs

1//! Pipeline API for chaining transformations
2//!
3//! This module provides utilities for creating pipelines of transformations
4//! that can be applied sequentially, similar to scikit-learn's Pipeline.
5
6// mod adapters;
7
8// pub use adapters::boxed;
9
10use ndarray::{Array2, ArrayBase, Data, Ix2};
11use num_traits::{Float, NumCast};
12use std::any::Any;
13
14use crate::error::{Result, TransformError};
15
16/// Trait for all transformers that can be used in pipelines
17pub trait Transformer: Send + Sync {
18    /// Fits the transformer to the input data
19    fn fit(&mut self, x: &Array2<f64>) -> Result<()>;
20
21    /// Transforms the input data
22    fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>>;
23
24    /// Fits and transforms the data in one step
25    fn fit_transform(&mut self, x: &Array2<f64>) -> Result<Array2<f64>> {
26        self.fit(x)?;
27        self.transform(x)
28    }
29
30    /// Returns a boxed clone of the transformer
31    fn clone_box(&self) -> Box<dyn Transformer>;
32
33    /// Returns the transformer as Any for downcasting
34    fn as_any(&self) -> &dyn Any;
35
36    /// Returns the transformer as mutable Any for downcasting
37    fn as_any_mut(&mut self) -> &mut dyn Any;
38}
39
40/// A pipeline of transformations to be applied sequentially
41pub struct Pipeline {
42    /// List of named steps in the pipeline
43    steps: Vec<(String, Box<dyn Transformer>)>,
44    /// Whether the pipeline has been fitted
45    fitted: bool,
46}
47
48impl Pipeline {
49    /// Creates a new empty pipeline
50    pub fn new() -> Self {
51        Pipeline {
52            steps: Vec::new(),
53            fitted: false,
54        }
55    }
56
57    /// Adds a step to the pipeline
58    ///
59    /// # Arguments
60    /// * `name` - Name of the step
61    /// * `transformer` - The transformer to add
62    ///
63    /// # Returns
64    /// * `Self` - The pipeline for chaining
65    pub fn add_step(mut self, name: impl Into<String>, transformer: Box<dyn Transformer>) -> Self {
66        self.steps.push((name.into(), transformer));
67        self
68    }
69
70    /// Fits all steps in the pipeline
71    ///
72    /// # Arguments
73    /// * `x` - The input data
74    ///
75    /// # Returns
76    /// * `Result<()>` - Ok if successful, Err otherwise
77    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
78    where
79        S: Data,
80        S::Elem: Float + NumCast,
81    {
82        let mut x_transformed = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
83
84        for (name, transformer) in &mut self.steps {
85            transformer.fit(&x_transformed).map_err(|e| {
86                TransformError::TransformationError(format!("Failed to fit step '{name}': {e}"))
87            })?;
88
89            x_transformed = transformer.transform(&x_transformed).map_err(|e| {
90                TransformError::TransformationError(format!(
91                    "Failed to transform in step '{name}': {e}"
92                ))
93            })?;
94        }
95
96        self.fitted = true;
97        Ok(())
98    }
99
100    /// Transforms data through all steps in the pipeline
101    ///
102    /// # Arguments
103    /// * `x` - The input data
104    ///
105    /// # Returns
106    /// * `Result<Array2<f64>>` - The transformed data
107    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
108    where
109        S: Data,
110        S::Elem: Float + NumCast,
111    {
112        if !self.fitted {
113            return Err(TransformError::TransformationError(
114                "Pipeline has not been fitted".to_string(),
115            ));
116        }
117
118        let mut x_transformed = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
119
120        for (name, transformer) in &self.steps {
121            x_transformed = transformer.transform(&x_transformed).map_err(|e| {
122                TransformError::TransformationError(format!(
123                    "Failed to transform in step '{name}': {e}"
124                ))
125            })?;
126        }
127
128        Ok(x_transformed)
129    }
130
131    /// Fits and transforms data through all steps in the pipeline
132    ///
133    /// # Arguments
134    /// * `x` - The input data
135    ///
136    /// # Returns
137    /// * `Result<Array2<f64>>` - The transformed data
138    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
139    where
140        S: Data,
141        S::Elem: Float + NumCast,
142    {
143        self.fit(x)?;
144        self.transform(x)
145    }
146
147    /// Returns the number of steps in the pipeline
148    pub fn len(&self) -> usize {
149        self.steps.len()
150    }
151
152    /// Returns whether the pipeline is empty
153    pub fn is_empty(&self) -> bool {
154        self.steps.is_empty()
155    }
156
157    /// Gets a reference to a step by name
158    pub fn get_step(&self, name: &str) -> Option<&dyn Transformer> {
159        self.steps
160            .iter()
161            .find(|(n, _)| n == name)
162            .map(|(_, t)| t.as_ref())
163    }
164
165    /// Gets a mutable reference to a step by name
166    pub fn get_step_mut(&mut self, name: &str) -> Option<&mut Box<dyn Transformer>> {
167        self.steps
168            .iter_mut()
169            .find(|(n, _)| n == name)
170            .map(|(_, t)| t)
171    }
172}
173
174impl Default for Pipeline {
175    fn default() -> Self {
176        Self::new()
177    }
178}
179
180/// ColumnTransformer applies different transformers to different columns
181pub struct ColumnTransformer {
182    /// List of transformers with their column indices
183    transformers: Vec<(String, Box<dyn Transformer>, Vec<usize>)>,
184    /// Whether to pass through columns not specified
185    remainder: RemainderOption,
186    /// Whether the transformer has been fitted
187    fitted: bool,
188}
189
190/// Options for handling columns not specified in transformers
191#[derive(Debug, Clone, Copy)]
192pub enum RemainderOption {
193    /// Drop unspecified columns
194    Drop,
195    /// Pass through unspecified columns unchanged
196    Passthrough,
197}
198
199impl ColumnTransformer {
200    /// Creates a new ColumnTransformer
201    ///
202    /// # Arguments
203    /// * `remainder` - How to handle unspecified columns
204    pub fn new(remainder: RemainderOption) -> Self {
205        ColumnTransformer {
206            transformers: Vec::new(),
207            remainder,
208            fitted: false,
209        }
210    }
211
212    /// Adds a transformer for specific columns
213    ///
214    /// # Arguments
215    /// * `name` - Name of the transformer
216    /// * `transformer` - The transformer to apply
217    /// * `columns` - Column indices to apply the transformer to
218    ///
219    /// # Returns
220    /// * `Self` - The ColumnTransformer for chaining
221    pub fn add_transformer(
222        mut self,
223        name: impl Into<String>,
224        transformer: Box<dyn Transformer>,
225        columns: Vec<usize>,
226    ) -> Self {
227        self.transformers.push((name.into(), transformer, columns));
228        self
229    }
230
231    /// Fits all transformers to their respective columns
232    ///
233    /// # Arguments
234    /// * `x` - The input data
235    ///
236    /// # Returns
237    /// * `Result<()>` - Ok if successful, Err otherwise
238    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
239    where
240        S: Data,
241        S::Elem: Float + NumCast,
242    {
243        let x_f64 = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
244        let n_features = x_f64.shape()[1];
245
246        // Validate column indices
247        for (name_, transformer, columns) in &self.transformers {
248            for &col in columns {
249                if col >= n_features {
250                    return Err(TransformError::InvalidInput(format!(
251                        "Column index {col} in transformer '{name_}' exceeds number of features {n_features}"
252                    )));
253                }
254            }
255        }
256
257        // Fit each transformer on its columns
258        for (name, transformer, columns) in &mut self.transformers {
259            // Extract relevant columns
260            let subset = extract_columns(&x_f64, columns);
261
262            transformer.fit(&subset).map_err(|e| {
263                TransformError::TransformationError(format!(
264                    "Failed to fit transformer '{name}': {e}"
265                ))
266            })?;
267        }
268
269        self.fitted = true;
270        Ok(())
271    }
272
273    /// Transforms data using all configured transformers
274    ///
275    /// # Arguments
276    /// * `x` - The input data
277    ///
278    /// # Returns
279    /// * `Result<Array2<f64>>` - The transformed data
280    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
281    where
282        S: Data,
283        S::Elem: Float + NumCast,
284    {
285        if !self.fitted {
286            return Err(TransformError::TransformationError(
287                "ColumnTransformer has not been fitted".to_string(),
288            ));
289        }
290
291        let x_f64 = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
292        let n_samples = x_f64.shape()[0];
293        let n_features = x_f64.shape()[1];
294
295        // Track which columns have been transformed
296        let mut used_columns = vec![false; n_features];
297        let mut transformed_parts = Vec::new();
298
299        // Transform each group of columns
300        for (name, transformer, columns) in &self.transformers {
301            // Mark columns as used
302            for &col in columns {
303                used_columns[col] = true;
304            }
305
306            // Extract and transform columns
307            let subset = extract_columns(&x_f64, columns);
308            let transformed = transformer.transform(&subset).map_err(|e| {
309                TransformError::TransformationError(format!(
310                    "Failed to transform with '{name}': {e}"
311                ))
312            })?;
313
314            transformed_parts.push(transformed);
315        }
316
317        // Handle remainder columns
318        match self.remainder {
319            RemainderOption::Passthrough => {
320                // Collect unused columns
321                let unused_columns: Vec<usize> =
322                    (0..n_features).filter(|&i| !used_columns[i]).collect();
323
324                if !unused_columns.is_empty() {
325                    let remainder = extract_columns(&x_f64, &unused_columns);
326                    transformed_parts.push(remainder);
327                }
328            }
329            RemainderOption::Drop => {
330                // Do nothing - unused columns are dropped
331            }
332        }
333
334        // Concatenate all parts horizontally
335        if transformed_parts.is_empty() {
336            return Ok(Array2::zeros((n_samples, 0)));
337        }
338
339        concatenate_horizontal(&transformed_parts)
340    }
341
342    /// Fits and transforms data in one step
343    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
344    where
345        S: Data,
346        S::Elem: Float + NumCast,
347    {
348        self.fit(x)?;
349        self.transform(x)
350    }
351}
352
353/// Extracts specific columns from a 2D array
354#[allow(dead_code)]
355fn extract_columns(data: &Array2<f64>, columns: &[usize]) -> Array2<f64> {
356    let n_samples = data.shape()[0];
357    let n_cols = columns.len();
358
359    let mut result = Array2::zeros((n_samples, n_cols));
360
361    for (j, &col_idx) in columns.iter().enumerate() {
362        for i in 0..n_samples {
363            result[[i, j]] = data[[i, col_idx]];
364        }
365    }
366
367    result
368}
369
370/// Concatenates arrays horizontally
371#[allow(dead_code)]
372fn concatenate_horizontal(arrays: &[Array2<f64>]) -> Result<Array2<f64>> {
373    if arrays.is_empty() {
374        return Err(TransformError::InvalidInput(
375            "Cannot concatenate empty array list".to_string(),
376        ));
377    }
378
379    let n_samples = arrays[0].shape()[0];
380    let total_features: usize = arrays.iter().map(|a| a.shape()[1]).sum();
381
382    // Verify all _arrays have the same number of samples
383    for arr in arrays {
384        if arr.shape()[0] != n_samples {
385            return Err(TransformError::InvalidInput(
386                "All _arrays must have the same number of samples".to_string(),
387            ));
388        }
389    }
390
391    let mut result = Array2::zeros((n_samples, total_features));
392    let mut col_offset = 0;
393
394    for arr in arrays {
395        let n_cols = arr.shape()[1];
396        for i in 0..n_samples {
397            for j in 0..n_cols {
398                result[[i, col_offset + j]] = arr[[i, j]];
399            }
400        }
401        col_offset += n_cols;
402    }
403
404    Ok(result)
405}
406
407/// Make a pipeline from a list of (name, transformer) tuples
408#[allow(dead_code)]
409pub fn make_pipeline(steps: Vec<(&str, Box<dyn Transformer>)>) -> Pipeline {
410    let mut pipeline = Pipeline::new();
411    for (name, transformer) in steps {
412        pipeline = pipeline.add_step(name, transformer);
413    }
414    pipeline
415}
416
417/// Make a column transformer from a list of (name, transformer, columns) tuples
418#[allow(dead_code)]
419pub fn make_column_transformer(
420    transformers: Vec<(&str, Box<dyn Transformer>, Vec<usize>)>,
421    remainder: RemainderOption,
422) -> ColumnTransformer {
423    let mut ct = ColumnTransformer::new(remainder);
424    for (name, transformer, columns) in transformers {
425        ct = ct.add_transformer(name, transformer, columns);
426    }
427    ct
428}