Skip to main content

scry_learn/
pipeline.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Composable ML pipeline.
3//!
4//! Chain preprocessing steps with a final model in a single workflow.
5
6use crate::dataset::Dataset;
7use crate::error::{Result, ScryLearnError};
8use crate::preprocess::Transformer;
9
10/// A composable ML pipeline.
11///
12/// ```ignore
13/// let pipeline = Pipeline::new()
14///     .add_transformer(StandardScaler::new())
15///     .set_model(RandomForestClassifier::new());
16///
17/// pipeline.fit(&train)?;
18/// let preds = pipeline.predict(&test)?;
19/// ```
20#[non_exhaustive]
21pub struct Pipeline {
22    transformers: Vec<Box<dyn TransformerBox>>,
23    model: Option<Box<dyn PipelineModel>>,
24}
25
26/// Trait object wrapper for transformers (to store heterogeneous types).
27trait TransformerBox {
28    fn fit(&mut self, data: &Dataset) -> Result<()>;
29    fn transform(&self, data: &mut Dataset) -> Result<()>;
30}
31
32impl<T: Transformer> TransformerBox for T {
33    fn fit(&mut self, data: &Dataset) -> Result<()> {
34        Transformer::fit(self, data)
35    }
36    fn transform(&self, data: &mut Dataset) -> Result<()> {
37        Transformer::transform(self, data)
38    }
39}
40
41/// Trait for models that can be used in a pipeline.
42pub trait PipelineModel {
43    /// Train the model on a dataset.
44    fn fit(&mut self, data: &Dataset) -> Result<()>;
45    /// Predict on row-major feature matrix.
46    fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>>;
47}
48
49// Implement PipelineModel for all classifier/regressor types.
50macro_rules! impl_pipeline_model {
51    ($($ty:ty),* $(,)?) => {
52        $(
53            impl PipelineModel for $ty {
54                fn fit(&mut self, data: &Dataset) -> Result<()> { self.fit(data) }
55                fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> { self.predict(features) }
56            }
57        )*
58    };
59}
60
61impl_pipeline_model! {
62    crate::tree::DecisionTreeClassifier,
63    crate::tree::RandomForestClassifier,
64    crate::linear::LinearRegression,
65    crate::linear::LogisticRegression,
66    crate::neighbors::KnnClassifier,
67    crate::naive_bayes::GaussianNb,
68    crate::tree::DecisionTreeRegressor,
69    crate::tree::RandomForestRegressor,
70    crate::tree::GradientBoostingClassifier,
71    crate::tree::GradientBoostingRegressor,
72    crate::linear::LassoRegression,
73    crate::linear::ElasticNet,
74    crate::svm::LinearSVC,
75    crate::svm::LinearSVR,
76    crate::naive_bayes::BernoulliNB,
77    crate::naive_bayes::MultinomialNB,
78    crate::tree::HistGradientBoostingClassifier,
79    crate::tree::HistGradientBoostingRegressor,
80    crate::neural::MLPClassifier,
81    crate::neural::MLPRegressor,
82}
83
84#[cfg(feature = "experimental")]
85impl_pipeline_model! {
86    crate::svm::KernelSVC,
87    crate::svm::KernelSVR,
88}
89
90impl Pipeline {
91    /// Create an empty pipeline.
92    pub fn new() -> Self {
93        Self {
94            transformers: Vec::new(),
95            model: None,
96        }
97    }
98
99    /// Add a preprocessing transformer.
100    pub fn add_transformer<T: Transformer + 'static>(mut self, t: T) -> Self {
101        self.transformers.push(Box::new(t));
102        self
103    }
104
105    /// Set the final model.
106    pub fn set_model<M: PipelineModel + 'static>(mut self, m: M) -> Self {
107        self.model = Some(Box::new(m));
108        self
109    }
110
111    /// Fit all transformers and the model.
112    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
113        data.validate_finite()?;
114        let mut transformed = data.clone();
115
116        for t in &mut self.transformers {
117            t.fit(&transformed)?;
118            t.transform(&mut transformed)?;
119        }
120
121        if let Some(model) = &mut self.model {
122            model.fit(&transformed)?;
123        }
124
125        Ok(())
126    }
127
128    /// Transform data through all preprocessing steps and predict.
129    pub fn predict(&self, data: &Dataset) -> Result<Vec<f64>> {
130        let mut transformed = data.clone();
131
132        for t in &self.transformers {
133            t.transform(&mut transformed)?;
134        }
135
136        let model = self.model.as_ref().ok_or(ScryLearnError::NotFitted)?;
137        let features = transformed.feature_matrix();
138        model.predict(&features)
139    }
140}
141
142impl Default for Pipeline {
143    fn default() -> Self {
144        Self::new()
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use crate::preprocess::StandardScaler;
152    use crate::tree::DecisionTreeClassifier;
153
154    #[test]
155    fn test_pipeline_fit_predict() {
156        let features = vec![
157            vec![0.0, 0.5, 1.0, 5.0, 5.5, 6.0],
158            vec![0.0, 0.5, 1.0, 5.0, 5.5, 6.0],
159        ];
160        let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
161        let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
162
163        let mut pipeline = Pipeline::new()
164            .add_transformer(StandardScaler::new())
165            .set_model(DecisionTreeClassifier::new());
166
167        pipeline.fit(&data).unwrap();
168        let preds = pipeline.predict(&data).unwrap();
169        assert_eq!(preds.len(), 6);
170    }
171
172    /// Regression test: PCA.transform() invalidates the matrix cache.
173    /// Without lazy rebuild, the next transformer's fit() panics when
174    /// it calls data.matrix().
175    #[test]
176    fn test_pipeline_pca_then_model() {
177        use crate::preprocess::Pca;
178
179        let features = vec![
180            vec![0.0, 0.5, 1.0, 5.0, 5.5, 6.0],
181            vec![0.0, 0.5, 1.0, 5.0, 5.5, 6.0],
182        ];
183        let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
184        let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
185
186        let mut pipeline = Pipeline::new()
187            .add_transformer(Pca::with_n_components(2))
188            .add_transformer(StandardScaler::new())
189            .set_model(DecisionTreeClassifier::new());
190
191        // This panics without lazy matrix rebuild:
192        // PCA.transform() invalidates → StandardScaler.fit() calls matrix() → 💥
193        pipeline.fit(&data).unwrap();
194        let preds = pipeline.predict(&data).unwrap();
195        assert_eq!(preds.len(), 6);
196    }
197}