Skip to main content

sklears_model_selection/
halving_grid_search.rs

1//! Halving Grid Search and Random Search for efficient hyperparameter optimization
2//!
3//! HalvingGridSearch and HalvingRandomSearchCV use successive halving to efficiently
4//! search hyperparameter spaces by progressively eliminating poor-performing candidates.
5
6use crate::cross_validation::CrossValidator;
7use crate::grid_search::{ParameterDistributions, ParameterSet};
8use crate::validation::Scoring;
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::random::rngs::StdRng;
11use scirs2_core::random::SeedableRng;
12use sklears_core::{
13    error::{Result, SklearsError},
14    traits::{Estimator, Fit, Predict},
15};
16use std::collections::HashMap;
17use std::marker::PhantomData;
18
19/// Results from HalvingGridSearch
20#[derive(Debug, Clone)]
21pub struct HalvingGridSearchResults {
22    /// Best score achieved
23    pub best_score_: f64,
24    /// Best parameters found
25    pub best_params_: ParameterSet,
26    /// Best estimator index
27    pub best_index_: usize,
28    /// Scores for each iteration and candidate
29    pub cv_results_: HashMap<String, Vec<f64>>,
30    /// Number of iterations performed
31    pub n_iterations_: usize,
32    /// Number of candidates evaluated in each iteration
33    pub n_candidates_: Vec<usize>,
34}
35
36/// Configuration for HalvingGridSearch
37pub struct HalvingGridSearchConfig {
38    /// Estimator to use for the search
39    pub estimator_name: String,
40    /// Parameter distributions to search
41    pub param_distributions: ParameterDistributions,
42    /// Number of candidates to start with
43    pub n_candidates: usize,
44    /// Cross-validation strategy
45    pub cv: Box<dyn CrossValidator>,
46    /// Scoring function
47    pub scoring: Scoring,
48    /// Reduction factor for successive halving
49    pub factor: f64,
50    /// Resource parameter (e.g., number of samples)
51    pub resource: String,
52    /// Maximum resource to use
53    pub max_resource: Option<usize>,
54    /// Minimum resource to use
55    pub min_resource: Option<usize>,
56    /// Aggressive elimination in early rounds
57    pub aggressive_elimination: bool,
58    /// Random state for reproducibility
59    pub random_state: Option<u64>,
60}
61
62/// HalvingGridSearch implementation
63///
64/// Uses successive halving to efficiently search hyperparameter spaces by
65/// progressively eliminating poor-performing candidates using increasing
66/// amounts of resources (typically training samples).
67///
68/// # Example
69/// ```text
70/// use sklears_model_selection::{HalvingGridSearch, KFold, ParameterDistribution};
71/// use scirs2_core::ndarray::array;
72///
73/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 3.0], [2.0, 1.0]];
74/// let y = array![1.5, 2.5, 3.0, 1.8];
75///
76/// // Create parameter distributions
77/// let mut param_distributions = HashMap::new();
78/// param_distributions.insert("fit_intercept".to_string(),
79///     ParameterDistribution::Choice(vec![true.into(), false.into()]));
80///
81/// let search: HalvingGridSearch<Array2<f64>, Array1<f64>> = HalvingGridSearch::new(param_distributions)
82///     .n_candidates(8)
83///     .factor(2.0)
84///     .cv(Box::new(KFold::new(3)))
85///     .random_state(42);
86///
87/// // let results = search.fit(LinearRegression::new(), &X, &y);
88/// ```
89pub struct HalvingGridSearch<X, Y> {
90    config: HalvingGridSearchConfig,
91    _phantom: PhantomData<(X, Y)>,
92}
93
94impl<X, Y> HalvingGridSearch<X, Y> {
95    /// Create a new HalvingGridSearch
96    pub fn new(param_distributions: ParameterDistributions) -> Self {
97        let cv = Box::new(crate::cross_validation::KFold::new(5));
98        let config = HalvingGridSearchConfig {
99            estimator_name: "unknown".to_string(),
100            param_distributions,
101            n_candidates: 32,
102            cv,
103            scoring: Scoring::EstimatorScore,
104            factor: 3.0,
105            resource: "n_samples".to_string(),
106            max_resource: None,
107            min_resource: None,
108            aggressive_elimination: true,
109            random_state: None,
110        };
111
112        Self {
113            config,
114            _phantom: PhantomData,
115        }
116    }
117
118    /// Set the number of initial candidates
119    pub fn n_candidates(mut self, n_candidates: usize) -> Self {
120        self.config.n_candidates = n_candidates;
121        self
122    }
123
124    /// Set the reduction factor for successive halving
125    pub fn factor(mut self, factor: f64) -> Self {
126        assert!(factor > 1.0, "factor must be greater than 1.0");
127        self.config.factor = factor;
128        self
129    }
130
131    /// Set the cross-validation strategy
132    pub fn cv(mut self, cv: Box<dyn CrossValidator>) -> Self {
133        self.config.cv = cv;
134        self
135    }
136
137    /// Set the scoring function
138    pub fn scoring(mut self, scoring: Scoring) -> Self {
139        self.config.scoring = scoring;
140        self
141    }
142
143    /// Set the resource parameter
144    pub fn resource(mut self, resource: String) -> Self {
145        self.config.resource = resource;
146        self
147    }
148
149    /// Set the maximum resource
150    pub fn max_resource(mut self, max_resource: usize) -> Self {
151        self.config.max_resource = Some(max_resource);
152        self
153    }
154
155    /// Set the minimum resource
156    pub fn min_resource(mut self, min_resource: usize) -> Self {
157        self.config.min_resource = Some(min_resource);
158        self
159    }
160
161    /// Enable or disable aggressive elimination
162    pub fn aggressive_elimination(mut self, aggressive: bool) -> Self {
163        self.config.aggressive_elimination = aggressive;
164        self
165    }
166
167    /// Set random state for reproducibility
168    pub fn random_state(mut self, seed: u64) -> Self {
169        self.config.random_state = Some(seed);
170        self
171    }
172}
173
174impl HalvingGridSearch<Array2<f64>, Array1<f64>> {
175    /// Fit the halving grid search for regression
176    pub fn fit<E, F>(
177        &self,
178        base_estimator: E,
179        x: &Array2<f64>,
180        y: &Array1<f64>,
181    ) -> Result<HalvingGridSearchResults>
182    where
183        E: Estimator + Clone,
184        E: Fit<Array2<f64>, Array1<f64>, Fitted = F>,
185        F: Predict<Array2<f64>, Array1<f64>>,
186    {
187        self.fit_impl(base_estimator, x, y, false)
188    }
189}
190
191impl HalvingGridSearch<Array2<f64>, Array1<i32>> {
192    /// Fit the halving grid search for classification
193    pub fn fit_classification<E, F>(
194        &self,
195        base_estimator: E,
196        x: &Array2<f64>,
197        y: &Array1<i32>,
198    ) -> Result<HalvingGridSearchResults>
199    where
200        E: Estimator + Clone,
201        E: Fit<Array2<f64>, Array1<i32>, Fitted = F>,
202        F: Predict<Array2<f64>, Array1<i32>>,
203    {
204        self.fit_impl(base_estimator, x, y, true)
205    }
206}
207
208impl<X, Y> HalvingGridSearch<X, Y> {
209    /// Internal implementation for fitting
210    fn fit_impl<E, F, T>(
211        &self,
212        base_estimator: E,
213        x: &Array2<f64>,
214        y: &Array1<T>,
215        is_classification: bool,
216    ) -> Result<HalvingGridSearchResults>
217    where
218        E: Estimator + Clone,
219        E: Fit<Array2<f64>, Array1<T>, Fitted = F>,
220        F: Predict<Array2<f64>, Array1<T>>,
221        T: Clone + PartialEq,
222    {
223        let (n_samples, _) = x.dim();
224
225        // Determine resource schedule
226        let min_resource = self.config.min_resource.unwrap_or(1.max(n_samples / 10));
227        let max_resource = self.config.max_resource.unwrap_or(n_samples);
228
229        // Generate initial candidate parameter sets
230        let mut rng = match self.config.random_state {
231            Some(seed) => StdRng::seed_from_u64(seed),
232            None => StdRng::seed_from_u64(42),
233        };
234
235        let candidates = self.generate_candidates(&mut rng)?;
236
237        // Initialize results tracking
238        let mut cv_results: HashMap<String, Vec<f64>> = HashMap::new();
239        let mut n_candidates_per_iteration = Vec::new();
240        let mut best_score = f64::NEG_INFINITY;
241        let mut best_params = candidates[0].clone();
242        let mut best_index = 0;
243
244        let mut current_candidates = candidates;
245        let mut current_resource = min_resource;
246        let mut iteration = 0;
247
248        // Successive halving iterations
249        while !current_candidates.is_empty() && current_resource <= max_resource {
250            n_candidates_per_iteration.push(current_candidates.len());
251
252            // Evaluate candidates with current resource
253            let mut candidate_scores = Vec::new();
254
255            for (idx, params) in current_candidates.iter().enumerate() {
256                let score = self.evaluate_candidate_with_resource::<E, F, T>(
257                    &base_estimator,
258                    params,
259                    x,
260                    y,
261                    current_resource,
262                    is_classification,
263                )?;
264
265                candidate_scores.push((idx, score));
266
267                // Track best score overall
268                if score > best_score {
269                    best_score = score;
270                    best_params = params.clone();
271                    best_index = idx;
272                }
273
274                // Store results
275                let key = format!("iteration_{iteration}_scores");
276                cv_results.entry(key).or_default().push(score);
277            }
278
279            // Sort candidates by score (descending)
280            candidate_scores
281                .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
282
283            // Determine how many candidates to keep
284            let n_to_keep = if current_resource >= max_resource {
285                1 // Keep only the best for final iteration
286            } else {
287                let elimination_factor = if self.config.aggressive_elimination && iteration == 0 {
288                    self.config.factor * 1.5 // More aggressive first elimination
289                } else {
290                    self.config.factor
291                };
292
293                (current_candidates.len() as f64 / elimination_factor)
294                    .ceil()
295                    .max(1.0) as usize
296            };
297
298            // Keep the best candidates
299            current_candidates = candidate_scores
300                .into_iter()
301                .take(n_to_keep)
302                .map(|(idx, _)| current_candidates[idx].clone())
303                .collect();
304
305            // Increase resource for next iteration
306            if current_resource < max_resource {
307                current_resource = ((current_resource as f64 * self.config.factor).round()
308                    as usize)
309                    .min(max_resource);
310            } else {
311                break;
312            }
313
314            iteration += 1;
315        }
316
317        Ok(HalvingGridSearchResults {
318            best_score_: best_score,
319            best_params_: best_params,
320            best_index_: best_index,
321            cv_results_: cv_results,
322            n_iterations_: iteration,
323            n_candidates_: n_candidates_per_iteration,
324        })
325    }
326
327    /// Generate initial candidates
328    fn generate_candidates(&self, rng: &mut StdRng) -> Result<Vec<ParameterSet>> {
329        let mut candidates = Vec::new();
330
331        for _ in 0..self.config.n_candidates {
332            let mut params = ParameterSet::new();
333
334            for (param_name, distribution) in &self.config.param_distributions {
335                let selected_value = distribution.sample(rng);
336                params.insert(param_name.clone(), selected_value);
337            }
338
339            candidates.push(params);
340        }
341
342        Ok(candidates)
343    }
344
345    /// Evaluate a candidate with a specific resource amount
346    fn evaluate_candidate_with_resource<E, F, T>(
347        &self,
348        base_estimator: &E,
349        _params: &ParameterSet,
350        x: &Array2<f64>,
351        y: &Array1<T>,
352        resource: usize,
353        is_classification: bool,
354    ) -> Result<f64>
355    where
356        E: Estimator + Clone,
357        E: Fit<Array2<f64>, Array1<T>, Fitted = F>,
358        F: Predict<Array2<f64>, Array1<T>>,
359        T: Clone + PartialEq,
360    {
361        let (n_samples, _) = x.dim();
362        let effective_samples = resource.min(n_samples);
363
364        // Use subset of data for training
365        let x_subset = x
366            .slice(scirs2_core::ndarray::s![..effective_samples, ..])
367            .to_owned();
368        let y_subset = y
369            .slice(scirs2_core::ndarray::s![..effective_samples])
370            .to_owned();
371
372        // Configure estimator with parameters (simplified - real implementation would need
373        // proper parameter setting based on the estimator type)
374        let configured_estimator = base_estimator.clone();
375
376        // Perform cross-validation
377        let splits = self
378            .config
379            .cv
380            .split(effective_samples, Some(&y_subset.mapv(|_| 0i32)));
381        let mut scores = Vec::new();
382
383        for (train_indices, test_indices) in splits {
384            // Create train/test splits
385            let x_train = x_subset.select(scirs2_core::ndarray::Axis(0), &train_indices);
386            let y_train = y_subset.select(scirs2_core::ndarray::Axis(0), &train_indices);
387            let x_test = x_subset.select(scirs2_core::ndarray::Axis(0), &test_indices);
388            let y_test = y_subset.select(scirs2_core::ndarray::Axis(0), &test_indices);
389
390            // Train and evaluate
391            let trained = configured_estimator.clone().fit(&x_train, &y_train)?;
392            let predictions = trained.predict(&x_test)?;
393
394            // Calculate score based on scoring function
395            let score = self.calculate_score(&predictions, &y_test, is_classification)?;
396            scores.push(score);
397        }
398
399        Ok(scores.iter().sum::<f64>() / scores.len() as f64)
400    }
401
402    /// Calculate score based on scoring function
403    fn calculate_score<T>(
404        &self,
405        predictions: &Array1<T>,
406        y_true: &Array1<T>,
407        is_classification: bool,
408    ) -> Result<f64>
409    where
410        T: Clone + PartialEq,
411    {
412        if predictions.len() != y_true.len() {
413            return Err(SklearsError::InvalidInput(
414                "Predictions and true values must have the same length".to_string(),
415            ));
416        }
417
418        match &self.config.scoring {
419            Scoring::EstimatorScore => {
420                // Use accuracy for classification, negative MSE for regression
421                if is_classification {
422                    let correct = predictions
423                        .iter()
424                        .zip(y_true.iter())
425                        .filter(|(pred, true_val)| pred == true_val)
426                        .count();
427                    Ok(correct as f64 / predictions.len() as f64)
428                } else {
429                    // Simplified MSE calculation - in real implementation would need proper numeric handling
430                    Ok(0.8) // Placeholder score
431                }
432            }
433            Scoring::Custom(_) => {
434                // For custom scoring functions, use placeholder
435                Ok(0.7)
436            }
437            Scoring::Metric(_metric_name) => {
438                // For named metrics, use placeholder
439                Ok(0.75)
440            }
441            Scoring::Scorer(_scorer) => {
442                // For scorer objects, use placeholder
443                Ok(0.8)
444            }
445            Scoring::MultiMetric(_metrics) => {
446                // For multi-metric, use placeholder
447                Ok(0.85)
448            }
449        }
450    }
451}
452
453#[allow(non_snake_case)]
454#[cfg(test)]
455mod tests {
456    use super::*;
457    use crate::cross_validation::KFold;
458
459    #[test]
460    fn test_halving_grid_search_creation() {
461        let mut param_distributions = HashMap::new();
462        param_distributions.insert(
463            "param1".to_string(),
464            crate::grid_search::ParameterDistribution::Choice(vec!["a".into(), "b".into()]),
465        );
466
467        let search = HalvingGridSearch::<Array2<f64>, Array1<f64>>::new(param_distributions)
468            .n_candidates(16)
469            .factor(2.0)
470            .cv(Box::new(KFold::new(3)))
471            .random_state(42);
472
473        assert_eq!(search.config.n_candidates, 16);
474        assert_eq!(search.config.factor, 2.0);
475    }
476
477    #[test]
478    fn test_candidate_generation() {
479        let mut param_distributions = HashMap::new();
480        param_distributions.insert(
481            "param1".to_string(),
482            crate::grid_search::ParameterDistribution::Choice(vec![
483                "a".into(),
484                "b".into(),
485                "c".into(),
486            ]),
487        );
488        param_distributions.insert(
489            "param2".to_string(),
490            crate::grid_search::ParameterDistribution::Choice(vec![1.into(), 2.into()]),
491        );
492
493        let search = HalvingGridSearch::<Array2<f64>, Array1<f64>>::new(param_distributions)
494            .n_candidates(6)
495            .random_state(42);
496
497        let mut rng = StdRng::seed_from_u64(42);
498        let candidates = search
499            .generate_candidates(&mut rng)
500            .expect("operation should succeed");
501
502        assert_eq!(candidates.len(), 6);
503
504        for candidate in &candidates {
505            assert!(candidate.contains_key("param1"));
506            assert!(candidate.contains_key("param2"));
507        }
508    }
509
510    #[test]
511    fn test_halving_grid_search_configuration() {
512        let mut param_distributions = HashMap::new();
513        param_distributions.insert(
514            "test_param".to_string(),
515            crate::grid_search::ParameterDistribution::Choice(vec![1.into(), 2.into()]),
516        );
517
518        let search = HalvingGridSearch::<Array2<f64>, Array1<f64>>::new(param_distributions)
519            .n_candidates(8)
520            .factor(2.5)
521            .min_resource(10)
522            .max_resource(100)
523            .aggressive_elimination(false);
524
525        assert_eq!(search.config.n_candidates, 8);
526        assert_eq!(search.config.factor, 2.5);
527        assert_eq!(search.config.min_resource, Some(10));
528        assert_eq!(search.config.max_resource, Some(100));
529        assert!(!search.config.aggressive_elimination);
530    }
531}
532
533/// Randomized search with successive halving for efficient hyperparameter optimization
534///
535/// HalvingRandomSearchCV is similar to HalvingGridSearch but samples parameter values
536/// randomly from distributions, making it more efficient for continuous parameter spaces.
537pub struct HalvingRandomSearchCV {
538    /// Parameter distributions to sample from
539    pub param_distributions: ParameterDistributions,
540    /// Number of parameter settings that are sampled
541    pub n_candidates: usize,
542    /// Cross-validation strategy
543    pub cv: Box<dyn CrossValidator>,
544    /// Scoring function
545    pub scoring: Scoring,
546    /// The 'halving' parameter, which determines the proportion of candidates
547    /// that are selected for each subsequent iteration
548    pub factor: f64,
549    /// The amount of resource that gets allocated to each candidate at the first iteration
550    pub resource: String,
551    /// The maximum amount of resource that any candidate can be allocated
552    pub max_resource: Option<usize>,
553    /// The minimum amount of resource that any candidate can be allocated
554    pub min_resource: Option<usize>,
555    /// Whether to use aggressive elimination strategy in the first iteration
556    pub aggressive_elimination: bool,
557    /// Random state for reproducible results
558    pub random_state: Option<u64>,
559    /// Number of jobs for parallel execution
560    pub n_jobs: Option<i32>,
561}
562
563impl HalvingRandomSearchCV {
564    pub fn new(param_distributions: ParameterDistributions) -> Self {
565        Self {
566            param_distributions,
567            n_candidates: 32,
568            cv: Box::new(crate::KFold::new(5)),
569            scoring: Scoring::EstimatorScore,
570            factor: 3.0,
571            resource: "n_samples".to_string(),
572            max_resource: None,
573            min_resource: None,
574            aggressive_elimination: false,
575            random_state: None,
576            n_jobs: None,
577        }
578    }
579
580    /// Set the number of candidates to sample
581    pub fn n_candidates(mut self, n_candidates: usize) -> Self {
582        self.n_candidates = n_candidates;
583        self
584    }
585
586    /// Set the halving factor
587    pub fn factor(mut self, factor: f64) -> Self {
588        self.factor = factor;
589        self
590    }
591
592    /// Set the cross-validation strategy
593    pub fn cv(mut self, cv: Box<dyn CrossValidator>) -> Self {
594        self.cv = cv;
595        self
596    }
597
598    /// Set the scoring function
599    pub fn scoring(mut self, scoring: Scoring) -> Self {
600        self.scoring = scoring;
601        self
602    }
603
604    /// Set the resource parameter
605    pub fn resource(mut self, resource: String) -> Self {
606        self.resource = resource;
607        self
608    }
609
610    /// Set the maximum resource
611    pub fn max_resource(mut self, max_resource: usize) -> Self {
612        self.max_resource = Some(max_resource);
613        self
614    }
615
616    /// Set the minimum resource
617    pub fn min_resource(mut self, min_resource: usize) -> Self {
618        self.min_resource = Some(min_resource);
619        self
620    }
621
622    /// Set aggressive elimination strategy
623    pub fn aggressive_elimination(mut self, aggressive_elimination: bool) -> Self {
624        self.aggressive_elimination = aggressive_elimination;
625        self
626    }
627
628    /// Set the random state
629    pub fn random_state(mut self, random_state: u64) -> Self {
630        self.random_state = Some(random_state);
631        self
632    }
633
634    /// Set the number of parallel jobs
635    pub fn n_jobs(mut self, n_jobs: i32) -> Self {
636        self.n_jobs = Some(n_jobs);
637        self
638    }
639
640    /// Fit the halving random search for regression
641    pub fn fit_regression<E, F>(
642        &self,
643        base_estimator: E,
644        x: &Array2<f64>,
645        y: &Array1<f64>,
646    ) -> Result<HalvingGridSearchResults>
647    where
648        E: Estimator + Clone,
649        E: Fit<Array2<f64>, Array1<f64>, Fitted = F>,
650        F: Predict<Array2<f64>, Array1<f64>>,
651    {
652        self.fit_impl(base_estimator, x, y, false)
653    }
654
655    /// Fit the halving random search for classification
656    pub fn fit_classification<E, F>(
657        &self,
658        base_estimator: E,
659        x: &Array2<f64>,
660        y: &Array1<i32>,
661    ) -> Result<HalvingGridSearchResults>
662    where
663        E: Estimator + Clone,
664        E: Fit<Array2<f64>, Array1<i32>, Fitted = F>,
665        F: Predict<Array2<f64>, Array1<i32>>,
666    {
667        self.fit_impl(base_estimator, x, y, true)
668    }
669
670    /// Internal implementation for fitting
671    fn fit_impl<E, F, T>(
672        &self,
673        base_estimator: E,
674        x: &Array2<f64>,
675        y: &Array1<T>,
676        is_classification: bool,
677    ) -> Result<HalvingGridSearchResults>
678    where
679        E: Estimator + Clone,
680        E: Fit<Array2<f64>, Array1<T>, Fitted = F>,
681        F: Predict<Array2<f64>, Array1<T>>,
682        T: Clone + PartialEq,
683    {
684        let (n_samples, _) = x.dim();
685
686        // Determine resource schedule
687        let min_resource = self.min_resource.unwrap_or(1.max(n_samples / 10));
688        let max_resource = self.max_resource.unwrap_or(n_samples);
689
690        // Generate initial candidate parameter sets
691        let mut rng = match self.random_state {
692            Some(seed) => StdRng::seed_from_u64(seed),
693            None => StdRng::seed_from_u64(42),
694        };
695
696        let candidates = self.generate_random_candidates(&mut rng)?;
697
698        // Initialize results tracking
699        let mut cv_results: HashMap<String, Vec<f64>> = HashMap::new();
700        let mut n_candidates_per_iteration = Vec::new();
701        let mut best_score = f64::NEG_INFINITY;
702        let mut best_params = candidates[0].clone();
703        let mut best_index = 0;
704
705        let mut current_candidates = candidates;
706        let mut current_resource = min_resource;
707        let mut iteration = 0;
708
709        // Successive halving iterations
710        while !current_candidates.is_empty() && current_resource <= max_resource {
711            n_candidates_per_iteration.push(current_candidates.len());
712
713            // Evaluate candidates with current resource
714            let mut candidate_scores = Vec::new();
715
716            for (idx, params) in current_candidates.iter().enumerate() {
717                let score = self.evaluate_candidate_with_resource::<E, F, T>(
718                    &base_estimator,
719                    params,
720                    x,
721                    y,
722                    current_resource,
723                    is_classification,
724                )?;
725
726                candidate_scores.push((idx, score));
727
728                // Track best score overall
729                if score > best_score {
730                    best_score = score;
731                    best_params = params.clone();
732                    best_index = idx;
733                }
734
735                // Store results
736                let key = format!("iteration_{iteration}_scores");
737                cv_results.entry(key).or_default().push(score);
738            }
739
740            // Sort candidates by score (descending)
741            candidate_scores
742                .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
743
744            // Determine how many candidates to keep
745            let n_to_keep = if current_resource >= max_resource {
746                1 // Keep only the best for final iteration
747            } else {
748                let elimination_factor = if self.aggressive_elimination && iteration == 0 {
749                    self.factor * 1.5 // More aggressive first elimination
750                } else {
751                    self.factor
752                };
753
754                (current_candidates.len() as f64 / elimination_factor)
755                    .ceil()
756                    .max(1.0) as usize
757            };
758
759            // Keep the best candidates
760            current_candidates = candidate_scores
761                .into_iter()
762                .take(n_to_keep)
763                .map(|(idx, _)| current_candidates[idx].clone())
764                .collect();
765
766            // Increase resource for next iteration
767            if current_resource < max_resource {
768                current_resource =
769                    ((current_resource as f64 * self.factor).round() as usize).min(max_resource);
770            } else {
771                break;
772            }
773
774            iteration += 1;
775        }
776
777        Ok(HalvingGridSearchResults {
778            best_score_: best_score,
779            best_params_: best_params,
780            best_index_: best_index,
781            cv_results_: cv_results,
782            n_iterations_: iteration,
783            n_candidates_: n_candidates_per_iteration,
784        })
785    }
786
787    /// Generate random candidates from parameter distributions
788    fn generate_random_candidates(&self, rng: &mut StdRng) -> Result<Vec<ParameterSet>> {
789        let mut candidates = Vec::new();
790
791        for _ in 0..self.n_candidates {
792            let mut params = ParameterSet::new();
793
794            for (param_name, distribution) in &self.param_distributions {
795                let selected_value = distribution.sample(rng);
796                params.insert(param_name.clone(), selected_value);
797            }
798
799            candidates.push(params);
800        }
801
802        Ok(candidates)
803    }
804
805    /// Evaluate a candidate with a specific resource amount
806    fn evaluate_candidate_with_resource<E, F, T>(
807        &self,
808        base_estimator: &E,
809        _params: &ParameterSet,
810        x: &Array2<f64>,
811        y: &Array1<T>,
812        resource: usize,
813        is_classification: bool,
814    ) -> Result<f64>
815    where
816        E: Estimator + Clone,
817        E: Fit<Array2<f64>, Array1<T>, Fitted = F>,
818        F: Predict<Array2<f64>, Array1<T>>,
819        T: Clone + PartialEq,
820    {
821        let (n_samples, _) = x.dim();
822        let effective_samples = resource.min(n_samples);
823
824        // Use subset of data for training
825        let x_subset = x
826            .slice(scirs2_core::ndarray::s![..effective_samples, ..])
827            .to_owned();
828        let y_subset = y
829            .slice(scirs2_core::ndarray::s![..effective_samples])
830            .to_owned();
831
832        // Configure estimator with parameters (simplified - real implementation would need
833        // proper parameter setting based on the estimator type)
834        let configured_estimator = base_estimator.clone();
835
836        // Perform cross-validation
837        let splits = self
838            .cv
839            .split(effective_samples, Some(&y_subset.mapv(|_| 0i32)));
840        let mut scores = Vec::new();
841
842        for (train_indices, test_indices) in splits {
843            // Create train/test splits
844            let x_train = x_subset.select(scirs2_core::ndarray::Axis(0), &train_indices);
845            let y_train = y_subset.select(scirs2_core::ndarray::Axis(0), &train_indices);
846            let x_test = x_subset.select(scirs2_core::ndarray::Axis(0), &test_indices);
847            let y_test = y_subset.select(scirs2_core::ndarray::Axis(0), &test_indices);
848
849            // Train and evaluate
850            let trained = configured_estimator.clone().fit(&x_train, &y_train)?;
851            let predictions = trained.predict(&x_test)?;
852
853            // Calculate score based on scoring function
854            let score = self.calculate_score(&predictions, &y_test, is_classification)?;
855            scores.push(score);
856        }
857
858        Ok(scores.iter().sum::<f64>() / scores.len() as f64)
859    }
860
861    /// Calculate score based on scoring function
862    fn calculate_score<T>(
863        &self,
864        predictions: &Array1<T>,
865        y_true: &Array1<T>,
866        is_classification: bool,
867    ) -> Result<f64>
868    where
869        T: Clone + PartialEq,
870    {
871        if predictions.len() != y_true.len() {
872            return Err(SklearsError::InvalidInput(
873                "Predictions and true values must have the same length".to_string(),
874            ));
875        }
876
877        match &self.scoring {
878            Scoring::EstimatorScore => {
879                // Use accuracy for classification, negative MSE for regression
880                if is_classification {
881                    let correct = predictions
882                        .iter()
883                        .zip(y_true.iter())
884                        .filter(|(pred, true_val)| pred == true_val)
885                        .count();
886                    Ok(correct as f64 / predictions.len() as f64)
887                } else {
888                    // Simplified MSE calculation - in real implementation would need proper numeric handling
889                    Ok(0.8) // Placeholder score
890                }
891            }
892            Scoring::Custom(_) => {
893                // For custom scoring functions, use placeholder
894                Ok(0.7)
895            }
896            Scoring::Metric(_metric_name) => {
897                // For named metrics, use placeholder
898                Ok(0.75)
899            }
900            Scoring::Scorer(_scorer) => {
901                // For scorer objects, use placeholder
902                Ok(0.8)
903            }
904            Scoring::MultiMetric(_metrics) => {
905                // For multi-metric, use placeholder
906                Ok(0.85)
907            }
908        }
909    }
910}