sklears_ensemble/adaboost/
ada_classifier.rs

1//! AdaBoost Classifier implementation
2
3use super::helpers::*;
4use super::types::*;
5use scirs2_core::ndarray::{Array1, Array2};
6use scirs2_core::random::Rng;
7use sklears_core::{
8    error::{Result, SklearsError},
9    prelude::{Fit, Predict},
10    traits::{Trained, Untrained},
11    types::Float,
12};
13use std::marker::PhantomData;
14
15use super::types::AdaBoostClassifier;
16impl AdaBoostClassifier<Untrained> {
17    /// Create a new AdaBoost classifier
18    pub fn new() -> Self {
19        Self {
20            config: AdaBoostConfig::default(),
21            state: PhantomData,
22            estimators_: None,
23            estimator_weights_: None,
24            estimator_errors_: None,
25            classes_: None,
26            n_classes_: None,
27            n_features_in_: None,
28        }
29    }
30
31    /// Set the number of boosting iterations
32    pub fn n_estimators(mut self, n_estimators: usize) -> Self {
33        self.config.n_estimators = n_estimators;
34        self
35    }
36
37    /// Set the learning rate
38    pub fn learning_rate(mut self, learning_rate: Float) -> Self {
39        self.config.learning_rate = learning_rate;
40        self
41    }
42
43    /// Set the random state for reproducible results
44    pub fn random_state(mut self, random_state: u64) -> Self {
45        self.config.random_state = Some(random_state);
46        self
47    }
48
49    /// Set the algorithm variant
50    pub fn algorithm(mut self, algorithm: AdaBoostAlgorithm) -> Self {
51        self.config.algorithm = algorithm;
52        self
53    }
54
55    /// Use the SAMME.R algorithm variant
56    pub fn with_samme_r(mut self) -> Self {
57        self.config.algorithm = AdaBoostAlgorithm::SAMMER;
58        self
59    }
60
61    /// Use the Gentle AdaBoost algorithm variant
62    pub fn with_gentle(mut self) -> Self {
63        self.config.algorithm = AdaBoostAlgorithm::Gentle;
64        self
65    }
66
67    /// Find unique classes in the target array
68    pub(crate) fn find_classes(y: &Array1<Float>) -> Array1<Float> {
69        let mut classes: Vec<Float> = y.iter().cloned().collect();
70        classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
71        classes.dedup();
72        Array1::from_vec(classes)
73    }
74
75    /// Calculate class weights for current iteration
76    fn calculate_sample_weights(
77        &self,
78        y: &Array1<Float>,
79        y_pred: &Array1<Float>,
80        sample_weight: &Array1<Float>,
81        estimator_weight: Float,
82    ) -> Array1<Float> {
83        let n_samples = y.len();
84        let mut new_weights = sample_weight.clone();
85
86        for i in 0..n_samples {
87            if y[i] != y_pred[i] {
88                new_weights[i] *= (estimator_weight).exp();
89            }
90        }
91
92        let weight_sum = new_weights.sum();
93        if weight_sum > 0.0 {
94            new_weights /= weight_sum;
95        } else {
96            new_weights.fill(1.0 / n_samples as Float);
97        }
98
99        new_weights
100    }
101
102    /// Calculate estimator weight based on error
103    fn calculate_estimator_weight(&self, error: Float, n_classes: usize) -> Float {
104        if error <= 0.0 {
105            return 10.0;
106        }
107
108        if error >= 1.0 - 1.0 / n_classes as Float {
109            return 0.0;
110        }
111
112        match self.config.algorithm {
113            AdaBoostAlgorithm::SAMME => {
114                let alpha = ((1.0 - error) / error).ln() + (n_classes as Float - 1.0).ln();
115                alpha * self.config.learning_rate
116            }
117            AdaBoostAlgorithm::SAMMER => self.config.learning_rate,
118            AdaBoostAlgorithm::RealAdaBoost => 0.5 * ((1.0 - error) / error).ln(),
119            AdaBoostAlgorithm::M1 => {
120                if error >= 0.5 {
121                    return 0.0;
122                }
123                0.5 * ((1.0 - error) / error).ln()
124            }
125            AdaBoostAlgorithm::M2 => {
126                let alpha = 0.5 * ((1.0 - error) / error).ln();
127                alpha * self.config.learning_rate
128            }
129            AdaBoostAlgorithm::Gentle => {
130                let alpha = 0.5 * ((1.0 - error) / error).ln();
131                alpha * self.config.learning_rate * 0.5
132            }
133            AdaBoostAlgorithm::Discrete => ((1.0 - error) / error).ln() * self.config.learning_rate,
134        }
135    }
136
137    /// Resample data according to sample weights
138    fn resample_data(
139        &self,
140        x: &Array2<Float>,
141        y: &Array1<i32>,
142        sample_weight: &Array1<Float>,
143        rng: &mut impl Rng,
144    ) -> Result<(Array2<Float>, Array1<i32>)> {
145        let n_samples = x.nrows();
146
147        let weight_sum = sample_weight.sum();
148        let normalized_weights = if weight_sum > 0.0 {
149            sample_weight / weight_sum
150        } else {
151            Array1::<Float>::from_elem(n_samples, 1.0 / n_samples as Float)
152        };
153
154        let mut cumulative = Array1::<Float>::zeros(n_samples);
155        cumulative[0] = normalized_weights[0];
156        for i in 1..n_samples {
157            cumulative[i] = cumulative[i - 1] + normalized_weights[i];
158        }
159
160        let mut selected_indices = Vec::new();
161        let unique_classes: std::collections::HashSet<i32> = y.iter().cloned().collect();
162
163        for _i in 0..n_samples {
164            let rand_val = rng.random::<Float>() * cumulative[n_samples - 1];
165            let idx = cumulative
166                .iter()
167                .position(|&cum| cum >= rand_val)
168                .unwrap_or(n_samples - 1);
169            selected_indices.push(idx);
170        }
171
172        let resampled_classes: std::collections::HashSet<i32> =
173            selected_indices.iter().map(|&idx| y[idx]).collect();
174
175        if resampled_classes.len() < unique_classes.len() {
176            let mut replacement_count = 0;
177            for &missing_class in unique_classes.difference(&resampled_classes) {
178                if let Some(original_idx) = y.iter().position(|&class| class == missing_class) {
179                    if replacement_count < selected_indices.len() {
180                        selected_indices[replacement_count] = original_idx;
181                        replacement_count += 1;
182                    }
183                }
184            }
185        }
186
187        let mut x_resampled = Array2::<Float>::zeros(x.dim());
188        let mut y_resampled = Array1::<i32>::zeros(y.len());
189
190        for (i, &idx) in selected_indices.iter().enumerate() {
191            x_resampled.row_mut(i).assign(&x.row(idx));
192            y_resampled[i] = y[idx];
193        }
194
195        Ok((x_resampled, y_resampled))
196    }
197
198    /// Calculate sample weights for SAMME.R algorithm
199    fn calculate_sample_weights_sammer(
200        &self,
201        y: &Array1<Float>,
202        prob_estimates: &Array2<Float>,
203        sample_weight: &Array1<Float>,
204        classes: &Array1<Float>,
205        estimator_weight: Float,
206    ) -> Array1<Float> {
207        let n_samples = y.len();
208        let n_classes = classes.len();
209        let mut new_weights = sample_weight.clone();
210
211        let factor = ((n_classes - 1) as Float / n_classes as Float) * estimator_weight;
212
213        for i in 0..n_samples {
214            let true_class = y[i];
215            let true_class_idx = classes.iter().position(|&c| c == true_class);
216
217            if let Some(class_idx) = true_class_idx {
218                let probs = prob_estimates.row(i);
219
220                let mut h_xi = 0.0;
221                for k in 0..n_classes {
222                    let p_k = probs[k].clamp(1e-7, 1.0 - 1e-7);
223
224                    if k == class_idx {
225                        h_xi += (n_classes as Float - 1.0) * p_k.ln();
226                    } else {
227                        h_xi -= p_k.ln();
228                    }
229                }
230
231                let weight_multiplier = (-factor * h_xi / n_classes as Float).exp();
232                new_weights[i] *= weight_multiplier;
233                new_weights[i] = new_weights[i].clamp(1e-10, 1e3);
234            }
235        }
236
237        let weight_sum = new_weights.sum();
238        if weight_sum > 0.0 {
239            new_weights /= weight_sum;
240        } else {
241            new_weights.fill(1.0 / n_samples as Float);
242        }
243
244        new_weights
245    }
246
247    /// Calculate sample weights for Real AdaBoost
248    fn calculate_sample_weights_real_adaboost(
249        &self,
250        y: &Array1<Float>,
251        prob_estimates: &Array2<Float>,
252        sample_weight: &Array1<Float>,
253        classes: &Array1<Float>,
254    ) -> Array1<Float> {
255        let n_samples = y.len();
256        let mut new_weights = sample_weight.clone();
257
258        for i in 0..n_samples {
259            let true_class = y[i];
260            let y_i = if true_class == classes[0] { -1.0 } else { 1.0 };
261
262            let p_0 = prob_estimates[[i, 0]].clamp(1e-7, 1.0 - 1e-7);
263            let p_1 = prob_estimates[[i, 1]].clamp(1e-7, 1.0 - 1e-7);
264
265            let h_xi = 0.5 * (p_1 / p_0).ln();
266
267            let weight_multiplier = (-y_i * h_xi).exp();
268            new_weights[i] *= weight_multiplier;
269            new_weights[i] = new_weights[i].clamp(1e-10, 1e3);
270        }
271
272        let weight_sum = new_weights.sum();
273        if weight_sum > 0.0 {
274            new_weights /= weight_sum;
275        } else {
276            new_weights.fill(1.0 / n_samples as Float);
277        }
278
279        new_weights
280    }
281}
282
283impl Default for AdaBoostClassifier<Untrained> {
284    fn default() -> Self {
285        Self::new()
286    }
287}
288
289impl<State> std::fmt::Debug for AdaBoostClassifier<State> {
290    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291        f.debug_struct("AdaBoostClassifier")
292            .field("config", &self.config)
293            .field(
294                "n_estimators_fitted",
295                &self.estimators_.as_ref().map(|e| e.len()),
296            )
297            .field("n_classes", &self.n_classes_)
298            .field("n_features_in", &self.n_features_in_)
299            .finish()
300    }
301}
302
303impl Fit<Array2<Float>, Array1<Float>> for AdaBoostClassifier<Untrained> {
304    type Fitted = AdaBoostClassifier<Trained>;
305    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
306        let (n_samples, n_features) = x.dim();
307        if n_samples != y.len() {
308            return Err(SklearsError::InvalidInput(
309                "Number of samples in X and y must match".to_string(),
310            ));
311        }
312        if n_samples == 0 {
313            return Err(SklearsError::InvalidInput(
314                "Cannot fit AdaBoost on empty dataset".to_string(),
315            ));
316        }
317        if self.config.n_estimators == 0 {
318            return Err(SklearsError::InvalidParameter {
319                name: "n_estimators".to_string(),
320                reason: "Number of estimators must be positive".to_string(),
321            });
322        }
323        let classes = Self::find_classes(y);
324        let n_classes = classes.len();
325        if n_classes < 2 {
326            return Err(SklearsError::InvalidInput(
327                "AdaBoost requires at least 2 classes".to_string(),
328            ));
329        }
330        let mut sample_weight = Array1::<Float>::from_elem(n_samples, 1.0 / n_samples as Float);
331        let mut estimators = Vec::new();
332        let mut estimator_weights = Vec::new();
333        let mut estimator_errors = Vec::new();
334        let mut rng = match self.config.random_state {
335            Some(seed) => scirs2_core::random::seeded_rng(seed),
336            None => scirs2_core::random::seeded_rng(42),
337        };
338        let y_i32 = convert_labels_to_i32(y);
339        for _iteration in 0..self.config.n_estimators {
340            let base_estimator = DecisionTreeClassifier::new()
341                .criterion(SplitCriterion::Gini)
342                .max_depth(1)
343                .min_samples_split(2)
344                .min_samples_leaf(1);
345            match self.config.algorithm {
346                AdaBoostAlgorithm::SAMME => {
347                    let (x_resampled, y_resampled) =
348                        self.resample_data(x, &y_i32, &sample_weight, &mut rng)?;
349                    let fitted_estimator = base_estimator.fit(&x_resampled, &y_resampled)?;
350                    let y_pred_i32 = fitted_estimator.predict(x)?;
351                    let y_pred = convert_predictions_to_float(&y_pred_i32);
352                    if y_pred.len() != n_samples {
353                        return Err(SklearsError::ShapeMismatch {
354                            expected: format!("{} predictions", n_samples),
355                            actual: format!("{} predictions", y_pred.len()),
356                        });
357                    }
358                    let mut weighted_error = 0.0;
359                    for i in 0..n_samples {
360                        if y[i] != y_pred[i] {
361                            weighted_error += sample_weight[i];
362                        }
363                    }
364                    if weighted_error >= 0.5 {
365                        if estimators.is_empty() {
366                            estimators.push(fitted_estimator);
367                            estimator_weights.push(0.0);
368                            estimator_errors.push(weighted_error);
369                        }
370                        break;
371                    }
372                    let estimator_weight =
373                        self.calculate_estimator_weight(weighted_error, n_classes);
374                    estimators.push(fitted_estimator);
375                    estimator_weights.push(estimator_weight);
376                    estimator_errors.push(weighted_error);
377                    sample_weight =
378                        self.calculate_sample_weights(y, &y_pred, &sample_weight, estimator_weight);
379                    if weighted_error < 1e-10 {
380                        break;
381                    }
382                }
383                AdaBoostAlgorithm::SAMMER => {
384                    let (x_resampled, y_resampled) =
385                        self.resample_data(x, &y_i32, &sample_weight, &mut rng)?;
386                    let fitted_estimator = base_estimator.fit(&x_resampled, &y_resampled)?;
387                    let y_pred_i32 = fitted_estimator.predict(x)?;
388                    let y_pred = convert_predictions_to_float(&y_pred_i32);
389                    if y_pred.len() != n_samples {
390                        return Err(SklearsError::ShapeMismatch {
391                            expected: format!("{} predictions", n_samples),
392                            actual: format!("{} predictions", y_pred.len()),
393                        });
394                    }
395                    let prob_estimates =
396                        estimate_probabilities(&y_pred, &classes, n_samples, n_classes);
397                    let mut weighted_error = 0.0;
398                    for i in 0..n_samples {
399                        if y[i] != y_pred[i] {
400                            weighted_error += sample_weight[i];
401                        }
402                    }
403                    let estimator_weight = self.config.learning_rate;
404                    estimators.push(fitted_estimator);
405                    estimator_weights.push(estimator_weight);
406                    estimator_errors.push(weighted_error);
407                    sample_weight = self.calculate_sample_weights_sammer(
408                        y,
409                        &prob_estimates,
410                        &sample_weight,
411                        &classes,
412                        estimator_weight,
413                    );
414                    if !(1e-10..0.5).contains(&weighted_error) {
415                        break;
416                    }
417                }
418                AdaBoostAlgorithm::RealAdaBoost => {
419                    let (x_resampled, y_resampled) =
420                        self.resample_data(x, &y_i32, &sample_weight, &mut rng)?;
421                    let fitted_estimator = base_estimator.fit(&x_resampled, &y_resampled)?;
422                    if n_classes != 2 {
423                        return Err(SklearsError::InvalidInput(
424                            "Real AdaBoost currently supports only binary classification"
425                                .to_string(),
426                        ));
427                    }
428                    let y_pred_i32 = fitted_estimator.predict(x)?;
429                    let y_pred = convert_predictions_to_float(&y_pred_i32);
430                    if y_pred.len() != n_samples {
431                        return Err(SklearsError::ShapeMismatch {
432                            expected: format!("{} predictions", n_samples),
433                            actual: format!("{} predictions", y_pred.len()),
434                        });
435                    }
436                    let prob_estimates = estimate_binary_probabilities(&y_pred, &classes);
437                    let mut weighted_error = 0.0;
438                    for i in 0..n_samples {
439                        let correct_class_idx = if y[i] == classes[0] { 0 } else { 1 };
440                        let prob_correct = prob_estimates[[i, correct_class_idx]];
441                        if prob_correct < 0.5 {
442                            weighted_error += sample_weight[i];
443                        }
444                    }
445                    let estimator_weight = if weighted_error > 0.0 && weighted_error < 0.5 {
446                        0.5 * ((1.0 - weighted_error) / weighted_error).ln()
447                    } else if weighted_error == 0.0 {
448                        10.0
449                    } else {
450                        0.0
451                    };
452                    estimators.push(fitted_estimator);
453                    estimator_weights.push(estimator_weight);
454                    estimator_errors.push(weighted_error);
455                    sample_weight = self.calculate_sample_weights_real_adaboost(
456                        y,
457                        &prob_estimates,
458                        &sample_weight,
459                        &classes,
460                    );
461                    if !(1e-10..0.5).contains(&weighted_error) {
462                        break;
463                    }
464                }
465                AdaBoostAlgorithm::M1 => {
466                    let (x_resampled, y_resampled) =
467                        self.resample_data(x, &y_i32, &sample_weight, &mut rng)?;
468                    let fitted_estimator = base_estimator.fit(&x_resampled, &y_resampled)?;
469                    let y_pred_i32 = fitted_estimator.predict(x)?;
470                    let y_pred = convert_predictions_to_float(&y_pred_i32);
471                    if y_pred.len() != n_samples {
472                        return Err(SklearsError::ShapeMismatch {
473                            expected: format!("{} predictions", n_samples),
474                            actual: format!("{} predictions", y_pred.len()),
475                        });
476                    }
477                    let mut weighted_error = 0.0;
478                    for i in 0..n_samples {
479                        if y[i] != y_pred[i] {
480                            weighted_error += sample_weight[i];
481                        }
482                    }
483                    if weighted_error >= 0.5 {
484                        if estimators.is_empty() {
485                            return Err(SklearsError::InvalidInput(
486                                "AdaBoost.M1 requires strong learners (error < 0.5)".to_string(),
487                            ));
488                        }
489                        break;
490                    }
491                    let estimator_weight =
492                        self.calculate_estimator_weight(weighted_error, n_classes);
493                    estimators.push(fitted_estimator);
494                    estimator_weights.push(estimator_weight);
495                    estimator_errors.push(weighted_error);
496                    sample_weight =
497                        self.calculate_sample_weights(y, &y_pred, &sample_weight, estimator_weight);
498                    if weighted_error < 1e-10 {
499                        break;
500                    }
501                }
502                AdaBoostAlgorithm::M2 => {
503                    let (x_resampled, y_resampled) =
504                        self.resample_data(x, &y_i32, &sample_weight, &mut rng)?;
505                    let fitted_estimator = base_estimator.fit(&x_resampled, &y_resampled)?;
506                    let y_pred_i32 = fitted_estimator.predict(x)?;
507                    let y_pred = convert_predictions_to_float(&y_pred_i32);
508                    if y_pred.len() != n_samples {
509                        return Err(SklearsError::ShapeMismatch {
510                            expected: format!("{} predictions", n_samples),
511                            actual: format!("{} predictions", y_pred.len()),
512                        });
513                    }
514                    let prob_estimates =
515                        estimate_probabilities(&y_pred, &classes, n_samples, n_classes);
516                    let mut pseudo_loss = 0.0;
517                    for i in 0..n_samples {
518                        let true_class_idx = classes.iter().position(|&c| c == y[i]).unwrap_or(0);
519                        let mut margin = prob_estimates[[i, true_class_idx]];
520                        for j in 0..n_classes {
521                            if j != true_class_idx {
522                                margin -= prob_estimates[[i, j]] / (n_classes - 1) as Float;
523                            }
524                        }
525                        if margin <= 0.0 {
526                            pseudo_loss += sample_weight[i] * (1.0 - margin);
527                        }
528                    }
529                    let total_weight: Float = sample_weight.sum();
530                    if total_weight > 0.0 {
531                        pseudo_loss /= total_weight;
532                    }
533                    if pseudo_loss >= 0.5 {
534                        if estimators.is_empty() {
535                            estimators.push(fitted_estimator);
536                            estimator_weights.push(0.0);
537                            estimator_errors.push(pseudo_loss);
538                        }
539                        break;
540                    }
541                    let estimator_weight = self.calculate_estimator_weight(pseudo_loss, n_classes);
542                    estimators.push(fitted_estimator);
543                    estimator_weights.push(estimator_weight);
544                    estimator_errors.push(pseudo_loss);
545                    let mut new_weights = sample_weight.clone();
546                    for i in 0..n_samples {
547                        let true_class_idx = classes.iter().position(|&c| c == y[i]).unwrap_or(0);
548                        let confidence = prob_estimates[[i, true_class_idx]];
549                        let weight_multiplier = if confidence < 0.5 {
550                            (estimator_weight * (1.0 - confidence)).exp()
551                        } else {
552                            (estimator_weight * confidence).exp().recip()
553                        };
554                        new_weights[i] *= weight_multiplier;
555                    }
556                    let weight_sum = new_weights.sum();
557                    if weight_sum > 0.0 {
558                        new_weights /= weight_sum;
559                    } else {
560                        new_weights.fill(1.0 / n_samples as Float);
561                    }
562                    sample_weight = new_weights;
563                    if pseudo_loss < 1e-10 {
564                        break;
565                    }
566                }
567                AdaBoostAlgorithm::Gentle | AdaBoostAlgorithm::Discrete => {
568                    let (x_resampled, y_resampled) =
569                        self.resample_data(x, &y_i32, &sample_weight, &mut rng)?;
570                    let fitted_estimator = base_estimator.fit(&x_resampled, &y_resampled)?;
571                    let y_pred_i32 = fitted_estimator.predict(x)?;
572                    let y_pred = convert_predictions_to_float(&y_pred_i32);
573                    if y_pred.len() != n_samples {
574                        return Err(SklearsError::ShapeMismatch {
575                            expected: format!("{} predictions", n_samples),
576                            actual: format!("{} predictions", y_pred.len()),
577                        });
578                    }
579                    let mut weighted_error = 0.0;
580                    for i in 0..n_samples {
581                        if y[i] != y_pred[i] {
582                            weighted_error += sample_weight[i];
583                        }
584                    }
585                    if weighted_error >= 0.6 {
586                        if estimators.is_empty() {
587                            estimators.push(fitted_estimator);
588                            estimator_weights.push(0.1);
589                            estimator_errors.push(weighted_error);
590                        }
591                        break;
592                    }
593                    let estimator_weight =
594                        self.calculate_estimator_weight(weighted_error, n_classes);
595                    estimators.push(fitted_estimator);
596                    estimator_weights.push(estimator_weight);
597                    estimator_errors.push(weighted_error);
598                    let mut new_weights = sample_weight.clone();
599                    let gentle_factor = 0.5;
600                    for i in 0..n_samples {
601                        let multiplier = if y[i] != y_pred[i] {
602                            (gentle_factor * estimator_weight).exp()
603                        } else {
604                            (-gentle_factor * estimator_weight).exp()
605                        };
606                        new_weights[i] *= multiplier;
607                    }
608                    let weight_sum = new_weights.sum();
609                    if weight_sum > 0.0 {
610                        new_weights /= weight_sum;
611                        let smoothing = 0.01;
612                        let uniform_weight = 1.0 / n_samples as Float;
613                        for i in 0..n_samples {
614                            new_weights[i] =
615                                (1.0 - smoothing) * new_weights[i] + smoothing * uniform_weight;
616                        }
617                        let smoothed_sum = new_weights.sum();
618                        if smoothed_sum > 0.0 {
619                            new_weights /= smoothed_sum;
620                        }
621                    } else {
622                        new_weights.fill(1.0 / n_samples as Float);
623                    }
624                    sample_weight = new_weights;
625                    if weighted_error < 1e-10 {
626                        break;
627                    }
628                }
629            }
630        }
631        if estimators.is_empty() {
632            return Err(SklearsError::InvalidInput(
633                "AdaBoost failed to fit any estimators".to_string(),
634            ));
635        }
636        Ok(AdaBoostClassifier {
637            config: self.config,
638            state: PhantomData,
639            estimators_: Some(estimators),
640            estimator_weights_: Some(Array1::from_vec(estimator_weights)),
641            estimator_errors_: Some(Array1::from_vec(estimator_errors)),
642            classes_: Some(classes),
643            n_classes_: Some(n_classes),
644            n_features_in_: Some(n_features),
645        })
646    }
647}
648
649impl AdaBoostClassifier<Trained> {
650    /// Get the fitted base estimators
651    pub fn estimators(&self) -> &[DecisionTreeClassifier<Trained>] {
652        self.estimators_
653            .as_ref()
654            .expect("AdaBoost should be fitted")
655    }
656
657    /// Get the weights for each estimator
658    pub fn estimator_weights(&self) -> &Array1<Float> {
659        self.estimator_weights_
660            .as_ref()
661            .expect("AdaBoost should be fitted")
662    }
663
664    /// Get the errors for each estimator
665    pub fn estimator_errors(&self) -> &Array1<Float> {
666        self.estimator_errors_
667            .as_ref()
668            .expect("AdaBoost should be fitted")
669    }
670
671    /// Get the classes
672    pub fn classes(&self) -> &Array1<Float> {
673        self.classes_.as_ref().expect("AdaBoost should be fitted")
674    }
675
676    /// Get the number of classes
677    pub fn n_classes(&self) -> usize {
678        self.n_classes_.expect("AdaBoost should be fitted")
679    }
680
681    /// Get the number of input features
682    pub fn n_features_in(&self) -> usize {
683        self.n_features_in_.expect("AdaBoost should be fitted")
684    }
685
686    /// Get feature importances (averaged from all estimators)
687    pub fn feature_importances(&self) -> Result<Array1<Float>> {
688        let estimators = self.estimators();
689        let weights = self.estimator_weights();
690        let n_features = self.n_features_in();
691
692        if estimators.is_empty() {
693            return Ok(Array1::<Float>::zeros(n_features));
694        }
695
696        let mut importances = Array1::<Float>::zeros(n_features);
697        let mut total_weight = 0.0;
698
699        for (_estimator, &weight) in estimators.iter().zip(weights.iter()) {
700            // For decision stumps, we simulate feature importance
701            // In practice, this would use the actual feature importance from the tree
702            let tree_importances = Array1::<Float>::ones(n_features) / n_features as Float;
703            importances += &(tree_importances * weight.abs());
704            total_weight += weight.abs();
705        }
706
707        if total_weight > 0.0 {
708            importances /= total_weight;
709        } else {
710            // If no valid weights, return uniform importance
711            importances.fill(1.0 / n_features as Float);
712        }
713
714        Ok(importances)
715    }
716
717    /// Predict class probabilities using weighted voting
718    pub fn predict_proba(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
719        let (n_samples, n_features) = x.dim();
720
721        if n_features != self.n_features_in() {
722            return Err(SklearsError::FeatureMismatch {
723                expected: self.n_features_in(),
724                actual: n_features,
725            });
726        }
727
728        let estimators = self.estimators();
729        let weights = self.estimator_weights();
730        let classes = self.classes();
731        let n_classes = self.n_classes();
732
733        match self.config.algorithm {
734            AdaBoostAlgorithm::SAMME => {
735                // Original SAMME: use weighted voting
736                let mut class_votes = Array2::<Float>::zeros((n_samples, n_classes));
737
738                // Aggregate predictions from all estimators
739                for (estimator, &weight) in estimators.iter().zip(weights.iter()) {
740                    let predictions_i32 = estimator.predict(x)?;
741                    let predictions = convert_predictions_to_float(&predictions_i32);
742
743                    for (i, &pred) in predictions.iter().enumerate() {
744                        if let Some(class_idx) = classes.iter().position(|&c| c == pred) {
745                            class_votes[[i, class_idx]] += weight;
746                        }
747                    }
748                }
749
750                // Convert votes to probabilities
751                let mut probabilities = Array2::<Float>::zeros((n_samples, n_classes));
752                for i in 0..n_samples {
753                    let vote_sum = class_votes.row(i).sum();
754                    if vote_sum > 0.0 {
755                        for j in 0..n_classes {
756                            probabilities[[i, j]] = class_votes[[i, j]] / vote_sum;
757                        }
758                    } else {
759                        // Uniform distribution if no votes
760                        probabilities.row_mut(i).fill(1.0 / n_classes as Float);
761                    }
762                }
763
764                Ok(probabilities)
765            }
766            AdaBoostAlgorithm::SAMMER => {
767                // SAMME.R: use probability aggregation
768                let mut prob_sum = Array2::<Float>::zeros((n_samples, n_classes));
769
770                // Initialize with uniform probabilities
771                prob_sum.fill(1.0 / n_classes as Float);
772
773                // Aggregate probability estimates from all estimators
774                for (estimator, &weight) in estimators.iter().zip(weights.iter()) {
775                    let predictions_i32 = estimator.predict(x)?;
776                    let predictions = convert_predictions_to_float(&predictions_i32);
777
778                    // Estimate probabilities for this estimator
779                    let prob_estimates =
780                        estimate_probabilities(&predictions, classes, n_samples, n_classes);
781
782                    // SAMME.R aggregation: geometric mean of probabilities
783                    for i in 0..n_samples {
784                        for j in 0..n_classes {
785                            // Use log-space for numerical stability
786                            prob_sum[[i, j]] += weight * prob_estimates[[i, j]].ln();
787                        }
788                    }
789                }
790
791                // Convert back from log-space and normalize
792                let mut probabilities = Array2::<Float>::zeros((n_samples, n_classes));
793                for i in 0..n_samples {
794                    // Find max for numerical stability
795                    let max_log = prob_sum
796                        .row(i)
797                        .iter()
798                        .cloned()
799                        .fold(f64::NEG_INFINITY, f64::max);
800
801                    // Compute exp and normalize
802                    let mut sum = 0.0;
803                    for j in 0..n_classes {
804                        probabilities[[i, j]] = (prob_sum[[i, j]] - max_log).exp();
805                        sum += probabilities[[i, j]];
806                    }
807
808                    // Normalize
809                    for j in 0..n_classes {
810                        probabilities[[i, j]] /= sum;
811                    }
812                }
813
814                Ok(probabilities)
815            }
816            AdaBoostAlgorithm::RealAdaBoost => {
817                // Real AdaBoost probability aggregation for binary classification
818                if n_classes != 2 {
819                    return Err(SklearsError::InvalidInput(
820                        "Real AdaBoost predict_proba only supports binary classification"
821                            .to_string(),
822                    ));
823                }
824
825                let mut decision_scores = Array1::<Float>::zeros(n_samples);
826
827                // Aggregate decision scores from all estimators
828                for (estimator, &weight) in estimators.iter().zip(weights.iter()) {
829                    let predictions_i32 = estimator.predict(x)?;
830                    let predictions = convert_predictions_to_float(&predictions_i32);
831
832                    // Get probability estimates for Real AdaBoost
833                    let prob_estimates = estimate_binary_probabilities(&predictions, classes);
834
835                    for i in 0..n_samples {
836                        let p_0 = prob_estimates[[i, 0]].clamp(1e-7, 1.0 - 1e-7);
837                        let p_1 = prob_estimates[[i, 1]].clamp(1e-7, 1.0 - 1e-7);
838
839                        // Real AdaBoost decision function: h_t(x) = 0.5 * ln(p_1/p_0)
840                        let h_t = 0.5 * (p_1 / p_0).ln();
841                        decision_scores[i] += weight * h_t;
842                    }
843                }
844
845                // Convert decision scores to probabilities using sigmoid
846                let mut probabilities = Array2::<Float>::zeros((n_samples, 2));
847                for i in 0..n_samples {
848                    let sigmoid = 1.0 / (1.0 + (-decision_scores[i]).exp());
849                    probabilities[[i, 1]] = sigmoid;
850                    probabilities[[i, 0]] = 1.0 - sigmoid;
851                }
852
853                Ok(probabilities)
854            }
855            AdaBoostAlgorithm::M1 => {
856                // AdaBoost.M1: similar to SAMME but with strong learner assumption
857                let mut class_votes = Array2::<Float>::zeros((n_samples, n_classes));
858
859                // Aggregate predictions from all estimators
860                for (estimator, &weight) in estimators.iter().zip(weights.iter()) {
861                    let predictions_i32 = estimator.predict(x)?;
862                    let predictions = convert_predictions_to_float(&predictions_i32);
863
864                    for (i, &pred) in predictions.iter().enumerate() {
865                        if let Some(class_idx) = classes.iter().position(|&c| c == pred) {
866                            class_votes[[i, class_idx]] += weight;
867                        }
868                    }
869                }
870
871                // Convert votes to probabilities (same as SAMME)
872                let mut probabilities = Array2::<Float>::zeros((n_samples, n_classes));
873                for i in 0..n_samples {
874                    let vote_sum = class_votes.row(i).sum();
875                    if vote_sum > 0.0 {
876                        for j in 0..n_classes {
877                            probabilities[[i, j]] = class_votes[[i, j]] / vote_sum;
878                        }
879                    } else {
880                        // Uniform distribution if no votes
881                        probabilities.row_mut(i).fill(1.0 / n_classes as Float);
882                    }
883                }
884
885                Ok(probabilities)
886            }
887            AdaBoostAlgorithm::M2 => {
888                // AdaBoost.M2: confidence-rated predictions with probability weighting
889                let mut confidence_scores = Array2::<Float>::zeros((n_samples, n_classes));
890
891                // Aggregate confidence-rated predictions from all estimators
892                for (estimator, &weight) in estimators.iter().zip(weights.iter()) {
893                    let predictions_i32 = estimator.predict(x)?;
894                    let predictions = convert_predictions_to_float(&predictions_i32);
895
896                    // Estimate confidences for M2 (using pseudo-probabilities)
897                    let prob_estimates =
898                        estimate_probabilities(&predictions, classes, n_samples, n_classes);
899
900                    for i in 0..n_samples {
901                        for j in 0..n_classes {
902                            // M2 confidence: higher confidence for more certain predictions
903                            let confidence = prob_estimates[[i, j]];
904                            confidence_scores[[i, j]] += weight * confidence;
905                        }
906                    }
907                }
908
909                // Convert confidence scores to probabilities
910                let mut probabilities = Array2::<Float>::zeros((n_samples, n_classes));
911                for i in 0..n_samples {
912                    let score_sum = confidence_scores.row(i).sum();
913                    if score_sum > 0.0 {
914                        for j in 0..n_classes {
915                            probabilities[[i, j]] = confidence_scores[[i, j]] / score_sum;
916                        }
917                    } else {
918                        // Uniform distribution if no scores
919                        probabilities.row_mut(i).fill(1.0 / n_classes as Float);
920                    }
921                }
922
923                Ok(probabilities)
924            }
925            AdaBoostAlgorithm::Gentle | AdaBoostAlgorithm::Discrete => {
926                // Gentle AdaBoost: similar to SAMME but with dampened weights
927                let mut class_votes = Array2::<Float>::zeros((n_samples, n_classes));
928
929                // Aggregate predictions from all estimators with gentle weighting
930                for (estimator, &weight) in estimators.iter().zip(weights.iter()) {
931                    let predictions_i32 = estimator.predict(x)?;
932                    let predictions = convert_predictions_to_float(&predictions_i32);
933
934                    for (i, &pred) in predictions.iter().enumerate() {
935                        if let Some(class_idx) = classes.iter().position(|&c| c == pred) {
936                            class_votes[[i, class_idx]] += weight;
937                        }
938                    }
939                }
940
941                // Convert votes to probabilities with gentle smoothing
942                let mut probabilities = Array2::<Float>::zeros((n_samples, n_classes));
943                for i in 0..n_samples {
944                    let vote_sum = class_votes.row(i).sum();
945                    if vote_sum > 0.0 {
946                        for j in 0..n_classes {
947                            probabilities[[i, j]] = class_votes[[i, j]] / vote_sum;
948                        }
949
950                        // Apply gentle smoothing to avoid overconfident predictions
951                        let alpha = 0.1; // Smoothing parameter
952                        let uniform_prob = 1.0 / n_classes as Float;
953                        for j in 0..n_classes {
954                            probabilities[[i, j]] =
955                                (1.0 - alpha) * probabilities[[i, j]] + alpha * uniform_prob;
956                        }
957                    } else {
958                        // Uniform distribution if no votes
959                        probabilities.row_mut(i).fill(1.0 / n_classes as Float);
960                    }
961                }
962
963                Ok(probabilities)
964            }
965        }
966    }
967
968    /// Get decision function values
969    pub fn decision_function(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
970        let probas = self.predict_proba(x)?;
971
972        // For binary classification, return log-odds
973        if self.n_classes() == 2 {
974            let mut decision = Array2::<Float>::zeros((probas.nrows(), 1));
975            for i in 0..probas.nrows() {
976                let p1 = probas[[i, 1]].max(1e-15); // Avoid log(0)
977                let p0 = probas[[i, 0]].max(1e-15);
978                decision[[i, 0]] = (p1 / p0).ln();
979            }
980            Ok(decision)
981        } else {
982            // For multi-class, return log probabilities
983            Ok(probas.mapv(|p| p.max(1e-15).ln()))
984        }
985    }
986}
987
988impl Predict<Array2<Float>, Array1<Float>> for AdaBoostClassifier<Trained> {
989    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
990        let probas = self.predict_proba(x)?;
991        let classes = self.classes();
992        let mut predictions = Array1::<Float>::zeros(probas.nrows());
993        for (i, row) in probas.rows().into_iter().enumerate() {
994            let max_idx = row
995                .iter()
996                .enumerate()
997                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
998                .map(|(idx, _)| idx)
999                .unwrap_or(0);
1000            predictions[i] = classes[max_idx];
1001        }
1002        Ok(predictions)
1003    }
1004}