1#![allow(dead_code)]
2#![allow(non_snake_case)]
3#![allow(missing_docs)]
4#![allow(deprecated)]
5#![allow(clippy::needless_range_loop)]
6use scirs2_core::ndarray::{Array1, Array2, Axis};
14use sklears_core::{
15 error::{validate, Result},
16 prelude::{Fit, Predict, SklearsError},
17 traits::{Estimator, PredictProba, Trained, Untrained},
18 types::Float,
19};
20use std::{collections::HashMap, marker::PhantomData};
21
22pub mod core;
24
25pub mod metrics;
27
28pub mod binary;
30
31pub mod isotonic;
33
34pub mod temperature;
36
37pub mod histogram;
39
40pub mod bbq;
42
43pub mod multiclass;
45
46pub mod beta;
48
49pub mod local;
51
52pub mod kde;
54
55pub mod gaussian_process;
57
58pub mod visualization;
60
61pub mod conformal;
63
64#[allow(non_snake_case)]
66#[cfg(test)]
67mod property_tests;
68
69pub mod statistical_tests;
71
72pub mod numerical_stability;
74
75pub mod prediction_intervals;
77
78pub mod uncertainty_estimation;
80
81pub mod higher_order_uncertainty;
83
84pub mod bayesian;
86
87pub mod domain_specific;
89
90pub mod neural_calibration;
92
93pub mod streaming;
95
96pub mod calibration_aware_training;
98
99#[allow(non_snake_case)]
101#[cfg(test)]
102pub mod robustness_tests;
103
104pub mod high_precision;
106
107pub mod ultra_precision;
109
110pub mod theoretical_validation;
112
113pub mod fluent_api;
115
116pub mod multi_modal;
118
119pub mod large_scale;
121
122pub mod optimization;
124
125pub mod quantum_optimization;
127
128pub mod information_geometry;
130
131pub mod modular_framework;
133
134pub mod advanced;
136
137#[allow(non_snake_case)]
139#[cfg(test)]
140pub mod reference_tests;
141
142#[cfg(feature = "serde")]
144pub mod serialization;
145
146pub mod validation;
148
149pub mod performance;
151
152pub mod gpu_calibration;
154
155pub mod llm_calibration;
157
158pub mod differential_privacy;
160
161pub mod meta_learning;
163
164pub mod continual_learning;
166
167pub mod topological_calibration;
169
170pub mod category_theoretic;
172
173pub mod measure_theoretic;
175
176#[cfg(feature = "serde")]
178pub use serialization::{
179 CalibrationMetadata, CalibrationModelFactory, CalibrationSerializer, FromSerializable,
180 SerializableCalibrationModel, SerializableParameter, ToSerializable,
181};
182
183use advanced::{
184 train_bayesian_model_averaging_calibrators, train_conformal_cross_calibrators,
185 train_conformal_jackknife_calibrators, train_conformal_split_calibrators,
186 train_dirichlet_process_calibrators, train_hierarchical_bayesian_calibrators,
187 train_mcmc_calibrators, train_nonparametric_gp_calibrators, train_ranking_calibrators,
188 train_regression_calibrators, train_survival_calibrators, train_time_series_calibrators,
189 train_variational_inference_calibrators,
190};
191use binary::{
192 create_dummy_probabilities, train_adaptive_kde_calibrators, train_bbq_calibrators,
193 train_beta_calibrators, train_dirichlet_calibrators, train_ensemble_temperature_calibrators,
194 train_gaussian_process_calibrators, train_histogram_calibrators, train_isotonic_calibrators,
195 train_kde_calibrators, train_local_binning_calibrators, train_local_knn_calibrators,
196 train_matrix_scaling_calibrators, train_multiclass_temperature_calibrators,
197 train_one_vs_one_calibrators, train_sigmoid_calibrators, train_temperature_calibrators,
198 SigmoidCalibrator,
199};
200use gaussian_process::VariationalGPCalibrator;
201
202pub trait CalibrationEstimator: Send + Sync + std::fmt::Debug {
204 fn fit(&mut self, probabilities: &Array1<Float>, y_true: &Array1<i32>) -> Result<()>;
206
207 fn predict_proba(&self, probabilities: &Array1<Float>) -> Result<Array1<Float>>;
209
210 fn clone_box(&self) -> Box<dyn CalibrationEstimator>;
212}
213
214impl Clone for Box<dyn CalibrationEstimator> {
215 fn clone(&self) -> Self {
216 self.clone_box()
217 }
218}
219
220impl<State> std::fmt::Debug for CalibratedClassifierCV<State> {
221 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222 f.debug_struct("CalibratedClassifierCV")
223 .field("config", &self.config)
224 .field(
225 "n_calibrators",
226 &self.calibrators_.as_ref().map(|c| c.len()),
227 )
228 .field("classes", &self.classes_)
229 .field("n_features", &self.n_features_)
230 .finish()
231 }
232}
233
234#[derive(Debug, Clone)]
236#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
237pub struct CalibratedClassifierCVConfig {
238 pub method: CalibrationMethod,
240 pub cv: usize,
242 pub ensemble: bool,
244}
245
246impl Default for CalibratedClassifierCVConfig {
247 fn default() -> Self {
248 Self {
249 method: CalibrationMethod::Sigmoid,
250 cv: 3,
251 ensemble: true,
252 }
253 }
254}
255
256#[derive(Debug, Clone, PartialEq)]
258#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
259pub enum CalibrationMethod {
260 Sigmoid,
262 Isotonic,
264 Temperature,
266 HistogramBinning { n_bins: usize },
268 BBQ { min_bins: usize, max_bins: usize },
270 Beta,
272 EnsembleTemperature { n_estimators: usize },
274 OneVsOne,
276 MulticlassTemperature,
278 MatrixScaling,
280 Dirichlet { concentration: Float },
282 LocalKNN { k: usize },
284 LocalBinning { n_bins: usize },
286 KDE,
288 AdaptiveKDE { adaptation_factor: Float },
290 GaussianProcess,
292 VariationalGP { n_inducing: usize },
294 ConformalSplit { alpha: Float },
296 ConformalCross { alpha: Float, n_folds: usize },
298 ConformalJackknife { alpha: Float },
300 BayesianModelAveraging { n_models: usize },
302 VariationalInference {
304 learning_rate: Float,
305 n_samples: usize,
306 max_iter: usize,
307 },
308 MCMC {
310 n_samples: usize,
311 burn_in: usize,
312 step_size: Float,
313 },
314 HierarchicalBayesian,
316 DirichletProcess {
318 concentration: Float,
319 max_clusters: usize,
320 },
321 NonParametricGP {
323 kernel_type: String,
324 n_inducing: usize,
325 },
326 TimeSeries {
328 window_size: usize,
329 temporal_decay: Float,
330 },
331 Regression { distributional: bool },
333 Ranking {
335 ranking_weight: Float,
336 listwise: bool,
337 },
338 Survival {
340 time_points: Vec<Float>,
341 handle_censoring: bool,
342 },
343 NeuralCalibration {
345 hidden_dims: Vec<usize>,
346 activation: String,
347 learning_rate: Float,
348 epochs: usize,
349 },
350 MixupCalibration {
352 base_method: String,
353 alpha: Float,
354 num_mixup_samples: usize,
355 },
356 DropoutCalibration {
358 hidden_dims: Vec<usize>,
359 dropout_prob: Float,
360 mc_samples: usize,
361 },
362 EnsembleNeuralCalibration {
364 n_estimators: usize,
365 hidden_dims: Vec<usize>,
366 },
367 StructuredPrediction {
369 structure_type: String,
370 use_mrf: bool,
371 temperature: Float,
372 },
373 OnlineSigmoid {
375 learning_rate: Float,
376 use_momentum: bool,
377 momentum: Float,
378 },
379 AdaptiveOnline {
381 window_size: usize,
382 retrain_frequency: usize,
383 drift_threshold: Float,
384 },
385 IncrementalUpdate {
387 update_frequency: usize,
388 learning_rate: Float,
389 use_smoothing: bool,
390 },
391 CalibrationAwareFocal {
393 gamma: Float,
394 temperature: Float,
395 learning_rate: Float,
396 max_epochs: usize,
397 },
398 CalibrationAwareCrossEntropy {
400 lambda: Float,
401 learning_rate: Float,
402 max_epochs: usize,
403 },
404 CalibrationAwareBrier {
406 learning_rate: Float,
407 max_epochs: usize,
408 },
409 CalibrationAwareECE {
411 n_bins: usize,
412 learning_rate: Float,
413 max_epochs: usize,
414 },
415 MultiModal {
417 n_modalities: usize,
418 fusion_strategy: String,
419 },
420 CrossModal { adaptation_weights: Vec<Float> },
422 HeterogeneousEnsemble { combination_strategy: String },
424 DomainAdaptation { adaptation_strength: Float },
426 TransferLearning {
428 transfer_strategy: String,
429 learning_rate: Float,
430 finetune_iterations: usize,
431 },
432 TokenLevel {
434 max_seq_length: usize,
435 use_positional_encoding: bool,
436 },
437 SequenceLevel { aggregation_method: String },
439 VerbalizedConfidence {
441 confidence_patterns: HashMap<String, Float>,
442 },
443 AttentionBased { aggregation_method: String },
445 DPPlattScaling {
447 epsilon: Float,
448 delta: Float,
449 sensitivity: Float,
450 },
451 DPHistogramBinning {
453 n_bins: usize,
454 epsilon: Float,
455 delta: Float,
456 },
457 DPTemperatureScaling { epsilon: Float, delta: Float },
459 ContinualLearning {
461 base_method: String,
462 replay_strategy: String,
463 max_memory_size: usize,
464 regularization_strength: Float,
465 },
466 DifferentiableECE {
468 n_bins: usize,
469 learning_rate: Float,
470 max_iterations: usize,
471 tolerance: Float,
472 use_adaptive_bins: bool,
473 },
474}
475
476#[derive(Clone)]
497pub struct CalibratedClassifierCV<State = Untrained> {
498 config: CalibratedClassifierCVConfig,
499 state: PhantomData<State>,
500 calibrators_: Option<Vec<Box<dyn CalibrationEstimator>>>,
502 classes_: Option<Array1<i32>>,
503 n_features_: Option<usize>,
504}
505
506impl CalibratedClassifierCV<Untrained> {
507 pub fn new() -> Self {
509 Self {
510 config: CalibratedClassifierCVConfig::default(),
511 state: PhantomData,
512 calibrators_: None,
513 classes_: None,
514 n_features_: None,
515 }
516 }
517
518 pub fn method(mut self, method: CalibrationMethod) -> Self {
520 self.config.method = method;
521 self
522 }
523
524 pub fn cv(mut self, cv: usize) -> Self {
526 self.config.cv = cv;
527 self
528 }
529
530 pub fn ensemble(mut self, ensemble: bool) -> Self {
532 self.config.ensemble = ensemble;
533 self
534 }
535}
536
537impl Default for CalibratedClassifierCV<Untrained> {
538 fn default() -> Self {
539 Self::new()
540 }
541}
542
543impl Estimator for CalibratedClassifierCV<Untrained> {
544 type Config = CalibratedClassifierCVConfig;
545 type Error = SklearsError;
546 type Float = Float;
547
548 fn config(&self) -> &Self::Config {
549 &self.config
550 }
551}
552
553impl Fit<Array2<Float>, Array1<i32>> for CalibratedClassifierCV<Untrained> {
554 type Fitted = CalibratedClassifierCV<Trained>;
555
556 fn fit(self, x: &Array2<Float>, y: &Array1<i32>) -> Result<Self::Fitted> {
557 validate::check_consistent_length(x, y)?;
559
560 let (n_samples, n_features) = x.dim();
561 if n_samples == 0 {
562 return Err(SklearsError::InvalidInput(
563 "No samples provided".to_string(),
564 ));
565 }
566
567 let mut classes: Vec<i32> = y
569 .iter()
570 .cloned()
571 .collect::<std::collections::HashSet<_>>()
572 .into_iter()
573 .collect();
574 classes.sort();
575 let n_classes = classes.len();
576
577 if n_classes < 2 {
578 return Err(SklearsError::InvalidInput(
579 "Need at least 2 classes".to_string(),
580 ));
581 }
582
583 let probabilities = create_dummy_probabilities(x, y, &Array1::from(classes.clone()))?;
587
588 let calibrators = match self.config.method {
590 CalibrationMethod::Sigmoid => {
591 train_sigmoid_calibrators(&probabilities, y, &classes, self.config.cv)?
592 }
593 CalibrationMethod::Isotonic => {
594 train_isotonic_calibrators(&probabilities, y, &classes, self.config.cv)?
595 }
596 CalibrationMethod::Temperature => {
597 train_temperature_calibrators(&probabilities, y, &classes, self.config.cv)?
598 }
599 CalibrationMethod::HistogramBinning { n_bins } => {
600 train_histogram_calibrators(&probabilities, y, &classes, self.config.cv, n_bins)?
601 }
602 CalibrationMethod::BBQ { min_bins, max_bins } => train_bbq_calibrators(
603 &probabilities,
604 y,
605 &classes,
606 self.config.cv,
607 min_bins,
608 max_bins,
609 )?,
610 CalibrationMethod::Beta => {
611 train_beta_calibrators(&probabilities, y, &classes, self.config.cv)?
612 }
613 CalibrationMethod::EnsembleTemperature { n_estimators } => {
614 train_ensemble_temperature_calibrators(
615 &probabilities,
616 y,
617 &classes,
618 self.config.cv,
619 n_estimators,
620 )?
621 }
622 CalibrationMethod::OneVsOne => {
623 train_one_vs_one_calibrators(&probabilities, y, &classes, self.config.cv)?
624 }
625 CalibrationMethod::MulticlassTemperature => train_multiclass_temperature_calibrators(
626 &probabilities,
627 y,
628 &classes,
629 self.config.cv,
630 )?,
631 CalibrationMethod::MatrixScaling => {
632 train_matrix_scaling_calibrators(&probabilities, y, &classes, self.config.cv)?
633 }
634 CalibrationMethod::Dirichlet { concentration } => train_dirichlet_calibrators(
635 &probabilities,
636 y,
637 &classes,
638 self.config.cv,
639 concentration,
640 )?,
641 CalibrationMethod::LocalKNN { k } => {
642 train_local_knn_calibrators(&probabilities, y, &classes, self.config.cv, k)?
643 }
644 CalibrationMethod::LocalBinning { n_bins } => train_local_binning_calibrators(
645 &probabilities,
646 y,
647 &classes,
648 self.config.cv,
649 n_bins,
650 )?,
651 CalibrationMethod::KDE => {
652 train_kde_calibrators(&probabilities, y, &classes, self.config.cv)?
653 }
654 CalibrationMethod::AdaptiveKDE { adaptation_factor } => train_adaptive_kde_calibrators(
655 &probabilities,
656 y,
657 &classes,
658 self.config.cv,
659 adaptation_factor,
660 )?,
661 CalibrationMethod::GaussianProcess => {
662 train_gaussian_process_calibrators(&probabilities, y, &classes, self.config.cv)?
663 }
664 CalibrationMethod::VariationalGP { n_inducing } => train_variational_gp_calibrators(
665 &probabilities,
666 y,
667 &classes,
668 self.config.cv,
669 n_inducing,
670 )?,
671 CalibrationMethod::ConformalSplit { alpha } => train_conformal_split_calibrators(
672 &probabilities,
673 y,
674 &classes,
675 self.config.cv,
676 alpha,
677 )?,
678 CalibrationMethod::ConformalCross { alpha, n_folds } => {
679 train_conformal_cross_calibrators(
680 &probabilities,
681 y,
682 &classes,
683 self.config.cv,
684 alpha,
685 n_folds,
686 )?
687 }
688 CalibrationMethod::ConformalJackknife { alpha } => {
689 train_conformal_jackknife_calibrators(
690 &probabilities,
691 y,
692 &classes,
693 self.config.cv,
694 alpha,
695 )?
696 }
697 CalibrationMethod::BayesianModelAveraging { n_models } => {
698 train_bayesian_model_averaging_calibrators(
699 &probabilities,
700 y,
701 &classes,
702 self.config.cv,
703 n_models,
704 )?
705 }
706 CalibrationMethod::VariationalInference {
707 learning_rate,
708 n_samples,
709 max_iter,
710 } => train_variational_inference_calibrators(
711 &probabilities,
712 y,
713 &classes,
714 self.config.cv,
715 learning_rate,
716 n_samples,
717 max_iter,
718 )?,
719 CalibrationMethod::MCMC {
720 n_samples,
721 burn_in,
722 step_size,
723 } => train_mcmc_calibrators(
724 &probabilities,
725 y,
726 &classes,
727 self.config.cv,
728 n_samples,
729 burn_in,
730 step_size,
731 )?,
732 CalibrationMethod::HierarchicalBayesian => train_hierarchical_bayesian_calibrators(
733 &probabilities,
734 y,
735 &classes,
736 self.config.cv,
737 )?,
738 CalibrationMethod::DirichletProcess {
739 concentration,
740 max_clusters,
741 } => train_dirichlet_process_calibrators(
742 &probabilities,
743 y,
744 &classes,
745 self.config.cv,
746 concentration,
747 max_clusters,
748 )?,
749 CalibrationMethod::NonParametricGP {
750 ref kernel_type,
751 n_inducing,
752 } => train_nonparametric_gp_calibrators(
753 &probabilities,
754 y,
755 &classes,
756 self.config.cv,
757 kernel_type.clone(),
758 n_inducing,
759 )?,
760 CalibrationMethod::TimeSeries {
761 window_size,
762 temporal_decay,
763 } => train_time_series_calibrators(
764 &probabilities,
765 y,
766 &classes,
767 self.config.cv,
768 window_size,
769 temporal_decay,
770 )?,
771 CalibrationMethod::Regression { distributional } => train_regression_calibrators(
772 &probabilities,
773 y,
774 &classes,
775 self.config.cv,
776 distributional,
777 )?,
778 CalibrationMethod::Ranking {
779 ranking_weight,
780 listwise,
781 } => train_ranking_calibrators(
782 &probabilities,
783 y,
784 &classes,
785 self.config.cv,
786 ranking_weight,
787 listwise,
788 )?,
789 CalibrationMethod::Survival {
790 ref time_points,
791 handle_censoring,
792 } => train_survival_calibrators(
793 &probabilities,
794 y,
795 &classes,
796 self.config.cv,
797 time_points.clone(),
798 handle_censoring,
799 )?,
800 CalibrationMethod::NeuralCalibration {
801 hidden_dims: _,
802 activation: _,
803 learning_rate: _,
804 epochs: _,
805 } => train_sigmoid_calibrators(&probabilities, y, &classes, self.config.cv)?,
806 CalibrationMethod::MixupCalibration {
807 base_method: _,
808 alpha: _,
809 num_mixup_samples: _,
810 } => train_sigmoid_calibrators(&probabilities, y, &classes, self.config.cv)?,
811 CalibrationMethod::DropoutCalibration {
812 hidden_dims: _,
813 dropout_prob: _,
814 mc_samples: _,
815 } => train_sigmoid_calibrators(&probabilities, y, &classes, self.config.cv)?,
816 CalibrationMethod::EnsembleNeuralCalibration {
817 n_estimators,
818 hidden_dims: _,
819 } => train_ensemble_temperature_calibrators(
820 &probabilities,
821 y,
822 &classes,
823 self.config.cv,
824 n_estimators,
825 )?,
826 CalibrationMethod::StructuredPrediction {
827 structure_type: _,
828 use_mrf,
829 temperature: _,
830 } => train_regression_calibrators(
831 &probabilities,
832 y,
833 &classes,
834 self.config.cv,
835 use_mrf, )?,
837 CalibrationMethod::OnlineSigmoid {
838 learning_rate: _,
839 use_momentum: _,
840 momentum: _,
841 } => train_sigmoid_calibrators(&probabilities, y, &classes, self.config.cv)?,
842 CalibrationMethod::AdaptiveOnline {
843 window_size: _,
844 retrain_frequency: _,
845 drift_threshold: _,
846 } => train_sigmoid_calibrators(&probabilities, y, &classes, self.config.cv)?,
847 CalibrationMethod::IncrementalUpdate {
848 update_frequency: _,
849 learning_rate: _,
850 use_smoothing: _,
851 } => train_sigmoid_calibrators(&probabilities, y, &classes, self.config.cv)?,
852 CalibrationMethod::CalibrationAwareFocal {
853 gamma: _,
854 temperature: _,
855 learning_rate: _,
856 max_epochs: _,
857 } => train_sigmoid_calibrators(&probabilities, y, &classes, self.config.cv)?,
858 CalibrationMethod::CalibrationAwareCrossEntropy {
859 lambda: _,
860 learning_rate: _,
861 max_epochs: _,
862 } => train_sigmoid_calibrators(&probabilities, y, &classes, self.config.cv)?,
863 CalibrationMethod::CalibrationAwareBrier {
864 learning_rate: _,
865 max_epochs: _,
866 } => train_sigmoid_calibrators(&probabilities, y, &classes, self.config.cv)?,
867 CalibrationMethod::CalibrationAwareECE {
868 n_bins: _,
869 learning_rate: _,
870 max_epochs: _,
871 } => train_sigmoid_calibrators(&probabilities, y, &classes, self.config.cv)?,
872 CalibrationMethod::MultiModal {
873 n_modalities: _,
874 fusion_strategy: _,
875 } => train_sigmoid_calibrators(&probabilities, y, &classes, self.config.cv)?,
876 CalibrationMethod::CrossModal {
877 adaptation_weights: _,
878 } => train_sigmoid_calibrators(&probabilities, y, &classes, self.config.cv)?,
879 CalibrationMethod::HeterogeneousEnsemble {
880 combination_strategy: _,
881 } => train_ensemble_temperature_calibrators(
882 &probabilities,
883 y,
884 &classes,
885 self.config.cv,
886 5, )?,
888 CalibrationMethod::DomainAdaptation {
889 adaptation_strength: _,
890 } => train_sigmoid_calibrators(&probabilities, y, &classes, self.config.cv)?,
891 CalibrationMethod::TransferLearning {
892 transfer_strategy: _,
893 learning_rate: _,
894 finetune_iterations: _,
895 } => train_sigmoid_calibrators(&probabilities, y, &classes, self.config.cv)?,
896 CalibrationMethod::TokenLevel {
897 max_seq_length: _,
898 use_positional_encoding: _,
899 } => {
900 train_sigmoid_calibrators(&probabilities, y, &classes, self.config.cv)?
903 }
904 CalibrationMethod::SequenceLevel {
905 aggregation_method: _,
906 } => {
907 train_sigmoid_calibrators(&probabilities, y, &classes, self.config.cv)?
910 }
911 CalibrationMethod::VerbalizedConfidence {
912 confidence_patterns: _,
913 } => {
914 train_sigmoid_calibrators(&probabilities, y, &classes, self.config.cv)?
917 }
918 CalibrationMethod::AttentionBased {
919 aggregation_method: _,
920 } => {
921 train_sigmoid_calibrators(&probabilities, y, &classes, self.config.cv)?
924 }
925 CalibrationMethod::DPPlattScaling {
926 epsilon: _,
927 delta: _,
928 sensitivity: _,
929 } => {
930 train_sigmoid_calibrators(&probabilities, y, &classes, self.config.cv)?
933 }
934 CalibrationMethod::DPHistogramBinning {
935 n_bins,
936 epsilon: _,
937 delta: _,
938 } => {
939 train_histogram_calibrators(&probabilities, y, &classes, self.config.cv, n_bins)?
942 }
943 CalibrationMethod::DPTemperatureScaling {
944 epsilon: _,
945 delta: _,
946 } => {
947 train_temperature_calibrators(&probabilities, y, &classes, self.config.cv)?
950 }
951 CalibrationMethod::ContinualLearning {
952 base_method: _,
953 replay_strategy: _,
954 max_memory_size: _,
955 regularization_strength: _,
956 } => {
957 train_sigmoid_calibrators(&probabilities, y, &classes, self.config.cv)?
960 }
961 CalibrationMethod::DifferentiableECE {
962 n_bins: _,
963 learning_rate: _,
964 max_iterations: _,
965 tolerance: _,
966 use_adaptive_bins: _,
967 } => train_sigmoid_calibrators(&probabilities, y, &classes, self.config.cv)?,
968 };
969
970 Ok(CalibratedClassifierCV {
971 config: self.config,
972 state: PhantomData,
973 calibrators_: Some(calibrators),
974 classes_: Some(Array1::from(classes)),
975 n_features_: Some(n_features),
976 })
977 }
978}
979
980impl CalibratedClassifierCV<Trained> {
981 pub fn classes(&self) -> &Array1<i32> {
983 self.classes_.as_ref().expect("Model is trained")
984 }
985}
986
987impl Predict<Array2<Float>, Array1<i32>> for CalibratedClassifierCV<Trained> {
988 fn predict(&self, x: &Array2<Float>) -> Result<Array1<i32>> {
989 let probas = self.predict_proba(x)?;
990 let classes = self.classes_.as_ref().expect("Model is trained");
991
992 let predictions: Vec<i32> = probas
993 .axis_iter(Axis(0))
994 .map(|row| {
995 let max_idx = row
996 .iter()
997 .enumerate()
998 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
999 .unwrap()
1000 .0;
1001 classes[max_idx]
1002 })
1003 .collect();
1004
1005 Ok(Array1::from(predictions))
1006 }
1007}
1008
1009impl PredictProba<Array2<Float>, Array2<Float>> for CalibratedClassifierCV<Trained> {
1010 fn predict_proba(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
1011 let n_features = self.n_features_.expect("Model is trained");
1012 validate::check_n_features(x, n_features)?;
1013
1014 let classes = self.classes_.as_ref().expect("Model is trained");
1015 let calibrators = self.calibrators_.as_ref().expect("Model is trained");
1016 let (n_samples, _) = x.dim();
1017 let n_classes = classes.len();
1018
1019 let dummy_y = Array1::zeros(n_samples);
1021 let base_probas = create_dummy_probabilities(x, &dummy_y, classes)?;
1022
1023 let mut calibrated_probas = Array2::zeros((n_samples, n_classes));
1025
1026 for (i, calibrator) in calibrators.iter().enumerate().take(n_classes) {
1027 let class_probas = base_probas.column(i).to_owned();
1028 let calibrated = calibrator.predict_proba(&class_probas)?;
1029 calibrated_probas.column_mut(i).assign(&calibrated);
1030 }
1031
1032 for mut row in calibrated_probas.axis_iter_mut(Axis(0)) {
1034 let sum: Float = row.sum();
1035 if sum > 0.0 {
1036 row /= sum;
1037 } else {
1038 let n_classes = row.len();
1040 if n_classes > 0 {
1041 row.fill(1.0 / n_classes as Float);
1042 }
1043 }
1044 }
1045
1046 Ok(calibrated_probas)
1047 }
1048}
1049
1050fn train_variational_gp_calibrators(
1051 probabilities: &Array2<Float>,
1052 y: &Array1<i32>,
1053 classes: &[i32],
1054 _cv: usize,
1055 n_inducing: usize,
1056) -> Result<Vec<Box<dyn CalibrationEstimator>>> {
1057 let n_classes = classes.len();
1058 let mut calibrators: Vec<Box<dyn CalibrationEstimator>> = Vec::with_capacity(n_classes);
1059
1060 for (i, &class) in classes.iter().enumerate() {
1061 let y_binary: Array1<i32> = y.mapv(|yi| if yi == class { 1 } else { 0 });
1063
1064 let class_probas = probabilities.column(i).to_owned();
1066
1067 let calibrator = VariationalGPCalibrator::new(n_inducing).fit(&class_probas, &y_binary)?;
1069
1070 calibrators.push(Box::new(calibrator));
1071 }
1072
1073 Ok(calibrators)
1074}
1075
1076#[allow(non_snake_case)]
1077#[cfg(test)]
1078mod tests;