Skip to main content

scry_learn/search/
mod.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Hyperparameter search via cross-validation.
3//!
4//! [`GridSearchCV`] performs exhaustive search over a parameter grid,
5//! while [`RandomizedSearchCV`] samples random combinations for faster
6//! exploration of large search spaces.
7//!
8//! # Examples
9//!
10//! ```ignore
11//! use scry_learn::prelude::*;
12//! use scry_learn::search::*;
13//!
14//! let mut grid = ParamGrid::new();
15//! grid.insert("max_depth".into(), vec![ParamValue::Int(2), ParamValue::Int(6)]);
16//!
17//! let result = GridSearchCV::new(DecisionTreeClassifier::new(), grid)
18//!     .cv(5)
19//!     .scoring(accuracy)
20//!     .fit(&data)
21//!     .unwrap();
22//!
23//! println!("Best score: {}", result.best_score());
24//! ```
25
26mod bayes;
27mod grid;
28mod random;
29mod tunable;
30
31pub use bayes::{BayesSearchCV, ParamDistribution, ParamSpace};
32pub use grid::GridSearchCV;
33pub use random::RandomizedSearchCV;
34pub use tunable::Tunable;
35
36use std::collections::HashMap;
37
38use crate::dataset::Dataset;
39use crate::error::Result;
40use crate::split::ScoringFn;
41
42// ---------------------------------------------------------------------------
43// ParamValue + ParamGrid
44// ---------------------------------------------------------------------------
45
46/// A single hyperparameter value.
47///
48/// # Examples
49///
50/// ```
51/// use scry_learn::search::ParamValue;
52///
53/// let depth = ParamValue::Int(5);
54/// let lr = ParamValue::Float(0.01);
55/// let flag = ParamValue::Bool(true);
56/// ```
57#[derive(Debug, Clone, PartialEq)]
58#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
59#[non_exhaustive]
60pub enum ParamValue {
61    /// Integer parameter (e.g. `max_depth`, `n_estimators`).
62    Int(usize),
63    /// Floating-point parameter (e.g. `learning_rate`).
64    Float(f64),
65    /// Boolean parameter (e.g. `bootstrap`).
66    Bool(bool),
67    /// Categorical / string parameter (e.g. `criterion = "gini"`).
68    Categorical(String),
69}
70
71impl std::fmt::Display for ParamValue {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        match self {
74            ParamValue::Int(v) => write!(f, "{v}"),
75            ParamValue::Float(v) => write!(f, "{v}"),
76            ParamValue::Bool(v) => write!(f, "{v}"),
77            ParamValue::Categorical(v) => write!(f, "{v}"),
78        }
79    }
80}
81
82/// A grid of hyperparameter values to search over.
83///
84/// Keys are parameter names (e.g. `"max_depth"`), values are lists of
85/// candidate values to try.
86///
87/// # Examples
88///
89/// ```
90/// use scry_learn::search::{ParamGrid, ParamValue};
91///
92/// let mut grid = ParamGrid::new();
93/// grid.insert("max_depth".into(), vec![
94///     ParamValue::Int(2),
95///     ParamValue::Int(4),
96///     ParamValue::Int(8),
97/// ]);
98/// ```
99pub type ParamGrid = HashMap<String, Vec<ParamValue>>;
100
101// ---------------------------------------------------------------------------
102// CvResult
103// ---------------------------------------------------------------------------
104
105/// Result of a single parameter combination evaluated via cross-validation.
106///
107/// # Examples
108///
109/// ```ignore
110/// for r in search_result.cv_results() {
111///     println!("params={:?}  mean_score={:.3}", r.params, r.mean_score);
112/// }
113/// ```
114#[derive(Debug, Clone)]
115#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
116pub struct CvResult {
117    /// The parameter combination that was evaluated.
118    pub params: HashMap<String, ParamValue>,
119    /// Mean score across all CV folds.
120    pub mean_score: f64,
121    /// Individual fold scores.
122    pub fold_scores: Vec<f64>,
123}
124
125// ---------------------------------------------------------------------------
126// Helpers
127// ---------------------------------------------------------------------------
128
129/// Generate the cartesian product of all parameter lists.
130pub(super) fn cartesian_product(grid: &ParamGrid) -> Vec<HashMap<String, ParamValue>> {
131    let keys: Vec<&String> = grid.keys().collect();
132    if keys.is_empty() {
133        return Vec::new();
134    }
135
136    let mut combos: Vec<HashMap<String, ParamValue>> = vec![HashMap::new()];
137
138    for key in &keys {
139        let values = &grid[*key];
140        let mut new_combos = Vec::with_capacity(combos.len() * values.len());
141        for combo in &combos {
142            for val in values {
143                let mut c = combo.clone();
144                c.insert((*key).clone(), val.clone());
145                new_combos.push(c);
146            }
147        }
148        combos = new_combos;
149    }
150
151    combos
152}
153
154/// Evaluate a single parameter combination via k-fold CV.
155pub(super) fn evaluate_combo(
156    base: &dyn Tunable,
157    params: &HashMap<String, ParamValue>,
158    folds: &[(Dataset, Dataset)],
159    scorer: ScoringFn,
160) -> Result<CvResult> {
161    let mut scores = Vec::with_capacity(folds.len());
162
163    for (train, test) in folds {
164        let mut model = base.clone_box();
165        for (name, value) in params {
166            model.set_param(name, value.clone())?;
167        }
168        model.fit(train)?;
169        let features = test.feature_matrix();
170        let preds = model.predict(&features)?;
171        scores.push(scorer(&test.target, &preds));
172    }
173
174    let mean = scores.iter().sum::<f64>() / scores.len() as f64;
175
176    Ok(CvResult {
177        params: params.clone(),
178        mean_score: mean,
179        fold_scores: scores,
180    })
181}
182
183// ---------------------------------------------------------------------------
184// Tests
185// ---------------------------------------------------------------------------
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use crate::tree::{DecisionTreeClassifier, RandomForestClassifier};
191
192    /// Build an Iris-like dataset with 3 well-separated classes.
193    fn iris_like() -> Dataset {
194        let n_per_class = 30;
195        let n = n_per_class * 3;
196        let mut f0 = Vec::with_capacity(n);
197        let mut f1 = Vec::with_capacity(n);
198        let mut f2 = Vec::with_capacity(n);
199        let mut f3 = Vec::with_capacity(n);
200        let mut target = Vec::with_capacity(n);
201
202        let mut rng = crate::rng::FastRng::new(123);
203
204        for _ in 0..n_per_class {
205            // Class 0: small values
206            f0.push(1.0 + rng.f64() * 0.5);
207            f1.push(1.0 + rng.f64() * 0.5);
208            f2.push(0.5 + rng.f64() * 0.3);
209            f3.push(0.1 + rng.f64() * 0.2);
210            target.push(0.0);
211        }
212        for _ in 0..n_per_class {
213            // Class 1: medium values
214            f0.push(5.0 + rng.f64() * 0.5);
215            f1.push(3.0 + rng.f64() * 0.5);
216            f2.push(3.5 + rng.f64() * 0.5);
217            f3.push(1.0 + rng.f64() * 0.3);
218            target.push(1.0);
219        }
220        for _ in 0..n_per_class {
221            // Class 2: large values
222            f0.push(6.5 + rng.f64() * 0.5);
223            f1.push(3.0 + rng.f64() * 0.5);
224            f2.push(5.5 + rng.f64() * 0.5);
225            f3.push(2.0 + rng.f64() * 0.3);
226            target.push(2.0);
227        }
228
229        Dataset::new(
230            vec![f0, f1, f2, f3],
231            target,
232            vec![
233                "sepal_len".into(),
234                "sepal_wid".into(),
235                "petal_len".into(),
236                "petal_wid".into(),
237            ],
238            "species",
239        )
240    }
241
242    #[test]
243    fn test_grid_search_dt() {
244        let data = iris_like();
245        let mut grid = ParamGrid::new();
246        grid.insert(
247            "max_depth".into(),
248            vec![
249                ParamValue::Int(2),
250                ParamValue::Int(4),
251                ParamValue::Int(6),
252                ParamValue::Int(8),
253            ],
254        );
255
256        let result = GridSearchCV::new(DecisionTreeClassifier::new(), grid)
257            .cv(3)
258            .scoring(crate::metrics::accuracy)
259            .seed(42)
260            .fit(&data)
261            .unwrap();
262
263        // Should find a reasonable best score on well-separated data.
264        assert!(
265            result.best_score() > 0.7,
266            "best score {:.3} too low",
267            result.best_score()
268        );
269        // Should have evaluated all 4 combos.
270        assert_eq!(result.cv_results().len(), 4);
271        // Best params must include max_depth.
272        assert!(result.best_params().contains_key("max_depth"));
273    }
274
275    #[test]
276    fn test_randomized_search_rf() {
277        let data = iris_like();
278        let mut grid = ParamGrid::new();
279        grid.insert(
280            "n_estimators".into(),
281            vec![ParamValue::Int(3), ParamValue::Int(5), ParamValue::Int(10)],
282        );
283        grid.insert(
284            "max_depth".into(),
285            vec![ParamValue::Int(2), ParamValue::Int(4), ParamValue::Int(6)],
286        );
287
288        let result = RandomizedSearchCV::new(RandomForestClassifier::new(), grid)
289            .n_iter(5)
290            .cv(3)
291            .seed(99)
292            .fit(&data)
293            .unwrap();
294
295        // Should have evaluated exactly 5 combos (out of 9 total).
296        assert_eq!(result.cv_results().len(), 5);
297        assert!(
298            result.best_score() > 0.5,
299            "randomized best score too low: {:.3}",
300            result.best_score()
301        );
302        assert!(result.best_params().contains_key("n_estimators"));
303        assert!(result.best_params().contains_key("max_depth"));
304    }
305
306    #[test]
307    fn test_cartesian_product() {
308        let mut grid = ParamGrid::new();
309        grid.insert("a".into(), vec![ParamValue::Int(1), ParamValue::Int(2)]);
310        grid.insert(
311            "b".into(),
312            vec![ParamValue::Float(0.1), ParamValue::Float(0.2)],
313        );
314        let combos = cartesian_product(&grid);
315        assert_eq!(combos.len(), 4);
316    }
317
318    #[test]
319    fn test_invalid_param() {
320        let mut dt = DecisionTreeClassifier::new();
321        let err = dt.set_param("max_depth", ParamValue::Float(3.5));
322        assert!(err.is_err());
323        let err = dt.set_param("nonexistent", ParamValue::Int(3));
324        assert!(err.is_err());
325    }
326
327    #[test]
328    fn test_empty_grid() {
329        let data = iris_like();
330        let grid = ParamGrid::new();
331        let result = GridSearchCV::new(DecisionTreeClassifier::new(), grid).fit(&data);
332        assert!(result.is_err());
333    }
334
335    #[test]
336    fn test_grid_search_logistic() {
337        let data = iris_like();
338        let mut grid = ParamGrid::new();
339        grid.insert(
340            "max_iter".into(),
341            vec![ParamValue::Int(50), ParamValue::Int(200)],
342        );
343        let result = GridSearchCV::new(crate::linear::LogisticRegression::new(), grid)
344            .cv(3)
345            .scoring(crate::metrics::accuracy)
346            .fit(&data)
347            .unwrap();
348
349        assert_eq!(result.cv_results().len(), 2);
350        assert!(
351            result.best_score() > 0.5,
352            "logistic best score too low: {:.3}",
353            result.best_score()
354        );
355        assert!(result.best_params().contains_key("max_iter"));
356    }
357
358    #[test]
359    fn test_grid_search_knn() {
360        let data = iris_like();
361        let mut grid = ParamGrid::new();
362        grid.insert(
363            "k".into(),
364            vec![ParamValue::Int(1), ParamValue::Int(3), ParamValue::Int(5)],
365        );
366        let result = GridSearchCV::new(crate::neighbors::KnnClassifier::new(), grid)
367            .cv(3)
368            .scoring(crate::metrics::accuracy)
369            .fit(&data)
370            .unwrap();
371
372        assert_eq!(result.cv_results().len(), 3);
373        assert!(
374            result.best_score() > 0.7,
375            "knn best score too low: {:.3}",
376            result.best_score()
377        );
378        assert!(result.best_params().contains_key("k"));
379    }
380
381    #[test]
382    fn test_grid_search_gbc() {
383        let data = iris_like();
384        let mut grid = ParamGrid::new();
385        grid.insert(
386            "n_estimators".into(),
387            vec![ParamValue::Int(10), ParamValue::Int(20)],
388        );
389        grid.insert(
390            "max_depth".into(),
391            vec![ParamValue::Int(2), ParamValue::Int(3)],
392        );
393        let result = GridSearchCV::new(crate::tree::GradientBoostingClassifier::new(), grid)
394            .cv(3)
395            .scoring(crate::metrics::accuracy)
396            .fit(&data)
397            .unwrap();
398
399        assert_eq!(result.cv_results().len(), 4);
400        assert!(
401            result.best_score() > 0.6,
402            "gbc best score too low: {:.3}",
403            result.best_score()
404        );
405        assert!(result.best_params().contains_key("n_estimators"));
406        assert!(result.best_params().contains_key("max_depth"));
407    }
408
409    #[test]
410    fn test_grid_search_lasso() {
411        // Regression dataset: y = 2*x + noise.
412        let n = 60;
413        let mut rng = crate::rng::FastRng::new(42);
414        let x: Vec<f64> = (0..n).map(|i| i as f64 / 10.0).collect();
415        let target: Vec<f64> = x.iter().map(|&xi| 2.0 * xi + rng.f64() * 0.5).collect();
416        let data = crate::dataset::Dataset::new(vec![x], target, vec!["x".into()], "y");
417        let mut grid = ParamGrid::new();
418        grid.insert(
419            "alpha".into(),
420            vec![
421                ParamValue::Float(0.01),
422                ParamValue::Float(0.1),
423                ParamValue::Float(1.0),
424            ],
425        );
426        let result = GridSearchCV::new(crate::linear::LassoRegression::new(), grid)
427            .cv(3)
428            .scoring(crate::metrics::r2_score)
429            .fit(&data)
430            .unwrap();
431
432        assert_eq!(result.cv_results().len(), 3);
433        assert!(
434            result.best_score() > 0.5,
435            "lasso r2 too low: {:.3}",
436            result.best_score()
437        );
438        assert!(result.best_params().contains_key("alpha"));
439    }
440
441    #[test]
442    fn test_categorical_display() {
443        let v = ParamValue::Categorical("gini".into());
444        assert_eq!(format!("{v}"), "gini");
445    }
446
447    #[test]
448    fn test_grid_search_stratified() {
449        let data = iris_like();
450        let mut grid = ParamGrid::new();
451        grid.insert(
452            "max_depth".into(),
453            vec![ParamValue::Int(2), ParamValue::Int(4)],
454        );
455
456        let result = GridSearchCV::new(DecisionTreeClassifier::new(), grid)
457            .cv(3)
458            .stratified(true)
459            .scoring(crate::metrics::accuracy)
460            .seed(42)
461            .fit(&data)
462            .unwrap();
463
464        assert_eq!(result.cv_results().len(), 2);
465        assert!(
466            result.best_score() > 0.7,
467            "stratified best score {:.3} too low",
468            result.best_score()
469        );
470    }
471
472    #[test]
473    fn test_randomized_search_stratified() {
474        let data = iris_like();
475        let mut grid = ParamGrid::new();
476        grid.insert(
477            "max_depth".into(),
478            vec![ParamValue::Int(2), ParamValue::Int(4), ParamValue::Int(6)],
479        );
480
481        let result = RandomizedSearchCV::new(DecisionTreeClassifier::new(), grid)
482            .n_iter(2)
483            .cv(3)
484            .stratified(true)
485            .seed(99)
486            .fit(&data)
487            .unwrap();
488
489        assert_eq!(result.cv_results().len(), 2);
490        assert!(
491            result.best_score() > 0.5,
492            "stratified randomized best score {:.3} too low",
493            result.best_score()
494        );
495    }
496}