Skip to main content

scry_learn/
calibration.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Probability calibration for classifiers.
3//!
4//! Many classifiers (especially tree-based ones) produce poorly-calibrated
5//! probability estimates. This module provides methods to map raw classifier
6//! outputs to well-calibrated probabilities.
7//!
8//! # Methods
9//!
10//! - [`PlattScaling`] — Fits a logistic sigmoid `P(y=1|f) = 1/(1+exp(Af+B))`
11//!   using the Platt (1999) algorithm. Best for small datasets.
12//! - [`IsotonicRegression`] — Non-parametric calibration using the Pool
13//!   Adjacent Violators (PAV) algorithm. More flexible but needs more data.
14//! - [`CalibratedClassifierCV`] — Wraps any classifier and calibrates its
15//!   `predict_proba` output using cross-validation.
16//!
17//! # Example
18//!
19//! ```ignore
20//! use scry_learn::calibration::{CalibratedClassifierCV, CalibrationMethod};
21//! use scry_learn::tree::RandomForestClassifier;
22//!
23//! let cal = CalibratedClassifierCV::new(
24//!     RandomForestClassifier::new().n_estimators(50),
25//!     CalibrationMethod::Isotonic,
26//! );
27//! ```
28
29use crate::error::{Result, ScryLearnError};
30
31// ---------------------------------------------------------------------------
32// Platt Scaling
33// ---------------------------------------------------------------------------
34
35/// Sigmoid calibration using Platt's method.
36///
37/// Fits the parameters A and B of the sigmoid:
38/// `P(y=1 | f) = 1 / (1 + exp(A·f + B))`
39///
40/// Uses the improved algorithm from Platt (1999) with modified target
41/// values to avoid saturation:
42/// `t+ = (N+ + 1) / (N+ + 2)` and `t- = 1 / (N- + 2)`.
43///
44/// # Example
45///
46/// ```
47/// use scry_learn::calibration::PlattScaling;
48///
49/// let mut platt = PlattScaling::new();
50/// // decision_values: raw SVM or tree output
51/// // labels: 0.0 or 1.0
52/// platt.fit(&[2.0, 1.5, -0.5, -1.0], &[1.0, 1.0, 0.0, 0.0]).unwrap();
53/// let probs = platt.predict(&[1.0, -1.0]);
54/// assert!(probs[0] > 0.5);
55/// assert!(probs[1] < 0.5);
56/// ```
57#[derive(Clone, Debug)]
58#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
59#[non_exhaustive]
60pub struct PlattScaling {
61    /// Sigmoid parameter A (slope).
62    a: f64,
63    /// Sigmoid parameter B (intercept).
64    b: f64,
65    /// Maximum iterations for Newton's method.
66    max_iter: usize,
67    /// Whether the model has been fitted.
68    fitted: bool,
69    #[cfg_attr(feature = "serde", serde(default))]
70    _schema_version: u32,
71}
72
73impl PlattScaling {
74    /// Create a new unfitted Platt scaler with default parameters.
75    pub fn new() -> Self {
76        Self {
77            a: 0.0,
78            b: 0.0,
79            max_iter: 100,
80            fitted: false,
81            _schema_version: crate::version::SCHEMA_VERSION,
82        }
83    }
84
85    /// Set the maximum number of Newton iterations (default: 100).
86    pub fn max_iter(mut self, n: usize) -> Self {
87        self.max_iter = n;
88        self
89    }
90
91    /// The fitted A parameter (sigmoid slope).
92    pub fn a(&self) -> f64 {
93        self.a
94    }
95
96    /// The fitted B parameter (sigmoid intercept).
97    pub fn b(&self) -> f64 {
98        self.b
99    }
100
101    /// Fit the sigmoid parameters to decision values and binary labels.
102    ///
103    /// `decision_values`: raw classifier output (e.g. distance to hyperplane).
104    /// `labels`: binary ground truth, each element must be 0.0 or 1.0.
105    pub fn fit(&mut self, decision_values: &[f64], labels: &[f64]) -> Result<()> {
106        let n = decision_values.len();
107        if n != labels.len() {
108            return Err(ScryLearnError::InvalidParameter(
109                "decision_values and labels must have the same length".into(),
110            ));
111        }
112        if n == 0 {
113            return Err(ScryLearnError::EmptyDataset);
114        }
115
116        // Count positives and negatives.
117        let n_pos = labels.iter().filter(|&&y| y > 0.5).count();
118        let n_neg = n - n_pos;
119        if n_pos == 0 || n_neg == 0 {
120            return Err(ScryLearnError::InvalidParameter(
121                "labels must contain both positive and negative samples".into(),
122            ));
123        }
124
125        // Bail early if all decision values are identical — the sigmoid
126        // cannot separate anything and Newton will produce a singular Hessian.
127        let dv_min = decision_values
128            .iter()
129            .copied()
130            .fold(f64::INFINITY, f64::min);
131        let dv_max = decision_values
132            .iter()
133            .copied()
134            .fold(f64::NEG_INFINITY, f64::max);
135        if (dv_max - dv_min).abs() < f64::EPSILON {
136            return Err(ScryLearnError::InvalidData(
137                "all decision values are identical — Platt scaling cannot calibrate".into(),
138            ));
139        }
140
141        // Modified target values to avoid saturation (Platt 1999).
142        let t_pos = (n_pos as f64 + 1.0) / (n_pos as f64 + 2.0);
143        let t_neg = 1.0 / (n_neg as f64 + 2.0);
144
145        let t: Vec<f64> = labels
146            .iter()
147            .map(|&y| if y > 0.5 { t_pos } else { t_neg })
148            .collect();
149
150        // Newton's method to minimize the negative log-likelihood:
151        //   L = -sum[ t_i * log(p_i) + (1-t_i) * log(1-p_i) ]
152        // where p_i = 1 / (1 + exp(A*f_i + B))
153        let mut a = 0.0_f64;
154        let mut b = ((n_neg as f64 + 1.0) / (n_pos as f64 + 1.0)).ln();
155
156        let min_step = crate::constants::PLATT_MIN_STEP;
157        let sigma = crate::constants::PLATT_HESSIAN_REG;
158
159        for _ in 0..self.max_iter {
160            // Compute gradient and Hessian.
161            let mut g1 = 0.0_f64; // dL/dA
162            let mut g2 = 0.0_f64; // dL/dB
163            let mut h11 = sigma; // d²L/dA²
164            let mut h22 = sigma; // d²L/dB²
165            let mut h21 = 0.0_f64; // d²L/dAdB
166
167            for i in 0..n {
168                let fval = decision_values[i] * a + b;
169                let p = sigmoid(fval);
170                let d = p - t[i];
171                let w = p * (1.0 - p).max(crate::constants::SINGULAR_THRESHOLD);
172                let fi = decision_values[i];
173
174                g1 += fi * d;
175                g2 += d;
176                h11 += fi * fi * w;
177                h22 += w;
178                h21 += fi * w;
179            }
180
181            // Solve the 2×2 system: H * [dA, dB]' = -[g1, g2]'
182            let det = h11 * h22 - h21 * h21;
183            if det.abs() < crate::constants::PLATT_SINGULAR_DET {
184                return Err(ScryLearnError::ConvergenceFailure {
185                    iterations: self.max_iter,
186                    tolerance: crate::constants::PLATT_SINGULAR_DET,
187                });
188            }
189            let da = -(h22 * g1 - h21 * g2) / det;
190            let db = -(h11 * g2 - h21 * g1) / det;
191
192            // Line search with step halving.
193            let mut step = 1.0;
194            let old_nll = neg_log_likelihood(decision_values, &t, a, b);
195            loop {
196                let new_a = a + step * da;
197                let new_b = b + step * db;
198                let new_nll = neg_log_likelihood(decision_values, &t, new_a, new_b);
199                if new_nll < old_nll + crate::constants::ARMIJO_C * step * (g1 * da + g2 * db) {
200                    a = new_a;
201                    b = new_b;
202                    break;
203                }
204                step *= 0.5;
205                if step < min_step {
206                    a += step * da;
207                    b += step * db;
208                    break;
209                }
210            }
211
212            // Check convergence.
213            if (da * step).abs() < crate::constants::PLATT_CONVERGENCE
214                && (db * step).abs() < crate::constants::PLATT_CONVERGENCE
215            {
216                break;
217            }
218        }
219
220        self.a = a;
221        self.b = b;
222        self.fitted = true;
223        Ok(())
224    }
225
226    /// Transform decision values into calibrated probabilities.
227    pub fn predict(&self, decision_values: &[f64]) -> Vec<f64> {
228        decision_values
229            .iter()
230            .map(|&f| sigmoid(self.a * f + self.b))
231            .collect()
232    }
233}
234
235impl Default for PlattScaling {
236    fn default() -> Self {
237        Self::new()
238    }
239}
240
241/// Sigmoid function: 1 / (1 + exp(-x)).
242fn sigmoid(x: f64) -> f64 {
243    if x >= 0.0 {
244        1.0 / (1.0 + (-x).exp())
245    } else {
246        let ex = x.exp();
247        ex / (1.0 + ex)
248    }
249}
250
251/// Negative log-likelihood for Platt scaling.
252fn neg_log_likelihood(f: &[f64], t: &[f64], a: f64, b: f64) -> f64 {
253    let mut nll = 0.0;
254    for i in 0..f.len() {
255        let p = sigmoid(a * f[i] + b);
256        let p_clamped = p.clamp(
257            crate::constants::NEAR_ZERO,
258            1.0 - crate::constants::NEAR_ZERO,
259        );
260        nll -= t[i] * p_clamped.ln() + (1.0 - t[i]) * (1.0 - p_clamped).ln();
261    }
262    nll
263}
264
265// ---------------------------------------------------------------------------
266// Isotonic Regression
267// ---------------------------------------------------------------------------
268
269/// Non-parametric calibration using isotonic (monotone) regression.
270///
271/// Uses the Pool Adjacent Violators (PAV) algorithm to fit a non-decreasing
272/// step function to the data. Predictions use linear interpolation between
273/// fitted values.
274///
275/// # Example
276///
277/// ```
278/// use scry_learn::calibration::IsotonicRegression;
279///
280/// let mut iso = IsotonicRegression::new();
281/// iso.fit(&[0.1, 0.4, 0.6, 0.9], &[0.0, 0.0, 1.0, 1.0]).unwrap();
282/// let p = iso.predict(&[0.5]);
283/// assert!(p[0] >= 0.0 && p[0] <= 1.0);
284/// ```
285#[derive(Clone, Debug)]
286#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
287#[non_exhaustive]
288pub struct IsotonicRegression {
289    /// Fitted x-values (sorted).
290    xs: Vec<f64>,
291    /// Fitted y-values (non-decreasing, corresponding to xs).
292    ys: Vec<f64>,
293    /// Whether the model has been fitted.
294    fitted: bool,
295    #[cfg_attr(feature = "serde", serde(default))]
296    _schema_version: u32,
297}
298
299impl IsotonicRegression {
300    /// Create a new unfitted isotonic regression model.
301    pub fn new() -> Self {
302        Self {
303            xs: Vec::new(),
304            ys: Vec::new(),
305            fitted: false,
306            _schema_version: crate::version::SCHEMA_VERSION,
307        }
308    }
309
310    /// Fit the isotonic regression to (x, y) pairs.
311    ///
312    /// `x`: predictor values (e.g. uncalibrated probabilities).
313    /// `y`: response values (e.g. 0/1 labels or true probabilities).
314    pub fn fit(&mut self, x: &[f64], y: &[f64]) -> Result<()> {
315        let n = x.len();
316        if n != y.len() {
317            return Err(ScryLearnError::InvalidParameter(
318                "x and y must have the same length".into(),
319            ));
320        }
321        if n == 0 {
322            return Err(ScryLearnError::EmptyDataset);
323        }
324
325        // Sort by x.
326        let mut indices: Vec<usize> = (0..n).collect();
327        indices.sort_by(|&a, &b| x[a].partial_cmp(&x[b]).unwrap_or(std::cmp::Ordering::Equal));
328
329        let sorted_x: Vec<f64> = indices.iter().map(|&i| x[i]).collect();
330        let sorted_y: Vec<f64> = indices.iter().map(|&i| y[i]).collect();
331
332        // Pool Adjacent Violators (PAV) algorithm.
333        // Each block has (sum_y, count, start_idx).
334        let mut blocks: Vec<(f64, usize, usize)> = Vec::with_capacity(n);
335
336        for i in 0..n {
337            blocks.push((sorted_y[i], 1, i));
338
339            // Merge backwards while the last block violates monotonicity.
340            while blocks.len() >= 2 {
341                let len = blocks.len();
342                let mean_last = blocks[len - 1].0 / blocks[len - 1].1 as f64;
343                let mean_prev = blocks[len - 2].0 / blocks[len - 2].1 as f64;
344                if mean_prev <= mean_last {
345                    break;
346                }
347                // Merge — loop guard above ensures blocks.len() >= 2.
348                let last = blocks.pop().expect("loop guard: blocks.len() >= 2");
349                let prev = blocks.last_mut().expect("loop guard: blocks.len() >= 2");
350                prev.0 += last.0;
351                prev.1 += last.1;
352            }
353        }
354
355        // Build the fitted (x, y) pairs — one pair per unique x value.
356        // For each block, use the block's mean and the mean x of the block.
357        let mut fit_x = Vec::with_capacity(blocks.len());
358        let mut fit_y = Vec::with_capacity(blocks.len());
359
360        let mut idx = 0;
361        for &(sum_y, count, _) in &blocks {
362            let mean_y = sum_y / count as f64;
363            // Use the first and last x of this block's range.
364            let block_start = idx;
365            let block_end = idx + count;
366            // Use the mid-x of the block as representative.
367            let mean_x: f64 = sorted_x[block_start..block_end].iter().sum::<f64>() / count as f64;
368            fit_x.push(mean_x);
369            fit_y.push(mean_y);
370            idx = block_end;
371        }
372
373        self.xs = fit_x;
374        self.ys = fit_y;
375        self.fitted = true;
376        Ok(())
377    }
378
379    /// Predict calibrated values using linear interpolation.
380    pub fn predict(&self, x: &[f64]) -> Vec<f64> {
381        x.iter().map(|&v| self.interpolate(v)).collect()
382    }
383
384    /// Linear interpolation between fitted points.
385    fn interpolate(&self, x: f64) -> f64 {
386        if self.xs.is_empty() {
387            return 0.5;
388        }
389        if self.xs.len() == 1 {
390            return self.ys[0];
391        }
392
393        // Clamp to range.
394        if x <= self.xs[0] {
395            return self.ys[0];
396        }
397        // xs.len() >= 2 is guaranteed by the early returns above.
398        if x >= self.xs[self.xs.len() - 1] {
399            return self.ys[self.ys.len() - 1];
400        }
401
402        // Binary search for the interval.
403        let mut lo = 0;
404        let mut hi = self.xs.len() - 1;
405        while lo + 1 < hi {
406            let mid = usize::midpoint(lo, hi);
407            if self.xs[mid] <= x {
408                lo = mid;
409            } else {
410                hi = mid;
411            }
412        }
413
414        // Linear interpolation.
415        let dx = self.xs[hi] - self.xs[lo];
416        if dx.abs() < crate::constants::NEAR_ZERO {
417            return f64::midpoint(self.ys[lo], self.ys[hi]);
418        }
419        let t = (x - self.xs[lo]) / dx;
420        self.ys[lo] + t * (self.ys[hi] - self.ys[lo])
421    }
422}
423
424impl Default for IsotonicRegression {
425    fn default() -> Self {
426        Self::new()
427    }
428}
429
430// ---------------------------------------------------------------------------
431// Calibration method enum
432// ---------------------------------------------------------------------------
433
434/// Calibration method for [`CalibratedClassifierCV`].
435#[derive(Clone, Debug, Default)]
436#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
437pub enum CalibrationMethod {
438    /// Platt's sigmoid-based calibration.
439    #[default]
440    Sigmoid,
441    /// Isotonic regression (non-parametric).
442    Isotonic,
443}
444
445// ---------------------------------------------------------------------------
446// CalibratedClassifierCV
447// ---------------------------------------------------------------------------
448
449/// A calibrated classifier wrapper.
450///
451/// Uses internal cross-validation to produce calibrated probability
452/// estimates from any classifier that supports `predict_proba`.
453///
454/// During `fit`, the data is split into `n_folds` folds. For each fold,
455/// the base classifier is trained on the training portion, predictions are
456/// made on the held-out portion, and those predictions are used to fit a
457/// calibration model (per class in the OVR scheme). At predict time, the
458/// full model's raw probabilities are transformed through the calibrator.
459///
460/// # Example
461///
462/// ```ignore
463/// use scry_learn::calibration::{CalibratedClassifierCV, CalibrationMethod};
464/// use scry_learn::tree::DecisionTreeClassifier;
465/// use scry_learn::dataset::Dataset;
466///
467/// let data = Dataset::from_csv("iris.csv", "species").unwrap();
468/// let mut cal = CalibratedClassifierCV::new(
469///     DecisionTreeClassifier::new(),
470///     CalibrationMethod::Isotonic,
471/// ).n_folds(5);
472///
473/// cal.fit(&data).unwrap();
474/// let probs = cal.predict_proba(&data.feature_matrix()).unwrap();
475/// ```
476#[non_exhaustive]
477pub struct CalibratedClassifierCV {
478    /// The base classifier (boxed for heterogeneity).
479    base: Box<dyn CalibrableClassifier>,
480    /// Calibration method.
481    method: CalibrationMethod,
482    /// Number of cross-validation folds.
483    n_folds: usize,
484    /// Per-class calibrators (fitted during `.fit()`).
485    calibrators: Vec<CalibratorKind>,
486    /// Whether the model has been fitted.
487    fitted: bool,
488}
489
490/// Internal enum wrapping a fitted calibrator.
491enum CalibratorKind {
492    Platt(PlattScaling),
493    Isotonic(IsotonicRegression),
494}
495
496impl CalibratorKind {
497    fn predict(&self, values: &[f64]) -> Vec<f64> {
498        match self {
499            Self::Platt(p) => p.predict(values),
500            Self::Isotonic(iso) => iso.predict(values),
501        }
502    }
503}
504
505/// Trait for classifiers that can be calibrated.
506///
507/// Any classifier with `fit`, `predict`, and `predict_proba` methods
508/// that returns `Vec<Vec<f64>>` for probabilities.
509pub trait CalibrableClassifier {
510    /// Train on a dataset.
511    fn fit(&mut self, data: &crate::dataset::Dataset) -> Result<()>;
512    /// Predict class labels.
513    fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>>;
514    /// Predict class probabilities. Returns `[n_samples][n_classes]`.
515    fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>>;
516    /// Clone into a boxed trait object.
517    fn clone_box(&self) -> Box<dyn CalibrableClassifier>;
518}
519
520// Implement CalibrableClassifier for common classifiers.
521macro_rules! impl_calibrable {
522    ($($ty:ty),* $(,)?) => {
523        $(
524            impl CalibrableClassifier for $ty {
525                fn fit(&mut self, data: &crate::dataset::Dataset) -> Result<()> {
526                    self.fit(data)
527                }
528                fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
529                    self.predict(features)
530                }
531                fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
532                    self.predict_proba(features)
533                }
534                fn clone_box(&self) -> Box<dyn CalibrableClassifier> {
535                    Box::new(self.clone())
536                }
537            }
538        )*
539    };
540}
541
542impl_calibrable! {
543    crate::tree::DecisionTreeClassifier,
544    crate::tree::RandomForestClassifier,
545    crate::tree::GradientBoostingClassifier,
546    crate::tree::HistGradientBoostingClassifier,
547    crate::linear::LogisticRegression,
548    crate::naive_bayes::GaussianNb,
549    crate::naive_bayes::BernoulliNB,
550    crate::naive_bayes::MultinomialNB,
551    crate::svm::LinearSVC,
552    crate::neighbors::KnnClassifier,
553}
554
555#[cfg(feature = "experimental")]
556impl_calibrable! {
557    crate::svm::KernelSVC,
558}
559
560impl CalibratedClassifierCV {
561    /// Create a new calibrated classifier wrapper.
562    pub fn new<C: CalibrableClassifier + 'static>(
563        classifier: C,
564        method: CalibrationMethod,
565    ) -> Self {
566        Self {
567            base: Box::new(classifier),
568            method,
569            n_folds: 5,
570            calibrators: Vec::new(),
571            fitted: false,
572        }
573    }
574
575    /// Set the number of cross-validation folds (default: 5).
576    pub fn n_folds(mut self, n: usize) -> Self {
577        self.n_folds = n;
578        self
579    }
580
581    /// Fit the calibrated classifier.
582    ///
583    /// 1. Splits data into `n_folds` stratified folds.
584    /// 2. For each fold, trains a clone of the base classifier on the
585    ///    training portion and collects `predict_proba` on held-out.
586    /// 3. Fits a per-class calibrators on the aggregated out-of-fold predictions.
587    /// 4. Re-trains the base classifier on the full dataset.
588    pub fn fit(&mut self, data: &crate::dataset::Dataset) -> Result<()> {
589        let n = data.n_samples();
590        if n < self.n_folds {
591            return Err(ScryLearnError::InvalidParameter(format!(
592                "n_folds ({}) must be ≤ n_samples ({n})",
593                self.n_folds
594            )));
595        }
596
597        let features = data.feature_matrix(); // row-major
598        let targets = &data.target;
599
600        // Determine number of classes.
601        let n_classes = {
602            let mut max_class = 0usize;
603            for &t in targets {
604                let c = t as usize;
605                if c > max_class {
606                    max_class = c;
607                }
608            }
609            max_class + 1
610        };
611
612        // Hold-out predictions: proba[i][c] for sample i, class c.
613        let mut oof_proba: Vec<Vec<f64>> = vec![vec![0.0; n_classes]; n];
614
615        // Simple k-fold splitting (deterministic, stratified-like via interleaving).
616        let fold_indices = k_fold_indices(n, self.n_folds);
617
618        for fold in 0..self.n_folds {
619            let val_mask = &fold_indices[fold];
620            let train_indices: Vec<usize> = (0..n).filter(|i| !val_mask.contains(i)).collect();
621            let val_indices: Vec<usize> = val_mask.clone();
622
623            // Build training dataset.
624            let train_features: Vec<Vec<f64>> = data
625                .features
626                .iter()
627                .map(|col| train_indices.iter().map(|&i| col[i]).collect())
628                .collect();
629            let train_target: Vec<f64> = train_indices.iter().map(|&i| targets[i]).collect();
630            let train_data = crate::dataset::Dataset::new(
631                train_features,
632                train_target,
633                data.feature_names.clone(),
634                &data.target_name,
635            );
636
637            // Train a clone of the base classifier.
638            let mut clf = self.base.clone_box();
639            clf.fit(&train_data)?;
640
641            // Predict probabilities on validation fold.
642            let val_features: Vec<Vec<f64>> =
643                val_indices.iter().map(|&i| features[i].clone()).collect();
644
645            let proba = clf.predict_proba(&val_features)?;
646
647            // Store out-of-fold predictions.
648            for (j, &val_idx) in val_indices.iter().enumerate() {
649                if j < proba.len() {
650                    for c in 0..n_classes.min(proba[j].len()) {
651                        oof_proba[val_idx][c] = proba[j][c];
652                    }
653                }
654            }
655        }
656
657        // Fit per-class calibrators using OOF predictions.
658        self.calibrators = Vec::with_capacity(n_classes);
659        for c in 0..n_classes {
660            let proba_c: Vec<f64> = oof_proba.iter().map(|p| p[c]).collect();
661            let labels_c: Vec<f64> = targets
662                .iter()
663                .map(|&t| if (t as usize) == c { 1.0 } else { 0.0 })
664                .collect();
665
666            let cal = match &self.method {
667                CalibrationMethod::Sigmoid => {
668                    let mut platt = PlattScaling::new();
669                    platt.fit(&proba_c, &labels_c)?;
670                    CalibratorKind::Platt(platt)
671                }
672                CalibrationMethod::Isotonic => {
673                    let mut iso = IsotonicRegression::new();
674                    iso.fit(&proba_c, &labels_c)?;
675                    CalibratorKind::Isotonic(iso)
676                }
677            };
678            self.calibrators.push(cal);
679        }
680
681        // Re-train the base classifier on the full dataset.
682        self.base.fit(data)?;
683        self.fitted = true;
684        Ok(())
685    }
686
687    /// Predict calibrated probabilities.
688    ///
689    /// Returns `[n_samples][n_classes]` with calibrated probabilities
690    /// that sum to 1 for each sample.
691    pub fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
692        if !self.fitted {
693            return Err(ScryLearnError::NotFitted);
694        }
695
696        // Get raw probabilities from the base classifier.
697        let raw = self.base.predict_proba(features)?;
698        let n_classes = self.calibrators.len();
699
700        // Calibrate each class independently, then normalize.
701        let mut result = Vec::with_capacity(raw.len());
702        for row in &raw {
703            let mut calibrated = Vec::with_capacity(n_classes);
704            for c in 0..n_classes {
705                let raw_p = if c < row.len() { row[c] } else { 0.0 };
706                let cal_p = self.calibrators[c].predict(&[raw_p])[0];
707                calibrated.push(cal_p.max(0.0));
708            }
709
710            // Normalize to sum to 1.
711            let sum: f64 = calibrated.iter().sum();
712            if sum > 0.0 {
713                for p in &mut calibrated {
714                    *p /= sum;
715                }
716            } else {
717                // Uniform fallback.
718                let uniform = 1.0 / n_classes as f64;
719                calibrated.fill(uniform);
720            }
721            result.push(calibrated);
722        }
723        Ok(result)
724    }
725
726    /// Predict class labels (argmax of calibrated probabilities).
727    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
728        let proba = self.predict_proba(features)?;
729        Ok(proba
730            .iter()
731            .map(|row| {
732                row.iter()
733                    .enumerate()
734                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
735                    .map_or(0.0, |(i, _)| i as f64)
736            })
737            .collect())
738    }
739}
740
741/// Simple k-fold index splitting: returns `k` vecs of sample indices.
742fn k_fold_indices(n: usize, k: usize) -> Vec<Vec<usize>> {
743    let mut folds: Vec<Vec<usize>> = (0..k).map(|_| Vec::new()).collect();
744    for i in 0..n {
745        folds[i % k].push(i);
746    }
747    folds
748}
749
750// ---------------------------------------------------------------------------
751// Tests
752// ---------------------------------------------------------------------------
753
754#[cfg(test)]
755mod tests {
756    use super::*;
757
758    #[test]
759    fn test_platt_basic_separation() {
760        let mut platt = PlattScaling::new();
761        let dv = vec![3.0, 2.0, 1.0, 0.5, -0.5, -1.0, -2.0, -3.0];
762        let labels = vec![1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0];
763        platt.fit(&dv, &labels).unwrap();
764
765        let probs = platt.predict(&[2.0, -2.0]);
766        assert!(
767            probs[0] > 0.7,
768            "positive should have high prob: {}",
769            probs[0]
770        );
771        assert!(
772            probs[1] < 0.3,
773            "negative should have low prob: {}",
774            probs[1]
775        );
776    }
777
778    #[test]
779    fn test_platt_monotone() {
780        let mut platt = PlattScaling::new();
781        let dv: Vec<f64> = (-10..=10).map(|x| x as f64).collect();
782        let labels: Vec<f64> = dv
783            .iter()
784            .map(|&x| if x >= 0.0 { 1.0 } else { 0.0 })
785            .collect();
786        platt.fit(&dv, &labels).unwrap();
787
788        let test_vals = vec![-5.0, -2.0, 0.0, 2.0, 5.0];
789        let probs = platt.predict(&test_vals);
790        for w in probs.windows(2) {
791            assert!(
792                w[1] >= w[0] - 1e-6,
793                "probabilities should be monotone: {} < {}",
794                w[0],
795                w[1]
796            );
797        }
798    }
799
800    #[test]
801    fn test_isotonic_monotone_output() {
802        let mut iso = IsotonicRegression::new();
803        let x = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9];
804        let y = vec![0.0, 0.0, 0.3, 0.2, 0.5, 0.4, 0.8, 0.9, 1.0];
805        iso.fit(&x, &y).unwrap();
806
807        let pred = iso.predict(&x);
808        for w in pred.windows(2) {
809            assert!(
810                w[1] >= w[0] - 1e-10,
811                "isotonic output must be non-decreasing: {} > {}",
812                w[0],
813                w[1]
814            );
815        }
816    }
817
818    #[test]
819    fn test_isotonic_perfect_data() {
820        let mut iso = IsotonicRegression::new();
821        let x = vec![0.0, 0.25, 0.5, 0.75, 1.0];
822        let y = vec![0.0, 0.25, 0.5, 0.75, 1.0];
823        iso.fit(&x, &y).unwrap();
824
825        let pred = iso.predict(&[0.0, 0.5, 1.0]);
826        assert!((pred[0] - 0.0).abs() < 0.05);
827        assert!((pred[1] - 0.5).abs() < 0.05);
828        assert!((pred[2] - 1.0).abs() < 0.05);
829    }
830
831    #[test]
832    fn test_isotonic_clamp_extrapolation() {
833        let mut iso = IsotonicRegression::new();
834        iso.fit(&[0.2, 0.5, 0.8], &[0.1, 0.5, 0.9]).unwrap();
835
836        let pred = iso.predict(&[0.0, 1.0]);
837        // Should clamp to boundary values.
838        assert!((pred[0] - 0.1).abs() < 1e-6);
839        assert!((pred[1] - 0.9).abs() < 1e-6);
840    }
841
842    #[test]
843    fn test_calibrated_classifier_cv_smoke() {
844        use crate::dataset::Dataset;
845        use crate::tree::DecisionTreeClassifier;
846
847        let features = vec![
848            vec![0.0, 0.5, 1.0, 1.5, 5.0, 5.5, 6.0, 6.5],
849            vec![0.0, 0.5, 1.0, 1.5, 5.0, 5.5, 6.0, 6.5],
850        ];
851        let target = vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0];
852        let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
853
854        let mut cal =
855            CalibratedClassifierCV::new(DecisionTreeClassifier::new(), CalibrationMethod::Isotonic)
856                .n_folds(2);
857
858        cal.fit(&data).unwrap();
859        let proba = cal.predict_proba(&data.feature_matrix()).unwrap();
860
861        assert_eq!(proba.len(), 8);
862        for row in &proba {
863            assert_eq!(row.len(), 2);
864            let sum: f64 = row.iter().sum();
865            assert!(
866                (sum - 1.0).abs() < 1e-6,
867                "probabilities should sum to 1, got {sum}"
868            );
869        }
870
871        let preds = cal.predict(&data.feature_matrix()).unwrap();
872        assert_eq!(preds.len(), 8);
873    }
874
875    #[test]
876    fn test_calibrated_classifier_cv_sigmoid() {
877        use crate::dataset::Dataset;
878        use crate::tree::DecisionTreeClassifier;
879
880        let features = vec![
881            vec![0.0, 0.5, 1.0, 1.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5],
882            vec![0.0, 0.5, 1.0, 1.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5],
883        ];
884        let target = vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0];
885        let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
886
887        let mut cal =
888            CalibratedClassifierCV::new(DecisionTreeClassifier::new(), CalibrationMethod::Sigmoid)
889                .n_folds(2);
890
891        cal.fit(&data).unwrap();
892        let proba = cal.predict_proba(&data.feature_matrix()).unwrap();
893
894        // All probabilities should be valid.
895        for row in &proba {
896            for &p in row {
897                assert!((0.0..=1.0).contains(&p), "prob out of range: {p}");
898            }
899        }
900    }
901
902    #[test]
903    fn test_platt_constant_decision_values() {
904        let mut platt = PlattScaling::new();
905        let dv = vec![1.0; 10];
906        let labels = vec![1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0];
907        let result = platt.fit(&dv, &labels);
908        assert!(result.is_err(), "constant decision values should fail");
909        let err = result.unwrap_err();
910        assert!(
911            matches!(err, ScryLearnError::InvalidData(_)),
912            "expected InvalidData, got {err:?}"
913        );
914    }
915
916    #[test]
917    fn test_platt_near_singular() {
918        let mut platt = PlattScaling::new().max_iter(5);
919        // Decision values with near-zero variance — likely to produce singular Hessian.
920        let dv = vec![1.0, 1.0 + 1e-18, 1.0 - 1e-18, 1.0, 1.0 + 1e-18, 1.0 - 1e-18];
921        let labels = vec![1.0, 1.0, 1.0, 0.0, 0.0, 0.0];
922        let result = platt.fit(&dv, &labels);
923        // This should either hit the InvalidData check (values within EPSILON)
924        // or the ConvergenceFailure from singular det.
925        assert!(result.is_err(), "near-singular should fail");
926        let err = result.unwrap_err();
927        assert!(
928            matches!(
929                err,
930                ScryLearnError::InvalidData(_) | ScryLearnError::ConvergenceFailure { .. }
931            ),
932            "expected InvalidData or ConvergenceFailure, got {err:?}"
933        );
934    }
935}