1use super::helpers::*;
4use super::types::*;
5use scirs2_core::ndarray::{Array1, Array2};
6use scirs2_core::random::Rng;
7use sklears_core::{
8 error::{Result, SklearsError},
9 prelude::{Fit, Predict},
10 traits::{Trained, Untrained},
11 types::Float,
12};
13use std::marker::PhantomData;
14
15use super::types::AdaBoostClassifier;
16impl AdaBoostClassifier<Untrained> {
17 pub fn new() -> Self {
19 Self {
20 config: AdaBoostConfig::default(),
21 state: PhantomData,
22 estimators_: None,
23 estimator_weights_: None,
24 estimator_errors_: None,
25 classes_: None,
26 n_classes_: None,
27 n_features_in_: None,
28 }
29 }
30
31 pub fn n_estimators(mut self, n_estimators: usize) -> Self {
33 self.config.n_estimators = n_estimators;
34 self
35 }
36
37 pub fn learning_rate(mut self, learning_rate: Float) -> Self {
39 self.config.learning_rate = learning_rate;
40 self
41 }
42
43 pub fn random_state(mut self, random_state: u64) -> Self {
45 self.config.random_state = Some(random_state);
46 self
47 }
48
49 pub fn algorithm(mut self, algorithm: AdaBoostAlgorithm) -> Self {
51 self.config.algorithm = algorithm;
52 self
53 }
54
55 pub fn with_samme_r(mut self) -> Self {
57 self.config.algorithm = AdaBoostAlgorithm::SAMMER;
58 self
59 }
60
61 pub fn with_gentle(mut self) -> Self {
63 self.config.algorithm = AdaBoostAlgorithm::Gentle;
64 self
65 }
66
67 pub(crate) fn find_classes(y: &Array1<Float>) -> Array1<Float> {
69 let mut classes: Vec<Float> = y.iter().cloned().collect();
70 classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
71 classes.dedup();
72 Array1::from_vec(classes)
73 }
74
75 fn calculate_sample_weights(
77 &self,
78 y: &Array1<Float>,
79 y_pred: &Array1<Float>,
80 sample_weight: &Array1<Float>,
81 estimator_weight: Float,
82 ) -> Array1<Float> {
83 let n_samples = y.len();
84 let mut new_weights = sample_weight.clone();
85
86 for i in 0..n_samples {
87 if y[i] != y_pred[i] {
88 new_weights[i] *= (estimator_weight).exp();
89 }
90 }
91
92 let weight_sum = new_weights.sum();
93 if weight_sum > 0.0 {
94 new_weights /= weight_sum;
95 } else {
96 new_weights.fill(1.0 / n_samples as Float);
97 }
98
99 new_weights
100 }
101
102 fn calculate_estimator_weight(&self, error: Float, n_classes: usize) -> Float {
104 if error <= 0.0 {
105 return 10.0;
106 }
107
108 if error >= 1.0 - 1.0 / n_classes as Float {
109 return 0.0;
110 }
111
112 match self.config.algorithm {
113 AdaBoostAlgorithm::SAMME => {
114 let alpha = ((1.0 - error) / error).ln() + (n_classes as Float - 1.0).ln();
115 alpha * self.config.learning_rate
116 }
117 AdaBoostAlgorithm::SAMMER => self.config.learning_rate,
118 AdaBoostAlgorithm::RealAdaBoost => 0.5 * ((1.0 - error) / error).ln(),
119 AdaBoostAlgorithm::M1 => {
120 if error >= 0.5 {
121 return 0.0;
122 }
123 0.5 * ((1.0 - error) / error).ln()
124 }
125 AdaBoostAlgorithm::M2 => {
126 let alpha = 0.5 * ((1.0 - error) / error).ln();
127 alpha * self.config.learning_rate
128 }
129 AdaBoostAlgorithm::Gentle => {
130 let alpha = 0.5 * ((1.0 - error) / error).ln();
131 alpha * self.config.learning_rate * 0.5
132 }
133 AdaBoostAlgorithm::Discrete => ((1.0 - error) / error).ln() * self.config.learning_rate,
134 }
135 }
136
137 fn resample_data(
139 &self,
140 x: &Array2<Float>,
141 y: &Array1<i32>,
142 sample_weight: &Array1<Float>,
143 rng: &mut impl Rng,
144 ) -> Result<(Array2<Float>, Array1<i32>)> {
145 let n_samples = x.nrows();
146
147 let weight_sum = sample_weight.sum();
148 let normalized_weights = if weight_sum > 0.0 {
149 sample_weight / weight_sum
150 } else {
151 Array1::<Float>::from_elem(n_samples, 1.0 / n_samples as Float)
152 };
153
154 let mut cumulative = Array1::<Float>::zeros(n_samples);
155 cumulative[0] = normalized_weights[0];
156 for i in 1..n_samples {
157 cumulative[i] = cumulative[i - 1] + normalized_weights[i];
158 }
159
160 let mut selected_indices = Vec::new();
161 let unique_classes: std::collections::HashSet<i32> = y.iter().cloned().collect();
162
163 for _i in 0..n_samples {
164 let rand_val = rng.random::<Float>() * cumulative[n_samples - 1];
165 let idx = cumulative
166 .iter()
167 .position(|&cum| cum >= rand_val)
168 .unwrap_or(n_samples - 1);
169 selected_indices.push(idx);
170 }
171
172 let resampled_classes: std::collections::HashSet<i32> =
173 selected_indices.iter().map(|&idx| y[idx]).collect();
174
175 if resampled_classes.len() < unique_classes.len() {
176 let mut replacement_count = 0;
177 for &missing_class in unique_classes.difference(&resampled_classes) {
178 if let Some(original_idx) = y.iter().position(|&class| class == missing_class) {
179 if replacement_count < selected_indices.len() {
180 selected_indices[replacement_count] = original_idx;
181 replacement_count += 1;
182 }
183 }
184 }
185 }
186
187 let mut x_resampled = Array2::<Float>::zeros(x.dim());
188 let mut y_resampled = Array1::<i32>::zeros(y.len());
189
190 for (i, &idx) in selected_indices.iter().enumerate() {
191 x_resampled.row_mut(i).assign(&x.row(idx));
192 y_resampled[i] = y[idx];
193 }
194
195 Ok((x_resampled, y_resampled))
196 }
197
198 fn calculate_sample_weights_sammer(
200 &self,
201 y: &Array1<Float>,
202 prob_estimates: &Array2<Float>,
203 sample_weight: &Array1<Float>,
204 classes: &Array1<Float>,
205 estimator_weight: Float,
206 ) -> Array1<Float> {
207 let n_samples = y.len();
208 let n_classes = classes.len();
209 let mut new_weights = sample_weight.clone();
210
211 let factor = ((n_classes - 1) as Float / n_classes as Float) * estimator_weight;
212
213 for i in 0..n_samples {
214 let true_class = y[i];
215 let true_class_idx = classes.iter().position(|&c| c == true_class);
216
217 if let Some(class_idx) = true_class_idx {
218 let probs = prob_estimates.row(i);
219
220 let mut h_xi = 0.0;
221 for k in 0..n_classes {
222 let p_k = probs[k].clamp(1e-7, 1.0 - 1e-7);
223
224 if k == class_idx {
225 h_xi += (n_classes as Float - 1.0) * p_k.ln();
226 } else {
227 h_xi -= p_k.ln();
228 }
229 }
230
231 let weight_multiplier = (-factor * h_xi / n_classes as Float).exp();
232 new_weights[i] *= weight_multiplier;
233 new_weights[i] = new_weights[i].clamp(1e-10, 1e3);
234 }
235 }
236
237 let weight_sum = new_weights.sum();
238 if weight_sum > 0.0 {
239 new_weights /= weight_sum;
240 } else {
241 new_weights.fill(1.0 / n_samples as Float);
242 }
243
244 new_weights
245 }
246
247 fn calculate_sample_weights_real_adaboost(
249 &self,
250 y: &Array1<Float>,
251 prob_estimates: &Array2<Float>,
252 sample_weight: &Array1<Float>,
253 classes: &Array1<Float>,
254 ) -> Array1<Float> {
255 let n_samples = y.len();
256 let mut new_weights = sample_weight.clone();
257
258 for i in 0..n_samples {
259 let true_class = y[i];
260 let y_i = if true_class == classes[0] { -1.0 } else { 1.0 };
261
262 let p_0 = prob_estimates[[i, 0]].clamp(1e-7, 1.0 - 1e-7);
263 let p_1 = prob_estimates[[i, 1]].clamp(1e-7, 1.0 - 1e-7);
264
265 let h_xi = 0.5 * (p_1 / p_0).ln();
266
267 let weight_multiplier = (-y_i * h_xi).exp();
268 new_weights[i] *= weight_multiplier;
269 new_weights[i] = new_weights[i].clamp(1e-10, 1e3);
270 }
271
272 let weight_sum = new_weights.sum();
273 if weight_sum > 0.0 {
274 new_weights /= weight_sum;
275 } else {
276 new_weights.fill(1.0 / n_samples as Float);
277 }
278
279 new_weights
280 }
281}
282
283impl Default for AdaBoostClassifier<Untrained> {
284 fn default() -> Self {
285 Self::new()
286 }
287}
288
289impl<State> std::fmt::Debug for AdaBoostClassifier<State> {
290 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291 f.debug_struct("AdaBoostClassifier")
292 .field("config", &self.config)
293 .field(
294 "n_estimators_fitted",
295 &self.estimators_.as_ref().map(|e| e.len()),
296 )
297 .field("n_classes", &self.n_classes_)
298 .field("n_features_in", &self.n_features_in_)
299 .finish()
300 }
301}
302
303impl Fit<Array2<Float>, Array1<Float>> for AdaBoostClassifier<Untrained> {
304 type Fitted = AdaBoostClassifier<Trained>;
305 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
306 let (n_samples, n_features) = x.dim();
307 if n_samples != y.len() {
308 return Err(SklearsError::InvalidInput(
309 "Number of samples in X and y must match".to_string(),
310 ));
311 }
312 if n_samples == 0 {
313 return Err(SklearsError::InvalidInput(
314 "Cannot fit AdaBoost on empty dataset".to_string(),
315 ));
316 }
317 if self.config.n_estimators == 0 {
318 return Err(SklearsError::InvalidParameter {
319 name: "n_estimators".to_string(),
320 reason: "Number of estimators must be positive".to_string(),
321 });
322 }
323 let classes = Self::find_classes(y);
324 let n_classes = classes.len();
325 if n_classes < 2 {
326 return Err(SklearsError::InvalidInput(
327 "AdaBoost requires at least 2 classes".to_string(),
328 ));
329 }
330 let mut sample_weight = Array1::<Float>::from_elem(n_samples, 1.0 / n_samples as Float);
331 let mut estimators = Vec::new();
332 let mut estimator_weights = Vec::new();
333 let mut estimator_errors = Vec::new();
334 let mut rng = match self.config.random_state {
335 Some(seed) => scirs2_core::random::seeded_rng(seed),
336 None => scirs2_core::random::seeded_rng(42),
337 };
338 let y_i32 = convert_labels_to_i32(y);
339 for _iteration in 0..self.config.n_estimators {
340 let base_estimator = DecisionTreeClassifier::new()
341 .criterion(SplitCriterion::Gini)
342 .max_depth(1)
343 .min_samples_split(2)
344 .min_samples_leaf(1);
345 match self.config.algorithm {
346 AdaBoostAlgorithm::SAMME => {
347 let (x_resampled, y_resampled) =
348 self.resample_data(x, &y_i32, &sample_weight, &mut rng)?;
349 let fitted_estimator = base_estimator.fit(&x_resampled, &y_resampled)?;
350 let y_pred_i32 = fitted_estimator.predict(x)?;
351 let y_pred = convert_predictions_to_float(&y_pred_i32);
352 if y_pred.len() != n_samples {
353 return Err(SklearsError::ShapeMismatch {
354 expected: format!("{} predictions", n_samples),
355 actual: format!("{} predictions", y_pred.len()),
356 });
357 }
358 let mut weighted_error = 0.0;
359 for i in 0..n_samples {
360 if y[i] != y_pred[i] {
361 weighted_error += sample_weight[i];
362 }
363 }
364 if weighted_error >= 0.5 {
365 if estimators.is_empty() {
366 estimators.push(fitted_estimator);
367 estimator_weights.push(0.0);
368 estimator_errors.push(weighted_error);
369 }
370 break;
371 }
372 let estimator_weight =
373 self.calculate_estimator_weight(weighted_error, n_classes);
374 estimators.push(fitted_estimator);
375 estimator_weights.push(estimator_weight);
376 estimator_errors.push(weighted_error);
377 sample_weight =
378 self.calculate_sample_weights(y, &y_pred, &sample_weight, estimator_weight);
379 if weighted_error < 1e-10 {
380 break;
381 }
382 }
383 AdaBoostAlgorithm::SAMMER => {
384 let (x_resampled, y_resampled) =
385 self.resample_data(x, &y_i32, &sample_weight, &mut rng)?;
386 let fitted_estimator = base_estimator.fit(&x_resampled, &y_resampled)?;
387 let y_pred_i32 = fitted_estimator.predict(x)?;
388 let y_pred = convert_predictions_to_float(&y_pred_i32);
389 if y_pred.len() != n_samples {
390 return Err(SklearsError::ShapeMismatch {
391 expected: format!("{} predictions", n_samples),
392 actual: format!("{} predictions", y_pred.len()),
393 });
394 }
395 let prob_estimates =
396 estimate_probabilities(&y_pred, &classes, n_samples, n_classes);
397 let mut weighted_error = 0.0;
398 for i in 0..n_samples {
399 if y[i] != y_pred[i] {
400 weighted_error += sample_weight[i];
401 }
402 }
403 let estimator_weight = self.config.learning_rate;
404 estimators.push(fitted_estimator);
405 estimator_weights.push(estimator_weight);
406 estimator_errors.push(weighted_error);
407 sample_weight = self.calculate_sample_weights_sammer(
408 y,
409 &prob_estimates,
410 &sample_weight,
411 &classes,
412 estimator_weight,
413 );
414 if !(1e-10..0.5).contains(&weighted_error) {
415 break;
416 }
417 }
418 AdaBoostAlgorithm::RealAdaBoost => {
419 let (x_resampled, y_resampled) =
420 self.resample_data(x, &y_i32, &sample_weight, &mut rng)?;
421 let fitted_estimator = base_estimator.fit(&x_resampled, &y_resampled)?;
422 if n_classes != 2 {
423 return Err(SklearsError::InvalidInput(
424 "Real AdaBoost currently supports only binary classification"
425 .to_string(),
426 ));
427 }
428 let y_pred_i32 = fitted_estimator.predict(x)?;
429 let y_pred = convert_predictions_to_float(&y_pred_i32);
430 if y_pred.len() != n_samples {
431 return Err(SklearsError::ShapeMismatch {
432 expected: format!("{} predictions", n_samples),
433 actual: format!("{} predictions", y_pred.len()),
434 });
435 }
436 let prob_estimates = estimate_binary_probabilities(&y_pred, &classes);
437 let mut weighted_error = 0.0;
438 for i in 0..n_samples {
439 let correct_class_idx = if y[i] == classes[0] { 0 } else { 1 };
440 let prob_correct = prob_estimates[[i, correct_class_idx]];
441 if prob_correct < 0.5 {
442 weighted_error += sample_weight[i];
443 }
444 }
445 let estimator_weight = if weighted_error > 0.0 && weighted_error < 0.5 {
446 0.5 * ((1.0 - weighted_error) / weighted_error).ln()
447 } else if weighted_error == 0.0 {
448 10.0
449 } else {
450 0.0
451 };
452 estimators.push(fitted_estimator);
453 estimator_weights.push(estimator_weight);
454 estimator_errors.push(weighted_error);
455 sample_weight = self.calculate_sample_weights_real_adaboost(
456 y,
457 &prob_estimates,
458 &sample_weight,
459 &classes,
460 );
461 if !(1e-10..0.5).contains(&weighted_error) {
462 break;
463 }
464 }
465 AdaBoostAlgorithm::M1 => {
466 let (x_resampled, y_resampled) =
467 self.resample_data(x, &y_i32, &sample_weight, &mut rng)?;
468 let fitted_estimator = base_estimator.fit(&x_resampled, &y_resampled)?;
469 let y_pred_i32 = fitted_estimator.predict(x)?;
470 let y_pred = convert_predictions_to_float(&y_pred_i32);
471 if y_pred.len() != n_samples {
472 return Err(SklearsError::ShapeMismatch {
473 expected: format!("{} predictions", n_samples),
474 actual: format!("{} predictions", y_pred.len()),
475 });
476 }
477 let mut weighted_error = 0.0;
478 for i in 0..n_samples {
479 if y[i] != y_pred[i] {
480 weighted_error += sample_weight[i];
481 }
482 }
483 if weighted_error >= 0.5 {
484 if estimators.is_empty() {
485 return Err(SklearsError::InvalidInput(
486 "AdaBoost.M1 requires strong learners (error < 0.5)".to_string(),
487 ));
488 }
489 break;
490 }
491 let estimator_weight =
492 self.calculate_estimator_weight(weighted_error, n_classes);
493 estimators.push(fitted_estimator);
494 estimator_weights.push(estimator_weight);
495 estimator_errors.push(weighted_error);
496 sample_weight =
497 self.calculate_sample_weights(y, &y_pred, &sample_weight, estimator_weight);
498 if weighted_error < 1e-10 {
499 break;
500 }
501 }
502 AdaBoostAlgorithm::M2 => {
503 let (x_resampled, y_resampled) =
504 self.resample_data(x, &y_i32, &sample_weight, &mut rng)?;
505 let fitted_estimator = base_estimator.fit(&x_resampled, &y_resampled)?;
506 let y_pred_i32 = fitted_estimator.predict(x)?;
507 let y_pred = convert_predictions_to_float(&y_pred_i32);
508 if y_pred.len() != n_samples {
509 return Err(SklearsError::ShapeMismatch {
510 expected: format!("{} predictions", n_samples),
511 actual: format!("{} predictions", y_pred.len()),
512 });
513 }
514 let prob_estimates =
515 estimate_probabilities(&y_pred, &classes, n_samples, n_classes);
516 let mut pseudo_loss = 0.0;
517 for i in 0..n_samples {
518 let true_class_idx = classes.iter().position(|&c| c == y[i]).unwrap_or(0);
519 let mut margin = prob_estimates[[i, true_class_idx]];
520 for j in 0..n_classes {
521 if j != true_class_idx {
522 margin -= prob_estimates[[i, j]] / (n_classes - 1) as Float;
523 }
524 }
525 if margin <= 0.0 {
526 pseudo_loss += sample_weight[i] * (1.0 - margin);
527 }
528 }
529 let total_weight: Float = sample_weight.sum();
530 if total_weight > 0.0 {
531 pseudo_loss /= total_weight;
532 }
533 if pseudo_loss >= 0.5 {
534 if estimators.is_empty() {
535 estimators.push(fitted_estimator);
536 estimator_weights.push(0.0);
537 estimator_errors.push(pseudo_loss);
538 }
539 break;
540 }
541 let estimator_weight = self.calculate_estimator_weight(pseudo_loss, n_classes);
542 estimators.push(fitted_estimator);
543 estimator_weights.push(estimator_weight);
544 estimator_errors.push(pseudo_loss);
545 let mut new_weights = sample_weight.clone();
546 for i in 0..n_samples {
547 let true_class_idx = classes.iter().position(|&c| c == y[i]).unwrap_or(0);
548 let confidence = prob_estimates[[i, true_class_idx]];
549 let weight_multiplier = if confidence < 0.5 {
550 (estimator_weight * (1.0 - confidence)).exp()
551 } else {
552 (estimator_weight * confidence).exp().recip()
553 };
554 new_weights[i] *= weight_multiplier;
555 }
556 let weight_sum = new_weights.sum();
557 if weight_sum > 0.0 {
558 new_weights /= weight_sum;
559 } else {
560 new_weights.fill(1.0 / n_samples as Float);
561 }
562 sample_weight = new_weights;
563 if pseudo_loss < 1e-10 {
564 break;
565 }
566 }
567 AdaBoostAlgorithm::Gentle | AdaBoostAlgorithm::Discrete => {
568 let (x_resampled, y_resampled) =
569 self.resample_data(x, &y_i32, &sample_weight, &mut rng)?;
570 let fitted_estimator = base_estimator.fit(&x_resampled, &y_resampled)?;
571 let y_pred_i32 = fitted_estimator.predict(x)?;
572 let y_pred = convert_predictions_to_float(&y_pred_i32);
573 if y_pred.len() != n_samples {
574 return Err(SklearsError::ShapeMismatch {
575 expected: format!("{} predictions", n_samples),
576 actual: format!("{} predictions", y_pred.len()),
577 });
578 }
579 let mut weighted_error = 0.0;
580 for i in 0..n_samples {
581 if y[i] != y_pred[i] {
582 weighted_error += sample_weight[i];
583 }
584 }
585 if weighted_error >= 0.6 {
586 if estimators.is_empty() {
587 estimators.push(fitted_estimator);
588 estimator_weights.push(0.1);
589 estimator_errors.push(weighted_error);
590 }
591 break;
592 }
593 let estimator_weight =
594 self.calculate_estimator_weight(weighted_error, n_classes);
595 estimators.push(fitted_estimator);
596 estimator_weights.push(estimator_weight);
597 estimator_errors.push(weighted_error);
598 let mut new_weights = sample_weight.clone();
599 let gentle_factor = 0.5;
600 for i in 0..n_samples {
601 let multiplier = if y[i] != y_pred[i] {
602 (gentle_factor * estimator_weight).exp()
603 } else {
604 (-gentle_factor * estimator_weight).exp()
605 };
606 new_weights[i] *= multiplier;
607 }
608 let weight_sum = new_weights.sum();
609 if weight_sum > 0.0 {
610 new_weights /= weight_sum;
611 let smoothing = 0.01;
612 let uniform_weight = 1.0 / n_samples as Float;
613 for i in 0..n_samples {
614 new_weights[i] =
615 (1.0 - smoothing) * new_weights[i] + smoothing * uniform_weight;
616 }
617 let smoothed_sum = new_weights.sum();
618 if smoothed_sum > 0.0 {
619 new_weights /= smoothed_sum;
620 }
621 } else {
622 new_weights.fill(1.0 / n_samples as Float);
623 }
624 sample_weight = new_weights;
625 if weighted_error < 1e-10 {
626 break;
627 }
628 }
629 }
630 }
631 if estimators.is_empty() {
632 return Err(SklearsError::InvalidInput(
633 "AdaBoost failed to fit any estimators".to_string(),
634 ));
635 }
636 Ok(AdaBoostClassifier {
637 config: self.config,
638 state: PhantomData,
639 estimators_: Some(estimators),
640 estimator_weights_: Some(Array1::from_vec(estimator_weights)),
641 estimator_errors_: Some(Array1::from_vec(estimator_errors)),
642 classes_: Some(classes),
643 n_classes_: Some(n_classes),
644 n_features_in_: Some(n_features),
645 })
646 }
647}
648
649impl AdaBoostClassifier<Trained> {
650 pub fn estimators(&self) -> &[DecisionTreeClassifier<Trained>] {
652 self.estimators_
653 .as_ref()
654 .expect("AdaBoost should be fitted")
655 }
656
657 pub fn estimator_weights(&self) -> &Array1<Float> {
659 self.estimator_weights_
660 .as_ref()
661 .expect("AdaBoost should be fitted")
662 }
663
664 pub fn estimator_errors(&self) -> &Array1<Float> {
666 self.estimator_errors_
667 .as_ref()
668 .expect("AdaBoost should be fitted")
669 }
670
671 pub fn classes(&self) -> &Array1<Float> {
673 self.classes_.as_ref().expect("AdaBoost should be fitted")
674 }
675
676 pub fn n_classes(&self) -> usize {
678 self.n_classes_.expect("AdaBoost should be fitted")
679 }
680
681 pub fn n_features_in(&self) -> usize {
683 self.n_features_in_.expect("AdaBoost should be fitted")
684 }
685
686 pub fn feature_importances(&self) -> Result<Array1<Float>> {
688 let estimators = self.estimators();
689 let weights = self.estimator_weights();
690 let n_features = self.n_features_in();
691
692 if estimators.is_empty() {
693 return Ok(Array1::<Float>::zeros(n_features));
694 }
695
696 let mut importances = Array1::<Float>::zeros(n_features);
697 let mut total_weight = 0.0;
698
699 for (_estimator, &weight) in estimators.iter().zip(weights.iter()) {
700 let tree_importances = Array1::<Float>::ones(n_features) / n_features as Float;
703 importances += &(tree_importances * weight.abs());
704 total_weight += weight.abs();
705 }
706
707 if total_weight > 0.0 {
708 importances /= total_weight;
709 } else {
710 importances.fill(1.0 / n_features as Float);
712 }
713
714 Ok(importances)
715 }
716
717 pub fn predict_proba(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
719 let (n_samples, n_features) = x.dim();
720
721 if n_features != self.n_features_in() {
722 return Err(SklearsError::FeatureMismatch {
723 expected: self.n_features_in(),
724 actual: n_features,
725 });
726 }
727
728 let estimators = self.estimators();
729 let weights = self.estimator_weights();
730 let classes = self.classes();
731 let n_classes = self.n_classes();
732
733 match self.config.algorithm {
734 AdaBoostAlgorithm::SAMME => {
735 let mut class_votes = Array2::<Float>::zeros((n_samples, n_classes));
737
738 for (estimator, &weight) in estimators.iter().zip(weights.iter()) {
740 let predictions_i32 = estimator.predict(x)?;
741 let predictions = convert_predictions_to_float(&predictions_i32);
742
743 for (i, &pred) in predictions.iter().enumerate() {
744 if let Some(class_idx) = classes.iter().position(|&c| c == pred) {
745 class_votes[[i, class_idx]] += weight;
746 }
747 }
748 }
749
750 let mut probabilities = Array2::<Float>::zeros((n_samples, n_classes));
752 for i in 0..n_samples {
753 let vote_sum = class_votes.row(i).sum();
754 if vote_sum > 0.0 {
755 for j in 0..n_classes {
756 probabilities[[i, j]] = class_votes[[i, j]] / vote_sum;
757 }
758 } else {
759 probabilities.row_mut(i).fill(1.0 / n_classes as Float);
761 }
762 }
763
764 Ok(probabilities)
765 }
766 AdaBoostAlgorithm::SAMMER => {
767 let mut prob_sum = Array2::<Float>::zeros((n_samples, n_classes));
769
770 prob_sum.fill(1.0 / n_classes as Float);
772
773 for (estimator, &weight) in estimators.iter().zip(weights.iter()) {
775 let predictions_i32 = estimator.predict(x)?;
776 let predictions = convert_predictions_to_float(&predictions_i32);
777
778 let prob_estimates =
780 estimate_probabilities(&predictions, classes, n_samples, n_classes);
781
782 for i in 0..n_samples {
784 for j in 0..n_classes {
785 prob_sum[[i, j]] += weight * prob_estimates[[i, j]].ln();
787 }
788 }
789 }
790
791 let mut probabilities = Array2::<Float>::zeros((n_samples, n_classes));
793 for i in 0..n_samples {
794 let max_log = prob_sum
796 .row(i)
797 .iter()
798 .cloned()
799 .fold(f64::NEG_INFINITY, f64::max);
800
801 let mut sum = 0.0;
803 for j in 0..n_classes {
804 probabilities[[i, j]] = (prob_sum[[i, j]] - max_log).exp();
805 sum += probabilities[[i, j]];
806 }
807
808 for j in 0..n_classes {
810 probabilities[[i, j]] /= sum;
811 }
812 }
813
814 Ok(probabilities)
815 }
816 AdaBoostAlgorithm::RealAdaBoost => {
817 if n_classes != 2 {
819 return Err(SklearsError::InvalidInput(
820 "Real AdaBoost predict_proba only supports binary classification"
821 .to_string(),
822 ));
823 }
824
825 let mut decision_scores = Array1::<Float>::zeros(n_samples);
826
827 for (estimator, &weight) in estimators.iter().zip(weights.iter()) {
829 let predictions_i32 = estimator.predict(x)?;
830 let predictions = convert_predictions_to_float(&predictions_i32);
831
832 let prob_estimates = estimate_binary_probabilities(&predictions, classes);
834
835 for i in 0..n_samples {
836 let p_0 = prob_estimates[[i, 0]].clamp(1e-7, 1.0 - 1e-7);
837 let p_1 = prob_estimates[[i, 1]].clamp(1e-7, 1.0 - 1e-7);
838
839 let h_t = 0.5 * (p_1 / p_0).ln();
841 decision_scores[i] += weight * h_t;
842 }
843 }
844
845 let mut probabilities = Array2::<Float>::zeros((n_samples, 2));
847 for i in 0..n_samples {
848 let sigmoid = 1.0 / (1.0 + (-decision_scores[i]).exp());
849 probabilities[[i, 1]] = sigmoid;
850 probabilities[[i, 0]] = 1.0 - sigmoid;
851 }
852
853 Ok(probabilities)
854 }
855 AdaBoostAlgorithm::M1 => {
856 let mut class_votes = Array2::<Float>::zeros((n_samples, n_classes));
858
859 for (estimator, &weight) in estimators.iter().zip(weights.iter()) {
861 let predictions_i32 = estimator.predict(x)?;
862 let predictions = convert_predictions_to_float(&predictions_i32);
863
864 for (i, &pred) in predictions.iter().enumerate() {
865 if let Some(class_idx) = classes.iter().position(|&c| c == pred) {
866 class_votes[[i, class_idx]] += weight;
867 }
868 }
869 }
870
871 let mut probabilities = Array2::<Float>::zeros((n_samples, n_classes));
873 for i in 0..n_samples {
874 let vote_sum = class_votes.row(i).sum();
875 if vote_sum > 0.0 {
876 for j in 0..n_classes {
877 probabilities[[i, j]] = class_votes[[i, j]] / vote_sum;
878 }
879 } else {
880 probabilities.row_mut(i).fill(1.0 / n_classes as Float);
882 }
883 }
884
885 Ok(probabilities)
886 }
887 AdaBoostAlgorithm::M2 => {
888 let mut confidence_scores = Array2::<Float>::zeros((n_samples, n_classes));
890
891 for (estimator, &weight) in estimators.iter().zip(weights.iter()) {
893 let predictions_i32 = estimator.predict(x)?;
894 let predictions = convert_predictions_to_float(&predictions_i32);
895
896 let prob_estimates =
898 estimate_probabilities(&predictions, classes, n_samples, n_classes);
899
900 for i in 0..n_samples {
901 for j in 0..n_classes {
902 let confidence = prob_estimates[[i, j]];
904 confidence_scores[[i, j]] += weight * confidence;
905 }
906 }
907 }
908
909 let mut probabilities = Array2::<Float>::zeros((n_samples, n_classes));
911 for i in 0..n_samples {
912 let score_sum = confidence_scores.row(i).sum();
913 if score_sum > 0.0 {
914 for j in 0..n_classes {
915 probabilities[[i, j]] = confidence_scores[[i, j]] / score_sum;
916 }
917 } else {
918 probabilities.row_mut(i).fill(1.0 / n_classes as Float);
920 }
921 }
922
923 Ok(probabilities)
924 }
925 AdaBoostAlgorithm::Gentle | AdaBoostAlgorithm::Discrete => {
926 let mut class_votes = Array2::<Float>::zeros((n_samples, n_classes));
928
929 for (estimator, &weight) in estimators.iter().zip(weights.iter()) {
931 let predictions_i32 = estimator.predict(x)?;
932 let predictions = convert_predictions_to_float(&predictions_i32);
933
934 for (i, &pred) in predictions.iter().enumerate() {
935 if let Some(class_idx) = classes.iter().position(|&c| c == pred) {
936 class_votes[[i, class_idx]] += weight;
937 }
938 }
939 }
940
941 let mut probabilities = Array2::<Float>::zeros((n_samples, n_classes));
943 for i in 0..n_samples {
944 let vote_sum = class_votes.row(i).sum();
945 if vote_sum > 0.0 {
946 for j in 0..n_classes {
947 probabilities[[i, j]] = class_votes[[i, j]] / vote_sum;
948 }
949
950 let alpha = 0.1; let uniform_prob = 1.0 / n_classes as Float;
953 for j in 0..n_classes {
954 probabilities[[i, j]] =
955 (1.0 - alpha) * probabilities[[i, j]] + alpha * uniform_prob;
956 }
957 } else {
958 probabilities.row_mut(i).fill(1.0 / n_classes as Float);
960 }
961 }
962
963 Ok(probabilities)
964 }
965 }
966 }
967
968 pub fn decision_function(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
970 let probas = self.predict_proba(x)?;
971
972 if self.n_classes() == 2 {
974 let mut decision = Array2::<Float>::zeros((probas.nrows(), 1));
975 for i in 0..probas.nrows() {
976 let p1 = probas[[i, 1]].max(1e-15); let p0 = probas[[i, 0]].max(1e-15);
978 decision[[i, 0]] = (p1 / p0).ln();
979 }
980 Ok(decision)
981 } else {
982 Ok(probas.mapv(|p| p.max(1e-15).ln()))
984 }
985 }
986}
987
988impl Predict<Array2<Float>, Array1<Float>> for AdaBoostClassifier<Trained> {
989 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
990 let probas = self.predict_proba(x)?;
991 let classes = self.classes();
992 let mut predictions = Array1::<Float>::zeros(probas.nrows());
993 for (i, row) in probas.rows().into_iter().enumerate() {
994 let max_idx = row
995 .iter()
996 .enumerate()
997 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
998 .map(|(idx, _)| idx)
999 .unwrap_or(0);
1000 predictions[i] = classes[max_idx];
1001 }
1002 Ok(predictions)
1003 }
1004}