1use std::sync::Arc;
25
26use anyhow::{Context, Result};
27use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
28use cpal::{Device, Stream, StreamConfig};
29use wavecraft_protocol::MeterUpdateNotification;
30
31use super::atomic_params::AtomicParameterBridge;
32use super::ffi_processor::DevAudioProcessor;
33
34const GAIN_MULTIPLIER_MIN: f32 = 0.0;
35const GAIN_MULTIPLIER_MAX: f32 = 2.0;
36const INPUT_GAIN_PARAM_ID: &str = "input_gain_level";
38const OUTPUT_GAIN_PARAM_ID: &str = "output_gain_level";
39
40#[derive(Debug, Clone)]
42pub struct AudioConfig {
43 pub sample_rate: f32,
45 pub buffer_size: u32,
47}
48
49pub struct AudioHandle {
52 _input_stream: Stream,
53 _output_stream: Option<Stream>,
54}
55
56pub struct AudioServer {
59 processor: Box<dyn DevAudioProcessor>,
60 config: AudioConfig,
61 input_device: Device,
62 output_device: Device,
63 input_config: StreamConfig,
64 output_config: StreamConfig,
65 param_bridge: Arc<AtomicParameterBridge>,
66}
67
68impl AudioServer {
69 pub fn new(
72 processor: Box<dyn DevAudioProcessor>,
73 config: AudioConfig,
74 param_bridge: Arc<AtomicParameterBridge>,
75 ) -> Result<Self> {
76 let host = cpal::default_host();
77
78 let input_device = host
80 .default_input_device()
81 .context("No input device available")?;
82 tracing::info!("Using input device: {}", input_device.name()?);
83
84 let supported_input = input_device
85 .default_input_config()
86 .context("Failed to get default input config")?;
87 let input_sample_rate = supported_input.sample_rate().0;
88 tracing::info!("Input sample rate: {} Hz", input_sample_rate);
89 let input_config: StreamConfig = supported_input.into();
90
91 let output_device = host
93 .default_output_device()
94 .context("No output device available")?;
95
96 match output_device.name() {
97 Ok(name) => tracing::info!("Using output device: {}", name),
98 Err(_) => tracing::info!("Using output device: (unnamed)"),
99 }
100
101 let supported_output = output_device
102 .default_output_config()
103 .context("Failed to get default output config")?;
104 let output_sr = supported_output.sample_rate().0;
105 tracing::info!("Output sample rate: {} Hz", output_sr);
106 if output_sr != input_sample_rate {
107 tracing::warn!(
108 "Input/output sample rate mismatch ({} vs {}). \
109 Processing at input rate; output device may resample.",
110 input_sample_rate,
111 output_sr
112 );
113 }
114 let output_config: StreamConfig = supported_output.into();
115
116 Ok(Self {
117 processor,
118 config,
119 input_device,
120 output_device,
121 input_config,
122 output_config,
123 param_bridge,
124 })
125 }
126
127 pub fn start(mut self) -> Result<(AudioHandle, rtrb::Consumer<MeterUpdateNotification>)> {
135 let actual_sample_rate = self.input_config.sample_rate.0 as f32;
137 self.processor.set_sample_rate(actual_sample_rate);
138
139 let mut processor = self.processor;
140 let buffer_size = self.config.buffer_size as usize;
141 let input_channels = self.input_config.channels as usize;
142 let output_channels = self.output_config.channels as usize;
143 let param_bridge = Arc::clone(&self.param_bridge);
144
145 let ring_capacity = buffer_size * 2 * 4;
149 let (mut ring_producer, mut ring_consumer) = rtrb::RingBuffer::new(ring_capacity);
150
151 let (mut meter_producer, meter_consumer) =
156 rtrb::RingBuffer::<MeterUpdateNotification>::new(64);
157
158 let mut frame_counter = 0u64;
159 let mut oscillator_phase = 0.0_f32;
160
161 let mut left_buf = vec![0.0f32; buffer_size];
165 let mut right_buf = vec![0.0f32; buffer_size];
166
167 let mut interleave_buf = vec![0.0f32; buffer_size * 2];
169
170 let input_stream = self
171 .input_device
172 .build_input_stream(
173 &self.input_config,
174 move |data: &[f32], _: &cpal::InputCallbackInfo| {
175 frame_counter += 1;
176
177 let num_samples = data.len() / input_channels.max(1);
178 if num_samples == 0 || input_channels == 0 {
179 return;
180 }
181
182 let actual_samples = num_samples.min(left_buf.len());
183 let left = &mut left_buf[..actual_samples];
184 let right = &mut right_buf[..actual_samples];
185
186 left.fill(0.0);
188 right.fill(0.0);
189
190 for i in 0..actual_samples {
191 left[i] = data[i * input_channels];
192 if input_channels > 1 {
193 right[i] = data[i * input_channels + 1];
194 } else {
195 right[i] = left[i];
196 }
197 }
198
199 {
201 let mut channels: [&mut [f32]; 2] = [left, right];
202 processor.process(&mut channels);
203 }
204
205 apply_output_modifiers(
209 left,
210 right,
211 ¶m_bridge,
212 &mut oscillator_phase,
213 actual_sample_rate,
214 );
215
216 let left = &left_buf[..actual_samples];
218 let right = &right_buf[..actual_samples];
219
220 let (peak_left, rms_left) = compute_peak_and_rms(left);
222 let (peak_right, rms_right) = compute_peak_and_rms(right);
223
224 if frame_counter.is_multiple_of(2) {
229 let notification = MeterUpdateNotification {
230 timestamp_us: frame_counter,
231 left_peak: peak_left,
232 left_rms: rms_left,
233 right_peak: peak_right,
234 right_rms: rms_right,
235 };
236 let _ = meter_producer.push(notification);
240 }
241
242 let interleave = &mut interleave_buf[..actual_samples * 2];
246 for i in 0..actual_samples {
247 interleave[i * 2] = left[i];
248 interleave[i * 2 + 1] = right[i];
249 }
250
251 for &sample in interleave.iter() {
253 if ring_producer.push(sample).is_err() {
254 break;
255 }
256 }
257 },
258 |err| {
259 tracing::error!("Audio input stream error: {}", err);
260 },
261 None,
262 )
263 .context("Failed to build input stream")?;
264
265 input_stream
266 .play()
267 .context("Failed to start input stream")?;
268 tracing::info!("Input stream started");
269
270 let output_stream = self
272 .output_device
273 .build_output_stream(
274 &self.output_config,
275 move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
276 if output_channels == 0 {
277 data.fill(0.0);
278 return;
279 }
280
281 for frame in data.chunks_mut(output_channels) {
284 let left = ring_consumer.pop().unwrap_or(0.0);
285 let right = ring_consumer.pop().unwrap_or(0.0);
286
287 if output_channels == 1 {
288 frame[0] = 0.5 * (left + right);
289 continue;
290 }
291
292 frame[0] = left;
293 frame[1] = right;
294
295 for channel in frame.iter_mut().skip(2) {
296 *channel = 0.0;
297 }
298 }
299 },
300 |err| {
301 tracing::error!("Audio output stream error: {}", err);
302 },
303 None,
304 )
305 .context("Failed to build output stream")?;
306
307 output_stream
308 .play()
309 .context("Failed to start output stream")?;
310 tracing::info!("Output stream started");
311
312 tracing::info!("Audio server started in full-duplex (input + output) mode");
313
314 Ok((
315 AudioHandle {
316 _input_stream: input_stream,
317 _output_stream: Some(output_stream),
318 },
319 meter_consumer,
320 ))
321 }
322
323 pub fn has_output(&self) -> bool {
325 true
326 }
327}
328
329fn apply_output_modifiers(
330 left: &mut [f32],
331 right: &mut [f32],
332 param_bridge: &AtomicParameterBridge,
333 oscillator_phase: &mut f32,
334 sample_rate: f32,
335) {
336 let input_gain = read_gain_multiplier(param_bridge, INPUT_GAIN_PARAM_ID);
337 let output_gain = read_gain_multiplier(param_bridge, OUTPUT_GAIN_PARAM_ID);
338 let combined_gain = input_gain * output_gain;
339
340 if let Some(enabled) = param_bridge.read("oscillator_enabled")
343 && enabled < 0.5
344 {
345 left.fill(0.0);
346 right.fill(0.0);
347 apply_gain(left, right, combined_gain);
348 return;
349 }
350
351 let oscillator_frequency = param_bridge.read("oscillator_frequency");
354 let oscillator_level = param_bridge.read("oscillator_level");
355
356 if let (Some(frequency), Some(level)) = (oscillator_frequency, oscillator_level) {
357 if !sample_rate.is_finite() || sample_rate <= 0.0 {
358 apply_gain(left, right, combined_gain);
359 return;
360 }
361
362 let clamped_frequency = if frequency.is_finite() {
363 frequency.clamp(20.0, 5000.0)
364 } else {
365 440.0
366 };
367 let clamped_level = if level.is_finite() {
368 level.clamp(0.0, 1.0)
369 } else {
370 0.0
371 };
372
373 let phase_delta = clamped_frequency / sample_rate;
374 let mut phase = if oscillator_phase.is_finite() {
375 *oscillator_phase
376 } else {
377 0.0
378 };
379
380 for (left_sample, right_sample) in left.iter_mut().zip(right.iter_mut()) {
381 let sample = (phase * std::f32::consts::TAU).sin() * clamped_level;
382 *left_sample = sample;
383 *right_sample = sample;
384
385 phase += phase_delta;
386 if phase >= 1.0 {
387 phase -= phase.floor();
388 }
389 }
390
391 *oscillator_phase = phase;
392 }
393
394 apply_gain(left, right, combined_gain);
395}
396
397fn read_gain_multiplier(param_bridge: &AtomicParameterBridge, id: &str) -> f32 {
398 if let Some(value) = param_bridge.read(id)
399 && value.is_finite()
400 {
401 return value.clamp(GAIN_MULTIPLIER_MIN, GAIN_MULTIPLIER_MAX);
402 }
403
404 1.0
405}
406
407fn compute_peak_and_rms(samples: &[f32]) -> (f32, f32) {
408 let peak = samples
409 .iter()
410 .copied()
411 .fold(0.0f32, |acc, sample| acc.max(sample.abs()));
412 let rms =
413 (samples.iter().map(|sample| sample * sample).sum::<f32>() / samples.len() as f32).sqrt();
414
415 (peak, rms)
416}
417
418fn apply_gain(left: &mut [f32], right: &mut [f32], gain: f32) {
419 if (gain - 1.0).abs() <= f32::EPSILON {
420 return;
421 }
422
423 for (left_sample, right_sample) in left.iter_mut().zip(right.iter_mut()) {
424 *left_sample *= gain;
425 *right_sample *= gain;
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::apply_output_modifiers;
432 use crate::audio::atomic_params::AtomicParameterBridge;
433 use wavecraft_protocol::{ParameterInfo, ParameterType};
434
435 fn bridge_with_enabled(default_value: f32) -> AtomicParameterBridge {
436 AtomicParameterBridge::new(&[ParameterInfo {
437 id: "oscillator_enabled".to_string(),
438 name: "Enabled".to_string(),
439 param_type: ParameterType::Float,
440 value: default_value,
441 default: default_value,
442 unit: Some("%".to_string()),
443 min: 0.0,
444 max: 1.0,
445 group: Some("Oscillator".to_string()),
446 }])
447 }
448
449 #[test]
450 fn output_modifiers_mute_when_oscillator_disabled() {
451 let bridge = bridge_with_enabled(1.0);
452 bridge.write("oscillator_enabled", 0.0);
453
454 let mut left = [0.25_f32, -0.5, 0.75];
455 let mut right = [0.2_f32, -0.4, 0.6];
456 let mut phase = 0.0;
457 apply_output_modifiers(&mut left, &mut right, &bridge, &mut phase, 48_000.0);
458
459 assert!(left.iter().all(|s| s.abs() <= f32::EPSILON));
460 assert!(right.iter().all(|s| s.abs() <= f32::EPSILON));
461 }
462
463 #[test]
464 fn output_modifiers_keep_signal_when_oscillator_enabled() {
465 let bridge = bridge_with_enabled(1.0);
466
467 let mut left = [0.25_f32, -0.5, 0.75];
468 let mut right = [0.2_f32, -0.4, 0.6];
469 let mut phase = 0.0;
470 apply_output_modifiers(&mut left, &mut right, &bridge, &mut phase, 48_000.0);
471
472 assert_eq!(left, [0.25, -0.5, 0.75]);
473 assert_eq!(right, [0.2, -0.4, 0.6]);
474 }
475
476 fn oscillator_bridge(
477 frequency: f32,
478 level: f32,
479 enabled: f32,
480 input_gain_level: f32,
481 output_gain_level: f32,
482 ) -> AtomicParameterBridge {
483 AtomicParameterBridge::new(&[
484 ParameterInfo {
485 id: "oscillator_enabled".to_string(),
486 name: "Enabled".to_string(),
487 param_type: ParameterType::Float,
488 value: enabled,
489 default: enabled,
490 unit: Some("%".to_string()),
491 min: 0.0,
492 max: 1.0,
493 group: Some("Oscillator".to_string()),
494 },
495 ParameterInfo {
496 id: "oscillator_frequency".to_string(),
497 name: "Frequency".to_string(),
498 param_type: ParameterType::Float,
499 value: frequency,
500 default: frequency,
501 min: 20.0,
502 max: 5_000.0,
503 unit: Some("Hz".to_string()),
504 group: Some("Oscillator".to_string()),
505 },
506 ParameterInfo {
507 id: "oscillator_level".to_string(),
508 name: "Level".to_string(),
509 param_type: ParameterType::Float,
510 value: level,
511 default: level,
512 unit: Some("%".to_string()),
513 min: 0.0,
514 max: 1.0,
515 group: Some("Oscillator".to_string()),
516 },
517 ParameterInfo {
518 id: "input_gain_level".to_string(),
519 name: "Level".to_string(),
520 param_type: ParameterType::Float,
521 value: input_gain_level,
522 default: input_gain_level,
523 unit: Some("x".to_string()),
524 min: 0.0,
525 max: 2.0,
526 group: Some("InputGain".to_string()),
527 },
528 ParameterInfo {
529 id: "output_gain_level".to_string(),
530 name: "Level".to_string(),
531 param_type: ParameterType::Float,
532 value: output_gain_level,
533 default: output_gain_level,
534 unit: Some("x".to_string()),
535 min: 0.0,
536 max: 2.0,
537 group: Some("OutputGain".to_string()),
538 },
539 ])
540 }
541
542 #[test]
543 fn output_modifiers_generate_runtime_oscillator_from_frequency_and_level() {
544 let bridge = oscillator_bridge(880.0, 0.75, 1.0, 1.0, 1.0);
545 let mut left = [0.0_f32; 128];
546 let mut right = [0.0_f32; 128];
547 let mut phase = 0.0;
548
549 apply_output_modifiers(&mut left, &mut right, &bridge, &mut phase, 48_000.0);
550
551 let peak_left = left
552 .iter()
553 .fold(0.0_f32, |acc, sample| acc.max(sample.abs()));
554 let peak_right = right
555 .iter()
556 .fold(0.0_f32, |acc, sample| acc.max(sample.abs()));
557
558 assert!(peak_left > 0.2, "expected audible generated oscillator");
559 assert!(peak_right > 0.2, "expected audible generated oscillator");
560 assert_eq!(left, right, "expected in-phase stereo oscillator output");
561 assert!(phase > 0.0, "phase should advance after generation");
562 }
563
564 #[test]
565 fn output_modifiers_level_zero_produces_silence() {
566 let bridge = oscillator_bridge(440.0, 0.0, 1.0, 1.0, 1.0);
567 let mut left = [0.1_f32; 64];
568 let mut right = [0.1_f32; 64];
569 let mut phase = 0.0;
570
571 apply_output_modifiers(&mut left, &mut right, &bridge, &mut phase, 48_000.0);
572
573 assert!(left.iter().all(|s| s.abs() <= f32::EPSILON));
574 assert!(right.iter().all(|s| s.abs() <= f32::EPSILON));
575 }
576
577 #[test]
578 fn output_modifiers_frequency_change_changes_waveform() {
579 let low_freq_bridge = oscillator_bridge(220.0, 0.5, 1.0, 1.0, 1.0);
580 let high_freq_bridge = oscillator_bridge(1760.0, 0.5, 1.0, 1.0, 1.0);
581
582 let mut low_left = [0.0_f32; 256];
583 let mut low_right = [0.0_f32; 256];
584 let mut high_left = [0.0_f32; 256];
585 let mut high_right = [0.0_f32; 256];
586
587 let mut low_phase = 0.0;
588 let mut high_phase = 0.0;
589
590 apply_output_modifiers(
591 &mut low_left,
592 &mut low_right,
593 &low_freq_bridge,
594 &mut low_phase,
595 48_000.0,
596 );
597 apply_output_modifiers(
598 &mut high_left,
599 &mut high_right,
600 &high_freq_bridge,
601 &mut high_phase,
602 48_000.0,
603 );
604
605 assert_ne!(
606 low_left, high_left,
607 "frequency updates should alter waveform"
608 );
609 assert_eq!(low_left, low_right);
610 assert_eq!(high_left, high_right);
611 }
612
613 #[test]
614 fn output_modifiers_apply_input_and_output_gain_levels() {
615 let unity_bridge = oscillator_bridge(880.0, 0.5, 1.0, 1.0, 1.0);
616 let boosted_bridge = oscillator_bridge(880.0, 0.5, 1.0, 1.5, 2.0);
617
618 let mut unity_left = [0.0_f32; 256];
619 let mut unity_right = [0.0_f32; 256];
620 let mut boosted_left = [0.0_f32; 256];
621 let mut boosted_right = [0.0_f32; 256];
622
623 let mut unity_phase = 0.0;
624 let mut boosted_phase = 0.0;
625
626 apply_output_modifiers(
627 &mut unity_left,
628 &mut unity_right,
629 &unity_bridge,
630 &mut unity_phase,
631 48_000.0,
632 );
633 apply_output_modifiers(
634 &mut boosted_left,
635 &mut boosted_right,
636 &boosted_bridge,
637 &mut boosted_phase,
638 48_000.0,
639 );
640
641 let unity_peak = unity_left
642 .iter()
643 .fold(0.0_f32, |acc, sample| acc.max(sample.abs()));
644 let boosted_peak = boosted_left
645 .iter()
646 .fold(0.0_f32, |acc, sample| acc.max(sample.abs()));
647
648 assert!(boosted_peak > unity_peak * 2.5);
649 assert_eq!(boosted_left, boosted_right);
650 assert_eq!(unity_left, unity_right);
651 }
652
653 #[test]
654 fn output_modifiers_apply_gain_without_oscillator_params() {
655 let bridge = AtomicParameterBridge::new(&[
656 ParameterInfo {
657 id: "input_gain_level".to_string(),
658 name: "Level".to_string(),
659 param_type: ParameterType::Float,
660 value: 1.5,
661 default: 1.5,
662 unit: Some("x".to_string()),
663 min: 0.0,
664 max: 2.0,
665 group: Some("InputGain".to_string()),
666 },
667 ParameterInfo {
668 id: "output_gain_level".to_string(),
669 name: "Level".to_string(),
670 param_type: ParameterType::Float,
671 value: 1.2,
672 default: 1.2,
673 unit: Some("x".to_string()),
674 min: 0.0,
675 max: 2.0,
676 group: Some("OutputGain".to_string()),
677 },
678 ]);
679
680 let mut left = [0.25_f32, -0.5, 0.75];
681 let mut right = [0.2_f32, -0.4, 0.6];
682 let mut phase = 0.0;
683
684 apply_output_modifiers(&mut left, &mut right, &bridge, &mut phase, 48_000.0);
685
686 let expected_gain = 1.5 * 1.2;
687 assert_eq!(
688 left,
689 [
690 0.25 * expected_gain,
691 -0.5 * expected_gain,
692 0.75 * expected_gain
693 ]
694 );
695 assert_eq!(
696 right,
697 [
698 0.2 * expected_gain,
699 -0.4 * expected_gain,
700 0.6 * expected_gain
701 ]
702 );
703 }
704
705 #[test]
706 fn output_modifiers_ignore_compact_legacy_gain_ids() {
707 let bridge = AtomicParameterBridge::new(&[
708 ParameterInfo {
709 id: "inputgain_level".to_string(),
710 name: "Level".to_string(),
711 param_type: ParameterType::Float,
712 value: 0.2,
713 default: 0.2,
714 unit: Some("x".to_string()),
715 min: 0.0,
716 max: 2.0,
717 group: Some("InputGain".to_string()),
718 },
719 ParameterInfo {
720 id: "outputgain_level".to_string(),
721 name: "Level".to_string(),
722 param_type: ParameterType::Float,
723 value: 0.2,
724 default: 0.2,
725 unit: Some("x".to_string()),
726 min: 0.0,
727 max: 2.0,
728 group: Some("OutputGain".to_string()),
729 },
730 ]);
731
732 let mut left = [0.5_f32; 16];
733 let mut right = [0.5_f32; 16];
734 let mut phase = 0.0;
735
736 apply_output_modifiers(&mut left, &mut right, &bridge, &mut phase, 48_000.0);
737
738 let expected = 0.5;
740 assert!(left.iter().all(|sample| (*sample - expected).abs() < 1e-6));
741 assert!(right.iter().all(|sample| (*sample - expected).abs() < 1e-6));
742 }
743
744 #[test]
745 fn output_modifiers_ignore_legacy_snake_case_gain_suffix_ids() {
746 let bridge = AtomicParameterBridge::new(&[
747 ParameterInfo {
748 id: "input_gain_gain".to_string(),
749 name: "Gain".to_string(),
750 param_type: ParameterType::Float,
751 value: 0.2,
752 default: 0.2,
753 unit: Some("x".to_string()),
754 min: 0.0,
755 max: 2.0,
756 group: Some("InputGain".to_string()),
757 },
758 ParameterInfo {
759 id: "output_gain_gain".to_string(),
760 name: "Gain".to_string(),
761 param_type: ParameterType::Float,
762 value: 0.2,
763 default: 0.2,
764 unit: Some("x".to_string()),
765 min: 0.0,
766 max: 2.0,
767 group: Some("OutputGain".to_string()),
768 },
769 ]);
770
771 let mut left = [0.5_f32; 16];
772 let mut right = [0.5_f32; 16];
773 let mut phase = 0.0;
774
775 apply_output_modifiers(&mut left, &mut right, &bridge, &mut phase, 48_000.0);
776
777 let expected = 0.5;
779 assert!(left.iter().all(|sample| (*sample - expected).abs() < 1e-6));
780 assert!(right.iter().all(|sample| (*sample - expected).abs() < 1e-6));
781 }
782
783 #[test]
784 fn output_modifiers_use_canonical_ids_even_when_legacy_variants_exist() {
785 let bridge = AtomicParameterBridge::new(&[
786 ParameterInfo {
787 id: "input_gain_level".to_string(),
788 name: "Level".to_string(),
789 param_type: ParameterType::Float,
790 value: 1.6,
791 default: 1.6,
792 unit: Some("x".to_string()),
793 min: 0.0,
794 max: 2.0,
795 group: Some("InputGain".to_string()),
796 },
797 ParameterInfo {
798 id: "inputgain_level".to_string(),
799 name: "Level".to_string(),
800 param_type: ParameterType::Float,
801 value: 0.4,
802 default: 0.4,
803 unit: Some("x".to_string()),
804 min: 0.0,
805 max: 2.0,
806 group: Some("InputGain".to_string()),
807 },
808 ParameterInfo {
809 id: "output_gain_level".to_string(),
810 name: "Level".to_string(),
811 param_type: ParameterType::Float,
812 value: 1.0,
813 default: 1.0,
814 unit: Some("x".to_string()),
815 min: 0.0,
816 max: 2.0,
817 group: Some("OutputGain".to_string()),
818 },
819 ]);
820
821 let mut left = [0.5_f32; 8];
822 let mut right = [0.5_f32; 8];
823 let mut phase = 0.0;
824
825 apply_output_modifiers(&mut left, &mut right, &bridge, &mut phase, 48_000.0);
826
827 let expected = 0.5 * 1.6;
829 assert!(left.iter().all(|sample| (*sample - expected).abs() < 1e-6));
830 assert!(right.iter().all(|sample| (*sample - expected).abs() < 1e-6));
831 }
832}