Skip to main content

sphereql_embed/
confidence.rs

1//! Unified confidence scoring for query results.
2//!
3//! Combines three independent quality signals into a single score:
4//! - **EVR**: global projection quality
5//! - **Certainty**: per-point projection fidelity
6//! - **Gap confidence**: proximity to category caps (void = low confidence)
7
8use sphereql_core::SphericalPoint;
9
10/// Confidence assessment for a single query result or point on S².
11#[derive(Debug, Clone, Copy)]
12pub struct QualitySignal {
13    pub evr: f64,
14    pub certainty: f64,
15    pub void_distance: f64,
16    pub gap_confidence: f64,
17    /// evr × certainty × gap_confidence
18    pub combined: f64,
19    pub level: ConfidenceLevel,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
23pub enum ConfidenceLevel {
24    Unreliable,
25    Low,
26    Moderate,
27    High,
28}
29
30impl std::fmt::Display for ConfidenceLevel {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            Self::High => write!(f, "HIGH"),
34            Self::Moderate => write!(f, "MODERATE"),
35            Self::Low => write!(f, "LOW"),
36            Self::Unreliable => write!(f, "UNRELIABLE"),
37        }
38    }
39}
40
41impl QualitySignal {
42    pub fn compute(evr: f64, certainty: f64, void_dist: f64, sharpness: f64) -> Self {
43        let gap_confidence = 1.0 / (1.0 + (sharpness * void_dist).exp());
44        let combined = evr * certainty * gap_confidence;
45        let level = classify(combined);
46        Self {
47            evr,
48            certainty,
49            void_distance: void_dist,
50            gap_confidence,
51            combined,
52            level,
53        }
54    }
55
56    /// Simplified: no void distance available (e.g., raw k-NN results).
57    pub fn from_certainty(evr: f64, certainty: f64) -> Self {
58        let gap_confidence = certainty.sqrt().max(0.01);
59        let combined = evr * certainty * gap_confidence;
60        let level = classify(combined);
61        Self {
62            evr,
63            certainty,
64            void_distance: 0.0,
65            gap_confidence,
66            combined,
67            level,
68        }
69    }
70
71    pub fn passes_threshold(&self, min_combined: f64) -> bool {
72        self.combined >= min_combined
73    }
74}
75
76fn classify(combined: f64) -> ConfidenceLevel {
77    if combined > 0.10 {
78        ConfidenceLevel::High
79    } else if combined > 0.03 {
80        ConfidenceLevel::Moderate
81    } else if combined > 0.005 {
82        ConfidenceLevel::Low
83    } else {
84        ConfidenceLevel::Unreliable
85    }
86}
87
88/// Full quality signal using centroids and half-angles for void distance.
89pub fn point_quality(
90    evr: f64,
91    certainty: f64,
92    position: &SphericalPoint,
93    centroids: &[SphericalPoint],
94    half_angles: &[f64],
95    sharpness: f64,
96) -> QualitySignal {
97    let void_dist = sphereql_core::spatial::void_distance(position, centroids, half_angles);
98    QualitySignal::compute(evr, certainty, void_dist, sharpness)
99}
100
101/// Configuration for quality-based filtering.
102#[derive(Debug, Clone)]
103pub struct QualityConfig {
104    /// Minimum per-point certainty. Default: 0.0 (no filtering).
105    pub min_certainty: f64,
106    /// Minimum combined confidence. Default: 0.0 (no filtering).
107    pub min_combined: f64,
108    /// Sigmoid sharpness for gap confidence. Default: 5.0.
109    pub gap_sharpness: f64,
110    /// EVR threshold for projection warnings. Default: 0.35.
111    pub warn_below_evr: f64,
112}
113
114impl Default for QualityConfig {
115    fn default() -> Self {
116        Self {
117            min_certainty: 0.0,
118            min_combined: 0.0,
119            gap_sharpness: 5.0,
120            warn_below_evr: 0.35,
121        }
122    }
123}
124
125/// A structured warning about projection quality.
126#[derive(Debug, Clone)]
127pub struct ProjectionWarning {
128    pub message: String,
129    pub evr: f64,
130    pub severity: WarningSeverity,
131}
132
133#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134pub enum WarningSeverity {
135    Info,
136    Warning,
137    Critical,
138}
139
140impl ProjectionWarning {
141    pub fn from_evr(evr: f64, threshold: f64) -> Option<Self> {
142        if evr >= threshold {
143            return None;
144        }
145        let (message, severity) = if evr < 0.15 {
146            (
147                format!(
148                    "EVR={:.1}% \u{2014} projection captures very little variance. \
149                 Category routing and bridges are unreliable. Use inner spheres.",
150                    evr * 100.0
151                ),
152                WarningSeverity::Critical,
153            )
154        } else if evr < 0.25 {
155            (
156                format!(
157                    "EVR={:.1}% \u{2014} projection is lossy. Bridge counts may be inflated. \
158                 Certainty-weighted results recommended.",
159                    evr * 100.0
160                ),
161                WarningSeverity::Warning,
162            )
163        } else {
164            (
165                format!(
166                    "EVR={:.1}% \u{2014} below recommended {:.0}%. Results usable with caution.",
167                    evr * 100.0,
168                    threshold * 100.0
169                ),
170                WarningSeverity::Info,
171            )
172        };
173        Some(Self {
174            message,
175            evr,
176            severity,
177        })
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn high_confidence_all_good() {
187        let sig = QualitySignal::compute(0.6, 0.8, -0.5, 5.0);
188        assert_eq!(sig.level, ConfidenceLevel::High);
189        assert!(sig.combined > 0.10);
190    }
191
192    #[test]
193    fn low_certainty_kills_confidence() {
194        let sig = QualitySignal::compute(0.6, 0.007, -0.5, 5.0);
195        assert!(sig.combined < 0.01);
196    }
197
198    #[test]
199    fn void_kills_confidence() {
200        let sig = QualitySignal::compute(0.6, 0.8, 1.0, 5.0);
201        assert!(sig.gap_confidence < 0.01);
202        assert_eq!(sig.level, ConfidenceLevel::Unreliable);
203    }
204
205    #[test]
206    fn low_evr_reduces_confidence() {
207        let good = QualitySignal::compute(0.6, 0.5, -0.3, 5.0);
208        let bad = QualitySignal::compute(0.19, 0.5, -0.3, 5.0);
209        assert!(good.combined > bad.combined);
210    }
211
212    #[test]
213    fn from_certainty_fallback() {
214        let sig = QualitySignal::from_certainty(0.5, 0.3);
215        assert!(sig.combined > 0.0);
216        assert_eq!(sig.void_distance, 0.0);
217    }
218
219    #[test]
220    fn threshold_filtering() {
221        let sig = QualitySignal::compute(0.19, 0.26, -0.2, 5.0);
222        assert!(sig.passes_threshold(0.0));
223        assert!(sig.passes_threshold(0.01));
224    }
225
226    #[test]
227    fn warning_at_low_evr() {
228        let w = ProjectionWarning::from_evr(0.19, 0.35).unwrap();
229        assert_eq!(w.severity, WarningSeverity::Warning);
230    }
231
232    #[test]
233    fn no_warning_at_high_evr() {
234        assert!(ProjectionWarning::from_evr(0.60, 0.35).is_none());
235    }
236
237    #[test]
238    fn critical_at_very_low_evr() {
239        let w = ProjectionWarning::from_evr(0.10, 0.35).unwrap();
240        assert_eq!(w.severity, WarningSeverity::Critical);
241    }
242
243    #[test]
244    fn confidence_levels_ordered() {
245        assert!(ConfidenceLevel::High > ConfidenceLevel::Moderate);
246        assert!(ConfidenceLevel::Moderate > ConfidenceLevel::Low);
247        assert!(ConfidenceLevel::Low > ConfidenceLevel::Unreliable);
248    }
249}