scirs2_metrics/visualization/
calibration.rs1use scirs2_core::ndarray::{ArrayBase, Data, Ix1};
6use std::error::Error;
7
8use super::{MetricVisualizer, PlotType, VisualizationData, VisualizationMetadata};
9use crate::classification::curves::calibration_curve;
10use crate::error::{MetricsError, Result};
11
12#[derive(Debug, Clone)]
16#[allow(dead_code)]
17pub struct CalibrationVisualizer<'a, T, S>
18where
19 T: Clone + PartialOrd,
20 S: Data<Elem = T>,
21{
22 fraction_of_positives: Option<Vec<f64>>,
24 mean_predicted_value: Option<Vec<f64>>,
26 n_bins: usize,
28 strategy: String,
30 title: String,
32 show_perfectly_calibrated: bool,
34 y_true: Option<&'a ArrayBase<S, Ix1>>,
36 y_prob: Option<&'a ArrayBase<S, Ix1>>,
38 pos_label: Option<T>,
40}
41
42impl<'a, T, S> CalibrationVisualizer<'a, T, S>
43where
44 T: Clone + PartialOrd + 'static,
45 S: Data<Elem = T>,
46 f64: From<T>,
47{
48 pub fn new(
61 fraction_of_positives: Vec<f64>,
62 mean_predicted_value: Vec<f64>,
63 n_bins: usize,
64 strategy: String,
65 ) -> Self {
66 CalibrationVisualizer {
67 fraction_of_positives: Some(fraction_of_positives),
68 mean_predicted_value: Some(mean_predicted_value),
69 n_bins,
70 strategy,
71 title: "Calibration Curve".to_string(),
72 show_perfectly_calibrated: true,
73 y_true: None,
74 y_prob: None,
75 pos_label: None,
76 }
77 }
78
79 pub fn from_labels(
93 y_true: &'a ArrayBase<S, Ix1>,
94 y_prob: &'a ArrayBase<S, Ix1>,
95 n_bins: usize,
96 strategy: String,
97 pos_label: Option<T>,
98 ) -> Self {
99 CalibrationVisualizer {
100 fraction_of_positives: None,
101 mean_predicted_value: None,
102 n_bins,
103 strategy,
104 title: "Calibration Curve".to_string(),
105 show_perfectly_calibrated: true,
106 y_true: Some(y_true),
107 y_prob: Some(y_prob),
108 pos_label,
109 }
110 }
111
112 pub fn with_title(mut self, title: String) -> Self {
122 self.title = title;
123 self
124 }
125
126 pub fn with_show_perfectly_calibrated(mut self, show_perfectlycalibrated: bool) -> Self {
136 self.show_perfectly_calibrated = show_perfectlycalibrated;
137 self
138 }
139
140 fn compute_calibration(&self) -> Result<(Vec<f64>, Vec<f64>)> {
146 if self.fraction_of_positives.is_some() && self.mean_predicted_value.is_some() {
147 return Ok((
149 self.fraction_of_positives.clone().unwrap(),
150 self.mean_predicted_value.clone().unwrap(),
151 ));
152 }
153
154 if self.y_true.is_none() || self.y_prob.is_none() {
155 return Err(MetricsError::InvalidInput(
156 "No data provided for calibration curve computation".to_string(),
157 ));
158 }
159
160 let y_true = self.y_true.unwrap();
161 let y_prob = self.y_prob.unwrap();
162
163 let calib_result = calibration_curve(y_true, y_prob, Some(self.n_bins))?;
165
166 Ok((calib_result.0.to_vec(), calib_result.1.to_vec()))
167 }
168}
169
170impl<T, S> MetricVisualizer for CalibrationVisualizer<'_, T, S>
171where
172 T: Clone + PartialOrd + 'static,
173 S: Data<Elem = T>,
174 f64: From<T>,
175{
176 fn prepare_data(&self) -> std::result::Result<VisualizationData, Box<dyn Error>> {
177 let (fraction_of_positives, mean_predicted_value) = self
178 .compute_calibration()
179 .map_err(|e| Box::new(e) as Box<dyn Error>)?;
180
181 let mut x_values = mean_predicted_value;
183 let mut y_values = fraction_of_positives;
184
185 let mut series_names = Vec::new();
187
188 series_names.push(format!("Calibration curve (bins={})", self.n_bins));
189
190 if self.show_perfectly_calibrated {
192 series_names.push("Perfectly calibrated".to_string());
194
195 x_values.push(0.0);
197 x_values.push(1.0);
198 y_values.push(0.0);
199 y_values.push(1.0);
200 }
201
202 Ok(VisualizationData {
203 x: x_values,
204 y: y_values,
205 z: None,
206 series_names: Some(series_names),
207 x_labels: None,
208 y_labels: None,
209 auxiliary_data: std::collections::HashMap::new(),
210 auxiliary_metadata: std::collections::HashMap::new(),
211 series: std::collections::HashMap::new(),
212 })
213 }
214
215 fn get_metadata(&self) -> VisualizationMetadata {
216 VisualizationMetadata {
217 title: self.title.clone(),
218 x_label: "Mean predicted probability".to_string(),
219 y_label: "Fraction of positives".to_string(),
220 plot_type: PlotType::Line,
221 description: Some("Calibration curve (reliability diagram) showing the relationship between predicted probabilities and the actual fraction of positive samples".to_string()),
222 }
223 }
224}
225
226#[allow(dead_code)]
239pub fn calibration_visualization(
240 fraction_of_positives: Vec<f64>,
241 mean_predicted_value: Vec<f64>,
242 n_bins: usize,
243 strategy: String,
244) -> CalibrationVisualizer<'static, f64, scirs2_core::ndarray::OwnedRepr<f64>> {
245 CalibrationVisualizer::new(
246 fraction_of_positives,
247 mean_predicted_value,
248 n_bins,
249 strategy,
250 )
251}
252
253#[allow(dead_code)]
267pub fn calibration_from_labels<'a, T, S>(
268 y_true: &'a ArrayBase<S, Ix1>,
269 y_prob: &'a ArrayBase<S, Ix1>,
270 n_bins: usize,
271 strategy: &str,
272 pos_label: Option<T>,
273) -> CalibrationVisualizer<'a, T, S>
274where
275 T: Clone + PartialOrd + 'static,
276 S: Data<Elem = T>,
277 f64: From<T>,
278{
279 CalibrationVisualizer::from_labels(y_true, y_prob, n_bins, strategy.to_string(), pos_label)
280}