Skip to main content

speech_prep/
format.rs

1//! Audio format detection and metadata extraction for the audio pipeline.
2//!
3//! This module provides fast, robust format detection using a hybrid approach:
4//! - **Fast path**: Magic-byte detection for common formats (WAV, FLAC, MP3) -
5//!   <1µs
6//! - **Validation path**: Symphonia probe for complex formats (Opus, `WebM`,
7//!   M4A) - <10ms
8//!
9//! ## Performance Contract
10//!
11//! - Detection latency: <1ms for 99% of inputs
12//! - Total processing: <10ms including validation
13//! - Zero panics: All byte access bounds-checked
14//!
15//! ## Supported Formats
16//!
17//! - **WAV** (RIFF/PCM): Primary format, instant detection
18//! - **FLAC**: Lossless compression, instant detection
19//! - **MP3**: MPEG-1/2 Layer 3, frame sync validation
20//! - **Opus**: Ogg container with Opus codec
21//! - **`WebM`**: Matroska container (audio track)
22//! - **M4A/AAC**: MPEG-4 container with AAC codec
23
24use std::io::Cursor;
25
26use crate::error::{Error, Result};
27use symphonia::core::io::MediaSourceStream;
28use symphonia::core::probe::Hint;
29
30/// Audio container and codec format identifier.
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
32pub enum AudioFormat {
33    /// RIFF WAV container with PCM encoding.
34    WavPcm,
35    /// Free Lossless Audio Codec.
36    Flac,
37    /// MPEG-1/2 Audio Layer 3.
38    Mp3,
39    /// Opus codec in Ogg container.
40    Opus,
41    /// `WebM` container (Matroska subset) with audio track.
42    WebM,
43    /// MPEG-4 container with AAC codec.
44    Aac,
45}
46
47impl AudioFormat {
48    /// Human-readable format name for logging and metrics.
49    #[must_use]
50    pub fn as_str(self) -> &'static str {
51        match self {
52            Self::WavPcm => "wav",
53            Self::Flac => "flac",
54            Self::Mp3 => "mp3",
55            Self::Opus => "opus",
56            Self::WebM => "webm",
57            Self::Aac => "aac",
58        }
59    }
60
61    /// Whether this format is lossless.
62    #[must_use]
63    pub const fn is_lossless(self) -> bool {
64        matches!(self, Self::WavPcm | Self::Flac)
65    }
66
67    /// Whether this format requires container parsing (vs raw frames).
68    #[must_use]
69    pub const fn is_container_format(self) -> bool {
70        matches!(self, Self::WavPcm | Self::Opus | Self::WebM | Self::Aac)
71    }
72}
73
74impl std::fmt::Display for AudioFormat {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        f.write_str(self.as_str())
77    }
78}
79
80/// Audio metadata extracted during format detection.
81#[derive(Debug, Clone, Copy, PartialEq)]
82pub struct AudioMetadata {
83    /// Detected container/codec format.
84    pub format: AudioFormat,
85    /// Number of audio channels (if determinable from header).
86    pub channels: Option<u16>,
87    /// Sample rate in Hz (if determinable from header).
88    pub sample_rate: Option<u32>,
89    /// Bit depth (if applicable for PCM formats).
90    pub bit_depth: Option<u16>,
91    /// Total duration in seconds (if available in container).
92    pub duration_sec: Option<f64>,
93}
94
95impl AudioMetadata {
96    /// Create metadata with only format known (minimal detection).
97    #[must_use]
98    pub const fn format_only(format: AudioFormat) -> Self {
99        Self {
100            format,
101            channels: None,
102            sample_rate: None,
103            bit_depth: None,
104            duration_sec: None,
105        }
106    }
107
108    /// Create metadata with format and basic audio properties.
109    #[must_use]
110    pub const fn with_properties(
111        format: AudioFormat,
112        channels: u16,
113        sample_rate: u32,
114        bit_depth: Option<u16>,
115    ) -> Self {
116        Self {
117            format,
118            channels: Some(channels),
119            sample_rate: Some(sample_rate),
120            bit_depth,
121            duration_sec: None,
122        }
123    }
124}
125
126/// Audio format detector using hybrid magic-byte + Symphonia validation.
127#[derive(Debug, Default, Clone, Copy)]
128pub struct FormatDetector;
129
130impl FormatDetector {
131    /// Create a new format detector instance.
132    #[must_use]
133    pub const fn new() -> Self {
134        Self
135    }
136
137    /// Detect audio format from byte stream using fast magic-byte detection.
138    ///
139    /// This is the primary entry point optimized for speed (<1µs for common
140    /// formats). Falls back to Symphonia validation for complex/ambiguous
141    /// formats.
142    ///
143    /// # Errors
144    ///
145    /// Returns `Error::InvalidInput` if:
146    /// - Payload is too short for any valid audio format
147    /// - Format is unsupported or unrecognized
148    /// - Byte stream is malformed (detected via Symphonia probe)
149    pub fn detect(data: &[u8]) -> Result<AudioMetadata> {
150        if data.len() < 4 {
151            return Err(Error::InvalidInput(
152                "audio payload too short (minimum 4 bytes required)".into(),
153            ));
154        }
155
156        if let Some(format) = Self::detect_magic_bytes(data) {
157            return Ok(AudioMetadata::format_only(format));
158        }
159
160        Self::detect_with_symphonia(data)
161    }
162
163    /// Detect format and extract full metadata using Symphonia probe.
164    ///
165    /// This method provides comprehensive metadata extraction but is slower
166    /// (~10ms). Use when full audio properties are needed (channels, sample
167    /// rate, duration).
168    ///
169    /// # Errors
170    ///
171    /// Returns `Error::InvalidInput` if format cannot be determined.
172    pub fn detect_with_metadata(data: &[u8]) -> Result<AudioMetadata> {
173        Self::detect_with_symphonia(data)
174    }
175
176    /// Fast magic-byte detection for common formats.
177    ///
178    /// Returns `Some(AudioFormat)` if format is recognized via magic bytes,
179    /// `None` if validation via Symphonia is needed.
180    fn detect_magic_bytes(data: &[u8]) -> Option<AudioFormat> {
181        let len = data.len();
182
183        // WAV: RIFF + size + WAVE
184        if len >= 12 {
185            if let (Some(riff), Some(wave)) = (data.get(0..4), data.get(8..12)) {
186                if riff == b"RIFF" && wave == b"WAVE" {
187                    return Some(AudioFormat::WavPcm);
188                }
189            }
190        }
191
192        // FLAC
193        if len >= 4 {
194            if let Some(header) = data.get(0..4) {
195                if header == b"fLaC" {
196                    return Some(AudioFormat::Flac);
197                }
198            }
199        }
200
201        // MP3: frame sync heuristic, validated by Symphonia downstream
202        if len >= 2 {
203            if let (Some(&first), Some(&second)) = (data.first(), data.get(1)) {
204                if first == 0xFF && (second & 0xE0) == 0xE0 {
205                    let layer = (second >> 1) & 0x03;
206                    if layer == 0x01 {
207                        return Some(AudioFormat::Mp3);
208                    }
209                }
210            }
211        }
212
213        // Ogg: needs Symphonia to distinguish Opus from Vorbis
214        if len >= 4 {
215            if let Some(header) = data.get(0..4) {
216                if header == b"OggS" {
217                    return None;
218                }
219            }
220        }
221
222        // WebM/Matroska: EBML header
223        if len >= 4 {
224            if let Some(header) = data.get(0..4) {
225                if header == [0x1A, 0x45, 0xDF, 0xA3] {
226                    return Some(AudioFormat::WebM);
227                }
228            }
229        }
230
231        // M4A/AAC: ftyp box
232        if len >= 12 {
233            if let (Some(ftyp), Some(brand)) = (data.get(4..8), data.get(8..12)) {
234                if ftyp == b"ftyp" && (brand == b"M4A " || brand == b"mp42" || brand == b"isom") {
235                    return Some(AudioFormat::Aac);
236                }
237            }
238        }
239
240        None
241    }
242
243    /// Detect format using Symphonia's comprehensive probe.
244    ///
245    /// This provides robust format validation and metadata extraction.
246    fn detect_with_symphonia(data: &[u8]) -> Result<AudioMetadata> {
247        let data_vec = data.to_vec();
248        let cursor = Cursor::new(data_vec);
249        let mss = MediaSourceStream::new(
250            Box::new(cursor),
251            symphonia::core::io::MediaSourceStreamOptions::default(),
252        );
253
254        let hint = Hint::new();
255        let probe_result = symphonia::default::get_probe()
256            .format(
257                &hint,
258                mss,
259                &symphonia::core::formats::FormatOptions::default(),
260                &symphonia::core::meta::MetadataOptions::default(),
261            )
262            .map_err(|err| {
263                Error::InvalidInput(format!("unsupported or malformed audio format: {err}"))
264            })?;
265
266        let format_reader = probe_result.format;
267        let codec_params = &format_reader
268            .default_track()
269            .ok_or_else(|| Error::InvalidInput("no audio track found in container".into()))?
270            .codec_params;
271
272        let format = match codec_params.codec {
273            symphonia::core::codecs::CODEC_TYPE_PCM_S16LE
274            | symphonia::core::codecs::CODEC_TYPE_PCM_S24LE
275            | symphonia::core::codecs::CODEC_TYPE_PCM_S32LE
276            | symphonia::core::codecs::CODEC_TYPE_PCM_F32LE => AudioFormat::WavPcm,
277            symphonia::core::codecs::CODEC_TYPE_FLAC => AudioFormat::Flac,
278            symphonia::core::codecs::CODEC_TYPE_MP3 => AudioFormat::Mp3,
279            symphonia::core::codecs::CODEC_TYPE_OPUS => AudioFormat::Opus,
280            symphonia::core::codecs::CODEC_TYPE_VORBIS => {
281                return Err(Error::InvalidInput(
282                    "Vorbis codec not supported (use Opus instead)".into(),
283                ));
284            }
285            symphonia::core::codecs::CODEC_TYPE_AAC => AudioFormat::Aac,
286            _ => {
287                return Err(Error::InvalidInput(format!(
288                    "unsupported codec: {:?}",
289                    codec_params.codec
290                )));
291            }
292        };
293
294        let channels = codec_params.channels.map(|ch| ch.count() as u16);
295        let sample_rate = codec_params.sample_rate;
296        let bit_depth = codec_params.bits_per_sample.map(|b| b as u16);
297        let duration_sec = codec_params
298            .n_frames
299            .and_then(|frames| sample_rate.map(|rate| frames as f64 / f64::from(rate)));
300
301        Ok(AudioMetadata {
302            format,
303            channels,
304            sample_rate,
305            bit_depth,
306            duration_sec,
307        })
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314
315    type TestResult<T> = std::result::Result<T, String>;
316
317    fn create_detector() -> FormatDetector {
318        FormatDetector::new()
319    }
320
321    fn detect_format(_detector: FormatDetector, data: &[u8]) -> TestResult<AudioMetadata> {
322        FormatDetector::detect(data).map_err(|e| e.to_string())
323    }
324
325    // Magic-byte test fixtures
326    fn wav_header() -> Vec<u8> {
327        // Minimal valid WAV header: RIFF + size + WAVE + fmt chunk
328        let mut header = Vec::new();
329        header.extend_from_slice(b"RIFF");
330        header.extend_from_slice(&36u32.to_le_bytes()); // File size - 8
331        header.extend_from_slice(b"WAVE");
332        header.extend_from_slice(b"fmt ");
333        header.extend_from_slice(&16u32.to_le_bytes()); // Subchunk1 size
334        header.extend_from_slice(&1u16.to_le_bytes()); // Audio format (PCM)
335        header.extend_from_slice(&2u16.to_le_bytes()); // Num channels (stereo)
336        header.extend_from_slice(&44100u32.to_le_bytes()); // Sample rate
337        header.extend_from_slice(&(44100u32 * 2 * 2).to_le_bytes()); // Byte rate
338        header.extend_from_slice(&4u16.to_le_bytes()); // Block align
339        header.extend_from_slice(&16u16.to_le_bytes()); // Bits per sample
340        header
341    }
342
343    fn flac_header() -> Vec<u8> {
344        // FLAC stream marker: "fLaC"
345        b"fLaC".to_vec()
346    }
347
348    fn mp3_header() -> Vec<u8> {
349        // MP3 frame sync: 0xFF 0xFB (MPEG-1 Layer 3, no CRC)
350        // Frame header format: 11111111 111BBCCD EEEEFFGH IIJJKLMM
351        // 0xFF 0xFB = 11111111 11111011
352        // Bits: sync(11) + version(11=MPEG-1) + layer(01=Layer3) + CRC(1=no)
353        vec![0xFF, 0xFB, 0x90, 0x00] // Minimal valid MP3 frame header
354    }
355
356    fn webm_header() -> Vec<u8> {
357        // WebM/Matroska EBML header
358        vec![0x1A, 0x45, 0xDF, 0xA3, 0x00, 0x00, 0x00, 0x20]
359    }
360
361    fn aac_header() -> Vec<u8> {
362        // M4A/AAC ftyp box
363        let mut header = Vec::new();
364        header.extend_from_slice(&20u32.to_be_bytes()); // Box size
365        header.extend_from_slice(b"ftyp"); // Box type
366        header.extend_from_slice(b"M4A "); // Major brand
367        header.extend_from_slice(&0u32.to_be_bytes()); // Minor version
368        header.extend_from_slice(b"mp42"); // Compatible brand
369        header
370    }
371
372    // Positive path tests
373    #[test]
374    fn test_detect_wav_format() -> TestResult<()> {
375        let detector = create_detector();
376        let metadata = detect_format(detector, &wav_header())?;
377        assert_eq!(metadata.format, AudioFormat::WavPcm);
378        assert_eq!(metadata.format.as_str(), "wav");
379        assert!(metadata.format.is_lossless());
380        Ok(())
381    }
382
383    #[test]
384    fn test_detect_flac_format() -> TestResult<()> {
385        let detector = create_detector();
386        let metadata = detect_format(detector, &flac_header())?;
387        assert_eq!(metadata.format, AudioFormat::Flac);
388        assert_eq!(metadata.format.as_str(), "flac");
389        assert!(metadata.format.is_lossless());
390        Ok(())
391    }
392
393    #[test]
394    fn test_detect_mp3_format() -> TestResult<()> {
395        let detector = create_detector();
396        let metadata = detect_format(detector, &mp3_header())?;
397        assert_eq!(metadata.format, AudioFormat::Mp3);
398        assert_eq!(metadata.format.as_str(), "mp3");
399        assert!(!metadata.format.is_lossless());
400        Ok(())
401    }
402
403    #[test]
404    fn test_detect_webm_format() -> TestResult<()> {
405        let detector = create_detector();
406        let metadata = detect_format(detector, &webm_header())?;
407        assert_eq!(metadata.format, AudioFormat::WebM);
408        assert_eq!(metadata.format.as_str(), "webm");
409        Ok(())
410    }
411
412    #[test]
413    fn test_detect_aac_format() -> TestResult<()> {
414        let detector = create_detector();
415        let metadata = detect_format(detector, &aac_header())?;
416        assert_eq!(metadata.format, AudioFormat::Aac);
417        assert_eq!(metadata.format.as_str(), "aac");
418        assert!(!metadata.format.is_lossless());
419        Ok(())
420    }
421
422    // Negative path tests
423    #[test]
424    fn test_reject_empty_payload() {
425        let result = FormatDetector::detect(&[]);
426        assert!(result.is_err());
427    }
428
429    #[test]
430    fn test_reject_too_short_payload() {
431        let result = FormatDetector::detect(&[0xFF, 0xFE]); // Only 2 bytes
432        assert!(result.is_err());
433    }
434
435    #[test]
436    fn test_reject_random_bytes() {
437        let random_data = vec![0xDE, 0xAD, 0xBE, 0xEF, 0xCA, 0xFE, 0xBA, 0xBE];
438        let result = FormatDetector::detect(&random_data);
439        assert!(result.is_err());
440    }
441
442    #[test]
443    fn test_reject_truncated_wav_header() {
444        let truncated = b"RIFF".to_vec(); // Missing size + WAVE
445        let result = FormatDetector::detect(&truncated);
446        assert!(result.is_err());
447    }
448
449    #[test]
450    fn test_reject_mismatched_riff_signature() {
451        let mut bad_wav = Vec::new();
452        bad_wav.extend_from_slice(b"RIFF");
453        bad_wav.extend_from_slice(&36u32.to_le_bytes());
454        bad_wav.extend_from_slice(b"AVI "); // Wrong signature (should be WAVE)
455        let result = FormatDetector::detect(&bad_wav);
456        assert!(result.is_err());
457    }
458
459    // Edge case tests
460    #[test]
461    fn test_handle_exact_minimum_length() -> TestResult<()> {
462        let detector = create_detector();
463        let flac_minimal = b"fLaC".to_vec(); // Exactly 4 bytes
464        let metadata = detect_format(detector, &flac_minimal)?;
465        assert_eq!(metadata.format, AudioFormat::Flac);
466        Ok(())
467    }
468
469    #[test]
470    fn test_handle_large_payload_prefix() -> TestResult<()> {
471        let detector = create_detector();
472        let mut large_payload = wav_header();
473        large_payload.extend(vec![0u8; 1024 * 1024]); // 1MB of silence
474        let metadata = detect_format(detector, &large_payload)?;
475        assert_eq!(metadata.format, AudioFormat::WavPcm);
476        Ok(())
477    }
478
479    // Property tests
480    #[test]
481    fn test_format_display_matches_as_str() {
482        let formats = [
483            AudioFormat::WavPcm,
484            AudioFormat::Flac,
485            AudioFormat::Mp3,
486            AudioFormat::Opus,
487            AudioFormat::WebM,
488            AudioFormat::Aac,
489        ];
490        for format in &formats {
491            assert_eq!(format.to_string(), format.as_str());
492        }
493    }
494
495    #[test]
496    fn test_lossless_formats_identified() {
497        assert!(AudioFormat::WavPcm.is_lossless());
498        assert!(AudioFormat::Flac.is_lossless());
499        assert!(!AudioFormat::Mp3.is_lossless());
500        assert!(!AudioFormat::Opus.is_lossless());
501        assert!(!AudioFormat::Aac.is_lossless());
502    }
503
504    #[test]
505    fn test_container_formats_identified() {
506        assert!(AudioFormat::WavPcm.is_container_format());
507        assert!(AudioFormat::Opus.is_container_format());
508        assert!(AudioFormat::WebM.is_container_format());
509        assert!(AudioFormat::Aac.is_container_format());
510        assert!(!AudioFormat::Flac.is_container_format());
511        assert!(!AudioFormat::Mp3.is_container_format());
512    }
513
514    #[test]
515    fn test_metadata_format_only_constructor() {
516        let metadata = AudioMetadata::format_only(AudioFormat::Mp3);
517        assert_eq!(metadata.format, AudioFormat::Mp3);
518        assert_eq!(metadata.channels, None);
519        assert_eq!(metadata.sample_rate, None);
520        assert_eq!(metadata.bit_depth, None);
521        assert_eq!(metadata.duration_sec, None);
522    }
523
524    #[test]
525    fn test_metadata_with_properties_constructor() {
526        let metadata = AudioMetadata::with_properties(AudioFormat::WavPcm, 2, 44100, Some(16));
527        assert_eq!(metadata.format, AudioFormat::WavPcm);
528        assert_eq!(metadata.channels, Some(2));
529        assert_eq!(metadata.sample_rate, Some(44100));
530        assert_eq!(metadata.bit_depth, Some(16));
531        assert_eq!(metadata.duration_sec, None);
532    }
533}