Skip to main content

whisperforge_core/
audio_stream.rs

1//! Streaming audio decoder for processing large files without pre-loading.
2//!
3//! This module provides `AudioChunkIterator`, a pull-based iterator that decodes
4//! and resamples audio on-demand, holding at most one 30-second chunk in memory.
5//!
6//! # Memory efficiency
7//!
8//! Instead of loading entire files before processing:
9//! - Peak RAM ≈ 2 MB per chunk (16 kHz mono, 30s = 480k samples)
10//! - Processes 1-hour files with <10 MB working set (model weights excluded)
11//! - Streaming resampler avoids full-file copies
12
13use anyhow::{Result, anyhow};
14use audioadapter_buffers::direct::SequentialSliceOfVecs;
15use rubato::{Async, FixedAsync, Resampler, SincInterpolationParameters, WindowFunction};
16use std::path::Path;
17use symphonia::core::codecs::audio::CODEC_ID_NULL_AUDIO;
18use symphonia::core::{
19    codecs::audio::AudioDecoderOptions,
20    formats::{FormatOptions, FormatReader, probe::Hint},
21    io::MediaSourceStream,
22    meta::MetadataOptions,
23};
24
25/// A decoded chunk of audio at 16 kHz mono.
26#[derive(Debug, Clone)]
27pub struct AudioChunk {
28    /// 16 kHz mono samples.
29    pub samples: Vec<f32>,
30    /// Start time in seconds (relative to file).
31    pub start_sec: f32,
32    /// End time in seconds (relative to file).
33    pub end_sec: f32,
34}
35
36/// Streaming iterator over audio file chunks.
37///
38/// Decodes and resamples on-demand, holding only one packet + resampler state + overlap.
39/// Yields chunks with automatic 1-second overlap for alignment across boundaries.
40pub struct AudioChunkIterator {
41    reader: Box<dyn FormatReader>,
42    decoder: Box<dyn symphonia::core::codecs::audio::AudioDecoder>,
43    track_id: u32,
44    sample_rate: u32,
45    channels: u16,
46
47    // Rubato streaming resampler state (survives between packets)
48    resampler: Option<Async<f32>>,
49
50    // Overlap buffer from previous chunk
51    overlap_buf: Vec<f32>,
52
53    // Configuration
54    chunk_samples: usize,   // e.g., 480_000 for 30s @ 16 kHz
55    overlap_samples: usize, // e.g., 16_000 for 1s @ 16 kHz
56    target_rate: u32,
57
58    // Position tracking
59    samples_out: usize, // Total 16 kHz samples emitted so far
60    done: bool,
61}
62
63impl AudioChunkIterator {
64    /// Create a streaming iterator from an audio file.
65    ///
66    /// # Arguments
67    /// * `path` - Path to audio file (WAV, MP3, FLAC, OGG, AAC, MKV, etc.)
68    /// * `chunk_sec` - Chunk duration in seconds
69    /// * `overlap_sec` - Overlap between chunks in seconds
70    pub fn new<P: AsRef<Path>>(path: P, chunk_sec: f32, overlap_sec: f32) -> Result<Self> {
71        let path = path.as_ref();
72        let file = std::fs::File::open(path)
73            .map_err(|e| anyhow!("Failed to open audio file '{}': {}", path.display(), e))?;
74        let mss = MediaSourceStream::new(Box::new(file), Default::default());
75
76        let mut hint = Hint::new();
77        if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
78            hint.with_extension(ext);
79        }
80
81        let format = symphonia::default::get_probe()
82            .probe(
83                &hint,
84                mss,
85                FormatOptions::default(),
86                MetadataOptions::default(),
87            )
88            .map_err(|e| anyhow!("Unsupported audio format '{}': {}", path.display(), e))?;
89        let track = format
90            .tracks()
91            .iter()
92            .find(|t| {
93                t.codec_params
94                    .as_ref()
95                    .and_then(|cp| cp.audio())
96                    .map(|ap| ap.codec != CODEC_ID_NULL_AUDIO)
97                    .unwrap_or(false)
98            })
99            .ok_or_else(|| anyhow!("No audio tracks found in '{}'", path.display()))?;
100
101        let track_id = track.id;
102        let codec_params = track
103            .codec_params
104            .as_ref()
105            .and_then(|cp| cp.audio())
106            .ok_or_else(|| anyhow!("Missing codec parameters in '{}'", path.display()))?;
107
108        let sample_rate = codec_params
109            .sample_rate
110            .ok_or_else(|| anyhow!("Unknown sample rate in '{}'", path.display()))?;
111        let channels = codec_params
112            .channels
113            .as_ref()
114            .ok_or_else(|| anyhow!("Unknown channel count in '{}'", path.display()))?
115            .count() as u16;
116
117        let decoder = symphonia::default::get_codecs()
118            .make_audio_decoder(codec_params, &AudioDecoderOptions::default())
119            .map_err(|e| anyhow!("Failed to create decoder for '{}': {}", path.display(), e))?;
120
121        let target_rate = 16000;
122        let chunk_samples = (chunk_sec * target_rate as f32) as usize;
123        let overlap_samples = (overlap_sec * target_rate as f32) as usize;
124
125        // Create resampler if needed (always mono output: 1 channel)
126        let resampler = if sample_rate != target_rate {
127            let f_ratio = target_rate as f64 / sample_rate as f64;
128            let params = SincInterpolationParameters {
129                sinc_len: 256,
130                f_cutoff: 0.95,
131                interpolation: rubato::SincInterpolationType::Cubic,
132                oversampling_factor: 256,
133                window: WindowFunction::BlackmanHarris2,
134            };
135            let resampler = Async::<f32>::new_sinc(
136                f_ratio,
137                2.0,
138                &params,
139                128, // Small enough that any codec packet fits
140                1,   // Always mono output
141                FixedAsync::Input,
142            )
143            .map_err(|e| anyhow!("Failed to create resampler: {}", e))?;
144            Some(resampler)
145        } else {
146            None
147        };
148
149        Ok(Self {
150            reader: format,
151            decoder,
152            track_id,
153            sample_rate,
154            channels,
155            resampler,
156            overlap_buf: Vec::new(),
157            chunk_samples,
158            overlap_samples,
159            target_rate,
160            samples_out: 0,
161            done: false,
162        })
163    }
164
165    /// Create a 30s-chunk iterator with 1s overlap (Whisper default).
166    pub fn default_whisper<P: AsRef<Path>>(path: P) -> Result<Self> {
167        Self::new(path, 30.0, 1.0)
168    }
169
170    /// Decode and accumulate the next chunk of audio.
171    fn next_chunk(&mut self) -> Result<Option<AudioChunk>> {
172        let target_samples = self.chunk_samples;
173        let mut samples = Vec::with_capacity(target_samples);
174
175        // Prepend overlap from previous chunk
176        samples.extend_from_slice(&self.overlap_buf);
177        let overlap_len = self.overlap_buf.len();
178
179        // Accumulate decoded packets until we have enough samples
180        loop {
181            if samples.len() >= target_samples {
182                break;
183            }
184
185            let packet = match self.reader.next_packet() {
186                Ok(Some(p)) => p,
187                Ok(None) => {
188                    self.done = true;
189                    break;
190                }
191                Err(symphonia::core::errors::Error::ResetRequired) => {
192                    continue;
193                }
194                Err(e) => {
195                    return Err(anyhow!("Error reading packet: {}", e));
196                }
197            };
198
199            if packet.track_id != self.track_id {
200                continue;
201            }
202
203            let decoded = match self.decoder.decode(&packet) {
204                Ok(d) => d,
205                Err(symphonia::core::errors::Error::IoError(_)) => continue,
206                Err(e) => {
207                    return Err(anyhow!("Decode error: {}", e));
208                }
209            };
210
211            let mut packet_samples = Vec::new();
212            decoded.copy_to_vec_interleaved::<f32>(&mut packet_samples);
213
214            // Resample or copy samples
215            if self.resampler.is_some() {
216                let mut resampler = self.resampler.take().unwrap();
217                self.resample_packet_into_buffer(&packet_samples, &mut resampler, &mut samples)?;
218                self.resampler = Some(resampler);
219            } else {
220                // Already at target rate; convert to mono if needed
221                samples.extend_from_slice(&packet_samples);
222            }
223        }
224
225        // Convert to mono if needed
226        if self.channels > 1 && self.resampler.is_none() {
227            samples = self.to_mono(&samples);
228        }
229
230        // Trim to target size
231        if samples.len() > target_samples {
232            samples.truncate(target_samples);
233        }
234
235        // No new content beyond what we prepended — EOF with nothing more to yield
236        if samples.len() <= overlap_len {
237            self.overlap_buf.clear();
238            return Ok(None);
239        }
240
241        // Save overlap for next chunk
242        let overlap_start = if samples.len() >= self.overlap_samples {
243            samples.len() - self.overlap_samples
244        } else {
245            0
246        };
247        self.overlap_buf = samples[overlap_start..].to_vec();
248
249        // Calculate timestamps
250        let start_sec = self.samples_out as f32 / self.target_rate as f32;
251        let end_sec = (self.samples_out + samples.len()) as f32 / self.target_rate as f32;
252        self.samples_out += samples.len() - overlap_len;
253
254        Ok(Some(AudioChunk {
255            samples,
256            start_sec,
257            end_sec,
258        }))
259    }
260
261    /// Resample a packet of audio into the samples buffer.
262    fn resample_packet_into_buffer(
263        &mut self,
264        packet_samples: &[f32],
265        resampler: &mut Async<f32>,
266        output: &mut Vec<f32>,
267    ) -> Result<()> {
268        if packet_samples.is_empty() {
269            return Ok(());
270        }
271
272        // Deinterleave samples into per-channel vectors
273        let frames_per_channel = packet_samples.len() / self.channels as usize;
274        let mut input_channels: Vec<Vec<f32>> =
275            vec![Vec::with_capacity(frames_per_channel); self.channels as usize];
276
277        for (i, &sample) in packet_samples.iter().enumerate() {
278            let channel = i % self.channels as usize;
279            input_channels[channel].push(sample);
280        }
281
282        // Convert to mono by averaging channels
283        if self.channels > 1 {
284            input_channels[0] = (0..frames_per_channel)
285                .map(|f| input_channels.iter().map(|ch| ch[f]).sum::<f32>() / self.channels as f32)
286                .collect();
287            input_channels.truncate(1);
288        }
289
290        // Prepare adapters for rubato
291        let input_adapter = SequentialSliceOfVecs::new(&input_channels, 1, frames_per_channel)
292            .map_err(|e| anyhow!("Failed to create input adapter: {}", e))?;
293
294        // Estimate output size
295        let f_ratio = self.target_rate as f64 / self.sample_rate as f64;
296        let estimated_output_frames = (frames_per_channel as f64 * f_ratio) as usize + 10; // +10 for safety
297
298        let mut output_channels: Vec<Vec<f32>> = vec![vec![0.0f32; estimated_output_frames]; 1];
299        let mut output_adapter =
300            SequentialSliceOfVecs::new_mut(&mut output_channels, 1, estimated_output_frames)
301                .map_err(|e| anyhow!("Failed to create output adapter: {}", e))?;
302
303        let mut indexing = rubato::Indexing {
304            input_offset: 0,
305            output_offset: 0,
306            active_channels_mask: None,
307            partial_len: None,
308        };
309
310        let mut input_frames_left = frames_per_channel;
311        let mut input_frames_next = resampler.input_frames_next();
312
313        // Process full chunks from the resampler
314        while input_frames_left >= input_frames_next {
315            let (frames_read, frames_written) = resampler
316                .process_into_buffer(&input_adapter, &mut output_adapter, Some(&indexing))
317                .map_err(|e| anyhow!("Resampling failed: {}", e))?;
318
319            indexing.input_offset += frames_read;
320            indexing.output_offset += frames_written;
321            input_frames_left -= frames_read;
322            input_frames_next = resampler.input_frames_next();
323        }
324
325        // Remaining frames less than chunk size are buffered internally by the resampler
326        // and will be output on the next packet. No need to force-process them here.
327
328        output.extend_from_slice(&output_channels[0][..indexing.output_offset]);
329        Ok(())
330    }
331
332    /// Convert interleaved samples to mono by averaging channels.
333    fn to_mono(&self, samples: &[f32]) -> Vec<f32> {
334        if self.channels == 1 {
335            return samples.to_vec();
336        }
337        samples
338            .chunks(self.channels as usize)
339            .map(|chunk| chunk.iter().sum::<f32>() / self.channels as f32)
340            .collect()
341    }
342}
343
344impl Iterator for AudioChunkIterator {
345    type Item = Result<AudioChunk>;
346
347    fn next(&mut self) -> Option<Self::Item> {
348        if self.done && self.overlap_buf.is_empty() {
349            return None;
350        }
351        match self.next_chunk() {
352            Ok(Some(chunk)) => Some(Ok(chunk)),
353            Ok(None) => None,
354            Err(e) => Some(Err(e)),
355        }
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    #[test]
364    fn test_audio_chunk_iterator_creation() -> Result<()> {
365        // Just test that we can construct without a file (will fail gracefully)
366        match AudioChunkIterator::default_whisper("/nonexistent/file.wav") {
367            Err(e) => {
368                assert!(e.to_string().contains("Failed to open audio file"));
369                Ok(())
370            }
371            Ok(_) => Err(anyhow!("Should have failed to open nonexistent file")),
372        }
373    }
374}