scry_learn/search/
grid.rs1use std::collections::HashMap;
5
6use crate::dataset::Dataset;
7use crate::error::{Result, ScryLearnError};
8use crate::metrics::accuracy;
9use crate::split::{k_fold, stratified_k_fold, ScoringFn};
10
11use super::{cartesian_product, evaluate_combo, CvResult, ParamGrid, ParamValue, Tunable};
12
13#[non_exhaustive]
38pub struct GridSearchCV {
39 base_model: Box<dyn Tunable>,
40 param_grid: ParamGrid,
41 cv: usize,
42 scorer: ScoringFn,
43 seed: u64,
44 stratified: bool,
45 best_params_: Option<HashMap<String, ParamValue>>,
47 best_score_: f64,
48 cv_results_: Vec<CvResult>,
49}
50
51impl GridSearchCV {
52 pub fn new(model: impl Tunable + 'static, grid: ParamGrid) -> Self {
56 Self {
57 base_model: Box::new(model),
58 param_grid: grid,
59 cv: 5,
60 scorer: accuracy,
61 seed: 42,
62 stratified: false,
63 best_params_: None,
64 best_score_: f64::NEG_INFINITY,
65 cv_results_: Vec::new(),
66 }
67 }
68
69 pub fn cv(mut self, k: usize) -> Self {
71 self.cv = k;
72 self
73 }
74
75 pub fn scoring(mut self, scorer: ScoringFn) -> Self {
77 self.scorer = scorer;
78 self
79 }
80
81 pub fn seed(mut self, seed: u64) -> Self {
83 self.seed = seed;
84 self
85 }
86
87 pub fn stratified(mut self, stratified: bool) -> Self {
92 self.stratified = stratified;
93 self
94 }
95
96 pub fn fit(mut self, data: &Dataset) -> Result<Self> {
100 if self.cv < 2 {
101 return Err(ScryLearnError::InvalidParameter(format!(
102 "cv must be >= 2, got {}",
103 self.cv
104 )));
105 }
106 let combos = cartesian_product(&self.param_grid);
107 if combos.is_empty() {
108 return Err(ScryLearnError::InvalidParameter(
109 "parameter grid is empty".into(),
110 ));
111 }
112
113 let folds = if self.stratified {
114 stratified_k_fold(data, self.cv, self.seed)
115 } else {
116 k_fold(data, self.cv, self.seed)
117 };
118
119 for combo in &combos {
120 let result = evaluate_combo(&*self.base_model, combo, &folds, self.scorer)?;
121
122 if result.mean_score.is_finite()
123 && (self.best_params_.is_none() || result.mean_score > self.best_score_)
124 {
125 self.best_score_ = result.mean_score;
126 self.best_params_ = Some(result.params.clone());
127 }
128 self.cv_results_.push(result);
129 }
130
131 if self.best_params_.is_none() {
132 return Err(ScryLearnError::InvalidParameter(
133 "all parameter combinations produced NaN scores".into(),
134 ));
135 }
136
137 Ok(self)
138 }
139
140 pub fn best_params(&self) -> &HashMap<String, ParamValue> {
146 self.best_params_.as_ref().expect("call fit() first")
147 }
148
149 pub fn best_score(&self) -> f64 {
151 self.best_score_
152 }
153
154 pub fn cv_results(&self) -> &[CvResult] {
156 &self.cv_results_
157 }
158}