1use cpal::{
4 traits::{DeviceTrait, HostTrait, StreamTrait},
5 ChannelCount, Device, Host, SampleRate, Stream,
6};
7use std::sync::{Arc, Mutex};
8use std::time::Duration;
9use voirs_sdk::{Result, VoirsError};
10
11use super::AudioData;
12
13#[derive(Debug, Clone)]
15pub struct RealTimeStreamConfig {
16 pub sample_rate: u32,
18 pub channels: u16,
20 pub buffer_size: u32,
22 pub target_latency_ms: u32,
24 pub device_name: Option<String>,
26}
27
28impl Default for RealTimeStreamConfig {
29 fn default() -> Self {
30 Self {
31 sample_rate: 22050,
32 channels: 1,
33 buffer_size: 512, target_latency_ms: 50,
35 device_name: None,
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct BufferConfig {
43 pub buffer_count: usize,
45 pub buffer_size: usize,
47 pub underrun_threshold: usize,
49}
50
51impl Default for BufferConfig {
52 fn default() -> Self {
53 Self {
54 buffer_count: 8,
55 buffer_size: 512,
56 underrun_threshold: 2,
57 }
58 }
59}
60
61#[derive(Debug)]
63struct AudioBuffer {
64 data: Vec<f32>,
65 is_ready: bool,
66 timestamp: std::time::Instant,
67}
68
69impl AudioBuffer {
70 fn new(size: usize) -> Self {
71 Self {
72 data: vec![0.0; size],
73 is_ready: false,
74 timestamp: std::time::Instant::now(),
75 }
76 }
77
78 fn write_samples(&mut self, samples: &[f32]) {
79 let copy_len = samples.len().min(self.data.len());
80 self.data[..copy_len].copy_from_slice(&samples[..copy_len]);
81 self.is_ready = true;
82 self.timestamp = std::time::Instant::now();
83 }
84
85 fn read_samples(&mut self, output: &mut [f32]) -> usize {
86 if !self.is_ready {
87 for sample in output.iter_mut() {
89 *sample = 0.0;
90 }
91 return 0;
92 }
93
94 let copy_len = output.len().min(self.data.len());
95 output[..copy_len].copy_from_slice(&self.data[..copy_len]);
96
97 self.is_ready = false;
99
100 copy_len
101 }
102}
103
104#[derive(Debug, Clone)]
106pub struct StreamStats {
107 pub buffers_played: u64,
108 pub buffers_dropped: u64,
109 pub underruns: u64,
110 pub average_latency_ms: f32,
111 pub current_buffer_fill: f32,
112}
113
114impl Default for StreamStats {
115 fn default() -> Self {
116 Self {
117 buffers_played: 0,
118 buffers_dropped: 0,
119 underruns: 0,
120 average_latency_ms: 0.0,
121 current_buffer_fill: 0.0,
122 }
123 }
124}
125
126pub struct RealTimeAudioStream {
128 config: RealTimeStreamConfig,
129 buffer_config: BufferConfig,
130 device: Device,
131 stream: Option<Stream>,
132 buffers: Arc<Mutex<Vec<AudioBuffer>>>,
133 write_index: Arc<Mutex<usize>>,
134 read_index: Arc<Mutex<usize>>,
135 stats: Arc<Mutex<StreamStats>>,
136 is_active: Arc<Mutex<bool>>,
137}
138
139impl RealTimeAudioStream {
140 pub fn new(stream_config: RealTimeStreamConfig, buffer_config: BufferConfig) -> Result<Self> {
142 let host = cpal::default_host();
143 let device = if let Some(device_name) = &stream_config.device_name {
144 Self::find_device_by_name(&host, device_name)?.ok_or_else(|| {
145 VoirsError::device_error(
146 "audio_device",
147 format!("Audio device '{}' not found", device_name),
148 )
149 })?
150 } else {
151 host.default_output_device().ok_or_else(|| {
152 VoirsError::device_error("audio_device", "No default audio output device found")
153 })?
154 };
155
156 let mut buffers = Vec::with_capacity(buffer_config.buffer_count);
158 for _ in 0..buffer_config.buffer_count {
159 buffers.push(AudioBuffer::new(buffer_config.buffer_size));
160 }
161
162 Ok(Self {
163 config: stream_config,
164 buffer_config,
165 device,
166 stream: None,
167 buffers: Arc::new(Mutex::new(buffers)),
168 write_index: Arc::new(Mutex::new(0)),
169 read_index: Arc::new(Mutex::new(0)),
170 stats: Arc::new(Mutex::new(StreamStats::default())),
171 is_active: Arc::new(Mutex::new(false)),
172 })
173 }
174
175 fn find_device_by_name(host: &Host, device_name: &str) -> Result<Option<Device>> {
177 for device in host.output_devices().map_err(|e| {
178 VoirsError::device_error(
179 "audio_device",
180 format!("Failed to enumerate devices: {}", e),
181 )
182 })? {
183 if let Ok(name) = device.name() {
184 if name == device_name {
185 return Ok(Some(device));
186 }
187 }
188 }
189 Ok(None)
190 }
191
192 pub async fn start(&mut self) -> Result<()> {
194 if self.is_active()? {
195 return Ok(());
196 }
197
198 let stream_config = cpal::StreamConfig {
199 channels: self.config.channels as ChannelCount,
200 sample_rate: SampleRate(self.config.sample_rate),
201 buffer_size: cpal::BufferSize::Fixed(self.config.buffer_size),
202 };
203
204 let buffers = self.buffers.clone();
205 let read_index = self.read_index.clone();
206 let stats = self.stats.clone();
207 let is_active = self.is_active.clone();
208
209 let stream = self
210 .device
211 .build_output_stream(
212 &stream_config,
213 move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
214 let active = if let Ok(guard) = is_active.lock() {
216 *guard
217 } else {
218 false
219 };
220
221 if !active {
222 for sample in data.iter_mut() {
224 *sample = 0.0;
225 }
226 return;
227 }
228
229 let mut read_idx = if let Ok(guard) = read_index.lock() {
231 *guard
232 } else {
233 0
234 };
235
236 let (samples_read, buffer_count) = if let Ok(mut buffers_guard) = buffers.lock()
237 {
238 let count = buffers_guard.len();
239 let read = if read_idx < count {
240 buffers_guard[read_idx].read_samples(data)
241 } else {
242 0
243 };
244 (read, count)
245 } else {
246 (0, 1)
247 };
248
249 if samples_read > 0 {
250 if let Ok(mut guard) = read_index.lock() {
252 *guard = (read_idx + 1) % buffer_count;
253 }
254
255 if let Ok(mut stats_guard) = stats.lock() {
257 stats_guard.buffers_played += 1;
258 }
259 } else {
260 for sample in data.iter_mut() {
262 *sample = 0.0;
263 }
264
265 if let Ok(mut stats_guard) = stats.lock() {
267 stats_guard.underruns += 1;
268 }
269 }
270 },
271 move |err| {
272 tracing::error!("Real-time audio stream error: {}", err);
273 },
274 None, )
276 .map_err(|e| {
277 VoirsError::device_error(
278 "audio_device",
279 format!("Failed to build output stream: {}", e),
280 )
281 })?;
282
283 stream.play().map_err(|e| {
284 VoirsError::device_error("audio_device", format!("Failed to start stream: {}", e))
285 })?;
286
287 self.stream = Some(stream);
288 self.set_active(true)?;
289
290 Ok(())
291 }
292
293 pub fn stop(&mut self) -> Result<()> {
295 self.set_active(false)?;
296
297 if let Some(stream) = self.stream.take() {
298 stream.pause().map_err(|e| {
299 VoirsError::device_error("audio_device", format!("Failed to stop stream: {}", e))
300 })?;
301 }
302
303 Ok(())
304 }
305
306 pub fn write_audio(&self, audio_data: &AudioData) -> Result<()> {
308 let samples_f32: Vec<f32> = audio_data
309 .samples
310 .iter()
311 .map(|&s| s as f32 / i16::MAX as f32)
312 .collect();
313
314 self.write_samples(&samples_f32)
315 }
316
317 pub fn write_samples(&self, samples: &[f32]) -> Result<()> {
319 let mut write_idx = self.write_index.lock().map_err(|_| {
320 VoirsError::device_error("audio_stream", "Failed to lock write_index mutex")
321 })?;
322
323 let mut buffers = self.buffers.lock().map_err(|_| {
324 VoirsError::device_error("audio_stream", "Failed to lock buffers mutex")
325 })?;
326
327 if *write_idx < buffers.len() {
328 buffers[*write_idx].write_samples(samples);
329 *write_idx = (*write_idx + 1) % buffers.len();
330 }
331
332 Ok(())
333 }
334
335 pub fn is_active(&self) -> Result<bool> {
337 let active = self.is_active.lock().map_err(|_| {
338 VoirsError::device_error("audio_stream", "Failed to lock is_active mutex")
339 })?;
340 Ok(*active)
341 }
342
343 fn set_active(&self, active: bool) -> Result<()> {
345 let mut state = self.is_active.lock().map_err(|_| {
346 VoirsError::device_error("audio_stream", "Failed to lock is_active mutex")
347 })?;
348 *state = active;
349 Ok(())
350 }
351
352 pub fn get_buffer_fill_level(&self) -> Result<f32> {
354 let write_idx = self.write_index.lock().map_err(|_| {
355 VoirsError::device_error("audio_stream", "Failed to lock write_index mutex")
356 })?;
357 let read_idx = self.read_index.lock().map_err(|_| {
358 VoirsError::device_error("audio_stream", "Failed to lock read_index mutex")
359 })?;
360
361 let buffers = self.buffers.lock().map_err(|_| {
362 VoirsError::device_error("audio_stream", "Failed to lock buffers mutex")
363 })?;
364
365 let ready_buffers = buffers.iter().filter(|b| b.is_ready).count();
366
367 Ok(ready_buffers as f32 / self.buffer_config.buffer_count as f32)
368 }
369
370 pub fn get_stats(&self) -> Result<StreamStats> {
372 let stats = self
373 .stats
374 .lock()
375 .map_err(|_| VoirsError::device_error("audio_stream", "Failed to lock stats mutex"))?;
376
377 let mut stats_copy = stats.clone();
378 stats_copy.current_buffer_fill = self.get_buffer_fill_level()?;
379
380 Ok(stats_copy)
381 }
382
383 pub fn reset_stats(&self) -> Result<()> {
385 let mut stats = self
386 .stats
387 .lock()
388 .map_err(|_| VoirsError::device_error("audio_stream", "Failed to lock stats mutex"))?;
389 *stats = StreamStats::default();
390 Ok(())
391 }
392
393 pub fn get_estimated_latency_ms(&self) -> Result<f32> {
395 let buffer_fill = self.get_buffer_fill_level()?;
396 let buffer_duration_ms =
397 (self.buffer_config.buffer_size as f32 / self.config.sample_rate as f32) * 1000.0;
398 let total_buffer_latency =
399 buffer_fill * buffer_duration_ms * self.buffer_config.buffer_count as f32;
400
401 Ok(total_buffer_latency)
402 }
403
404 pub fn has_sufficient_buffer_space(&self) -> Result<bool> {
406 let buffer_fill = self.get_buffer_fill_level()?;
407 Ok(buffer_fill < 0.8) }
409
410 pub async fn wait_for_buffer_space(&self, timeout: Duration) -> Result<bool> {
412 let start_time = std::time::Instant::now();
413
414 while start_time.elapsed() < timeout {
415 if self.has_sufficient_buffer_space()? {
416 return Ok(true);
417 }
418
419 tokio::time::sleep(Duration::from_millis(1)).await;
420 }
421
422 Ok(false)
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use super::AudioData;
429 use super::*;
430
431 #[test]
432 fn test_stream_config_default() {
433 let config = RealTimeStreamConfig::default();
434 assert_eq!(config.sample_rate, 22050);
435 assert_eq!(config.channels, 1);
436 assert_eq!(config.target_latency_ms, 50);
437 }
438
439 #[test]
440 fn test_buffer_config_default() {
441 let config = BufferConfig::default();
442 assert_eq!(config.buffer_count, 8);
443 assert_eq!(config.buffer_size, 512);
444 assert_eq!(config.underrun_threshold, 2);
445 }
446
447 #[test]
448 fn test_audio_buffer() {
449 let mut buffer = AudioBuffer::new(4);
450 assert!(!buffer.is_ready);
451
452 let samples = vec![0.1, 0.2, 0.3, 0.4];
453 buffer.write_samples(&samples);
454 assert!(buffer.is_ready);
455
456 let mut output = vec![0.0; 4];
457 let samples_read = buffer.read_samples(&mut output);
458 assert_eq!(samples_read, 4);
459 assert_eq!(output, samples);
460 assert!(!buffer.is_ready);
461 }
462
463 #[tokio::test]
464 async fn test_realtime_stream_creation() {
465 let stream_config = RealTimeStreamConfig::default();
466 let buffer_config = BufferConfig::default();
467
468 match RealTimeAudioStream::new(stream_config, buffer_config) {
470 Ok(stream) => {
471 assert!(!stream.is_active().unwrap());
472 let fill_level = stream.get_buffer_fill_level().unwrap();
473 assert!(fill_level >= 0.0 && fill_level <= 1.0);
474 }
475 Err(_) => {
476 }
478 }
479 }
480
481 #[tokio::test]
482 async fn test_stream_buffer_operations() {
483 let stream_config = RealTimeStreamConfig::default();
484 let buffer_config = BufferConfig::default();
485
486 if let Ok(stream) = RealTimeAudioStream::new(stream_config, buffer_config) {
487 let audio_data = AudioData {
488 samples: vec![0, 1000, 2000, 3000],
489 sample_rate: 22050,
490 channels: 1,
491 };
492
493 stream.write_audio(&audio_data).unwrap();
495
496 let fill_level = stream.get_buffer_fill_level().unwrap();
498 assert!(fill_level > 0.0);
499
500 assert!(stream.has_sufficient_buffer_space().unwrap());
502 }
503 }
504
505 #[tokio::test]
506 async fn test_stream_stats() {
507 let stream_config = RealTimeStreamConfig::default();
508 let buffer_config = BufferConfig::default();
509
510 if let Ok(stream) = RealTimeAudioStream::new(stream_config, buffer_config) {
511 let stats = stream.get_stats().unwrap();
512 assert_eq!(stats.buffers_played, 0);
513 assert_eq!(stats.buffers_dropped, 0);
514 assert_eq!(stats.underruns, 0);
515
516 stream.reset_stats().unwrap();
518 let stats_after_reset = stream.get_stats().unwrap();
519 assert_eq!(stats_after_reset.buffers_played, 0);
520 }
521 }
522}