stratum_dsp/analysis/
confidence.rs

1//! Confidence scoring module
2//!
3//! Generates trustworthiness scores for analysis results. This module provides
4//! comprehensive confidence scoring that combines individual feature confidences
5//! into an overall assessment of analysis quality.
6//!
7//! # Confidence Components
8//!
9//! 1. **BPM Confidence**: Based on method agreement, peak prominence, and octave error handling
10//! 2. **Key Confidence**: Based on template matching score difference and key clarity
11//! 3. **Grid Stability**: Based on beat grid consistency and tempo variation
12//! 4. **Overall Confidence**: Weighted combination of all components
13//!
14//! # Example
15//!
16//! ```no_run
17//! use stratum_dsp::{analyze_audio, AnalysisConfig};
18//! use stratum_dsp::analysis::confidence::compute_confidence;
19//!
20//! let samples = vec![0.0f32; 44100 * 30];
21//! let result = analyze_audio(&samples, 44100, AnalysisConfig::default())?;
22//! let confidence = compute_confidence(&result);
23//!
24//! println!("Overall confidence: {:.2}", confidence.overall_confidence);
25//! # Ok::<(), stratum_dsp::AnalysisError>(())
26//! ```
27
28use super::result::{AnalysisResult, AnalysisFlag};
29use serde::{Deserialize, Serialize};
30
31/// Analysis confidence scores
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct AnalysisConfidence {
34    /// BPM confidence (0.0-1.0)
35    /// 
36    /// Based on:
37    /// - Method agreement (autocorrelation + comb filterbank)
38    /// - Peak prominence in period estimation
39    /// - Octave error handling
40    pub bpm_confidence: f32,
41    
42    /// Key confidence (0.0-1.0)
43    /// 
44    /// Based on:
45    /// - Template matching score difference
46    /// - Key clarity (tonal strength)
47    /// - Chroma vector quality
48    pub key_confidence: f32,
49    
50    /// Grid stability (0.0-1.0)
51    /// 
52    /// Based on:
53    /// - Beat grid consistency (coefficient of variation)
54    /// - Tempo variation detection
55    /// - Downbeat alignment
56    pub grid_stability: f32,
57    
58    /// Overall confidence (weighted average)
59    /// 
60    /// Weighted combination of BPM, key, and grid stability:
61    /// - BPM: 40% weight
62    /// - Key: 30% weight
63    /// - Grid: 30% weight
64    pub overall_confidence: f32,
65    
66    /// Confidence flags indicating specific issues
67    pub flags: Vec<AnalysisFlag>,
68}
69
70/// Compute confidence scores for analysis result
71///
72/// This function analyzes the analysis result and computes comprehensive
73/// confidence scores for each component (BPM, key, beat grid) as well as
74/// an overall confidence score.
75///
76/// # Arguments
77///
78/// * `result` - Analysis result from `analyze_audio()`
79///
80/// # Returns
81///
82/// `AnalysisConfidence` with individual and overall confidence scores
83///
84/// # Algorithm
85///
86/// 1. **BPM Confidence**: Uses the confidence from period estimation, adjusted for:
87///    - Method agreement (both autocorrelation and comb filterbank agree)
88///    - Peak prominence (strong vs weak peaks)
89///    - Edge cases (BPM = 0 indicates failure)
90///
91/// 2. **Key Confidence**: Uses the confidence from key detection, adjusted for:
92///    - Key clarity (tonal strength)
93///    - Template matching score difference
94///    - Edge cases (low confidence indicates ambiguous/atonal music)
95///
96/// 3. **Grid Stability**: Uses the grid stability from beat tracking, adjusted for:
97///    - Beat interval consistency
98///    - Tempo variation detection
99///    - Edge cases (empty grid indicates failure)
100///
101/// 4. **Overall Confidence**: Weighted average:
102///    - BPM: 40% weight (most important for DJ use case)
103///    - Key: 30% weight
104///    - Grid: 30% weight
105///
106/// # Example
107///
108/// ```no_run
109/// use stratum_dsp::{analyze_audio, AnalysisConfig};
110/// use stratum_dsp::analysis::confidence::compute_confidence;
111///
112/// let samples = vec![0.0f32; 44100 * 30];
113/// let result = analyze_audio(&samples, 44100, AnalysisConfig::default())?;
114/// let confidence = compute_confidence(&result);
115///
116/// if confidence.overall_confidence < 0.5 {
117///     println!("Warning: Low confidence analysis");
118/// }
119/// # Ok::<(), stratum_dsp::AnalysisError>(())
120/// ```
121pub fn compute_confidence(result: &AnalysisResult) -> AnalysisConfidence {
122    log::debug!("Computing confidence scores for analysis result");
123    
124    // 1. BPM Confidence
125    let bpm_confidence = compute_bpm_confidence(result);
126    
127    // 2. Key Confidence
128    let key_confidence = compute_key_confidence(result);
129    
130    // 3. Grid Stability (already computed, but we validate it)
131    let grid_stability = result.grid_stability.max(0.0).min(1.0);
132    
133    // 4. Overall Confidence (weighted average)
134    // Weights: BPM=40%, Key=30%, Grid=30%
135    // If any component failed (confidence = 0), reduce overall confidence
136    let overall_confidence = if bpm_confidence > 0.0 && key_confidence > 0.0 {
137        // All components succeeded: weighted average
138        (bpm_confidence * 0.4 + key_confidence * 0.3 + grid_stability * 0.3)
139            .max(0.0)
140            .min(1.0)
141    } else if bpm_confidence > 0.0 {
142        // Only BPM succeeded: use BPM confidence with penalty
143        bpm_confidence * 0.6
144    } else if key_confidence > 0.0 {
145        // Only key succeeded: use key confidence with penalty
146        key_confidence * 0.6
147    } else {
148        // All components failed
149        0.0
150    };
151    
152    // 5. Collect flags
153    let mut flags = result.metadata.flags.clone();
154    
155    // Add confidence-based flags
156    if bpm_confidence < 0.3 {
157        flags.push(AnalysisFlag::MultimodalBpm);
158    }
159    if key_confidence < 0.2 {
160        flags.push(AnalysisFlag::WeakTonality);
161    }
162    if grid_stability < 0.3 {
163        flags.push(AnalysisFlag::TempoVariation);
164    }
165    
166    log::debug!(
167        "Confidence scores: BPM={:.3}, Key={:.3}, Grid={:.3}, Overall={:.3}",
168        bpm_confidence,
169        key_confidence,
170        grid_stability,
171        overall_confidence
172    );
173    
174    AnalysisConfidence {
175        bpm_confidence,
176        key_confidence,
177        grid_stability,
178        overall_confidence,
179        flags,
180    }
181}
182
183impl AnalysisConfidence {
184    /// Check if overall confidence is high (>= 0.7)
185    ///
186    /// # Returns
187    ///
188    /// `true` if overall confidence is high, `false` otherwise
189    pub fn is_high_confidence(&self) -> bool {
190        self.overall_confidence >= 0.7
191    }
192    
193    /// Check if overall confidence is low (< 0.5)
194    ///
195    /// # Returns
196    ///
197    /// `true` if overall confidence is low, `false` otherwise
198    pub fn is_low_confidence(&self) -> bool {
199        self.overall_confidence < 0.5
200    }
201    
202    /// Check if overall confidence is medium (0.5-0.7)
203    ///
204    /// # Returns
205    ///
206    /// `true` if overall confidence is medium, `false` otherwise
207    pub fn is_medium_confidence(&self) -> bool {
208        self.overall_confidence >= 0.5 && self.overall_confidence < 0.7
209    }
210    
211    /// Get a human-readable confidence level description
212    ///
213    /// # Returns
214    ///
215    /// String describing the confidence level: "High", "Medium", or "Low"
216    pub fn confidence_level(&self) -> &'static str {
217        if self.is_high_confidence() {
218            "High"
219        } else if self.is_low_confidence() {
220            "Low"
221        } else {
222            "Medium"
223        }
224    }
225}
226
227/// Compute BPM confidence from analysis result
228///
229/// BPM confidence is based on:
230/// - The confidence score from period estimation
231/// - Method agreement (both autocorrelation and comb filterbank agree)
232/// - Edge cases (BPM = 0 indicates failure)
233fn compute_bpm_confidence(result: &AnalysisResult) -> f32 {
234    if result.bpm <= 0.0 {
235        // BPM detection failed
236        return 0.0;
237    }
238    
239    // Use the confidence from period estimation
240    // This already includes method agreement and peak prominence
241    let base_confidence = result.bpm_confidence.max(0.0).min(1.0);
242    
243    // Additional adjustments based on metadata
244    // Check if there are warnings about BPM
245    let has_bpm_warning = result.metadata.confidence_warnings.iter()
246        .any(|w| w.contains("BPM"));
247    
248    if has_bpm_warning {
249        // Reduce confidence if there are warnings
250        base_confidence * 0.7
251    } else {
252        base_confidence
253    }
254}
255
256/// Compute key confidence from analysis result
257///
258/// Key confidence is based on:
259/// - The confidence score from key detection
260/// - Key clarity (tonal strength) - directly incorporated
261/// - Edge cases (low confidence indicates ambiguous/atonal music)
262fn compute_key_confidence(result: &AnalysisResult) -> f32 {
263    if result.key_confidence <= 0.0 {
264        // Key detection failed or returned default
265        return 0.0;
266    }
267    
268    // Use the confidence from key detection
269    // This already includes template matching score difference
270    let base_confidence = result.key_confidence.max(0.0).min(1.0);
271    
272    // Incorporate key clarity directly: low clarity reduces confidence
273    // Key clarity is a strong indicator of detection reliability
274    let clarity_adjustment = if result.key_clarity < 0.2 {
275        // Low clarity: significant penalty
276        0.6
277    } else if result.key_clarity < 0.5 {
278        // Medium clarity: moderate penalty
279        0.85
280    } else {
281        // High clarity: no penalty (or slight boost)
282        1.0
283    };
284    
285    // Additional adjustments based on metadata warnings
286    let has_key_warning = result.metadata.confidence_warnings.iter()
287        .any(|w| w.contains("key") || w.contains("Key") || w.contains("tonality"));
288    
289    let warning_adjustment = if has_key_warning {
290        0.7
291    } else {
292        1.0
293    };
294    
295    // Apply both adjustments (multiplicative)
296    base_confidence * clarity_adjustment * warning_adjustment
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302    use crate::analysis::result::{AnalysisResult, AnalysisMetadata, BeatGrid, Key};
303    
304    fn create_test_result(
305        bpm: f32,
306        bpm_confidence: f32,
307        key: Key,
308        key_confidence: f32,
309        key_clarity: f32,
310        grid_stability: f32,
311    ) -> AnalysisResult {
312        AnalysisResult {
313            bpm,
314            bpm_confidence,
315            key,
316            key_confidence,
317            key_clarity,
318            beat_grid: BeatGrid {
319                downbeats: vec![],
320                beats: vec![],
321                bars: vec![],
322            },
323            grid_stability,
324            metadata: AnalysisMetadata {
325                duration_seconds: 30.0,
326                sample_rate: 44100,
327                processing_time_ms: 100.0,
328                algorithm_version: "0.1.0-alpha".to_string(),
329                onset_method_consensus: 1.0,
330                methods_used: vec![],
331                flags: vec![],
332                confidence_warnings: vec![],
333                tempogram_candidates: None,
334                tempogram_multi_res_triggered: None,
335                tempogram_multi_res_used: None,
336                tempogram_percussive_triggered: None,
337                tempogram_percussive_used: None,
338            },
339        }
340    }
341    
342    #[test]
343    fn test_compute_confidence_all_good() {
344        let result = create_test_result(
345            120.0,
346            0.9,
347            Key::Major(0),
348            0.8,
349            0.7, // High key clarity
350            0.85,
351        );
352        
353        let confidence = compute_confidence(&result);
354        
355        assert_eq!(confidence.bpm_confidence, 0.9);
356        assert_eq!(confidence.key_confidence, 0.8);
357        assert_eq!(confidence.grid_stability, 0.85);
358        
359        // Overall: 0.9*0.4 + 0.8*0.3 + 0.85*0.3 = 0.36 + 0.24 + 0.255 = 0.855
360        assert!((confidence.overall_confidence - 0.855).abs() < 0.01);
361    }
362    
363    #[test]
364    fn test_compute_confidence_bpm_failed() {
365        let result = create_test_result(
366            0.0,
367            0.0,
368            Key::Major(0),
369            0.8,
370            0.7, // High key clarity
371            0.85,
372        );
373        
374        let confidence = compute_confidence(&result);
375        
376        assert_eq!(confidence.bpm_confidence, 0.0);
377        assert_eq!(confidence.key_confidence, 0.8);
378        assert_eq!(confidence.grid_stability, 0.85);
379        
380        // Overall: Only key and grid, but BPM failed so overall is reduced
381        // Only key succeeded: 0.8 * 0.6 = 0.48
382        assert!((confidence.overall_confidence - 0.48).abs() < 0.01);
383    }
384    
385    #[test]
386    fn test_compute_confidence_key_failed() {
387        let result = create_test_result(
388            120.0,
389            0.9,
390            Key::Major(0),
391            0.0,
392            0.0, // No key clarity when key failed
393            0.85,
394        );
395        
396        let confidence = compute_confidence(&result);
397        
398        assert_eq!(confidence.bpm_confidence, 0.9);
399        assert_eq!(confidence.key_confidence, 0.0);
400        assert_eq!(confidence.grid_stability, 0.85);
401        
402        // Overall: Only BPM succeeded: 0.9 * 0.6 = 0.54
403        assert!((confidence.overall_confidence - 0.54).abs() < 0.01);
404    }
405    
406    #[test]
407    fn test_compute_confidence_all_failed() {
408        let result = create_test_result(
409            0.0,
410            0.0,
411            Key::Major(0),
412            0.0,
413            0.0, // No key clarity
414            0.0,
415        );
416        
417        let confidence = compute_confidence(&result);
418        
419        assert_eq!(confidence.bpm_confidence, 0.0);
420        assert_eq!(confidence.key_confidence, 0.0);
421        assert_eq!(confidence.grid_stability, 0.0);
422        assert_eq!(confidence.overall_confidence, 0.0);
423    }
424    
425    #[test]
426    fn test_compute_confidence_with_warnings() {
427        let mut result = create_test_result(
428            120.0,
429            0.9,
430            Key::Major(0),
431            0.8,
432            0.7, // High key clarity
433            0.85,
434        );
435        
436        result.metadata.confidence_warnings.push(
437            "BPM detection failed: insufficient onsets".to_string()
438        );
439        
440        let confidence = compute_confidence(&result);
441        
442        // BPM confidence should be reduced due to warning
443        assert!(confidence.bpm_confidence < 0.9);
444        assert!(confidence.bpm_confidence > 0.0);
445    }
446    
447    #[test]
448    fn test_compute_confidence_clamping() {
449        // Test that confidences are clamped to [0, 1]
450        let result = create_test_result(
451            120.0,
452            1.5, // > 1.0
453            Key::Major(0),
454            -0.5, // < 0.0
455            0.7, // Normal key clarity
456            2.0, // > 1.0
457        );
458        
459        let confidence = compute_confidence(&result);
460        
461        assert!(confidence.bpm_confidence <= 1.0);
462        assert!(confidence.key_confidence >= 0.0);
463        assert!(confidence.grid_stability <= 1.0);
464        assert!(confidence.overall_confidence >= 0.0);
465        assert!(confidence.overall_confidence <= 1.0);
466    }
467    
468    #[test]
469    fn test_confidence_helper_methods() {
470        let result = create_test_result(
471            120.0,
472            0.9,
473            Key::Major(0),
474            0.8,
475            0.7,
476            0.85,
477        );
478        
479        let confidence = compute_confidence(&result);
480        
481        // Should be high confidence
482        assert!(confidence.is_high_confidence());
483        assert!(!confidence.is_low_confidence());
484        assert!(!confidence.is_medium_confidence());
485        assert_eq!(confidence.confidence_level(), "High");
486        
487        // Test low confidence case
488        let low_result = create_test_result(
489            0.0,
490            0.0,
491            Key::Major(0),
492            0.0,
493            0.0,
494            0.0,
495        );
496        
497        let low_confidence = compute_confidence(&low_result);
498        assert!(low_confidence.is_low_confidence());
499        assert!(!low_confidence.is_high_confidence());
500        assert_eq!(low_confidence.confidence_level(), "Low");
501    }
502    
503    #[test]
504    fn test_key_clarity_adjustment() {
505        // Test that low key clarity reduces confidence
506        let high_clarity_result = create_test_result(
507            120.0,
508            0.9,
509            Key::Major(0),
510            0.8,
511            0.7, // High clarity
512            0.85,
513        );
514        
515        let low_clarity_result = create_test_result(
516            120.0,
517            0.9,
518            Key::Major(0),
519            0.8,
520            0.1, // Low clarity
521            0.85,
522        );
523        
524        let high_conf = compute_confidence(&high_clarity_result);
525        let low_conf = compute_confidence(&low_clarity_result);
526        
527        // Low clarity should result in lower key confidence
528        assert!(low_conf.key_confidence < high_conf.key_confidence);
529    }
530}
531