1use crate::dataset::Dataset;
7use crate::error::{Result, ScryLearnError};
8use crate::preprocess::Transformer;
9
10#[non_exhaustive]
21pub struct Pipeline {
22 transformers: Vec<Box<dyn TransformerBox>>,
23 model: Option<Box<dyn PipelineModel>>,
24}
25
26trait TransformerBox {
28 fn fit(&mut self, data: &Dataset) -> Result<()>;
29 fn transform(&self, data: &mut Dataset) -> Result<()>;
30}
31
32impl<T: Transformer> TransformerBox for T {
33 fn fit(&mut self, data: &Dataset) -> Result<()> {
34 Transformer::fit(self, data)
35 }
36 fn transform(&self, data: &mut Dataset) -> Result<()> {
37 Transformer::transform(self, data)
38 }
39}
40
41pub trait PipelineModel {
43 fn fit(&mut self, data: &Dataset) -> Result<()>;
45 fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>>;
47}
48
49macro_rules! impl_pipeline_model {
51 ($($ty:ty),* $(,)?) => {
52 $(
53 impl PipelineModel for $ty {
54 fn fit(&mut self, data: &Dataset) -> Result<()> { self.fit(data) }
55 fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> { self.predict(features) }
56 }
57 )*
58 };
59}
60
61impl_pipeline_model! {
62 crate::tree::DecisionTreeClassifier,
63 crate::tree::RandomForestClassifier,
64 crate::linear::LinearRegression,
65 crate::linear::LogisticRegression,
66 crate::neighbors::KnnClassifier,
67 crate::naive_bayes::GaussianNb,
68 crate::tree::DecisionTreeRegressor,
69 crate::tree::RandomForestRegressor,
70 crate::tree::GradientBoostingClassifier,
71 crate::tree::GradientBoostingRegressor,
72 crate::linear::LassoRegression,
73 crate::linear::ElasticNet,
74 crate::svm::LinearSVC,
75 crate::svm::LinearSVR,
76 crate::naive_bayes::BernoulliNB,
77 crate::naive_bayes::MultinomialNB,
78 crate::tree::HistGradientBoostingClassifier,
79 crate::tree::HistGradientBoostingRegressor,
80 crate::neural::MLPClassifier,
81 crate::neural::MLPRegressor,
82}
83
84#[cfg(feature = "experimental")]
85impl_pipeline_model! {
86 crate::svm::KernelSVC,
87 crate::svm::KernelSVR,
88}
89
90impl Pipeline {
91 pub fn new() -> Self {
93 Self {
94 transformers: Vec::new(),
95 model: None,
96 }
97 }
98
99 pub fn add_transformer<T: Transformer + 'static>(mut self, t: T) -> Self {
101 self.transformers.push(Box::new(t));
102 self
103 }
104
105 pub fn set_model<M: PipelineModel + 'static>(mut self, m: M) -> Self {
107 self.model = Some(Box::new(m));
108 self
109 }
110
111 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
113 data.validate_finite()?;
114 let mut transformed = data.clone();
115
116 for t in &mut self.transformers {
117 t.fit(&transformed)?;
118 t.transform(&mut transformed)?;
119 }
120
121 if let Some(model) = &mut self.model {
122 model.fit(&transformed)?;
123 }
124
125 Ok(())
126 }
127
128 pub fn predict(&self, data: &Dataset) -> Result<Vec<f64>> {
130 let mut transformed = data.clone();
131
132 for t in &self.transformers {
133 t.transform(&mut transformed)?;
134 }
135
136 let model = self.model.as_ref().ok_or(ScryLearnError::NotFitted)?;
137 let features = transformed.feature_matrix();
138 model.predict(&features)
139 }
140}
141
142impl Default for Pipeline {
143 fn default() -> Self {
144 Self::new()
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use crate::preprocess::StandardScaler;
152 use crate::tree::DecisionTreeClassifier;
153
154 #[test]
155 fn test_pipeline_fit_predict() {
156 let features = vec![
157 vec![0.0, 0.5, 1.0, 5.0, 5.5, 6.0],
158 vec![0.0, 0.5, 1.0, 5.0, 5.5, 6.0],
159 ];
160 let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
161 let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
162
163 let mut pipeline = Pipeline::new()
164 .add_transformer(StandardScaler::new())
165 .set_model(DecisionTreeClassifier::new());
166
167 pipeline.fit(&data).unwrap();
168 let preds = pipeline.predict(&data).unwrap();
169 assert_eq!(preds.len(), 6);
170 }
171
172 #[test]
176 fn test_pipeline_pca_then_model() {
177 use crate::preprocess::Pca;
178
179 let features = vec![
180 vec![0.0, 0.5, 1.0, 5.0, 5.5, 6.0],
181 vec![0.0, 0.5, 1.0, 5.0, 5.5, 6.0],
182 ];
183 let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
184 let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
185
186 let mut pipeline = Pipeline::new()
187 .add_transformer(Pca::with_n_components(2))
188 .add_transformer(StandardScaler::new())
189 .set_model(DecisionTreeClassifier::new());
190
191 pipeline.fit(&data).unwrap();
194 let preds = pipeline.predict(&data).unwrap();
195 assert_eq!(preds.len(), 6);
196 }
197}