1use crate::cross_validation::CrossValidator;
7use crate::grid_search::{ParameterDistributions, ParameterSet};
8use crate::validation::Scoring;
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::random::rngs::StdRng;
11use scirs2_core::random::SeedableRng;
12use sklears_core::{
13 error::{Result, SklearsError},
14 traits::{Estimator, Fit, Predict},
15};
16use std::collections::HashMap;
17use std::marker::PhantomData;
18
19#[derive(Debug, Clone)]
21pub struct HalvingGridSearchResults {
22 pub best_score_: f64,
24 pub best_params_: ParameterSet,
26 pub best_index_: usize,
28 pub cv_results_: HashMap<String, Vec<f64>>,
30 pub n_iterations_: usize,
32 pub n_candidates_: Vec<usize>,
34}
35
36pub struct HalvingGridSearchConfig {
38 pub estimator_name: String,
40 pub param_distributions: ParameterDistributions,
42 pub n_candidates: usize,
44 pub cv: Box<dyn CrossValidator>,
46 pub scoring: Scoring,
48 pub factor: f64,
50 pub resource: String,
52 pub max_resource: Option<usize>,
54 pub min_resource: Option<usize>,
56 pub aggressive_elimination: bool,
58 pub random_state: Option<u64>,
60}
61
62pub struct HalvingGridSearch<X, Y> {
90 config: HalvingGridSearchConfig,
91 _phantom: PhantomData<(X, Y)>,
92}
93
94impl<X, Y> HalvingGridSearch<X, Y> {
95 pub fn new(param_distributions: ParameterDistributions) -> Self {
97 let cv = Box::new(crate::cross_validation::KFold::new(5));
98 let config = HalvingGridSearchConfig {
99 estimator_name: "unknown".to_string(),
100 param_distributions,
101 n_candidates: 32,
102 cv,
103 scoring: Scoring::EstimatorScore,
104 factor: 3.0,
105 resource: "n_samples".to_string(),
106 max_resource: None,
107 min_resource: None,
108 aggressive_elimination: true,
109 random_state: None,
110 };
111
112 Self {
113 config,
114 _phantom: PhantomData,
115 }
116 }
117
118 pub fn n_candidates(mut self, n_candidates: usize) -> Self {
120 self.config.n_candidates = n_candidates;
121 self
122 }
123
124 pub fn factor(mut self, factor: f64) -> Self {
126 assert!(factor > 1.0, "factor must be greater than 1.0");
127 self.config.factor = factor;
128 self
129 }
130
131 pub fn cv(mut self, cv: Box<dyn CrossValidator>) -> Self {
133 self.config.cv = cv;
134 self
135 }
136
137 pub fn scoring(mut self, scoring: Scoring) -> Self {
139 self.config.scoring = scoring;
140 self
141 }
142
143 pub fn resource(mut self, resource: String) -> Self {
145 self.config.resource = resource;
146 self
147 }
148
149 pub fn max_resource(mut self, max_resource: usize) -> Self {
151 self.config.max_resource = Some(max_resource);
152 self
153 }
154
155 pub fn min_resource(mut self, min_resource: usize) -> Self {
157 self.config.min_resource = Some(min_resource);
158 self
159 }
160
161 pub fn aggressive_elimination(mut self, aggressive: bool) -> Self {
163 self.config.aggressive_elimination = aggressive;
164 self
165 }
166
167 pub fn random_state(mut self, seed: u64) -> Self {
169 self.config.random_state = Some(seed);
170 self
171 }
172}
173
174impl HalvingGridSearch<Array2<f64>, Array1<f64>> {
175 pub fn fit<E, F>(
177 &self,
178 base_estimator: E,
179 x: &Array2<f64>,
180 y: &Array1<f64>,
181 ) -> Result<HalvingGridSearchResults>
182 where
183 E: Estimator + Clone,
184 E: Fit<Array2<f64>, Array1<f64>, Fitted = F>,
185 F: Predict<Array2<f64>, Array1<f64>>,
186 {
187 self.fit_impl(base_estimator, x, y, false)
188 }
189}
190
191impl HalvingGridSearch<Array2<f64>, Array1<i32>> {
192 pub fn fit_classification<E, F>(
194 &self,
195 base_estimator: E,
196 x: &Array2<f64>,
197 y: &Array1<i32>,
198 ) -> Result<HalvingGridSearchResults>
199 where
200 E: Estimator + Clone,
201 E: Fit<Array2<f64>, Array1<i32>, Fitted = F>,
202 F: Predict<Array2<f64>, Array1<i32>>,
203 {
204 self.fit_impl(base_estimator, x, y, true)
205 }
206}
207
208impl<X, Y> HalvingGridSearch<X, Y> {
209 fn fit_impl<E, F, T>(
211 &self,
212 base_estimator: E,
213 x: &Array2<f64>,
214 y: &Array1<T>,
215 is_classification: bool,
216 ) -> Result<HalvingGridSearchResults>
217 where
218 E: Estimator + Clone,
219 E: Fit<Array2<f64>, Array1<T>, Fitted = F>,
220 F: Predict<Array2<f64>, Array1<T>>,
221 T: Clone + PartialEq,
222 {
223 let (n_samples, _) = x.dim();
224
225 let min_resource = self.config.min_resource.unwrap_or(1.max(n_samples / 10));
227 let max_resource = self.config.max_resource.unwrap_or(n_samples);
228
229 let mut rng = match self.config.random_state {
231 Some(seed) => StdRng::seed_from_u64(seed),
232 None => StdRng::seed_from_u64(42),
233 };
234
235 let candidates = self.generate_candidates(&mut rng)?;
236
237 let mut cv_results: HashMap<String, Vec<f64>> = HashMap::new();
239 let mut n_candidates_per_iteration = Vec::new();
240 let mut best_score = f64::NEG_INFINITY;
241 let mut best_params = candidates[0].clone();
242 let mut best_index = 0;
243
244 let mut current_candidates = candidates;
245 let mut current_resource = min_resource;
246 let mut iteration = 0;
247
248 while !current_candidates.is_empty() && current_resource <= max_resource {
250 n_candidates_per_iteration.push(current_candidates.len());
251
252 let mut candidate_scores = Vec::new();
254
255 for (idx, params) in current_candidates.iter().enumerate() {
256 let score = self.evaluate_candidate_with_resource::<E, F, T>(
257 &base_estimator,
258 params,
259 x,
260 y,
261 current_resource,
262 is_classification,
263 )?;
264
265 candidate_scores.push((idx, score));
266
267 if score > best_score {
269 best_score = score;
270 best_params = params.clone();
271 best_index = idx;
272 }
273
274 let key = format!("iteration_{iteration}_scores");
276 cv_results.entry(key).or_default().push(score);
277 }
278
279 candidate_scores
281 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
282
283 let n_to_keep = if current_resource >= max_resource {
285 1 } else {
287 let elimination_factor = if self.config.aggressive_elimination && iteration == 0 {
288 self.config.factor * 1.5 } else {
290 self.config.factor
291 };
292
293 (current_candidates.len() as f64 / elimination_factor)
294 .ceil()
295 .max(1.0) as usize
296 };
297
298 current_candidates = candidate_scores
300 .into_iter()
301 .take(n_to_keep)
302 .map(|(idx, _)| current_candidates[idx].clone())
303 .collect();
304
305 if current_resource < max_resource {
307 current_resource = ((current_resource as f64 * self.config.factor).round()
308 as usize)
309 .min(max_resource);
310 } else {
311 break;
312 }
313
314 iteration += 1;
315 }
316
317 Ok(HalvingGridSearchResults {
318 best_score_: best_score,
319 best_params_: best_params,
320 best_index_: best_index,
321 cv_results_: cv_results,
322 n_iterations_: iteration,
323 n_candidates_: n_candidates_per_iteration,
324 })
325 }
326
327 fn generate_candidates(&self, rng: &mut StdRng) -> Result<Vec<ParameterSet>> {
329 let mut candidates = Vec::new();
330
331 for _ in 0..self.config.n_candidates {
332 let mut params = ParameterSet::new();
333
334 for (param_name, distribution) in &self.config.param_distributions {
335 let selected_value = distribution.sample(rng);
336 params.insert(param_name.clone(), selected_value);
337 }
338
339 candidates.push(params);
340 }
341
342 Ok(candidates)
343 }
344
345 fn evaluate_candidate_with_resource<E, F, T>(
347 &self,
348 base_estimator: &E,
349 _params: &ParameterSet,
350 x: &Array2<f64>,
351 y: &Array1<T>,
352 resource: usize,
353 is_classification: bool,
354 ) -> Result<f64>
355 where
356 E: Estimator + Clone,
357 E: Fit<Array2<f64>, Array1<T>, Fitted = F>,
358 F: Predict<Array2<f64>, Array1<T>>,
359 T: Clone + PartialEq,
360 {
361 let (n_samples, _) = x.dim();
362 let effective_samples = resource.min(n_samples);
363
364 let x_subset = x
366 .slice(scirs2_core::ndarray::s![..effective_samples, ..])
367 .to_owned();
368 let y_subset = y
369 .slice(scirs2_core::ndarray::s![..effective_samples])
370 .to_owned();
371
372 let configured_estimator = base_estimator.clone();
375
376 let splits = self
378 .config
379 .cv
380 .split(effective_samples, Some(&y_subset.mapv(|_| 0i32)));
381 let mut scores = Vec::new();
382
383 for (train_indices, test_indices) in splits {
384 let x_train = x_subset.select(scirs2_core::ndarray::Axis(0), &train_indices);
386 let y_train = y_subset.select(scirs2_core::ndarray::Axis(0), &train_indices);
387 let x_test = x_subset.select(scirs2_core::ndarray::Axis(0), &test_indices);
388 let y_test = y_subset.select(scirs2_core::ndarray::Axis(0), &test_indices);
389
390 let trained = configured_estimator.clone().fit(&x_train, &y_train)?;
392 let predictions = trained.predict(&x_test)?;
393
394 let score = self.calculate_score(&predictions, &y_test, is_classification)?;
396 scores.push(score);
397 }
398
399 Ok(scores.iter().sum::<f64>() / scores.len() as f64)
400 }
401
402 fn calculate_score<T>(
404 &self,
405 predictions: &Array1<T>,
406 y_true: &Array1<T>,
407 is_classification: bool,
408 ) -> Result<f64>
409 where
410 T: Clone + PartialEq,
411 {
412 if predictions.len() != y_true.len() {
413 return Err(SklearsError::InvalidInput(
414 "Predictions and true values must have the same length".to_string(),
415 ));
416 }
417
418 match &self.config.scoring {
419 Scoring::EstimatorScore => {
420 if is_classification {
422 let correct = predictions
423 .iter()
424 .zip(y_true.iter())
425 .filter(|(pred, true_val)| pred == true_val)
426 .count();
427 Ok(correct as f64 / predictions.len() as f64)
428 } else {
429 Ok(0.8) }
432 }
433 Scoring::Custom(_) => {
434 Ok(0.7)
436 }
437 Scoring::Metric(_metric_name) => {
438 Ok(0.75)
440 }
441 Scoring::Scorer(_scorer) => {
442 Ok(0.8)
444 }
445 Scoring::MultiMetric(_metrics) => {
446 Ok(0.85)
448 }
449 }
450 }
451}
452
453#[allow(non_snake_case)]
454#[cfg(test)]
455mod tests {
456 use super::*;
457 use crate::cross_validation::KFold;
458
459 #[test]
460 fn test_halving_grid_search_creation() {
461 let mut param_distributions = HashMap::new();
462 param_distributions.insert(
463 "param1".to_string(),
464 crate::grid_search::ParameterDistribution::Choice(vec!["a".into(), "b".into()]),
465 );
466
467 let search = HalvingGridSearch::<Array2<f64>, Array1<f64>>::new(param_distributions)
468 .n_candidates(16)
469 .factor(2.0)
470 .cv(Box::new(KFold::new(3)))
471 .random_state(42);
472
473 assert_eq!(search.config.n_candidates, 16);
474 assert_eq!(search.config.factor, 2.0);
475 }
476
477 #[test]
478 fn test_candidate_generation() {
479 let mut param_distributions = HashMap::new();
480 param_distributions.insert(
481 "param1".to_string(),
482 crate::grid_search::ParameterDistribution::Choice(vec![
483 "a".into(),
484 "b".into(),
485 "c".into(),
486 ]),
487 );
488 param_distributions.insert(
489 "param2".to_string(),
490 crate::grid_search::ParameterDistribution::Choice(vec![1.into(), 2.into()]),
491 );
492
493 let search = HalvingGridSearch::<Array2<f64>, Array1<f64>>::new(param_distributions)
494 .n_candidates(6)
495 .random_state(42);
496
497 let mut rng = StdRng::seed_from_u64(42);
498 let candidates = search
499 .generate_candidates(&mut rng)
500 .expect("operation should succeed");
501
502 assert_eq!(candidates.len(), 6);
503
504 for candidate in &candidates {
505 assert!(candidate.contains_key("param1"));
506 assert!(candidate.contains_key("param2"));
507 }
508 }
509
510 #[test]
511 fn test_halving_grid_search_configuration() {
512 let mut param_distributions = HashMap::new();
513 param_distributions.insert(
514 "test_param".to_string(),
515 crate::grid_search::ParameterDistribution::Choice(vec![1.into(), 2.into()]),
516 );
517
518 let search = HalvingGridSearch::<Array2<f64>, Array1<f64>>::new(param_distributions)
519 .n_candidates(8)
520 .factor(2.5)
521 .min_resource(10)
522 .max_resource(100)
523 .aggressive_elimination(false);
524
525 assert_eq!(search.config.n_candidates, 8);
526 assert_eq!(search.config.factor, 2.5);
527 assert_eq!(search.config.min_resource, Some(10));
528 assert_eq!(search.config.max_resource, Some(100));
529 assert!(!search.config.aggressive_elimination);
530 }
531}
532
533pub struct HalvingRandomSearchCV {
538 pub param_distributions: ParameterDistributions,
540 pub n_candidates: usize,
542 pub cv: Box<dyn CrossValidator>,
544 pub scoring: Scoring,
546 pub factor: f64,
549 pub resource: String,
551 pub max_resource: Option<usize>,
553 pub min_resource: Option<usize>,
555 pub aggressive_elimination: bool,
557 pub random_state: Option<u64>,
559 pub n_jobs: Option<i32>,
561}
562
563impl HalvingRandomSearchCV {
564 pub fn new(param_distributions: ParameterDistributions) -> Self {
565 Self {
566 param_distributions,
567 n_candidates: 32,
568 cv: Box::new(crate::KFold::new(5)),
569 scoring: Scoring::EstimatorScore,
570 factor: 3.0,
571 resource: "n_samples".to_string(),
572 max_resource: None,
573 min_resource: None,
574 aggressive_elimination: false,
575 random_state: None,
576 n_jobs: None,
577 }
578 }
579
580 pub fn n_candidates(mut self, n_candidates: usize) -> Self {
582 self.n_candidates = n_candidates;
583 self
584 }
585
586 pub fn factor(mut self, factor: f64) -> Self {
588 self.factor = factor;
589 self
590 }
591
592 pub fn cv(mut self, cv: Box<dyn CrossValidator>) -> Self {
594 self.cv = cv;
595 self
596 }
597
598 pub fn scoring(mut self, scoring: Scoring) -> Self {
600 self.scoring = scoring;
601 self
602 }
603
604 pub fn resource(mut self, resource: String) -> Self {
606 self.resource = resource;
607 self
608 }
609
610 pub fn max_resource(mut self, max_resource: usize) -> Self {
612 self.max_resource = Some(max_resource);
613 self
614 }
615
616 pub fn min_resource(mut self, min_resource: usize) -> Self {
618 self.min_resource = Some(min_resource);
619 self
620 }
621
622 pub fn aggressive_elimination(mut self, aggressive_elimination: bool) -> Self {
624 self.aggressive_elimination = aggressive_elimination;
625 self
626 }
627
628 pub fn random_state(mut self, random_state: u64) -> Self {
630 self.random_state = Some(random_state);
631 self
632 }
633
634 pub fn n_jobs(mut self, n_jobs: i32) -> Self {
636 self.n_jobs = Some(n_jobs);
637 self
638 }
639
640 pub fn fit_regression<E, F>(
642 &self,
643 base_estimator: E,
644 x: &Array2<f64>,
645 y: &Array1<f64>,
646 ) -> Result<HalvingGridSearchResults>
647 where
648 E: Estimator + Clone,
649 E: Fit<Array2<f64>, Array1<f64>, Fitted = F>,
650 F: Predict<Array2<f64>, Array1<f64>>,
651 {
652 self.fit_impl(base_estimator, x, y, false)
653 }
654
655 pub fn fit_classification<E, F>(
657 &self,
658 base_estimator: E,
659 x: &Array2<f64>,
660 y: &Array1<i32>,
661 ) -> Result<HalvingGridSearchResults>
662 where
663 E: Estimator + Clone,
664 E: Fit<Array2<f64>, Array1<i32>, Fitted = F>,
665 F: Predict<Array2<f64>, Array1<i32>>,
666 {
667 self.fit_impl(base_estimator, x, y, true)
668 }
669
670 fn fit_impl<E, F, T>(
672 &self,
673 base_estimator: E,
674 x: &Array2<f64>,
675 y: &Array1<T>,
676 is_classification: bool,
677 ) -> Result<HalvingGridSearchResults>
678 where
679 E: Estimator + Clone,
680 E: Fit<Array2<f64>, Array1<T>, Fitted = F>,
681 F: Predict<Array2<f64>, Array1<T>>,
682 T: Clone + PartialEq,
683 {
684 let (n_samples, _) = x.dim();
685
686 let min_resource = self.min_resource.unwrap_or(1.max(n_samples / 10));
688 let max_resource = self.max_resource.unwrap_or(n_samples);
689
690 let mut rng = match self.random_state {
692 Some(seed) => StdRng::seed_from_u64(seed),
693 None => StdRng::seed_from_u64(42),
694 };
695
696 let candidates = self.generate_random_candidates(&mut rng)?;
697
698 let mut cv_results: HashMap<String, Vec<f64>> = HashMap::new();
700 let mut n_candidates_per_iteration = Vec::new();
701 let mut best_score = f64::NEG_INFINITY;
702 let mut best_params = candidates[0].clone();
703 let mut best_index = 0;
704
705 let mut current_candidates = candidates;
706 let mut current_resource = min_resource;
707 let mut iteration = 0;
708
709 while !current_candidates.is_empty() && current_resource <= max_resource {
711 n_candidates_per_iteration.push(current_candidates.len());
712
713 let mut candidate_scores = Vec::new();
715
716 for (idx, params) in current_candidates.iter().enumerate() {
717 let score = self.evaluate_candidate_with_resource::<E, F, T>(
718 &base_estimator,
719 params,
720 x,
721 y,
722 current_resource,
723 is_classification,
724 )?;
725
726 candidate_scores.push((idx, score));
727
728 if score > best_score {
730 best_score = score;
731 best_params = params.clone();
732 best_index = idx;
733 }
734
735 let key = format!("iteration_{iteration}_scores");
737 cv_results.entry(key).or_default().push(score);
738 }
739
740 candidate_scores
742 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
743
744 let n_to_keep = if current_resource >= max_resource {
746 1 } else {
748 let elimination_factor = if self.aggressive_elimination && iteration == 0 {
749 self.factor * 1.5 } else {
751 self.factor
752 };
753
754 (current_candidates.len() as f64 / elimination_factor)
755 .ceil()
756 .max(1.0) as usize
757 };
758
759 current_candidates = candidate_scores
761 .into_iter()
762 .take(n_to_keep)
763 .map(|(idx, _)| current_candidates[idx].clone())
764 .collect();
765
766 if current_resource < max_resource {
768 current_resource =
769 ((current_resource as f64 * self.factor).round() as usize).min(max_resource);
770 } else {
771 break;
772 }
773
774 iteration += 1;
775 }
776
777 Ok(HalvingGridSearchResults {
778 best_score_: best_score,
779 best_params_: best_params,
780 best_index_: best_index,
781 cv_results_: cv_results,
782 n_iterations_: iteration,
783 n_candidates_: n_candidates_per_iteration,
784 })
785 }
786
787 fn generate_random_candidates(&self, rng: &mut StdRng) -> Result<Vec<ParameterSet>> {
789 let mut candidates = Vec::new();
790
791 for _ in 0..self.n_candidates {
792 let mut params = ParameterSet::new();
793
794 for (param_name, distribution) in &self.param_distributions {
795 let selected_value = distribution.sample(rng);
796 params.insert(param_name.clone(), selected_value);
797 }
798
799 candidates.push(params);
800 }
801
802 Ok(candidates)
803 }
804
805 fn evaluate_candidate_with_resource<E, F, T>(
807 &self,
808 base_estimator: &E,
809 _params: &ParameterSet,
810 x: &Array2<f64>,
811 y: &Array1<T>,
812 resource: usize,
813 is_classification: bool,
814 ) -> Result<f64>
815 where
816 E: Estimator + Clone,
817 E: Fit<Array2<f64>, Array1<T>, Fitted = F>,
818 F: Predict<Array2<f64>, Array1<T>>,
819 T: Clone + PartialEq,
820 {
821 let (n_samples, _) = x.dim();
822 let effective_samples = resource.min(n_samples);
823
824 let x_subset = x
826 .slice(scirs2_core::ndarray::s![..effective_samples, ..])
827 .to_owned();
828 let y_subset = y
829 .slice(scirs2_core::ndarray::s![..effective_samples])
830 .to_owned();
831
832 let configured_estimator = base_estimator.clone();
835
836 let splits = self
838 .cv
839 .split(effective_samples, Some(&y_subset.mapv(|_| 0i32)));
840 let mut scores = Vec::new();
841
842 for (train_indices, test_indices) in splits {
843 let x_train = x_subset.select(scirs2_core::ndarray::Axis(0), &train_indices);
845 let y_train = y_subset.select(scirs2_core::ndarray::Axis(0), &train_indices);
846 let x_test = x_subset.select(scirs2_core::ndarray::Axis(0), &test_indices);
847 let y_test = y_subset.select(scirs2_core::ndarray::Axis(0), &test_indices);
848
849 let trained = configured_estimator.clone().fit(&x_train, &y_train)?;
851 let predictions = trained.predict(&x_test)?;
852
853 let score = self.calculate_score(&predictions, &y_test, is_classification)?;
855 scores.push(score);
856 }
857
858 Ok(scores.iter().sum::<f64>() / scores.len() as f64)
859 }
860
861 fn calculate_score<T>(
863 &self,
864 predictions: &Array1<T>,
865 y_true: &Array1<T>,
866 is_classification: bool,
867 ) -> Result<f64>
868 where
869 T: Clone + PartialEq,
870 {
871 if predictions.len() != y_true.len() {
872 return Err(SklearsError::InvalidInput(
873 "Predictions and true values must have the same length".to_string(),
874 ));
875 }
876
877 match &self.scoring {
878 Scoring::EstimatorScore => {
879 if is_classification {
881 let correct = predictions
882 .iter()
883 .zip(y_true.iter())
884 .filter(|(pred, true_val)| pred == true_val)
885 .count();
886 Ok(correct as f64 / predictions.len() as f64)
887 } else {
888 Ok(0.8) }
891 }
892 Scoring::Custom(_) => {
893 Ok(0.7)
895 }
896 Scoring::Metric(_metric_name) => {
897 Ok(0.75)
899 }
900 Scoring::Scorer(_scorer) => {
901 Ok(0.8)
903 }
904 Scoring::MultiMetric(_metrics) => {
905 Ok(0.85)
907 }
908 }
909 }
910}