Skip to main content

scry_learn/search/
grid.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Exhaustive grid search with cross-validation.
3
4use std::collections::HashMap;
5
6use crate::dataset::Dataset;
7use crate::error::{Result, ScryLearnError};
8use crate::metrics::accuracy;
9use crate::split::{k_fold, stratified_k_fold, ScoringFn};
10
11use super::{cartesian_product, evaluate_combo, CvResult, ParamGrid, ParamValue, Tunable};
12
13/// Exhaustive search over a hyperparameter grid with cross-validation.
14///
15/// Tries every combination in the grid, evaluates each with k-fold CV,
16/// and reports the best-performing parameter set.
17///
18/// # Examples
19///
20/// ```ignore
21/// use scry_learn::prelude::*;
22/// use scry_learn::search::*;
23///
24/// let mut grid = ParamGrid::new();
25/// grid.insert("max_depth".into(), vec![
26///     ParamValue::Int(2), ParamValue::Int(4), ParamValue::Int(8),
27/// ]);
28///
29/// let result = GridSearchCV::new(DecisionTreeClassifier::new(), grid)
30///     .cv(5)
31///     .scoring(accuracy)
32///     .fit(&data)
33///     .unwrap();
34///
35/// println!("Best: {:?} → {:.3}", result.best_params(), result.best_score());
36/// ```
37#[non_exhaustive]
38pub struct GridSearchCV {
39    base_model: Box<dyn Tunable>,
40    param_grid: ParamGrid,
41    cv: usize,
42    scorer: ScoringFn,
43    seed: u64,
44    stratified: bool,
45    // Results (populated after fit)
46    best_params_: Option<HashMap<String, ParamValue>>,
47    best_score_: f64,
48    cv_results_: Vec<CvResult>,
49}
50
51impl GridSearchCV {
52    /// Create a grid search over the given model and parameter grid.
53    ///
54    /// Defaults: 5-fold CV, accuracy scorer, seed 42, non-stratified.
55    pub fn new(model: impl Tunable + 'static, grid: ParamGrid) -> Self {
56        Self {
57            base_model: Box::new(model),
58            param_grid: grid,
59            cv: 5,
60            scorer: accuracy,
61            seed: 42,
62            stratified: false,
63            best_params_: None,
64            best_score_: f64::NEG_INFINITY,
65            cv_results_: Vec::new(),
66        }
67    }
68
69    /// Set the number of cross-validation folds (default: 5).
70    pub fn cv(mut self, k: usize) -> Self {
71        self.cv = k;
72        self
73    }
74
75    /// Set the scoring function (default: `accuracy`).
76    pub fn scoring(mut self, scorer: ScoringFn) -> Self {
77        self.scorer = scorer;
78        self
79    }
80
81    /// Set the random seed for fold generation (default: 42).
82    pub fn seed(mut self, seed: u64) -> Self {
83        self.seed = seed;
84        self
85    }
86
87    /// Enable stratified k-fold CV (default: `false`).
88    ///
89    /// When `true`, uses [`stratified_k_fold`](crate::split::stratified_k_fold)
90    /// to preserve class proportions in each fold.
91    pub fn stratified(mut self, stratified: bool) -> Self {
92        self.stratified = stratified;
93        self
94    }
95
96    /// Run the exhaustive grid search.
97    ///
98    /// Returns `self` for chained accessor calls.
99    pub fn fit(mut self, data: &Dataset) -> Result<Self> {
100        if self.cv < 2 {
101            return Err(ScryLearnError::InvalidParameter(format!(
102                "cv must be >= 2, got {}",
103                self.cv
104            )));
105        }
106        let combos = cartesian_product(&self.param_grid);
107        if combos.is_empty() {
108            return Err(ScryLearnError::InvalidParameter(
109                "parameter grid is empty".into(),
110            ));
111        }
112
113        let folds = if self.stratified {
114            stratified_k_fold(data, self.cv, self.seed)
115        } else {
116            k_fold(data, self.cv, self.seed)
117        };
118
119        for combo in &combos {
120            let result = evaluate_combo(&*self.base_model, combo, &folds, self.scorer)?;
121
122            if result.mean_score.is_finite()
123                && (self.best_params_.is_none() || result.mean_score > self.best_score_)
124            {
125                self.best_score_ = result.mean_score;
126                self.best_params_ = Some(result.params.clone());
127            }
128            self.cv_results_.push(result);
129        }
130
131        if self.best_params_.is_none() {
132            return Err(ScryLearnError::InvalidParameter(
133                "all parameter combinations produced NaN scores".into(),
134            ));
135        }
136
137        Ok(self)
138    }
139
140    /// The best parameter combination found.
141    ///
142    /// # Panics
143    ///
144    /// Panics if called before [`fit`](Self::fit).
145    pub fn best_params(&self) -> &HashMap<String, ParamValue> {
146        self.best_params_.as_ref().expect("call fit() first")
147    }
148
149    /// The best mean CV score achieved.
150    pub fn best_score(&self) -> f64 {
151        self.best_score_
152    }
153
154    /// All evaluated combinations with their scores.
155    pub fn cv_results(&self) -> &[CvResult] {
156        &self.cv_results_
157    }
158}