Skip to main content

whisper_macos_cli/audio/
decode.rs

1use std::io::{Cursor, Read, Seek, SeekFrom};
2use std::path::Path;
3
4use symphonia::core::audio::{AudioBufferRef, Signal};
5use symphonia::core::codecs::{CODEC_TYPE_NULL, DecoderOptions};
6use symphonia::core::formats::FormatOptions;
7use symphonia::core::io::MediaSourceStream;
8use symphonia::core::meta::MetadataOptions;
9use symphonia::core::probe::Hint;
10
11use crate::video::ffmpeg::{FfmpegRunner, RealFfmpeg, TempOutputGuard};
12use crate::video::is_video_magic_bytes;
13
14const OPUS_PRESKIP_SAMPLES: usize = 3840;
15const STDIN_MAX_BYTES: u64 = 2 * 1024 * 1024 * 1024;
16
17pub struct PcmData {
18    pub samples: Vec<i16>,
19    pub sample_rate: u32,
20    pub channels: usize,
21}
22
23impl PcmData {
24    pub fn duration_seconds(&self) -> f64 {
25        if self.sample_rate == 0 || self.channels == 0 {
26            return 0.0;
27        }
28        self.samples.len() as f64 / (self.sample_rate as f64 * self.channels as f64)
29    }
30}
31
32pub fn decode_file(path: &Path) -> Result<PcmData, crate::error::Error> {
33    decode_file_inner(path, &RealFfmpeg::new("ffmpeg"), true)
34}
35
36/// Decode a file with optional ffmpeg fallback for unsupported audio
37/// formats (notably OGG/Opus from WhatsApp) and for video containers
38/// (MP4, MOV, MKV, AVI, WebM, M4V).
39///
40/// # Arguments
41///
42/// * `path` — input file path
43/// * `runner` — ffmpeg implementation (real or mock)
44/// * `auto_fallback` — if `true`, transparently use ffmpeg when the
45///   native decode fails or the input is a video container
46pub fn decode_file_with_runner(
47    path: &Path,
48    runner: &dyn FfmpegRunner,
49    auto_fallback: bool,
50) -> Result<PcmData, crate::error::Error> {
51    decode_file_inner(path, runner, auto_fallback)
52}
53
54/// Internal entry point that performs the actual decode logic.
55///
56/// Split from `decode_file` to avoid recursion when `decode_via_ffmpeg`
57/// calls back into this function with the temp WAV (which is a regular
58/// audio file, so auto_fallback MUST be disabled for the inner call).
59fn decode_file_inner(
60    path: &Path,
61    runner: &dyn FfmpegRunner,
62    auto_fallback: bool,
63) -> Result<PcmData, crate::error::Error> {
64    let file = std::fs::File::open(path).map_err(|e| {
65        if e.kind() == std::io::ErrorKind::NotFound {
66            crate::error::Error::InputNotFound {
67                path: path.display().to_string(),
68            }
69        } else {
70            crate::error::Error::Io(e)
71        }
72    })?;
73
74    let mut header = [0u8; 12];
75    let header_len = match (&file).read(&mut header) {
76        Ok(n) => n,
77        Err(e) => return Err(crate::error::Error::Io(e)),
78    };
79    if let Err(e) = (&file).seek(SeekFrom::Start(0)) {
80        return Err(crate::error::Error::Io(e));
81    }
82
83    // Branch 1: Video container detected by magic bytes — must extract
84    // audio via ffmpeg. We do this BEFORE attempting native decode
85    // because symphonia will misidentify the format.
86    if header_len >= 4 && is_video_magic_bytes(&header[..header_len]) {
87        if !auto_fallback {
88            return Err(crate::error::Error::UnsupportedVideoFormat {
89                format: path
90                    .extension()
91                    .and_then(|e| e.to_str())
92                    .unwrap_or("unknown")
93                    .to_string(),
94            });
95        }
96        tracing::info!(
97            path = %path.display(),
98            "video container detected, routing through ffmpeg"
99        );
100        return decode_via_ffmpeg(path, runner);
101    }
102
103    // Branch 2: OGG/Opus magic — try native first, then OGG fallback.
104    if header_len >= 4 && is_ogg_opus_magic(&header[..header_len]) {
105        return decode_ogg_opus(file);
106    }
107
108    let source = MediaSourceStream::new(Box::new(file), Default::default());
109
110    let mut hint = Hint::new();
111    if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
112        hint.with_extension(ext);
113    }
114
115    match decode_stream(source, hint) {
116        Ok(pcm) => Ok(pcm),
117        Err(crate::error::Error::AudioDecode(ref e))
118            if e.to_string().contains("unsupported codec") =>
119        {
120            tracing::info!("symphonia unsupported codec, trying OGG/Opus fallback");
121            let file2 = std::fs::File::open(path).map_err(|e| {
122                if e.kind() == std::io::ErrorKind::NotFound {
123                    crate::error::Error::InputNotFound {
124                        path: path.display().to_string(),
125                    }
126                } else {
127                    crate::error::Error::Io(e)
128                }
129            })?;
130            decode_ogg_opus(file2)
131        }
132        Err(e) => {
133            // Branch 3: Native decode failed for a non-OGG reason
134            // (symphonia bug with OGG/Opus from WhatsApp is a known
135            // issue). If the file looks like OGG/Opus, auto-fallback to
136            // ffmpeg, which handles the codec correctly.
137            if auto_fallback && is_ogg_opus_magic(&header[..header_len.min(4)]) {
138                return decode_via_ffmpeg(path, runner);
139            }
140            // Branch 4: Other formats — try ffmpeg as last resort
141            // (covers HLS, MPEG-TS, exotic MP3 variants, etc.)
142            if auto_fallback && runner.is_available() {
143                tracing::warn!(
144                    path = %path.display(),
145                    error = %e,
146                    "native decode failed, attempting ffmpeg fallback"
147                );
148                return decode_via_ffmpeg(path, runner);
149            }
150            Err(e)
151        }
152    }
153}
154
155/// Extract audio from `input` to a temp WAV via ffmpeg, then decode the
156/// WAV with the native pipeline. The temp file is cleaned up via
157/// [`TempOutputGuard`].
158///
159/// # Recursion note
160///
161/// The inner call uses `auto_fallback=false` to prevent infinite
162/// recursion: the temp WAV is a normal audio file, not a video.
163fn decode_via_ffmpeg(
164    input: &Path,
165    runner: &dyn FfmpegRunner,
166) -> Result<PcmData, crate::error::Error> {
167    let result = runner.extract_audio_wav(input)?;
168    let wav_path = result.output_path;
169    let _guard = TempOutputGuard::new(wav_path.clone());
170    // The inner call MUST use auto_fallback=false to prevent recursion
171    // (the temp WAV is not a video and ffmpeg must not be called again).
172    decode_file_inner(&wav_path, runner, false)
173}
174
175pub fn decode_stdin(format_hint: Option<&str>) -> Result<PcmData, crate::error::Error> {
176    let mut buf = Vec::new();
177    let mut handle = std::io::stdin().take(STDIN_MAX_BYTES + 1);
178    handle
179        .read_to_end(&mut buf)
180        .map_err(crate::error::Error::Io)?;
181
182    if buf.is_empty() {
183        return Err(crate::error::Error::NoInput);
184    }
185    if buf.len() as u64 > STDIN_MAX_BYTES {
186        return Err(crate::error::Error::Config(format!(
187            "stdin input exceeds maximum size of {STDIN_MAX_BYTES} bytes"
188        )));
189    }
190
191    if is_ogg_opus_magic(&buf[..buf.len().min(12)]) {
192        return decode_ogg_opus(Cursor::new(buf));
193    }
194
195    let source = MediaSourceStream::new(Box::new(Cursor::new(buf.clone())), Default::default());
196
197    let mut hint = Hint::new();
198    if let Some(fmt) = format_hint {
199        hint.with_extension(fmt);
200    }
201
202    match decode_stream(source, hint) {
203        Ok(pcm) => Ok(pcm),
204        Err(crate::error::Error::AudioDecode(ref e))
205            if e.to_string().contains("unsupported codec") =>
206        {
207            tracing::info!("symphonia unsupported codec, trying OGG/Opus fallback");
208            decode_ogg_opus(Cursor::new(buf))
209        }
210        Err(e) => Err(e),
211    }
212}
213
214pub fn is_ogg_opus_magic(header: &[u8]) -> bool {
215    if header.len() < 4 {
216        return false;
217    }
218    &header[..4] == b"OggS"
219}
220
221fn decode_stream(source: MediaSourceStream, hint: Hint) -> Result<PcmData, crate::error::Error> {
222    let probed = symphonia::default::get_probe()
223        .format(
224            &hint,
225            source,
226            &FormatOptions::default(),
227            &MetadataOptions::default(),
228        )
229        .map_err(|e| crate::error::Error::AudioDecode(anyhow::anyhow!("probe failed: {e}")))?;
230
231    let mut reader = probed.format;
232
233    let track = reader
234        .tracks()
235        .iter()
236        .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
237        .ok_or_else(|| crate::error::Error::AudioDecode(anyhow::anyhow!("no audio track found")))?;
238
239    let track_id = track.id;
240    let codec_params = track.codec_params.clone();
241
242    let sample_rate = codec_params
243        .sample_rate
244        .ok_or_else(|| crate::error::Error::AudioDecode(anyhow::anyhow!("unknown sample rate")))?;
245
246    let channels = codec_params.channels.map(|c| c.count()).unwrap_or(2);
247
248    let mut decoder = symphonia::default::get_codecs()
249        .make(&codec_params, &DecoderOptions::default())
250        .map_err(|e| crate::error::Error::AudioDecode(anyhow::anyhow!("codec init failed: {e}")))?;
251
252    let mut all_samples: Vec<i16> = Vec::new();
253
254    loop {
255        let packet = match reader.next_packet() {
256            Ok(p) => p,
257            Err(symphonia::core::errors::Error::IoError(e))
258                if e.kind() == std::io::ErrorKind::UnexpectedEof =>
259            {
260                break;
261            }
262            Err(_) => continue,
263        };
264
265        if packet.track_id() != track_id {
266            continue;
267        }
268
269        let audio_buf = match decoder.decode(&packet) {
270            Ok(buf) => buf,
271            Err(_) => continue,
272        };
273
274        extract_i16_samples(&audio_buf, &mut all_samples);
275    }
276
277    if all_samples.is_empty() {
278        return Err(crate::error::Error::AudioDecode(anyhow::anyhow!(
279            "no audio samples decoded"
280        )));
281    }
282
283    Ok(PcmData {
284        samples: all_samples,
285        sample_rate,
286        channels,
287    })
288}
289
290pub fn to_mono(samples: &[i16], channels: usize) -> Vec<i16> {
291    if channels == 1 {
292        return samples.to_vec();
293    }
294
295    let num_frames = samples.len() / channels;
296    let mut mono = Vec::with_capacity(num_frames);
297
298    for frame in 0..num_frames {
299        let mut sum: i32 = 0;
300        for ch in 0..channels {
301            sum += samples[frame * channels + ch] as i32;
302        }
303        let avg = sum / channels as i32;
304        mono.push(avg.clamp(i16::MIN as i32, i16::MAX as i32) as i16);
305    }
306
307    mono
308}
309
310pub fn i16_to_f32(samples: &[i16]) -> Vec<f32> {
311    samples.iter().map(|&s| s as f32 / 32768.0).collect()
312}
313
314fn decode_ogg_opus<R: Read + Seek>(mut reader: R) -> Result<PcmData, crate::error::Error> {
315    use ogg::reading::PacketReader;
316
317    let mut ogg_reader = PacketReader::new(&mut reader);
318    let mut channels = 1u8;
319    let mut pre_skip = OPUS_PRESKIP_SAMPLES;
320    let mut header_packets = 0u8;
321
322    while header_packets < 2 {
323        let pkt = ogg_reader
324            .read_packet_expected()
325            .map_err(|e| crate::error::Error::AudioDecode(anyhow::anyhow!("ogg header: {e}")))?;
326
327        if header_packets == 0 && pkt.data.len() >= 16 && &pkt.data[..8] == b"OpusHead" {
328            channels = pkt.data[9];
329            pre_skip = u32::from_le_bytes([pkt.data[10], pkt.data[11], pkt.data[12], pkt.data[13]])
330                as usize;
331        }
332        header_packets += 1;
333    }
334
335    let channels_usize = channels.max(1) as usize;
336    let output_rate = 48000;
337
338    let mut decoder = opus_decoder::OpusDecoder::new(output_rate, channels_usize)
339        .map_err(|e| crate::error::Error::AudioDecode(anyhow::anyhow!("opus init: {e:?}")))?;
340
341    let max_frame = opus_decoder::OpusDecoder::MAX_FRAME_SIZE_48K;
342    let mut pcm_buf = vec![0i16; max_frame * channels_usize];
343    let mut all_samples: Vec<i16> = Vec::new();
344    let mut samples_to_skip = pre_skip;
345
346    loop {
347        let pkt = match ogg_reader.read_packet() {
348            Ok(Some(p)) => p,
349            Ok(None) => break,
350            Err(_) => continue,
351        };
352
353        match decoder.decode(&pkt.data, &mut pcm_buf, false) {
354            Ok(samples_per_channel) => {
355                let total = samples_per_channel * channels_usize;
356                let slice = &pcm_buf[..total];
357
358                if samples_to_skip >= total {
359                    samples_to_skip -= total;
360                } else if samples_to_skip > 0 {
361                    let kept = &slice[samples_to_skip..];
362                    all_samples.extend_from_slice(kept);
363                    samples_to_skip = 0;
364                } else {
365                    all_samples.extend_from_slice(slice);
366                }
367            }
368            Err(_) => continue,
369        }
370    }
371
372    if all_samples.is_empty() {
373        return Err(crate::error::Error::AudioDecode(anyhow::anyhow!(
374            "no audio samples decoded from OGG/Opus"
375        )));
376    }
377
378    tracing::info!(
379        samples = all_samples.len(),
380        channels = channels_usize,
381        preskip_discarded = pre_skip,
382        "OGG/Opus decoded via fallback"
383    );
384
385    Ok(PcmData {
386        samples: all_samples,
387        sample_rate: output_rate,
388        channels: channels_usize,
389    })
390}
391
392fn extract_i16_samples(buffer: &AudioBufferRef, dest: &mut Vec<i16>) {
393    match buffer {
394        AudioBufferRef::U8(buf) => {
395            let ch = buf.spec().channels.count();
396            let frames = buf.frames();
397            dest.reserve(frames * ch);
398            for f in 0..frames {
399                for c in 0..ch {
400                    dest.push(((buf.chan(c)[f] as i32 - 128) * 256) as i16);
401                }
402            }
403        }
404        AudioBufferRef::S16(buf) => {
405            let ch = buf.spec().channels.count();
406            let frames = buf.frames();
407            dest.reserve(frames * ch);
408            for f in 0..frames {
409                for c in 0..ch {
410                    dest.push(buf.chan(c)[f]);
411                }
412            }
413        }
414        AudioBufferRef::S32(buf) => {
415            let ch = buf.spec().channels.count();
416            let frames = buf.frames();
417            dest.reserve(frames * ch);
418            for f in 0..frames {
419                for c in 0..ch {
420                    dest.push((buf.chan(c)[f] >> 16) as i16);
421                }
422            }
423        }
424        AudioBufferRef::F32(buf) => {
425            let ch = buf.spec().channels.count();
426            let frames = buf.frames();
427            dest.reserve(frames * ch);
428            for f in 0..frames {
429                for c in 0..ch {
430                    let v = buf.chan(c)[f].clamp(-1.0, 1.0);
431                    dest.push((v * 32767.0) as i16);
432                }
433            }
434        }
435        AudioBufferRef::F64(buf) => {
436            let ch = buf.spec().channels.count();
437            let frames = buf.frames();
438            dest.reserve(frames * ch);
439            for f in 0..frames {
440                for c in 0..ch {
441                    let v = buf.chan(c)[f].clamp(-1.0, 1.0);
442                    dest.push((v * 32767.0) as i16);
443                }
444            }
445        }
446        _ => {}
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453
454    #[test]
455    fn to_mono_passthrough_single_channel() {
456        let samples = vec![100i16, 200, 300];
457        let result = to_mono(&samples, 1);
458        assert_eq!(result, samples);
459    }
460
461    #[test]
462    fn to_mono_averages_stereo() {
463        let samples = vec![100i16, 200, 300, 400];
464        let result = to_mono(&samples, 2);
465        assert_eq!(result, vec![150, 350]);
466    }
467
468    #[test]
469    fn i16_to_f32_converts_correctly() {
470        let samples = vec![0i16, 32767, -32768];
471        let result = i16_to_f32(&samples);
472        assert!((result[0] - 0.0).abs() < 0.001);
473        assert!((result[1] - 1.0).abs() < 0.001);
474        assert!((result[2] - (-1.0)).abs() < 0.001);
475    }
476
477    #[test]
478    fn opus_magic_detected() {
479        let ogg = b"OggS\x00\x02\x00\x00\x00\x00\x00\x00";
480        assert!(is_ogg_opus_magic(ogg));
481    }
482
483    #[test]
484    fn non_opus_not_detected() {
485        let wav = b"RIFF\x00\x00\x00\x00";
486        assert!(!is_ogg_opus_magic(wav));
487    }
488
489    #[test]
490    fn short_buffer_not_detected() {
491        let short = b"Og";
492        assert!(!is_ogg_opus_magic(short));
493    }
494
495    #[test]
496    fn pcm_data_duration_computed_correctly() {
497        let pcm = PcmData {
498            samples: vec![0i16; 16000 * 2],
499            sample_rate: 16000,
500            channels: 1,
501        };
502        assert!((pcm.duration_seconds() - 2.0).abs() < 0.001);
503    }
504
505    #[test]
506    fn pcm_data_duration_with_zero_sample_rate_is_zero() {
507        let pcm = PcmData {
508            samples: vec![100i16; 1000],
509            sample_rate: 0,
510            channels: 1,
511        };
512        assert_eq!(pcm.duration_seconds(), 0.0);
513    }
514
515    #[test]
516    fn pcm_data_duration_with_zero_channels_is_zero() {
517        let pcm = PcmData {
518            samples: vec![100i16; 1000],
519            sample_rate: 16000,
520            channels: 0,
521        };
522        assert_eq!(pcm.duration_seconds(), 0.0);
523    }
524
525    #[test]
526    fn pcm_data_duration_with_empty_samples_is_zero() {
527        let pcm = PcmData {
528            samples: Vec::new(),
529            sample_rate: 16000,
530            channels: 1,
531        };
532        assert_eq!(pcm.duration_seconds(), 0.0);
533    }
534
535    #[test]
536    fn pcm_data_duration_stereo_divides_by_channels() {
537        let pcm = PcmData {
538            samples: vec![0i16; 16000 * 2 * 2],
539            sample_rate: 16000,
540            channels: 2,
541        };
542        assert!((pcm.duration_seconds() - 2.0).abs() < 0.001);
543    }
544
545    #[test]
546    fn to_mono_handles_empty_input() {
547        let result = to_mono(&[], 1);
548        assert!(result.is_empty());
549        let result = to_mono(&[], 2);
550        assert!(result.is_empty());
551    }
552
553    #[test]
554    fn to_mono_six_channels_averages() {
555        let samples = vec![100i16, 200, 300, 400, 500, 600];
556        let result = to_mono(&samples, 6);
557        assert_eq!(result.len(), 1);
558        assert_eq!(result[0], 350);
559    }
560
561    #[test]
562    fn to_mono_quad_channel_averages() {
563        let samples = vec![100i16, 200, 300, 400];
564        let result = to_mono(&samples, 4);
565        assert_eq!(result, vec![250]);
566    }
567
568    #[test]
569    fn i16_to_f32_handles_min_max_boundary() {
570        let samples = vec![i16::MIN, i16::MAX, 0i16];
571        let result = i16_to_f32(&samples);
572        assert!((result[0] - (-1.0)).abs() < 0.001);
573        assert!((result[1] - 1.0).abs() < 0.001);
574        assert!((result[2] - 0.0).abs() < 0.001);
575    }
576
577    #[test]
578    fn i16_to_f32_handles_empty_input() {
579        let result = i16_to_f32(&[]);
580        assert!(result.is_empty());
581    }
582
583    #[test]
584    fn ogg_opus_magic_rejects_truncated_headers() {
585        assert!(!is_ogg_opus_magic(b"Og"));
586        assert!(!is_ogg_opus_magic(b"Ogg"));
587        assert!(!is_ogg_opus_magic(b""));
588    }
589
590    #[test]
591    fn ogg_opus_magic_accepts_full_header() {
592        let ogg = b"OggS\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
593        assert!(is_ogg_opus_magic(ogg));
594    }
595}