Skip to main content

whisper_macos_cli/audio/
decode.rs

1use std::io::{Cursor, Read, Seek, SeekFrom};
2use std::path::Path;
3
4use symphonia::core::audio::{AudioBufferRef, Signal};
5use symphonia::core::codecs::{CODEC_TYPE_NULL, DecoderOptions};
6use symphonia::core::formats::FormatOptions;
7use symphonia::core::io::MediaSourceStream;
8use symphonia::core::meta::MetadataOptions;
9use symphonia::core::probe::Hint;
10
11const OPUS_PRESKIP_SAMPLES: usize = 3840;
12const STDIN_MAX_BYTES: u64 = 2 * 1024 * 1024 * 1024;
13
14pub struct PcmData {
15    pub samples: Vec<i16>,
16    pub sample_rate: u32,
17    pub channels: usize,
18}
19
20impl PcmData {
21    pub fn duration_seconds(&self) -> f64 {
22        if self.sample_rate == 0 || self.channels == 0 {
23            return 0.0;
24        }
25        self.samples.len() as f64 / (self.sample_rate as f64 * self.channels as f64)
26    }
27}
28
29pub fn decode_file(path: &Path) -> Result<PcmData, crate::error::Error> {
30    let file = std::fs::File::open(path).map_err(|e| {
31        if e.kind() == std::io::ErrorKind::NotFound {
32            crate::error::Error::InputNotFound {
33                path: path.display().to_string(),
34            }
35        } else {
36            crate::error::Error::Io(e)
37        }
38    })?;
39
40    let mut header = [0u8; 12];
41    let header_len = match (&file).read(&mut header) {
42        Ok(n) => n,
43        Err(e) => return Err(crate::error::Error::Io(e)),
44    };
45    if let Err(e) = (&file).seek(SeekFrom::Start(0)) {
46        return Err(crate::error::Error::Io(e));
47    }
48
49    if header_len >= 4 && is_ogg_opus_magic(&header[..header_len]) {
50        return decode_ogg_opus(file);
51    }
52
53    let source = MediaSourceStream::new(Box::new(file), Default::default());
54
55    let mut hint = Hint::new();
56    if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
57        hint.with_extension(ext);
58    }
59
60    match decode_stream(source, hint) {
61        Ok(pcm) => Ok(pcm),
62        Err(crate::error::Error::AudioDecode(ref e))
63            if e.to_string().contains("unsupported codec") =>
64        {
65            tracing::info!("symphonia unsupported codec, trying OGG/Opus fallback");
66            let file2 = std::fs::File::open(path).map_err(|e| {
67                if e.kind() == std::io::ErrorKind::NotFound {
68                    crate::error::Error::InputNotFound {
69                        path: path.display().to_string(),
70                    }
71                } else {
72                    crate::error::Error::Io(e)
73                }
74            })?;
75            decode_ogg_opus(file2)
76        }
77        Err(e) => Err(e),
78    }
79}
80
81pub fn decode_stdin(format_hint: Option<&str>) -> Result<PcmData, crate::error::Error> {
82    let mut buf = Vec::new();
83    let mut handle = std::io::stdin().take(STDIN_MAX_BYTES + 1);
84    handle
85        .read_to_end(&mut buf)
86        .map_err(crate::error::Error::Io)?;
87
88    if buf.is_empty() {
89        return Err(crate::error::Error::NoInput);
90    }
91    if buf.len() as u64 > STDIN_MAX_BYTES {
92        return Err(crate::error::Error::Config(format!(
93            "stdin input exceeds maximum size of {STDIN_MAX_BYTES} bytes"
94        )));
95    }
96
97    if is_ogg_opus_magic(&buf[..buf.len().min(12)]) {
98        return decode_ogg_opus(Cursor::new(buf));
99    }
100
101    let source = MediaSourceStream::new(Box::new(Cursor::new(buf.clone())), Default::default());
102
103    let mut hint = Hint::new();
104    if let Some(fmt) = format_hint {
105        hint.with_extension(fmt);
106    }
107
108    match decode_stream(source, hint) {
109        Ok(pcm) => Ok(pcm),
110        Err(crate::error::Error::AudioDecode(ref e))
111            if e.to_string().contains("unsupported codec") =>
112        {
113            tracing::info!("symphonia unsupported codec, trying OGG/Opus fallback");
114            decode_ogg_opus(Cursor::new(buf))
115        }
116        Err(e) => Err(e),
117    }
118}
119
120pub fn is_ogg_opus_magic(header: &[u8]) -> bool {
121    if header.len() < 4 {
122        return false;
123    }
124    &header[..4] == b"OggS"
125}
126
127fn decode_stream(source: MediaSourceStream, hint: Hint) -> Result<PcmData, crate::error::Error> {
128    let probed = symphonia::default::get_probe()
129        .format(
130            &hint,
131            source,
132            &FormatOptions::default(),
133            &MetadataOptions::default(),
134        )
135        .map_err(|e| crate::error::Error::AudioDecode(anyhow::anyhow!("probe failed: {e}")))?;
136
137    let mut reader = probed.format;
138
139    let track = reader
140        .tracks()
141        .iter()
142        .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
143        .ok_or_else(|| crate::error::Error::AudioDecode(anyhow::anyhow!("no audio track found")))?;
144
145    let track_id = track.id;
146    let codec_params = track.codec_params.clone();
147
148    let sample_rate = codec_params
149        .sample_rate
150        .ok_or_else(|| crate::error::Error::AudioDecode(anyhow::anyhow!("unknown sample rate")))?;
151
152    let channels = codec_params.channels.map(|c| c.count()).unwrap_or(2);
153
154    let mut decoder = symphonia::default::get_codecs()
155        .make(&codec_params, &DecoderOptions::default())
156        .map_err(|e| crate::error::Error::AudioDecode(anyhow::anyhow!("codec init failed: {e}")))?;
157
158    let mut all_samples: Vec<i16> = Vec::new();
159
160    loop {
161        let packet = match reader.next_packet() {
162            Ok(p) => p,
163            Err(symphonia::core::errors::Error::IoError(e))
164                if e.kind() == std::io::ErrorKind::UnexpectedEof =>
165            {
166                break;
167            }
168            Err(_) => continue,
169        };
170
171        if packet.track_id() != track_id {
172            continue;
173        }
174
175        let audio_buf = match decoder.decode(&packet) {
176            Ok(buf) => buf,
177            Err(_) => continue,
178        };
179
180        extract_i16_samples(&audio_buf, &mut all_samples);
181    }
182
183    if all_samples.is_empty() {
184        return Err(crate::error::Error::AudioDecode(anyhow::anyhow!(
185            "no audio samples decoded"
186        )));
187    }
188
189    Ok(PcmData {
190        samples: all_samples,
191        sample_rate,
192        channels,
193    })
194}
195
196pub fn to_mono(samples: &[i16], channels: usize) -> Vec<i16> {
197    if channels == 1 {
198        return samples.to_vec();
199    }
200
201    let num_frames = samples.len() / channels;
202    let mut mono = Vec::with_capacity(num_frames);
203
204    for frame in 0..num_frames {
205        let mut sum: i32 = 0;
206        for ch in 0..channels {
207            sum += samples[frame * channels + ch] as i32;
208        }
209        let avg = sum / channels as i32;
210        mono.push(avg.clamp(i16::MIN as i32, i16::MAX as i32) as i16);
211    }
212
213    mono
214}
215
216pub fn i16_to_f32(samples: &[i16]) -> Vec<f32> {
217    samples.iter().map(|&s| s as f32 / 32768.0).collect()
218}
219
220fn decode_ogg_opus<R: Read + Seek>(mut reader: R) -> Result<PcmData, crate::error::Error> {
221    use ogg::reading::PacketReader;
222
223    let mut ogg_reader = PacketReader::new(&mut reader);
224    let mut channels = 1u8;
225    let mut pre_skip = OPUS_PRESKIP_SAMPLES;
226    let mut header_packets = 0u8;
227
228    while header_packets < 2 {
229        let pkt = ogg_reader
230            .read_packet_expected()
231            .map_err(|e| crate::error::Error::AudioDecode(anyhow::anyhow!("ogg header: {e}")))?;
232
233        if header_packets == 0 && pkt.data.len() >= 16 && &pkt.data[..8] == b"OpusHead" {
234            channels = pkt.data[9];
235            pre_skip = u32::from_le_bytes([pkt.data[10], pkt.data[11], pkt.data[12], pkt.data[13]])
236                as usize;
237        }
238        header_packets += 1;
239    }
240
241    let channels_usize = channels.max(1) as usize;
242    let output_rate = 48000;
243
244    let mut decoder = opus_decoder::OpusDecoder::new(output_rate, channels_usize)
245        .map_err(|e| crate::error::Error::AudioDecode(anyhow::anyhow!("opus init: {e:?}")))?;
246
247    let max_frame = opus_decoder::OpusDecoder::MAX_FRAME_SIZE_48K;
248    let mut pcm_buf = vec![0i16; max_frame * channels_usize];
249    let mut all_samples: Vec<i16> = Vec::new();
250    let mut samples_to_skip = pre_skip;
251
252    loop {
253        let pkt = match ogg_reader.read_packet() {
254            Ok(Some(p)) => p,
255            Ok(None) => break,
256            Err(_) => continue,
257        };
258
259        match decoder.decode(&pkt.data, &mut pcm_buf, false) {
260            Ok(samples_per_channel) => {
261                let total = samples_per_channel * channels_usize;
262                let slice = &pcm_buf[..total];
263
264                if samples_to_skip >= total {
265                    samples_to_skip -= total;
266                } else if samples_to_skip > 0 {
267                    let kept = &slice[samples_to_skip..];
268                    all_samples.extend_from_slice(kept);
269                    samples_to_skip = 0;
270                } else {
271                    all_samples.extend_from_slice(slice);
272                }
273            }
274            Err(_) => continue,
275        }
276    }
277
278    if all_samples.is_empty() {
279        return Err(crate::error::Error::AudioDecode(anyhow::anyhow!(
280            "no audio samples decoded from OGG/Opus"
281        )));
282    }
283
284    tracing::info!(
285        samples = all_samples.len(),
286        channels = channels_usize,
287        preskip_discarded = pre_skip,
288        "OGG/Opus decoded via fallback"
289    );
290
291    Ok(PcmData {
292        samples: all_samples,
293        sample_rate: output_rate,
294        channels: channels_usize,
295    })
296}
297
298fn extract_i16_samples(buffer: &AudioBufferRef, dest: &mut Vec<i16>) {
299    match buffer {
300        AudioBufferRef::U8(buf) => {
301            let ch = buf.spec().channels.count();
302            let frames = buf.frames();
303            dest.reserve(frames * ch);
304            for f in 0..frames {
305                for c in 0..ch {
306                    dest.push(((buf.chan(c)[f] as i32 - 128) * 256) as i16);
307                }
308            }
309        }
310        AudioBufferRef::S16(buf) => {
311            let ch = buf.spec().channels.count();
312            let frames = buf.frames();
313            dest.reserve(frames * ch);
314            for f in 0..frames {
315                for c in 0..ch {
316                    dest.push(buf.chan(c)[f]);
317                }
318            }
319        }
320        AudioBufferRef::S32(buf) => {
321            let ch = buf.spec().channels.count();
322            let frames = buf.frames();
323            dest.reserve(frames * ch);
324            for f in 0..frames {
325                for c in 0..ch {
326                    dest.push((buf.chan(c)[f] >> 16) as i16);
327                }
328            }
329        }
330        AudioBufferRef::F32(buf) => {
331            let ch = buf.spec().channels.count();
332            let frames = buf.frames();
333            dest.reserve(frames * ch);
334            for f in 0..frames {
335                for c in 0..ch {
336                    let v = buf.chan(c)[f].clamp(-1.0, 1.0);
337                    dest.push((v * 32767.0) as i16);
338                }
339            }
340        }
341        AudioBufferRef::F64(buf) => {
342            let ch = buf.spec().channels.count();
343            let frames = buf.frames();
344            dest.reserve(frames * ch);
345            for f in 0..frames {
346                for c in 0..ch {
347                    let v = buf.chan(c)[f].clamp(-1.0, 1.0);
348                    dest.push((v * 32767.0) as i16);
349                }
350            }
351        }
352        _ => {}
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    #[test]
361    fn to_mono_passthrough_single_channel() {
362        let samples = vec![100i16, 200, 300];
363        let result = to_mono(&samples, 1);
364        assert_eq!(result, samples);
365    }
366
367    #[test]
368    fn to_mono_averages_stereo() {
369        let samples = vec![100i16, 200, 300, 400];
370        let result = to_mono(&samples, 2);
371        assert_eq!(result, vec![150, 350]);
372    }
373
374    #[test]
375    fn i16_to_f32_converts_correctly() {
376        let samples = vec![0i16, 32767, -32768];
377        let result = i16_to_f32(&samples);
378        assert!((result[0] - 0.0).abs() < 0.001);
379        assert!((result[1] - 1.0).abs() < 0.001);
380        assert!((result[2] - (-1.0)).abs() < 0.001);
381    }
382
383    #[test]
384    fn opus_magic_detected() {
385        let ogg = b"OggS\x00\x02\x00\x00\x00\x00\x00\x00";
386        assert!(is_ogg_opus_magic(ogg));
387    }
388
389    #[test]
390    fn non_opus_not_detected() {
391        let wav = b"RIFF\x00\x00\x00\x00";
392        assert!(!is_ogg_opus_magic(wav));
393    }
394
395    #[test]
396    fn short_buffer_not_detected() {
397        let short = b"Og";
398        assert!(!is_ogg_opus_magic(short));
399    }
400
401    #[test]
402    fn pcm_data_duration_computed_correctly() {
403        let pcm = PcmData {
404            samples: vec![0i16; 16000 * 2],
405            sample_rate: 16000,
406            channels: 1,
407        };
408        assert!((pcm.duration_seconds() - 2.0).abs() < 0.001);
409    }
410}