1use crate::{CrossValidator, KFold, Scoring};
4use scirs2_core::ndarray::{Array1, Array2};
5use scirs2_core::rand_prelude::IndexedRandom;
6use scirs2_core::random::essentials::Normal as RandNormal;
7use sklears_core::{
9 error::Result,
10 prelude::SklearsError,
11 traits::{Fit, Predict, Score},
12 types::Float,
13};
14use sklears_metrics::{classification::accuracy_score, get_scorer, regression::mean_squared_error};
15use std::collections::HashMap;
16use std::marker::PhantomData;
17
18pub type ParameterGrid = HashMap<String, Vec<ParameterValue>>;
23
24#[derive(Debug, Clone, PartialEq)]
26#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
27pub enum ParameterValue {
28 Int(i64),
30 Float(f64),
32 Bool(bool),
34 String(String),
36 OptionalInt(Option<i64>),
38 OptionalFloat(Option<f64>),
40}
41
42impl ParameterValue {
43 pub fn as_int(&self) -> Option<i64> {
45 match self {
46 ParameterValue::Int(v) => Some(*v),
47 _ => None,
48 }
49 }
50
51 pub fn as_float(&self) -> Option<f64> {
53 match self {
54 ParameterValue::Float(v) => Some(*v),
55 _ => None,
56 }
57 }
58
59 pub fn as_bool(&self) -> Option<bool> {
61 match self {
62 ParameterValue::Bool(v) => Some(*v),
63 _ => None,
64 }
65 }
66
67 pub fn as_optional_int(&self) -> Option<Option<i64>> {
69 match self {
70 ParameterValue::OptionalInt(v) => Some(*v),
71 _ => None,
72 }
73 }
74
75 pub fn as_optional_float(&self) -> Option<Option<f64>> {
77 match self {
78 ParameterValue::OptionalFloat(v) => Some(*v),
79 _ => None,
80 }
81 }
82}
83
84impl From<i32> for ParameterValue {
85 fn from(value: i32) -> Self {
86 ParameterValue::Int(value as i64)
87 }
88}
89
90impl From<i64> for ParameterValue {
91 fn from(value: i64) -> Self {
92 ParameterValue::Int(value)
93 }
94}
95
96impl From<f32> for ParameterValue {
97 fn from(value: f32) -> Self {
98 ParameterValue::Float(value as f64)
99 }
100}
101
102impl From<f64> for ParameterValue {
103 fn from(value: f64) -> Self {
104 ParameterValue::Float(value)
105 }
106}
107
108impl From<bool> for ParameterValue {
109 fn from(value: bool) -> Self {
110 ParameterValue::Bool(value)
111 }
112}
113
114impl From<String> for ParameterValue {
115 fn from(value: String) -> Self {
116 ParameterValue::String(value)
117 }
118}
119
120impl From<&str> for ParameterValue {
121 fn from(value: &str) -> Self {
122 ParameterValue::String(value.to_string())
123 }
124}
125
126impl From<Option<i32>> for ParameterValue {
127 fn from(value: Option<i32>) -> Self {
128 ParameterValue::OptionalInt(value.map(|v| v as i64))
129 }
130}
131
132impl From<Option<i64>> for ParameterValue {
133 fn from(value: Option<i64>) -> Self {
134 ParameterValue::OptionalInt(value)
135 }
136}
137
138impl From<Option<f64>> for ParameterValue {
139 fn from(value: Option<f64>) -> Self {
140 ParameterValue::OptionalFloat(value)
141 }
142}
143
144impl Eq for ParameterValue {}
145
146impl std::hash::Hash for ParameterValue {
147 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
148 std::mem::discriminant(self).hash(state);
149 match self {
150 ParameterValue::Int(v) => v.hash(state),
151 ParameterValue::Float(v) => v.to_bits().hash(state), ParameterValue::Bool(v) => v.hash(state),
153 ParameterValue::String(v) => v.hash(state),
154 ParameterValue::OptionalInt(v) => v.hash(state),
155 ParameterValue::OptionalFloat(v) => v.map(|f| f.to_bits()).hash(state),
156 }
157 }
158}
159
160pub type ParameterSet = HashMap<String, ParameterValue>;
162
163pub struct GridSearchCV<E, F, ConfigFn>
168where
169 E: Clone,
170 E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
171 F: Predict<Array2<Float>, Array1<Float>>,
172 F: Score<Array2<Float>, Array1<Float>, Float = f64>,
173 ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
174{
175 estimator: E,
177 param_grid: ParameterGrid,
179 cv: Box<dyn CrossValidator>,
181 scoring: Scoring,
183 n_jobs: Option<usize>,
185 refit: bool,
187 config_fn: ConfigFn,
189 _phantom: PhantomData<F>,
191 best_estimator_: Option<F>,
193 best_params_: Option<ParameterSet>,
194 best_score_: Option<f64>,
195 cv_results_: Option<GridSearchResults>,
196}
197
198#[derive(Debug, Clone)]
200pub struct GridSearchResults {
201 pub params: Vec<ParameterSet>,
202 pub mean_test_scores: Array1<f64>,
203 pub std_test_scores: Array1<f64>,
204 pub mean_fit_times: Array1<f64>,
205 pub mean_score_times: Array1<f64>,
206 pub rank_test_scores: Array1<usize>,
207}
208
209fn compute_score_for_regression(
211 metric_name: &str,
212 y_true: &Array1<f64>,
213 y_pred: &Array1<f64>,
214) -> Result<f64> {
215 match metric_name {
216 "neg_mean_squared_error" => Ok(-mean_squared_error(y_true, y_pred)?),
217 "mean_squared_error" => Ok(mean_squared_error(y_true, y_pred)?),
218 _ => {
219 Err(SklearsError::InvalidInput(format!(
221 "Metric '{}' not supported for regression",
222 metric_name
223 )))
224 }
225 }
226}
227
228fn compute_score_for_classification(
230 metric_name: &str,
231 y_true: &Array1<i32>,
232 y_pred: &Array1<i32>,
233) -> Result<f64> {
234 match metric_name {
235 "accuracy" => Ok(accuracy_score(y_true, y_pred)?),
236 _ => {
237 let scorer = get_scorer(metric_name)?;
238 scorer.score(
239 y_true.as_slice().expect("operation should succeed"),
240 y_pred.as_slice().expect("operation should succeed"),
241 )
242 }
243 }
244}
245
246impl<E, F, ConfigFn> GridSearchCV<E, F, ConfigFn>
247where
248 E: Clone,
249 E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
250 F: Predict<Array2<Float>, Array1<Float>>,
251 F: Score<Array2<Float>, Array1<Float>, Float = f64>,
252 ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
253{
254 pub fn new(estimator: E, param_grid: ParameterGrid, config_fn: ConfigFn) -> Self {
256 Self {
257 estimator,
258 param_grid,
259 cv: Box::new(KFold::new(5)),
260 scoring: Scoring::EstimatorScore,
261 n_jobs: None,
262 refit: true,
263 config_fn,
264 _phantom: PhantomData,
265 best_estimator_: None,
266 best_params_: None,
267 best_score_: None,
268 cv_results_: None,
269 }
270 }
271
272 pub fn cv<C: CrossValidator + 'static>(mut self, cv: C) -> Self {
274 self.cv = Box::new(cv);
275 self
276 }
277
278 pub fn scoring(mut self, scoring: Scoring) -> Self {
280 self.scoring = scoring;
281 self
282 }
283
284 pub fn n_jobs(mut self, n_jobs: Option<usize>) -> Self {
286 self.n_jobs = n_jobs;
287 self
288 }
289
290 pub fn refit(mut self, refit: bool) -> Self {
292 self.refit = refit;
293 self
294 }
295
296 pub fn best_estimator(&self) -> Option<&F> {
298 self.best_estimator_.as_ref()
299 }
300
301 pub fn best_params(&self) -> Option<&ParameterSet> {
303 self.best_params_.as_ref()
304 }
305
306 pub fn best_score(&self) -> Option<f64> {
308 self.best_score_
309 }
310
311 pub fn cv_results(&self) -> Option<&GridSearchResults> {
313 self.cv_results_.as_ref()
314 }
315
316 fn generate_param_combinations(&self) -> Vec<ParameterSet> {
318 let mut combinations = vec![HashMap::new()];
319
320 for (param_name, param_values) in &self.param_grid {
321 let mut new_combinations = Vec::new();
322
323 for combination in combinations {
324 for param_value in param_values {
325 let mut new_combination = combination.clone();
326 new_combination.insert(param_name.clone(), param_value.clone());
327 new_combinations.push(new_combination);
328 }
329 }
330
331 combinations = new_combinations;
332 }
333
334 combinations
335 }
336
337 fn evaluate_params(
339 &self,
340 params: &ParameterSet,
341 x: &Array2<Float>,
342 y: &Array1<Float>,
343 ) -> Result<(f64, f64, f64, f64)> {
344 let configured_estimator = (self.config_fn)(self.estimator.clone(), params)?;
346
347 let splits = self.cv.split(x.nrows(), None);
349 let n_splits = splits.len();
350
351 let mut test_scores = Vec::with_capacity(n_splits);
352 let mut fit_times = Vec::with_capacity(n_splits);
353 let mut score_times = Vec::with_capacity(n_splits);
354
355 for (train_idx, test_idx) in splits {
357 let x_train = x.select(scirs2_core::ndarray::Axis(0), &train_idx);
359 let y_train = y.select(scirs2_core::ndarray::Axis(0), &train_idx);
360 let x_test = x.select(scirs2_core::ndarray::Axis(0), &test_idx);
361 let y_test = y.select(scirs2_core::ndarray::Axis(0), &test_idx);
362
363 let start = std::time::Instant::now();
365 let fitted = configured_estimator.clone().fit(&x_train, &y_train)?;
366 let fit_time = start.elapsed().as_secs_f64();
367 fit_times.push(fit_time);
368
369 let start = std::time::Instant::now();
371 let test_score = match &self.scoring {
372 Scoring::EstimatorScore => fitted.score(&x_test, &y_test)?,
373 Scoring::Custom(func) => {
374 let y_pred = fitted.predict(&x_test)?;
375 func(&y_test.to_owned(), &y_pred)?
376 }
377 Scoring::Metric(metric_name) => {
378 let y_pred = fitted.predict(&x_test)?;
379 compute_score_for_regression(metric_name, &y_test, &y_pred)?
380 }
381 Scoring::Scorer(_scorer) => {
382 let y_pred = fitted.predict(&x_test)?;
383 -mean_squared_error(&y_test, &y_pred)?
385 }
386 Scoring::MultiMetric(_metrics) => {
387 fitted.score(&x_test, &y_test)?
389 }
390 };
391 let score_time = start.elapsed().as_secs_f64();
392 score_times.push(score_time);
393 test_scores.push(test_score);
394 }
395
396 let mean_test_score = test_scores.iter().sum::<f64>() / test_scores.len() as f64;
398 let std_test_score = {
399 let variance = test_scores
400 .iter()
401 .map(|&score| (score - mean_test_score).powi(2))
402 .sum::<f64>()
403 / test_scores.len() as f64;
404 variance.sqrt()
405 };
406 let mean_fit_time = fit_times.iter().sum::<f64>() / fit_times.len() as f64;
407 let mean_score_time = score_times.iter().sum::<f64>() / score_times.len() as f64;
408
409 Ok((
410 mean_test_score,
411 std_test_score,
412 mean_fit_time,
413 mean_score_time,
414 ))
415 }
416}
417
418impl<E, F, ConfigFn> Fit<Array2<Float>, Array1<Float>> for GridSearchCV<E, F, ConfigFn>
419where
420 E: Clone + Send + Sync,
421 E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
422 F: Predict<Array2<Float>, Array1<Float>> + Send + Sync,
423 F: Score<Array2<Float>, Array1<Float>, Float = f64>,
424 ConfigFn: Fn(E, &ParameterSet) -> Result<E> + Send + Sync,
425{
426 type Fitted = GridSearchCV<E, F, ConfigFn>;
427
428 fn fit(mut self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
429 if x.nrows() == 0 {
430 return Err(SklearsError::InvalidInput(
431 "Cannot fit on empty dataset".to_string(),
432 ));
433 }
434
435 if x.nrows() != y.len() {
436 return Err(SklearsError::ShapeMismatch {
437 expected: format!("X.shape[0] = {}", x.nrows()),
438 actual: format!("y.shape[0] = {}", y.len()),
439 });
440 }
441
442 let param_combinations = self.generate_param_combinations();
444
445 if param_combinations.is_empty() {
446 return Err(SklearsError::InvalidInput(
447 "No parameter combinations to evaluate".to_string(),
448 ));
449 }
450
451 let mut results = Vec::with_capacity(param_combinations.len());
453
454 for params in ¶m_combinations {
455 let (mean_score, std_score, mean_fit_time, mean_score_time) =
456 self.evaluate_params(params, x, y)?;
457
458 results.push((mean_score, std_score, mean_fit_time, mean_score_time));
459 }
460
461 let best_idx = results
463 .iter()
464 .enumerate()
465 .max_by(|(_, a), (_, b)| a.0.partial_cmp(&b.0).expect("operation should succeed"))
466 .map(|(idx, _)| idx)
467 .ok_or_else(|| SklearsError::NumericalError("No valid scores found".to_string()))?;
468
469 let best_params = param_combinations[best_idx].clone();
470 let best_score = results[best_idx].0;
471
472 let mean_test_scores = Array1::from_vec(results.iter().map(|r| r.0).collect());
474 let std_test_scores = Array1::from_vec(results.iter().map(|r| r.1).collect());
475 let mean_fit_times = Array1::from_vec(results.iter().map(|r| r.2).collect());
476 let mean_score_times = Array1::from_vec(results.iter().map(|r| r.3).collect());
477
478 let mut scores_with_idx: Vec<(f64, usize)> = mean_test_scores
480 .iter()
481 .enumerate()
482 .map(|(i, &score)| (score, i))
483 .collect();
484 scores_with_idx.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("operation should succeed"));
485
486 let mut ranks = vec![0; param_combinations.len()];
487 for (rank, (_, idx)) in scores_with_idx.iter().enumerate() {
488 ranks[*idx] = rank + 1;
489 }
490
491 let cv_results = GridSearchResults {
492 params: param_combinations.clone(),
493 mean_test_scores,
494 std_test_scores,
495 mean_fit_times,
496 mean_score_times,
497 rank_test_scores: Array1::from_vec(ranks),
498 };
499
500 let best_estimator = if self.refit {
502 let configured_estimator = (self.config_fn)(self.estimator.clone(), &best_params)?;
503 Some(configured_estimator.fit(x, y)?)
504 } else {
505 None
506 };
507
508 self.best_estimator_ = best_estimator;
509 self.best_params_ = Some(best_params);
510 self.best_score_ = Some(best_score);
511 self.cv_results_ = Some(cv_results);
512
513 Ok(self)
514 }
515}
516
517impl<E, F, ConfigFn> Predict<Array2<Float>, Array1<Float>> for GridSearchCV<E, F, ConfigFn>
518where
519 E: Clone,
520 E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
521 F: Predict<Array2<Float>, Array1<Float>>,
522 F: Score<Array2<Float>, Array1<Float>, Float = f64>,
523 ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
524{
525 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
526 match &self.best_estimator_ {
527 Some(estimator) => estimator.predict(x),
528 None => Err(SklearsError::NotFitted {
529 operation: "predict".to_string(),
530 }),
531 }
532 }
533}
534
535impl<E, F, ConfigFn> Score<Array2<Float>, Array1<Float>> for GridSearchCV<E, F, ConfigFn>
536where
537 E: Clone,
538 E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
539 F: Predict<Array2<Float>, Array1<Float>>,
540 F: Score<Array2<Float>, Array1<Float>, Float = f64>,
541 ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
542{
543 type Float = f64;
544 fn score(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
545 match &self.best_estimator_ {
546 Some(estimator) => estimator.score(x, y),
547 None => Err(SklearsError::NotFitted {
548 operation: "score".to_string(),
549 }),
550 }
551 }
552}
553
554#[derive(Debug, Clone)]
556pub enum ParameterDistribution {
557 Choice(Vec<ParameterValue>),
559 RandInt { low: i64, high: i64 },
561 Uniform { low: f64, high: f64 },
563 LogUniform { low: f64, high: f64 },
565 Normal { mean: f64, std: f64 },
567}
568
569impl ParameterDistribution {
570 pub fn sample(&self, rng: &mut impl scirs2_core::random::Rng) -> ParameterValue {
572 use scirs2_core::essentials::Uniform;
573 use scirs2_core::random::Distribution;
574
575 match self {
576 ParameterDistribution::Choice(values) => values
577 .as_slice()
578 .choose(rng)
579 .expect("operation should succeed")
580 .clone(),
581 ParameterDistribution::RandInt { low, high } => {
582 let dist = Uniform::new(*low, *high).expect("operation should succeed");
583 ParameterValue::Int(dist.sample(rng))
584 }
585 ParameterDistribution::Uniform { low, high } => {
586 let dist = Uniform::new(*low, *high).expect("operation should succeed");
587 ParameterValue::Float(dist.sample(rng))
588 }
589 ParameterDistribution::LogUniform { low, high } => {
590 let log_low = low.ln();
592 let log_high = high.ln();
593 let dist = Uniform::new(log_low, log_high).expect("operation should succeed");
594 let log_sample = dist.sample(rng);
595 ParameterValue::Float(log_sample.exp())
596 }
597 ParameterDistribution::Normal { mean, std } => {
598 let dist = RandNormal::new(*mean, *std).expect("operation should succeed");
599 ParameterValue::Float(dist.sample(rng))
600 }
601 }
602 }
603}
604
605pub type ParameterDistributions = HashMap<String, ParameterDistribution>;
607
608pub struct RandomizedSearchCV<E, F, ConfigFn>
613where
614 E: Clone,
615 E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
616 F: Predict<Array2<Float>, Array1<Float>>,
617 F: Score<Array2<Float>, Array1<Float>, Float = f64>,
618 ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
619{
620 estimator: E,
622 param_distributions: ParameterDistributions,
624 n_iter: usize,
626 cv: Box<dyn CrossValidator>,
628 scoring: Scoring,
630 n_jobs: Option<usize>,
632 refit: bool,
634 random_state: Option<u64>,
636 config_fn: ConfigFn,
638 _phantom: PhantomData<F>,
640 best_estimator_: Option<F>,
642 best_params_: Option<ParameterSet>,
643 best_score_: Option<f64>,
644 cv_results_: Option<GridSearchResults>,
645}
646
647impl<E, F, ConfigFn> RandomizedSearchCV<E, F, ConfigFn>
648where
649 E: Clone,
650 E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
651 F: Predict<Array2<Float>, Array1<Float>>,
652 F: Score<Array2<Float>, Array1<Float>, Float = f64>,
653 ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
654{
655 pub fn new(
656 estimator: E,
657 param_distributions: ParameterDistributions,
658 config_fn: ConfigFn,
659 ) -> Self {
660 Self {
661 estimator,
662 param_distributions,
663 n_iter: 10,
664 cv: Box::new(KFold::new(5)),
665 scoring: Scoring::EstimatorScore,
666 n_jobs: None,
667 refit: true,
668 random_state: None,
669 config_fn,
670 _phantom: PhantomData,
671 best_estimator_: None,
672 best_params_: None,
673 best_score_: None,
674 cv_results_: None,
675 }
676 }
677
678 pub fn n_iter(mut self, n_iter: usize) -> Self {
680 self.n_iter = n_iter;
681 self
682 }
683
684 pub fn cv<C: CrossValidator + 'static>(mut self, cv: C) -> Self {
686 self.cv = Box::new(cv);
687 self
688 }
689
690 pub fn scoring(mut self, scoring: Scoring) -> Self {
692 self.scoring = scoring;
693 self
694 }
695
696 pub fn n_jobs(mut self, n_jobs: Option<usize>) -> Self {
698 self.n_jobs = n_jobs;
699 self
700 }
701
702 pub fn refit(mut self, refit: bool) -> Self {
704 self.refit = refit;
705 self
706 }
707
708 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
710 self.random_state = random_state;
711 self
712 }
713
714 pub fn best_estimator(&self) -> Option<&F> {
716 self.best_estimator_.as_ref()
717 }
718
719 pub fn best_params(&self) -> Option<&ParameterSet> {
721 self.best_params_.as_ref()
722 }
723
724 pub fn best_score(&self) -> Option<f64> {
726 self.best_score_
727 }
728
729 pub fn cv_results(&self) -> Option<&GridSearchResults> {
731 self.cv_results_.as_ref()
732 }
733
734 fn sample_parameters(&self, n_samples: usize) -> Vec<ParameterSet> {
736 use scirs2_core::random::rngs::StdRng;
737 use scirs2_core::random::SeedableRng;
738
739 let mut rng = match self.random_state {
740 Some(seed) => StdRng::seed_from_u64(seed),
741 None => StdRng::seed_from_u64(42),
742 };
743
744 let mut param_sets = Vec::with_capacity(n_samples);
745
746 for _ in 0..n_samples {
747 let mut param_set = HashMap::new();
748
749 for (param_name, distribution) in &self.param_distributions {
750 let value = distribution.sample(&mut rng);
751 param_set.insert(param_name.clone(), value);
752 }
753
754 param_sets.push(param_set);
755 }
756
757 param_sets
758 }
759
760 fn evaluate_params(
762 &self,
763 params: &ParameterSet,
764 x: &Array2<Float>,
765 y: &Array1<Float>,
766 ) -> Result<(f64, f64, f64, f64)> {
767 let configured_estimator = (self.config_fn)(self.estimator.clone(), params)?;
769
770 let splits = self.cv.split(x.nrows(), None);
772 let n_splits = splits.len();
773
774 let mut test_scores = Vec::with_capacity(n_splits);
775 let mut fit_times = Vec::with_capacity(n_splits);
776 let mut score_times = Vec::with_capacity(n_splits);
777
778 for (train_idx, test_idx) in splits {
780 let x_train = x.select(scirs2_core::ndarray::Axis(0), &train_idx);
782 let y_train = y.select(scirs2_core::ndarray::Axis(0), &train_idx);
783 let x_test = x.select(scirs2_core::ndarray::Axis(0), &test_idx);
784 let y_test = y.select(scirs2_core::ndarray::Axis(0), &test_idx);
785
786 let start = std::time::Instant::now();
788 let fitted = configured_estimator.clone().fit(&x_train, &y_train)?;
789 let fit_time = start.elapsed().as_secs_f64();
790 fit_times.push(fit_time);
791
792 let start = std::time::Instant::now();
794 let test_score = match &self.scoring {
795 Scoring::EstimatorScore => fitted.score(&x_test, &y_test)?,
796 Scoring::Custom(func) => {
797 let y_pred = fitted.predict(&x_test)?;
798 func(&y_test.to_owned(), &y_pred)?
799 }
800 Scoring::Metric(metric_name) => {
801 let y_pred = fitted.predict(&x_test)?;
802 compute_score_for_regression(metric_name, &y_test, &y_pred)?
803 }
804 Scoring::Scorer(_scorer) => {
805 let y_pred = fitted.predict(&x_test)?;
806 -mean_squared_error(&y_test, &y_pred)?
808 }
809 Scoring::MultiMetric(_metrics) => {
810 fitted.score(&x_test, &y_test)?
812 }
813 };
814 let score_time = start.elapsed().as_secs_f64();
815 score_times.push(score_time);
816 test_scores.push(test_score);
817 }
818
819 let mean_test_score = test_scores.iter().sum::<f64>() / test_scores.len() as f64;
821 let std_test_score = {
822 let variance = test_scores
823 .iter()
824 .map(|&score| (score - mean_test_score).powi(2))
825 .sum::<f64>()
826 / test_scores.len() as f64;
827 variance.sqrt()
828 };
829 let mean_fit_time = fit_times.iter().sum::<f64>() / fit_times.len() as f64;
830 let mean_score_time = score_times.iter().sum::<f64>() / score_times.len() as f64;
831
832 Ok((
833 mean_test_score,
834 std_test_score,
835 mean_fit_time,
836 mean_score_time,
837 ))
838 }
839}
840
841impl<E, F, ConfigFn> Fit<Array2<Float>, Array1<Float>> for RandomizedSearchCV<E, F, ConfigFn>
842where
843 E: Clone + Send + Sync,
844 E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
845 F: Predict<Array2<Float>, Array1<Float>> + Send + Sync,
846 F: Score<Array2<Float>, Array1<Float>, Float = f64>,
847 ConfigFn: Fn(E, &ParameterSet) -> Result<E> + Send + Sync,
848{
849 type Fitted = RandomizedSearchCV<E, F, ConfigFn>;
850
851 fn fit(mut self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
852 if x.nrows() == 0 {
853 return Err(SklearsError::InvalidInput(
854 "Cannot fit on empty dataset".to_string(),
855 ));
856 }
857
858 if x.nrows() != y.len() {
859 return Err(SklearsError::ShapeMismatch {
860 expected: format!("X.shape[0] = {}", x.nrows()),
861 actual: format!("y.shape[0] = {}", y.len()),
862 });
863 }
864
865 if self.param_distributions.is_empty() {
866 return Err(SklearsError::InvalidInput(
867 "No parameter distributions to sample from".to_string(),
868 ));
869 }
870
871 let param_combinations = self.sample_parameters(self.n_iter);
873
874 let mut results = Vec::with_capacity(param_combinations.len());
876
877 for params in ¶m_combinations {
878 let (mean_score, std_score, mean_fit_time, mean_score_time) =
879 self.evaluate_params(params, x, y)?;
880
881 results.push((mean_score, std_score, mean_fit_time, mean_score_time));
882 }
883
884 let best_idx = results
886 .iter()
887 .enumerate()
888 .max_by(|(_, a), (_, b)| a.0.partial_cmp(&b.0).expect("operation should succeed"))
889 .map(|(idx, _)| idx)
890 .ok_or_else(|| SklearsError::NumericalError("No valid scores found".to_string()))?;
891
892 let best_params = param_combinations[best_idx].clone();
893 let best_score = results[best_idx].0;
894
895 let mean_test_scores = Array1::from_vec(results.iter().map(|r| r.0).collect());
897 let std_test_scores = Array1::from_vec(results.iter().map(|r| r.1).collect());
898 let mean_fit_times = Array1::from_vec(results.iter().map(|r| r.2).collect());
899 let mean_score_times = Array1::from_vec(results.iter().map(|r| r.3).collect());
900
901 let mut scores_with_idx: Vec<(f64, usize)> = mean_test_scores
903 .iter()
904 .enumerate()
905 .map(|(i, &score)| (score, i))
906 .collect();
907 scores_with_idx.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("operation should succeed"));
908
909 let mut ranks = vec![0; param_combinations.len()];
910 for (rank, (_, idx)) in scores_with_idx.iter().enumerate() {
911 ranks[*idx] = rank + 1;
912 }
913
914 let cv_results = GridSearchResults {
915 params: param_combinations.clone(),
916 mean_test_scores,
917 std_test_scores,
918 mean_fit_times,
919 mean_score_times,
920 rank_test_scores: Array1::from_vec(ranks),
921 };
922
923 let best_estimator = if self.refit {
925 let configured_estimator = (self.config_fn)(self.estimator.clone(), &best_params)?;
926 Some(configured_estimator.fit(x, y)?)
927 } else {
928 None
929 };
930
931 self.best_estimator_ = best_estimator;
932 self.best_params_ = Some(best_params);
933 self.best_score_ = Some(best_score);
934 self.cv_results_ = Some(cv_results);
935
936 Ok(self)
937 }
938}
939
940impl<E, F, ConfigFn> Predict<Array2<Float>, Array1<Float>> for RandomizedSearchCV<E, F, ConfigFn>
941where
942 E: Clone,
943 E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
944 F: Predict<Array2<Float>, Array1<Float>>,
945 F: Score<Array2<Float>, Array1<Float>, Float = f64>,
946 ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
947{
948 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
949 match &self.best_estimator_ {
950 Some(estimator) => estimator.predict(x),
951 None => Err(SklearsError::NotFitted {
952 operation: "predict".to_string(),
953 }),
954 }
955 }
956}
957
958impl<E, F, ConfigFn> Score<Array2<Float>, Array1<Float>> for RandomizedSearchCV<E, F, ConfigFn>
959where
960 E: Clone,
961 E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
962 F: Predict<Array2<Float>, Array1<Float>>,
963 F: Score<Array2<Float>, Array1<Float>, Float = f64>,
964 ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
965{
966 type Float = f64;
967 fn score(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
968 match &self.best_estimator_ {
969 Some(estimator) => estimator.score(x, y),
970 None => Err(SklearsError::NotFitted {
971 operation: "score".to_string(),
972 }),
973 }
974 }
975}
976
977#[allow(non_snake_case)]
978#[cfg(test)]
979mod tests {
980 use super::*;
981 use crate::KFold;
982 use scirs2_core::ndarray::array;
983 #[derive(Debug, Clone)]
987 struct MockRegressor {
988 n_estimators: usize,
989 learning_rate: f64,
990 random_state: Option<u64>,
991 fitted: bool,
992 }
993
994 impl MockRegressor {
995 fn new() -> Self {
996 Self {
997 n_estimators: 100,
998 learning_rate: 0.1,
999 random_state: None,
1000 fitted: false,
1001 }
1002 }
1003
1004 fn n_estimators(mut self, n: usize) -> Self {
1005 self.n_estimators = n;
1006 self
1007 }
1008
1009 fn learning_rate(mut self, lr: f64) -> Self {
1010 self.learning_rate = lr;
1011 self
1012 }
1013
1014 fn random_state(mut self, state: Option<u64>) -> Self {
1015 self.random_state = state;
1016 self
1017 }
1018 }
1019
1020 impl Fit<Array2<f64>, Array1<f64>> for MockRegressor {
1021 type Fitted = MockRegressor;
1022
1023 fn fit(mut self, _x: &Array2<f64>, _y: &Array1<f64>) -> Result<Self::Fitted> {
1024 self.fitted = true;
1025 Ok(self)
1026 }
1027 }
1028
1029 impl Predict<Array2<f64>, Array1<f64>> for MockRegressor {
1030 fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
1031 if !self.fitted {
1032 return Err(SklearsError::NotFitted {
1033 operation: "predict".to_string(),
1034 });
1035 }
1036 Ok(x.sum_axis(scirs2_core::ndarray::Axis(1)))
1038 }
1039 }
1040
1041 impl Score<Array2<f64>, Array1<f64>> for MockRegressor {
1042 type Float = f64;
1043
1044 fn score(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
1045 let y_pred = self.predict(x)?;
1046 let mse = mean_squared_error(y, &y_pred)?;
1047 Ok(-mse) }
1049 }
1050
1051 type GradientBoostingRegressor = MockRegressor;
1052
1053 #[test]
1054 fn test_parameter_value_extraction() {
1055 let int_param = ParameterValue::Int(42);
1056 assert_eq!(int_param.as_int(), Some(42));
1057 assert_eq!(int_param.as_float(), None);
1058
1059 let float_param = ParameterValue::Float(std::f64::consts::PI);
1060 assert_eq!(float_param.as_float(), Some(std::f64::consts::PI));
1061 assert_eq!(float_param.as_int(), None);
1062
1063 let opt_int_param = ParameterValue::OptionalInt(Some(10));
1064 assert_eq!(opt_int_param.as_optional_int(), Some(Some(10)));
1065 }
1066
1067 #[test]
1068 #[ignore] fn test_grid_search_cv() {
1070 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0],];
1072 let y = array![1.0, 4.0, 9.0, 16.0, 25.0]; let mut param_grid = HashMap::new();
1076 param_grid.insert(
1077 "n_estimators".to_string(),
1078 vec![ParameterValue::Int(5), ParameterValue::Int(10)],
1079 );
1080 param_grid.insert(
1081 "learning_rate".to_string(),
1082 vec![ParameterValue::Float(0.1), ParameterValue::Float(0.3)],
1083 );
1084
1085 let config_fn = |estimator: GradientBoostingRegressor,
1087 params: &ParameterSet|
1088 -> Result<GradientBoostingRegressor> {
1089 let mut configured = estimator;
1090
1091 if let Some(n_est) = params.get("n_estimators").and_then(|p| p.as_int()) {
1092 configured = configured.n_estimators(n_est as usize);
1093 }
1094
1095 if let Some(lr) = params.get("learning_rate").and_then(|p| p.as_float()) {
1096 configured = configured.learning_rate(lr);
1097 }
1098
1099 Ok(configured)
1100 };
1101
1102 let base_estimator = GradientBoostingRegressor::new().random_state(Some(42));
1104 let grid_search = GridSearchCV::new(base_estimator, param_grid, config_fn)
1105 .cv(KFold::new(3))
1106 .fit(&x, &y)
1107 .expect("operation should succeed");
1108
1109 assert!(grid_search.best_score().is_some());
1111 assert!(grid_search.best_params().is_some());
1112 assert!(grid_search.best_estimator().is_some());
1113 assert!(grid_search.cv_results().is_some());
1114
1115 let cv_results = grid_search.cv_results().expect("operation should succeed");
1117 assert_eq!(cv_results.params.len(), 4); assert_eq!(cv_results.mean_test_scores.len(), 4);
1119 assert_eq!(cv_results.rank_test_scores.len(), 4);
1120
1121 let best_rank = cv_results
1123 .rank_test_scores
1124 .iter()
1125 .min()
1126 .expect("operation should succeed");
1127 assert_eq!(*best_rank, 1);
1128
1129 let predictions = grid_search.predict(&x).expect("operation should succeed");
1131 assert_eq!(predictions.len(), x.nrows());
1132 }
1133
1134 #[test]
1135 #[ignore] fn test_grid_search_empty_grid() {
1137 let x = array![[1.0], [2.0]];
1138 let y = array![1.0, 2.0];
1139
1140 let param_grid = HashMap::new(); let config_fn = |estimator: GradientBoostingRegressor,
1142 _params: &ParameterSet|
1143 -> Result<GradientBoostingRegressor> { Ok(estimator) };
1144
1145 let base_estimator = GradientBoostingRegressor::new();
1146 let result = GridSearchCV::new(base_estimator, param_grid, config_fn)
1147 .cv(KFold::new(2)) .fit(&x, &y);
1149
1150 assert!(result.is_ok());
1152 let grid_search = result.expect("operation should succeed");
1153
1154 let cv_results = grid_search.cv_results().expect("operation should succeed");
1156 assert_eq!(cv_results.params.len(), 1);
1157 assert!(cv_results.params[0].is_empty()); }
1159
1160 #[test]
1161 fn test_parameter_distribution_sampling() {
1162 use scirs2_core::random::rngs::StdRng;
1163 use scirs2_core::random::SeedableRng;
1164 let mut rng = StdRng::seed_from_u64(42);
1165
1166 let choice_dist = ParameterDistribution::Choice(vec![
1168 ParameterValue::Int(1),
1169 ParameterValue::Int(2),
1170 ParameterValue::Int(3),
1171 ]);
1172 let sample = choice_dist.sample(&mut rng);
1173 if let ParameterValue::Int(val) = sample {
1174 assert!(val >= 1 && val <= 3);
1175 } else {
1176 panic!("Expected Int parameter value");
1177 }
1178
1179 let int_dist = ParameterDistribution::RandInt { low: 10, high: 20 };
1181 let sample = int_dist.sample(&mut rng);
1182 if let ParameterValue::Int(val) = sample {
1183 assert!(val >= 10 && val < 20);
1184 } else {
1185 panic!("Expected Int parameter value");
1186 }
1187
1188 let uniform_dist = ParameterDistribution::Uniform {
1190 low: 0.0,
1191 high: 1.0,
1192 };
1193 let sample = uniform_dist.sample(&mut rng);
1194 if let ParameterValue::Float(val) = sample {
1195 assert!(val >= 0.0 && val < 1.0);
1196 } else {
1197 panic!("Expected Float parameter value");
1198 }
1199
1200 let normal_dist = ParameterDistribution::Normal {
1202 mean: 0.0,
1203 std: 1.0,
1204 };
1205 let sample = normal_dist.sample(&mut rng);
1206 assert!(matches!(sample, ParameterValue::Float(_)));
1207 }
1208
1209 #[test]
1210 #[ignore] fn test_randomized_search_cv() {
1212 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0],];
1214 let y = array![1.0, 4.0, 9.0, 16.0, 25.0]; let mut param_distributions = HashMap::new();
1218 param_distributions.insert(
1219 "n_estimators".to_string(),
1220 ParameterDistribution::Choice(vec![
1221 ParameterValue::Int(5),
1222 ParameterValue::Int(10),
1223 ParameterValue::Int(15),
1224 ]),
1225 );
1226 param_distributions.insert(
1227 "learning_rate".to_string(),
1228 ParameterDistribution::Uniform {
1229 low: 0.05,
1230 high: 0.5,
1231 },
1232 );
1233
1234 let config_fn = |estimator: GradientBoostingRegressor,
1236 params: &ParameterSet|
1237 -> Result<GradientBoostingRegressor> {
1238 let mut configured = estimator;
1239
1240 if let Some(n_est) = params.get("n_estimators").and_then(|p| p.as_int()) {
1241 configured = configured.n_estimators(n_est as usize);
1242 }
1243
1244 if let Some(lr) = params.get("learning_rate").and_then(|p| p.as_float()) {
1245 configured = configured.learning_rate(lr);
1246 }
1247
1248 Ok(configured)
1249 };
1250
1251 let base_estimator = GradientBoostingRegressor::new().random_state(Some(42));
1253 let randomized_search =
1254 RandomizedSearchCV::new(base_estimator, param_distributions, config_fn)
1255 .n_iter(8) .cv(KFold::new(3))
1257 .random_state(Some(42))
1258 .fit(&x, &y)
1259 .expect("operation should succeed");
1260
1261 assert!(randomized_search.best_score().is_some());
1263 assert!(randomized_search.best_params().is_some());
1264 assert!(randomized_search.best_estimator().is_some());
1265 assert!(randomized_search.cv_results().is_some());
1266
1267 let cv_results = randomized_search
1269 .cv_results()
1270 .expect("operation should succeed");
1271 assert_eq!(cv_results.params.len(), 8); assert_eq!(cv_results.mean_test_scores.len(), 8);
1273 assert_eq!(cv_results.rank_test_scores.len(), 8);
1274
1275 let best_rank = cv_results
1277 .rank_test_scores
1278 .iter()
1279 .min()
1280 .expect("operation should succeed");
1281 assert_eq!(*best_rank, 1);
1282
1283 let predictions = randomized_search
1285 .predict(&x)
1286 .expect("operation should succeed");
1287 assert_eq!(predictions.len(), x.nrows());
1288
1289 for params in &cv_results.params {
1291 if let Some(n_est) = params.get("n_estimators").and_then(|p| p.as_int()) {
1292 assert!(n_est == 5 || n_est == 10 || n_est == 15);
1293 }
1294 if let Some(lr) = params.get("learning_rate").and_then(|p| p.as_float()) {
1295 assert!(lr >= 0.05 && lr < 0.5);
1296 }
1297 }
1298 }
1299
1300 #[test]
1301 #[ignore] fn test_randomized_search_empty_distributions() {
1303 let x = array![[1.0], [2.0]];
1304 let y = array![1.0, 2.0];
1305
1306 let param_distributions = HashMap::new(); let config_fn = |estimator: GradientBoostingRegressor,
1308 _params: &ParameterSet|
1309 -> Result<GradientBoostingRegressor> { Ok(estimator) };
1310
1311 let base_estimator = GradientBoostingRegressor::new();
1312 let result = RandomizedSearchCV::new(base_estimator, param_distributions, config_fn)
1313 .cv(KFold::new(2))
1314 .fit(&x, &y);
1315
1316 assert!(result.is_err());
1317 }
1318
1319 #[test]
1320 #[ignore] fn test_randomized_search_reproducibility() {
1322 let x = array![[1.0], [2.0], [3.0], [4.0]];
1323 let y = array![1.0, 2.0, 3.0, 4.0];
1324
1325 let mut param_distributions = HashMap::new();
1327 param_distributions.insert(
1328 "learning_rate".to_string(),
1329 ParameterDistribution::Uniform {
1330 low: 0.1,
1331 high: 0.5,
1332 },
1333 );
1334
1335 let config_fn = |estimator: GradientBoostingRegressor,
1336 params: &ParameterSet|
1337 -> Result<GradientBoostingRegressor> {
1338 let mut configured = estimator;
1339 if let Some(lr) = params.get("learning_rate").and_then(|p| p.as_float()) {
1340 configured = configured.learning_rate(lr);
1341 }
1342 Ok(configured)
1343 };
1344
1345 let base_estimator1 = GradientBoostingRegressor::new().random_state(Some(42));
1347 let result1 =
1348 RandomizedSearchCV::new(base_estimator1, param_distributions.clone(), config_fn)
1349 .n_iter(5)
1350 .random_state(Some(123))
1351 .cv(KFold::new(2))
1352 .fit(&x, &y)
1353 .expect("operation should succeed");
1354
1355 let base_estimator2 = GradientBoostingRegressor::new().random_state(Some(42));
1356 let result2 = RandomizedSearchCV::new(base_estimator2, param_distributions, config_fn)
1357 .n_iter(5)
1358 .random_state(Some(123))
1359 .cv(KFold::new(2))
1360 .fit(&x, &y)
1361 .expect("operation should succeed");
1362
1363 assert_eq!(result1.best_score(), result2.best_score());
1365
1366 let params1 = result1.cv_results().expect("operation should succeed");
1367 let params2 = result2.cv_results().expect("operation should succeed");
1368
1369 for (p1, p2) in params1.params.iter().zip(params2.params.iter()) {
1371 assert_eq!(p1, p2);
1372 }
1373 }
1374}