1use anyhow::{Context, Result, bail};
2use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
3use ringbuf::{
4 HeapRb,
5 traits::{Consumer, Producer, Split},
6};
7use rubato::audioadapter_buffers::direct::SequentialSlice;
8use rubato::{Fft, FixedSync, Resampler};
9use std::path::Path;
10use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
11use std::sync::{Arc, Mutex};
12use std::thread;
13use std::thread::JoinHandle;
14
15const RESAMPLER_CHUNK: usize = 1024;
16
17pub struct MicCapture {
19 _stream: cpal::Stream,
20 _resample_thread: Option<std::thread::JoinHandle<()>>,
21 pub consumer: Arc<Mutex<ringbuf::HeapCons<f32>>>,
22 pub native_sample_rate: u32,
23 pub native_channels: u16,
24 pub dropped_samples: Arc<AtomicU64>,
28}
29
30impl MicCapture {
31 pub fn open(device_name: Option<&str>) -> Result<Self> {
35 let host = cpal::default_host();
36
37 let device = if let Some(name) = device_name {
38 host.input_devices()
39 .context("Failed to enumerate input devices")?
40 .find(|d| d.name().ok().as_deref() == Some(name))
41 .context(format!("Input device '{name}' not found"))?
42 } else {
43 host.default_input_device()
44 .context("No default input device found")?
45 };
46
47 let device_name = device.name().unwrap_or_else(|_| "<unknown>".to_string());
48 tracing::info!("Opening input device: {}", device_name);
49
50 let config = device
51 .default_input_config()
52 .context("Failed to get device config")?;
53
54 let native_sample_rate = config.sample_rate().0;
55 let native_channels = config.channels();
56
57 tracing::info!(
58 "Device config: {} Hz, {} channels, format: {:?}",
59 native_sample_rate,
60 native_channels,
61 config.sample_format()
62 );
63
64 let ring_raw_rb = HeapRb::<f32>::new(((native_sample_rate as usize) * 2).max(4096));
67 let (prod_raw, cons_raw) = ring_raw_rb.split();
68
69 let ring_16khz_rb = HeapRb::<f32>::new(32_000);
70 let (prod_16khz, cons_16khz) = ring_16khz_rb.split();
71 let ring_16khz_cons = Arc::new(Mutex::new(cons_16khz));
72
73 let dropped_samples = Arc::new(AtomicU64::new(0));
74
75 let dropped_samples_resample = Arc::clone(&dropped_samples);
78 let native_sr = native_sample_rate;
79
80 let resample_thread = thread::Builder::new()
81 .name("audio-resample".to_string())
82 .spawn(move || {
83 run_resample(native_sr, cons_raw, prod_16khz, dropped_samples_resample);
84 })
85 .context("Failed to spawn resample thread")?;
86
87 let dropped_samples_callback = Arc::clone(&dropped_samples);
91
92 let stream = match config.sample_format() {
93 cpal::SampleFormat::F32 => {
94 let mut prod = prod_raw;
95 let dropped_cb = dropped_samples_callback;
96 device.build_input_stream(
97 &config.into(),
98 move |data: &[f32], _info| {
99 let mono: Vec<f32> = data
100 .chunks_exact(native_channels as usize)
101 .map(|ch| ch.iter().sum::<f32>() / native_channels as f32)
102 .collect();
103 let written = prod.push_slice(&mono);
104 let dropped = mono.len() - written;
105 if dropped > 0 {
106 dropped_cb.fetch_add(dropped as u64, Ordering::Relaxed);
107 }
108 },
109 |err| tracing::error!("Stream error: {}", err),
110 None,
111 )
112 }
113 cpal::SampleFormat::I16 => {
114 let mut prod = prod_raw;
115 let dropped_cb = dropped_samples_callback;
116 device.build_input_stream(
117 &config.into(),
118 move |data: &[i16], _info| {
119 let mono: Vec<f32> = data
120 .chunks_exact(native_channels as usize)
121 .map(|ch| {
122 ch.iter().map(|&s| s as f32 / 32_768.0).sum::<f32>()
123 / native_channels as f32
124 })
125 .collect();
126 let written = prod.push_slice(&mono);
127 let dropped = mono.len() - written;
128 if dropped > 0 {
129 dropped_cb.fetch_add(dropped as u64, Ordering::Relaxed);
130 }
131 },
132 |err| tracing::error!("Stream error: {}", err),
133 None,
134 )
135 }
136 cpal::SampleFormat::U16 => {
137 let mut prod = prod_raw;
138 let dropped_cb = dropped_samples_callback;
139 device.build_input_stream(
140 &config.into(),
141 move |data: &[u16], _info| {
142 let mono: Vec<f32> = data
143 .chunks_exact(native_channels as usize)
144 .map(|ch| {
145 ch.iter()
146 .map(|&s| (s as f32 - 32_768.0) / 32_768.0)
147 .sum::<f32>()
148 / native_channels as f32
149 })
150 .collect();
151 let written = prod.push_slice(&mono);
152 let dropped = mono.len() - written;
153 if dropped > 0 {
154 dropped_cb.fetch_add(dropped as u64, Ordering::Relaxed);
155 }
156 },
157 |err| tracing::error!("Stream error: {}", err),
158 None,
159 )
160 }
161 _ => bail!("Unsupported sample format"),
162 }?;
163
164 stream.play().context("Failed to start playback")?;
165
166 Ok(MicCapture {
167 _stream: stream,
168 _resample_thread: Some(resample_thread),
169 consumer: ring_16khz_cons,
170 native_sample_rate,
171 native_channels,
172 dropped_samples,
173 })
174 }
175
176 pub fn stop(self) {
178 drop(self._stream);
179 drop(self._resample_thread);
180 }
181}
182
183fn run_resample(
186 native_sr: u32,
187 mut cons_raw: ringbuf::HeapCons<f32>,
188 mut prod_16khz: ringbuf::HeapProd<f32>,
189 dropped_samples: Arc<AtomicU64>,
190) {
191 if native_sr == 16_000 {
193 let mut buf = vec![0.0f32; 4096];
194 loop {
195 let n = cons_raw.pop_slice(&mut buf);
196 if n == 0 {
197 thread::sleep(std::time::Duration::from_millis(1));
198 continue;
199 }
200 let written = prod_16khz.push_slice(&buf[..n]);
201 let dropped = n - written;
202 if dropped > 0 {
203 dropped_samples.fetch_add(dropped as u64, Ordering::Relaxed);
204 }
205 }
206 }
207
208 let mut resampler = match Fft::<f32>::new(
210 native_sr as usize,
211 16_000,
212 RESAMPLER_CHUNK,
213 2,
214 1,
215 FixedSync::Input,
216 ) {
217 Ok(r) => r,
218 Err(e) => {
219 tracing::error!("rubato resampler construction failed ({native_sr} Hz → 16 kHz): {e}");
220 return;
221 }
222 };
223
224 let max_in = resampler.input_frames_max();
225 let max_out = resampler.output_frames_max();
226 let mut input_buf = vec![0.0f32; max_in];
227 let mut output_buf = vec![0.0f32; max_out];
228
229 loop {
230 let chunk_in = resampler.input_frames_next();
231 let mut filled = 0;
232 while filled < chunk_in {
233 let n = cons_raw.pop_slice(&mut input_buf[filled..chunk_in]);
234 if n == 0 {
235 thread::sleep(std::time::Duration::from_millis(1));
236 continue;
237 }
238 filled += n;
239 }
240
241 let in_adapter = match SequentialSlice::new(&input_buf[..chunk_in], 1, chunk_in) {
242 Ok(a) => a,
243 Err(e) => {
244 tracing::error!("rubato input adapter: {e:?}");
245 return;
246 }
247 };
248 let mut out_adapter = match SequentialSlice::new_mut(&mut output_buf[..max_out], 1, max_out)
249 {
250 Ok(a) => a,
251 Err(e) => {
252 tracing::error!("rubato output adapter: {e:?}");
253 return;
254 }
255 };
256 let out_n = match resampler.process_into_buffer(&in_adapter, &mut out_adapter, None) {
257 Ok((_in_used, out_n)) => out_n,
258 Err(e) => {
259 tracing::error!("rubato resample failed: {e:?}");
260 return;
261 }
262 };
263
264 if out_n > 0 {
265 let written = prod_16khz.push_slice(&output_buf[..out_n]);
266 let dropped = out_n - written;
267 if dropped > 0 {
268 dropped_samples.fetch_add(dropped as u64, Ordering::Relaxed);
269 }
270 }
271 }
272}
273
274pub struct FakeMic {
279 pub consumer: Arc<Mutex<ringbuf::HeapCons<f32>>>,
280 pub native_sample_rate: u32,
281 pub native_channels: u16,
282 is_done: Arc<AtomicBool>,
283 shutdown: Arc<AtomicBool>,
284}
285
286impl FakeMic {
287 pub fn open(path: &Path, realtime: bool) -> Result<(Self, JoinHandle<()>)> {
293 let reader = hound::WavReader::open(path)
294 .with_context(|| format!("open WAV: {}", path.display()))?;
295 let spec = reader.spec();
296
297 if spec.sample_rate != 16_000 {
298 bail!(
299 "FakeMic: expected 16 kHz WAV, got {} Hz ({})",
300 spec.sample_rate,
301 path.display()
302 );
303 }
304
305 let channels = spec.channels;
306 let native_sample_rate = spec.sample_rate;
307
308 let samples: Vec<f32> = match spec.sample_format {
310 hound::SampleFormat::Float => reader
311 .into_samples::<f32>()
312 .collect::<std::result::Result<Vec<_>, _>>()
313 .context("read f32 samples")?
314 .chunks(channels as usize)
315 .map(|ch| ch.iter().sum::<f32>() / channels as f32)
316 .collect(),
317 hound::SampleFormat::Int => {
318 let max_val = (1i64 << (spec.bits_per_sample - 1)) as f32;
319 reader
320 .into_samples::<i32>()
321 .collect::<std::result::Result<Vec<_>, _>>()
322 .context("read i32 samples")?
323 .chunks(channels as usize)
324 .map(|ch| ch.iter().map(|&s| s as f32 / max_val).sum::<f32>() / channels as f32)
325 .collect()
326 }
327 };
328
329 let rb = HeapRb::<f32>::new(32_000);
330 let (prod, cons) = rb.split();
331 let consumer = Arc::new(Mutex::new(cons));
332 let prod = Arc::new(Mutex::new(prod));
333
334 let is_done = Arc::new(AtomicBool::new(false));
335 let shutdown = Arc::new(AtomicBool::new(false));
336
337 let is_done_thread = Arc::clone(&is_done);
338 let shutdown_thread = Arc::clone(&shutdown);
339 let prod_thread = Arc::clone(&prod);
340
341 let handle = thread::Builder::new()
342 .name("fake-mic-feeder".to_string())
343 .spawn(move || {
344 const CHUNK: usize = 512;
345 let mut offset = 0;
346 while offset < samples.len() && !shutdown_thread.load(Ordering::Relaxed) {
347 let end = (offset + CHUNK).min(samples.len());
348 let chunk = &samples[offset..end];
349
350 loop {
352 if shutdown_thread.load(Ordering::Relaxed) {
353 return;
354 }
355 let written = prod_thread.lock().unwrap().push_slice(chunk);
356 if written == chunk.len() {
357 break;
358 }
359 thread::sleep(std::time::Duration::from_millis(1));
363 }
364
365 offset = end;
366
367 if realtime {
368 thread::sleep(std::time::Duration::from_millis(32));
370 }
371 }
372 is_done_thread.store(true, Ordering::Release);
373 })
374 .context("spawn fake-mic feeder thread")?;
375
376 Ok((
377 FakeMic {
378 consumer,
379 native_sample_rate,
380 native_channels: 1,
381 is_done,
382 shutdown,
383 },
384 handle,
385 ))
386 }
387
388 pub fn is_done(&self) -> bool {
390 self.is_done.load(Ordering::Acquire)
391 }
392
393 pub fn stop(self) {
395 self.shutdown.store(true, Ordering::SeqCst);
396 }
397}
398
399pub enum CaptureSource {
401 Microphone(MicCapture),
402 File(FakeMic),
403}
404
405impl CaptureSource {
406 pub fn pop_samples(&self, buf: &mut [f32]) -> usize {
408 match self {
409 CaptureSource::Microphone(mic) => mic.consumer.lock().unwrap().pop_slice(buf),
410 CaptureSource::File(fake) => fake.consumer.lock().unwrap().pop_slice(buf),
411 }
412 }
413
414 pub fn is_file_done(&self) -> bool {
417 match self {
418 CaptureSource::Microphone(_) => false,
419 CaptureSource::File(fake) => fake.is_done(),
420 }
421 }
422
423 pub fn native_sample_rate(&self) -> u32 {
424 match self {
425 CaptureSource::Microphone(mic) => mic.native_sample_rate,
426 CaptureSource::File(fake) => fake.native_sample_rate,
427 }
428 }
429
430 pub fn native_channels(&self) -> u16 {
431 match self {
432 CaptureSource::Microphone(mic) => mic.native_channels,
433 CaptureSource::File(fake) => fake.native_channels,
434 }
435 }
436
437 pub fn dropped_samples(&self) -> u64 {
440 match self {
441 CaptureSource::Microphone(mic) => mic.dropped_samples.load(Ordering::Relaxed),
442 CaptureSource::File(_) => 0,
443 }
444 }
445
446 pub fn stop(self) {
448 match self {
449 CaptureSource::Microphone(mic) => mic.stop(),
450 CaptureSource::File(fake) => fake.stop(),
451 }
452 }
453}
454
455pub fn list_input_devices() -> Result<Vec<(String, String)>> {
457 let mut devices = Vec::new();
458 let hosts = cpal::ALL_HOSTS;
459
460 for host_id in hosts {
461 let host = cpal::host_from_id(*host_id).context("Failed to instantiate host")?;
462 let host_name = host.id().name().to_string();
463
464 if let Ok(input_devices) = host.input_devices() {
465 for device in input_devices {
466 if let Ok(name) = device.name() {
467 devices.push((host_name.clone(), name));
468 }
469 }
470 }
471 }
472
473 Ok(devices)
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 #[test]
481 #[ignore]
482 fn test_mic_capture_opens() {
483 let result = MicCapture::open(None);
485 assert!(
486 result.is_ok(),
487 "Failed to open MicCapture: {:?}",
488 result.err()
489 );
490
491 if let Ok(mic) = result {
492 assert!(mic.native_sample_rate > 0);
493 assert!(mic.native_channels > 0);
494 mic.stop();
495 }
496 }
497
498 #[test]
499 fn test_list_input_devices() {
500 let result = list_input_devices();
501 assert!(result.is_ok());
502 }
503}