1use crate::cross_validation::CrossValidator;
7use crate::grid_search::{ParameterDistributions, ParameterSet};
8use crate::validation::Scoring;
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::random::rngs::StdRng;
11use scirs2_core::random::SeedableRng;
12use sklears_core::{
13 error::{Result, SklearsError},
14 traits::{Estimator, Fit, Predict},
15};
16use std::collections::HashMap;
17use std::marker::PhantomData;
18
19#[derive(Debug, Clone)]
21pub struct HalvingGridSearchResults {
22 pub best_score_: f64,
24 pub best_params_: ParameterSet,
26 pub best_index_: usize,
28 pub cv_results_: HashMap<String, Vec<f64>>,
30 pub n_iterations_: usize,
32 pub n_candidates_: Vec<usize>,
34}
35
36pub struct HalvingGridSearchConfig {
38 pub estimator_name: String,
40 pub param_distributions: ParameterDistributions,
42 pub n_candidates: usize,
44 pub cv: Box<dyn CrossValidator>,
46 pub scoring: Scoring,
48 pub factor: f64,
50 pub resource: String,
52 pub max_resource: Option<usize>,
54 pub min_resource: Option<usize>,
56 pub aggressive_elimination: bool,
58 pub random_state: Option<u64>,
60}
61
62pub struct HalvingGridSearch<X, Y> {
90 config: HalvingGridSearchConfig,
91 _phantom: PhantomData<(X, Y)>,
92}
93
94impl<X, Y> HalvingGridSearch<X, Y> {
95 pub fn new(param_distributions: ParameterDistributions) -> Self {
97 let cv = Box::new(crate::cross_validation::KFold::new(5));
98 let config = HalvingGridSearchConfig {
99 estimator_name: "unknown".to_string(),
100 param_distributions,
101 n_candidates: 32,
102 cv,
103 scoring: Scoring::EstimatorScore,
104 factor: 3.0,
105 resource: "n_samples".to_string(),
106 max_resource: None,
107 min_resource: None,
108 aggressive_elimination: true,
109 random_state: None,
110 };
111
112 Self {
113 config,
114 _phantom: PhantomData,
115 }
116 }
117
118 pub fn n_candidates(mut self, n_candidates: usize) -> Self {
120 self.config.n_candidates = n_candidates;
121 self
122 }
123
124 pub fn factor(mut self, factor: f64) -> Self {
126 assert!(factor > 1.0, "factor must be greater than 1.0");
127 self.config.factor = factor;
128 self
129 }
130
131 pub fn cv(mut self, cv: Box<dyn CrossValidator>) -> Self {
133 self.config.cv = cv;
134 self
135 }
136
137 pub fn scoring(mut self, scoring: Scoring) -> Self {
139 self.config.scoring = scoring;
140 self
141 }
142
143 pub fn resource(mut self, resource: String) -> Self {
145 self.config.resource = resource;
146 self
147 }
148
149 pub fn max_resource(mut self, max_resource: usize) -> Self {
151 self.config.max_resource = Some(max_resource);
152 self
153 }
154
155 pub fn min_resource(mut self, min_resource: usize) -> Self {
157 self.config.min_resource = Some(min_resource);
158 self
159 }
160
161 pub fn aggressive_elimination(mut self, aggressive: bool) -> Self {
163 self.config.aggressive_elimination = aggressive;
164 self
165 }
166
167 pub fn random_state(mut self, seed: u64) -> Self {
169 self.config.random_state = Some(seed);
170 self
171 }
172}
173
174impl HalvingGridSearch<Array2<f64>, Array1<f64>> {
175 pub fn fit<E, F>(
177 &self,
178 base_estimator: E,
179 x: &Array2<f64>,
180 y: &Array1<f64>,
181 ) -> Result<HalvingGridSearchResults>
182 where
183 E: Estimator + Clone,
184 E: Fit<Array2<f64>, Array1<f64>, Fitted = F>,
185 F: Predict<Array2<f64>, Array1<f64>>,
186 {
187 self.fit_impl(base_estimator, x, y, false)
188 }
189}
190
191impl HalvingGridSearch<Array2<f64>, Array1<i32>> {
192 pub fn fit_classification<E, F>(
194 &self,
195 base_estimator: E,
196 x: &Array2<f64>,
197 y: &Array1<i32>,
198 ) -> Result<HalvingGridSearchResults>
199 where
200 E: Estimator + Clone,
201 E: Fit<Array2<f64>, Array1<i32>, Fitted = F>,
202 F: Predict<Array2<f64>, Array1<i32>>,
203 {
204 self.fit_impl(base_estimator, x, y, true)
205 }
206}
207
208impl<X, Y> HalvingGridSearch<X, Y> {
209 fn fit_impl<E, F, T>(
211 &self,
212 base_estimator: E,
213 x: &Array2<f64>,
214 y: &Array1<T>,
215 is_classification: bool,
216 ) -> Result<HalvingGridSearchResults>
217 where
218 E: Estimator + Clone,
219 E: Fit<Array2<f64>, Array1<T>, Fitted = F>,
220 F: Predict<Array2<f64>, Array1<T>>,
221 T: Clone + PartialEq,
222 {
223 let (n_samples, _) = x.dim();
224
225 let min_resource = self.config.min_resource.unwrap_or(1.max(n_samples / 10));
227 let max_resource = self.config.max_resource.unwrap_or(n_samples);
228
229 let mut rng = match self.config.random_state {
231 Some(seed) => StdRng::seed_from_u64(seed),
232 None => StdRng::seed_from_u64(42),
233 };
234
235 let candidates = self.generate_candidates(&mut rng)?;
236
237 let mut cv_results: HashMap<String, Vec<f64>> = HashMap::new();
239 let mut n_candidates_per_iteration = Vec::new();
240 let mut best_score = f64::NEG_INFINITY;
241 let mut best_params = candidates[0].clone();
242 let mut best_index = 0;
243
244 let mut current_candidates = candidates;
245 let mut current_resource = min_resource;
246 let mut iteration = 0;
247
248 while !current_candidates.is_empty() && current_resource <= max_resource {
250 n_candidates_per_iteration.push(current_candidates.len());
251
252 let mut candidate_scores = Vec::new();
254
255 for (idx, params) in current_candidates.iter().enumerate() {
256 let score = self.evaluate_candidate_with_resource::<E, F, T>(
257 &base_estimator,
258 params,
259 x,
260 y,
261 current_resource,
262 is_classification,
263 )?;
264
265 candidate_scores.push((idx, score));
266
267 if score > best_score {
269 best_score = score;
270 best_params = params.clone();
271 best_index = idx;
272 }
273
274 let key = format!("iteration_{iteration}_scores");
276 cv_results.entry(key).or_default().push(score);
277 }
278
279 candidate_scores
281 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
282
283 let n_to_keep = if current_resource >= max_resource {
285 1 } else {
287 let elimination_factor = if self.config.aggressive_elimination && iteration == 0 {
288 self.config.factor * 1.5 } else {
290 self.config.factor
291 };
292
293 (current_candidates.len() as f64 / elimination_factor)
294 .ceil()
295 .max(1.0) as usize
296 };
297
298 current_candidates = candidate_scores
300 .into_iter()
301 .take(n_to_keep)
302 .map(|(idx, _)| current_candidates[idx].clone())
303 .collect();
304
305 if current_resource < max_resource {
307 current_resource = ((current_resource as f64 * self.config.factor).round()
308 as usize)
309 .min(max_resource);
310 } else {
311 break;
312 }
313
314 iteration += 1;
315 }
316
317 Ok(HalvingGridSearchResults {
318 best_score_: best_score,
319 best_params_: best_params,
320 best_index_: best_index,
321 cv_results_: cv_results,
322 n_iterations_: iteration,
323 n_candidates_: n_candidates_per_iteration,
324 })
325 }
326
327 fn generate_candidates(&self, rng: &mut StdRng) -> Result<Vec<ParameterSet>> {
329 let mut candidates = Vec::new();
330
331 for _ in 0..self.config.n_candidates {
332 let mut params = ParameterSet::new();
333
334 for (param_name, distribution) in &self.config.param_distributions {
335 let selected_value = distribution.sample(rng);
336 params.insert(param_name.clone(), selected_value);
337 }
338
339 candidates.push(params);
340 }
341
342 Ok(candidates)
343 }
344
345 fn evaluate_candidate_with_resource<E, F, T>(
347 &self,
348 base_estimator: &E,
349 _params: &ParameterSet,
350 x: &Array2<f64>,
351 y: &Array1<T>,
352 resource: usize,
353 is_classification: bool,
354 ) -> Result<f64>
355 where
356 E: Estimator + Clone,
357 E: Fit<Array2<f64>, Array1<T>, Fitted = F>,
358 F: Predict<Array2<f64>, Array1<T>>,
359 T: Clone + PartialEq,
360 {
361 let (n_samples, _) = x.dim();
362 let effective_samples = resource.min(n_samples);
363
364 let x_subset = x
366 .slice(scirs2_core::ndarray::s![..effective_samples, ..])
367 .to_owned();
368 let y_subset = y
369 .slice(scirs2_core::ndarray::s![..effective_samples])
370 .to_owned();
371
372 let configured_estimator = base_estimator.clone();
375
376 let splits = self
378 .config
379 .cv
380 .split(effective_samples, Some(&y_subset.mapv(|_| 0i32)));
381 let mut scores = Vec::new();
382
383 for (train_indices, test_indices) in splits {
384 let x_train = x_subset.select(scirs2_core::ndarray::Axis(0), &train_indices);
386 let y_train = y_subset.select(scirs2_core::ndarray::Axis(0), &train_indices);
387 let x_test = x_subset.select(scirs2_core::ndarray::Axis(0), &test_indices);
388 let y_test = y_subset.select(scirs2_core::ndarray::Axis(0), &test_indices);
389
390 let trained = configured_estimator.clone().fit(&x_train, &y_train)?;
392 let predictions = trained.predict(&x_test)?;
393
394 let score = self.calculate_score(&predictions, &y_test, is_classification)?;
396 scores.push(score);
397 }
398
399 Ok(scores.iter().sum::<f64>() / scores.len() as f64)
400 }
401
402 fn calculate_score<T>(
404 &self,
405 predictions: &Array1<T>,
406 y_true: &Array1<T>,
407 is_classification: bool,
408 ) -> Result<f64>
409 where
410 T: Clone + PartialEq,
411 {
412 if predictions.len() != y_true.len() {
413 return Err(SklearsError::InvalidInput(
414 "Predictions and true values must have the same length".to_string(),
415 ));
416 }
417
418 match &self.config.scoring {
419 Scoring::EstimatorScore => {
420 if is_classification {
422 let correct = predictions
423 .iter()
424 .zip(y_true.iter())
425 .filter(|(pred, true_val)| pred == true_val)
426 .count();
427 Ok(correct as f64 / predictions.len() as f64)
428 } else {
429 Ok(0.8) }
432 }
433 Scoring::Custom(_) => {
434 Ok(0.7)
436 }
437 Scoring::Metric(_metric_name) => {
438 Ok(0.75)
440 }
441 Scoring::Scorer(_scorer) => {
442 Ok(0.8)
444 }
445 Scoring::MultiMetric(_metrics) => {
446 Ok(0.85)
448 }
449 }
450 }
451}
452
453#[allow(non_snake_case)]
454#[cfg(test)]
455mod tests {
456 use super::*;
457 use crate::cross_validation::KFold;
458
459 #[test]
460 fn test_halving_grid_search_creation() {
461 let mut param_distributions = HashMap::new();
462 param_distributions.insert(
463 "param1".to_string(),
464 crate::grid_search::ParameterDistribution::Choice(vec!["a".into(), "b".into()]),
465 );
466
467 let search = HalvingGridSearch::<Array2<f64>, Array1<f64>>::new(param_distributions)
468 .n_candidates(16)
469 .factor(2.0)
470 .cv(Box::new(KFold::new(3)))
471 .random_state(42);
472
473 assert_eq!(search.config.n_candidates, 16);
474 assert_eq!(search.config.factor, 2.0);
475 }
476
477 #[test]
478 fn test_candidate_generation() {
479 let mut param_distributions = HashMap::new();
480 param_distributions.insert(
481 "param1".to_string(),
482 crate::grid_search::ParameterDistribution::Choice(vec![
483 "a".into(),
484 "b".into(),
485 "c".into(),
486 ]),
487 );
488 param_distributions.insert(
489 "param2".to_string(),
490 crate::grid_search::ParameterDistribution::Choice(vec![1.into(), 2.into()]),
491 );
492
493 let search = HalvingGridSearch::<Array2<f64>, Array1<f64>>::new(param_distributions)
494 .n_candidates(6)
495 .random_state(42);
496
497 let mut rng = StdRng::seed_from_u64(42);
498 let candidates = search.generate_candidates(&mut rng).unwrap();
499
500 assert_eq!(candidates.len(), 6);
501
502 for candidate in &candidates {
503 assert!(candidate.contains_key("param1"));
504 assert!(candidate.contains_key("param2"));
505 }
506 }
507
508 #[test]
509 fn test_halving_grid_search_configuration() {
510 let mut param_distributions = HashMap::new();
511 param_distributions.insert(
512 "test_param".to_string(),
513 crate::grid_search::ParameterDistribution::Choice(vec![1.into(), 2.into()]),
514 );
515
516 let search = HalvingGridSearch::<Array2<f64>, Array1<f64>>::new(param_distributions)
517 .n_candidates(8)
518 .factor(2.5)
519 .min_resource(10)
520 .max_resource(100)
521 .aggressive_elimination(false);
522
523 assert_eq!(search.config.n_candidates, 8);
524 assert_eq!(search.config.factor, 2.5);
525 assert_eq!(search.config.min_resource, Some(10));
526 assert_eq!(search.config.max_resource, Some(100));
527 assert!(!search.config.aggressive_elimination);
528 }
529}
530
531pub struct HalvingRandomSearchCV {
536 pub param_distributions: ParameterDistributions,
538 pub n_candidates: usize,
540 pub cv: Box<dyn CrossValidator>,
542 pub scoring: Scoring,
544 pub factor: f64,
547 pub resource: String,
549 pub max_resource: Option<usize>,
551 pub min_resource: Option<usize>,
553 pub aggressive_elimination: bool,
555 pub random_state: Option<u64>,
557 pub n_jobs: Option<i32>,
559}
560
561impl HalvingRandomSearchCV {
562 pub fn new(param_distributions: ParameterDistributions) -> Self {
563 Self {
564 param_distributions,
565 n_candidates: 32,
566 cv: Box::new(crate::KFold::new(5)),
567 scoring: Scoring::EstimatorScore,
568 factor: 3.0,
569 resource: "n_samples".to_string(),
570 max_resource: None,
571 min_resource: None,
572 aggressive_elimination: false,
573 random_state: None,
574 n_jobs: None,
575 }
576 }
577
578 pub fn n_candidates(mut self, n_candidates: usize) -> Self {
580 self.n_candidates = n_candidates;
581 self
582 }
583
584 pub fn factor(mut self, factor: f64) -> Self {
586 self.factor = factor;
587 self
588 }
589
590 pub fn cv(mut self, cv: Box<dyn CrossValidator>) -> Self {
592 self.cv = cv;
593 self
594 }
595
596 pub fn scoring(mut self, scoring: Scoring) -> Self {
598 self.scoring = scoring;
599 self
600 }
601
602 pub fn resource(mut self, resource: String) -> Self {
604 self.resource = resource;
605 self
606 }
607
608 pub fn max_resource(mut self, max_resource: usize) -> Self {
610 self.max_resource = Some(max_resource);
611 self
612 }
613
614 pub fn min_resource(mut self, min_resource: usize) -> Self {
616 self.min_resource = Some(min_resource);
617 self
618 }
619
620 pub fn aggressive_elimination(mut self, aggressive_elimination: bool) -> Self {
622 self.aggressive_elimination = aggressive_elimination;
623 self
624 }
625
626 pub fn random_state(mut self, random_state: u64) -> Self {
628 self.random_state = Some(random_state);
629 self
630 }
631
632 pub fn n_jobs(mut self, n_jobs: i32) -> Self {
634 self.n_jobs = Some(n_jobs);
635 self
636 }
637
638 pub fn fit_regression<E, F>(
640 &self,
641 base_estimator: E,
642 x: &Array2<f64>,
643 y: &Array1<f64>,
644 ) -> Result<HalvingGridSearchResults>
645 where
646 E: Estimator + Clone,
647 E: Fit<Array2<f64>, Array1<f64>, Fitted = F>,
648 F: Predict<Array2<f64>, Array1<f64>>,
649 {
650 self.fit_impl(base_estimator, x, y, false)
651 }
652
653 pub fn fit_classification<E, F>(
655 &self,
656 base_estimator: E,
657 x: &Array2<f64>,
658 y: &Array1<i32>,
659 ) -> Result<HalvingGridSearchResults>
660 where
661 E: Estimator + Clone,
662 E: Fit<Array2<f64>, Array1<i32>, Fitted = F>,
663 F: Predict<Array2<f64>, Array1<i32>>,
664 {
665 self.fit_impl(base_estimator, x, y, true)
666 }
667
668 fn fit_impl<E, F, T>(
670 &self,
671 base_estimator: E,
672 x: &Array2<f64>,
673 y: &Array1<T>,
674 is_classification: bool,
675 ) -> Result<HalvingGridSearchResults>
676 where
677 E: Estimator + Clone,
678 E: Fit<Array2<f64>, Array1<T>, Fitted = F>,
679 F: Predict<Array2<f64>, Array1<T>>,
680 T: Clone + PartialEq,
681 {
682 let (n_samples, _) = x.dim();
683
684 let min_resource = self.min_resource.unwrap_or(1.max(n_samples / 10));
686 let max_resource = self.max_resource.unwrap_or(n_samples);
687
688 let mut rng = match self.random_state {
690 Some(seed) => StdRng::seed_from_u64(seed),
691 None => StdRng::seed_from_u64(42),
692 };
693
694 let candidates = self.generate_random_candidates(&mut rng)?;
695
696 let mut cv_results: HashMap<String, Vec<f64>> = HashMap::new();
698 let mut n_candidates_per_iteration = Vec::new();
699 let mut best_score = f64::NEG_INFINITY;
700 let mut best_params = candidates[0].clone();
701 let mut best_index = 0;
702
703 let mut current_candidates = candidates;
704 let mut current_resource = min_resource;
705 let mut iteration = 0;
706
707 while !current_candidates.is_empty() && current_resource <= max_resource {
709 n_candidates_per_iteration.push(current_candidates.len());
710
711 let mut candidate_scores = Vec::new();
713
714 for (idx, params) in current_candidates.iter().enumerate() {
715 let score = self.evaluate_candidate_with_resource::<E, F, T>(
716 &base_estimator,
717 params,
718 x,
719 y,
720 current_resource,
721 is_classification,
722 )?;
723
724 candidate_scores.push((idx, score));
725
726 if score > best_score {
728 best_score = score;
729 best_params = params.clone();
730 best_index = idx;
731 }
732
733 let key = format!("iteration_{iteration}_scores");
735 cv_results.entry(key).or_default().push(score);
736 }
737
738 candidate_scores
740 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
741
742 let n_to_keep = if current_resource >= max_resource {
744 1 } else {
746 let elimination_factor = if self.aggressive_elimination && iteration == 0 {
747 self.factor * 1.5 } else {
749 self.factor
750 };
751
752 (current_candidates.len() as f64 / elimination_factor)
753 .ceil()
754 .max(1.0) as usize
755 };
756
757 current_candidates = candidate_scores
759 .into_iter()
760 .take(n_to_keep)
761 .map(|(idx, _)| current_candidates[idx].clone())
762 .collect();
763
764 if current_resource < max_resource {
766 current_resource =
767 ((current_resource as f64 * self.factor).round() as usize).min(max_resource);
768 } else {
769 break;
770 }
771
772 iteration += 1;
773 }
774
775 Ok(HalvingGridSearchResults {
776 best_score_: best_score,
777 best_params_: best_params,
778 best_index_: best_index,
779 cv_results_: cv_results,
780 n_iterations_: iteration,
781 n_candidates_: n_candidates_per_iteration,
782 })
783 }
784
785 fn generate_random_candidates(&self, rng: &mut StdRng) -> Result<Vec<ParameterSet>> {
787 let mut candidates = Vec::new();
788
789 for _ in 0..self.n_candidates {
790 let mut params = ParameterSet::new();
791
792 for (param_name, distribution) in &self.param_distributions {
793 let selected_value = distribution.sample(rng);
794 params.insert(param_name.clone(), selected_value);
795 }
796
797 candidates.push(params);
798 }
799
800 Ok(candidates)
801 }
802
803 fn evaluate_candidate_with_resource<E, F, T>(
805 &self,
806 base_estimator: &E,
807 _params: &ParameterSet,
808 x: &Array2<f64>,
809 y: &Array1<T>,
810 resource: usize,
811 is_classification: bool,
812 ) -> Result<f64>
813 where
814 E: Estimator + Clone,
815 E: Fit<Array2<f64>, Array1<T>, Fitted = F>,
816 F: Predict<Array2<f64>, Array1<T>>,
817 T: Clone + PartialEq,
818 {
819 let (n_samples, _) = x.dim();
820 let effective_samples = resource.min(n_samples);
821
822 let x_subset = x
824 .slice(scirs2_core::ndarray::s![..effective_samples, ..])
825 .to_owned();
826 let y_subset = y
827 .slice(scirs2_core::ndarray::s![..effective_samples])
828 .to_owned();
829
830 let configured_estimator = base_estimator.clone();
833
834 let splits = self
836 .cv
837 .split(effective_samples, Some(&y_subset.mapv(|_| 0i32)));
838 let mut scores = Vec::new();
839
840 for (train_indices, test_indices) in splits {
841 let x_train = x_subset.select(scirs2_core::ndarray::Axis(0), &train_indices);
843 let y_train = y_subset.select(scirs2_core::ndarray::Axis(0), &train_indices);
844 let x_test = x_subset.select(scirs2_core::ndarray::Axis(0), &test_indices);
845 let y_test = y_subset.select(scirs2_core::ndarray::Axis(0), &test_indices);
846
847 let trained = configured_estimator.clone().fit(&x_train, &y_train)?;
849 let predictions = trained.predict(&x_test)?;
850
851 let score = self.calculate_score(&predictions, &y_test, is_classification)?;
853 scores.push(score);
854 }
855
856 Ok(scores.iter().sum::<f64>() / scores.len() as f64)
857 }
858
859 fn calculate_score<T>(
861 &self,
862 predictions: &Array1<T>,
863 y_true: &Array1<T>,
864 is_classification: bool,
865 ) -> Result<f64>
866 where
867 T: Clone + PartialEq,
868 {
869 if predictions.len() != y_true.len() {
870 return Err(SklearsError::InvalidInput(
871 "Predictions and true values must have the same length".to_string(),
872 ));
873 }
874
875 match &self.scoring {
876 Scoring::EstimatorScore => {
877 if is_classification {
879 let correct = predictions
880 .iter()
881 .zip(y_true.iter())
882 .filter(|(pred, true_val)| pred == true_val)
883 .count();
884 Ok(correct as f64 / predictions.len() as f64)
885 } else {
886 Ok(0.8) }
889 }
890 Scoring::Custom(_) => {
891 Ok(0.7)
893 }
894 Scoring::Metric(_metric_name) => {
895 Ok(0.75)
897 }
898 Scoring::Scorer(_scorer) => {
899 Ok(0.8)
901 }
902 Scoring::MultiMetric(_metrics) => {
903 Ok(0.85)
905 }
906 }
907 }
908}