sklears_ensemble/voting/
core.rs

1//! Core VotingClassifier implementation with type-safe state management
2
3use scirs2_core::ndarray::{Array1, Array2};
4use sklears_core::{
5    error::{Result, SklearsError},
6    traits::{Fit, Predict, Trained, Untrained},
7    types::Float,
8};
9use std::marker::PhantomData;
10
11use crate::voting::{
12    config::{
13        EnsembleSizeAnalysis, EnsembleSizeRecommendations, VotingClassifierConfig, VotingStrategy,
14    },
15    ensemble::{ensemble_utils, EnsembleMember},
16    simd_ops::{
17        simd_bayesian_averaging, simd_bootstrap_aggregate, simd_confidence_weighted_voting,
18        simd_ensemble_disagreement, simd_entropy_weighted_voting, simd_hard_voting_weighted,
19        simd_soft_voting_weighted, simd_variance_weighted_voting,
20    },
21    strategies::{
22        adaptive_ensemble_voting, consensus_voting, dynamic_weight_adjustment, meta_voting,
23        rank_based_voting, temperature_scaled_voting, uncertainty_aware_voting,
24    },
25};
26
27/// Voting Classifier with type-safe state management
28///
29/// A voting classifier is an ensemble meta-algorithm that fits several base
30/// classifiers, each on the whole dataset. It then aggregates the individual
31/// predictions to form a final prediction.
32///
33/// # Examples
34///
35/// ```rust
36/// use sklears_ensemble::voting::{VotingClassifier, VotingStrategy};
37/// use sklears_ensemble::voting::config::VotingClassifierConfig;
38///
39/// // Create voting classifier with configuration
40/// let config = VotingClassifierConfig {
41///     voting: VotingStrategy::Soft,
42///     weights: Some(vec![1.0, 2.0, 1.5]),
43///     confidence_weighting: true,
44///     confidence_threshold: 0.8,
45///     ..Default::default()
46/// };
47/// let classifier = VotingClassifier::new(config);
48/// assert_eq!(classifier.estimators().len(), 0);
49/// ```
50pub struct VotingClassifier<State = Untrained> {
51    config: VotingClassifierConfig,
52    estimators_: Vec<Box<dyn EnsembleMember + Send + Sync>>,
53    classes_: Option<Array1<Float>>,
54    n_features_in_: Option<usize>,
55    state: PhantomData<State>,
56}
57
58/// Type alias for trained voting classifier
59pub type TrainedVotingClassifier = VotingClassifier<Trained>;
60
61impl VotingClassifier<Untrained> {
62    /// Create a new untrained voting classifier
63    pub fn new(config: VotingClassifierConfig) -> Self {
64        Self {
65            config,
66            estimators_: Vec::new(),
67            classes_: None,
68            n_features_in_: None,
69            state: PhantomData,
70        }
71    }
72
73    /// Create a builder for configuring the voting classifier
74    pub fn builder() -> VotingClassifierBuilder {
75        VotingClassifierBuilder::new()
76    }
77
78    /// Add an estimator to the ensemble
79    pub fn add_estimator(&mut self, estimator: Box<dyn EnsembleMember + Send + Sync>) {
80        self.estimators_.push(estimator);
81    }
82
83    /// Get the current configuration
84    pub fn config(&self) -> &VotingClassifierConfig {
85        &self.config
86    }
87
88    /// Get the estimators in the ensemble
89    pub fn estimators(&self) -> &[Box<dyn EnsembleMember + Send + Sync>] {
90        &self.estimators_
91    }
92
93    /// Optimize ensemble size based on training data
94    pub fn optimize_ensemble_size(
95        &self,
96        x: &Array2<Float>,
97        y: &Array1<Float>,
98    ) -> Result<EnsembleSizeRecommendations> {
99        let n_samples = x.nrows();
100        let n_features = x.ncols();
101
102        // Simple heuristics for ensemble size recommendations
103        let min_size = if n_features > 100 { 5 } else { 3 };
104        let max_size = (n_samples / 10).min(50).max(10);
105        let sweet_spot = (n_features / 5).max(5).min(20);
106        let diminishing_returns_threshold = sweet_spot + (sweet_spot / 2);
107
108        Ok(EnsembleSizeRecommendations {
109            min_size,
110            max_size,
111            sweet_spot,
112            diminishing_returns_threshold,
113        })
114    }
115
116    /// Analyze ensemble size performance characteristics
117    pub fn analyze_ensemble_size(
118        &self,
119        x: &Array2<Float>,
120        y: &Array1<Float>,
121    ) -> Result<EnsembleSizeAnalysis> {
122        let n_samples = x.nrows();
123        let n_features = x.ncols();
124
125        // Generate synthetic performance and diversity curves
126        let sizes: Vec<usize> = (1..=20).collect();
127        let mut performance_curve = Array1::zeros(sizes.len());
128        let mut diversity_curve = Array1::zeros(sizes.len());
129
130        for (i, &size) in sizes.iter().enumerate() {
131            // Simple model for performance curve: logarithmic growth with plateau
132            let perf = 0.5 + 0.45 * (1.0 - (-0.3 * size as Float).exp());
133            performance_curve[i] = perf;
134
135            // Simple model for diversity curve: increase then plateau
136            let div = 0.1 + 0.7 * (1.0 - (-0.2 * size as Float).exp());
137            diversity_curve[i] = div;
138        }
139
140        let optimal_size = 8;
141        let performance_plateau_size = 15;
142        let diversity_saturation_size = 12;
143
144        Ok(EnsembleSizeAnalysis {
145            performance_curve,
146            diversity_curve,
147            optimal_size,
148            performance_plateau_size,
149            diversity_saturation_size,
150        })
151    }
152}
153
154impl VotingClassifier<Trained> {
155    /// Get the classes discovered during training
156    pub fn classes(&self) -> &Array1<Float> {
157        self.classes_.as_ref().unwrap()
158    }
159
160    /// Get the number of features seen during training
161    pub fn n_features_in(&self) -> usize {
162        self.n_features_in_.unwrap()
163    }
164
165    /// Get the estimators in the ensemble
166    pub fn estimators(&self) -> &[Box<dyn EnsembleMember + Send + Sync>] {
167        &self.estimators_
168    }
169
170    /// Make predictions with confidence scores
171    pub fn predict_with_confidence(
172        &self,
173        x: &Array2<Float>,
174    ) -> Result<(Array1<Float>, Array1<Float>)> {
175        let predictions = self.predict(x)?;
176
177        // Calculate confidence based on ensemble agreement
178        let n_samples = x.nrows();
179        let mut confidence = Array1::ones(n_samples);
180
181        if self.estimators_.len() > 1 {
182            // Collect all predictions
183            let mut all_predictions = Vec::new();
184            for estimator in &self.estimators_ {
185                if estimator.is_fitted() {
186                    match estimator.predict(x) {
187                        Ok(pred) => all_predictions.push(pred),
188                        Err(_) => continue,
189                    }
190                }
191            }
192
193            // Calculate ensemble disagreement as inverse confidence
194            if !all_predictions.is_empty() {
195                let mut disagreement = Array1::zeros(n_samples);
196                if simd_ensemble_disagreement(&all_predictions, &mut disagreement).is_ok() {
197                    for i in 0..n_samples {
198                        confidence[i] = 1.0 / (1.0 + disagreement[i]);
199                    }
200                }
201            }
202        }
203
204        Ok((predictions, confidence))
205    }
206
207    /// Make predictions with confidence weighting
208    pub fn predict_with_confidence_weighting(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
209        if x.ncols() != self.n_features_in() {
210            return Err(SklearsError::FeatureMismatch {
211                expected: self.n_features_in(),
212                actual: x.ncols(),
213            });
214        }
215
216        // Apply confidence weighting to all estimator predictions
217        let mut all_predictions = Vec::new();
218        let mut confidence_weights = Vec::new();
219
220        for estimator in &self.estimators_ {
221            if estimator.is_fitted() {
222                match estimator.predict(x) {
223                    Ok(pred) => {
224                        all_predictions.push(pred);
225                        confidence_weights.push(estimator.confidence());
226                    }
227                    Err(_) => continue,
228                }
229            }
230        }
231
232        if all_predictions.is_empty() {
233            return Err(SklearsError::InvalidOperation(
234                "No fitted estimators available".to_string(),
235            ));
236        }
237
238        // Use weighted voting based on confidence scores
239        let n_samples = x.nrows();
240        let mut result = Array1::zeros(n_samples);
241
242        for sample_idx in 0..n_samples {
243            let mut weighted_sum = 0.0;
244            let mut weight_sum = 0.0;
245
246            for (pred_idx, prediction) in all_predictions.iter().enumerate() {
247                let weight = confidence_weights[pred_idx];
248                weighted_sum += prediction[sample_idx] * weight;
249                weight_sum += weight;
250            }
251
252            result[sample_idx] = if weight_sum > 1e-8 {
253                weighted_sum / weight_sum
254            } else {
255                all_predictions
256                    .iter()
257                    .map(|pred| pred[sample_idx])
258                    .sum::<Float>()
259                    / all_predictions.len() as Float
260            };
261        }
262
263        Ok(result)
264    }
265
266    /// Get confidence scores for individual estimators
267    pub fn estimator_confidence_scores(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
268        let mut confidence_scores = Array1::zeros(self.estimators_.len());
269
270        for (i, estimator) in self.estimators_.iter().enumerate() {
271            confidence_scores[i] = estimator.confidence();
272        }
273
274        Ok(confidence_scores)
275    }
276
277    /// Make probability predictions (if supported by estimators)
278    pub fn predict_proba(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
279        if x.ncols() != self.n_features_in() {
280            return Err(SklearsError::FeatureMismatch {
281                expected: self.n_features_in(),
282                actual: x.ncols(),
283            });
284        }
285
286        let mut all_probabilities = Vec::new();
287        let mut weights = Vec::new();
288
289        for estimator in &self.estimators_ {
290            if estimator.is_fitted() && estimator.supports_proba() {
291                match estimator.predict_proba(x) {
292                    Ok(proba) => {
293                        all_probabilities.push(proba);
294                        weights.push(estimator.weight());
295                    }
296                    Err(_) => continue,
297                }
298            }
299        }
300
301        if all_probabilities.is_empty() {
302            return Err(SklearsError::InvalidOperation(
303                "No estimators support probability predictions".to_string(),
304            ));
305        }
306
307        // Use the configured voting strategy for probability aggregation
308        match self.config.voting {
309            VotingStrategy::Soft | VotingStrategy::Weighted => {
310                Ok(simd_soft_voting_weighted(&all_probabilities, &weights))
311            }
312            VotingStrategy::EntropyWeighted => Ok(simd_entropy_weighted_voting(
313                &all_probabilities,
314                self.config.entropy_weight_factor as f32,
315            )),
316            VotingStrategy::VarianceWeighted => Ok(simd_variance_weighted_voting(
317                &all_probabilities,
318                self.config.variance_weight_factor as f32,
319            )),
320            VotingStrategy::ConfidenceWeighted => Ok(simd_confidence_weighted_voting(
321                &all_probabilities,
322                self.config.confidence_threshold as f32,
323            )),
324            VotingStrategy::BayesianAveraging => {
325                // Use weights as model evidences
326                Ok(simd_bayesian_averaging(&all_probabilities, &weights))
327            }
328            VotingStrategy::TemperatureScaled => {
329                temperature_scaled_voting(&all_probabilities, self.config.temperature as f32)
330            }
331            _ => {
332                // Fall back to simple soft voting
333                Ok(simd_soft_voting_weighted(&all_probabilities, &weights))
334            }
335        }
336    }
337
338    /// Update ensemble weights dynamically based on performance
339    pub fn update_weights_dynamically(&mut self, recent_performances: &[Float]) -> Result<()> {
340        if recent_performances.len() != self.estimators_.len() {
341            return Err(SklearsError::InvalidParameter {
342                name: "recent_performances".to_string(),
343                reason: "Performance array length must match number of estimators".to_string(),
344            });
345        }
346
347        let current_weights: Vec<Float> = self.estimators_.iter().map(|e| e.weight()).collect();
348
349        let new_weights = dynamic_weight_adjustment(
350            &current_weights,
351            recent_performances,
352            self.config.weight_adjustment_rate as f32,
353        )?;
354
355        for (estimator, &new_weight) in self.estimators_.iter_mut().zip(new_weights.iter()) {
356            estimator.set_weight(new_weight);
357        }
358
359        Ok(())
360    }
361
362    /// Get ensemble size recommendations
363    pub fn get_ensemble_size_recommendations(&self) -> EnsembleSizeRecommendations {
364        let current_size = self.estimators_.len();
365
366        EnsembleSizeRecommendations {
367            min_size: 3,
368            max_size: current_size * 2,
369            sweet_spot: (current_size + 5).min(15),
370            diminishing_returns_threshold: current_size + 10,
371        }
372    }
373}
374
375// Default implementation for untrained classifier
376impl Default for VotingClassifier<Untrained> {
377    fn default() -> Self {
378        Self::new(VotingClassifierConfig::default())
379    }
380}
381
382// Debug implementations
383impl std::fmt::Debug for VotingClassifier<Untrained> {
384    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
385        f.debug_struct("VotingClassifier<Untrained>")
386            .field("config", &self.config)
387            .field("n_estimators", &self.estimators_.len())
388            .finish()
389    }
390}
391
392impl std::fmt::Debug for VotingClassifier<Trained> {
393    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394        f.debug_struct("VotingClassifier<Trained>")
395            .field("config", &self.config)
396            .field("n_estimators", &self.estimators_.len())
397            .field("classes", &self.classes_)
398            .field("n_features_in", &self.n_features_in_)
399            .finish()
400    }
401}
402
403// Implement Predict trait for trained classifier
404impl Predict<Array2<Float>, Array1<Float>> for VotingClassifier<Trained> {
405    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
406        if x.ncols() != self.n_features_in() {
407            return Err(SklearsError::FeatureMismatch {
408                expected: self.n_features_in(),
409                actual: x.ncols(),
410            });
411        }
412
413        if self.estimators_.is_empty() {
414            return Err(SklearsError::InvalidOperation(
415                "No estimators in ensemble".to_string(),
416            ));
417        }
418
419        // Apply the configured voting strategy
420        match self.config.voting {
421            VotingStrategy::Hard => self.hard_voting(x),
422            VotingStrategy::Soft => self.soft_voting(x),
423            VotingStrategy::Weighted => self.weighted_voting(x),
424            VotingStrategy::ConfidenceWeighted => self.confidence_weighted_voting(x),
425            VotingStrategy::BayesianAveraging => self.bayesian_averaging(x),
426            VotingStrategy::RankBased => self.rank_based_voting(x),
427            VotingStrategy::MetaVoting => self.meta_voting(x),
428            VotingStrategy::DynamicWeightAdjustment => self.dynamic_voting(x),
429            VotingStrategy::UncertaintyAware => self.uncertainty_aware_voting(x),
430            VotingStrategy::ConsensusBased => self.consensus_voting(x),
431            VotingStrategy::EntropyWeighted => self.entropy_weighted_voting(x),
432            VotingStrategy::VarianceWeighted => self.variance_weighted_voting(x),
433            VotingStrategy::BootstrapAggregation => self.bootstrap_voting(x),
434            VotingStrategy::TemperatureScaled => self.temperature_scaled_voting(x),
435            VotingStrategy::AdaptiveEnsemble => self.adaptive_voting(x),
436        }
437    }
438}
439
440// Private implementation methods for different voting strategies
441impl VotingClassifier<Trained> {
442    fn hard_voting(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
443        let mut all_predictions = Vec::new();
444        let mut weights = Vec::new();
445
446        for estimator in &self.estimators_ {
447            if estimator.is_fitted() {
448                match estimator.predict(x) {
449                    Ok(pred) => {
450                        all_predictions.push(pred);
451                        weights.push(estimator.weight());
452                    }
453                    Err(_) => continue,
454                }
455            }
456        }
457
458        if all_predictions.is_empty() {
459            return Err(SklearsError::InvalidOperation(
460                "No fitted estimators available".to_string(),
461            ));
462        }
463
464        let classes = self.classes();
465        Ok(simd_hard_voting_weighted(
466            &all_predictions,
467            &weights,
468            classes.as_slice().unwrap(),
469        ))
470    }
471
472    fn soft_voting(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
473        let probabilities = self.predict_proba(x)?;
474        let n_samples = probabilities.nrows();
475        let mut result = Array1::zeros(n_samples);
476
477        // Convert probabilities to class predictions
478        for i in 0..n_samples {
479            let mut max_prob = probabilities[[i, 0]];
480            let mut best_class = 0;
481
482            for j in 1..probabilities.ncols() {
483                if probabilities[[i, j]] > max_prob {
484                    max_prob = probabilities[[i, j]];
485                    best_class = j;
486                }
487            }
488
489            result[i] = best_class as Float;
490        }
491
492        Ok(result)
493    }
494
495    fn weighted_voting(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
496        // Use configured weights if available, otherwise use estimator weights
497        let weights = if let Some(ref config_weights) = self.config.weights {
498            config_weights.clone()
499        } else {
500            self.estimators_.iter().map(|e| e.weight()).collect()
501        };
502
503        let mut all_predictions = Vec::new();
504        for estimator in &self.estimators_ {
505            if estimator.is_fitted() {
506                match estimator.predict(x) {
507                    Ok(pred) => all_predictions.push(pred),
508                    Err(_) => continue,
509                }
510            }
511        }
512
513        let classes = self.classes();
514        Ok(simd_hard_voting_weighted(
515            &all_predictions,
516            &weights,
517            classes.as_slice().unwrap(),
518        ))
519    }
520
521    fn confidence_weighted_voting(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
522        self.predict_with_confidence_weighting(x)
523    }
524
525    fn bayesian_averaging(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
526        let probabilities = self.predict_proba(x)?;
527        let n_samples = probabilities.nrows();
528        let mut result = Array1::zeros(n_samples);
529
530        for i in 0..n_samples {
531            let mut max_prob = probabilities[[i, 0]];
532            let mut best_class = 0;
533
534            for j in 1..probabilities.ncols() {
535                if probabilities[[i, j]] > max_prob {
536                    max_prob = probabilities[[i, j]];
537                    best_class = j;
538                }
539            }
540
541            result[i] = best_class as Float;
542        }
543
544        Ok(result)
545    }
546
547    fn rank_based_voting(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
548        let mut all_probabilities = Vec::new();
549
550        for estimator in &self.estimators_ {
551            if estimator.is_fitted() && estimator.supports_proba() {
552                match estimator.predict_proba(x) {
553                    Ok(proba) => all_probabilities.push(proba),
554                    Err(_) => continue,
555                }
556            }
557        }
558
559        if all_probabilities.is_empty() {
560            return self.hard_voting(x);
561        }
562
563        rank_based_voting(&all_probabilities)
564    }
565
566    fn meta_voting(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
567        // For simplicity, use uniform meta-weights
568        let n_samples = x.nrows();
569        let n_estimators = self.estimators_.len();
570        let meta_weights = Array2::ones((n_samples, n_estimators)) / n_estimators as Float;
571
572        let mut all_predictions = Vec::new();
573        for estimator in &self.estimators_ {
574            if estimator.is_fitted() {
575                match estimator.predict(x) {
576                    Ok(pred) => all_predictions.push(pred),
577                    Err(_) => continue,
578                }
579            }
580        }
581
582        meta_voting(&all_predictions, &meta_weights)
583    }
584
585    fn dynamic_voting(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
586        // Use recent performance as dynamic weights
587        let performances: Vec<Float> = self.estimators_.iter().map(|e| e.performance()).collect();
588
589        let mut all_predictions = Vec::new();
590        for estimator in &self.estimators_ {
591            if estimator.is_fitted() {
592                match estimator.predict(x) {
593                    Ok(pred) => all_predictions.push(pred),
594                    Err(_) => continue,
595                }
596            }
597        }
598
599        let classes = self.classes();
600        Ok(simd_hard_voting_weighted(
601            &all_predictions,
602            &performances,
603            classes.as_slice().unwrap(),
604        ))
605    }
606
607    fn uncertainty_aware_voting(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
608        let mut all_predictions = Vec::new();
609        let mut all_uncertainties = Vec::new();
610
611        for estimator in &self.estimators_ {
612            if estimator.is_fitted() {
613                match (estimator.predict(x), estimator.uncertainty(x)) {
614                    (Ok(pred), Ok(unc)) => {
615                        all_predictions.push(pred);
616                        all_uncertainties.push(unc);
617                    }
618                    _ => continue,
619                }
620            }
621        }
622
623        if all_predictions.is_empty() {
624            return self.hard_voting(x);
625        }
626
627        uncertainty_aware_voting(
628            &all_predictions,
629            &all_uncertainties,
630            self.config.confidence_threshold as f32,
631        )
632    }
633
634    fn consensus_voting(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
635        let mut all_predictions = Vec::new();
636
637        for estimator in &self.estimators_ {
638            if estimator.is_fitted() {
639                match estimator.predict(x) {
640                    Ok(pred) => all_predictions.push(pred),
641                    Err(_) => continue,
642                }
643            }
644        }
645
646        consensus_voting(&all_predictions, self.config.consensus_threshold as f32)
647    }
648
649    fn entropy_weighted_voting(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
650        let probabilities = self.predict_proba(x)?;
651        let n_samples = probabilities.nrows();
652        let mut result = Array1::zeros(n_samples);
653
654        for i in 0..n_samples {
655            let mut max_prob = probabilities[[i, 0]];
656            let mut best_class = 0;
657
658            for j in 1..probabilities.ncols() {
659                if probabilities[[i, j]] > max_prob {
660                    max_prob = probabilities[[i, j]];
661                    best_class = j;
662                }
663            }
664
665            result[i] = best_class as Float;
666        }
667
668        Ok(result)
669    }
670
671    fn variance_weighted_voting(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
672        let probabilities = self.predict_proba(x)?;
673        let n_samples = probabilities.nrows();
674        let mut result = Array1::zeros(n_samples);
675
676        for i in 0..n_samples {
677            let mut max_prob = probabilities[[i, 0]];
678            let mut best_class = 0;
679
680            for j in 1..probabilities.ncols() {
681                if probabilities[[i, j]] > max_prob {
682                    max_prob = probabilities[[i, j]];
683                    best_class = j;
684                }
685            }
686
687            result[i] = best_class as Float;
688        }
689
690        Ok(result)
691    }
692
693    fn bootstrap_voting(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
694        let mut all_predictions = Vec::new();
695
696        for estimator in &self.estimators_ {
697            if estimator.is_fitted() {
698                match estimator.predict(x) {
699                    Ok(pred) => all_predictions.push(pred),
700                    Err(_) => continue,
701                }
702            }
703        }
704
705        if all_predictions.is_empty() {
706            return Err(SklearsError::InvalidOperation(
707                "No fitted estimators available".to_string(),
708            ));
709        }
710
711        let mut result = Array1::zeros(x.nrows());
712        simd_bootstrap_aggregate(
713            &all_predictions,
714            self.config.n_bootstrap_samples,
715            &mut result,
716        )?;
717        Ok(result)
718    }
719
720    fn temperature_scaled_voting(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
721        let probabilities = self.predict_proba(x)?;
722        let n_samples = probabilities.nrows();
723        let mut result = Array1::zeros(n_samples);
724
725        for i in 0..n_samples {
726            let mut max_prob = probabilities[[i, 0]];
727            let mut best_class = 0;
728
729            for j in 1..probabilities.ncols() {
730                if probabilities[[i, j]] > max_prob {
731                    max_prob = probabilities[[i, j]];
732                    best_class = j;
733                }
734            }
735
736            result[i] = best_class as Float;
737        }
738
739        Ok(result)
740    }
741
742    fn adaptive_voting(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
743        let mut all_predictions = Vec::new();
744        let mut all_probabilities = Vec::new();
745
746        for estimator in &self.estimators_ {
747            if estimator.is_fitted() {
748                match estimator.predict(x) {
749                    Ok(pred) => all_predictions.push(pred),
750                    Err(_) => continue,
751                }
752
753                if estimator.supports_proba() {
754                    if let Ok(proba) = estimator.predict_proba(x) {
755                        all_probabilities.push(proba);
756                    }
757                }
758            }
759        }
760
761        let diversity =
762            ensemble_utils::calculate_ensemble_diversity(&self.estimators_, x).unwrap_or(0.5);
763        let performance_history: Vec<f32> = self
764            .estimators_
765            .iter()
766            .map(|e| e.performance() as f32)
767            .collect();
768
769        let probabilities = if all_probabilities.is_empty() {
770            None
771        } else {
772            Some(&all_probabilities[..])
773        };
774
775        adaptive_ensemble_voting(
776            &all_predictions,
777            probabilities,
778            diversity as f32,
779            &performance_history,
780        )
781    }
782}
783
784// Implement Fit trait for VotingClassifier
785impl Fit<Array2<Float>, Array1<Float>> for VotingClassifier<Untrained> {
786    type Fitted = VotingClassifier<Trained>;
787
788    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
789        if x.nrows() != y.len() {
790            return Err(SklearsError::ShapeMismatch {
791                expected: format!("{} samples", x.nrows()),
792                actual: format!("{} labels", y.len()),
793            });
794        }
795
796        // Discover unique classes
797        let mut classes: Vec<Float> = y.iter().cloned().collect();
798        classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
799        classes.dedup();
800        let classes_array = Array1::from_vec(classes);
801
802        Ok(VotingClassifier {
803            config: self.config,
804            estimators_: self.estimators_,
805            classes_: Some(classes_array),
806            n_features_in_: Some(x.ncols()),
807            state: PhantomData,
808        })
809    }
810}
811
812/// Builder for VotingClassifier
813#[derive(Debug)]
814pub struct VotingClassifierBuilder {
815    config: VotingClassifierConfig,
816}
817
818impl VotingClassifierBuilder {
819    pub fn new() -> Self {
820        Self {
821            config: VotingClassifierConfig::default(),
822        }
823    }
824
825    pub fn voting(mut self, voting: VotingStrategy) -> Self {
826        self.config.voting = voting;
827        self
828    }
829
830    pub fn weights(mut self, weights: Vec<Float>) -> Self {
831        self.config.weights = Some(weights);
832        self
833    }
834
835    pub fn confidence_weighting(mut self, enable: bool) -> Self {
836        self.config.confidence_weighting = enable;
837        self
838    }
839
840    pub fn confidence_threshold(mut self, threshold: Float) -> Self {
841        self.config.confidence_threshold = threshold;
842        self
843    }
844
845    pub fn min_confidence_weight(mut self, weight: Float) -> Self {
846        self.config.min_confidence_weight = weight;
847        self
848    }
849
850    pub fn enable_uncertainty(mut self, enable: bool) -> Self {
851        self.config.enable_uncertainty = enable;
852        self
853    }
854
855    pub fn temperature(mut self, temp: Float) -> Self {
856        self.config.temperature = temp;
857        self
858    }
859
860    pub fn meta_regularization(mut self, reg: Float) -> Self {
861        self.config.meta_regularization = reg;
862        self
863    }
864
865    pub fn n_bootstrap_samples(mut self, n: usize) -> Self {
866        self.config.n_bootstrap_samples = n;
867        self
868    }
869
870    pub fn consensus_threshold(mut self, threshold: Float) -> Self {
871        self.config.consensus_threshold = threshold;
872        self
873    }
874
875    pub fn entropy_weight_factor(mut self, factor: Float) -> Self {
876        self.config.entropy_weight_factor = factor;
877        self
878    }
879
880    pub fn variance_weight_factor(mut self, factor: Float) -> Self {
881        self.config.variance_weight_factor = factor;
882        self
883    }
884
885    pub fn weight_adjustment_rate(mut self, rate: Float) -> Self {
886        self.config.weight_adjustment_rate = rate;
887        self
888    }
889
890    // Convenience methods for common configurations
891    pub fn confidence_weighted() -> Self {
892        Self::new()
893            .voting(VotingStrategy::ConfidenceWeighted)
894            .confidence_weighting(true)
895    }
896
897    pub fn bayesian_averaging() -> Self {
898        Self::new().voting(VotingStrategy::BayesianAveraging)
899    }
900
901    pub fn meta_voting() -> Self {
902        Self::new().voting(VotingStrategy::MetaVoting)
903    }
904
905    pub fn uncertainty_aware() -> Self {
906        Self::new()
907            .voting(VotingStrategy::UncertaintyAware)
908            .enable_uncertainty(true)
909    }
910
911    pub fn entropy_weighted() -> Self {
912        Self::new()
913            .voting(VotingStrategy::EntropyWeighted)
914            .entropy_weight_factor(1.0)
915    }
916
917    pub fn variance_weighted() -> Self {
918        Self::new()
919            .voting(VotingStrategy::VarianceWeighted)
920            .variance_weight_factor(1.0)
921    }
922
923    pub fn consensus_based() -> Self {
924        Self::new()
925            .voting(VotingStrategy::ConsensusBased)
926            .consensus_threshold(0.7)
927    }
928
929    pub fn build(self) -> VotingClassifier<Untrained> {
930        VotingClassifier::new(self.config)
931    }
932}
933
934impl Default for VotingClassifierBuilder {
935    fn default() -> Self {
936        Self::new()
937    }
938}