Skip to main content

scry_learn/search/
random.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Randomized search with cross-validation.
3
4use std::collections::HashMap;
5
6use crate::dataset::Dataset;
7use crate::error::{Result, ScryLearnError};
8use crate::metrics::accuracy;
9use crate::split::{k_fold, stratified_k_fold, ScoringFn};
10
11use super::{cartesian_product, evaluate_combo, CvResult, ParamGrid, ParamValue, Tunable};
12
13/// Randomized search over a hyperparameter grid with cross-validation.
14///
15/// Samples `n_iter` random combinations from the grid instead of trying
16/// every one — much faster for large grids.
17///
18/// # Examples
19///
20/// ```ignore
21/// use scry_learn::prelude::*;
22/// use scry_learn::search::*;
23///
24/// let mut grid = ParamGrid::new();
25/// grid.insert("max_depth".into(), vec![
26///     ParamValue::Int(2), ParamValue::Int(4),
27///     ParamValue::Int(6), ParamValue::Int(8),
28/// ]);
29///
30/// let result = RandomizedSearchCV::new(DecisionTreeClassifier::new(), grid)
31///     .n_iter(5)
32///     .cv(3)
33///     .fit(&data)
34///     .unwrap();
35/// ```
36#[non_exhaustive]
37pub struct RandomizedSearchCV {
38    base_model: Box<dyn Tunable>,
39    param_grid: ParamGrid,
40    n_iter: usize,
41    cv: usize,
42    scorer: ScoringFn,
43    seed: u64,
44    stratified: bool,
45    best_params_: Option<HashMap<String, ParamValue>>,
46    best_score_: f64,
47    cv_results_: Vec<CvResult>,
48}
49
50impl RandomizedSearchCV {
51    /// Create a randomized search with `n_iter` random samples.
52    ///
53    /// Defaults: 10 iterations, 5-fold CV, accuracy scorer, seed 42, non-stratified.
54    pub fn new(model: impl Tunable + 'static, grid: ParamGrid) -> Self {
55        Self {
56            base_model: Box::new(model),
57            param_grid: grid,
58            n_iter: 10,
59            cv: 5,
60            scorer: accuracy,
61            seed: 42,
62            stratified: false,
63            best_params_: None,
64            best_score_: f64::NEG_INFINITY,
65            cv_results_: Vec::new(),
66        }
67    }
68
69    /// Set the number of random combinations to try (default: 10).
70    pub fn n_iter(mut self, n: usize) -> Self {
71        self.n_iter = n;
72        self
73    }
74
75    /// Set the number of cross-validation folds (default: 5).
76    pub fn cv(mut self, k: usize) -> Self {
77        self.cv = k;
78        self
79    }
80
81    /// Set the scoring function (default: `accuracy`).
82    pub fn scoring(mut self, scorer: ScoringFn) -> Self {
83        self.scorer = scorer;
84        self
85    }
86
87    /// Set the random seed (default: 42).
88    pub fn seed(mut self, seed: u64) -> Self {
89        self.seed = seed;
90        self
91    }
92
93    /// Enable stratified k-fold CV (default: `false`).
94    ///
95    /// When `true`, uses [`stratified_k_fold`](crate::split::stratified_k_fold)
96    /// to preserve class proportions in each fold.
97    pub fn stratified(mut self, stratified: bool) -> Self {
98        self.stratified = stratified;
99        self
100    }
101
102    /// Run the randomized search.
103    ///
104    /// Samples up to `n_iter` random parameter combinations from the grid.
105    pub fn fit(mut self, data: &Dataset) -> Result<Self> {
106        if self.cv < 2 {
107            return Err(ScryLearnError::InvalidParameter(format!(
108                "cv must be >= 2, got {}",
109                self.cv
110            )));
111        }
112        let all_combos = cartesian_product(&self.param_grid);
113        if all_combos.is_empty() {
114            return Err(ScryLearnError::InvalidParameter(
115                "parameter grid is empty".into(),
116            ));
117        }
118
119        let folds = if self.stratified {
120            stratified_k_fold(data, self.cv, self.seed)
121        } else {
122            k_fold(data, self.cv, self.seed)
123        };
124        let mut rng = crate::rng::FastRng::new(self.seed);
125
126        // Sample n_iter unique combos (or all if grid is smaller).
127        let n = self.n_iter.min(all_combos.len());
128        let mut indices: Vec<usize> = (0..all_combos.len()).collect();
129        // Fisher-Yates shuffle and take first n.
130        for i in (1..indices.len()).rev() {
131            let j = rng.usize(0..=i);
132            indices.swap(i, j);
133        }
134
135        for &idx in &indices[..n] {
136            let combo = &all_combos[idx];
137            let result = evaluate_combo(&*self.base_model, combo, &folds, self.scorer)?;
138
139            if result.mean_score.is_finite()
140                && (self.best_params_.is_none() || result.mean_score > self.best_score_)
141            {
142                self.best_score_ = result.mean_score;
143                self.best_params_ = Some(result.params.clone());
144            }
145            self.cv_results_.push(result);
146        }
147
148        if self.best_params_.is_none() {
149            return Err(ScryLearnError::InvalidParameter(
150                "all parameter combinations produced NaN scores".into(),
151            ));
152        }
153
154        Ok(self)
155    }
156
157    /// The best parameter combination found.
158    ///
159    /// # Panics
160    ///
161    /// Panics if called before [`fit`](Self::fit).
162    pub fn best_params(&self) -> &HashMap<String, ParamValue> {
163        self.best_params_.as_ref().expect("call fit() first")
164    }
165
166    /// The best mean CV score achieved.
167    pub fn best_score(&self) -> f64 {
168        self.best_score_
169    }
170
171    /// All evaluated combinations with their scores.
172    pub fn cv_results(&self) -> &[CvResult] {
173        &self.cv_results_
174    }
175}