stratum_dsp/analysis/confidence.rs
1//! Confidence scoring module
2//!
3//! Generates trustworthiness scores for analysis results. This module provides
4//! comprehensive confidence scoring that combines individual feature confidences
5//! into an overall assessment of analysis quality.
6//!
7//! # Confidence Components
8//!
9//! 1. **BPM Confidence**: Based on method agreement, peak prominence, and octave error handling
10//! 2. **Key Confidence**: Based on template matching score difference and key clarity
11//! 3. **Grid Stability**: Based on beat grid consistency and tempo variation
12//! 4. **Overall Confidence**: Weighted combination of all components
13//!
14//! # Example
15//!
16//! ```no_run
17//! use stratum_dsp::{analyze_audio, AnalysisConfig};
18//! use stratum_dsp::analysis::confidence::compute_confidence;
19//!
20//! let samples = vec![0.0f32; 44100 * 30];
21//! let result = analyze_audio(&samples, 44100, AnalysisConfig::default())?;
22//! let confidence = compute_confidence(&result);
23//!
24//! println!("Overall confidence: {:.2}", confidence.overall_confidence);
25//! # Ok::<(), stratum_dsp::AnalysisError>(())
26//! ```
27
28use super::result::{AnalysisResult, AnalysisFlag};
29use serde::{Deserialize, Serialize};
30
31/// Analysis confidence scores
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct AnalysisConfidence {
34 /// BPM confidence (0.0-1.0)
35 ///
36 /// Based on:
37 /// - Method agreement (autocorrelation + comb filterbank)
38 /// - Peak prominence in period estimation
39 /// - Octave error handling
40 pub bpm_confidence: f32,
41
42 /// Key confidence (0.0-1.0)
43 ///
44 /// Based on:
45 /// - Template matching score difference
46 /// - Key clarity (tonal strength)
47 /// - Chroma vector quality
48 pub key_confidence: f32,
49
50 /// Grid stability (0.0-1.0)
51 ///
52 /// Based on:
53 /// - Beat grid consistency (coefficient of variation)
54 /// - Tempo variation detection
55 /// - Downbeat alignment
56 pub grid_stability: f32,
57
58 /// Overall confidence (weighted average)
59 ///
60 /// Weighted combination of BPM, key, and grid stability:
61 /// - BPM: 40% weight
62 /// - Key: 30% weight
63 /// - Grid: 30% weight
64 pub overall_confidence: f32,
65
66 /// Confidence flags indicating specific issues
67 pub flags: Vec<AnalysisFlag>,
68}
69
70/// Compute confidence scores for analysis result
71///
72/// This function analyzes the analysis result and computes comprehensive
73/// confidence scores for each component (BPM, key, beat grid) as well as
74/// an overall confidence score.
75///
76/// # Arguments
77///
78/// * `result` - Analysis result from `analyze_audio()`
79///
80/// # Returns
81///
82/// `AnalysisConfidence` with individual and overall confidence scores
83///
84/// # Algorithm
85///
86/// 1. **BPM Confidence**: Uses the confidence from period estimation, adjusted for:
87/// - Method agreement (both autocorrelation and comb filterbank agree)
88/// - Peak prominence (strong vs weak peaks)
89/// - Edge cases (BPM = 0 indicates failure)
90///
91/// 2. **Key Confidence**: Uses the confidence from key detection, adjusted for:
92/// - Key clarity (tonal strength)
93/// - Template matching score difference
94/// - Edge cases (low confidence indicates ambiguous/atonal music)
95///
96/// 3. **Grid Stability**: Uses the grid stability from beat tracking, adjusted for:
97/// - Beat interval consistency
98/// - Tempo variation detection
99/// - Edge cases (empty grid indicates failure)
100///
101/// 4. **Overall Confidence**: Weighted average:
102/// - BPM: 40% weight (most important for DJ use case)
103/// - Key: 30% weight
104/// - Grid: 30% weight
105///
106/// # Example
107///
108/// ```no_run
109/// use stratum_dsp::{analyze_audio, AnalysisConfig};
110/// use stratum_dsp::analysis::confidence::compute_confidence;
111///
112/// let samples = vec![0.0f32; 44100 * 30];
113/// let result = analyze_audio(&samples, 44100, AnalysisConfig::default())?;
114/// let confidence = compute_confidence(&result);
115///
116/// if confidence.overall_confidence < 0.5 {
117/// println!("Warning: Low confidence analysis");
118/// }
119/// # Ok::<(), stratum_dsp::AnalysisError>(())
120/// ```
121pub fn compute_confidence(result: &AnalysisResult) -> AnalysisConfidence {
122 log::debug!("Computing confidence scores for analysis result");
123
124 // 1. BPM Confidence
125 let bpm_confidence = compute_bpm_confidence(result);
126
127 // 2. Key Confidence
128 let key_confidence = compute_key_confidence(result);
129
130 // 3. Grid Stability (already computed, but we validate it)
131 let grid_stability = result.grid_stability.max(0.0).min(1.0);
132
133 // 4. Overall Confidence (weighted average)
134 // Weights: BPM=40%, Key=30%, Grid=30%
135 // If any component failed (confidence = 0), reduce overall confidence
136 let overall_confidence = if bpm_confidence > 0.0 && key_confidence > 0.0 {
137 // All components succeeded: weighted average
138 (bpm_confidence * 0.4 + key_confidence * 0.3 + grid_stability * 0.3)
139 .max(0.0)
140 .min(1.0)
141 } else if bpm_confidence > 0.0 {
142 // Only BPM succeeded: use BPM confidence with penalty
143 bpm_confidence * 0.6
144 } else if key_confidence > 0.0 {
145 // Only key succeeded: use key confidence with penalty
146 key_confidence * 0.6
147 } else {
148 // All components failed
149 0.0
150 };
151
152 // 5. Collect flags
153 let mut flags = result.metadata.flags.clone();
154
155 // Add confidence-based flags
156 if bpm_confidence < 0.3 {
157 flags.push(AnalysisFlag::MultimodalBpm);
158 }
159 if key_confidence < 0.2 {
160 flags.push(AnalysisFlag::WeakTonality);
161 }
162 if grid_stability < 0.3 {
163 flags.push(AnalysisFlag::TempoVariation);
164 }
165
166 log::debug!(
167 "Confidence scores: BPM={:.3}, Key={:.3}, Grid={:.3}, Overall={:.3}",
168 bpm_confidence,
169 key_confidence,
170 grid_stability,
171 overall_confidence
172 );
173
174 AnalysisConfidence {
175 bpm_confidence,
176 key_confidence,
177 grid_stability,
178 overall_confidence,
179 flags,
180 }
181}
182
183impl AnalysisConfidence {
184 /// Check if overall confidence is high (>= 0.7)
185 ///
186 /// # Returns
187 ///
188 /// `true` if overall confidence is high, `false` otherwise
189 pub fn is_high_confidence(&self) -> bool {
190 self.overall_confidence >= 0.7
191 }
192
193 /// Check if overall confidence is low (< 0.5)
194 ///
195 /// # Returns
196 ///
197 /// `true` if overall confidence is low, `false` otherwise
198 pub fn is_low_confidence(&self) -> bool {
199 self.overall_confidence < 0.5
200 }
201
202 /// Check if overall confidence is medium (0.5-0.7)
203 ///
204 /// # Returns
205 ///
206 /// `true` if overall confidence is medium, `false` otherwise
207 pub fn is_medium_confidence(&self) -> bool {
208 self.overall_confidence >= 0.5 && self.overall_confidence < 0.7
209 }
210
211 /// Get a human-readable confidence level description
212 ///
213 /// # Returns
214 ///
215 /// String describing the confidence level: "High", "Medium", or "Low"
216 pub fn confidence_level(&self) -> &'static str {
217 if self.is_high_confidence() {
218 "High"
219 } else if self.is_low_confidence() {
220 "Low"
221 } else {
222 "Medium"
223 }
224 }
225}
226
227/// Compute BPM confidence from analysis result
228///
229/// BPM confidence is based on:
230/// - The confidence score from period estimation
231/// - Method agreement (both autocorrelation and comb filterbank agree)
232/// - Edge cases (BPM = 0 indicates failure)
233fn compute_bpm_confidence(result: &AnalysisResult) -> f32 {
234 if result.bpm <= 0.0 {
235 // BPM detection failed
236 return 0.0;
237 }
238
239 // Use the confidence from period estimation
240 // This already includes method agreement and peak prominence
241 let base_confidence = result.bpm_confidence.max(0.0).min(1.0);
242
243 // Additional adjustments based on metadata
244 // Check if there are warnings about BPM
245 let has_bpm_warning = result.metadata.confidence_warnings.iter()
246 .any(|w| w.contains("BPM"));
247
248 if has_bpm_warning {
249 // Reduce confidence if there are warnings
250 base_confidence * 0.7
251 } else {
252 base_confidence
253 }
254}
255
256/// Compute key confidence from analysis result
257///
258/// Key confidence is based on:
259/// - The confidence score from key detection
260/// - Key clarity (tonal strength) - directly incorporated
261/// - Edge cases (low confidence indicates ambiguous/atonal music)
262fn compute_key_confidence(result: &AnalysisResult) -> f32 {
263 if result.key_confidence <= 0.0 {
264 // Key detection failed or returned default
265 return 0.0;
266 }
267
268 // Use the confidence from key detection
269 // This already includes template matching score difference
270 let base_confidence = result.key_confidence.max(0.0).min(1.0);
271
272 // Incorporate key clarity directly: low clarity reduces confidence
273 // Key clarity is a strong indicator of detection reliability
274 let clarity_adjustment = if result.key_clarity < 0.2 {
275 // Low clarity: significant penalty
276 0.6
277 } else if result.key_clarity < 0.5 {
278 // Medium clarity: moderate penalty
279 0.85
280 } else {
281 // High clarity: no penalty (or slight boost)
282 1.0
283 };
284
285 // Additional adjustments based on metadata warnings
286 let has_key_warning = result.metadata.confidence_warnings.iter()
287 .any(|w| w.contains("key") || w.contains("Key") || w.contains("tonality"));
288
289 let warning_adjustment = if has_key_warning {
290 0.7
291 } else {
292 1.0
293 };
294
295 // Apply both adjustments (multiplicative)
296 base_confidence * clarity_adjustment * warning_adjustment
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use crate::analysis::result::{AnalysisResult, AnalysisMetadata, BeatGrid, Key};
303
304 fn create_test_result(
305 bpm: f32,
306 bpm_confidence: f32,
307 key: Key,
308 key_confidence: f32,
309 key_clarity: f32,
310 grid_stability: f32,
311 ) -> AnalysisResult {
312 AnalysisResult {
313 bpm,
314 bpm_confidence,
315 key,
316 key_confidence,
317 key_clarity,
318 beat_grid: BeatGrid {
319 downbeats: vec![],
320 beats: vec![],
321 bars: vec![],
322 },
323 grid_stability,
324 metadata: AnalysisMetadata {
325 duration_seconds: 30.0,
326 sample_rate: 44100,
327 processing_time_ms: 100.0,
328 algorithm_version: "0.1.0-alpha".to_string(),
329 onset_method_consensus: 1.0,
330 methods_used: vec![],
331 flags: vec![],
332 confidence_warnings: vec![],
333 tempogram_candidates: None,
334 tempogram_multi_res_triggered: None,
335 tempogram_multi_res_used: None,
336 tempogram_percussive_triggered: None,
337 tempogram_percussive_used: None,
338 },
339 }
340 }
341
342 #[test]
343 fn test_compute_confidence_all_good() {
344 let result = create_test_result(
345 120.0,
346 0.9,
347 Key::Major(0),
348 0.8,
349 0.7, // High key clarity
350 0.85,
351 );
352
353 let confidence = compute_confidence(&result);
354
355 assert_eq!(confidence.bpm_confidence, 0.9);
356 assert_eq!(confidence.key_confidence, 0.8);
357 assert_eq!(confidence.grid_stability, 0.85);
358
359 // Overall: 0.9*0.4 + 0.8*0.3 + 0.85*0.3 = 0.36 + 0.24 + 0.255 = 0.855
360 assert!((confidence.overall_confidence - 0.855).abs() < 0.01);
361 }
362
363 #[test]
364 fn test_compute_confidence_bpm_failed() {
365 let result = create_test_result(
366 0.0,
367 0.0,
368 Key::Major(0),
369 0.8,
370 0.7, // High key clarity
371 0.85,
372 );
373
374 let confidence = compute_confidence(&result);
375
376 assert_eq!(confidence.bpm_confidence, 0.0);
377 assert_eq!(confidence.key_confidence, 0.8);
378 assert_eq!(confidence.grid_stability, 0.85);
379
380 // Overall: Only key and grid, but BPM failed so overall is reduced
381 // Only key succeeded: 0.8 * 0.6 = 0.48
382 assert!((confidence.overall_confidence - 0.48).abs() < 0.01);
383 }
384
385 #[test]
386 fn test_compute_confidence_key_failed() {
387 let result = create_test_result(
388 120.0,
389 0.9,
390 Key::Major(0),
391 0.0,
392 0.0, // No key clarity when key failed
393 0.85,
394 );
395
396 let confidence = compute_confidence(&result);
397
398 assert_eq!(confidence.bpm_confidence, 0.9);
399 assert_eq!(confidence.key_confidence, 0.0);
400 assert_eq!(confidence.grid_stability, 0.85);
401
402 // Overall: Only BPM succeeded: 0.9 * 0.6 = 0.54
403 assert!((confidence.overall_confidence - 0.54).abs() < 0.01);
404 }
405
406 #[test]
407 fn test_compute_confidence_all_failed() {
408 let result = create_test_result(
409 0.0,
410 0.0,
411 Key::Major(0),
412 0.0,
413 0.0, // No key clarity
414 0.0,
415 );
416
417 let confidence = compute_confidence(&result);
418
419 assert_eq!(confidence.bpm_confidence, 0.0);
420 assert_eq!(confidence.key_confidence, 0.0);
421 assert_eq!(confidence.grid_stability, 0.0);
422 assert_eq!(confidence.overall_confidence, 0.0);
423 }
424
425 #[test]
426 fn test_compute_confidence_with_warnings() {
427 let mut result = create_test_result(
428 120.0,
429 0.9,
430 Key::Major(0),
431 0.8,
432 0.7, // High key clarity
433 0.85,
434 );
435
436 result.metadata.confidence_warnings.push(
437 "BPM detection failed: insufficient onsets".to_string()
438 );
439
440 let confidence = compute_confidence(&result);
441
442 // BPM confidence should be reduced due to warning
443 assert!(confidence.bpm_confidence < 0.9);
444 assert!(confidence.bpm_confidence > 0.0);
445 }
446
447 #[test]
448 fn test_compute_confidence_clamping() {
449 // Test that confidences are clamped to [0, 1]
450 let result = create_test_result(
451 120.0,
452 1.5, // > 1.0
453 Key::Major(0),
454 -0.5, // < 0.0
455 0.7, // Normal key clarity
456 2.0, // > 1.0
457 );
458
459 let confidence = compute_confidence(&result);
460
461 assert!(confidence.bpm_confidence <= 1.0);
462 assert!(confidence.key_confidence >= 0.0);
463 assert!(confidence.grid_stability <= 1.0);
464 assert!(confidence.overall_confidence >= 0.0);
465 assert!(confidence.overall_confidence <= 1.0);
466 }
467
468 #[test]
469 fn test_confidence_helper_methods() {
470 let result = create_test_result(
471 120.0,
472 0.9,
473 Key::Major(0),
474 0.8,
475 0.7,
476 0.85,
477 );
478
479 let confidence = compute_confidence(&result);
480
481 // Should be high confidence
482 assert!(confidence.is_high_confidence());
483 assert!(!confidence.is_low_confidence());
484 assert!(!confidence.is_medium_confidence());
485 assert_eq!(confidence.confidence_level(), "High");
486
487 // Test low confidence case
488 let low_result = create_test_result(
489 0.0,
490 0.0,
491 Key::Major(0),
492 0.0,
493 0.0,
494 0.0,
495 );
496
497 let low_confidence = compute_confidence(&low_result);
498 assert!(low_confidence.is_low_confidence());
499 assert!(!low_confidence.is_high_confidence());
500 assert_eq!(low_confidence.confidence_level(), "Low");
501 }
502
503 #[test]
504 fn test_key_clarity_adjustment() {
505 // Test that low key clarity reduces confidence
506 let high_clarity_result = create_test_result(
507 120.0,
508 0.9,
509 Key::Major(0),
510 0.8,
511 0.7, // High clarity
512 0.85,
513 );
514
515 let low_clarity_result = create_test_result(
516 120.0,
517 0.9,
518 Key::Major(0),
519 0.8,
520 0.1, // Low clarity
521 0.85,
522 );
523
524 let high_conf = compute_confidence(&high_clarity_result);
525 let low_conf = compute_confidence(&low_clarity_result);
526
527 // Low clarity should result in lower key confidence
528 assert!(low_conf.key_confidence < high_conf.key_confidence);
529 }
530}
531