1mod bayes;
27mod grid;
28mod random;
29mod tunable;
30
31pub use bayes::{BayesSearchCV, ParamDistribution, ParamSpace};
32pub use grid::GridSearchCV;
33pub use random::RandomizedSearchCV;
34pub use tunable::Tunable;
35
36use std::collections::HashMap;
37
38use crate::dataset::Dataset;
39use crate::error::Result;
40use crate::split::ScoringFn;
41
42#[derive(Debug, Clone, PartialEq)]
58#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
59#[non_exhaustive]
60pub enum ParamValue {
61 Int(usize),
63 Float(f64),
65 Bool(bool),
67 Categorical(String),
69}
70
71impl std::fmt::Display for ParamValue {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 match self {
74 ParamValue::Int(v) => write!(f, "{v}"),
75 ParamValue::Float(v) => write!(f, "{v}"),
76 ParamValue::Bool(v) => write!(f, "{v}"),
77 ParamValue::Categorical(v) => write!(f, "{v}"),
78 }
79 }
80}
81
82pub type ParamGrid = HashMap<String, Vec<ParamValue>>;
100
101#[derive(Debug, Clone)]
115#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
116pub struct CvResult {
117 pub params: HashMap<String, ParamValue>,
119 pub mean_score: f64,
121 pub fold_scores: Vec<f64>,
123}
124
125pub(super) fn cartesian_product(grid: &ParamGrid) -> Vec<HashMap<String, ParamValue>> {
131 let keys: Vec<&String> = grid.keys().collect();
132 if keys.is_empty() {
133 return Vec::new();
134 }
135
136 let mut combos: Vec<HashMap<String, ParamValue>> = vec![HashMap::new()];
137
138 for key in &keys {
139 let values = &grid[*key];
140 let mut new_combos = Vec::with_capacity(combos.len() * values.len());
141 for combo in &combos {
142 for val in values {
143 let mut c = combo.clone();
144 c.insert((*key).clone(), val.clone());
145 new_combos.push(c);
146 }
147 }
148 combos = new_combos;
149 }
150
151 combos
152}
153
154pub(super) fn evaluate_combo(
156 base: &dyn Tunable,
157 params: &HashMap<String, ParamValue>,
158 folds: &[(Dataset, Dataset)],
159 scorer: ScoringFn,
160) -> Result<CvResult> {
161 let mut scores = Vec::with_capacity(folds.len());
162
163 for (train, test) in folds {
164 let mut model = base.clone_box();
165 for (name, value) in params {
166 model.set_param(name, value.clone())?;
167 }
168 model.fit(train)?;
169 let features = test.feature_matrix();
170 let preds = model.predict(&features)?;
171 scores.push(scorer(&test.target, &preds));
172 }
173
174 let mean = scores.iter().sum::<f64>() / scores.len() as f64;
175
176 Ok(CvResult {
177 params: params.clone(),
178 mean_score: mean,
179 fold_scores: scores,
180 })
181}
182
183#[cfg(test)]
188mod tests {
189 use super::*;
190 use crate::tree::{DecisionTreeClassifier, RandomForestClassifier};
191
192 fn iris_like() -> Dataset {
194 let n_per_class = 30;
195 let n = n_per_class * 3;
196 let mut f0 = Vec::with_capacity(n);
197 let mut f1 = Vec::with_capacity(n);
198 let mut f2 = Vec::with_capacity(n);
199 let mut f3 = Vec::with_capacity(n);
200 let mut target = Vec::with_capacity(n);
201
202 let mut rng = crate::rng::FastRng::new(123);
203
204 for _ in 0..n_per_class {
205 f0.push(1.0 + rng.f64() * 0.5);
207 f1.push(1.0 + rng.f64() * 0.5);
208 f2.push(0.5 + rng.f64() * 0.3);
209 f3.push(0.1 + rng.f64() * 0.2);
210 target.push(0.0);
211 }
212 for _ in 0..n_per_class {
213 f0.push(5.0 + rng.f64() * 0.5);
215 f1.push(3.0 + rng.f64() * 0.5);
216 f2.push(3.5 + rng.f64() * 0.5);
217 f3.push(1.0 + rng.f64() * 0.3);
218 target.push(1.0);
219 }
220 for _ in 0..n_per_class {
221 f0.push(6.5 + rng.f64() * 0.5);
223 f1.push(3.0 + rng.f64() * 0.5);
224 f2.push(5.5 + rng.f64() * 0.5);
225 f3.push(2.0 + rng.f64() * 0.3);
226 target.push(2.0);
227 }
228
229 Dataset::new(
230 vec![f0, f1, f2, f3],
231 target,
232 vec![
233 "sepal_len".into(),
234 "sepal_wid".into(),
235 "petal_len".into(),
236 "petal_wid".into(),
237 ],
238 "species",
239 )
240 }
241
242 #[test]
243 fn test_grid_search_dt() {
244 let data = iris_like();
245 let mut grid = ParamGrid::new();
246 grid.insert(
247 "max_depth".into(),
248 vec![
249 ParamValue::Int(2),
250 ParamValue::Int(4),
251 ParamValue::Int(6),
252 ParamValue::Int(8),
253 ],
254 );
255
256 let result = GridSearchCV::new(DecisionTreeClassifier::new(), grid)
257 .cv(3)
258 .scoring(crate::metrics::accuracy)
259 .seed(42)
260 .fit(&data)
261 .unwrap();
262
263 assert!(
265 result.best_score() > 0.7,
266 "best score {:.3} too low",
267 result.best_score()
268 );
269 assert_eq!(result.cv_results().len(), 4);
271 assert!(result.best_params().contains_key("max_depth"));
273 }
274
275 #[test]
276 fn test_randomized_search_rf() {
277 let data = iris_like();
278 let mut grid = ParamGrid::new();
279 grid.insert(
280 "n_estimators".into(),
281 vec![ParamValue::Int(3), ParamValue::Int(5), ParamValue::Int(10)],
282 );
283 grid.insert(
284 "max_depth".into(),
285 vec![ParamValue::Int(2), ParamValue::Int(4), ParamValue::Int(6)],
286 );
287
288 let result = RandomizedSearchCV::new(RandomForestClassifier::new(), grid)
289 .n_iter(5)
290 .cv(3)
291 .seed(99)
292 .fit(&data)
293 .unwrap();
294
295 assert_eq!(result.cv_results().len(), 5);
297 assert!(
298 result.best_score() > 0.5,
299 "randomized best score too low: {:.3}",
300 result.best_score()
301 );
302 assert!(result.best_params().contains_key("n_estimators"));
303 assert!(result.best_params().contains_key("max_depth"));
304 }
305
306 #[test]
307 fn test_cartesian_product() {
308 let mut grid = ParamGrid::new();
309 grid.insert("a".into(), vec![ParamValue::Int(1), ParamValue::Int(2)]);
310 grid.insert(
311 "b".into(),
312 vec![ParamValue::Float(0.1), ParamValue::Float(0.2)],
313 );
314 let combos = cartesian_product(&grid);
315 assert_eq!(combos.len(), 4);
316 }
317
318 #[test]
319 fn test_invalid_param() {
320 let mut dt = DecisionTreeClassifier::new();
321 let err = dt.set_param("max_depth", ParamValue::Float(3.5));
322 assert!(err.is_err());
323 let err = dt.set_param("nonexistent", ParamValue::Int(3));
324 assert!(err.is_err());
325 }
326
327 #[test]
328 fn test_empty_grid() {
329 let data = iris_like();
330 let grid = ParamGrid::new();
331 let result = GridSearchCV::new(DecisionTreeClassifier::new(), grid).fit(&data);
332 assert!(result.is_err());
333 }
334
335 #[test]
336 fn test_grid_search_logistic() {
337 let data = iris_like();
338 let mut grid = ParamGrid::new();
339 grid.insert(
340 "max_iter".into(),
341 vec![ParamValue::Int(50), ParamValue::Int(200)],
342 );
343 let result = GridSearchCV::new(crate::linear::LogisticRegression::new(), grid)
344 .cv(3)
345 .scoring(crate::metrics::accuracy)
346 .fit(&data)
347 .unwrap();
348
349 assert_eq!(result.cv_results().len(), 2);
350 assert!(
351 result.best_score() > 0.5,
352 "logistic best score too low: {:.3}",
353 result.best_score()
354 );
355 assert!(result.best_params().contains_key("max_iter"));
356 }
357
358 #[test]
359 fn test_grid_search_knn() {
360 let data = iris_like();
361 let mut grid = ParamGrid::new();
362 grid.insert(
363 "k".into(),
364 vec![ParamValue::Int(1), ParamValue::Int(3), ParamValue::Int(5)],
365 );
366 let result = GridSearchCV::new(crate::neighbors::KnnClassifier::new(), grid)
367 .cv(3)
368 .scoring(crate::metrics::accuracy)
369 .fit(&data)
370 .unwrap();
371
372 assert_eq!(result.cv_results().len(), 3);
373 assert!(
374 result.best_score() > 0.7,
375 "knn best score too low: {:.3}",
376 result.best_score()
377 );
378 assert!(result.best_params().contains_key("k"));
379 }
380
381 #[test]
382 fn test_grid_search_gbc() {
383 let data = iris_like();
384 let mut grid = ParamGrid::new();
385 grid.insert(
386 "n_estimators".into(),
387 vec![ParamValue::Int(10), ParamValue::Int(20)],
388 );
389 grid.insert(
390 "max_depth".into(),
391 vec![ParamValue::Int(2), ParamValue::Int(3)],
392 );
393 let result = GridSearchCV::new(crate::tree::GradientBoostingClassifier::new(), grid)
394 .cv(3)
395 .scoring(crate::metrics::accuracy)
396 .fit(&data)
397 .unwrap();
398
399 assert_eq!(result.cv_results().len(), 4);
400 assert!(
401 result.best_score() > 0.6,
402 "gbc best score too low: {:.3}",
403 result.best_score()
404 );
405 assert!(result.best_params().contains_key("n_estimators"));
406 assert!(result.best_params().contains_key("max_depth"));
407 }
408
409 #[test]
410 fn test_grid_search_lasso() {
411 let n = 60;
413 let mut rng = crate::rng::FastRng::new(42);
414 let x: Vec<f64> = (0..n).map(|i| i as f64 / 10.0).collect();
415 let target: Vec<f64> = x.iter().map(|&xi| 2.0 * xi + rng.f64() * 0.5).collect();
416 let data = crate::dataset::Dataset::new(vec![x], target, vec!["x".into()], "y");
417 let mut grid = ParamGrid::new();
418 grid.insert(
419 "alpha".into(),
420 vec![
421 ParamValue::Float(0.01),
422 ParamValue::Float(0.1),
423 ParamValue::Float(1.0),
424 ],
425 );
426 let result = GridSearchCV::new(crate::linear::LassoRegression::new(), grid)
427 .cv(3)
428 .scoring(crate::metrics::r2_score)
429 .fit(&data)
430 .unwrap();
431
432 assert_eq!(result.cv_results().len(), 3);
433 assert!(
434 result.best_score() > 0.5,
435 "lasso r2 too low: {:.3}",
436 result.best_score()
437 );
438 assert!(result.best_params().contains_key("alpha"));
439 }
440
441 #[test]
442 fn test_categorical_display() {
443 let v = ParamValue::Categorical("gini".into());
444 assert_eq!(format!("{v}"), "gini");
445 }
446
447 #[test]
448 fn test_grid_search_stratified() {
449 let data = iris_like();
450 let mut grid = ParamGrid::new();
451 grid.insert(
452 "max_depth".into(),
453 vec![ParamValue::Int(2), ParamValue::Int(4)],
454 );
455
456 let result = GridSearchCV::new(DecisionTreeClassifier::new(), grid)
457 .cv(3)
458 .stratified(true)
459 .scoring(crate::metrics::accuracy)
460 .seed(42)
461 .fit(&data)
462 .unwrap();
463
464 assert_eq!(result.cv_results().len(), 2);
465 assert!(
466 result.best_score() > 0.7,
467 "stratified best score {:.3} too low",
468 result.best_score()
469 );
470 }
471
472 #[test]
473 fn test_randomized_search_stratified() {
474 let data = iris_like();
475 let mut grid = ParamGrid::new();
476 grid.insert(
477 "max_depth".into(),
478 vec![ParamValue::Int(2), ParamValue::Int(4), ParamValue::Int(6)],
479 );
480
481 let result = RandomizedSearchCV::new(DecisionTreeClassifier::new(), grid)
482 .n_iter(2)
483 .cv(3)
484 .stratified(true)
485 .seed(99)
486 .fit(&data)
487 .unwrap();
488
489 assert_eq!(result.cv_results().len(), 2);
490 assert!(
491 result.best_score() > 0.5,
492 "stratified randomized best score {:.3} too low",
493 result.best_score()
494 );
495 }
496}