Skip to main content

whisperforge_core/
audio_capture.rs

1use anyhow::{Context, Result, bail};
2use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
3use ringbuf::{
4    HeapRb,
5    traits::{Consumer, Producer, Split},
6};
7use rubato::audioadapter_buffers::direct::SequentialSlice;
8use rubato::{Fft, FixedSync, Resampler};
9use std::path::Path;
10use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
11use std::sync::{Arc, Mutex};
12use std::thread;
13use std::thread::JoinHandle;
14
15const RESAMPLER_CHUNK: usize = 1024;
16
17/// Captures audio from a microphone and resamples it to 16 kHz mono.
18pub struct MicCapture {
19    _stream: cpal::Stream,
20    _resample_thread: Option<std::thread::JoinHandle<()>>,
21    pub consumer: Arc<Mutex<ringbuf::HeapCons<f32>>>,
22    pub native_sample_rate: u32,
23    pub native_channels: u16,
24    /// Cumulative samples dropped because a downstream ring filled up. Non-zero
25    /// values indicate the consumer (typically the decoder) is not keeping up
26    /// with real-time audio arrival.
27    pub dropped_samples: Arc<AtomicU64>,
28}
29
30impl MicCapture {
31    /// Open a microphone input device.
32    /// If `device_name` is None, uses the system default device.
33    /// Returns the consumer side of a ring buffer containing 16 kHz mono samples.
34    pub fn open(device_name: Option<&str>) -> Result<Self> {
35        let host = cpal::default_host();
36
37        let device = if let Some(name) = device_name {
38            host.input_devices()
39                .context("Failed to enumerate input devices")?
40                .find(|d| d.name().ok().as_deref() == Some(name))
41                .context(format!("Input device '{name}' not found"))?
42        } else {
43            host.default_input_device()
44                .context("No default input device found")?
45        };
46
47        let device_name = device.name().unwrap_or_else(|_| "<unknown>".to_string());
48        tracing::info!("Opening input device: {}", device_name);
49
50        let config = device
51            .default_input_config()
52            .context("Failed to get device config")?;
53
54        let native_sample_rate = config.sample_rate().0;
55        let native_channels = config.channels();
56
57        tracing::info!(
58            "Device config: {} Hz, {} channels, format: {:?}",
59            native_sample_rate,
60            native_channels,
61            config.sample_format()
62        );
63
64        // Native-rate ring sized to a couple of seconds so callback bursts never block the
65        // audio thread. cpal callbacks typically deliver 10–40 ms per fire, so this is generous.
66        let ring_raw_rb = HeapRb::<f32>::new(((native_sample_rate as usize) * 2).max(4096));
67        let (prod_raw, cons_raw) = ring_raw_rb.split();
68
69        let ring_16khz_rb = HeapRb::<f32>::new(32_000);
70        let (prod_16khz, cons_16khz) = ring_16khz_rb.split();
71        let ring_16khz_cons = Arc::new(Mutex::new(cons_16khz));
72
73        let dropped_samples = Arc::new(AtomicU64::new(0));
74
75        // Spawn the resampling worker thread — moves `cons_raw` and `prod_16khz` in
76        // directly; ringbuf SPSC needs no synchronization.
77        let dropped_samples_resample = Arc::clone(&dropped_samples);
78        let native_sr = native_sample_rate;
79
80        let resample_thread = thread::Builder::new()
81            .name("audio-resample".to_string())
82            .spawn(move || {
83                run_resample(native_sr, cons_raw, prod_16khz, dropped_samples_resample);
84            })
85            .context("Failed to spawn resample thread")?;
86
87        // Build the cpal input stream — `prod_raw` is moved into the matching arm's
88        // closure (cpal picks exactly one format, so the producer has exactly one owner)
89        // so no Arc<Mutex<>> is needed around the audio callback's producer.
90        let dropped_samples_callback = Arc::clone(&dropped_samples);
91
92        let stream = match config.sample_format() {
93            cpal::SampleFormat::F32 => {
94                let mut prod = prod_raw;
95                let dropped_cb = dropped_samples_callback;
96                device.build_input_stream(
97                    &config.into(),
98                    move |data: &[f32], _info| {
99                        let mono: Vec<f32> = data
100                            .chunks_exact(native_channels as usize)
101                            .map(|ch| ch.iter().sum::<f32>() / native_channels as f32)
102                            .collect();
103                        let written = prod.push_slice(&mono);
104                        let dropped = mono.len() - written;
105                        if dropped > 0 {
106                            dropped_cb.fetch_add(dropped as u64, Ordering::Relaxed);
107                        }
108                    },
109                    |err| tracing::error!("Stream error: {}", err),
110                    None,
111                )
112            }
113            cpal::SampleFormat::I16 => {
114                let mut prod = prod_raw;
115                let dropped_cb = dropped_samples_callback;
116                device.build_input_stream(
117                    &config.into(),
118                    move |data: &[i16], _info| {
119                        let mono: Vec<f32> = data
120                            .chunks_exact(native_channels as usize)
121                            .map(|ch| {
122                                ch.iter().map(|&s| s as f32 / 32_768.0).sum::<f32>()
123                                    / native_channels as f32
124                            })
125                            .collect();
126                        let written = prod.push_slice(&mono);
127                        let dropped = mono.len() - written;
128                        if dropped > 0 {
129                            dropped_cb.fetch_add(dropped as u64, Ordering::Relaxed);
130                        }
131                    },
132                    |err| tracing::error!("Stream error: {}", err),
133                    None,
134                )
135            }
136            cpal::SampleFormat::U16 => {
137                let mut prod = prod_raw;
138                let dropped_cb = dropped_samples_callback;
139                device.build_input_stream(
140                    &config.into(),
141                    move |data: &[u16], _info| {
142                        let mono: Vec<f32> = data
143                            .chunks_exact(native_channels as usize)
144                            .map(|ch| {
145                                ch.iter()
146                                    .map(|&s| (s as f32 - 32_768.0) / 32_768.0)
147                                    .sum::<f32>()
148                                    / native_channels as f32
149                            })
150                            .collect();
151                        let written = prod.push_slice(&mono);
152                        let dropped = mono.len() - written;
153                        if dropped > 0 {
154                            dropped_cb.fetch_add(dropped as u64, Ordering::Relaxed);
155                        }
156                    },
157                    |err| tracing::error!("Stream error: {}", err),
158                    None,
159                )
160            }
161            _ => bail!("Unsupported sample format"),
162        }?;
163
164        stream.play().context("Failed to start playback")?;
165
166        Ok(MicCapture {
167            _stream: stream,
168            _resample_thread: Some(resample_thread),
169            consumer: ring_16khz_cons,
170            native_sample_rate,
171            native_channels,
172            dropped_samples,
173        })
174    }
175
176    /// Stop capturing audio and drop the stream.
177    pub fn stop(self) {
178        drop(self._stream);
179        drop(self._resample_thread);
180    }
181}
182
183/// Worker body: drain `cons_raw` at native rate, resample to 16 kHz mono with rubato,
184/// push into `prod_16khz`. Loops forever until the upstream ring is dropped (shutdown).
185fn run_resample(
186    native_sr: u32,
187    mut cons_raw: ringbuf::HeapCons<f32>,
188    mut prod_16khz: ringbuf::HeapProd<f32>,
189    dropped_samples: Arc<AtomicU64>,
190) {
191    // Fast path: device is already 16 kHz, just forward without resampling.
192    if native_sr == 16_000 {
193        let mut buf = vec![0.0f32; 4096];
194        loop {
195            let n = cons_raw.pop_slice(&mut buf);
196            if n == 0 {
197                thread::sleep(std::time::Duration::from_millis(1));
198                continue;
199            }
200            let written = prod_16khz.push_slice(&buf[..n]);
201            let dropped = n - written;
202            if dropped > 0 {
203                dropped_samples.fetch_add(dropped as u64, Ordering::Relaxed);
204            }
205        }
206    }
207
208    // FFT resampler: fixed input chunk, mono, 1.1× ratio headroom isn't required for synchronous.
209    let mut resampler = match Fft::<f32>::new(
210        native_sr as usize,
211        16_000,
212        RESAMPLER_CHUNK,
213        2,
214        1,
215        FixedSync::Input,
216    ) {
217        Ok(r) => r,
218        Err(e) => {
219            tracing::error!("rubato resampler construction failed ({native_sr} Hz → 16 kHz): {e}");
220            return;
221        }
222    };
223
224    let max_in = resampler.input_frames_max();
225    let max_out = resampler.output_frames_max();
226    let mut input_buf = vec![0.0f32; max_in];
227    let mut output_buf = vec![0.0f32; max_out];
228
229    loop {
230        let chunk_in = resampler.input_frames_next();
231        let mut filled = 0;
232        while filled < chunk_in {
233            let n = cons_raw.pop_slice(&mut input_buf[filled..chunk_in]);
234            if n == 0 {
235                thread::sleep(std::time::Duration::from_millis(1));
236                continue;
237            }
238            filled += n;
239        }
240
241        let in_adapter = match SequentialSlice::new(&input_buf[..chunk_in], 1, chunk_in) {
242            Ok(a) => a,
243            Err(e) => {
244                tracing::error!("rubato input adapter: {e:?}");
245                return;
246            }
247        };
248        let mut out_adapter = match SequentialSlice::new_mut(&mut output_buf[..max_out], 1, max_out)
249        {
250            Ok(a) => a,
251            Err(e) => {
252                tracing::error!("rubato output adapter: {e:?}");
253                return;
254            }
255        };
256        let out_n = match resampler.process_into_buffer(&in_adapter, &mut out_adapter, None) {
257            Ok((_in_used, out_n)) => out_n,
258            Err(e) => {
259                tracing::error!("rubato resample failed: {e:?}");
260                return;
261            }
262        };
263
264        if out_n > 0 {
265            let written = prod_16khz.push_slice(&output_buf[..out_n]);
266            let dropped = out_n - written;
267            if dropped > 0 {
268                dropped_samples.fetch_add(dropped as u64, Ordering::Relaxed);
269            }
270        }
271    }
272}
273
274/// File-backed drop-in for `MicCapture`. Feeds a WAV file into a ring buffer at
275/// either real-time (16 kHz wall-clock) or as-fast-as-possible pace.
276///
277/// The WAV must be 16 kHz; multi-channel files are downmixed to mono.
278pub struct FakeMic {
279    pub consumer: Arc<Mutex<ringbuf::HeapCons<f32>>>,
280    pub native_sample_rate: u32,
281    pub native_channels: u16,
282    is_done: Arc<AtomicBool>,
283    shutdown: Arc<AtomicBool>,
284}
285
286impl FakeMic {
287    /// Open a WAV file and start a background feeder thread.
288    ///
289    /// Returns the `FakeMic` consumer handle and the feeder `JoinHandle`.
290    /// `realtime=true` throttles the feeder to 16 kHz wall-clock pace;
291    /// `realtime=false` pushes as fast as the consumer drains.
292    pub fn open(path: &Path, realtime: bool) -> Result<(Self, JoinHandle<()>)> {
293        let reader = hound::WavReader::open(path)
294            .with_context(|| format!("open WAV: {}", path.display()))?;
295        let spec = reader.spec();
296
297        if spec.sample_rate != 16_000 {
298            bail!(
299                "FakeMic: expected 16 kHz WAV, got {} Hz ({})",
300                spec.sample_rate,
301                path.display()
302            );
303        }
304
305        let channels = spec.channels;
306        let native_sample_rate = spec.sample_rate;
307
308        // Read all samples upfront; downmix to mono f32.
309        let samples: Vec<f32> = match spec.sample_format {
310            hound::SampleFormat::Float => reader
311                .into_samples::<f32>()
312                .collect::<std::result::Result<Vec<_>, _>>()
313                .context("read f32 samples")?
314                .chunks(channels as usize)
315                .map(|ch| ch.iter().sum::<f32>() / channels as f32)
316                .collect(),
317            hound::SampleFormat::Int => {
318                let max_val = (1i64 << (spec.bits_per_sample - 1)) as f32;
319                reader
320                    .into_samples::<i32>()
321                    .collect::<std::result::Result<Vec<_>, _>>()
322                    .context("read i32 samples")?
323                    .chunks(channels as usize)
324                    .map(|ch| ch.iter().map(|&s| s as f32 / max_val).sum::<f32>() / channels as f32)
325                    .collect()
326            }
327        };
328
329        let rb = HeapRb::<f32>::new(32_000);
330        let (prod, cons) = rb.split();
331        let consumer = Arc::new(Mutex::new(cons));
332        let prod = Arc::new(Mutex::new(prod));
333
334        let is_done = Arc::new(AtomicBool::new(false));
335        let shutdown = Arc::new(AtomicBool::new(false));
336
337        let is_done_thread = Arc::clone(&is_done);
338        let shutdown_thread = Arc::clone(&shutdown);
339        let prod_thread = Arc::clone(&prod);
340
341        let handle = thread::Builder::new()
342            .name("fake-mic-feeder".to_string())
343            .spawn(move || {
344                const CHUNK: usize = 512;
345                let mut offset = 0;
346                while offset < samples.len() && !shutdown_thread.load(Ordering::Relaxed) {
347                    let end = (offset + CHUNK).min(samples.len());
348                    let chunk = &samples[offset..end];
349
350                    // Push chunk; if buffer is full, wait briefly and retry.
351                    loop {
352                        if shutdown_thread.load(Ordering::Relaxed) {
353                            return;
354                        }
355                        let written = prod_thread.lock().unwrap().push_slice(chunk);
356                        if written == chunk.len() {
357                            break;
358                        }
359                        // Partial push — back up and retry after a short sleep.
360                        // (Only possible if ring buffer is smaller than CHUNK; in
361                        // practice the buffer is 32 000 >> 512 so this is a safety net.)
362                        thread::sleep(std::time::Duration::from_millis(1));
363                    }
364
365                    offset = end;
366
367                    if realtime {
368                        // 512 samples @ 16 kHz = 32 ms
369                        thread::sleep(std::time::Duration::from_millis(32));
370                    }
371                }
372                is_done_thread.store(true, Ordering::Release);
373            })
374            .context("spawn fake-mic feeder thread")?;
375
376        Ok((
377            FakeMic {
378                consumer,
379                native_sample_rate,
380                native_channels: 1,
381                is_done,
382                shutdown,
383            },
384            handle,
385        ))
386    }
387
388    /// Returns `true` once the feeder thread has pushed all file samples.
389    pub fn is_done(&self) -> bool {
390        self.is_done.load(Ordering::Acquire)
391    }
392
393    /// Signal the feeder thread to stop (used on early shutdown).
394    pub fn stop(self) {
395        self.shutdown.store(true, Ordering::SeqCst);
396    }
397}
398
399/// Unified capture source: either a live microphone or a file-fed `FakeMic`.
400pub enum CaptureSource {
401    Microphone(MicCapture),
402    File(FakeMic),
403}
404
405impl CaptureSource {
406    /// Drain up to `buf.len()` 16 kHz mono samples. Returns the count actually read.
407    pub fn pop_samples(&self, buf: &mut [f32]) -> usize {
408        match self {
409            CaptureSource::Microphone(mic) => mic.consumer.lock().unwrap().pop_slice(buf),
410            CaptureSource::File(fake) => fake.consumer.lock().unwrap().pop_slice(buf),
411        }
412    }
413
414    /// Returns `true` when a `File` source's feeder thread has finished and the
415    /// ring buffer is empty. Always `false` for a live `Microphone` source.
416    pub fn is_file_done(&self) -> bool {
417        match self {
418            CaptureSource::Microphone(_) => false,
419            CaptureSource::File(fake) => fake.is_done(),
420        }
421    }
422
423    pub fn native_sample_rate(&self) -> u32 {
424        match self {
425            CaptureSource::Microphone(mic) => mic.native_sample_rate,
426            CaptureSource::File(fake) => fake.native_sample_rate,
427        }
428    }
429
430    pub fn native_channels(&self) -> u16 {
431        match self {
432            CaptureSource::Microphone(mic) => mic.native_channels,
433            CaptureSource::File(fake) => fake.native_channels,
434        }
435    }
436
437    /// Cumulative samples dropped due to a full downstream ring. Always 0 for file sources.
438    /// A growing value means the consumer (decoder) can't keep up with real-time audio.
439    pub fn dropped_samples(&self) -> u64 {
440        match self {
441            CaptureSource::Microphone(mic) => mic.dropped_samples.load(Ordering::Relaxed),
442            CaptureSource::File(_) => 0,
443        }
444    }
445
446    /// Shut down the underlying capture source.
447    pub fn stop(self) {
448        match self {
449            CaptureSource::Microphone(mic) => mic.stop(),
450            CaptureSource::File(fake) => fake.stop(),
451        }
452    }
453}
454
455/// Lists all available input devices with their host names.
456pub fn list_input_devices() -> Result<Vec<(String, String)>> {
457    let mut devices = Vec::new();
458    let hosts = cpal::ALL_HOSTS;
459
460    for host_id in hosts {
461        let host = cpal::host_from_id(*host_id).context("Failed to instantiate host")?;
462        let host_name = host.id().name().to_string();
463
464        if let Ok(input_devices) = host.input_devices() {
465            for device in input_devices {
466                if let Ok(name) = device.name() {
467                    devices.push((host_name.clone(), name));
468                }
469            }
470        }
471    }
472
473    Ok(devices)
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479
480    #[test]
481    #[ignore]
482    fn test_mic_capture_opens() {
483        // This test requires an actual audio device and is ignored in CI
484        let result = MicCapture::open(None);
485        assert!(
486            result.is_ok(),
487            "Failed to open MicCapture: {:?}",
488            result.err()
489        );
490
491        if let Ok(mic) = result {
492            assert!(mic.native_sample_rate > 0);
493            assert!(mic.native_channels > 0);
494            mic.stop();
495        }
496    }
497
498    #[test]
499    fn test_list_input_devices() {
500        let result = list_input_devices();
501        assert!(result.is_ok());
502    }
503}