whisper_macos_cli/audio/
decode.rs1use std::io::{Cursor, Read, Seek, SeekFrom};
2use std::path::Path;
3
4use symphonia::core::audio::{AudioBufferRef, Signal};
5use symphonia::core::codecs::{CODEC_TYPE_NULL, DecoderOptions};
6use symphonia::core::formats::FormatOptions;
7use symphonia::core::io::MediaSourceStream;
8use symphonia::core::meta::MetadataOptions;
9use symphonia::core::probe::Hint;
10
11const OPUS_PRESKIP_SAMPLES: usize = 3840;
12const STDIN_MAX_BYTES: u64 = 2 * 1024 * 1024 * 1024;
13
14pub struct PcmData {
15 pub samples: Vec<i16>,
16 pub sample_rate: u32,
17 pub channels: usize,
18}
19
20impl PcmData {
21 pub fn duration_seconds(&self) -> f64 {
22 if self.sample_rate == 0 || self.channels == 0 {
23 return 0.0;
24 }
25 self.samples.len() as f64 / (self.sample_rate as f64 * self.channels as f64)
26 }
27}
28
29pub fn decode_file(path: &Path) -> Result<PcmData, crate::error::Error> {
30 let file = std::fs::File::open(path).map_err(|e| {
31 if e.kind() == std::io::ErrorKind::NotFound {
32 crate::error::Error::InputNotFound {
33 path: path.display().to_string(),
34 }
35 } else {
36 crate::error::Error::Io(e)
37 }
38 })?;
39
40 let mut header = [0u8; 12];
41 let header_len = match (&file).read(&mut header) {
42 Ok(n) => n,
43 Err(e) => return Err(crate::error::Error::Io(e)),
44 };
45 if let Err(e) = (&file).seek(SeekFrom::Start(0)) {
46 return Err(crate::error::Error::Io(e));
47 }
48
49 if header_len >= 4 && is_ogg_opus_magic(&header[..header_len]) {
50 return decode_ogg_opus(file);
51 }
52
53 let source = MediaSourceStream::new(Box::new(file), Default::default());
54
55 let mut hint = Hint::new();
56 if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
57 hint.with_extension(ext);
58 }
59
60 match decode_stream(source, hint) {
61 Ok(pcm) => Ok(pcm),
62 Err(crate::error::Error::AudioDecode(ref e))
63 if e.to_string().contains("unsupported codec") =>
64 {
65 tracing::info!("symphonia unsupported codec, trying OGG/Opus fallback");
66 let file2 = std::fs::File::open(path).map_err(|e| {
67 if e.kind() == std::io::ErrorKind::NotFound {
68 crate::error::Error::InputNotFound {
69 path: path.display().to_string(),
70 }
71 } else {
72 crate::error::Error::Io(e)
73 }
74 })?;
75 decode_ogg_opus(file2)
76 }
77 Err(e) => Err(e),
78 }
79}
80
81pub fn decode_stdin(format_hint: Option<&str>) -> Result<PcmData, crate::error::Error> {
82 let mut buf = Vec::new();
83 let mut handle = std::io::stdin().take(STDIN_MAX_BYTES + 1);
84 handle
85 .read_to_end(&mut buf)
86 .map_err(crate::error::Error::Io)?;
87
88 if buf.is_empty() {
89 return Err(crate::error::Error::NoInput);
90 }
91 if buf.len() as u64 > STDIN_MAX_BYTES {
92 return Err(crate::error::Error::Config(format!(
93 "stdin input exceeds maximum size of {STDIN_MAX_BYTES} bytes"
94 )));
95 }
96
97 if is_ogg_opus_magic(&buf[..buf.len().min(12)]) {
98 return decode_ogg_opus(Cursor::new(buf));
99 }
100
101 let source = MediaSourceStream::new(Box::new(Cursor::new(buf.clone())), Default::default());
102
103 let mut hint = Hint::new();
104 if let Some(fmt) = format_hint {
105 hint.with_extension(fmt);
106 }
107
108 match decode_stream(source, hint) {
109 Ok(pcm) => Ok(pcm),
110 Err(crate::error::Error::AudioDecode(ref e))
111 if e.to_string().contains("unsupported codec") =>
112 {
113 tracing::info!("symphonia unsupported codec, trying OGG/Opus fallback");
114 decode_ogg_opus(Cursor::new(buf))
115 }
116 Err(e) => Err(e),
117 }
118}
119
120pub fn is_ogg_opus_magic(header: &[u8]) -> bool {
121 if header.len() < 4 {
122 return false;
123 }
124 &header[..4] == b"OggS"
125}
126
127fn decode_stream(source: MediaSourceStream, hint: Hint) -> Result<PcmData, crate::error::Error> {
128 let probed = symphonia::default::get_probe()
129 .format(
130 &hint,
131 source,
132 &FormatOptions::default(),
133 &MetadataOptions::default(),
134 )
135 .map_err(|e| crate::error::Error::AudioDecode(anyhow::anyhow!("probe failed: {e}")))?;
136
137 let mut reader = probed.format;
138
139 let track = reader
140 .tracks()
141 .iter()
142 .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
143 .ok_or_else(|| crate::error::Error::AudioDecode(anyhow::anyhow!("no audio track found")))?;
144
145 let track_id = track.id;
146 let codec_params = track.codec_params.clone();
147
148 let sample_rate = codec_params
149 .sample_rate
150 .ok_or_else(|| crate::error::Error::AudioDecode(anyhow::anyhow!("unknown sample rate")))?;
151
152 let channels = codec_params.channels.map(|c| c.count()).unwrap_or(2);
153
154 let mut decoder = symphonia::default::get_codecs()
155 .make(&codec_params, &DecoderOptions::default())
156 .map_err(|e| crate::error::Error::AudioDecode(anyhow::anyhow!("codec init failed: {e}")))?;
157
158 let mut all_samples: Vec<i16> = Vec::new();
159
160 loop {
161 let packet = match reader.next_packet() {
162 Ok(p) => p,
163 Err(symphonia::core::errors::Error::IoError(e))
164 if e.kind() == std::io::ErrorKind::UnexpectedEof =>
165 {
166 break;
167 }
168 Err(_) => continue,
169 };
170
171 if packet.track_id() != track_id {
172 continue;
173 }
174
175 let audio_buf = match decoder.decode(&packet) {
176 Ok(buf) => buf,
177 Err(_) => continue,
178 };
179
180 extract_i16_samples(&audio_buf, &mut all_samples);
181 }
182
183 if all_samples.is_empty() {
184 return Err(crate::error::Error::AudioDecode(anyhow::anyhow!(
185 "no audio samples decoded"
186 )));
187 }
188
189 Ok(PcmData {
190 samples: all_samples,
191 sample_rate,
192 channels,
193 })
194}
195
196pub fn to_mono(samples: &[i16], channels: usize) -> Vec<i16> {
197 if channels == 1 {
198 return samples.to_vec();
199 }
200
201 let num_frames = samples.len() / channels;
202 let mut mono = Vec::with_capacity(num_frames);
203
204 for frame in 0..num_frames {
205 let mut sum: i32 = 0;
206 for ch in 0..channels {
207 sum += samples[frame * channels + ch] as i32;
208 }
209 let avg = sum / channels as i32;
210 mono.push(avg.clamp(i16::MIN as i32, i16::MAX as i32) as i16);
211 }
212
213 mono
214}
215
216pub fn i16_to_f32(samples: &[i16]) -> Vec<f32> {
217 samples.iter().map(|&s| s as f32 / 32768.0).collect()
218}
219
220fn decode_ogg_opus<R: Read + Seek>(mut reader: R) -> Result<PcmData, crate::error::Error> {
221 use ogg::reading::PacketReader;
222
223 let mut ogg_reader = PacketReader::new(&mut reader);
224 let mut channels = 1u8;
225 let mut pre_skip = OPUS_PRESKIP_SAMPLES;
226 let mut header_packets = 0u8;
227
228 while header_packets < 2 {
229 let pkt = ogg_reader
230 .read_packet_expected()
231 .map_err(|e| crate::error::Error::AudioDecode(anyhow::anyhow!("ogg header: {e}")))?;
232
233 if header_packets == 0 && pkt.data.len() >= 16 && &pkt.data[..8] == b"OpusHead" {
234 channels = pkt.data[9];
235 pre_skip = u32::from_le_bytes([pkt.data[10], pkt.data[11], pkt.data[12], pkt.data[13]])
236 as usize;
237 }
238 header_packets += 1;
239 }
240
241 let channels_usize = channels.max(1) as usize;
242 let output_rate = 48000;
243
244 let mut decoder = opus_decoder::OpusDecoder::new(output_rate, channels_usize)
245 .map_err(|e| crate::error::Error::AudioDecode(anyhow::anyhow!("opus init: {e:?}")))?;
246
247 let max_frame = opus_decoder::OpusDecoder::MAX_FRAME_SIZE_48K;
248 let mut pcm_buf = vec![0i16; max_frame * channels_usize];
249 let mut all_samples: Vec<i16> = Vec::new();
250 let mut samples_to_skip = pre_skip;
251
252 loop {
253 let pkt = match ogg_reader.read_packet() {
254 Ok(Some(p)) => p,
255 Ok(None) => break,
256 Err(_) => continue,
257 };
258
259 match decoder.decode(&pkt.data, &mut pcm_buf, false) {
260 Ok(samples_per_channel) => {
261 let total = samples_per_channel * channels_usize;
262 let slice = &pcm_buf[..total];
263
264 if samples_to_skip >= total {
265 samples_to_skip -= total;
266 } else if samples_to_skip > 0 {
267 let kept = &slice[samples_to_skip..];
268 all_samples.extend_from_slice(kept);
269 samples_to_skip = 0;
270 } else {
271 all_samples.extend_from_slice(slice);
272 }
273 }
274 Err(_) => continue,
275 }
276 }
277
278 if all_samples.is_empty() {
279 return Err(crate::error::Error::AudioDecode(anyhow::anyhow!(
280 "no audio samples decoded from OGG/Opus"
281 )));
282 }
283
284 tracing::info!(
285 samples = all_samples.len(),
286 channels = channels_usize,
287 preskip_discarded = pre_skip,
288 "OGG/Opus decoded via fallback"
289 );
290
291 Ok(PcmData {
292 samples: all_samples,
293 sample_rate: output_rate,
294 channels: channels_usize,
295 })
296}
297
298fn extract_i16_samples(buffer: &AudioBufferRef, dest: &mut Vec<i16>) {
299 match buffer {
300 AudioBufferRef::U8(buf) => {
301 let ch = buf.spec().channels.count();
302 let frames = buf.frames();
303 dest.reserve(frames * ch);
304 for f in 0..frames {
305 for c in 0..ch {
306 dest.push(((buf.chan(c)[f] as i32 - 128) * 256) as i16);
307 }
308 }
309 }
310 AudioBufferRef::S16(buf) => {
311 let ch = buf.spec().channels.count();
312 let frames = buf.frames();
313 dest.reserve(frames * ch);
314 for f in 0..frames {
315 for c in 0..ch {
316 dest.push(buf.chan(c)[f]);
317 }
318 }
319 }
320 AudioBufferRef::S32(buf) => {
321 let ch = buf.spec().channels.count();
322 let frames = buf.frames();
323 dest.reserve(frames * ch);
324 for f in 0..frames {
325 for c in 0..ch {
326 dest.push((buf.chan(c)[f] >> 16) as i16);
327 }
328 }
329 }
330 AudioBufferRef::F32(buf) => {
331 let ch = buf.spec().channels.count();
332 let frames = buf.frames();
333 dest.reserve(frames * ch);
334 for f in 0..frames {
335 for c in 0..ch {
336 let v = buf.chan(c)[f].clamp(-1.0, 1.0);
337 dest.push((v * 32767.0) as i16);
338 }
339 }
340 }
341 AudioBufferRef::F64(buf) => {
342 let ch = buf.spec().channels.count();
343 let frames = buf.frames();
344 dest.reserve(frames * ch);
345 for f in 0..frames {
346 for c in 0..ch {
347 let v = buf.chan(c)[f].clamp(-1.0, 1.0);
348 dest.push((v * 32767.0) as i16);
349 }
350 }
351 }
352 _ => {}
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn to_mono_passthrough_single_channel() {
362 let samples = vec![100i16, 200, 300];
363 let result = to_mono(&samples, 1);
364 assert_eq!(result, samples);
365 }
366
367 #[test]
368 fn to_mono_averages_stereo() {
369 let samples = vec![100i16, 200, 300, 400];
370 let result = to_mono(&samples, 2);
371 assert_eq!(result, vec![150, 350]);
372 }
373
374 #[test]
375 fn i16_to_f32_converts_correctly() {
376 let samples = vec![0i16, 32767, -32768];
377 let result = i16_to_f32(&samples);
378 assert!((result[0] - 0.0).abs() < 0.001);
379 assert!((result[1] - 1.0).abs() < 0.001);
380 assert!((result[2] - (-1.0)).abs() < 0.001);
381 }
382
383 #[test]
384 fn opus_magic_detected() {
385 let ogg = b"OggS\x00\x02\x00\x00\x00\x00\x00\x00";
386 assert!(is_ogg_opus_magic(ogg));
387 }
388
389 #[test]
390 fn non_opus_not_detected() {
391 let wav = b"RIFF\x00\x00\x00\x00";
392 assert!(!is_ogg_opus_magic(wav));
393 }
394
395 #[test]
396 fn short_buffer_not_detected() {
397 let short = b"Og";
398 assert!(!is_ogg_opus_magic(short));
399 }
400
401 #[test]
402 fn pcm_data_duration_computed_correctly() {
403 let pcm = PcmData {
404 samples: vec![0i16; 16000 * 2],
405 sample_rate: 16000,
406 channels: 1,
407 };
408 assert!((pcm.duration_seconds() - 2.0).abs() < 0.001);
409 }
410}