1use crate::{CrossValidator, KFold, Scoring};
4use scirs2_core::ndarray::{Array1, Array2};
5use scirs2_core::rand_prelude::IndexedRandom;
6use scirs2_core::random::essentials::Normal as RandNormal;
7use sklears_core::{
9 error::Result,
10 prelude::SklearsError,
11 traits::{Fit, Predict, Score},
12 types::Float,
13};
14use sklears_metrics::{classification::accuracy_score, get_scorer, regression::mean_squared_error};
15use std::collections::HashMap;
16use std::marker::PhantomData;
17
18pub type ParameterGrid = HashMap<String, Vec<ParameterValue>>;
23
24#[derive(Debug, Clone, PartialEq)]
26#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
27pub enum ParameterValue {
28 Int(i64),
30 Float(f64),
32 Bool(bool),
34 String(String),
36 OptionalInt(Option<i64>),
38 OptionalFloat(Option<f64>),
40}
41
42impl ParameterValue {
43 pub fn as_int(&self) -> Option<i64> {
45 match self {
46 ParameterValue::Int(v) => Some(*v),
47 _ => None,
48 }
49 }
50
51 pub fn as_float(&self) -> Option<f64> {
53 match self {
54 ParameterValue::Float(v) => Some(*v),
55 _ => None,
56 }
57 }
58
59 pub fn as_bool(&self) -> Option<bool> {
61 match self {
62 ParameterValue::Bool(v) => Some(*v),
63 _ => None,
64 }
65 }
66
67 pub fn as_optional_int(&self) -> Option<Option<i64>> {
69 match self {
70 ParameterValue::OptionalInt(v) => Some(*v),
71 _ => None,
72 }
73 }
74
75 pub fn as_optional_float(&self) -> Option<Option<f64>> {
77 match self {
78 ParameterValue::OptionalFloat(v) => Some(*v),
79 _ => None,
80 }
81 }
82}
83
84impl From<i32> for ParameterValue {
85 fn from(value: i32) -> Self {
86 ParameterValue::Int(value as i64)
87 }
88}
89
90impl From<i64> for ParameterValue {
91 fn from(value: i64) -> Self {
92 ParameterValue::Int(value)
93 }
94}
95
96impl From<f32> for ParameterValue {
97 fn from(value: f32) -> Self {
98 ParameterValue::Float(value as f64)
99 }
100}
101
102impl From<f64> for ParameterValue {
103 fn from(value: f64) -> Self {
104 ParameterValue::Float(value)
105 }
106}
107
108impl From<bool> for ParameterValue {
109 fn from(value: bool) -> Self {
110 ParameterValue::Bool(value)
111 }
112}
113
114impl From<String> for ParameterValue {
115 fn from(value: String) -> Self {
116 ParameterValue::String(value)
117 }
118}
119
120impl From<&str> for ParameterValue {
121 fn from(value: &str) -> Self {
122 ParameterValue::String(value.to_string())
123 }
124}
125
126impl From<Option<i32>> for ParameterValue {
127 fn from(value: Option<i32>) -> Self {
128 ParameterValue::OptionalInt(value.map(|v| v as i64))
129 }
130}
131
132impl From<Option<i64>> for ParameterValue {
133 fn from(value: Option<i64>) -> Self {
134 ParameterValue::OptionalInt(value)
135 }
136}
137
138impl From<Option<f64>> for ParameterValue {
139 fn from(value: Option<f64>) -> Self {
140 ParameterValue::OptionalFloat(value)
141 }
142}
143
144impl Eq for ParameterValue {}
145
146impl std::hash::Hash for ParameterValue {
147 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
148 std::mem::discriminant(self).hash(state);
149 match self {
150 ParameterValue::Int(v) => v.hash(state),
151 ParameterValue::Float(v) => v.to_bits().hash(state), ParameterValue::Bool(v) => v.hash(state),
153 ParameterValue::String(v) => v.hash(state),
154 ParameterValue::OptionalInt(v) => v.hash(state),
155 ParameterValue::OptionalFloat(v) => v.map(|f| f.to_bits()).hash(state),
156 }
157 }
158}
159
160pub type ParameterSet = HashMap<String, ParameterValue>;
162
163pub struct GridSearchCV<E, F, ConfigFn>
168where
169 E: Clone,
170 E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
171 F: Predict<Array2<Float>, Array1<Float>>,
172 F: Score<Array2<Float>, Array1<Float>, Float = f64>,
173 ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
174{
175 estimator: E,
177 param_grid: ParameterGrid,
179 cv: Box<dyn CrossValidator>,
181 scoring: Scoring,
183 n_jobs: Option<usize>,
185 refit: bool,
187 config_fn: ConfigFn,
189 _phantom: PhantomData<F>,
191 best_estimator_: Option<F>,
193 best_params_: Option<ParameterSet>,
194 best_score_: Option<f64>,
195 cv_results_: Option<GridSearchResults>,
196}
197
198#[derive(Debug, Clone)]
200pub struct GridSearchResults {
201 pub params: Vec<ParameterSet>,
202 pub mean_test_scores: Array1<f64>,
203 pub std_test_scores: Array1<f64>,
204 pub mean_fit_times: Array1<f64>,
205 pub mean_score_times: Array1<f64>,
206 pub rank_test_scores: Array1<usize>,
207}
208
209fn compute_score_for_regression(
211 metric_name: &str,
212 y_true: &Array1<f64>,
213 y_pred: &Array1<f64>,
214) -> Result<f64> {
215 match metric_name {
216 "neg_mean_squared_error" => Ok(-mean_squared_error(y_true, y_pred)?),
217 "mean_squared_error" => Ok(mean_squared_error(y_true, y_pred)?),
218 _ => {
219 Err(SklearsError::InvalidInput(format!(
221 "Metric '{}' not supported for regression",
222 metric_name
223 )))
224 }
225 }
226}
227
228fn compute_score_for_classification(
230 metric_name: &str,
231 y_true: &Array1<i32>,
232 y_pred: &Array1<i32>,
233) -> Result<f64> {
234 match metric_name {
235 "accuracy" => Ok(accuracy_score(y_true, y_pred)?),
236 _ => {
237 let scorer = get_scorer(metric_name)?;
238 scorer.score(y_true.as_slice().unwrap(), y_pred.as_slice().unwrap())
239 }
240 }
241}
242
243impl<E, F, ConfigFn> GridSearchCV<E, F, ConfigFn>
244where
245 E: Clone,
246 E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
247 F: Predict<Array2<Float>, Array1<Float>>,
248 F: Score<Array2<Float>, Array1<Float>, Float = f64>,
249 ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
250{
251 pub fn new(estimator: E, param_grid: ParameterGrid, config_fn: ConfigFn) -> Self {
253 Self {
254 estimator,
255 param_grid,
256 cv: Box::new(KFold::new(5)),
257 scoring: Scoring::EstimatorScore,
258 n_jobs: None,
259 refit: true,
260 config_fn,
261 _phantom: PhantomData,
262 best_estimator_: None,
263 best_params_: None,
264 best_score_: None,
265 cv_results_: None,
266 }
267 }
268
269 pub fn cv<C: CrossValidator + 'static>(mut self, cv: C) -> Self {
271 self.cv = Box::new(cv);
272 self
273 }
274
275 pub fn scoring(mut self, scoring: Scoring) -> Self {
277 self.scoring = scoring;
278 self
279 }
280
281 pub fn n_jobs(mut self, n_jobs: Option<usize>) -> Self {
283 self.n_jobs = n_jobs;
284 self
285 }
286
287 pub fn refit(mut self, refit: bool) -> Self {
289 self.refit = refit;
290 self
291 }
292
293 pub fn best_estimator(&self) -> Option<&F> {
295 self.best_estimator_.as_ref()
296 }
297
298 pub fn best_params(&self) -> Option<&ParameterSet> {
300 self.best_params_.as_ref()
301 }
302
303 pub fn best_score(&self) -> Option<f64> {
305 self.best_score_
306 }
307
308 pub fn cv_results(&self) -> Option<&GridSearchResults> {
310 self.cv_results_.as_ref()
311 }
312
313 fn generate_param_combinations(&self) -> Vec<ParameterSet> {
315 let mut combinations = vec![HashMap::new()];
316
317 for (param_name, param_values) in &self.param_grid {
318 let mut new_combinations = Vec::new();
319
320 for combination in combinations {
321 for param_value in param_values {
322 let mut new_combination = combination.clone();
323 new_combination.insert(param_name.clone(), param_value.clone());
324 new_combinations.push(new_combination);
325 }
326 }
327
328 combinations = new_combinations;
329 }
330
331 combinations
332 }
333
334 fn evaluate_params(
336 &self,
337 params: &ParameterSet,
338 x: &Array2<Float>,
339 y: &Array1<Float>,
340 ) -> Result<(f64, f64, f64, f64)> {
341 let configured_estimator = (self.config_fn)(self.estimator.clone(), params)?;
343
344 let splits = self.cv.split(x.nrows(), None);
346 let n_splits = splits.len();
347
348 let mut test_scores = Vec::with_capacity(n_splits);
349 let mut fit_times = Vec::with_capacity(n_splits);
350 let mut score_times = Vec::with_capacity(n_splits);
351
352 for (train_idx, test_idx) in splits {
354 let x_train = x.select(scirs2_core::ndarray::Axis(0), &train_idx);
356 let y_train = y.select(scirs2_core::ndarray::Axis(0), &train_idx);
357 let x_test = x.select(scirs2_core::ndarray::Axis(0), &test_idx);
358 let y_test = y.select(scirs2_core::ndarray::Axis(0), &test_idx);
359
360 let start = std::time::Instant::now();
362 let fitted = configured_estimator.clone().fit(&x_train, &y_train)?;
363 let fit_time = start.elapsed().as_secs_f64();
364 fit_times.push(fit_time);
365
366 let start = std::time::Instant::now();
368 let test_score = match &self.scoring {
369 Scoring::EstimatorScore => fitted.score(&x_test, &y_test)?,
370 Scoring::Custom(func) => {
371 let y_pred = fitted.predict(&x_test)?;
372 func(&y_test.to_owned(), &y_pred)?
373 }
374 Scoring::Metric(metric_name) => {
375 let y_pred = fitted.predict(&x_test)?;
376 compute_score_for_regression(metric_name, &y_test, &y_pred)?
377 }
378 Scoring::Scorer(_scorer) => {
379 let y_pred = fitted.predict(&x_test)?;
380 -mean_squared_error(&y_test, &y_pred)?
382 }
383 Scoring::MultiMetric(_metrics) => {
384 fitted.score(&x_test, &y_test)?
386 }
387 };
388 let score_time = start.elapsed().as_secs_f64();
389 score_times.push(score_time);
390 test_scores.push(test_score);
391 }
392
393 let mean_test_score = test_scores.iter().sum::<f64>() / test_scores.len() as f64;
395 let std_test_score = {
396 let variance = test_scores
397 .iter()
398 .map(|&score| (score - mean_test_score).powi(2))
399 .sum::<f64>()
400 / test_scores.len() as f64;
401 variance.sqrt()
402 };
403 let mean_fit_time = fit_times.iter().sum::<f64>() / fit_times.len() as f64;
404 let mean_score_time = score_times.iter().sum::<f64>() / score_times.len() as f64;
405
406 Ok((
407 mean_test_score,
408 std_test_score,
409 mean_fit_time,
410 mean_score_time,
411 ))
412 }
413}
414
415impl<E, F, ConfigFn> Fit<Array2<Float>, Array1<Float>> for GridSearchCV<E, F, ConfigFn>
416where
417 E: Clone + Send + Sync,
418 E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
419 F: Predict<Array2<Float>, Array1<Float>> + Send + Sync,
420 F: Score<Array2<Float>, Array1<Float>, Float = f64>,
421 ConfigFn: Fn(E, &ParameterSet) -> Result<E> + Send + Sync,
422{
423 type Fitted = GridSearchCV<E, F, ConfigFn>;
424
425 fn fit(mut self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
426 if x.nrows() == 0 {
427 return Err(SklearsError::InvalidInput(
428 "Cannot fit on empty dataset".to_string(),
429 ));
430 }
431
432 if x.nrows() != y.len() {
433 return Err(SklearsError::ShapeMismatch {
434 expected: format!("X.shape[0] = {}", x.nrows()),
435 actual: format!("y.shape[0] = {}", y.len()),
436 });
437 }
438
439 let param_combinations = self.generate_param_combinations();
441
442 if param_combinations.is_empty() {
443 return Err(SklearsError::InvalidInput(
444 "No parameter combinations to evaluate".to_string(),
445 ));
446 }
447
448 let mut results = Vec::with_capacity(param_combinations.len());
450
451 for params in ¶m_combinations {
452 let (mean_score, std_score, mean_fit_time, mean_score_time) =
453 self.evaluate_params(params, x, y)?;
454
455 results.push((mean_score, std_score, mean_fit_time, mean_score_time));
456 }
457
458 let best_idx = results
460 .iter()
461 .enumerate()
462 .max_by(|(_, a), (_, b)| a.0.partial_cmp(&b.0).unwrap())
463 .map(|(idx, _)| idx)
464 .ok_or_else(|| SklearsError::NumericalError("No valid scores found".to_string()))?;
465
466 let best_params = param_combinations[best_idx].clone();
467 let best_score = results[best_idx].0;
468
469 let mean_test_scores = Array1::from_vec(results.iter().map(|r| r.0).collect());
471 let std_test_scores = Array1::from_vec(results.iter().map(|r| r.1).collect());
472 let mean_fit_times = Array1::from_vec(results.iter().map(|r| r.2).collect());
473 let mean_score_times = Array1::from_vec(results.iter().map(|r| r.3).collect());
474
475 let mut scores_with_idx: Vec<(f64, usize)> = mean_test_scores
477 .iter()
478 .enumerate()
479 .map(|(i, &score)| (score, i))
480 .collect();
481 scores_with_idx.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
482
483 let mut ranks = vec![0; param_combinations.len()];
484 for (rank, (_, idx)) in scores_with_idx.iter().enumerate() {
485 ranks[*idx] = rank + 1;
486 }
487
488 let cv_results = GridSearchResults {
489 params: param_combinations.clone(),
490 mean_test_scores,
491 std_test_scores,
492 mean_fit_times,
493 mean_score_times,
494 rank_test_scores: Array1::from_vec(ranks),
495 };
496
497 let best_estimator = if self.refit {
499 let configured_estimator = (self.config_fn)(self.estimator.clone(), &best_params)?;
500 Some(configured_estimator.fit(x, y)?)
501 } else {
502 None
503 };
504
505 self.best_estimator_ = best_estimator;
506 self.best_params_ = Some(best_params);
507 self.best_score_ = Some(best_score);
508 self.cv_results_ = Some(cv_results);
509
510 Ok(self)
511 }
512}
513
514impl<E, F, ConfigFn> Predict<Array2<Float>, Array1<Float>> for GridSearchCV<E, F, ConfigFn>
515where
516 E: Clone,
517 E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
518 F: Predict<Array2<Float>, Array1<Float>>,
519 F: Score<Array2<Float>, Array1<Float>, Float = f64>,
520 ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
521{
522 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
523 match &self.best_estimator_ {
524 Some(estimator) => estimator.predict(x),
525 None => Err(SklearsError::NotFitted {
526 operation: "predict".to_string(),
527 }),
528 }
529 }
530}
531
532impl<E, F, ConfigFn> Score<Array2<Float>, Array1<Float>> for GridSearchCV<E, F, ConfigFn>
533where
534 E: Clone,
535 E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
536 F: Predict<Array2<Float>, Array1<Float>>,
537 F: Score<Array2<Float>, Array1<Float>, Float = f64>,
538 ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
539{
540 type Float = f64;
541 fn score(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
542 match &self.best_estimator_ {
543 Some(estimator) => estimator.score(x, y),
544 None => Err(SklearsError::NotFitted {
545 operation: "score".to_string(),
546 }),
547 }
548 }
549}
550
551#[derive(Debug, Clone)]
553pub enum ParameterDistribution {
554 Choice(Vec<ParameterValue>),
556 RandInt { low: i64, high: i64 },
558 Uniform { low: f64, high: f64 },
560 LogUniform { low: f64, high: f64 },
562 Normal { mean: f64, std: f64 },
564}
565
566impl ParameterDistribution {
567 pub fn sample(&self, rng: &mut impl scirs2_core::random::Rng) -> ParameterValue {
569 use scirs2_core::essentials::Uniform;
570 use scirs2_core::random::Distribution;
571
572 match self {
573 ParameterDistribution::Choice(values) => values.as_slice().choose(rng).unwrap().clone(),
574 ParameterDistribution::RandInt { low, high } => {
575 let dist = Uniform::new(*low, *high).unwrap();
576 ParameterValue::Int(dist.sample(rng))
577 }
578 ParameterDistribution::Uniform { low, high } => {
579 let dist = Uniform::new(*low, *high).unwrap();
580 ParameterValue::Float(dist.sample(rng))
581 }
582 ParameterDistribution::LogUniform { low, high } => {
583 let log_low = low.ln();
585 let log_high = high.ln();
586 let dist = Uniform::new(log_low, log_high).unwrap();
587 let log_sample = dist.sample(rng);
588 ParameterValue::Float(log_sample.exp())
589 }
590 ParameterDistribution::Normal { mean, std } => {
591 let dist = RandNormal::new(*mean, *std).unwrap();
592 ParameterValue::Float(dist.sample(rng))
593 }
594 }
595 }
596}
597
598pub type ParameterDistributions = HashMap<String, ParameterDistribution>;
600
601pub struct RandomizedSearchCV<E, F, ConfigFn>
606where
607 E: Clone,
608 E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
609 F: Predict<Array2<Float>, Array1<Float>>,
610 F: Score<Array2<Float>, Array1<Float>, Float = f64>,
611 ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
612{
613 estimator: E,
615 param_distributions: ParameterDistributions,
617 n_iter: usize,
619 cv: Box<dyn CrossValidator>,
621 scoring: Scoring,
623 n_jobs: Option<usize>,
625 refit: bool,
627 random_state: Option<u64>,
629 config_fn: ConfigFn,
631 _phantom: PhantomData<F>,
633 best_estimator_: Option<F>,
635 best_params_: Option<ParameterSet>,
636 best_score_: Option<f64>,
637 cv_results_: Option<GridSearchResults>,
638}
639
640impl<E, F, ConfigFn> RandomizedSearchCV<E, F, ConfigFn>
641where
642 E: Clone,
643 E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
644 F: Predict<Array2<Float>, Array1<Float>>,
645 F: Score<Array2<Float>, Array1<Float>, Float = f64>,
646 ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
647{
648 pub fn new(
649 estimator: E,
650 param_distributions: ParameterDistributions,
651 config_fn: ConfigFn,
652 ) -> Self {
653 Self {
654 estimator,
655 param_distributions,
656 n_iter: 10,
657 cv: Box::new(KFold::new(5)),
658 scoring: Scoring::EstimatorScore,
659 n_jobs: None,
660 refit: true,
661 random_state: None,
662 config_fn,
663 _phantom: PhantomData,
664 best_estimator_: None,
665 best_params_: None,
666 best_score_: None,
667 cv_results_: None,
668 }
669 }
670
671 pub fn n_iter(mut self, n_iter: usize) -> Self {
673 self.n_iter = n_iter;
674 self
675 }
676
677 pub fn cv<C: CrossValidator + 'static>(mut self, cv: C) -> Self {
679 self.cv = Box::new(cv);
680 self
681 }
682
683 pub fn scoring(mut self, scoring: Scoring) -> Self {
685 self.scoring = scoring;
686 self
687 }
688
689 pub fn n_jobs(mut self, n_jobs: Option<usize>) -> Self {
691 self.n_jobs = n_jobs;
692 self
693 }
694
695 pub fn refit(mut self, refit: bool) -> Self {
697 self.refit = refit;
698 self
699 }
700
701 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
703 self.random_state = random_state;
704 self
705 }
706
707 pub fn best_estimator(&self) -> Option<&F> {
709 self.best_estimator_.as_ref()
710 }
711
712 pub fn best_params(&self) -> Option<&ParameterSet> {
714 self.best_params_.as_ref()
715 }
716
717 pub fn best_score(&self) -> Option<f64> {
719 self.best_score_
720 }
721
722 pub fn cv_results(&self) -> Option<&GridSearchResults> {
724 self.cv_results_.as_ref()
725 }
726
727 fn sample_parameters(&self, n_samples: usize) -> Vec<ParameterSet> {
729 use scirs2_core::random::rngs::StdRng;
730 use scirs2_core::random::SeedableRng;
731
732 let mut rng = match self.random_state {
733 Some(seed) => StdRng::seed_from_u64(seed),
734 None => StdRng::seed_from_u64(42),
735 };
736
737 let mut param_sets = Vec::with_capacity(n_samples);
738
739 for _ in 0..n_samples {
740 let mut param_set = HashMap::new();
741
742 for (param_name, distribution) in &self.param_distributions {
743 let value = distribution.sample(&mut rng);
744 param_set.insert(param_name.clone(), value);
745 }
746
747 param_sets.push(param_set);
748 }
749
750 param_sets
751 }
752
753 fn evaluate_params(
755 &self,
756 params: &ParameterSet,
757 x: &Array2<Float>,
758 y: &Array1<Float>,
759 ) -> Result<(f64, f64, f64, f64)> {
760 let configured_estimator = (self.config_fn)(self.estimator.clone(), params)?;
762
763 let splits = self.cv.split(x.nrows(), None);
765 let n_splits = splits.len();
766
767 let mut test_scores = Vec::with_capacity(n_splits);
768 let mut fit_times = Vec::with_capacity(n_splits);
769 let mut score_times = Vec::with_capacity(n_splits);
770
771 for (train_idx, test_idx) in splits {
773 let x_train = x.select(scirs2_core::ndarray::Axis(0), &train_idx);
775 let y_train = y.select(scirs2_core::ndarray::Axis(0), &train_idx);
776 let x_test = x.select(scirs2_core::ndarray::Axis(0), &test_idx);
777 let y_test = y.select(scirs2_core::ndarray::Axis(0), &test_idx);
778
779 let start = std::time::Instant::now();
781 let fitted = configured_estimator.clone().fit(&x_train, &y_train)?;
782 let fit_time = start.elapsed().as_secs_f64();
783 fit_times.push(fit_time);
784
785 let start = std::time::Instant::now();
787 let test_score = match &self.scoring {
788 Scoring::EstimatorScore => fitted.score(&x_test, &y_test)?,
789 Scoring::Custom(func) => {
790 let y_pred = fitted.predict(&x_test)?;
791 func(&y_test.to_owned(), &y_pred)?
792 }
793 Scoring::Metric(metric_name) => {
794 let y_pred = fitted.predict(&x_test)?;
795 compute_score_for_regression(metric_name, &y_test, &y_pred)?
796 }
797 Scoring::Scorer(_scorer) => {
798 let y_pred = fitted.predict(&x_test)?;
799 -mean_squared_error(&y_test, &y_pred)?
801 }
802 Scoring::MultiMetric(_metrics) => {
803 fitted.score(&x_test, &y_test)?
805 }
806 };
807 let score_time = start.elapsed().as_secs_f64();
808 score_times.push(score_time);
809 test_scores.push(test_score);
810 }
811
812 let mean_test_score = test_scores.iter().sum::<f64>() / test_scores.len() as f64;
814 let std_test_score = {
815 let variance = test_scores
816 .iter()
817 .map(|&score| (score - mean_test_score).powi(2))
818 .sum::<f64>()
819 / test_scores.len() as f64;
820 variance.sqrt()
821 };
822 let mean_fit_time = fit_times.iter().sum::<f64>() / fit_times.len() as f64;
823 let mean_score_time = score_times.iter().sum::<f64>() / score_times.len() as f64;
824
825 Ok((
826 mean_test_score,
827 std_test_score,
828 mean_fit_time,
829 mean_score_time,
830 ))
831 }
832}
833
834impl<E, F, ConfigFn> Fit<Array2<Float>, Array1<Float>> for RandomizedSearchCV<E, F, ConfigFn>
835where
836 E: Clone + Send + Sync,
837 E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
838 F: Predict<Array2<Float>, Array1<Float>> + Send + Sync,
839 F: Score<Array2<Float>, Array1<Float>, Float = f64>,
840 ConfigFn: Fn(E, &ParameterSet) -> Result<E> + Send + Sync,
841{
842 type Fitted = RandomizedSearchCV<E, F, ConfigFn>;
843
844 fn fit(mut self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
845 if x.nrows() == 0 {
846 return Err(SklearsError::InvalidInput(
847 "Cannot fit on empty dataset".to_string(),
848 ));
849 }
850
851 if x.nrows() != y.len() {
852 return Err(SklearsError::ShapeMismatch {
853 expected: format!("X.shape[0] = {}", x.nrows()),
854 actual: format!("y.shape[0] = {}", y.len()),
855 });
856 }
857
858 if self.param_distributions.is_empty() {
859 return Err(SklearsError::InvalidInput(
860 "No parameter distributions to sample from".to_string(),
861 ));
862 }
863
864 let param_combinations = self.sample_parameters(self.n_iter);
866
867 let mut results = Vec::with_capacity(param_combinations.len());
869
870 for params in ¶m_combinations {
871 let (mean_score, std_score, mean_fit_time, mean_score_time) =
872 self.evaluate_params(params, x, y)?;
873
874 results.push((mean_score, std_score, mean_fit_time, mean_score_time));
875 }
876
877 let best_idx = results
879 .iter()
880 .enumerate()
881 .max_by(|(_, a), (_, b)| a.0.partial_cmp(&b.0).unwrap())
882 .map(|(idx, _)| idx)
883 .ok_or_else(|| SklearsError::NumericalError("No valid scores found".to_string()))?;
884
885 let best_params = param_combinations[best_idx].clone();
886 let best_score = results[best_idx].0;
887
888 let mean_test_scores = Array1::from_vec(results.iter().map(|r| r.0).collect());
890 let std_test_scores = Array1::from_vec(results.iter().map(|r| r.1).collect());
891 let mean_fit_times = Array1::from_vec(results.iter().map(|r| r.2).collect());
892 let mean_score_times = Array1::from_vec(results.iter().map(|r| r.3).collect());
893
894 let mut scores_with_idx: Vec<(f64, usize)> = mean_test_scores
896 .iter()
897 .enumerate()
898 .map(|(i, &score)| (score, i))
899 .collect();
900 scores_with_idx.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
901
902 let mut ranks = vec![0; param_combinations.len()];
903 for (rank, (_, idx)) in scores_with_idx.iter().enumerate() {
904 ranks[*idx] = rank + 1;
905 }
906
907 let cv_results = GridSearchResults {
908 params: param_combinations.clone(),
909 mean_test_scores,
910 std_test_scores,
911 mean_fit_times,
912 mean_score_times,
913 rank_test_scores: Array1::from_vec(ranks),
914 };
915
916 let best_estimator = if self.refit {
918 let configured_estimator = (self.config_fn)(self.estimator.clone(), &best_params)?;
919 Some(configured_estimator.fit(x, y)?)
920 } else {
921 None
922 };
923
924 self.best_estimator_ = best_estimator;
925 self.best_params_ = Some(best_params);
926 self.best_score_ = Some(best_score);
927 self.cv_results_ = Some(cv_results);
928
929 Ok(self)
930 }
931}
932
933impl<E, F, ConfigFn> Predict<Array2<Float>, Array1<Float>> for RandomizedSearchCV<E, F, ConfigFn>
934where
935 E: Clone,
936 E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
937 F: Predict<Array2<Float>, Array1<Float>>,
938 F: Score<Array2<Float>, Array1<Float>, Float = f64>,
939 ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
940{
941 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
942 match &self.best_estimator_ {
943 Some(estimator) => estimator.predict(x),
944 None => Err(SklearsError::NotFitted {
945 operation: "predict".to_string(),
946 }),
947 }
948 }
949}
950
951impl<E, F, ConfigFn> Score<Array2<Float>, Array1<Float>> for RandomizedSearchCV<E, F, ConfigFn>
952where
953 E: Clone,
954 E: Fit<Array2<Float>, Array1<Float>, Fitted = F>,
955 F: Predict<Array2<Float>, Array1<Float>>,
956 F: Score<Array2<Float>, Array1<Float>, Float = f64>,
957 ConfigFn: Fn(E, &ParameterSet) -> Result<E>,
958{
959 type Float = f64;
960 fn score(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
961 match &self.best_estimator_ {
962 Some(estimator) => estimator.score(x, y),
963 None => Err(SklearsError::NotFitted {
964 operation: "score".to_string(),
965 }),
966 }
967 }
968}
969
970#[allow(non_snake_case)]
971#[cfg(test)]
972mod tests {
973 use super::*;
974 use crate::KFold;
975 use scirs2_core::ndarray::array;
976 #[derive(Debug, Clone)]
980 struct MockRegressor {
981 n_estimators: usize,
982 learning_rate: f64,
983 random_state: Option<u64>,
984 fitted: bool,
985 }
986
987 impl MockRegressor {
988 fn new() -> Self {
989 Self {
990 n_estimators: 100,
991 learning_rate: 0.1,
992 random_state: None,
993 fitted: false,
994 }
995 }
996
997 fn n_estimators(mut self, n: usize) -> Self {
998 self.n_estimators = n;
999 self
1000 }
1001
1002 fn learning_rate(mut self, lr: f64) -> Self {
1003 self.learning_rate = lr;
1004 self
1005 }
1006
1007 fn random_state(mut self, state: Option<u64>) -> Self {
1008 self.random_state = state;
1009 self
1010 }
1011 }
1012
1013 impl Fit<Array2<f64>, Array1<f64>> for MockRegressor {
1014 type Fitted = MockRegressor;
1015
1016 fn fit(mut self, _x: &Array2<f64>, _y: &Array1<f64>) -> Result<Self::Fitted> {
1017 self.fitted = true;
1018 Ok(self)
1019 }
1020 }
1021
1022 impl Predict<Array2<f64>, Array1<f64>> for MockRegressor {
1023 fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
1024 if !self.fitted {
1025 return Err(SklearsError::NotFitted {
1026 operation: "predict".to_string(),
1027 });
1028 }
1029 Ok(x.sum_axis(scirs2_core::ndarray::Axis(1)))
1031 }
1032 }
1033
1034 impl Score<Array2<f64>, Array1<f64>> for MockRegressor {
1035 type Float = f64;
1036
1037 fn score(&self, x: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
1038 let y_pred = self.predict(x)?;
1039 let mse = mean_squared_error(y, &y_pred)?;
1040 Ok(-mse) }
1042 }
1043
1044 type GradientBoostingRegressor = MockRegressor;
1045
1046 #[test]
1047 fn test_parameter_value_extraction() {
1048 let int_param = ParameterValue::Int(42);
1049 assert_eq!(int_param.as_int(), Some(42));
1050 assert_eq!(int_param.as_float(), None);
1051
1052 let float_param = ParameterValue::Float(std::f64::consts::PI);
1053 assert_eq!(float_param.as_float(), Some(std::f64::consts::PI));
1054 assert_eq!(float_param.as_int(), None);
1055
1056 let opt_int_param = ParameterValue::OptionalInt(Some(10));
1057 assert_eq!(opt_int_param.as_optional_int(), Some(Some(10)));
1058 }
1059
1060 #[test]
1061 #[ignore] fn test_grid_search_cv() {
1063 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0],];
1065 let y = array![1.0, 4.0, 9.0, 16.0, 25.0]; let mut param_grid = HashMap::new();
1069 param_grid.insert(
1070 "n_estimators".to_string(),
1071 vec![ParameterValue::Int(5), ParameterValue::Int(10)],
1072 );
1073 param_grid.insert(
1074 "learning_rate".to_string(),
1075 vec![ParameterValue::Float(0.1), ParameterValue::Float(0.3)],
1076 );
1077
1078 let config_fn = |estimator: GradientBoostingRegressor,
1080 params: &ParameterSet|
1081 -> Result<GradientBoostingRegressor> {
1082 let mut configured = estimator;
1083
1084 if let Some(n_est) = params.get("n_estimators").and_then(|p| p.as_int()) {
1085 configured = configured.n_estimators(n_est as usize);
1086 }
1087
1088 if let Some(lr) = params.get("learning_rate").and_then(|p| p.as_float()) {
1089 configured = configured.learning_rate(lr);
1090 }
1091
1092 Ok(configured)
1093 };
1094
1095 let base_estimator = GradientBoostingRegressor::new().random_state(Some(42));
1097 let grid_search = GridSearchCV::new(base_estimator, param_grid, config_fn)
1098 .cv(KFold::new(3))
1099 .fit(&x, &y)
1100 .unwrap();
1101
1102 assert!(grid_search.best_score().is_some());
1104 assert!(grid_search.best_params().is_some());
1105 assert!(grid_search.best_estimator().is_some());
1106 assert!(grid_search.cv_results().is_some());
1107
1108 let cv_results = grid_search.cv_results().unwrap();
1110 assert_eq!(cv_results.params.len(), 4); assert_eq!(cv_results.mean_test_scores.len(), 4);
1112 assert_eq!(cv_results.rank_test_scores.len(), 4);
1113
1114 let best_rank = cv_results.rank_test_scores.iter().min().unwrap();
1116 assert_eq!(*best_rank, 1);
1117
1118 let predictions = grid_search.predict(&x).unwrap();
1120 assert_eq!(predictions.len(), x.nrows());
1121 }
1122
1123 #[test]
1124 #[ignore] fn test_grid_search_empty_grid() {
1126 let x = array![[1.0], [2.0]];
1127 let y = array![1.0, 2.0];
1128
1129 let param_grid = HashMap::new(); let config_fn = |estimator: GradientBoostingRegressor,
1131 _params: &ParameterSet|
1132 -> Result<GradientBoostingRegressor> { Ok(estimator) };
1133
1134 let base_estimator = GradientBoostingRegressor::new();
1135 let result = GridSearchCV::new(base_estimator, param_grid, config_fn)
1136 .cv(KFold::new(2)) .fit(&x, &y);
1138
1139 assert!(result.is_ok());
1141 let grid_search = result.unwrap();
1142
1143 let cv_results = grid_search.cv_results().unwrap();
1145 assert_eq!(cv_results.params.len(), 1);
1146 assert!(cv_results.params[0].is_empty()); }
1148
1149 #[test]
1150 fn test_parameter_distribution_sampling() {
1151 use scirs2_core::random::rngs::StdRng;
1152 use scirs2_core::random::SeedableRng;
1153 let mut rng = StdRng::seed_from_u64(42);
1154
1155 let choice_dist = ParameterDistribution::Choice(vec![
1157 ParameterValue::Int(1),
1158 ParameterValue::Int(2),
1159 ParameterValue::Int(3),
1160 ]);
1161 let sample = choice_dist.sample(&mut rng);
1162 if let ParameterValue::Int(val) = sample {
1163 assert!(val >= 1 && val <= 3);
1164 } else {
1165 panic!("Expected Int parameter value");
1166 }
1167
1168 let int_dist = ParameterDistribution::RandInt { low: 10, high: 20 };
1170 let sample = int_dist.sample(&mut rng);
1171 if let ParameterValue::Int(val) = sample {
1172 assert!(val >= 10 && val < 20);
1173 } else {
1174 panic!("Expected Int parameter value");
1175 }
1176
1177 let uniform_dist = ParameterDistribution::Uniform {
1179 low: 0.0,
1180 high: 1.0,
1181 };
1182 let sample = uniform_dist.sample(&mut rng);
1183 if let ParameterValue::Float(val) = sample {
1184 assert!(val >= 0.0 && val < 1.0);
1185 } else {
1186 panic!("Expected Float parameter value");
1187 }
1188
1189 let normal_dist = ParameterDistribution::Normal {
1191 mean: 0.0,
1192 std: 1.0,
1193 };
1194 let sample = normal_dist.sample(&mut rng);
1195 assert!(matches!(sample, ParameterValue::Float(_)));
1196 }
1197
1198 #[test]
1199 #[ignore] fn test_randomized_search_cv() {
1201 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0],];
1203 let y = array![1.0, 4.0, 9.0, 16.0, 25.0]; let mut param_distributions = HashMap::new();
1207 param_distributions.insert(
1208 "n_estimators".to_string(),
1209 ParameterDistribution::Choice(vec![
1210 ParameterValue::Int(5),
1211 ParameterValue::Int(10),
1212 ParameterValue::Int(15),
1213 ]),
1214 );
1215 param_distributions.insert(
1216 "learning_rate".to_string(),
1217 ParameterDistribution::Uniform {
1218 low: 0.05,
1219 high: 0.5,
1220 },
1221 );
1222
1223 let config_fn = |estimator: GradientBoostingRegressor,
1225 params: &ParameterSet|
1226 -> Result<GradientBoostingRegressor> {
1227 let mut configured = estimator;
1228
1229 if let Some(n_est) = params.get("n_estimators").and_then(|p| p.as_int()) {
1230 configured = configured.n_estimators(n_est as usize);
1231 }
1232
1233 if let Some(lr) = params.get("learning_rate").and_then(|p| p.as_float()) {
1234 configured = configured.learning_rate(lr);
1235 }
1236
1237 Ok(configured)
1238 };
1239
1240 let base_estimator = GradientBoostingRegressor::new().random_state(Some(42));
1242 let randomized_search =
1243 RandomizedSearchCV::new(base_estimator, param_distributions, config_fn)
1244 .n_iter(8) .cv(KFold::new(3))
1246 .random_state(Some(42))
1247 .fit(&x, &y)
1248 .unwrap();
1249
1250 assert!(randomized_search.best_score().is_some());
1252 assert!(randomized_search.best_params().is_some());
1253 assert!(randomized_search.best_estimator().is_some());
1254 assert!(randomized_search.cv_results().is_some());
1255
1256 let cv_results = randomized_search.cv_results().unwrap();
1258 assert_eq!(cv_results.params.len(), 8); assert_eq!(cv_results.mean_test_scores.len(), 8);
1260 assert_eq!(cv_results.rank_test_scores.len(), 8);
1261
1262 let best_rank = cv_results.rank_test_scores.iter().min().unwrap();
1264 assert_eq!(*best_rank, 1);
1265
1266 let predictions = randomized_search.predict(&x).unwrap();
1268 assert_eq!(predictions.len(), x.nrows());
1269
1270 for params in &cv_results.params {
1272 if let Some(n_est) = params.get("n_estimators").and_then(|p| p.as_int()) {
1273 assert!(n_est == 5 || n_est == 10 || n_est == 15);
1274 }
1275 if let Some(lr) = params.get("learning_rate").and_then(|p| p.as_float()) {
1276 assert!(lr >= 0.05 && lr < 0.5);
1277 }
1278 }
1279 }
1280
1281 #[test]
1282 #[ignore] fn test_randomized_search_empty_distributions() {
1284 let x = array![[1.0], [2.0]];
1285 let y = array![1.0, 2.0];
1286
1287 let param_distributions = HashMap::new(); let config_fn = |estimator: GradientBoostingRegressor,
1289 _params: &ParameterSet|
1290 -> Result<GradientBoostingRegressor> { Ok(estimator) };
1291
1292 let base_estimator = GradientBoostingRegressor::new();
1293 let result = RandomizedSearchCV::new(base_estimator, param_distributions, config_fn)
1294 .cv(KFold::new(2))
1295 .fit(&x, &y);
1296
1297 assert!(result.is_err());
1298 }
1299
1300 #[test]
1301 #[ignore] fn test_randomized_search_reproducibility() {
1303 let x = array![[1.0], [2.0], [3.0], [4.0]];
1304 let y = array![1.0, 2.0, 3.0, 4.0];
1305
1306 let mut param_distributions = HashMap::new();
1308 param_distributions.insert(
1309 "learning_rate".to_string(),
1310 ParameterDistribution::Uniform {
1311 low: 0.1,
1312 high: 0.5,
1313 },
1314 );
1315
1316 let config_fn = |estimator: GradientBoostingRegressor,
1317 params: &ParameterSet|
1318 -> Result<GradientBoostingRegressor> {
1319 let mut configured = estimator;
1320 if let Some(lr) = params.get("learning_rate").and_then(|p| p.as_float()) {
1321 configured = configured.learning_rate(lr);
1322 }
1323 Ok(configured)
1324 };
1325
1326 let base_estimator1 = GradientBoostingRegressor::new().random_state(Some(42));
1328 let result1 =
1329 RandomizedSearchCV::new(base_estimator1, param_distributions.clone(), config_fn)
1330 .n_iter(5)
1331 .random_state(Some(123))
1332 .cv(KFold::new(2))
1333 .fit(&x, &y)
1334 .unwrap();
1335
1336 let base_estimator2 = GradientBoostingRegressor::new().random_state(Some(42));
1337 let result2 = RandomizedSearchCV::new(base_estimator2, param_distributions, config_fn)
1338 .n_iter(5)
1339 .random_state(Some(123))
1340 .cv(KFold::new(2))
1341 .fit(&x, &y)
1342 .unwrap();
1343
1344 assert_eq!(result1.best_score(), result2.best_score());
1346
1347 let params1 = result1.cv_results().unwrap();
1348 let params2 = result2.cv_results().unwrap();
1349
1350 for (p1, p2) in params1.params.iter().zip(params2.params.iter()) {
1352 assert_eq!(p1, p2);
1353 }
1354 }
1355}