scirs2_metrics/visualization/
calibration.rs

1//! Calibration curve visualization
2//!
3//! This module provides tools for visualizing calibration curves (reliability diagrams).
4
5use scirs2_core::ndarray::{ArrayBase, Data, Ix1};
6use std::error::Error;
7
8use super::{MetricVisualizer, PlotType, VisualizationData, VisualizationMetadata};
9use crate::classification::curves::calibration_curve;
10use crate::error::{MetricsError, Result};
11
12/// Calibration curve visualizer
13///
14/// This struct provides methods for visualizing calibration curves (reliability diagrams).
15#[derive(Debug, Clone)]
16#[allow(dead_code)]
17pub struct CalibrationVisualizer<'a, T, S>
18where
19    T: Clone + PartialOrd,
20    S: Data<Elem = T>,
21{
22    /// Fraction of positive samples in each bin (empirical probabilities)
23    fraction_of_positives: Option<Vec<f64>>,
24    /// Mean predicted probability in each bin
25    mean_predicted_value: Option<Vec<f64>>,
26    /// Number of bins
27    n_bins: usize,
28    /// Strategy for binning
29    strategy: String,
30    /// Title for the plot
31    title: String,
32    /// Whether to display the perfect calibration line
33    show_perfectly_calibrated: bool,
34    /// Original y_true data
35    y_true: Option<&'a ArrayBase<S, Ix1>>,
36    /// Original y_prob data
37    y_prob: Option<&'a ArrayBase<S, Ix1>>,
38    /// Class label for multi-class calibration
39    pos_label: Option<T>,
40}
41
42impl<'a, T, S> CalibrationVisualizer<'a, T, S>
43where
44    T: Clone + PartialOrd + 'static,
45    S: Data<Elem = T>,
46    f64: From<T>,
47{
48    /// Create a new CalibrationVisualizer from pre-computed calibration curve data
49    ///
50    /// # Arguments
51    ///
52    /// * `fraction_of_positives` - Fraction of positive samples in each bin
53    /// * `mean_predicted_value` - Mean predicted probability in each bin
54    /// * `n_bins` - Number of bins used
55    /// * `strategy` - Strategy used for binning ("uniform" or "quantile")
56    ///
57    /// # Returns
58    ///
59    /// * A new CalibrationVisualizer
60    pub fn new(
61        fraction_of_positives: Vec<f64>,
62        mean_predicted_value: Vec<f64>,
63        n_bins: usize,
64        strategy: String,
65    ) -> Self {
66        CalibrationVisualizer {
67            fraction_of_positives: Some(fraction_of_positives),
68            mean_predicted_value: Some(mean_predicted_value),
69            n_bins,
70            strategy,
71            title: "Calibration Curve".to_string(),
72            show_perfectly_calibrated: true,
73            y_true: None,
74            y_prob: None,
75            pos_label: None,
76        }
77    }
78
79    /// Create a CalibrationVisualizer from true labels and probabilities
80    ///
81    /// # Arguments
82    ///
83    /// * `y_true` - True binary labels
84    /// * `y_prob` - Predicted probabilities
85    /// * `n_bins` - Number of bins to use
86    /// * `strategy` - Strategy for binning
87    /// * `pos_label` - Label of the positive class
88    ///
89    /// # Returns
90    ///
91    /// * A new CalibrationVisualizer
92    pub fn from_labels(
93        y_true: &'a ArrayBase<S, Ix1>,
94        y_prob: &'a ArrayBase<S, Ix1>,
95        n_bins: usize,
96        strategy: String,
97        pos_label: Option<T>,
98    ) -> Self {
99        CalibrationVisualizer {
100            fraction_of_positives: None,
101            mean_predicted_value: None,
102            n_bins,
103            strategy,
104            title: "Calibration Curve".to_string(),
105            show_perfectly_calibrated: true,
106            y_true: Some(y_true),
107            y_prob: Some(y_prob),
108            pos_label,
109        }
110    }
111
112    /// Set the title for the plot
113    ///
114    /// # Arguments
115    ///
116    /// * `title` - Title for the plot
117    ///
118    /// # Returns
119    ///
120    /// * Self for method chaining
121    pub fn with_title(mut self, title: String) -> Self {
122        self.title = title;
123        self
124    }
125
126    /// Set whether to display the perfect calibration line
127    ///
128    /// # Arguments
129    ///
130    /// * `show_perfectly_calibrated` - Whether to display the perfect calibration line
131    ///
132    /// # Returns
133    ///
134    /// * Self for method chaining
135    pub fn with_show_perfectly_calibrated(mut self, show_perfectlycalibrated: bool) -> Self {
136        self.show_perfectly_calibrated = show_perfectlycalibrated;
137        self
138    }
139
140    /// Compute the calibration curve if not already computed
141    ///
142    /// # Returns
143    ///
144    /// * Result containing (fraction_of_positives, mean_predicted_value)
145    fn compute_calibration(&self) -> Result<(Vec<f64>, Vec<f64>)> {
146        if self.fraction_of_positives.is_some() && self.mean_predicted_value.is_some() {
147            // Return pre-computed values
148            return Ok((
149                self.fraction_of_positives.clone().unwrap(),
150                self.mean_predicted_value.clone().unwrap(),
151            ));
152        }
153
154        if self.y_true.is_none() || self.y_prob.is_none() {
155            return Err(MetricsError::InvalidInput(
156                "No data provided for calibration curve computation".to_string(),
157            ));
158        }
159
160        let y_true = self.y_true.unwrap();
161        let y_prob = self.y_prob.unwrap();
162
163        // Compute calibration curve
164        let calib_result = calibration_curve(y_true, y_prob, Some(self.n_bins))?;
165
166        Ok((calib_result.0.to_vec(), calib_result.1.to_vec()))
167    }
168}
169
170impl<T, S> MetricVisualizer for CalibrationVisualizer<'_, T, S>
171where
172    T: Clone + PartialOrd + 'static,
173    S: Data<Elem = T>,
174    f64: From<T>,
175{
176    fn prepare_data(&self) -> std::result::Result<VisualizationData, Box<dyn Error>> {
177        let (fraction_of_positives, mean_predicted_value) = self
178            .compute_calibration()
179            .map_err(|e| Box::new(e) as Box<dyn Error>)?;
180
181        // Prepare data for visualization
182        let mut x_values = mean_predicted_value;
183        let mut y_values = fraction_of_positives;
184
185        // Prepare series names
186        let mut series_names = Vec::new();
187
188        series_names.push(format!("Calibration curve (bins={})", self.n_bins));
189
190        // Add perfect calibration line if requested
191        if self.show_perfectly_calibrated {
192            // Add the perfect calibration line (y = x)
193            series_names.push("Perfectly calibrated".to_string());
194
195            // Add points for the perfect calibration line
196            x_values.push(0.0);
197            x_values.push(1.0);
198            y_values.push(0.0);
199            y_values.push(1.0);
200        }
201
202        Ok(VisualizationData {
203            x: x_values,
204            y: y_values,
205            z: None,
206            series_names: Some(series_names),
207            x_labels: None,
208            y_labels: None,
209            auxiliary_data: std::collections::HashMap::new(),
210            auxiliary_metadata: std::collections::HashMap::new(),
211            series: std::collections::HashMap::new(),
212        })
213    }
214
215    fn get_metadata(&self) -> VisualizationMetadata {
216        VisualizationMetadata {
217            title: self.title.clone(),
218            x_label: "Mean predicted probability".to_string(),
219            y_label: "Fraction of positives".to_string(),
220            plot_type: PlotType::Line,
221            description: Some("Calibration curve (reliability diagram) showing the relationship between predicted probabilities and the actual fraction of positive samples".to_string()),
222        }
223    }
224}
225
226/// Create a calibration curve visualization from pre-computed calibration curve data
227///
228/// # Arguments
229///
230/// * `fraction_of_positives` - Fraction of positive samples in each bin
231/// * `mean_predicted_value` - Mean predicted probability in each bin
232/// * `n_bins` - Number of bins used
233/// * `strategy` - Strategy used for binning ("uniform" or "quantile")
234///
235/// # Returns
236///
237/// * A CalibrationVisualizer
238#[allow(dead_code)]
239pub fn calibration_visualization(
240    fraction_of_positives: Vec<f64>,
241    mean_predicted_value: Vec<f64>,
242    n_bins: usize,
243    strategy: String,
244) -> CalibrationVisualizer<'static, f64, scirs2_core::ndarray::OwnedRepr<f64>> {
245    CalibrationVisualizer::new(
246        fraction_of_positives,
247        mean_predicted_value,
248        n_bins,
249        strategy,
250    )
251}
252
253/// Create a calibration curve visualization from true labels and probabilities
254///
255/// # Arguments
256///
257/// * `y_true` - True binary labels
258/// * `y_prob` - Predicted probabilities
259/// * `n_bins` - Number of bins to use
260/// * `strategy` - Strategy for binning
261/// * `pos_label` - Optional label of the positive class
262///
263/// # Returns
264///
265/// * A CalibrationVisualizer
266#[allow(dead_code)]
267pub fn calibration_from_labels<'a, T, S>(
268    y_true: &'a ArrayBase<S, Ix1>,
269    y_prob: &'a ArrayBase<S, Ix1>,
270    n_bins: usize,
271    strategy: &str,
272    pos_label: Option<T>,
273) -> CalibrationVisualizer<'a, T, S>
274where
275    T: Clone + PartialOrd + 'static,
276    S: Data<Elem = T>,
277    f64: From<T>,
278{
279    CalibrationVisualizer::from_labels(y_true, y_prob, n_bins, strategy.to_string(), pos_label)
280}