1use crate::dataset::Dataset;
8use crate::error::{Result, ScryLearnError};
9
10use super::ParamValue;
11
12pub trait Tunable {
26 fn set_param(&mut self, name: &str, value: ParamValue) -> Result<()>;
31
32 fn clone_box(&self) -> Box<dyn Tunable>;
34
35 fn fit(&mut self, data: &Dataset) -> Result<()>;
37
38 fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>>;
40}
41
42macro_rules! impl_tunable {
54 (
55 $(
56 $(#[$meta:meta])*
57 $Model:ty {
58 $( $param:ident : $kind:ident ),* $(,)?
59 }
60 );* $(;)?
61 ) => {
62 $(
63 $(#[$meta])*
64 impl Tunable for $Model {
65 fn set_param(&mut self, name: &str, _value: ParamValue) -> Result<()> {
66 match name {
67 $(
68 stringify!($param) => {
69 impl_tunable!(@extract _value, $kind, $param, self)
70 }
71 )*
72 _ => Err(ScryLearnError::InvalidParameter(format!(
73 "unknown parameter: {name}"
74 ))),
75 }
76 }
77
78 fn clone_box(&self) -> Box<dyn Tunable> {
79 Box::new(self.clone())
80 }
81
82 fn fit(&mut self, data: &Dataset) -> Result<()> {
83 self.fit(data)
84 }
85
86 fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
87 self.predict(features)
88 }
89 }
90 )*
91 };
92
93 (@extract $value:ident, Int, $param:ident, $self:ident) => {
95 if let ParamValue::Int(v) = $value {
96 *$self = $self.clone().$param(v);
97 Ok(())
98 } else {
99 Err(ScryLearnError::InvalidParameter(format!(
100 concat!(stringify!($param), " expects Int, got {}"), $value
101 )))
102 }
103 };
104
105 (@extract $value:ident, Float, $param:ident, $self:ident) => {
107 if let ParamValue::Float(v) = $value {
108 *$self = $self.clone().$param(v);
109 Ok(())
110 } else {
111 Err(ScryLearnError::InvalidParameter(format!(
112 concat!(stringify!($param), " expects Float, got {}"), $value
113 )))
114 }
115 };
116}
117
118impl_tunable! {
123 crate::tree::DecisionTreeClassifier {
124 max_depth: Int,
125 min_samples_split: Int,
126 min_samples_leaf: Int,
127 };
128 crate::tree::DecisionTreeRegressor {
129 max_depth: Int,
130 min_samples_split: Int,
131 min_samples_leaf: Int,
132 };
133 crate::tree::RandomForestClassifier {
134 n_estimators: Int,
135 max_depth: Int,
136 };
137 crate::linear::LogisticRegression {
138 learning_rate: Float,
139 max_iter: Int,
140 alpha: Float,
141 tolerance: Float,
142 };
143 crate::neighbors::KnnClassifier {
144 k: Int,
145 };
146 crate::neighbors::KnnRegressor {
147 k: Int,
148 };
149 crate::tree::GradientBoostingRegressor {
150 n_estimators: Int,
151 learning_rate: Float,
152 max_depth: Int,
153 min_samples_split: Int,
154 min_samples_leaf: Int,
155 };
156 crate::tree::GradientBoostingClassifier {
157 n_estimators: Int,
158 learning_rate: Float,
159 max_depth: Int,
160 min_samples_split: Int,
161 min_samples_leaf: Int,
162 };
163 crate::svm::LinearSVC {
164 c: Float,
165 max_iter: Int,
166 tol: Float,
167 };
168 crate::svm::LinearSVR {
169 c: Float,
170 epsilon: Float,
171 max_iter: Int,
172 tol: Float,
173 };
174 #[cfg(feature = "experimental")]
175 crate::svm::KernelSVC {
176 c: Float,
177 tol: Float,
178 max_iter: Int,
179 };
180 #[cfg(feature = "experimental")]
181 crate::svm::KernelSVR {
182 c: Float,
183 epsilon: Float,
184 tol: Float,
185 max_iter: Int,
186 };
187 crate::naive_bayes::GaussianNb {};
188 crate::naive_bayes::BernoulliNB {
189 alpha: Float,
190 };
191 crate::naive_bayes::MultinomialNB {
192 alpha: Float,
193 };
194 crate::linear::LassoRegression {
195 alpha: Float,
196 max_iter: Int,
197 tol: Float,
198 };
199 crate::linear::ElasticNet {
200 alpha: Float,
201 l1_ratio: Float,
202 max_iter: Int,
203 tol: Float,
204 };
205 crate::tree::HistGradientBoostingRegressor {
206 n_estimators: Int,
207 learning_rate: Float,
208 max_leaf_nodes: Int,
209 max_depth: Int,
210 min_samples_leaf: Int,
211 };
212 crate::tree::HistGradientBoostingClassifier {
213 n_estimators: Int,
214 learning_rate: Float,
215 max_leaf_nodes: Int,
216 max_depth: Int,
217 min_samples_leaf: Int,
218 };
219 crate::neural::MLPClassifier {
220 learning_rate: Float,
221 alpha: Float,
222 max_iter: Int,
223 batch_size: Int,
224 };
225 crate::neural::MLPRegressor {
226 learning_rate: Float,
227 alpha: Float,
228 max_iter: Int,
229 batch_size: Int,
230 };
231}
232
233impl Tunable for crate::cluster::KMeans {
238 fn set_param(&mut self, name: &str, value: ParamValue) -> Result<()> {
239 match name {
240 "max_iter" => {
241 if let ParamValue::Int(v) = value {
242 *self = self.clone().max_iter(v);
243 Ok(())
244 } else {
245 Err(ScryLearnError::InvalidParameter(format!(
246 "max_iter expects Int, got {value}"
247 )))
248 }
249 }
250 "tolerance" => {
251 if let ParamValue::Float(v) = value {
252 *self = self.clone().tolerance(v);
253 Ok(())
254 } else {
255 Err(ScryLearnError::InvalidParameter(format!(
256 "tolerance expects Float, got {value}"
257 )))
258 }
259 }
260 "n_init" => {
261 if let ParamValue::Int(v) = value {
262 *self = self.clone().n_init(v);
263 Ok(())
264 } else {
265 Err(ScryLearnError::InvalidParameter(format!(
266 "n_init expects Int, got {value}"
267 )))
268 }
269 }
270 _ => Err(ScryLearnError::InvalidParameter(format!(
271 "unknown parameter: {name}"
272 ))),
273 }
274 }
275 fn clone_box(&self) -> Box<dyn Tunable> {
276 Box::new(self.clone())
277 }
278 fn fit(&mut self, data: &Dataset) -> Result<()> {
279 self.fit(data)
280 }
281 fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
282 let labels = crate::cluster::KMeans::predict(self, features)?;
283 Ok(labels.into_iter().map(|l| l as f64).collect())
284 }
285}
286
287impl Tunable for crate::anomaly::IsolationForest {
288 fn set_param(&mut self, name: &str, value: ParamValue) -> Result<()> {
289 match name {
290 "n_estimators" => {
291 if let ParamValue::Int(v) = value {
292 *self = self.clone().n_estimators(v);
293 Ok(())
294 } else {
295 Err(ScryLearnError::InvalidParameter(format!(
296 "n_estimators expects Int, got {value}"
297 )))
298 }
299 }
300 "max_samples" => {
301 if let ParamValue::Int(v) = value {
302 *self = self.clone().max_samples(v);
303 Ok(())
304 } else {
305 Err(ScryLearnError::InvalidParameter(format!(
306 "max_samples expects Int, got {value}"
307 )))
308 }
309 }
310 "contamination" => {
311 if let ParamValue::Float(v) = value {
312 *self = self.clone().contamination(v);
313 Ok(())
314 } else {
315 Err(ScryLearnError::InvalidParameter(format!(
316 "contamination expects Float, got {value}"
317 )))
318 }
319 }
320 _ => Err(ScryLearnError::InvalidParameter(format!(
321 "unknown parameter: {name}"
322 ))),
323 }
324 }
325 fn clone_box(&self) -> Box<dyn Tunable> {
326 Box::new(self.clone())
327 }
328 fn fit(&mut self, data: &Dataset) -> Result<()> {
329 let features = data.feature_matrix();
330 self.fit(&features)
331 }
332 fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
333 Ok(self.predict(features))
334 }
335}