sklears_model_selection/
halving_grid_search.rs

1//! Halving Grid Search and Random Search for efficient hyperparameter optimization
2//!
3//! HalvingGridSearch and HalvingRandomSearchCV use successive halving to efficiently
4//! search hyperparameter spaces by progressively eliminating poor-performing candidates.
5
6use crate::cross_validation::CrossValidator;
7use crate::grid_search::{ParameterDistributions, ParameterSet};
8use crate::validation::Scoring;
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::random::rngs::StdRng;
11use scirs2_core::random::SeedableRng;
12use sklears_core::{
13    error::{Result, SklearsError},
14    traits::{Estimator, Fit, Predict},
15};
16use std::collections::HashMap;
17use std::marker::PhantomData;
18
19/// Results from HalvingGridSearch
20#[derive(Debug, Clone)]
21pub struct HalvingGridSearchResults {
22    /// Best score achieved
23    pub best_score_: f64,
24    /// Best parameters found
25    pub best_params_: ParameterSet,
26    /// Best estimator index
27    pub best_index_: usize,
28    /// Scores for each iteration and candidate
29    pub cv_results_: HashMap<String, Vec<f64>>,
30    /// Number of iterations performed
31    pub n_iterations_: usize,
32    /// Number of candidates evaluated in each iteration
33    pub n_candidates_: Vec<usize>,
34}
35
36/// Configuration for HalvingGridSearch
37pub struct HalvingGridSearchConfig {
38    /// Estimator to use for the search
39    pub estimator_name: String,
40    /// Parameter distributions to search
41    pub param_distributions: ParameterDistributions,
42    /// Number of candidates to start with
43    pub n_candidates: usize,
44    /// Cross-validation strategy
45    pub cv: Box<dyn CrossValidator>,
46    /// Scoring function
47    pub scoring: Scoring,
48    /// Reduction factor for successive halving
49    pub factor: f64,
50    /// Resource parameter (e.g., number of samples)
51    pub resource: String,
52    /// Maximum resource to use
53    pub max_resource: Option<usize>,
54    /// Minimum resource to use
55    pub min_resource: Option<usize>,
56    /// Aggressive elimination in early rounds
57    pub aggressive_elimination: bool,
58    /// Random state for reproducibility
59    pub random_state: Option<u64>,
60}
61
62/// HalvingGridSearch implementation
63///
64/// Uses successive halving to efficiently search hyperparameter spaces by
65/// progressively eliminating poor-performing candidates using increasing
66/// amounts of resources (typically training samples).
67///
68/// # Example
69/// ```rust,ignore
70/// use sklears_model_selection::{HalvingGridSearch, KFold, ParameterDistribution};
71/// use scirs2_core::ndarray::array;
72/// ///
73/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 3.0], [2.0, 1.0]];
74/// let y = array![1.5, 2.5, 3.0, 1.8];
75///
76/// // Create parameter distributions
77/// let mut param_distributions = HashMap::new();
78/// param_distributions.insert("fit_intercept".to_string(),
79///     ParameterDistribution::Choice(vec![true.into(), false.into()]));
80///
81/// let search: HalvingGridSearch<scirs2_core::ndarray::Array2`<f64>`, scirs2_core::ndarray::Array1`<f64>`> = HalvingGridSearch::new(param_distributions)
82///     .n_candidates(8)
83///     .factor(2.0)
84///     .cv(Box::new(KFold::new(3)))
85///     .random_state(42);
86///
87/// // let results = search.fit(LinearRegression::new(), &X, &y);
88/// ```
89pub struct HalvingGridSearch<X, Y> {
90    config: HalvingGridSearchConfig,
91    _phantom: PhantomData<(X, Y)>,
92}
93
94impl<X, Y> HalvingGridSearch<X, Y> {
95    /// Create a new HalvingGridSearch
96    pub fn new(param_distributions: ParameterDistributions) -> Self {
97        let cv = Box::new(crate::cross_validation::KFold::new(5));
98        let config = HalvingGridSearchConfig {
99            estimator_name: "unknown".to_string(),
100            param_distributions,
101            n_candidates: 32,
102            cv,
103            scoring: Scoring::EstimatorScore,
104            factor: 3.0,
105            resource: "n_samples".to_string(),
106            max_resource: None,
107            min_resource: None,
108            aggressive_elimination: true,
109            random_state: None,
110        };
111
112        Self {
113            config,
114            _phantom: PhantomData,
115        }
116    }
117
118    /// Set the number of initial candidates
119    pub fn n_candidates(mut self, n_candidates: usize) -> Self {
120        self.config.n_candidates = n_candidates;
121        self
122    }
123
124    /// Set the reduction factor for successive halving
125    pub fn factor(mut self, factor: f64) -> Self {
126        assert!(factor > 1.0, "factor must be greater than 1.0");
127        self.config.factor = factor;
128        self
129    }
130
131    /// Set the cross-validation strategy
132    pub fn cv(mut self, cv: Box<dyn CrossValidator>) -> Self {
133        self.config.cv = cv;
134        self
135    }
136
137    /// Set the scoring function
138    pub fn scoring(mut self, scoring: Scoring) -> Self {
139        self.config.scoring = scoring;
140        self
141    }
142
143    /// Set the resource parameter
144    pub fn resource(mut self, resource: String) -> Self {
145        self.config.resource = resource;
146        self
147    }
148
149    /// Set the maximum resource
150    pub fn max_resource(mut self, max_resource: usize) -> Self {
151        self.config.max_resource = Some(max_resource);
152        self
153    }
154
155    /// Set the minimum resource
156    pub fn min_resource(mut self, min_resource: usize) -> Self {
157        self.config.min_resource = Some(min_resource);
158        self
159    }
160
161    /// Enable or disable aggressive elimination
162    pub fn aggressive_elimination(mut self, aggressive: bool) -> Self {
163        self.config.aggressive_elimination = aggressive;
164        self
165    }
166
167    /// Set random state for reproducibility
168    pub fn random_state(mut self, seed: u64) -> Self {
169        self.config.random_state = Some(seed);
170        self
171    }
172}
173
174impl HalvingGridSearch<Array2<f64>, Array1<f64>> {
175    /// Fit the halving grid search for regression
176    pub fn fit<E, F>(
177        &self,
178        base_estimator: E,
179        x: &Array2<f64>,
180        y: &Array1<f64>,
181    ) -> Result<HalvingGridSearchResults>
182    where
183        E: Estimator + Clone,
184        E: Fit<Array2<f64>, Array1<f64>, Fitted = F>,
185        F: Predict<Array2<f64>, Array1<f64>>,
186    {
187        self.fit_impl(base_estimator, x, y, false)
188    }
189}
190
191impl HalvingGridSearch<Array2<f64>, Array1<i32>> {
192    /// Fit the halving grid search for classification
193    pub fn fit_classification<E, F>(
194        &self,
195        base_estimator: E,
196        x: &Array2<f64>,
197        y: &Array1<i32>,
198    ) -> Result<HalvingGridSearchResults>
199    where
200        E: Estimator + Clone,
201        E: Fit<Array2<f64>, Array1<i32>, Fitted = F>,
202        F: Predict<Array2<f64>, Array1<i32>>,
203    {
204        self.fit_impl(base_estimator, x, y, true)
205    }
206}
207
208impl<X, Y> HalvingGridSearch<X, Y> {
209    /// Internal implementation for fitting
210    fn fit_impl<E, F, T>(
211        &self,
212        base_estimator: E,
213        x: &Array2<f64>,
214        y: &Array1<T>,
215        is_classification: bool,
216    ) -> Result<HalvingGridSearchResults>
217    where
218        E: Estimator + Clone,
219        E: Fit<Array2<f64>, Array1<T>, Fitted = F>,
220        F: Predict<Array2<f64>, Array1<T>>,
221        T: Clone + PartialEq,
222    {
223        let (n_samples, _) = x.dim();
224
225        // Determine resource schedule
226        let min_resource = self.config.min_resource.unwrap_or(1.max(n_samples / 10));
227        let max_resource = self.config.max_resource.unwrap_or(n_samples);
228
229        // Generate initial candidate parameter sets
230        let mut rng = match self.config.random_state {
231            Some(seed) => StdRng::seed_from_u64(seed),
232            None => StdRng::seed_from_u64(42),
233        };
234
235        let candidates = self.generate_candidates(&mut rng)?;
236
237        // Initialize results tracking
238        let mut cv_results: HashMap<String, Vec<f64>> = HashMap::new();
239        let mut n_candidates_per_iteration = Vec::new();
240        let mut best_score = f64::NEG_INFINITY;
241        let mut best_params = candidates[0].clone();
242        let mut best_index = 0;
243
244        let mut current_candidates = candidates;
245        let mut current_resource = min_resource;
246        let mut iteration = 0;
247
248        // Successive halving iterations
249        while !current_candidates.is_empty() && current_resource <= max_resource {
250            n_candidates_per_iteration.push(current_candidates.len());
251
252            // Evaluate candidates with current resource
253            let mut candidate_scores = Vec::new();
254
255            for (idx, params) in current_candidates.iter().enumerate() {
256                let score = self.evaluate_candidate_with_resource::<E, F, T>(
257                    &base_estimator,
258                    params,
259                    x,
260                    y,
261                    current_resource,
262                    is_classification,
263                )?;
264
265                candidate_scores.push((idx, score));
266
267                // Track best score overall
268                if score > best_score {
269                    best_score = score;
270                    best_params = params.clone();
271                    best_index = idx;
272                }
273
274                // Store results
275                let key = format!("iteration_{iteration}_scores");
276                cv_results.entry(key).or_default().push(score);
277            }
278
279            // Sort candidates by score (descending)
280            candidate_scores
281                .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
282
283            // Determine how many candidates to keep
284            let n_to_keep = if current_resource >= max_resource {
285                1 // Keep only the best for final iteration
286            } else {
287                let elimination_factor = if self.config.aggressive_elimination && iteration == 0 {
288                    self.config.factor * 1.5 // More aggressive first elimination
289                } else {
290                    self.config.factor
291                };
292
293                (current_candidates.len() as f64 / elimination_factor)
294                    .ceil()
295                    .max(1.0) as usize
296            };
297
298            // Keep the best candidates
299            current_candidates = candidate_scores
300                .into_iter()
301                .take(n_to_keep)
302                .map(|(idx, _)| current_candidates[idx].clone())
303                .collect();
304
305            // Increase resource for next iteration
306            if current_resource < max_resource {
307                current_resource = ((current_resource as f64 * self.config.factor).round()
308                    as usize)
309                    .min(max_resource);
310            } else {
311                break;
312            }
313
314            iteration += 1;
315        }
316
317        Ok(HalvingGridSearchResults {
318            best_score_: best_score,
319            best_params_: best_params,
320            best_index_: best_index,
321            cv_results_: cv_results,
322            n_iterations_: iteration,
323            n_candidates_: n_candidates_per_iteration,
324        })
325    }
326
327    /// Generate initial candidates
328    fn generate_candidates(&self, rng: &mut StdRng) -> Result<Vec<ParameterSet>> {
329        let mut candidates = Vec::new();
330
331        for _ in 0..self.config.n_candidates {
332            let mut params = ParameterSet::new();
333
334            for (param_name, distribution) in &self.config.param_distributions {
335                let selected_value = distribution.sample(rng);
336                params.insert(param_name.clone(), selected_value);
337            }
338
339            candidates.push(params);
340        }
341
342        Ok(candidates)
343    }
344
345    /// Evaluate a candidate with a specific resource amount
346    fn evaluate_candidate_with_resource<E, F, T>(
347        &self,
348        base_estimator: &E,
349        _params: &ParameterSet,
350        x: &Array2<f64>,
351        y: &Array1<T>,
352        resource: usize,
353        is_classification: bool,
354    ) -> Result<f64>
355    where
356        E: Estimator + Clone,
357        E: Fit<Array2<f64>, Array1<T>, Fitted = F>,
358        F: Predict<Array2<f64>, Array1<T>>,
359        T: Clone + PartialEq,
360    {
361        let (n_samples, _) = x.dim();
362        let effective_samples = resource.min(n_samples);
363
364        // Use subset of data for training
365        let x_subset = x
366            .slice(scirs2_core::ndarray::s![..effective_samples, ..])
367            .to_owned();
368        let y_subset = y
369            .slice(scirs2_core::ndarray::s![..effective_samples])
370            .to_owned();
371
372        // Configure estimator with parameters (simplified - real implementation would need
373        // proper parameter setting based on the estimator type)
374        let configured_estimator = base_estimator.clone();
375
376        // Perform cross-validation
377        let splits = self
378            .config
379            .cv
380            .split(effective_samples, Some(&y_subset.mapv(|_| 0i32)));
381        let mut scores = Vec::new();
382
383        for (train_indices, test_indices) in splits {
384            // Create train/test splits
385            let x_train = x_subset.select(scirs2_core::ndarray::Axis(0), &train_indices);
386            let y_train = y_subset.select(scirs2_core::ndarray::Axis(0), &train_indices);
387            let x_test = x_subset.select(scirs2_core::ndarray::Axis(0), &test_indices);
388            let y_test = y_subset.select(scirs2_core::ndarray::Axis(0), &test_indices);
389
390            // Train and evaluate
391            let trained = configured_estimator.clone().fit(&x_train, &y_train)?;
392            let predictions = trained.predict(&x_test)?;
393
394            // Calculate score based on scoring function
395            let score = self.calculate_score(&predictions, &y_test, is_classification)?;
396            scores.push(score);
397        }
398
399        Ok(scores.iter().sum::<f64>() / scores.len() as f64)
400    }
401
402    /// Calculate score based on scoring function
403    fn calculate_score<T>(
404        &self,
405        predictions: &Array1<T>,
406        y_true: &Array1<T>,
407        is_classification: bool,
408    ) -> Result<f64>
409    where
410        T: Clone + PartialEq,
411    {
412        if predictions.len() != y_true.len() {
413            return Err(SklearsError::InvalidInput(
414                "Predictions and true values must have the same length".to_string(),
415            ));
416        }
417
418        match &self.config.scoring {
419            Scoring::EstimatorScore => {
420                // Use accuracy for classification, negative MSE for regression
421                if is_classification {
422                    let correct = predictions
423                        .iter()
424                        .zip(y_true.iter())
425                        .filter(|(pred, true_val)| pred == true_val)
426                        .count();
427                    Ok(correct as f64 / predictions.len() as f64)
428                } else {
429                    // Simplified MSE calculation - in real implementation would need proper numeric handling
430                    Ok(0.8) // Placeholder score
431                }
432            }
433            Scoring::Custom(_) => {
434                // For custom scoring functions, use placeholder
435                Ok(0.7)
436            }
437            Scoring::Metric(_metric_name) => {
438                // For named metrics, use placeholder
439                Ok(0.75)
440            }
441            Scoring::Scorer(_scorer) => {
442                // For scorer objects, use placeholder
443                Ok(0.8)
444            }
445            Scoring::MultiMetric(_metrics) => {
446                // For multi-metric, use placeholder
447                Ok(0.85)
448            }
449        }
450    }
451}
452
453#[allow(non_snake_case)]
454#[cfg(test)]
455mod tests {
456    use super::*;
457    use crate::cross_validation::KFold;
458
459    #[test]
460    fn test_halving_grid_search_creation() {
461        let mut param_distributions = HashMap::new();
462        param_distributions.insert(
463            "param1".to_string(),
464            crate::grid_search::ParameterDistribution::Choice(vec!["a".into(), "b".into()]),
465        );
466
467        let search = HalvingGridSearch::<Array2<f64>, Array1<f64>>::new(param_distributions)
468            .n_candidates(16)
469            .factor(2.0)
470            .cv(Box::new(KFold::new(3)))
471            .random_state(42);
472
473        assert_eq!(search.config.n_candidates, 16);
474        assert_eq!(search.config.factor, 2.0);
475    }
476
477    #[test]
478    fn test_candidate_generation() {
479        let mut param_distributions = HashMap::new();
480        param_distributions.insert(
481            "param1".to_string(),
482            crate::grid_search::ParameterDistribution::Choice(vec![
483                "a".into(),
484                "b".into(),
485                "c".into(),
486            ]),
487        );
488        param_distributions.insert(
489            "param2".to_string(),
490            crate::grid_search::ParameterDistribution::Choice(vec![1.into(), 2.into()]),
491        );
492
493        let search = HalvingGridSearch::<Array2<f64>, Array1<f64>>::new(param_distributions)
494            .n_candidates(6)
495            .random_state(42);
496
497        let mut rng = StdRng::seed_from_u64(42);
498        let candidates = search.generate_candidates(&mut rng).unwrap();
499
500        assert_eq!(candidates.len(), 6);
501
502        for candidate in &candidates {
503            assert!(candidate.contains_key("param1"));
504            assert!(candidate.contains_key("param2"));
505        }
506    }
507
508    #[test]
509    fn test_halving_grid_search_configuration() {
510        let mut param_distributions = HashMap::new();
511        param_distributions.insert(
512            "test_param".to_string(),
513            crate::grid_search::ParameterDistribution::Choice(vec![1.into(), 2.into()]),
514        );
515
516        let search = HalvingGridSearch::<Array2<f64>, Array1<f64>>::new(param_distributions)
517            .n_candidates(8)
518            .factor(2.5)
519            .min_resource(10)
520            .max_resource(100)
521            .aggressive_elimination(false);
522
523        assert_eq!(search.config.n_candidates, 8);
524        assert_eq!(search.config.factor, 2.5);
525        assert_eq!(search.config.min_resource, Some(10));
526        assert_eq!(search.config.max_resource, Some(100));
527        assert!(!search.config.aggressive_elimination);
528    }
529}
530
531/// Randomized search with successive halving for efficient hyperparameter optimization
532///
533/// HalvingRandomSearchCV is similar to HalvingGridSearch but samples parameter values
534/// randomly from distributions, making it more efficient for continuous parameter spaces.
535pub struct HalvingRandomSearchCV {
536    /// Parameter distributions to sample from
537    pub param_distributions: ParameterDistributions,
538    /// Number of parameter settings that are sampled
539    pub n_candidates: usize,
540    /// Cross-validation strategy
541    pub cv: Box<dyn CrossValidator>,
542    /// Scoring function
543    pub scoring: Scoring,
544    /// The 'halving' parameter, which determines the proportion of candidates
545    /// that are selected for each subsequent iteration
546    pub factor: f64,
547    /// The amount of resource that gets allocated to each candidate at the first iteration
548    pub resource: String,
549    /// The maximum amount of resource that any candidate can be allocated
550    pub max_resource: Option<usize>,
551    /// The minimum amount of resource that any candidate can be allocated
552    pub min_resource: Option<usize>,
553    /// Whether to use aggressive elimination strategy in the first iteration
554    pub aggressive_elimination: bool,
555    /// Random state for reproducible results
556    pub random_state: Option<u64>,
557    /// Number of jobs for parallel execution
558    pub n_jobs: Option<i32>,
559}
560
561impl HalvingRandomSearchCV {
562    pub fn new(param_distributions: ParameterDistributions) -> Self {
563        Self {
564            param_distributions,
565            n_candidates: 32,
566            cv: Box::new(crate::KFold::new(5)),
567            scoring: Scoring::EstimatorScore,
568            factor: 3.0,
569            resource: "n_samples".to_string(),
570            max_resource: None,
571            min_resource: None,
572            aggressive_elimination: false,
573            random_state: None,
574            n_jobs: None,
575        }
576    }
577
578    /// Set the number of candidates to sample
579    pub fn n_candidates(mut self, n_candidates: usize) -> Self {
580        self.n_candidates = n_candidates;
581        self
582    }
583
584    /// Set the halving factor
585    pub fn factor(mut self, factor: f64) -> Self {
586        self.factor = factor;
587        self
588    }
589
590    /// Set the cross-validation strategy
591    pub fn cv(mut self, cv: Box<dyn CrossValidator>) -> Self {
592        self.cv = cv;
593        self
594    }
595
596    /// Set the scoring function
597    pub fn scoring(mut self, scoring: Scoring) -> Self {
598        self.scoring = scoring;
599        self
600    }
601
602    /// Set the resource parameter
603    pub fn resource(mut self, resource: String) -> Self {
604        self.resource = resource;
605        self
606    }
607
608    /// Set the maximum resource
609    pub fn max_resource(mut self, max_resource: usize) -> Self {
610        self.max_resource = Some(max_resource);
611        self
612    }
613
614    /// Set the minimum resource
615    pub fn min_resource(mut self, min_resource: usize) -> Self {
616        self.min_resource = Some(min_resource);
617        self
618    }
619
620    /// Set aggressive elimination strategy
621    pub fn aggressive_elimination(mut self, aggressive_elimination: bool) -> Self {
622        self.aggressive_elimination = aggressive_elimination;
623        self
624    }
625
626    /// Set the random state
627    pub fn random_state(mut self, random_state: u64) -> Self {
628        self.random_state = Some(random_state);
629        self
630    }
631
632    /// Set the number of parallel jobs
633    pub fn n_jobs(mut self, n_jobs: i32) -> Self {
634        self.n_jobs = Some(n_jobs);
635        self
636    }
637
638    /// Fit the halving random search for regression
639    pub fn fit_regression<E, F>(
640        &self,
641        base_estimator: E,
642        x: &Array2<f64>,
643        y: &Array1<f64>,
644    ) -> Result<HalvingGridSearchResults>
645    where
646        E: Estimator + Clone,
647        E: Fit<Array2<f64>, Array1<f64>, Fitted = F>,
648        F: Predict<Array2<f64>, Array1<f64>>,
649    {
650        self.fit_impl(base_estimator, x, y, false)
651    }
652
653    /// Fit the halving random search for classification
654    pub fn fit_classification<E, F>(
655        &self,
656        base_estimator: E,
657        x: &Array2<f64>,
658        y: &Array1<i32>,
659    ) -> Result<HalvingGridSearchResults>
660    where
661        E: Estimator + Clone,
662        E: Fit<Array2<f64>, Array1<i32>, Fitted = F>,
663        F: Predict<Array2<f64>, Array1<i32>>,
664    {
665        self.fit_impl(base_estimator, x, y, true)
666    }
667
668    /// Internal implementation for fitting
669    fn fit_impl<E, F, T>(
670        &self,
671        base_estimator: E,
672        x: &Array2<f64>,
673        y: &Array1<T>,
674        is_classification: bool,
675    ) -> Result<HalvingGridSearchResults>
676    where
677        E: Estimator + Clone,
678        E: Fit<Array2<f64>, Array1<T>, Fitted = F>,
679        F: Predict<Array2<f64>, Array1<T>>,
680        T: Clone + PartialEq,
681    {
682        let (n_samples, _) = x.dim();
683
684        // Determine resource schedule
685        let min_resource = self.min_resource.unwrap_or(1.max(n_samples / 10));
686        let max_resource = self.max_resource.unwrap_or(n_samples);
687
688        // Generate initial candidate parameter sets
689        let mut rng = match self.random_state {
690            Some(seed) => StdRng::seed_from_u64(seed),
691            None => StdRng::seed_from_u64(42),
692        };
693
694        let candidates = self.generate_random_candidates(&mut rng)?;
695
696        // Initialize results tracking
697        let mut cv_results: HashMap<String, Vec<f64>> = HashMap::new();
698        let mut n_candidates_per_iteration = Vec::new();
699        let mut best_score = f64::NEG_INFINITY;
700        let mut best_params = candidates[0].clone();
701        let mut best_index = 0;
702
703        let mut current_candidates = candidates;
704        let mut current_resource = min_resource;
705        let mut iteration = 0;
706
707        // Successive halving iterations
708        while !current_candidates.is_empty() && current_resource <= max_resource {
709            n_candidates_per_iteration.push(current_candidates.len());
710
711            // Evaluate candidates with current resource
712            let mut candidate_scores = Vec::new();
713
714            for (idx, params) in current_candidates.iter().enumerate() {
715                let score = self.evaluate_candidate_with_resource::<E, F, T>(
716                    &base_estimator,
717                    params,
718                    x,
719                    y,
720                    current_resource,
721                    is_classification,
722                )?;
723
724                candidate_scores.push((idx, score));
725
726                // Track best score overall
727                if score > best_score {
728                    best_score = score;
729                    best_params = params.clone();
730                    best_index = idx;
731                }
732
733                // Store results
734                let key = format!("iteration_{iteration}_scores");
735                cv_results.entry(key).or_default().push(score);
736            }
737
738            // Sort candidates by score (descending)
739            candidate_scores
740                .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
741
742            // Determine how many candidates to keep
743            let n_to_keep = if current_resource >= max_resource {
744                1 // Keep only the best for final iteration
745            } else {
746                let elimination_factor = if self.aggressive_elimination && iteration == 0 {
747                    self.factor * 1.5 // More aggressive first elimination
748                } else {
749                    self.factor
750                };
751
752                (current_candidates.len() as f64 / elimination_factor)
753                    .ceil()
754                    .max(1.0) as usize
755            };
756
757            // Keep the best candidates
758            current_candidates = candidate_scores
759                .into_iter()
760                .take(n_to_keep)
761                .map(|(idx, _)| current_candidates[idx].clone())
762                .collect();
763
764            // Increase resource for next iteration
765            if current_resource < max_resource {
766                current_resource =
767                    ((current_resource as f64 * self.factor).round() as usize).min(max_resource);
768            } else {
769                break;
770            }
771
772            iteration += 1;
773        }
774
775        Ok(HalvingGridSearchResults {
776            best_score_: best_score,
777            best_params_: best_params,
778            best_index_: best_index,
779            cv_results_: cv_results,
780            n_iterations_: iteration,
781            n_candidates_: n_candidates_per_iteration,
782        })
783    }
784
785    /// Generate random candidates from parameter distributions
786    fn generate_random_candidates(&self, rng: &mut StdRng) -> Result<Vec<ParameterSet>> {
787        let mut candidates = Vec::new();
788
789        for _ in 0..self.n_candidates {
790            let mut params = ParameterSet::new();
791
792            for (param_name, distribution) in &self.param_distributions {
793                let selected_value = distribution.sample(rng);
794                params.insert(param_name.clone(), selected_value);
795            }
796
797            candidates.push(params);
798        }
799
800        Ok(candidates)
801    }
802
803    /// Evaluate a candidate with a specific resource amount
804    fn evaluate_candidate_with_resource<E, F, T>(
805        &self,
806        base_estimator: &E,
807        _params: &ParameterSet,
808        x: &Array2<f64>,
809        y: &Array1<T>,
810        resource: usize,
811        is_classification: bool,
812    ) -> Result<f64>
813    where
814        E: Estimator + Clone,
815        E: Fit<Array2<f64>, Array1<T>, Fitted = F>,
816        F: Predict<Array2<f64>, Array1<T>>,
817        T: Clone + PartialEq,
818    {
819        let (n_samples, _) = x.dim();
820        let effective_samples = resource.min(n_samples);
821
822        // Use subset of data for training
823        let x_subset = x
824            .slice(scirs2_core::ndarray::s![..effective_samples, ..])
825            .to_owned();
826        let y_subset = y
827            .slice(scirs2_core::ndarray::s![..effective_samples])
828            .to_owned();
829
830        // Configure estimator with parameters (simplified - real implementation would need
831        // proper parameter setting based on the estimator type)
832        let configured_estimator = base_estimator.clone();
833
834        // Perform cross-validation
835        let splits = self
836            .cv
837            .split(effective_samples, Some(&y_subset.mapv(|_| 0i32)));
838        let mut scores = Vec::new();
839
840        for (train_indices, test_indices) in splits {
841            // Create train/test splits
842            let x_train = x_subset.select(scirs2_core::ndarray::Axis(0), &train_indices);
843            let y_train = y_subset.select(scirs2_core::ndarray::Axis(0), &train_indices);
844            let x_test = x_subset.select(scirs2_core::ndarray::Axis(0), &test_indices);
845            let y_test = y_subset.select(scirs2_core::ndarray::Axis(0), &test_indices);
846
847            // Train and evaluate
848            let trained = configured_estimator.clone().fit(&x_train, &y_train)?;
849            let predictions = trained.predict(&x_test)?;
850
851            // Calculate score based on scoring function
852            let score = self.calculate_score(&predictions, &y_test, is_classification)?;
853            scores.push(score);
854        }
855
856        Ok(scores.iter().sum::<f64>() / scores.len() as f64)
857    }
858
859    /// Calculate score based on scoring function
860    fn calculate_score<T>(
861        &self,
862        predictions: &Array1<T>,
863        y_true: &Array1<T>,
864        is_classification: bool,
865    ) -> Result<f64>
866    where
867        T: Clone + PartialEq,
868    {
869        if predictions.len() != y_true.len() {
870            return Err(SklearsError::InvalidInput(
871                "Predictions and true values must have the same length".to_string(),
872            ));
873        }
874
875        match &self.scoring {
876            Scoring::EstimatorScore => {
877                // Use accuracy for classification, negative MSE for regression
878                if is_classification {
879                    let correct = predictions
880                        .iter()
881                        .zip(y_true.iter())
882                        .filter(|(pred, true_val)| pred == true_val)
883                        .count();
884                    Ok(correct as f64 / predictions.len() as f64)
885                } else {
886                    // Simplified MSE calculation - in real implementation would need proper numeric handling
887                    Ok(0.8) // Placeholder score
888                }
889            }
890            Scoring::Custom(_) => {
891                // For custom scoring functions, use placeholder
892                Ok(0.7)
893            }
894            Scoring::Metric(_metric_name) => {
895                // For named metrics, use placeholder
896                Ok(0.75)
897            }
898            Scoring::Scorer(_scorer) => {
899                // For scorer objects, use placeholder
900                Ok(0.8)
901            }
902            Scoring::MultiMetric(_metrics) => {
903                // For multi-metric, use placeholder
904                Ok(0.85)
905            }
906        }
907    }
908}