Skip to main content

sklears_model_selection/
grid_search.rs

1//! Grid search and randomized search for hyperparameter tuning
2
3use crate::{CrossValidator, KFold, Scoring};
4use scirs2_core::ndarray::{Array1, Array2};
5use scirs2_core::rand_prelude::IndexedRandom;
6use scirs2_core::random::essentials::Normal as RandNormal;
7// use scirs2_core::random::prelude::*;
8use sklears_core::{
9    error::Result,
10    prelude::SklearsError,
11    traits::{Fit, Predict, Score},
12    types::Float,
13};
14use sklears_metrics::{classification::accuracy_score, get_scorer, regression::mean_squared_error};
15use std::collections::HashMap;
16use std::marker::PhantomData;
17
18/// Parameter grid for grid search
19///
20/// This represents all possible combinations of hyperparameters to test.
21/// Each parameter name maps to a vector of possible values.
22pub type ParameterGrid = HashMap<String, Vec<ParameterValue>>;
23
24/// A parameter value that can be of different types
25#[derive(Debug, Clone, PartialEq)]
26#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
27pub enum ParameterValue {
28    /// Integer parameter
29    Int(i64),
30    /// Float parameter
31    Float(f64),
32    /// Boolean parameter
33    Bool(bool),
34    /// String parameter
35    String(String),
36    /// Option integer parameter (Some/None)
37    OptionalInt(Option<i64>),
38    /// Option float parameter (Some/None)
39    OptionalFloat(Option<f64>),
40}
41
42impl ParameterValue {
43    /// Extract an integer value
44    pub fn as_int(&self) -> Option<i64> {
45        match self {
46            ParameterValue::Int(v) => Some(*v),
47            _ => None,
48        }
49    }
50
51    /// Extract a float value
52    pub fn as_float(&self) -> Option<f64> {
53        match self {
54            ParameterValue::Float(v) => Some(*v),
55            _ => None,
56        }
57    }
58
59    /// Extract a boolean value
60    pub fn as_bool(&self) -> Option<bool> {
61        match self {
62            ParameterValue::Bool(v) => Some(*v),
63            _ => None,
64        }
65    }
66
67    /// Extract an optional integer value
68    pub fn as_optional_int(&self) -> Option<Option<i64>> {
69        match self {
70            ParameterValue::OptionalInt(v) => Some(*v),
71            _ => None,
72        }
73    }
74
75    /// Extract an optional float value
76    pub fn as_optional_float(&self) -> Option<Option<f64>> {
77        match self {
78            ParameterValue::OptionalFloat(v) => Some(*v),
79            _ => None,
80        }
81    }
82}
83
84impl From<i32> for ParameterValue {
85    fn from(value: i32) -> Self {
86        ParameterValue::Int(value as i64)
87    }
88}
89
90impl From<i64> for ParameterValue {
91    fn from(value: i64) -> Self {
92        ParameterValue::Int(value)
93    }
94}
95
96impl From<f32> for ParameterValue {
97    fn from(value: f32) -> Self {
98        ParameterValue::Float(value as f64)
99    }
100}
101
102impl From<f64> for ParameterValue {
103    fn from(value: f64) -> Self {
104        ParameterValue::Float(value)
105    }
106}
107
108impl From<bool> for ParameterValue {
109    fn from(value: bool) -> Self {
110        ParameterValue::Bool(value)
111    }
112}
113
114impl From<String> for ParameterValue {
115    fn from(value: String) -> Self {
116        ParameterValue::String(value)
117    }
118}
119
120impl From<&str> for ParameterValue {
121    fn from(value: &str) -> Self {
122        ParameterValue::String(value.to_string())
123    }
124}
125
126impl From<Option<i32>> for ParameterValue {
127    fn from(value: Option<i32>) -> Self {
128        ParameterValue::OptionalInt(value.map(|v| v as i64))
129    }
130}
131
132impl From<Option<i64>> for ParameterValue {
133    fn from(value: Option<i64>) -> Self {
134        ParameterValue::OptionalInt(value)
135    }
136}
137
138impl From<Option<f64>> for ParameterValue {
139    fn from(value: Option<f64>) -> Self {
140        ParameterValue::OptionalFloat(value)
141    }
142}
143
144impl Eq for ParameterValue {}
145
146impl std::hash::Hash for ParameterValue {
147    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
148        std::mem::discriminant(self).hash(state);
149        match self {
150            ParameterValue::Int(v) => v.hash(state),
151            ParameterValue::Float(v) => v.to_bits().hash(state), // Use bits representation for f64
152            ParameterValue::Bool(v) => v.hash(state),
153            ParameterValue::String(v) => v.hash(state),
154            ParameterValue::OptionalInt(v) => v.hash(state),
155            ParameterValue::OptionalFloat(v) => v.map(|f| f.to_bits()).hash(state),
156        }
157    }
158}
159
160/// A parameter combination for one grid search iteration
161pub type ParameterSet = HashMap<String, ParameterValue>;
162
163/// Grid search cross-validation
164///
165/// Exhaustive search over specified parameter values for an estimator.
166/// Uses cross-validation to evaluate each parameter combination.
167pub struct GridSearchCV<E, F, ConfigFn>
168where
169    E: Clone,
170    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
171    F: Predict<Array2<Float>, Array1<Float>>,
172    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
173    ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
174{
175    /// Base estimator to use for each parameter combination
176    estimator: E,
177    /// Parameter grid to search
178    param_grid: ParameterGrid,
179    /// Cross-validation strategy
180    cv: Box<dyn CrossValidator>,
181    /// Scoring method
182    scoring: Scoring,
183    /// Number of parallel jobs (-1 for all cores)
184    n_jobs: Option<usize>,
185    /// Whether to refit on the entire dataset with best parameters
186    refit: bool,
187    /// Configuration function to apply parameters to estimator
188    config_fn: ConfigFn,
189    /// Phantom data for fitted estimator type
190    _phantom: PhantomData<F>,
191    // Fitted results
192    best_estimator_: Option<F>,
193    best_params_: Option<ParameterSet>,
194    best_score_: Option<f64>,
195    cv_results_: Option<GridSearchResults>,
196}
197
198/// Results from grid search cross-validation
199#[derive(Debug, Clone)]
200pub struct GridSearchResults {
201    pub params: Vec<ParameterSet>,
202    pub mean_test_scores: Array1<f64>,
203    pub std_test_scores: Array1<f64>,
204    pub mean_fit_times: Array1<f64>,
205    pub mean_score_times: Array1<f64>,
206    pub rank_test_scores: Array1<usize>,
207}
208
209/// Helper function for scoring that handles both regression and classification
210fn compute_score_for_regression(
211    metric_name: &str,
212    y_true: &Array1<f64>,
213    y_pred: &Array1<f64>,
214) -> Result<f64> {
215    match metric_name {
216        "neg_mean_squared_error" => Ok(-mean_squared_error(y_true, y_pred)?),
217        "mean_squared_error" => Ok(mean_squared_error(y_true, y_pred)?),
218        _ => {
219            // For unsupported metrics, return a default score
220            Err(SklearsError::InvalidInput(format!(
221                "Metric '{}' not supported for regression",
222                metric_name
223            )))
224        }
225    }
226}
227
228/// Helper function for scoring classification data
229fn compute_score_for_classification(
230    metric_name: &str,
231    y_true: &Array1<i32>,
232    y_pred: &Array1<i32>,
233) -> Result<f64> {
234    match metric_name {
235        "accuracy" => Ok(accuracy_score(y_true, y_pred)?),
236        _ => {
237            let scorer = get_scorer(metric_name)?;
238            scorer.score(
239                y_true.as_slice().expect("operation should succeed"),
240                y_pred.as_slice().expect("operation should succeed"),
241            )
242        }
243    }
244}
245
246impl<E, F, ConfigFn> GridSearchCV<E, F, ConfigFn>
247where
248    E: Clone,
249    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
250    F: Predict<Array2<Float>, Array1<Float>>,
251    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
252    ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
253{
254    /// Create a new grid search CV
255    pub fn new(estimator: E, param_grid: ParameterGrid, config_fn: ConfigFn) -> Self {
256        Self {
257            estimator,
258            param_grid,
259            cv: Box::new(KFold::new(5)),
260            scoring: Scoring::EstimatorScore,
261            n_jobs: None,
262            refit: true,
263            config_fn,
264            _phantom: PhantomData,
265            best_estimator_: None,
266            best_params_: None,
267            best_score_: None,
268            cv_results_: None,
269        }
270    }
271
272    /// Set the cross-validation strategy
273    pub fn cv<C: CrossValidator + 'static>(mut self, cv: C) -> Self {
274        self.cv = Box::new(cv);
275        self
276    }
277
278    /// Set the scoring method
279    pub fn scoring(mut self, scoring: Scoring) -> Self {
280        self.scoring = scoring;
281        self
282    }
283
284    /// Set the number of parallel jobs
285    pub fn n_jobs(mut self, n_jobs: Option<usize>) -> Self {
286        self.n_jobs = n_jobs;
287        self
288    }
289
290    /// Set whether to refit with best parameters
291    pub fn refit(mut self, refit: bool) -> Self {
292        self.refit = refit;
293        self
294    }
295
296    /// Get the best estimator (after fitting)
297    pub fn best_estimator(&self) -> Option<&F> {
298        self.best_estimator_.as_ref()
299    }
300
301    /// Get the best parameters (after fitting)
302    pub fn best_params(&self) -> Option<&ParameterSet> {
303        self.best_params_.as_ref()
304    }
305
306    /// Get the best score (after fitting)
307    pub fn best_score(&self) -> Option<f64> {
308        self.best_score_
309    }
310
311    /// Get the CV results (after fitting)
312    pub fn cv_results(&self) -> Option<&GridSearchResults> {
313        self.cv_results_.as_ref()
314    }
315
316    /// Generate all parameter combinations from the grid
317    fn generate_param_combinations(&self) -> Vec<ParameterSet> {
318        let mut combinations = vec![HashMap::new()];
319
320        for (param_name, param_values) in &self.param_grid {
321            let mut new_combinations = Vec::new();
322
323            for combination in combinations {
324                for param_value in param_values {
325                    let mut new_combination = combination.clone();
326                    new_combination.insert(param_name.clone(), param_value.clone());
327                    new_combinations.push(new_combination);
328                }
329            }
330
331            combinations = new_combinations;
332        }
333
334        combinations
335    }
336
337    /// Evaluate a single parameter combination using cross-validation
338    fn evaluate_params(
339        &self,
340        params: &ParameterSet,
341        x: &Array2<Float>,
342        y: &Array1<Float>,
343    ) -> Result<(f64, f64, f64, f64)> {
344        // Configure estimator with current parameters
345        let configured_estimator = (self.config_fn)(self.estimator.clone(), params)?;
346
347        // Get CV splits
348        let splits = self.cv.split(x.nrows(), None);
349        let n_splits = splits.len();
350
351        let mut test_scores = Vec::with_capacity(n_splits);
352        let mut fit_times = Vec::with_capacity(n_splits);
353        let mut score_times = Vec::with_capacity(n_splits);
354
355        // Evaluate on each CV fold
356        for (train_idx, test_idx) in splits {
357            // Extract train and test data
358            let x_train = x.select(scirs2_core::ndarray::Axis(0), &train_idx);
359            let y_train = y.select(scirs2_core::ndarray::Axis(0), &train_idx);
360            let x_test = x.select(scirs2_core::ndarray::Axis(0), &test_idx);
361            let y_test = y.select(scirs2_core::ndarray::Axis(0), &test_idx);
362
363            // Fit the estimator
364            let start = std::time::Instant::now();
365            let fitted = configured_estimator.clone().fit(&x_train, &y_train)?;
366            let fit_time = start.elapsed().as_secs_f64();
367            fit_times.push(fit_time);
368
369            // Score on test set
370            let start = std::time::Instant::now();
371            let test_score = match &self.scoring {
372                Scoring::EstimatorScore => fitted.score(&x_test, &y_test)?,
373                Scoring::Custom(func) => {
374                    let y_pred = fitted.predict(&x_test)?;
375                    func(&y_test.to_owned(), &y_pred)?
376                }
377                Scoring::Metric(metric_name) => {
378                    let y_pred = fitted.predict(&x_test)?;
379                    compute_score_for_regression(metric_name, &y_test, &y_pred)?
380                }
381                Scoring::Scorer(_scorer) => {
382                    let y_pred = fitted.predict(&x_test)?;
383                    // Default to negative MSE for regression
384                    -mean_squared_error(&y_test, &y_pred)?
385                }
386                Scoring::MultiMetric(_metrics) => {
387                    // For multi-metric, just use the first metric for now
388                    fitted.score(&x_test, &y_test)?
389                }
390            };
391            let score_time = start.elapsed().as_secs_f64();
392            score_times.push(score_time);
393            test_scores.push(test_score);
394        }
395
396        // Calculate statistics
397        let mean_test_score = test_scores.iter().sum::<f64>() / test_scores.len() as f64;
398        let std_test_score = {
399            let variance = test_scores
400                .iter()
401                .map(|&score| (score - mean_test_score).powi(2))
402                .sum::<f64>()
403                / test_scores.len() as f64;
404            variance.sqrt()
405        };
406        let mean_fit_time = fit_times.iter().sum::<f64>() / fit_times.len() as f64;
407        let mean_score_time = score_times.iter().sum::<f64>() / score_times.len() as f64;
408
409        Ok((
410            mean_test_score,
411            std_test_score,
412            mean_fit_time,
413            mean_score_time,
414        ))
415    }
416}
417
418impl<E, F, ConfigFn> Fit<Array2<Float>, Array1<Float>> for GridSearchCV<E, F, ConfigFn>
419where
420    E: Clone + Send + Sync,
421    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
422    F: Predict<Array2<Float>, Array1<Float>> + Send + Sync,
423    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
424    ConfigFn: Fn(E, &ParameterSet) -> Result<E> + Send + Sync,
425{
426    type Fitted = GridSearchCV<E, F, ConfigFn>;
427
428    fn fit(mut self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
429        if x.nrows() == 0 {
430            return Err(SklearsError::InvalidInput(
431                "Cannot fit on empty dataset".to_string(),
432            ));
433        }
434
435        if x.nrows() != y.len() {
436            return Err(SklearsError::ShapeMismatch {
437                expected: format!("X.shape[0] = {}", x.nrows()),
438                actual: format!("y.shape[0] = {}", y.len()),
439            });
440        }
441
442        // Generate all parameter combinations
443        let param_combinations = self.generate_param_combinations();
444
445        if param_combinations.is_empty() {
446            return Err(SklearsError::InvalidInput(
447                "No parameter combinations to evaluate".to_string(),
448            ));
449        }
450
451        // Evaluate each parameter combination
452        let mut results = Vec::with_capacity(param_combinations.len());
453
454        for params in &param_combinations {
455            let (mean_score, std_score, mean_fit_time, mean_score_time) =
456                self.evaluate_params(params, x, y)?;
457
458            results.push((mean_score, std_score, mean_fit_time, mean_score_time));
459        }
460
461        // Find best parameters
462        let best_idx = results
463            .iter()
464            .enumerate()
465            .max_by(|(_, a), (_, b)| a.0.partial_cmp(&b.0).expect("operation should succeed"))
466            .map(|(idx, _)| idx)
467            .ok_or_else(|| SklearsError::NumericalError("No valid scores found".to_string()))?;
468
469        let best_params = param_combinations[best_idx].clone();
470        let best_score = results[best_idx].0;
471
472        // Create CV results
473        let mean_test_scores = Array1::from_vec(results.iter().map(|r| r.0).collect());
474        let std_test_scores = Array1::from_vec(results.iter().map(|r| r.1).collect());
475        let mean_fit_times = Array1::from_vec(results.iter().map(|r| r.2).collect());
476        let mean_score_times = Array1::from_vec(results.iter().map(|r| r.3).collect());
477
478        // Calculate ranks (1 = best)
479        let mut scores_with_idx: Vec<(f64, usize)> = mean_test_scores
480            .iter()
481            .enumerate()
482            .map(|(i, &score)| (score, i))
483            .collect();
484        scores_with_idx.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("operation should succeed"));
485
486        let mut ranks = vec![0; param_combinations.len()];
487        for (rank, (_, idx)) in scores_with_idx.iter().enumerate() {
488            ranks[*idx] = rank + 1;
489        }
490
491        let cv_results = GridSearchResults {
492            params: param_combinations.clone(),
493            mean_test_scores,
494            std_test_scores,
495            mean_fit_times,
496            mean_score_times,
497            rank_test_scores: Array1::from_vec(ranks),
498        };
499
500        // Refit with best parameters if requested
501        let best_estimator = if self.refit {
502            let configured_estimator = (self.config_fn)(self.estimator.clone(), &best_params)?;
503            Some(configured_estimator.fit(x, y)?)
504        } else {
505            None
506        };
507
508        self.best_estimator_ = best_estimator;
509        self.best_params_ = Some(best_params);
510        self.best_score_ = Some(best_score);
511        self.cv_results_ = Some(cv_results);
512
513        Ok(self)
514    }
515}
516
517impl<E, F, ConfigFn> Predict<Array2<Float>, Array1<Float>> for GridSearchCV<E, F, ConfigFn>
518where
519    E: Clone,
520    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
521    F: Predict<Array2<Float>, Array1<Float>>,
522    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
523    ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
524{
525    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
526        match &self.best_estimator_ {
527            Some(estimator) => estimator.predict(x),
528            None => Err(SklearsError::NotFitted {
529                operation: "predict".to_string(),
530            }),
531        }
532    }
533}
534
535impl<E, F, ConfigFn> Score<Array2<Float>, Array1<Float>> for GridSearchCV<E, F, ConfigFn>
536where
537    E: Clone,
538    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
539    F: Predict<Array2<Float>, Array1<Float>>,
540    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
541    ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
542{
543    type Float = f64;
544    fn score(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
545        match &self.best_estimator_ {
546            Some(estimator) => estimator.score(x, y),
547            None => Err(SklearsError::NotFitted {
548                operation: "score".to_string(),
549            }),
550        }
551    }
552}
553
554/// Parameter distribution for randomized search
555#[derive(Debug, Clone)]
556pub enum ParameterDistribution {
557    /// Uniform distribution over discrete values
558    Choice(Vec<ParameterValue>),
559    /// Uniform distribution over integer range [low, high)
560    RandInt { low: i64, high: i64 },
561    /// Uniform distribution over float range [low, high)
562    Uniform { low: f64, high: f64 },
563    /// Log-uniform distribution over float range [low, high)
564    LogUniform { low: f64, high: f64 },
565    /// Normal distribution with mean and std
566    Normal { mean: f64, std: f64 },
567}
568
569impl ParameterDistribution {
570    /// Sample a value from this distribution
571    pub fn sample(&self, rng: &mut impl scirs2_core::random::Rng) -> ParameterValue {
572        use scirs2_core::essentials::Uniform;
573        use scirs2_core::random::Distribution;
574
575        match self {
576            ParameterDistribution::Choice(values) => values
577                .as_slice()
578                .choose(rng)
579                .expect("operation should succeed")
580                .clone(),
581            ParameterDistribution::RandInt { low, high } => {
582                let dist = Uniform::new(*low, *high).expect("operation should succeed");
583                ParameterValue::Int(dist.sample(rng))
584            }
585            ParameterDistribution::Uniform { low, high } => {
586                let dist = Uniform::new(*low, *high).expect("operation should succeed");
587                ParameterValue::Float(dist.sample(rng))
588            }
589            ParameterDistribution::LogUniform { low, high } => {
590                // Implement log-uniform manually: sample from log scale then exponentiate
591                let log_low = low.ln();
592                let log_high = high.ln();
593                let dist = Uniform::new(log_low, log_high).expect("operation should succeed");
594                let log_sample = dist.sample(rng);
595                ParameterValue::Float(log_sample.exp())
596            }
597            ParameterDistribution::Normal { mean, std } => {
598                let dist = RandNormal::new(*mean, *std).expect("operation should succeed");
599                ParameterValue::Float(dist.sample(rng))
600            }
601        }
602    }
603}
604
605/// Parameter distribution grid for randomized search
606pub type ParameterDistributions = HashMap<String, ParameterDistribution>;
607
608/// Randomized search cross-validation
609///
610/// Search over parameter distributions with random sampling.
611/// More efficient than GridSearchCV for large parameter spaces.
612pub struct RandomizedSearchCV<E, F, ConfigFn>
613where
614    E: Clone,
615    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
616    F: Predict<Array2<Float>, Array1<Float>>,
617    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
618    ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
619{
620    /// Base estimator to use for each parameter combination
621    estimator: E,
622    /// Parameter distributions to search
623    param_distributions: ParameterDistributions,
624    /// Number of parameter settings to sample
625    n_iter: usize,
626    /// Cross-validation strategy
627    cv: Box<dyn CrossValidator>,
628    /// Scoring method
629    scoring: Scoring,
630    /// Number of parallel jobs
631    n_jobs: Option<usize>,
632    /// Whether to refit on the entire dataset with best parameters
633    refit: bool,
634    /// Random seed for reproducibility
635    random_state: Option<u64>,
636    /// Configuration function to apply parameters to estimator
637    config_fn: ConfigFn,
638    /// Phantom data for fitted estimator type
639    _phantom: PhantomData<F>,
640    // Fitted results
641    best_estimator_: Option<F>,
642    best_params_: Option<ParameterSet>,
643    best_score_: Option<f64>,
644    cv_results_: Option<GridSearchResults>,
645}
646
647impl<E, F, ConfigFn> RandomizedSearchCV<E, F, ConfigFn>
648where
649    E: Clone,
650    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
651    F: Predict<Array2<Float>, Array1<Float>>,
652    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
653    ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
654{
655    pub fn new(
656        estimator: E,
657        param_distributions: ParameterDistributions,
658        config_fn: ConfigFn,
659    ) -> Self {
660        Self {
661            estimator,
662            param_distributions,
663            n_iter: 10,
664            cv: Box::new(KFold::new(5)),
665            scoring: Scoring::EstimatorScore,
666            n_jobs: None,
667            refit: true,
668            random_state: None,
669            config_fn,
670            _phantom: PhantomData,
671            best_estimator_: None,
672            best_params_: None,
673            best_score_: None,
674            cv_results_: None,
675        }
676    }
677
678    /// Set the number of parameter settings to sample
679    pub fn n_iter(mut self, n_iter: usize) -> Self {
680        self.n_iter = n_iter;
681        self
682    }
683
684    /// Set the cross-validation strategy
685    pub fn cv<C: CrossValidator + 'static>(mut self, cv: C) -> Self {
686        self.cv = Box::new(cv);
687        self
688    }
689
690    /// Set the scoring method
691    pub fn scoring(mut self, scoring: Scoring) -> Self {
692        self.scoring = scoring;
693        self
694    }
695
696    /// Set the number of parallel jobs
697    pub fn n_jobs(mut self, n_jobs: Option<usize>) -> Self {
698        self.n_jobs = n_jobs;
699        self
700    }
701
702    /// Set whether to refit with best parameters
703    pub fn refit(mut self, refit: bool) -> Self {
704        self.refit = refit;
705        self
706    }
707
708    /// Set the random seed
709    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
710        self.random_state = random_state;
711        self
712    }
713
714    /// Get the best estimator (after fitting)
715    pub fn best_estimator(&self) -> Option<&F> {
716        self.best_estimator_.as_ref()
717    }
718
719    /// Get the best parameters (after fitting)
720    pub fn best_params(&self) -> Option<&ParameterSet> {
721        self.best_params_.as_ref()
722    }
723
724    /// Get the best score (after fitting)
725    pub fn best_score(&self) -> Option<f64> {
726        self.best_score_
727    }
728
729    /// Get the CV results (after fitting)
730    pub fn cv_results(&self) -> Option<&GridSearchResults> {
731        self.cv_results_.as_ref()
732    }
733
734    /// Sample parameter combinations from distributions
735    fn sample_parameters(&self, n_samples: usize) -> Vec<ParameterSet> {
736        use scirs2_core::random::rngs::StdRng;
737        use scirs2_core::random::SeedableRng;
738
739        let mut rng = match self.random_state {
740            Some(seed) => StdRng::seed_from_u64(seed),
741            None => StdRng::seed_from_u64(42),
742        };
743
744        let mut param_sets = Vec::with_capacity(n_samples);
745
746        for _ in 0..n_samples {
747            let mut param_set = HashMap::new();
748
749            for (param_name, distribution) in &self.param_distributions {
750                let value = distribution.sample(&mut rng);
751                param_set.insert(param_name.clone(), value);
752            }
753
754            param_sets.push(param_set);
755        }
756
757        param_sets
758    }
759
760    /// Evaluate a single parameter combination using cross-validation
761    fn evaluate_params(
762        &self,
763        params: &ParameterSet,
764        x: &Array2<Float>,
765        y: &Array1<Float>,
766    ) -> Result<(f64, f64, f64, f64)> {
767        // Configure estimator with current parameters
768        let configured_estimator = (self.config_fn)(self.estimator.clone(), params)?;
769
770        // Get CV splits
771        let splits = self.cv.split(x.nrows(), None);
772        let n_splits = splits.len();
773
774        let mut test_scores = Vec::with_capacity(n_splits);
775        let mut fit_times = Vec::with_capacity(n_splits);
776        let mut score_times = Vec::with_capacity(n_splits);
777
778        // Evaluate on each CV fold
779        for (train_idx, test_idx) in splits {
780            // Extract train and test data
781            let x_train = x.select(scirs2_core::ndarray::Axis(0), &train_idx);
782            let y_train = y.select(scirs2_core::ndarray::Axis(0), &train_idx);
783            let x_test = x.select(scirs2_core::ndarray::Axis(0), &test_idx);
784            let y_test = y.select(scirs2_core::ndarray::Axis(0), &test_idx);
785
786            // Fit the estimator
787            let start = std::time::Instant::now();
788            let fitted = configured_estimator.clone().fit(&x_train, &y_train)?;
789            let fit_time = start.elapsed().as_secs_f64();
790            fit_times.push(fit_time);
791
792            // Score on test set
793            let start = std::time::Instant::now();
794            let test_score = match &self.scoring {
795                Scoring::EstimatorScore => fitted.score(&x_test, &y_test)?,
796                Scoring::Custom(func) => {
797                    let y_pred = fitted.predict(&x_test)?;
798                    func(&y_test.to_owned(), &y_pred)?
799                }
800                Scoring::Metric(metric_name) => {
801                    let y_pred = fitted.predict(&x_test)?;
802                    compute_score_for_regression(metric_name, &y_test, &y_pred)?
803                }
804                Scoring::Scorer(_scorer) => {
805                    let y_pred = fitted.predict(&x_test)?;
806                    // Default to negative MSE for regression
807                    -mean_squared_error(&y_test, &y_pred)?
808                }
809                Scoring::MultiMetric(_metrics) => {
810                    // For multi-metric, just use the first metric for now
811                    fitted.score(&x_test, &y_test)?
812                }
813            };
814            let score_time = start.elapsed().as_secs_f64();
815            score_times.push(score_time);
816            test_scores.push(test_score);
817        }
818
819        // Calculate statistics
820        let mean_test_score = test_scores.iter().sum::<f64>() / test_scores.len() as f64;
821        let std_test_score = {
822            let variance = test_scores
823                .iter()
824                .map(|&score| (score - mean_test_score).powi(2))
825                .sum::<f64>()
826                / test_scores.len() as f64;
827            variance.sqrt()
828        };
829        let mean_fit_time = fit_times.iter().sum::<f64>() / fit_times.len() as f64;
830        let mean_score_time = score_times.iter().sum::<f64>() / score_times.len() as f64;
831
832        Ok((
833            mean_test_score,
834            std_test_score,
835            mean_fit_time,
836            mean_score_time,
837        ))
838    }
839}
840
841impl<E, F, ConfigFn> Fit<Array2<Float>, Array1<Float>> for RandomizedSearchCV<E, F, ConfigFn>
842where
843    E: Clone + Send + Sync,
844    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
845    F: Predict<Array2<Float>, Array1<Float>> + Send + Sync,
846    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
847    ConfigFn: Fn(E, &ParameterSet) -> Result<E> + Send + Sync,
848{
849    type Fitted = RandomizedSearchCV<E, F, ConfigFn>;
850
851    fn fit(mut self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
852        if x.nrows() == 0 {
853            return Err(SklearsError::InvalidInput(
854                "Cannot fit on empty dataset".to_string(),
855            ));
856        }
857
858        if x.nrows() != y.len() {
859            return Err(SklearsError::ShapeMismatch {
860                expected: format!("X.shape[0] = {}", x.nrows()),
861                actual: format!("y.shape[0] = {}", y.len()),
862            });
863        }
864
865        if self.param_distributions.is_empty() {
866            return Err(SklearsError::InvalidInput(
867                "No parameter distributions to sample from".to_string(),
868            ));
869        }
870
871        // Sample parameter combinations
872        let param_combinations = self.sample_parameters(self.n_iter);
873
874        // Evaluate each parameter combination
875        let mut results = Vec::with_capacity(param_combinations.len());
876
877        for params in &param_combinations {
878            let (mean_score, std_score, mean_fit_time, mean_score_time) =
879                self.evaluate_params(params, x, y)?;
880
881            results.push((mean_score, std_score, mean_fit_time, mean_score_time));
882        }
883
884        // Find best parameters
885        let best_idx = results
886            .iter()
887            .enumerate()
888            .max_by(|(_, a), (_, b)| a.0.partial_cmp(&b.0).expect("operation should succeed"))
889            .map(|(idx, _)| idx)
890            .ok_or_else(|| SklearsError::NumericalError("No valid scores found".to_string()))?;
891
892        let best_params = param_combinations[best_idx].clone();
893        let best_score = results[best_idx].0;
894
895        // Create CV results
896        let mean_test_scores = Array1::from_vec(results.iter().map(|r| r.0).collect());
897        let std_test_scores = Array1::from_vec(results.iter().map(|r| r.1).collect());
898        let mean_fit_times = Array1::from_vec(results.iter().map(|r| r.2).collect());
899        let mean_score_times = Array1::from_vec(results.iter().map(|r| r.3).collect());
900
901        // Calculate ranks (1 = best)
902        let mut scores_with_idx: Vec<(f64, usize)> = mean_test_scores
903            .iter()
904            .enumerate()
905            .map(|(i, &score)| (score, i))
906            .collect();
907        scores_with_idx.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("operation should succeed"));
908
909        let mut ranks = vec![0; param_combinations.len()];
910        for (rank, (_, idx)) in scores_with_idx.iter().enumerate() {
911            ranks[*idx] = rank + 1;
912        }
913
914        let cv_results = GridSearchResults {
915            params: param_combinations.clone(),
916            mean_test_scores,
917            std_test_scores,
918            mean_fit_times,
919            mean_score_times,
920            rank_test_scores: Array1::from_vec(ranks),
921        };
922
923        // Refit with best parameters if requested
924        let best_estimator = if self.refit {
925            let configured_estimator = (self.config_fn)(self.estimator.clone(), &best_params)?;
926            Some(configured_estimator.fit(x, y)?)
927        } else {
928            None
929        };
930
931        self.best_estimator_ = best_estimator;
932        self.best_params_ = Some(best_params);
933        self.best_score_ = Some(best_score);
934        self.cv_results_ = Some(cv_results);
935
936        Ok(self)
937    }
938}
939
940impl<E, F, ConfigFn> Predict<Array2<Float>, Array1<Float>> for RandomizedSearchCV<E, F, ConfigFn>
941where
942    E: Clone,
943    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
944    F: Predict<Array2<Float>, Array1<Float>>,
945    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
946    ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
947{
948    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
949        match &self.best_estimator_ {
950            Some(estimator) => estimator.predict(x),
951            None => Err(SklearsError::NotFitted {
952                operation: "predict".to_string(),
953            }),
954        }
955    }
956}
957
958impl<E, F, ConfigFn> Score<Array2<Float>, Array1<Float>> for RandomizedSearchCV<E, F, ConfigFn>
959where
960    E: Clone,
961    E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
962    F: Predict<Array2<Float>, Array1<Float>>,
963    F: Score<Array2<Float>, Array1<Float>, Float = f64>,
964    ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
965{
966    type Float = f64;
967    fn score(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
968        match &self.best_estimator_ {
969            Some(estimator) => estimator.score(x, y),
970            None => Err(SklearsError::NotFitted {
971                operation: "score".to_string(),
972            }),
973        }
974    }
975}
976
977#[allow(non_snake_case)]
978#[cfg(test)]
979mod tests {
980    use super::*;
981    use crate::KFold;
982    use scirs2_core::ndarray::array;
983    // use sklears_ensemble::GradientBoostingRegressor;
984
985    // Mock estimator for testing
986    #[derive(Debug, Clone)]
987    struct MockRegressor {
988        n_estimators: usize,
989        learning_rate: f64,
990        random_state: Option<u64>,
991        fitted: bool,
992    }
993
994    impl MockRegressor {
995        fn new() -> Self {
996            Self {
997                n_estimators: 100,
998                learning_rate: 0.1,
999                random_state: None,
1000                fitted: false,
1001            }
1002        }
1003
1004        fn n_estimators(mut self, n: usize) -> Self {
1005            self.n_estimators = n;
1006            self
1007        }
1008
1009        fn learning_rate(mut self, lr: f64) -> Self {
1010            self.learning_rate = lr;
1011            self
1012        }
1013
1014        fn random_state(mut self, state: Option<u64>) -> Self {
1015            self.random_state = state;
1016            self
1017        }
1018    }
1019
1020    impl Fit<Array2<f64>, Array1<f64>> for MockRegressor {
1021        type Fitted = MockRegressor;
1022
1023        fn fit(mut self, _x: &Array2<f64>, _y: &Array1<f64>) -> Result<Self::Fitted> {
1024            self.fitted = true;
1025            Ok(self)
1026        }
1027    }
1028
1029    impl Predict<Array2<f64>, Array1<f64>> for MockRegressor {
1030        fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
1031            if !self.fitted {
1032                return Err(SklearsError::NotFitted {
1033                    operation: "predict".to_string(),
1034                });
1035            }
1036            // Simple prediction: sum of features
1037            Ok(x.sum_axis(scirs2_core::ndarray::Axis(1)))
1038        }
1039    }
1040
1041    impl Score<Array2<f64>, Array1<f64>> for MockRegressor {
1042        type Float = f64;
1043
1044        fn score(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
1045            let y_pred = self.predict(x)?;
1046            let mse = mean_squared_error(y, &y_pred)?;
1047            Ok(-mse) // Return negative MSE as score (higher is better)
1048        }
1049    }
1050
1051    type GradientBoostingRegressor = MockRegressor;
1052
1053    #[test]
1054    fn test_parameter_value_extraction() {
1055        let int_param = ParameterValue::Int(42);
1056        assert_eq!(int_param.as_int(), Some(42));
1057        assert_eq!(int_param.as_float(), None);
1058
1059        let float_param = ParameterValue::Float(std::f64::consts::PI);
1060        assert_eq!(float_param.as_float(), Some(std::f64::consts::PI));
1061        assert_eq!(float_param.as_int(), None);
1062
1063        let opt_int_param = ParameterValue::OptionalInt(Some(10));
1064        assert_eq!(opt_int_param.as_optional_int(), Some(Some(10)));
1065    }
1066
1067    #[test]
1068    #[ignore] // Temporarily disabled due to sklears_ensemble dependency issues
1069    fn test_grid_search_cv() {
1070        // Create a simple dataset
1071        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0],];
1072        let y = array![1.0, 4.0, 9.0, 16.0, 25.0]; // roughly x^2
1073
1074        // Create parameter grid
1075        let mut param_grid = HashMap::new();
1076        param_grid.insert(
1077            "n_estimators".to_string(),
1078            vec![ParameterValue::Int(5), ParameterValue::Int(10)],
1079        );
1080        param_grid.insert(
1081            "learning_rate".to_string(),
1082            vec![ParameterValue::Float(0.1), ParameterValue::Float(0.3)],
1083        );
1084
1085        // Configuration function for GradientBoostingRegressor
1086        let config_fn = |estimator: GradientBoostingRegressor,
1087                         params: &ParameterSet|
1088         -> Result<GradientBoostingRegressor> {
1089            let mut configured = estimator;
1090
1091            if let Some(n_est) = params.get("n_estimators").and_then(|p| p.as_int()) {
1092                configured = configured.n_estimators(n_est as usize);
1093            }
1094
1095            if let Some(lr) = params.get("learning_rate").and_then(|p| p.as_float()) {
1096                configured = configured.learning_rate(lr);
1097            }
1098
1099            Ok(configured)
1100        };
1101
1102        // Create and fit grid search
1103        let base_estimator = GradientBoostingRegressor::new().random_state(Some(42));
1104        let grid_search = GridSearchCV::new(base_estimator, param_grid, config_fn)
1105            .cv(KFold::new(3))
1106            .fit(&x, &y)
1107            .expect("operation should succeed");
1108
1109        // Check that we have results
1110        assert!(grid_search.best_score().is_some());
1111        assert!(grid_search.best_params().is_some());
1112        assert!(grid_search.best_estimator().is_some());
1113        assert!(grid_search.cv_results().is_some());
1114
1115        // Check CV results structure
1116        let cv_results = grid_search.cv_results().expect("operation should succeed");
1117        assert_eq!(cv_results.params.len(), 4); // 2 x 2 = 4 combinations
1118        assert_eq!(cv_results.mean_test_scores.len(), 4);
1119        assert_eq!(cv_results.rank_test_scores.len(), 4);
1120
1121        // Best rank should be 1
1122        let best_rank = cv_results
1123            .rank_test_scores
1124            .iter()
1125            .min()
1126            .expect("operation should succeed");
1127        assert_eq!(*best_rank, 1);
1128
1129        // Test prediction with best estimator
1130        let predictions = grid_search.predict(&x).expect("operation should succeed");
1131        assert_eq!(predictions.len(), x.nrows());
1132    }
1133
1134    #[test]
1135    #[ignore] // Temporarily disabled due to sklears_ensemble dependency issues
1136    fn test_grid_search_empty_grid() {
1137        let x = array![[1.0], [2.0]];
1138        let y = array![1.0, 2.0];
1139
1140        let param_grid = HashMap::new(); // Empty grid
1141        let config_fn = |estimator: GradientBoostingRegressor,
1142                         _params: &ParameterSet|
1143         -> Result<GradientBoostingRegressor> { Ok(estimator) };
1144
1145        let base_estimator = GradientBoostingRegressor::new();
1146        let result = GridSearchCV::new(base_estimator, param_grid, config_fn)
1147            .cv(KFold::new(2)) // Use 2 folds for 2 samples
1148            .fit(&x, &y);
1149
1150        // Empty grid should succeed with default parameters
1151        assert!(result.is_ok());
1152        let grid_search = result.expect("operation should succeed");
1153
1154        // Should have one parameter combination (empty set = default params)
1155        let cv_results = grid_search.cv_results().expect("operation should succeed");
1156        assert_eq!(cv_results.params.len(), 1);
1157        assert!(cv_results.params[0].is_empty()); // Empty parameter set
1158    }
1159
1160    #[test]
1161    fn test_parameter_distribution_sampling() {
1162        use scirs2_core::random::rngs::StdRng;
1163        use scirs2_core::random::SeedableRng;
1164        let mut rng = StdRng::seed_from_u64(42);
1165
1166        // Test Choice distribution
1167        let choice_dist = ParameterDistribution::Choice(vec![
1168            ParameterValue::Int(1),
1169            ParameterValue::Int(2),
1170            ParameterValue::Int(3),
1171        ]);
1172        let sample = choice_dist.sample(&mut rng);
1173        if let ParameterValue::Int(val) = sample {
1174            assert!(val >= 1 && val <= 3);
1175        } else {
1176            panic!("Expected Int parameter value");
1177        }
1178
1179        // Test RandInt distribution
1180        let int_dist = ParameterDistribution::RandInt { low: 10, high: 20 };
1181        let sample = int_dist.sample(&mut rng);
1182        if let ParameterValue::Int(val) = sample {
1183            assert!(val >= 10 && val < 20);
1184        } else {
1185            panic!("Expected Int parameter value");
1186        }
1187
1188        // Test Uniform distribution
1189        let uniform_dist = ParameterDistribution::Uniform {
1190            low: 0.0,
1191            high: 1.0,
1192        };
1193        let sample = uniform_dist.sample(&mut rng);
1194        if let ParameterValue::Float(val) = sample {
1195            assert!(val >= 0.0 && val < 1.0);
1196        } else {
1197            panic!("Expected Float parameter value");
1198        }
1199
1200        // Test Normal distribution
1201        let normal_dist = ParameterDistribution::Normal {
1202            mean: 0.0,
1203            std: 1.0,
1204        };
1205        let sample = normal_dist.sample(&mut rng);
1206        assert!(matches!(sample, ParameterValue::Float(_)));
1207    }
1208
1209    #[test]
1210    #[ignore] // Temporarily disabled due to sklears_ensemble dependency issues
1211    fn test_randomized_search_cv() {
1212        // Create a simple dataset
1213        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0],];
1214        let y = array![1.0, 4.0, 9.0, 16.0, 25.0]; // roughly x^2
1215
1216        // Create parameter distributions
1217        let mut param_distributions = HashMap::new();
1218        param_distributions.insert(
1219            "n_estimators".to_string(),
1220            ParameterDistribution::Choice(vec![
1221                ParameterValue::Int(5),
1222                ParameterValue::Int(10),
1223                ParameterValue::Int(15),
1224            ]),
1225        );
1226        param_distributions.insert(
1227            "learning_rate".to_string(),
1228            ParameterDistribution::Uniform {
1229                low: 0.05,
1230                high: 0.5,
1231            },
1232        );
1233
1234        // Configuration function for GradientBoostingRegressor
1235        let config_fn = |estimator: GradientBoostingRegressor,
1236                         params: &ParameterSet|
1237         -> Result<GradientBoostingRegressor> {
1238            let mut configured = estimator;
1239
1240            if let Some(n_est) = params.get("n_estimators").and_then(|p| p.as_int()) {
1241                configured = configured.n_estimators(n_est as usize);
1242            }
1243
1244            if let Some(lr) = params.get("learning_rate").and_then(|p| p.as_float()) {
1245                configured = configured.learning_rate(lr);
1246            }
1247
1248            Ok(configured)
1249        };
1250
1251        // Create and fit randomized search
1252        let base_estimator = GradientBoostingRegressor::new().random_state(Some(42));
1253        let randomized_search =
1254            RandomizedSearchCV::new(base_estimator, param_distributions, config_fn)
1255                .n_iter(8) // Sample 8 parameter combinations
1256                .cv(KFold::new(3))
1257                .random_state(Some(42))
1258                .fit(&x, &y)
1259                .expect("operation should succeed");
1260
1261        // Check that we have results
1262        assert!(randomized_search.best_score().is_some());
1263        assert!(randomized_search.best_params().is_some());
1264        assert!(randomized_search.best_estimator().is_some());
1265        assert!(randomized_search.cv_results().is_some());
1266
1267        // Check CV results structure
1268        let cv_results = randomized_search
1269            .cv_results()
1270            .expect("operation should succeed");
1271        assert_eq!(cv_results.params.len(), 8); // Should have 8 sampled combinations
1272        assert_eq!(cv_results.mean_test_scores.len(), 8);
1273        assert_eq!(cv_results.rank_test_scores.len(), 8);
1274
1275        // Best rank should be 1
1276        let best_rank = cv_results
1277            .rank_test_scores
1278            .iter()
1279            .min()
1280            .expect("operation should succeed");
1281        assert_eq!(*best_rank, 1);
1282
1283        // Test prediction with best estimator
1284        let predictions = randomized_search
1285            .predict(&x)
1286            .expect("operation should succeed");
1287        assert_eq!(predictions.len(), x.nrows());
1288
1289        // Check that parameter values are within expected ranges
1290        for params in &cv_results.params {
1291            if let Some(n_est) = params.get("n_estimators").and_then(|p| p.as_int()) {
1292                assert!(n_est == 5 || n_est == 10 || n_est == 15);
1293            }
1294            if let Some(lr) = params.get("learning_rate").and_then(|p| p.as_float()) {
1295                assert!(lr >= 0.05 && lr < 0.5);
1296            }
1297        }
1298    }
1299
1300    #[test]
1301    #[ignore] // Temporarily disabled due to sklears_ensemble dependency issues
1302    fn test_randomized_search_empty_distributions() {
1303        let x = array![[1.0], [2.0]];
1304        let y = array![1.0, 2.0];
1305
1306        let param_distributions = HashMap::new(); // Empty distributions
1307        let config_fn = |estimator: GradientBoostingRegressor,
1308                         _params: &ParameterSet|
1309         -> Result<GradientBoostingRegressor> { Ok(estimator) };
1310
1311        let base_estimator = GradientBoostingRegressor::new();
1312        let result = RandomizedSearchCV::new(base_estimator, param_distributions, config_fn)
1313            .cv(KFold::new(2))
1314            .fit(&x, &y);
1315
1316        assert!(result.is_err());
1317    }
1318
1319    #[test]
1320    #[ignore] // Temporarily disabled due to sklears_ensemble dependency issues
1321    fn test_randomized_search_reproducibility() {
1322        let x = array![[1.0], [2.0], [3.0], [4.0]];
1323        let y = array![1.0, 2.0, 3.0, 4.0];
1324
1325        // Create parameter distributions
1326        let mut param_distributions = HashMap::new();
1327        param_distributions.insert(
1328            "learning_rate".to_string(),
1329            ParameterDistribution::Uniform {
1330                low: 0.1,
1331                high: 0.5,
1332            },
1333        );
1334
1335        let config_fn = |estimator: GradientBoostingRegressor,
1336                         params: &ParameterSet|
1337         -> Result<GradientBoostingRegressor> {
1338            let mut configured = estimator;
1339            if let Some(lr) = params.get("learning_rate").and_then(|p| p.as_float()) {
1340                configured = configured.learning_rate(lr);
1341            }
1342            Ok(configured)
1343        };
1344
1345        // Run twice with same random state
1346        let base_estimator1 = GradientBoostingRegressor::new().random_state(Some(42));
1347        let result1 =
1348            RandomizedSearchCV::new(base_estimator1, param_distributions.clone(), config_fn)
1349                .n_iter(5)
1350                .random_state(Some(123))
1351                .cv(KFold::new(2))
1352                .fit(&x, &y)
1353                .expect("operation should succeed");
1354
1355        let base_estimator2 = GradientBoostingRegressor::new().random_state(Some(42));
1356        let result2 = RandomizedSearchCV::new(base_estimator2, param_distributions, config_fn)
1357            .n_iter(5)
1358            .random_state(Some(123))
1359            .cv(KFold::new(2))
1360            .fit(&x, &y)
1361            .expect("operation should succeed");
1362
1363        // Should get identical results
1364        assert_eq!(result1.best_score(), result2.best_score());
1365
1366        let params1 = result1.cv_results().expect("operation should succeed");
1367        let params2 = result2.cv_results().expect("operation should succeed");
1368
1369        // Check that the same parameters were sampled
1370        for (p1, p2) in params1.params.iter().zip(params2.params.iter()) {
1371            assert_eq!(p1, p2);
1372        }
1373    }
1374}