sklears_model_selection/
grid_search.rs

1//! Grid search and randomized search for hyperparameter tuning
2
3use crate::{CrossValidator, KFold, Scoring};
4use scirs2_core::ndarray::{Array1, Array2};
5use scirs2_core::rand_prelude::IndexedRandom;
6use scirs2_core::random::essentials::Normal as RandNormal;
7// use scirs2_core::random::prelude::*;
8use sklears_core::{
9    error::Result,
10    prelude::SklearsError,
11    traits::{Fit, Predict, Score},
12    types::Float,
13};
14use sklears_metrics::{classification::accuracy_score, get_scorer, regression::mean_squared_error};
15use std::collections::HashMap;
16use std::marker::PhantomData;
17
18/// Parameter grid for grid search
19///
20/// This represents all possible combinations of hyperparameters to test.
21/// Each parameter name maps to a vector of possible values.
22pub type ParameterGrid = HashMap<String, Vec<ParameterValue>>;
23
24/// A parameter value that can be of different types
25#[derive(Debug, Clone, PartialEq)]
26#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
27pub enum ParameterValue {
28    /// Integer parameter
29    Int(i64),
30    /// Float parameter
31    Float(f64),
32    /// Boolean parameter
33    Bool(bool),
34    /// String parameter
35    String(String),
36    /// Option integer parameter (Some/None)
37    OptionalInt(Option<i64>),
38    /// Option float parameter (Some/None)
39    OptionalFloat(Option<f64>),
40}
41
42impl ParameterValue {
43    /// Extract an integer value
44    pub fn as_int(&self) -> Option<i64> {
45        match self {
46            ParameterValue::Int(v) => Some(*v),
47            _ => None,
48        }
49    }
50
51    /// Extract a float value
52    pub fn as_float(&self) -> Option<f64> {
53        match self {
54            ParameterValue::Float(v) => Some(*v),
55            _ => None,
56        }
57    }
58
59    /// Extract a boolean value
60    pub fn as_bool(&self) -> Option<bool> {
61        match self {
62            ParameterValue::Bool(v) => Some(*v),
63            _ => None,
64        }
65    }
66
67    /// Extract an optional integer value
68    pub fn as_optional_int(&self) -> Option<Option<i64>> {
69        match self {
70            ParameterValue::OptionalInt(v) => Some(*v),
71            _ => None,
72        }
73    }
74
75    /// Extract an optional float value
76    pub fn as_optional_float(&self) -> Option<Option<f64>> {
77        match self {
78            ParameterValue::OptionalFloat(v) => Some(*v),
79            _ => None,
80        }
81    }
82}
83
84impl From<i32> for ParameterValue {
85    fn from(value: i32) -> Self {
86        ParameterValue::Int(value as i64)
87    }
88}
89
90impl From<i64> for ParameterValue {
91    fn from(value: i64) -> Self {
92        ParameterValue::Int(value)
93    }
94}
95
96impl From<f32> for ParameterValue {
97    fn from(value: f32) -> Self {
98        ParameterValue::Float(value as f64)
99    }
100}
101
102impl From<f64> for ParameterValue {
103    fn from(value: f64) -> Self {
104        ParameterValue::Float(value)
105    }
106}
107
108impl From<bool> for ParameterValue {
109    fn from(value: bool) -> Self {
110        ParameterValue::Bool(value)
111    }
112}
113
114impl From<String> for ParameterValue {
115    fn from(value: String) -> Self {
116        ParameterValue::String(value)
117    }
118}
119
120impl From<&str> for ParameterValue {
121    fn from(value: &str) -> Self {
122        ParameterValue::String(value.to_string())
123    }
124}
125
126impl From<Option<i32>> for ParameterValue {
127    fn from(value: Option<i32>) -> Self {
128        ParameterValue::OptionalInt(value.map(|v| v as i64))
129    }
130}
131
132impl From<Option<i64>> for ParameterValue {
133    fn from(value: Option<i64>) -> Self {
134        ParameterValue::OptionalInt(value)
135    }
136}
137
138impl From<Option<f64>> for ParameterValue {
139    fn from(value: Option<f64>) -> Self {
140        ParameterValue::OptionalFloat(value)
141    }
142}
143
144impl Eq for ParameterValue {}
145
146impl std::hash::Hash for ParameterValue {
147    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
148        std::mem::discriminant(self).hash(state);
149        match self {
150            ParameterValue::Int(v) => v.hash(state),
151            ParameterValue::Float(v) => v.to_bits().hash(state), // Use bits representation for f64
152            ParameterValue::Bool(v) => v.hash(state),
153            ParameterValue::String(v) => v.hash(state),
154            ParameterValue::OptionalInt(v) => v.hash(state),
155            ParameterValue::OptionalFloat(v) => v.map(|f| f.to_bits()).hash(state),
156        }
157    }
158}
159
160/// A parameter combination for one grid search iteration
161pub type ParameterSet = HashMap<String, ParameterValue>;
162
163/// Grid search cross-validation
164///
165/// Exhaustive search over specified parameter values for an estimator.
166/// Uses cross-validation to evaluate each parameter combination.
167pub struct GridSearchCV<E, F, ConfigFn>
168where
169    E: Clone,
170    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
171    F: Predict<Array2<Float>, Array1<Float>>,
172    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
173    ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
174{
175    /// Base estimator to use for each parameter combination
176    estimator: E,
177    /// Parameter grid to search
178    param_grid: ParameterGrid,
179    /// Cross-validation strategy
180    cv: Box<dyn CrossValidator>,
181    /// Scoring method
182    scoring: Scoring,
183    /// Number of parallel jobs (-1 for all cores)
184    n_jobs: Option<usize>,
185    /// Whether to refit on the entire dataset with best parameters
186    refit: bool,
187    /// Configuration function to apply parameters to estimator
188    config_fn: ConfigFn,
189    /// Phantom data for fitted estimator type
190    _phantom: PhantomData<F>,
191    // Fitted results
192    best_estimator_: Option<F>,
193    best_params_: Option<ParameterSet>,
194    best_score_: Option<f64>,
195    cv_results_: Option<GridSearchResults>,
196}
197
198/// Results from grid search cross-validation
199#[derive(Debug, Clone)]
200pub struct GridSearchResults {
201    pub params: Vec<ParameterSet>,
202    pub mean_test_scores: Array1<f64>,
203    pub std_test_scores: Array1<f64>,
204    pub mean_fit_times: Array1<f64>,
205    pub mean_score_times: Array1<f64>,
206    pub rank_test_scores: Array1<usize>,
207}
208
209/// Helper function for scoring that handles both regression and classification
210fn compute_score_for_regression(
211    metric_name: &str,
212    y_true: &Array1<f64>,
213    y_pred: &Array1<f64>,
214) -> Result<f64> {
215    match metric_name {
216        "neg_mean_squared_error" => Ok(-mean_squared_error(y_true, y_pred)?),
217        "mean_squared_error" => Ok(mean_squared_error(y_true, y_pred)?),
218        _ => {
219            // For unsupported metrics, return a default score
220            Err(SklearsError::InvalidInput(format!(
221                "Metric '{}' not supported for regression",
222                metric_name
223            )))
224        }
225    }
226}
227
228/// Helper function for scoring classification data
229fn compute_score_for_classification(
230    metric_name: &str,
231    y_true: &Array1<i32>,
232    y_pred: &Array1<i32>,
233) -> Result<f64> {
234    match metric_name {
235        "accuracy" => Ok(accuracy_score(y_true, y_pred)?),
236        _ => {
237            let scorer = get_scorer(metric_name)?;
238            scorer.score(y_true.as_slice().unwrap(), y_pred.as_slice().unwrap())
239        }
240    }
241}
242
243impl<E, F, ConfigFn> GridSearchCV<E, F, ConfigFn>
244where
245    E: Clone,
246    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
247    F: Predict<Array2<Float>, Array1<Float>>,
248    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
249    ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
250{
251    /// Create a new grid search CV
252    pub fn new(estimator: E, param_grid: ParameterGrid, config_fn: ConfigFn) -> Self {
253        Self {
254            estimator,
255            param_grid,
256            cv: Box::new(KFold::new(5)),
257            scoring: Scoring::EstimatorScore,
258            n_jobs: None,
259            refit: true,
260            config_fn,
261            _phantom: PhantomData,
262            best_estimator_: None,
263            best_params_: None,
264            best_score_: None,
265            cv_results_: None,
266        }
267    }
268
269    /// Set the cross-validation strategy
270    pub fn cv<C: CrossValidator + 'static>(mut self, cv: C) -> Self {
271        self.cv = Box::new(cv);
272        self
273    }
274
275    /// Set the scoring method
276    pub fn scoring(mut self, scoring: Scoring) -> Self {
277        self.scoring = scoring;
278        self
279    }
280
281    /// Set the number of parallel jobs
282    pub fn n_jobs(mut self, n_jobs: Option<usize>) -> Self {
283        self.n_jobs = n_jobs;
284        self
285    }
286
287    /// Set whether to refit with best parameters
288    pub fn refit(mut self, refit: bool) -> Self {
289        self.refit = refit;
290        self
291    }
292
293    /// Get the best estimator (after fitting)
294    pub fn best_estimator(&self) -> Option<&F> {
295        self.best_estimator_.as_ref()
296    }
297
298    /// Get the best parameters (after fitting)
299    pub fn best_params(&self) -> Option<&ParameterSet> {
300        self.best_params_.as_ref()
301    }
302
303    /// Get the best score (after fitting)
304    pub fn best_score(&self) -> Option<f64> {
305        self.best_score_
306    }
307
308    /// Get the CV results (after fitting)
309    pub fn cv_results(&self) -> Option<&GridSearchResults> {
310        self.cv_results_.as_ref()
311    }
312
313    /// Generate all parameter combinations from the grid
314    fn generate_param_combinations(&self) -> Vec<ParameterSet> {
315        let mut combinations = vec![HashMap::new()];
316
317        for (param_name, param_values) in &self.param_grid {
318            let mut new_combinations = Vec::new();
319
320            for combination in combinations {
321                for param_value in param_values {
322                    let mut new_combination = combination.clone();
323                    new_combination.insert(param_name.clone(), param_value.clone());
324                    new_combinations.push(new_combination);
325                }
326            }
327
328            combinations = new_combinations;
329        }
330
331        combinations
332    }
333
334    /// Evaluate a single parameter combination using cross-validation
335    fn evaluate_params(
336        &self,
337        params: &ParameterSet,
338        x: &Array2<Float>,
339        y: &Array1<Float>,
340    ) -> Result<(f64, f64, f64, f64)> {
341        // Configure estimator with current parameters
342        let configured_estimator = (self.config_fn)(self.estimator.clone(), params)?;
343
344        // Get CV splits
345        let splits = self.cv.split(x.nrows(), None);
346        let n_splits = splits.len();
347
348        let mut test_scores = Vec::with_capacity(n_splits);
349        let mut fit_times = Vec::with_capacity(n_splits);
350        let mut score_times = Vec::with_capacity(n_splits);
351
352        // Evaluate on each CV fold
353        for (train_idx, test_idx) in splits {
354            // Extract train and test data
355            let x_train = x.select(scirs2_core::ndarray::Axis(0), &train_idx);
356            let y_train = y.select(scirs2_core::ndarray::Axis(0), &train_idx);
357            let x_test = x.select(scirs2_core::ndarray::Axis(0), &test_idx);
358            let y_test = y.select(scirs2_core::ndarray::Axis(0), &test_idx);
359
360            // Fit the estimator
361            let start = std::time::Instant::now();
362            let fitted = configured_estimator.clone().fit(&x_train, &y_train)?;
363            let fit_time = start.elapsed().as_secs_f64();
364            fit_times.push(fit_time);
365
366            // Score on test set
367            let start = std::time::Instant::now();
368            let test_score = match &self.scoring {
369                Scoring::EstimatorScore => fitted.score(&x_test, &y_test)?,
370                Scoring::Custom(func) => {
371                    let y_pred = fitted.predict(&x_test)?;
372                    func(&y_test.to_owned(), &y_pred)?
373                }
374                Scoring::Metric(metric_name) => {
375                    let y_pred = fitted.predict(&x_test)?;
376                    compute_score_for_regression(metric_name, &y_test, &y_pred)?
377                }
378                Scoring::Scorer(_scorer) => {
379                    let y_pred = fitted.predict(&x_test)?;
380                    // Default to negative MSE for regression
381                    -mean_squared_error(&y_test, &y_pred)?
382                }
383                Scoring::MultiMetric(_metrics) => {
384                    // For multi-metric, just use the first metric for now
385                    fitted.score(&x_test, &y_test)?
386                }
387            };
388            let score_time = start.elapsed().as_secs_f64();
389            score_times.push(score_time);
390            test_scores.push(test_score);
391        }
392
393        // Calculate statistics
394        let mean_test_score = test_scores.iter().sum::<f64>() / test_scores.len() as f64;
395        let std_test_score = {
396            let variance = test_scores
397                .iter()
398                .map(|&score| (score - mean_test_score).powi(2))
399                .sum::<f64>()
400                / test_scores.len() as f64;
401            variance.sqrt()
402        };
403        let mean_fit_time = fit_times.iter().sum::<f64>() / fit_times.len() as f64;
404        let mean_score_time = score_times.iter().sum::<f64>() / score_times.len() as f64;
405
406        Ok((
407            mean_test_score,
408            std_test_score,
409            mean_fit_time,
410            mean_score_time,
411        ))
412    }
413}
414
415impl<E, F, ConfigFn> Fit<Array2<Float>, Array1<Float>> for GridSearchCV<E, F, ConfigFn>
416where
417    E: Clone + Send + Sync,
418    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
419    F: Predict<Array2<Float>, Array1<Float>> + Send + Sync,
420    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
421    ConfigFn: Fn(E, &ParameterSet) -> Result<E> + Send + Sync,
422{
423    type Fitted = GridSearchCV<E, F, ConfigFn>;
424
425    fn fit(mut self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
426        if x.nrows() == 0 {
427            return Err(SklearsError::InvalidInput(
428                "Cannot fit on empty dataset".to_string(),
429            ));
430        }
431
432        if x.nrows() != y.len() {
433            return Err(SklearsError::ShapeMismatch {
434                expected: format!("X.shape[0] = {}", x.nrows()),
435                actual: format!("y.shape[0] = {}", y.len()),
436            });
437        }
438
439        // Generate all parameter combinations
440        let param_combinations = self.generate_param_combinations();
441
442        if param_combinations.is_empty() {
443            return Err(SklearsError::InvalidInput(
444                "No parameter combinations to evaluate".to_string(),
445            ));
446        }
447
448        // Evaluate each parameter combination
449        let mut results = Vec::with_capacity(param_combinations.len());
450
451        for params in &param_combinations {
452            let (mean_score, std_score, mean_fit_time, mean_score_time) =
453                self.evaluate_params(params, x, y)?;
454
455            results.push((mean_score, std_score, mean_fit_time, mean_score_time));
456        }
457
458        // Find best parameters
459        let best_idx = results
460            .iter()
461            .enumerate()
462            .max_by(|(_, a), (_, b)| a.0.partial_cmp(&b.0).unwrap())
463            .map(|(idx, _)| idx)
464            .ok_or_else(|| SklearsError::NumericalError("No valid scores found".to_string()))?;
465
466        let best_params = param_combinations[best_idx].clone();
467        let best_score = results[best_idx].0;
468
469        // Create CV results
470        let mean_test_scores = Array1::from_vec(results.iter().map(|r| r.0).collect());
471        let std_test_scores = Array1::from_vec(results.iter().map(|r| r.1).collect());
472        let mean_fit_times = Array1::from_vec(results.iter().map(|r| r.2).collect());
473        let mean_score_times = Array1::from_vec(results.iter().map(|r| r.3).collect());
474
475        // Calculate ranks (1 = best)
476        let mut scores_with_idx: Vec<(f64, usize)> = mean_test_scores
477            .iter()
478            .enumerate()
479            .map(|(i, &score)| (score, i))
480            .collect();
481        scores_with_idx.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
482
483        let mut ranks = vec![0; param_combinations.len()];
484        for (rank, (_, idx)) in scores_with_idx.iter().enumerate() {
485            ranks[*idx] = rank + 1;
486        }
487
488        let cv_results = GridSearchResults {
489            params: param_combinations.clone(),
490            mean_test_scores,
491            std_test_scores,
492            mean_fit_times,
493            mean_score_times,
494            rank_test_scores: Array1::from_vec(ranks),
495        };
496
497        // Refit with best parameters if requested
498        let best_estimator = if self.refit {
499            let configured_estimator = (self.config_fn)(self.estimator.clone(), &best_params)?;
500            Some(configured_estimator.fit(x, y)?)
501        } else {
502            None
503        };
504
505        self.best_estimator_ = best_estimator;
506        self.best_params_ = Some(best_params);
507        self.best_score_ = Some(best_score);
508        self.cv_results_ = Some(cv_results);
509
510        Ok(self)
511    }
512}
513
514impl<E, F, ConfigFn> Predict<Array2<Float>, Array1<Float>> for GridSearchCV<E, F, ConfigFn>
515where
516    E: Clone,
517    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
518    F: Predict<Array2<Float>, Array1<Float>>,
519    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
520    ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
521{
522    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
523        match &self.best_estimator_ {
524            Some(estimator) => estimator.predict(x),
525            None => Err(SklearsError::NotFitted {
526                operation: "predict".to_string(),
527            }),
528        }
529    }
530}
531
532impl<E, F, ConfigFn> Score<Array2<Float>, Array1<Float>> for GridSearchCV<E, F, ConfigFn>
533where
534    E: Clone,
535    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
536    F: Predict<Array2<Float>, Array1<Float>>,
537    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
538    ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
539{
540    type Float = f64;
541    fn score(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
542        match &self.best_estimator_ {
543            Some(estimator) => estimator.score(x, y),
544            None => Err(SklearsError::NotFitted {
545                operation: "score".to_string(),
546            }),
547        }
548    }
549}
550
551/// Parameter distribution for randomized search
552#[derive(Debug, Clone)]
553pub enum ParameterDistribution {
554    /// Uniform distribution over discrete values
555    Choice(Vec<ParameterValue>),
556    /// Uniform distribution over integer range [low, high)
557    RandInt { low: i64, high: i64 },
558    /// Uniform distribution over float range [low, high)
559    Uniform { low: f64, high: f64 },
560    /// Log-uniform distribution over float range [low, high)
561    LogUniform { low: f64, high: f64 },
562    /// Normal distribution with mean and std
563    Normal { mean: f64, std: f64 },
564}
565
566impl ParameterDistribution {
567    /// Sample a value from this distribution
568    pub fn sample(&self, rng: &mut impl scirs2_core::random::Rng) -> ParameterValue {
569        use scirs2_core::essentials::Uniform;
570        use scirs2_core::random::Distribution;
571
572        match self {
573            ParameterDistribution::Choice(values) => values.as_slice().choose(rng).unwrap().clone(),
574            ParameterDistribution::RandInt { low, high } => {
575                let dist = Uniform::new(*low, *high).unwrap();
576                ParameterValue::Int(dist.sample(rng))
577            }
578            ParameterDistribution::Uniform { low, high } => {
579                let dist = Uniform::new(*low, *high).unwrap();
580                ParameterValue::Float(dist.sample(rng))
581            }
582            ParameterDistribution::LogUniform { low, high } => {
583                // Implement log-uniform manually: sample from log scale then exponentiate
584                let log_low = low.ln();
585                let log_high = high.ln();
586                let dist = Uniform::new(log_low, log_high).unwrap();
587                let log_sample = dist.sample(rng);
588                ParameterValue::Float(log_sample.exp())
589            }
590            ParameterDistribution::Normal { mean, std } => {
591                let dist = RandNormal::new(*mean, *std).unwrap();
592                ParameterValue::Float(dist.sample(rng))
593            }
594        }
595    }
596}
597
598/// Parameter distribution grid for randomized search
599pub type ParameterDistributions = HashMap<String, ParameterDistribution>;
600
601/// Randomized search cross-validation
602///
603/// Search over parameter distributions with random sampling.
604/// More efficient than GridSearchCV for large parameter spaces.
605pub struct RandomizedSearchCV<E, F, ConfigFn>
606where
607    E: Clone,
608    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
609    F: Predict<Array2<Float>, Array1<Float>>,
610    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
611    ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
612{
613    /// Base estimator to use for each parameter combination
614    estimator: E,
615    /// Parameter distributions to search
616    param_distributions: ParameterDistributions,
617    /// Number of parameter settings to sample
618    n_iter: usize,
619    /// Cross-validation strategy
620    cv: Box<dyn CrossValidator>,
621    /// Scoring method
622    scoring: Scoring,
623    /// Number of parallel jobs
624    n_jobs: Option<usize>,
625    /// Whether to refit on the entire dataset with best parameters
626    refit: bool,
627    /// Random seed for reproducibility
628    random_state: Option<u64>,
629    /// Configuration function to apply parameters to estimator
630    config_fn: ConfigFn,
631    /// Phantom data for fitted estimator type
632    _phantom: PhantomData<F>,
633    // Fitted results
634    best_estimator_: Option<F>,
635    best_params_: Option<ParameterSet>,
636    best_score_: Option<f64>,
637    cv_results_: Option<GridSearchResults>,
638}
639
640impl<E, F, ConfigFn> RandomizedSearchCV<E, F, ConfigFn>
641where
642    E: Clone,
643    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
644    F: Predict<Array2<Float>, Array1<Float>>,
645    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
646    ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
647{
648    pub fn new(
649        estimator: E,
650        param_distributions: ParameterDistributions,
651        config_fn: ConfigFn,
652    ) -> Self {
653        Self {
654            estimator,
655            param_distributions,
656            n_iter: 10,
657            cv: Box::new(KFold::new(5)),
658            scoring: Scoring::EstimatorScore,
659            n_jobs: None,
660            refit: true,
661            random_state: None,
662            config_fn,
663            _phantom: PhantomData,
664            best_estimator_: None,
665            best_params_: None,
666            best_score_: None,
667            cv_results_: None,
668        }
669    }
670
671    /// Set the number of parameter settings to sample
672    pub fn n_iter(mut self, n_iter: usize) -> Self {
673        self.n_iter = n_iter;
674        self
675    }
676
677    /// Set the cross-validation strategy
678    pub fn cv<C: CrossValidator + 'static>(mut self, cv: C) -> Self {
679        self.cv = Box::new(cv);
680        self
681    }
682
683    /// Set the scoring method
684    pub fn scoring(mut self, scoring: Scoring) -> Self {
685        self.scoring = scoring;
686        self
687    }
688
689    /// Set the number of parallel jobs
690    pub fn n_jobs(mut self, n_jobs: Option<usize>) -> Self {
691        self.n_jobs = n_jobs;
692        self
693    }
694
695    /// Set whether to refit with best parameters
696    pub fn refit(mut self, refit: bool) -> Self {
697        self.refit = refit;
698        self
699    }
700
701    /// Set the random seed
702    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
703        self.random_state = random_state;
704        self
705    }
706
707    /// Get the best estimator (after fitting)
708    pub fn best_estimator(&self) -> Option<&F> {
709        self.best_estimator_.as_ref()
710    }
711
712    /// Get the best parameters (after fitting)
713    pub fn best_params(&self) -> Option<&ParameterSet> {
714        self.best_params_.as_ref()
715    }
716
717    /// Get the best score (after fitting)
718    pub fn best_score(&self) -> Option<f64> {
719        self.best_score_
720    }
721
722    /// Get the CV results (after fitting)
723    pub fn cv_results(&self) -> Option<&GridSearchResults> {
724        self.cv_results_.as_ref()
725    }
726
727    /// Sample parameter combinations from distributions
728    fn sample_parameters(&self, n_samples: usize) -> Vec<ParameterSet> {
729        use scirs2_core::random::rngs::StdRng;
730        use scirs2_core::random::SeedableRng;
731
732        let mut rng = match self.random_state {
733            Some(seed) => StdRng::seed_from_u64(seed),
734            None => StdRng::seed_from_u64(42),
735        };
736
737        let mut param_sets = Vec::with_capacity(n_samples);
738
739        for _ in 0..n_samples {
740            let mut param_set = HashMap::new();
741
742            for (param_name, distribution) in &self.param_distributions {
743                let value = distribution.sample(&mut rng);
744                param_set.insert(param_name.clone(), value);
745            }
746
747            param_sets.push(param_set);
748        }
749
750        param_sets
751    }
752
753    /// Evaluate a single parameter combination using cross-validation
754    fn evaluate_params(
755        &self,
756        params: &ParameterSet,
757        x: &Array2<Float>,
758        y: &Array1<Float>,
759    ) -> Result<(f64, f64, f64, f64)> {
760        // Configure estimator with current parameters
761        let configured_estimator = (self.config_fn)(self.estimator.clone(), params)?;
762
763        // Get CV splits
764        let splits = self.cv.split(x.nrows(), None);
765        let n_splits = splits.len();
766
767        let mut test_scores = Vec::with_capacity(n_splits);
768        let mut fit_times = Vec::with_capacity(n_splits);
769        let mut score_times = Vec::with_capacity(n_splits);
770
771        // Evaluate on each CV fold
772        for (train_idx, test_idx) in splits {
773            // Extract train and test data
774            let x_train = x.select(scirs2_core::ndarray::Axis(0), &train_idx);
775            let y_train = y.select(scirs2_core::ndarray::Axis(0), &train_idx);
776            let x_test = x.select(scirs2_core::ndarray::Axis(0), &test_idx);
777            let y_test = y.select(scirs2_core::ndarray::Axis(0), &test_idx);
778
779            // Fit the estimator
780            let start = std::time::Instant::now();
781            let fitted = configured_estimator.clone().fit(&x_train, &y_train)?;
782            let fit_time = start.elapsed().as_secs_f64();
783            fit_times.push(fit_time);
784
785            // Score on test set
786            let start = std::time::Instant::now();
787            let test_score = match &self.scoring {
788                Scoring::EstimatorScore => fitted.score(&x_test, &y_test)?,
789                Scoring::Custom(func) => {
790                    let y_pred = fitted.predict(&x_test)?;
791                    func(&y_test.to_owned(), &y_pred)?
792                }
793                Scoring::Metric(metric_name) => {
794                    let y_pred = fitted.predict(&x_test)?;
795                    compute_score_for_regression(metric_name, &y_test, &y_pred)?
796                }
797                Scoring::Scorer(_scorer) => {
798                    let y_pred = fitted.predict(&x_test)?;
799                    // Default to negative MSE for regression
800                    -mean_squared_error(&y_test, &y_pred)?
801                }
802                Scoring::MultiMetric(_metrics) => {
803                    // For multi-metric, just use the first metric for now
804                    fitted.score(&x_test, &y_test)?
805                }
806            };
807            let score_time = start.elapsed().as_secs_f64();
808            score_times.push(score_time);
809            test_scores.push(test_score);
810        }
811
812        // Calculate statistics
813        let mean_test_score = test_scores.iter().sum::<f64>() / test_scores.len() as f64;
814        let std_test_score = {
815            let variance = test_scores
816                .iter()
817                .map(|&score| (score - mean_test_score).powi(2))
818                .sum::<f64>()
819                / test_scores.len() as f64;
820            variance.sqrt()
821        };
822        let mean_fit_time = fit_times.iter().sum::<f64>() / fit_times.len() as f64;
823        let mean_score_time = score_times.iter().sum::<f64>() / score_times.len() as f64;
824
825        Ok((
826            mean_test_score,
827            std_test_score,
828            mean_fit_time,
829            mean_score_time,
830        ))
831    }
832}
833
834impl<E, F, ConfigFn> Fit<Array2<Float>, Array1<Float>> for RandomizedSearchCV<E, F, ConfigFn>
835where
836    E: Clone + Send + Sync,
837    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
838    F: Predict<Array2<Float>, Array1<Float>> + Send + Sync,
839    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
840    ConfigFn: Fn(E, &ParameterSet) -> Result<E> + Send + Sync,
841{
842    type Fitted = RandomizedSearchCV<E, F, ConfigFn>;
843
844    fn fit(mut self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
845        if x.nrows() == 0 {
846            return Err(SklearsError::InvalidInput(
847                "Cannot fit on empty dataset".to_string(),
848            ));
849        }
850
851        if x.nrows() != y.len() {
852            return Err(SklearsError::ShapeMismatch {
853                expected: format!("X.shape[0] = {}", x.nrows()),
854                actual: format!("y.shape[0] = {}", y.len()),
855            });
856        }
857
858        if self.param_distributions.is_empty() {
859            return Err(SklearsError::InvalidInput(
860                "No parameter distributions to sample from".to_string(),
861            ));
862        }
863
864        // Sample parameter combinations
865        let param_combinations = self.sample_parameters(self.n_iter);
866
867        // Evaluate each parameter combination
868        let mut results = Vec::with_capacity(param_combinations.len());
869
870        for params in &param_combinations {
871            let (mean_score, std_score, mean_fit_time, mean_score_time) =
872                self.evaluate_params(params, x, y)?;
873
874            results.push((mean_score, std_score, mean_fit_time, mean_score_time));
875        }
876
877        // Find best parameters
878        let best_idx = results
879            .iter()
880            .enumerate()
881            .max_by(|(_, a), (_, b)| a.0.partial_cmp(&b.0).unwrap())
882            .map(|(idx, _)| idx)
883            .ok_or_else(|| SklearsError::NumericalError("No valid scores found".to_string()))?;
884
885        let best_params = param_combinations[best_idx].clone();
886        let best_score = results[best_idx].0;
887
888        // Create CV results
889        let mean_test_scores = Array1::from_vec(results.iter().map(|r| r.0).collect());
890        let std_test_scores = Array1::from_vec(results.iter().map(|r| r.1).collect());
891        let mean_fit_times = Array1::from_vec(results.iter().map(|r| r.2).collect());
892        let mean_score_times = Array1::from_vec(results.iter().map(|r| r.3).collect());
893
894        // Calculate ranks (1 = best)
895        let mut scores_with_idx: Vec<(f64, usize)> = mean_test_scores
896            .iter()
897            .enumerate()
898            .map(|(i, &score)| (score, i))
899            .collect();
900        scores_with_idx.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
901
902        let mut ranks = vec![0; param_combinations.len()];
903        for (rank, (_, idx)) in scores_with_idx.iter().enumerate() {
904            ranks[*idx] = rank + 1;
905        }
906
907        let cv_results = GridSearchResults {
908            params: param_combinations.clone(),
909            mean_test_scores,
910            std_test_scores,
911            mean_fit_times,
912            mean_score_times,
913            rank_test_scores: Array1::from_vec(ranks),
914        };
915
916        // Refit with best parameters if requested
917        let best_estimator = if self.refit {
918            let configured_estimator = (self.config_fn)(self.estimator.clone(), &best_params)?;
919            Some(configured_estimator.fit(x, y)?)
920        } else {
921            None
922        };
923
924        self.best_estimator_ = best_estimator;
925        self.best_params_ = Some(best_params);
926        self.best_score_ = Some(best_score);
927        self.cv_results_ = Some(cv_results);
928
929        Ok(self)
930    }
931}
932
933impl<E, F, ConfigFn> Predict<Array2<Float>, Array1<Float>> for RandomizedSearchCV<E, F, ConfigFn>
934where
935    E: Clone,
936    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
937    F: Predict<Array2<Float>, Array1<Float>>,
938    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
939    ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
940{
941    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
942        match &self.best_estimator_ {
943            Some(estimator) => estimator.predict(x),
944            None => Err(SklearsError::NotFitted {
945                operation: "predict".to_string(),
946            }),
947        }
948    }
949}
950
951impl<E, F, ConfigFn> Score<Array2<Float>, Array1<Float>> for RandomizedSearchCV<E, F, ConfigFn>
952where
953    E: Clone,
954    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
955    F: Predict<Array2<Float>, Array1<Float>>,
956    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
957    ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
958{
959    type Float = f64;
960    fn score(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
961        match &self.best_estimator_ {
962            Some(estimator) => estimator.score(x, y),
963            None => Err(SklearsError::NotFitted {
964                operation: "score".to_string(),
965            }),
966        }
967    }
968}
969
970#[allow(non_snake_case)]
971#[cfg(test)]
972mod tests {
973    use super::*;
974    use crate::KFold;
975    use scirs2_core::ndarray::array;
976    // use sklears_ensemble::GradientBoostingRegressor;
977
978    // Mock estimator for testing
979    #[derive(Debug, Clone)]
980    struct MockRegressor {
981        n_estimators: usize,
982        learning_rate: f64,
983        random_state: Option<u64>,
984        fitted: bool,
985    }
986
987    impl MockRegressor {
988        fn new() -> Self {
989            Self {
990                n_estimators: 100,
991                learning_rate: 0.1,
992                random_state: None,
993                fitted: false,
994            }
995        }
996
997        fn n_estimators(mut self, n: usize) -> Self {
998            self.n_estimators = n;
999            self
1000        }
1001
1002        fn learning_rate(mut self, lr: f64) -> Self {
1003            self.learning_rate = lr;
1004            self
1005        }
1006
1007        fn random_state(mut self, state: Option<u64>) -> Self {
1008            self.random_state = state;
1009            self
1010        }
1011    }
1012
1013    impl Fit<Array2<f64>, Array1<f64>> for MockRegressor {
1014        type Fitted = MockRegressor;
1015
1016        fn fit(mut self, _x: &Array2<f64>, _y: &Array1<f64>) -> Result<Self::Fitted> {
1017            self.fitted = true;
1018            Ok(self)
1019        }
1020    }
1021
1022    impl Predict<Array2<f64>, Array1<f64>> for MockRegressor {
1023        fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
1024            if !self.fitted {
1025                return Err(SklearsError::NotFitted {
1026                    operation: "predict".to_string(),
1027                });
1028            }
1029            // Simple prediction: sum of features
1030            Ok(x.sum_axis(scirs2_core::ndarray::Axis(1)))
1031        }
1032    }
1033
1034    impl Score<Array2<f64>, Array1<f64>> for MockRegressor {
1035        type Float = f64;
1036
1037        fn score(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
1038            let y_pred = self.predict(x)?;
1039            let mse = mean_squared_error(y, &y_pred)?;
1040            Ok(-mse) // Return negative MSE as score (higher is better)
1041        }
1042    }
1043
1044    type GradientBoostingRegressor = MockRegressor;
1045
1046    #[test]
1047    fn test_parameter_value_extraction() {
1048        let int_param = ParameterValue::Int(42);
1049        assert_eq!(int_param.as_int(), Some(42));
1050        assert_eq!(int_param.as_float(), None);
1051
1052        let float_param = ParameterValue::Float(std::f64::consts::PI);
1053        assert_eq!(float_param.as_float(), Some(std::f64::consts::PI));
1054        assert_eq!(float_param.as_int(), None);
1055
1056        let opt_int_param = ParameterValue::OptionalInt(Some(10));
1057        assert_eq!(opt_int_param.as_optional_int(), Some(Some(10)));
1058    }
1059
1060    #[test]
1061    #[ignore] // Temporarily disabled due to sklears_ensemble dependency issues
1062    fn test_grid_search_cv() {
1063        // Create a simple dataset
1064        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0],];
1065        let y = array![1.0, 4.0, 9.0, 16.0, 25.0]; // roughly x^2
1066
1067        // Create parameter grid
1068        let mut param_grid = HashMap::new();
1069        param_grid.insert(
1070            "n_estimators".to_string(),
1071            vec![ParameterValue::Int(5), ParameterValue::Int(10)],
1072        );
1073        param_grid.insert(
1074            "learning_rate".to_string(),
1075            vec![ParameterValue::Float(0.1), ParameterValue::Float(0.3)],
1076        );
1077
1078        // Configuration function for GradientBoostingRegressor
1079        let config_fn = |estimator: GradientBoostingRegressor,
1080                         params: &ParameterSet|
1081         -> Result<GradientBoostingRegressor> {
1082            let mut configured = estimator;
1083
1084            if let Some(n_est) = params.get("n_estimators").and_then(|p| p.as_int()) {
1085                configured = configured.n_estimators(n_est as usize);
1086            }
1087
1088            if let Some(lr) = params.get("learning_rate").and_then(|p| p.as_float()) {
1089                configured = configured.learning_rate(lr);
1090            }
1091
1092            Ok(configured)
1093        };
1094
1095        // Create and fit grid search
1096        let base_estimator = GradientBoostingRegressor::new().random_state(Some(42));
1097        let grid_search = GridSearchCV::new(base_estimator, param_grid, config_fn)
1098            .cv(KFold::new(3))
1099            .fit(&x, &y)
1100            .unwrap();
1101
1102        // Check that we have results
1103        assert!(grid_search.best_score().is_some());
1104        assert!(grid_search.best_params().is_some());
1105        assert!(grid_search.best_estimator().is_some());
1106        assert!(grid_search.cv_results().is_some());
1107
1108        // Check CV results structure
1109        let cv_results = grid_search.cv_results().unwrap();
1110        assert_eq!(cv_results.params.len(), 4); // 2 x 2 = 4 combinations
1111        assert_eq!(cv_results.mean_test_scores.len(), 4);
1112        assert_eq!(cv_results.rank_test_scores.len(), 4);
1113
1114        // Best rank should be 1
1115        let best_rank = cv_results.rank_test_scores.iter().min().unwrap();
1116        assert_eq!(*best_rank, 1);
1117
1118        // Test prediction with best estimator
1119        let predictions = grid_search.predict(&x).unwrap();
1120        assert_eq!(predictions.len(), x.nrows());
1121    }
1122
1123    #[test]
1124    #[ignore] // Temporarily disabled due to sklears_ensemble dependency issues
1125    fn test_grid_search_empty_grid() {
1126        let x = array![[1.0], [2.0]];
1127        let y = array![1.0, 2.0];
1128
1129        let param_grid = HashMap::new(); // Empty grid
1130        let config_fn = |estimator: GradientBoostingRegressor,
1131                         _params: &ParameterSet|
1132         -> Result<GradientBoostingRegressor> { Ok(estimator) };
1133
1134        let base_estimator = GradientBoostingRegressor::new();
1135        let result = GridSearchCV::new(base_estimator, param_grid, config_fn)
1136            .cv(KFold::new(2)) // Use 2 folds for 2 samples
1137            .fit(&x, &y);
1138
1139        // Empty grid should succeed with default parameters
1140        assert!(result.is_ok());
1141        let grid_search = result.unwrap();
1142
1143        // Should have one parameter combination (empty set = default params)
1144        let cv_results = grid_search.cv_results().unwrap();
1145        assert_eq!(cv_results.params.len(), 1);
1146        assert!(cv_results.params[0].is_empty()); // Empty parameter set
1147    }
1148
1149    #[test]
1150    fn test_parameter_distribution_sampling() {
1151        use scirs2_core::random::rngs::StdRng;
1152        use scirs2_core::random::SeedableRng;
1153        let mut rng = StdRng::seed_from_u64(42);
1154
1155        // Test Choice distribution
1156        let choice_dist = ParameterDistribution::Choice(vec![
1157            ParameterValue::Int(1),
1158            ParameterValue::Int(2),
1159            ParameterValue::Int(3),
1160        ]);
1161        let sample = choice_dist.sample(&mut rng);
1162        if let ParameterValue::Int(val) = sample {
1163            assert!(val >= 1 && val <= 3);
1164        } else {
1165            panic!("Expected Int parameter value");
1166        }
1167
1168        // Test RandInt distribution
1169        let int_dist = ParameterDistribution::RandInt { low: 10, high: 20 };
1170        let sample = int_dist.sample(&mut rng);
1171        if let ParameterValue::Int(val) = sample {
1172            assert!(val >= 10 && val < 20);
1173        } else {
1174            panic!("Expected Int parameter value");
1175        }
1176
1177        // Test Uniform distribution
1178        let uniform_dist = ParameterDistribution::Uniform {
1179            low: 0.0,
1180            high: 1.0,
1181        };
1182        let sample = uniform_dist.sample(&mut rng);
1183        if let ParameterValue::Float(val) = sample {
1184            assert!(val >= 0.0 && val < 1.0);
1185        } else {
1186            panic!("Expected Float parameter value");
1187        }
1188
1189        // Test Normal distribution
1190        let normal_dist = ParameterDistribution::Normal {
1191            mean: 0.0,
1192            std: 1.0,
1193        };
1194        let sample = normal_dist.sample(&mut rng);
1195        assert!(matches!(sample, ParameterValue::Float(_)));
1196    }
1197
1198    #[test]
1199    #[ignore] // Temporarily disabled due to sklears_ensemble dependency issues
1200    fn test_randomized_search_cv() {
1201        // Create a simple dataset
1202        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0],];
1203        let y = array![1.0, 4.0, 9.0, 16.0, 25.0]; // roughly x^2
1204
1205        // Create parameter distributions
1206        let mut param_distributions = HashMap::new();
1207        param_distributions.insert(
1208            "n_estimators".to_string(),
1209            ParameterDistribution::Choice(vec![
1210                ParameterValue::Int(5),
1211                ParameterValue::Int(10),
1212                ParameterValue::Int(15),
1213            ]),
1214        );
1215        param_distributions.insert(
1216            "learning_rate".to_string(),
1217            ParameterDistribution::Uniform {
1218                low: 0.05,
1219                high: 0.5,
1220            },
1221        );
1222
1223        // Configuration function for GradientBoostingRegressor
1224        let config_fn = |estimator: GradientBoostingRegressor,
1225                         params: &ParameterSet|
1226         -> Result<GradientBoostingRegressor> {
1227            let mut configured = estimator;
1228
1229            if let Some(n_est) = params.get("n_estimators").and_then(|p| p.as_int()) {
1230                configured = configured.n_estimators(n_est as usize);
1231            }
1232
1233            if let Some(lr) = params.get("learning_rate").and_then(|p| p.as_float()) {
1234                configured = configured.learning_rate(lr);
1235            }
1236
1237            Ok(configured)
1238        };
1239
1240        // Create and fit randomized search
1241        let base_estimator = GradientBoostingRegressor::new().random_state(Some(42));
1242        let randomized_search =
1243            RandomizedSearchCV::new(base_estimator, param_distributions, config_fn)
1244                .n_iter(8) // Sample 8 parameter combinations
1245                .cv(KFold::new(3))
1246                .random_state(Some(42))
1247                .fit(&x, &y)
1248                .unwrap();
1249
1250        // Check that we have results
1251        assert!(randomized_search.best_score().is_some());
1252        assert!(randomized_search.best_params().is_some());
1253        assert!(randomized_search.best_estimator().is_some());
1254        assert!(randomized_search.cv_results().is_some());
1255
1256        // Check CV results structure
1257        let cv_results = randomized_search.cv_results().unwrap();
1258        assert_eq!(cv_results.params.len(), 8); // Should have 8 sampled combinations
1259        assert_eq!(cv_results.mean_test_scores.len(), 8);
1260        assert_eq!(cv_results.rank_test_scores.len(), 8);
1261
1262        // Best rank should be 1
1263        let best_rank = cv_results.rank_test_scores.iter().min().unwrap();
1264        assert_eq!(*best_rank, 1);
1265
1266        // Test prediction with best estimator
1267        let predictions = randomized_search.predict(&x).unwrap();
1268        assert_eq!(predictions.len(), x.nrows());
1269
1270        // Check that parameter values are within expected ranges
1271        for params in &cv_results.params {
1272            if let Some(n_est) = params.get("n_estimators").and_then(|p| p.as_int()) {
1273                assert!(n_est == 5 || n_est == 10 || n_est == 15);
1274            }
1275            if let Some(lr) = params.get("learning_rate").and_then(|p| p.as_float()) {
1276                assert!(lr >= 0.05 && lr < 0.5);
1277            }
1278        }
1279    }
1280
1281    #[test]
1282    #[ignore] // Temporarily disabled due to sklears_ensemble dependency issues
1283    fn test_randomized_search_empty_distributions() {
1284        let x = array![[1.0], [2.0]];
1285        let y = array![1.0, 2.0];
1286
1287        let param_distributions = HashMap::new(); // Empty distributions
1288        let config_fn = |estimator: GradientBoostingRegressor,
1289                         _params: &ParameterSet|
1290         -> Result<GradientBoostingRegressor> { Ok(estimator) };
1291
1292        let base_estimator = GradientBoostingRegressor::new();
1293        let result = RandomizedSearchCV::new(base_estimator, param_distributions, config_fn)
1294            .cv(KFold::new(2))
1295            .fit(&x, &y);
1296
1297        assert!(result.is_err());
1298    }
1299
1300    #[test]
1301    #[ignore] // Temporarily disabled due to sklears_ensemble dependency issues
1302    fn test_randomized_search_reproducibility() {
1303        let x = array![[1.0], [2.0], [3.0], [4.0]];
1304        let y = array![1.0, 2.0, 3.0, 4.0];
1305
1306        // Create parameter distributions
1307        let mut param_distributions = HashMap::new();
1308        param_distributions.insert(
1309            "learning_rate".to_string(),
1310            ParameterDistribution::Uniform {
1311                low: 0.1,
1312                high: 0.5,
1313            },
1314        );
1315
1316        let config_fn = |estimator: GradientBoostingRegressor,
1317                         params: &ParameterSet|
1318         -> Result<GradientBoostingRegressor> {
1319            let mut configured = estimator;
1320            if let Some(lr) = params.get("learning_rate").and_then(|p| p.as_float()) {
1321                configured = configured.learning_rate(lr);
1322            }
1323            Ok(configured)
1324        };
1325
1326        // Run twice with same random state
1327        let base_estimator1 = GradientBoostingRegressor::new().random_state(Some(42));
1328        let result1 =
1329            RandomizedSearchCV::new(base_estimator1, param_distributions.clone(), config_fn)
1330                .n_iter(5)
1331                .random_state(Some(123))
1332                .cv(KFold::new(2))
1333                .fit(&x, &y)
1334                .unwrap();
1335
1336        let base_estimator2 = GradientBoostingRegressor::new().random_state(Some(42));
1337        let result2 = RandomizedSearchCV::new(base_estimator2, param_distributions, config_fn)
1338            .n_iter(5)
1339            .random_state(Some(123))
1340            .cv(KFold::new(2))
1341            .fit(&x, &y)
1342            .unwrap();
1343
1344        // Should get identical results
1345        assert_eq!(result1.best_score(), result2.best_score());
1346
1347        let params1 = result1.cv_results().unwrap();
1348        let params2 = result2.cv_results().unwrap();
1349
1350        // Check that the same parameters were sampled
1351        for (p1, p2) in params1.params.iter().zip(params2.params.iter()) {
1352            assert_eq!(p1, p2);
1353        }
1354    }
1355}