Skip to main content

rust_audio_api/
context.rs

1use crate::graph::{ControlMessage, GraphBuilder, NodeId};
2use crate::types::{AUDIO_UNIT_SIZE, AudioUnit};
3use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
4use cpal::{SampleFormat, Stream, StreamConfig};
5use crossbeam_channel::{Sender, unbounded};
6use std::sync::Arc;
7use std::sync::atomic::{AtomicU8, AtomicU32, Ordering};
8use std::time::Instant;
9
10/// Monitor for audio thread performance.
11///
12/// It tracks the number of late callbacks and the current CPU load percentage
13/// of the audio processing thread.
14#[derive(Clone)]
15pub struct PerformanceMonitor {
16    /// Number of times the audio thread failed to meet the real-time deadline.
17    pub late_callbacks: Arc<AtomicU32>,
18    /// Current CPU load of the audio processing thread in percentage (0-100).
19    pub current_load_percent: Arc<AtomicU8>,
20}
21
22impl Default for PerformanceMonitor {
23    fn default() -> Self {
24        Self {
25            late_callbacks: Arc::new(AtomicU32::new(0)),
26            current_load_percent: Arc::new(AtomicU8::new(0)),
27        }
28    }
29}
30
31/// The main entry point for the audio system.
32///
33/// `AudioContext` manages the audio graph, the audio backend (CPAL),
34/// and the real-time audio thread. It provides a high-level API for
35/// building and controlling audio processing graphs.
36///
37/// # Examples
38///
39/// ### Basic Usage
40/// ```no_run
41/// use rust_audio_api::AudioContext;
42///
43/// let mut ctx = AudioContext::new().unwrap();
44/// // ... build graph ...
45/// // ctx.resume(destination_id).unwrap();
46/// ```
47///
48/// ### Dynamic Parameter Updates
49/// ```no_run
50/// use rust_audio_api::{AudioContext, NodeParameter};
51/// use rust_audio_api::nodes::{GainNode, NodeType};
52///
53/// let mut ctx = AudioContext::new().unwrap();
54/// let mut gain_id = None;
55///
56/// let dest_id = ctx.build_graph(|builder| {
57///     let gain = builder.add_node(NodeType::Gain(GainNode::new(0.5)));
58///     gain_id = Some(gain);
59///     gain
60/// });
61///
62/// ctx.resume(dest_id).unwrap();
63///
64/// // Later, send a message to change the gain
65/// let sender = ctx.control_sender();
66/// sender.send(rust_audio_api::graph::ControlMessage::SetParameter(
67///     gain_id.unwrap(),
68///     NodeParameter::Gain(0.8)
69/// )).unwrap();
70/// ```
71pub struct AudioContext {
72    stream: Option<Stream>,
73    sample_rate: u32,
74    msg_sender: Sender<ControlMessage>,
75    graph_builder: Option<GraphBuilder>,
76    performance_monitor: PerformanceMonitor,
77}
78
79impl AudioContext {
80    /// Creates a new `AudioContext` with the default output device and sample rate.
81    pub fn new() -> Result<Self, anyhow::Error> {
82        let host = cpal::default_host();
83        let device = host
84            .default_output_device()
85            .expect("Default output device not found");
86        let supported_config = device.default_output_config()?;
87        let sample_rate = supported_config.sample_rate();
88
89        let (tx, _rx) = unbounded();
90
91        Ok(Self {
92            stream: None,
93            sample_rate,
94            msg_sender: tx,
95            graph_builder: Some(GraphBuilder::new()),
96            performance_monitor: PerformanceMonitor::default(),
97        })
98    }
99
100    /// Returns a `PerformanceMonitor` to track the audio thread's health.
101    pub fn performance_monitor(&self) -> PerformanceMonitor {
102        self.performance_monitor.clone()
103    }
104
105    /// Returns the sample rate of the audio context.
106    pub fn sample_rate(&self) -> u32 {
107        self.sample_rate
108    }
109
110    /// Provides a [`GraphBuilder`] to construct the audio processing graph.
111    ///
112    /// This method takes a closure where you can add nodes and define their connections.
113    /// It returns the [`NodeId`] of the final destination node in the graph.
114    pub fn build_graph<F>(&mut self, builder_func: F) -> NodeId
115    where
116        F: FnOnce(&mut GraphBuilder) -> NodeId,
117    {
118        if let Some(mut gb) = self.graph_builder.take() {
119            let dest_id = builder_func(&mut gb);
120            self.graph_builder = Some(gb);
121            dest_id
122        } else {
123            panic!("GraphBuilder already consumed, cannot rebuild topology");
124        }
125    }
126
127    /// Starts the audio processing thread and begins playback.
128    ///
129    /// This method finalizes the graph construction and hands it over to the audio backend.
130    /// `destination_id` should be the ID of the final node that outputs audio.
131    pub fn resume(&mut self, destination_id: NodeId) -> Result<(), anyhow::Error> {
132        if self.stream.is_some() {
133            return Ok(());
134        }
135
136        let host = cpal::default_host();
137        let device = host.default_output_device().unwrap();
138        let supported_config = device.default_output_config()?;
139        let sample_format = supported_config.sample_format();
140        let config: StreamConfig = supported_config.into();
141
142        // Take GraphBuilder and generate StaticGraph
143        let builder = self.graph_builder.take().expect("GraphBuilder is missing");
144        let (tx, rx) = unbounded();
145        self.msg_sender = tx; // Update control sender held by the main thread
146
147        let static_graph = builder.build(destination_id, rx);
148
149        let stream = match sample_format {
150            SampleFormat::F32 => self.build_stream::<f32>(&device, &config, static_graph)?,
151            SampleFormat::I16 => self.build_stream::<i16>(&device, &config, static_graph)?,
152            SampleFormat::U16 => self.build_stream::<u16>(&device, &config, static_graph)?,
153            _ => return Err(anyhow::anyhow!("Unsupported audio output device format")),
154        };
155
156        stream.play()?;
157        self.stream = Some(stream);
158        Ok(())
159    }
160
161    fn build_stream<T>(
162        &self,
163        device: &cpal::Device,
164        config: &StreamConfig,
165        mut graph: crate::graph::StaticGraph,
166    ) -> Result<Stream, anyhow::Error>
167    where
168        T: cpal::Sample + cpal::SizedSample + cpal::FromSample<f32>,
169    {
170        let channels = config.channels as usize;
171        let sample_rate = self.sample_rate;
172        let monitor = self.performance_monitor.clone();
173
174        let mut unit_frame_index = AUDIO_UNIT_SIZE;
175        let mut current_unit: AudioUnit = [[0.0; 2]; AUDIO_UNIT_SIZE];
176
177        let stream = device.build_output_stream(
178            config,
179            move |data: &mut [T], _: &cpal::OutputCallbackInfo| {
180                let start_time = Instant::now();
181                let frame_count = data.len() / channels;
182
183                for frame in data.chunks_mut(channels) {
184                    if unit_frame_index >= AUDIO_UNIT_SIZE {
185                        let new_unit = graph.pull_next_unit();
186                        current_unit.copy_from_slice(new_unit);
187                        unit_frame_index = 0;
188                    }
189
190                    let sample_f32 = current_unit[unit_frame_index];
191                    unit_frame_index += 1;
192
193                    // Format conversion to T (f32, i16, u16) in CPAL buffers & downmix/upmix handling
194                    if channels >= 2 {
195                        frame[0] = T::from_sample(sample_f32[0]);
196                        frame[1] = T::from_sample(sample_f32[1]);
197                        for f in frame.iter_mut().take(channels).skip(2) {
198                            *f = T::from_sample(0.0);
199                        }
200                    } else if channels == 1 {
201                        let mono = (sample_f32[0] + sample_f32[1]) * 0.5;
202                        frame[0] = T::from_sample(mono);
203                    }
204                }
205
206                let elapsed_micros = start_time.elapsed().as_micros();
207                let max_allowed_micros =
208                    (frame_count as f64 / sample_rate as f64 * 1_000_000.0) as u128;
209
210                let load_percent =
211                    ((elapsed_micros as f64 / max_allowed_micros as f64) * 100.0) as u8;
212                monitor
213                    .current_load_percent
214                    .store(load_percent, Ordering::Relaxed);
215
216                if elapsed_micros > max_allowed_micros {
217                    monitor.late_callbacks.fetch_add(1, Ordering::Relaxed);
218                }
219            },
220            |err| eprintln!("Audio stream error: {}", err),
221            None,
222        )?;
223
224        Ok(stream)
225    }
226
227    /// Returns a Sender for sending control messages (non-blocking)
228    pub fn control_sender(&self) -> Sender<ControlMessage> {
229        self.msg_sender.clone()
230    }
231}