Skip to main content

trem_cpal/
driver.rs

1//! `cpal` output stream driving a [`trem::graph::Graph`] with [`crate::bridge`] command/notification bridging.
2//!
3//! [`AudioEngine`] builds the device stream, drains any stale commands, and runs the graph in the callback.
4
5use crate::bridge::{AudioBridge, Command, Notification, ScopeFocus};
6use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
7use cpal::{SampleFormat, Stream, StreamConfig};
8use trem::event::{GraphEvent, TimedEvent};
9use trem::graph::{Graph, Sig};
10
11struct CallbackState {
12    cmd_rx: rtrb::Consumer<Command>,
13    notif_tx: rtrb::Producer<Notification>,
14    graph: Graph,
15    output_node: u32,
16    /// Inst-bus node id for [`ScopeFocus::PatchBuses`] (pre-master submix).
17    scope_input_node: Option<u32>,
18    scope_focus: ScopeFocus,
19    preview_scratch_l: Vec<f32>,
20    preview_scratch_r: Vec<f32>,
21    sample_rate: f64,
22    playing: bool,
23    bpm: f64,
24    pattern_events: Vec<TimedEvent>,
25    pattern_len: usize,
26    playhead: usize,
27    block_events: Vec<TimedEvent>,
28    meter_acc: usize,
29    meter_peak_l: f32,
30    meter_peak_r: f32,
31    pos_acc: usize,
32    scope_master: Box<[f32]>,
33    scope_master_len: usize,
34    scope_graph_in: Box<[f32]>,
35    scope_graph_in_len: usize,
36    scope_acc: usize,
37}
38
39impl CallbackState {
40    fn pattern_len_from_events(events: &[TimedEvent]) -> usize {
41        events
42            .iter()
43            .map(|e| e.sample_offset.saturating_add(1))
44            .max()
45            .unwrap_or(0)
46    }
47
48    fn drain_commands(&mut self) {
49        while let Ok(cmd) = self.cmd_rx.pop() {
50            match cmd {
51                Command::NoteOn {
52                    frequency,
53                    velocity,
54                    voice,
55                } => {
56                    self.block_events.push(TimedEvent {
57                        sample_offset: 0,
58                        event: GraphEvent::NoteOn {
59                            frequency,
60                            velocity,
61                            voice,
62                        },
63                    });
64                }
65                Command::NoteOff { voice } => {
66                    self.block_events.push(TimedEvent {
67                        sample_offset: 0,
68                        event: GraphEvent::NoteOff { voice },
69                    });
70                }
71                Command::SetBpm(bpm) => {
72                    self.bpm = bpm;
73                }
74                Command::Play => {
75                    self.playing = true;
76                }
77                Command::Pause => {
78                    self.playing = false;
79                    self.graph.reset();
80                }
81                Command::Stop => {
82                    self.playing = false;
83                    self.playhead = 0;
84                    self.graph.reset();
85                }
86                Command::LoadEvents(mut events) => {
87                    let old_len = self.pattern_len;
88                    std::mem::swap(&mut self.pattern_events, &mut events);
89                    drop(events);
90                    self.pattern_len = Self::pattern_len_from_events(&self.pattern_events);
91                    if self.pattern_len == 0 {
92                        self.playhead = 0;
93                    } else if old_len == 0 {
94                        // First pattern bound to this stream: always start at loop start.
95                        self.playhead = 0;
96                    } else {
97                        // Hot-swap (edit while playing/paused): keep position in the loop.
98                        self.playhead %= self.pattern_len;
99                    }
100                    self.block_events
101                        .reserve(self.pattern_events.len().saturating_mul(8).max(256));
102                }
103                Command::SetParam {
104                    path,
105                    param_id,
106                    value,
107                } => {
108                    self.graph.set_param_at_path(&path, param_id, value);
109                }
110                Command::SetScopeFocus(focus) => {
111                    self.scope_focus = focus;
112                }
113            }
114        }
115    }
116
117    /// All sample offsets `k` in `[0, frames)` where `(playhead + k) % len == target`.
118    fn schedule_pattern_event(
119        block_events: &mut Vec<TimedEvent>,
120        playhead: usize,
121        pattern_len: usize,
122        frames: usize,
123        event: &TimedEvent,
124    ) {
125        if pattern_len == 0 {
126            return;
127        }
128        let target = event.sample_offset % pattern_len;
129        let first = if target >= playhead {
130            target - playhead
131        } else {
132            pattern_len - playhead + target
133        };
134        let mut k = first;
135        while k < frames {
136            block_events.push(TimedEvent {
137                sample_offset: k,
138                event: event.event.clone(),
139            });
140            k = k.saturating_add(pattern_len);
141        }
142    }
143
144    fn collect_pattern_events_for_block(&mut self, frames: usize) {
145        if !self.playing || self.pattern_len == 0 {
146            return;
147        }
148        let playhead = self.playhead % self.pattern_len;
149        for e in &self.pattern_events {
150            Self::schedule_pattern_event(
151                &mut self.block_events,
152                playhead,
153                self.pattern_len,
154                frames,
155                e,
156            );
157        }
158    }
159
160    fn sort_block_events(&mut self) {
161        self.block_events.sort_by_key(|e| e.sample_offset);
162    }
163
164    fn advance_playhead(&mut self, frames: usize) {
165        if !self.playing || self.pattern_len == 0 {
166            return;
167        }
168        self.playhead = (self.playhead + frames) % self.pattern_len;
169    }
170
171    fn push_position_if_due(&mut self, frames: usize) {
172        self.pos_acc += frames;
173        const INTERVAL: usize = 256;
174        if self.pos_acc < INTERVAL {
175            return;
176        }
177        self.pos_acc = 0;
178        if !self.playing {
179            return;
180        }
181        let beat = self.playhead as f64 * self.bpm / (60.0 * self.sample_rate);
182        let _ = self.notif_tx.push(Notification::Position { beat });
183    }
184
185    fn flush_meter_if_due(&mut self) {
186        const METER_INTERVAL: usize = 1024;
187        if self.meter_acc < METER_INTERVAL {
188            return;
189        }
190        self.meter_acc = 0;
191        let _ = self.notif_tx.push(Notification::Meter {
192            peak_l: self.meter_peak_l,
193            peak_r: self.meter_peak_r,
194        });
195        self.meter_peak_l = 0.0;
196        self.meter_peak_r = 0.0;
197    }
198
199    fn flush_scope_if_due(&mut self) {
200        const SCOPE_INTERVAL: usize = 2048;
201        if self.scope_acc < SCOPE_INTERVAL {
202            return;
203        }
204        self.scope_acc = 0;
205        if self.scope_master_len > 0 {
206            let master = self.scope_master[..self.scope_master_len].to_vec();
207            let graph_in = if self.scope_graph_in_len > 0 {
208                self.scope_graph_in[..self.scope_graph_in_len].to_vec()
209            } else {
210                master.clone()
211            };
212            self.scope_master_len = 0;
213            self.scope_graph_in_len = 0;
214            let _ = self
215                .notif_tx
216                .push(Notification::ScopeData(crate::bridge::ScopeSnapshot {
217                    master,
218                    graph_in,
219                }));
220        }
221    }
222
223    fn process_output(&mut self, data: &mut [f32], channels: usize) {
224        let frames = data.len() / channels;
225
226        self.block_events.clear();
227        self.drain_commands();
228        self.collect_pattern_events_for_block(frames);
229        self.sort_block_events();
230
231        self.graph.run(frames, self.sample_rate, &self.block_events);
232        let l = self.graph.output_buffer(self.output_node, 0);
233        let r = self.graph.output_buffer(self.output_node, 1);
234
235        for i in 0..frames {
236            let li = l.get(i).copied().unwrap_or(0.0);
237            let ri = r.get(i).copied().unwrap_or(li);
238            let al = li.abs();
239            let ar = ri.abs();
240            if al > self.meter_peak_l {
241                self.meter_peak_l = al;
242            }
243            if ar > self.meter_peak_r {
244                self.meter_peak_r = ar;
245            }
246            if channels >= 2 {
247                data[i * channels] = li;
248                data[i * channels + 1] = ri;
249            } else {
250                data[i] = 0.5 * (li + ri);
251            }
252        }
253
254        self.append_scope_samples(frames);
255        self.scope_acc += frames;
256        self.flush_scope_if_due();
257        self.meter_acc += frames;
258        self.flush_meter_if_due();
259        if self.playing {
260            self.advance_playhead(frames);
261        }
262
263        self.push_position_if_due(frames);
264    }
265
266    /// Fills `scope_graph_in` (left pane = “in”) and `scope_master` (right = “out”) from the
267    /// active [`ScopeFocus`].
268    fn append_scope_samples(&mut self, frames: usize) {
269        let cap = self.scope_master.len();
270        match &self.scope_focus {
271            ScopeFocus::PatchBuses => {
272                let ml = self.graph.output_buffer(self.output_node, 0);
273                let mr = self.graph.output_buffer(self.output_node, 1);
274                let (il, ir) = if let Some(nid) = self.scope_input_node {
275                    (
276                        self.graph.output_buffer(nid, 0),
277                        self.graph.output_buffer(nid, 1),
278                    )
279                } else {
280                    (ml, mr)
281                };
282                for i in 0..frames {
283                    if self.scope_master_len + 2 > cap {
284                        break;
285                    }
286                    let gli = il.get(i).copied().unwrap_or(0.0);
287                    let gri = ir.get(i).copied().unwrap_or(gli);
288                    self.scope_graph_in[self.scope_graph_in_len] = gli;
289                    self.scope_graph_in[self.scope_graph_in_len + 1] = gri;
290                    self.scope_graph_in_len += 2;
291                    let m0 = ml.get(i).copied().unwrap_or(0.0);
292                    let m1 = mr.get(i).copied().unwrap_or(m0);
293                    self.scope_master[self.scope_master_len] = m0;
294                    self.scope_master[self.scope_master_len + 1] = m1;
295                    self.scope_master_len += 2;
296                }
297            }
298            ScopeFocus::GraphNode { graph_path, node } => {
299                if self.preview_scratch_l.len() < frames {
300                    self.preview_scratch_l.resize(frames, 0.0);
301                    self.preview_scratch_r.resize(frames, 0.0);
302                }
303                let sig = self
304                    .graph
305                    .node_sig_at_path(graph_path, *node)
306                    .unwrap_or(Sig {
307                        inputs: 0,
308                        outputs: 0,
309                    });
310                let ins = sig.inputs as usize;
311                let outs = sig.outputs as usize;
312                if ins >= 1 {
313                    self.graph.mix_input_port_at_path(
314                        graph_path,
315                        *node,
316                        0,
317                        frames,
318                        &mut self.preview_scratch_l[..frames],
319                    );
320                } else {
321                    self.preview_scratch_l[..frames].fill(0.0);
322                }
323                if ins >= 2 {
324                    self.graph.mix_input_port_at_path(
325                        graph_path,
326                        *node,
327                        1,
328                        frames,
329                        &mut self.preview_scratch_r[..frames],
330                    );
331                } else {
332                    self.preview_scratch_r[..frames]
333                        .copy_from_slice(&self.preview_scratch_l[..frames]);
334                }
335                for i in 0..frames {
336                    if self.scope_master_len + 2 > cap {
337                        break;
338                    }
339                    self.scope_graph_in[self.scope_graph_in_len] = self.preview_scratch_l[i];
340                    self.scope_graph_in[self.scope_graph_in_len + 1] = self.preview_scratch_r[i];
341                    self.scope_graph_in_len += 2;
342
343                    let ol0 = self
344                        .graph
345                        .output_buffer_at_path(graph_path, *node, 0)
346                        .and_then(|b| b.get(i))
347                        .copied()
348                        .unwrap_or(0.0);
349                    let ol1 = if outs >= 2 {
350                        self.graph
351                            .output_buffer_at_path(graph_path, *node, 1)
352                            .and_then(|b| b.get(i))
353                            .copied()
354                            .unwrap_or(ol0)
355                    } else if outs == 1 {
356                        ol0
357                    } else {
358                        0.0
359                    };
360                    self.scope_master[self.scope_master_len] = ol0;
361                    self.scope_master[self.scope_master_len + 1] = ol1;
362                    self.scope_master_len += 2;
363                }
364            }
365        }
366    }
367}
368
369/// Live output stream; dropping it stops audio and releases the device.
370pub struct AudioEngine {
371    /// Kept so the device keeps playing until the engine is dropped.
372    pub stream: Stream,
373}
374
375impl AudioEngine {
376    /// Opens the default F32 stereo output at `sample_rate`, wiring `audio_bridge` and `graph` into the callback.
377    ///
378    /// `scope_input_node`: optional graph node id whose stereo output is copied into
379    /// [`crate::bridge::ScopeSnapshot::graph_in`] (e.g. instrument bus before master FX).
380    pub fn new(
381        audio_bridge: AudioBridge,
382        graph: Graph,
383        output_node: u32,
384        scope_input_node: Option<u32>,
385        sample_rate: f64,
386    ) -> Result<Self, anyhow::Error> {
387        let host = cpal::default_host();
388        let device = host
389            .default_output_device()
390            .ok_or_else(|| anyhow::anyhow!("no default output device"))?;
391
392        let supported = device.default_output_config()?;
393        if supported.sample_format() != SampleFormat::F32 {
394            return Err(anyhow::anyhow!(
395                "default output is not f32; pick an F32-capable device or config"
396            ));
397        }
398
399        let mut stream_config: StreamConfig = supported.config();
400        stream_config.channels = stream_config.channels.max(2);
401        stream_config.sample_rate = sample_rate.round() as u32;
402
403        let channels = stream_config.channels as usize;
404
405        let AudioBridge {
406            mut cmd_rx,
407            notif_tx,
408        } = audio_bridge;
409
410        while cmd_rx.pop().is_ok() {}
411
412        let mut state = CallbackState {
413            cmd_rx,
414            notif_tx,
415            graph,
416            output_node,
417            scope_input_node,
418            scope_focus: ScopeFocus::PatchBuses,
419            preview_scratch_l: Vec::new(),
420            preview_scratch_r: Vec::new(),
421            sample_rate,
422            playing: false,
423            bpm: 120.0,
424            pattern_events: Vec::new(),
425            pattern_len: 0,
426            playhead: 0,
427            block_events: Vec::with_capacity(1024),
428            meter_acc: 0,
429            meter_peak_l: 0.0,
430            meter_peak_r: 0.0,
431            pos_acc: 0,
432            scope_master: vec![0.0f32; 8192].into_boxed_slice(),
433            scope_master_len: 0,
434            scope_graph_in: vec![0.0f32; 8192].into_boxed_slice(),
435            scope_graph_in_len: 0,
436            scope_acc: 0,
437        };
438        state.block_events.reserve(4096);
439
440        let stream = device.build_output_stream(
441            &stream_config,
442            move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
443                state.process_output(data, channels);
444            },
445            |err| {
446                eprintln!("trem-cpal stream error: {err}");
447            },
448            None,
449        )?;
450
451        stream.play()?;
452
453        Ok(Self { stream })
454    }
455}