Skip to main content

scry_learn/search/
tunable.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! `Tunable` trait and implementations for all model types.
3//!
4//! Models that implement `Tunable` can participate in [`GridSearchCV`] and
5//! [`RandomizedSearchCV`] hyperparameter search.
6
7use crate::dataset::Dataset;
8use crate::error::{Result, ScryLearnError};
9
10use super::ParamValue;
11
12/// A model whose hyperparameters can be set dynamically by name.
13///
14/// Implement this trait on any model that should participate in
15/// [`GridSearchCV`](super::GridSearchCV) or [`RandomizedSearchCV`](super::RandomizedSearchCV).
16///
17/// # Examples
18///
19/// ```ignore
20/// use scry_learn::search::{Tunable, ParamValue};
21///
22/// let mut dt = DecisionTreeClassifier::new();
23/// dt.set_param("max_depth", ParamValue::Int(5)).unwrap();
24/// ```
25pub trait Tunable {
26    /// Apply a named hyperparameter.
27    ///
28    /// Returns [`ScryLearnError::InvalidParameter`] if the parameter name
29    /// is unrecognised or the value type is wrong.
30    fn set_param(&mut self, name: &str, value: ParamValue) -> Result<()>;
31
32    /// Clone this model into a boxed trait object.
33    fn clone_box(&self) -> Box<dyn Tunable>;
34
35    /// Train on a dataset.
36    fn fit(&mut self, data: &Dataset) -> Result<()>;
37
38    /// Predict on row-major features.
39    fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>>;
40}
41
42// ---------------------------------------------------------------------------
43// impl_tunable! macro — generates the boilerplate for the common case:
44//   - clone + builder for each parameter
45//   - clone_box via self.clone()
46//   - fit delegates to self.fit(data)
47//   - predict delegates to self.predict(features)
48//
49// For models that need custom fit/predict (KMeans, IsolationForest),
50// keep a manual impl below the macro invocation.
51// ---------------------------------------------------------------------------
52
53macro_rules! impl_tunable {
54    (
55        $(
56            $(#[$meta:meta])*
57            $Model:ty {
58                $( $param:ident : $kind:ident ),* $(,)?
59            }
60        );* $(;)?
61    ) => {
62        $(
63            $(#[$meta])*
64            impl Tunable for $Model {
65                fn set_param(&mut self, name: &str, _value: ParamValue) -> Result<()> {
66                    match name {
67                        $(
68                            stringify!($param) => {
69                                impl_tunable!(@extract _value, $kind, $param, self)
70                            }
71                        )*
72                        _ => Err(ScryLearnError::InvalidParameter(format!(
73                            "unknown parameter: {name}"
74                        ))),
75                    }
76                }
77
78                fn clone_box(&self) -> Box<dyn Tunable> {
79                    Box::new(self.clone())
80                }
81
82                fn fit(&mut self, data: &Dataset) -> Result<()> {
83                    self.fit(data)
84                }
85
86                fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
87                    self.predict(features)
88                }
89            }
90        )*
91    };
92
93    // Internal: extract Int parameter.
94    (@extract $value:ident, Int, $param:ident, $self:ident) => {
95        if let ParamValue::Int(v) = $value {
96            *$self = $self.clone().$param(v);
97            Ok(())
98        } else {
99            Err(ScryLearnError::InvalidParameter(format!(
100                concat!(stringify!($param), " expects Int, got {}"), $value
101            )))
102        }
103    };
104
105    // Internal: extract Float parameter.
106    (@extract $value:ident, Float, $param:ident, $self:ident) => {
107        if let ParamValue::Float(v) = $value {
108            *$self = $self.clone().$param(v);
109            Ok(())
110        } else {
111            Err(ScryLearnError::InvalidParameter(format!(
112                concat!(stringify!($param), " expects Float, got {}"), $value
113            )))
114        }
115    };
116}
117
118// ---------------------------------------------------------------------------
119// Standard impls via macro
120// ---------------------------------------------------------------------------
121
122impl_tunable! {
123    crate::tree::DecisionTreeClassifier {
124        max_depth: Int,
125        min_samples_split: Int,
126        min_samples_leaf: Int,
127    };
128    crate::tree::DecisionTreeRegressor {
129        max_depth: Int,
130        min_samples_split: Int,
131        min_samples_leaf: Int,
132    };
133    crate::tree::RandomForestClassifier {
134        n_estimators: Int,
135        max_depth: Int,
136    };
137    crate::linear::LogisticRegression {
138        learning_rate: Float,
139        max_iter: Int,
140        alpha: Float,
141        tolerance: Float,
142    };
143    crate::neighbors::KnnClassifier {
144        k: Int,
145    };
146    crate::neighbors::KnnRegressor {
147        k: Int,
148    };
149    crate::tree::GradientBoostingRegressor {
150        n_estimators: Int,
151        learning_rate: Float,
152        max_depth: Int,
153        min_samples_split: Int,
154        min_samples_leaf: Int,
155    };
156    crate::tree::GradientBoostingClassifier {
157        n_estimators: Int,
158        learning_rate: Float,
159        max_depth: Int,
160        min_samples_split: Int,
161        min_samples_leaf: Int,
162    };
163    crate::svm::LinearSVC {
164        c: Float,
165        max_iter: Int,
166        tol: Float,
167    };
168    crate::svm::LinearSVR {
169        c: Float,
170        epsilon: Float,
171        max_iter: Int,
172        tol: Float,
173    };
174    #[cfg(feature = "experimental")]
175    crate::svm::KernelSVC {
176        c: Float,
177        tol: Float,
178        max_iter: Int,
179    };
180    #[cfg(feature = "experimental")]
181    crate::svm::KernelSVR {
182        c: Float,
183        epsilon: Float,
184        tol: Float,
185        max_iter: Int,
186    };
187    crate::naive_bayes::GaussianNb {};
188    crate::naive_bayes::BernoulliNB {
189        alpha: Float,
190    };
191    crate::naive_bayes::MultinomialNB {
192        alpha: Float,
193    };
194    crate::linear::LassoRegression {
195        alpha: Float,
196        max_iter: Int,
197        tol: Float,
198    };
199    crate::linear::ElasticNet {
200        alpha: Float,
201        l1_ratio: Float,
202        max_iter: Int,
203        tol: Float,
204    };
205    crate::tree::HistGradientBoostingRegressor {
206        n_estimators: Int,
207        learning_rate: Float,
208        max_leaf_nodes: Int,
209        max_depth: Int,
210        min_samples_leaf: Int,
211    };
212    crate::tree::HistGradientBoostingClassifier {
213        n_estimators: Int,
214        learning_rate: Float,
215        max_leaf_nodes: Int,
216        max_depth: Int,
217        min_samples_leaf: Int,
218    };
219    crate::neural::MLPClassifier {
220        learning_rate: Float,
221        alpha: Float,
222        max_iter: Int,
223        batch_size: Int,
224    };
225    crate::neural::MLPRegressor {
226        learning_rate: Float,
227        alpha: Float,
228        max_iter: Int,
229        batch_size: Int,
230    };
231}
232
233// ---------------------------------------------------------------------------
234// Manual impls for models with custom fit/predict
235// ---------------------------------------------------------------------------
236
237impl Tunable for crate::cluster::KMeans {
238    fn set_param(&mut self, name: &str, value: ParamValue) -> Result<()> {
239        match name {
240            "max_iter" => {
241                if let ParamValue::Int(v) = value {
242                    *self = self.clone().max_iter(v);
243                    Ok(())
244                } else {
245                    Err(ScryLearnError::InvalidParameter(format!(
246                        "max_iter expects Int, got {value}"
247                    )))
248                }
249            }
250            "tolerance" => {
251                if let ParamValue::Float(v) = value {
252                    *self = self.clone().tolerance(v);
253                    Ok(())
254                } else {
255                    Err(ScryLearnError::InvalidParameter(format!(
256                        "tolerance expects Float, got {value}"
257                    )))
258                }
259            }
260            "n_init" => {
261                if let ParamValue::Int(v) = value {
262                    *self = self.clone().n_init(v);
263                    Ok(())
264                } else {
265                    Err(ScryLearnError::InvalidParameter(format!(
266                        "n_init expects Int, got {value}"
267                    )))
268                }
269            }
270            _ => Err(ScryLearnError::InvalidParameter(format!(
271                "unknown parameter: {name}"
272            ))),
273        }
274    }
275    fn clone_box(&self) -> Box<dyn Tunable> {
276        Box::new(self.clone())
277    }
278    fn fit(&mut self, data: &Dataset) -> Result<()> {
279        self.fit(data)
280    }
281    fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
282        let labels = crate::cluster::KMeans::predict(self, features)?;
283        Ok(labels.into_iter().map(|l| l as f64).collect())
284    }
285}
286
287impl Tunable for crate::anomaly::IsolationForest {
288    fn set_param(&mut self, name: &str, value: ParamValue) -> Result<()> {
289        match name {
290            "n_estimators" => {
291                if let ParamValue::Int(v) = value {
292                    *self = self.clone().n_estimators(v);
293                    Ok(())
294                } else {
295                    Err(ScryLearnError::InvalidParameter(format!(
296                        "n_estimators expects Int, got {value}"
297                    )))
298                }
299            }
300            "max_samples" => {
301                if let ParamValue::Int(v) = value {
302                    *self = self.clone().max_samples(v);
303                    Ok(())
304                } else {
305                    Err(ScryLearnError::InvalidParameter(format!(
306                        "max_samples expects Int, got {value}"
307                    )))
308                }
309            }
310            "contamination" => {
311                if let ParamValue::Float(v) = value {
312                    *self = self.clone().contamination(v);
313                    Ok(())
314                } else {
315                    Err(ScryLearnError::InvalidParameter(format!(
316                        "contamination expects Float, got {value}"
317                    )))
318                }
319            }
320            _ => Err(ScryLearnError::InvalidParameter(format!(
321                "unknown parameter: {name}"
322            ))),
323        }
324    }
325    fn clone_box(&self) -> Box<dyn Tunable> {
326        Box::new(self.clone())
327    }
328    fn fit(&mut self, data: &Dataset) -> Result<()> {
329        let features = data.feature_matrix();
330        self.fit(&features)
331    }
332    fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
333        Ok(self.predict(features))
334    }
335}