1use anyhow::{Result, anyhow};
14use audioadapter_buffers::direct::SequentialSliceOfVecs;
15use rubato::{Async, FixedAsync, Resampler, SincInterpolationParameters, WindowFunction};
16use std::path::Path;
17use symphonia::core::codecs::audio::CODEC_ID_NULL_AUDIO;
18use symphonia::core::{
19 codecs::audio::AudioDecoderOptions,
20 formats::{FormatOptions, FormatReader, probe::Hint},
21 io::MediaSourceStream,
22 meta::MetadataOptions,
23};
24
25#[derive(Debug, Clone)]
27pub struct AudioChunk {
28 pub samples: Vec<f32>,
30 pub start_sec: f32,
32 pub end_sec: f32,
34}
35
36pub struct AudioChunkIterator {
41 reader: Box<dyn FormatReader>,
42 decoder: Box<dyn symphonia::core::codecs::audio::AudioDecoder>,
43 track_id: u32,
44 sample_rate: u32,
45 channels: u16,
46
47 resampler: Option<Async<f32>>,
49
50 overlap_buf: Vec<f32>,
52
53 chunk_samples: usize, overlap_samples: usize, target_rate: u32,
57
58 samples_out: usize, done: bool,
61}
62
63impl AudioChunkIterator {
64 pub fn new<P: AsRef<Path>>(path: P, chunk_sec: f32, overlap_sec: f32) -> Result<Self> {
71 let path = path.as_ref();
72 let file = std::fs::File::open(path)
73 .map_err(|e| anyhow!("Failed to open audio file '{}': {}", path.display(), e))?;
74 let mss = MediaSourceStream::new(Box::new(file), Default::default());
75
76 let mut hint = Hint::new();
77 if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
78 hint.with_extension(ext);
79 }
80
81 let format = symphonia::default::get_probe()
82 .probe(
83 &hint,
84 mss,
85 FormatOptions::default(),
86 MetadataOptions::default(),
87 )
88 .map_err(|e| anyhow!("Unsupported audio format '{}': {}", path.display(), e))?;
89 let track = format
90 .tracks()
91 .iter()
92 .find(|t| {
93 t.codec_params
94 .as_ref()
95 .and_then(|cp| cp.audio())
96 .map(|ap| ap.codec != CODEC_ID_NULL_AUDIO)
97 .unwrap_or(false)
98 })
99 .ok_or_else(|| anyhow!("No audio tracks found in '{}'", path.display()))?;
100
101 let track_id = track.id;
102 let codec_params = track
103 .codec_params
104 .as_ref()
105 .and_then(|cp| cp.audio())
106 .ok_or_else(|| anyhow!("Missing codec parameters in '{}'", path.display()))?;
107
108 let sample_rate = codec_params
109 .sample_rate
110 .ok_or_else(|| anyhow!("Unknown sample rate in '{}'", path.display()))?;
111 let channels = codec_params
112 .channels
113 .as_ref()
114 .ok_or_else(|| anyhow!("Unknown channel count in '{}'", path.display()))?
115 .count() as u16;
116
117 let decoder = symphonia::default::get_codecs()
118 .make_audio_decoder(codec_params, &AudioDecoderOptions::default())
119 .map_err(|e| anyhow!("Failed to create decoder for '{}': {}", path.display(), e))?;
120
121 let target_rate = 16000;
122 let chunk_samples = (chunk_sec * target_rate as f32) as usize;
123 let overlap_samples = (overlap_sec * target_rate as f32) as usize;
124
125 let resampler = if sample_rate != target_rate {
127 let f_ratio = target_rate as f64 / sample_rate as f64;
128 let params = SincInterpolationParameters {
129 sinc_len: 256,
130 f_cutoff: 0.95,
131 interpolation: rubato::SincInterpolationType::Cubic,
132 oversampling_factor: 256,
133 window: WindowFunction::BlackmanHarris2,
134 };
135 let resampler = Async::<f32>::new_sinc(
136 f_ratio,
137 2.0,
138 ¶ms,
139 128, 1, FixedAsync::Input,
142 )
143 .map_err(|e| anyhow!("Failed to create resampler: {}", e))?;
144 Some(resampler)
145 } else {
146 None
147 };
148
149 Ok(Self {
150 reader: format,
151 decoder,
152 track_id,
153 sample_rate,
154 channels,
155 resampler,
156 overlap_buf: Vec::new(),
157 chunk_samples,
158 overlap_samples,
159 target_rate,
160 samples_out: 0,
161 done: false,
162 })
163 }
164
165 pub fn default_whisper<P: AsRef<Path>>(path: P) -> Result<Self> {
167 Self::new(path, 30.0, 1.0)
168 }
169
170 fn next_chunk(&mut self) -> Result<Option<AudioChunk>> {
172 let target_samples = self.chunk_samples;
173 let mut samples = Vec::with_capacity(target_samples);
174
175 samples.extend_from_slice(&self.overlap_buf);
177 let overlap_len = self.overlap_buf.len();
178
179 loop {
181 if samples.len() >= target_samples {
182 break;
183 }
184
185 let packet = match self.reader.next_packet() {
186 Ok(Some(p)) => p,
187 Ok(None) => {
188 self.done = true;
189 break;
190 }
191 Err(symphonia::core::errors::Error::ResetRequired) => {
192 continue;
193 }
194 Err(e) => {
195 return Err(anyhow!("Error reading packet: {}", e));
196 }
197 };
198
199 if packet.track_id != self.track_id {
200 continue;
201 }
202
203 let decoded = match self.decoder.decode(&packet) {
204 Ok(d) => d,
205 Err(symphonia::core::errors::Error::IoError(_)) => continue,
206 Err(e) => {
207 return Err(anyhow!("Decode error: {}", e));
208 }
209 };
210
211 let mut packet_samples = Vec::new();
212 decoded.copy_to_vec_interleaved::<f32>(&mut packet_samples);
213
214 if self.resampler.is_some() {
216 let mut resampler = self.resampler.take().unwrap();
217 self.resample_packet_into_buffer(&packet_samples, &mut resampler, &mut samples)?;
218 self.resampler = Some(resampler);
219 } else {
220 samples.extend_from_slice(&packet_samples);
222 }
223 }
224
225 if self.channels > 1 && self.resampler.is_none() {
227 samples = self.to_mono(&samples);
228 }
229
230 if samples.len() > target_samples {
232 samples.truncate(target_samples);
233 }
234
235 if samples.len() <= overlap_len {
237 self.overlap_buf.clear();
238 return Ok(None);
239 }
240
241 let overlap_start = if samples.len() >= self.overlap_samples {
243 samples.len() - self.overlap_samples
244 } else {
245 0
246 };
247 self.overlap_buf = samples[overlap_start..].to_vec();
248
249 let start_sec = self.samples_out as f32 / self.target_rate as f32;
251 let end_sec = (self.samples_out + samples.len()) as f32 / self.target_rate as f32;
252 self.samples_out += samples.len() - overlap_len;
253
254 Ok(Some(AudioChunk {
255 samples,
256 start_sec,
257 end_sec,
258 }))
259 }
260
261 fn resample_packet_into_buffer(
263 &mut self,
264 packet_samples: &[f32],
265 resampler: &mut Async<f32>,
266 output: &mut Vec<f32>,
267 ) -> Result<()> {
268 if packet_samples.is_empty() {
269 return Ok(());
270 }
271
272 let frames_per_channel = packet_samples.len() / self.channels as usize;
274 let mut input_channels: Vec<Vec<f32>> =
275 vec![Vec::with_capacity(frames_per_channel); self.channels as usize];
276
277 for (i, &sample) in packet_samples.iter().enumerate() {
278 let channel = i % self.channels as usize;
279 input_channels[channel].push(sample);
280 }
281
282 if self.channels > 1 {
284 input_channels[0] = (0..frames_per_channel)
285 .map(|f| input_channels.iter().map(|ch| ch[f]).sum::<f32>() / self.channels as f32)
286 .collect();
287 input_channels.truncate(1);
288 }
289
290 let input_adapter = SequentialSliceOfVecs::new(&input_channels, 1, frames_per_channel)
292 .map_err(|e| anyhow!("Failed to create input adapter: {}", e))?;
293
294 let f_ratio = self.target_rate as f64 / self.sample_rate as f64;
296 let estimated_output_frames = (frames_per_channel as f64 * f_ratio) as usize + 10; let mut output_channels: Vec<Vec<f32>> = vec![vec![0.0f32; estimated_output_frames]; 1];
299 let mut output_adapter =
300 SequentialSliceOfVecs::new_mut(&mut output_channels, 1, estimated_output_frames)
301 .map_err(|e| anyhow!("Failed to create output adapter: {}", e))?;
302
303 let mut indexing = rubato::Indexing {
304 input_offset: 0,
305 output_offset: 0,
306 active_channels_mask: None,
307 partial_len: None,
308 };
309
310 let mut input_frames_left = frames_per_channel;
311 let mut input_frames_next = resampler.input_frames_next();
312
313 while input_frames_left >= input_frames_next {
315 let (frames_read, frames_written) = resampler
316 .process_into_buffer(&input_adapter, &mut output_adapter, Some(&indexing))
317 .map_err(|e| anyhow!("Resampling failed: {}", e))?;
318
319 indexing.input_offset += frames_read;
320 indexing.output_offset += frames_written;
321 input_frames_left -= frames_read;
322 input_frames_next = resampler.input_frames_next();
323 }
324
325 output.extend_from_slice(&output_channels[0][..indexing.output_offset]);
329 Ok(())
330 }
331
332 fn to_mono(&self, samples: &[f32]) -> Vec<f32> {
334 if self.channels == 1 {
335 return samples.to_vec();
336 }
337 samples
338 .chunks(self.channels as usize)
339 .map(|chunk| chunk.iter().sum::<f32>() / self.channels as f32)
340 .collect()
341 }
342}
343
344impl Iterator for AudioChunkIterator {
345 type Item = Result<AudioChunk>;
346
347 fn next(&mut self) -> Option<Self::Item> {
348 if self.done && self.overlap_buf.is_empty() {
349 return None;
350 }
351 match self.next_chunk() {
352 Ok(Some(chunk)) => Some(Ok(chunk)),
353 Ok(None) => None,
354 Err(e) => Some(Err(e)),
355 }
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
364 fn test_audio_chunk_iterator_creation() -> Result<()> {
365 match AudioChunkIterator::default_whisper("/nonexistent/file.wav") {
367 Err(e) => {
368 assert!(e.to_string().contains("Failed to open audio file"));
369 Ok(())
370 }
371 Ok(_) => Err(anyhow!("Should have failed to open nonexistent file")),
372 }
373 }
374}