Skip to main content

rust_audio_api/nodes/
convolver.rs

1use crate::types::{AUDIO_UNIT_SIZE, AudioUnit};
2use crossbeam_channel::{Sender, bounded};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU32, AtomicUsize, Ordering};
5use thread_priority::*;
6
7/// Configuration for the [`ConvolverNode`].
8pub struct ConvolverConfig {
9    /// Whether the convolution should be performed in stereo.
10    pub stereo: bool,
11    /// The exponent for partitioning the Impulse Response (IR).
12    /// Controls how the IR is divided into blocks for processing.
13    pub growth_exponent: u32,
14    /// The size of the first block (Block 0) in samples.
15    pub block_0_size: usize,
16}
17
18impl Default for ConvolverConfig {
19    fn default() -> Self {
20        Self {
21            stereo: true,
22            growth_exponent: 2,
23            block_0_size: AUDIO_UNIT_SIZE * 4,
24        }
25    }
26}
27
28pub struct AtomicF32(AtomicU32);
29
30impl AtomicF32 {
31    #[inline(always)]
32    pub fn new(v: f32) -> Self {
33        Self(AtomicU32::new(v.to_bits()))
34    }
35
36    #[inline(always)]
37    pub fn load(&self, order: Ordering) -> f32 {
38        f32::from_bits(self.0.load(order))
39    }
40
41    #[inline(always)]
42    pub fn store(&self, val: f32, order: Ordering) {
43        self.0.store(val.to_bits(), order);
44    }
45
46    #[inline(always)]
47    pub fn fetch_add(&self, val: f32, order: Ordering) {
48        let mut current = self.0.load(order);
49        loop {
50            let current_f32 = f32::from_bits(current);
51            let new_f32 = current_f32 + val;
52            let new_bits = new_f32.to_bits();
53            match self
54                .0
55                .compare_exchange_weak(current, new_bits, order, order)
56            {
57                Ok(_) => break,
58                Err(c) => current = c,
59            }
60        }
61    }
62
63    #[inline(always)]
64    pub fn swap(&self, val: f32, order: Ordering) -> f32 {
65        let old = self.0.swap(val.to_bits(), order);
66        f32::from_bits(old)
67    }
68}
69
70struct PartitionBlock {
71    size: usize,
72    offset: usize,
73    fft_data_l: Arc<[rustfft::num_complex::Complex<f32>]>,
74    fft_data_r: Arc<[rustfft::num_complex::Complex<f32>]>,
75    fft_plan: Arc<dyn realfft::RealToComplex<f32>>,
76    ifft_plan: Arc<dyn realfft::ComplexToReal<f32>>,
77}
78
79#[derive(Clone, Copy)]
80struct TaskMsg {
81    block_index: usize,
82    carry_read_ptr: usize,
83    history_write_ptr: usize,
84}
85
86/// A node that performs real-time convolution against an Impulse Response (IR).
87///
88/// Convolver is used for effects like reverb, speaker modeling, or virtual acoustics.
89/// It uses a partitioned convolution algorithm with background worker threads to
90/// achieve low-latency performance even with long IRs.
91pub struct ConvolverNode {
92    stereo: bool,
93    block_0_l: Vec<f32>,
94    block_0_r: Vec<f32>,
95    b0_out_l: Vec<f32>,
96    b0_out_r: Vec<f32>,
97    task_tx: Sender<TaskMsg>,
98
99    carry_buffer_l: Arc<Vec<AtomicF32>>,
100    carry_buffer_r: Arc<Vec<AtomicF32>>,
101    carry_mask: usize,
102    carry_read_ptr: usize,
103
104    history_buffer_l: Arc<Vec<AtomicF32>>,
105    history_buffer_r: Arc<Vec<AtomicF32>>,
106    history_mask: usize,
107    history_write_ptr: usize,
108
109    partition_blocks: Arc<[PartitionBlock]>,
110
111    shared_read_ptr: Arc<AtomicUsize>,
112    drop_count: Arc<AtomicUsize>,
113}
114
115impl ConvolverNode {
116    /// Creates a `ConvolverNode` by loading an Impulse Response (IR) from a WAV file.
117    ///
118    /// # Parameters
119    /// - `path`: Path to the WAV file.
120    /// - `target_sample_rate`: Target sample rate for processing.
121    /// - `max_len`: Optional maximum length (in samples) to truncate the IR.
122    pub fn from_file(
123        path: &str,
124        target_sample_rate: u32,
125        max_len: Option<usize>,
126    ) -> anyhow::Result<Self> {
127        Self::from_file_with_config(
128            path,
129            target_sample_rate,
130            max_len,
131            ConvolverConfig::default(),
132        )
133    }
134
135    /// Creates a `ConvolverNode` from a WAV file with custom configuration.
136    pub fn from_file_with_config(
137        path: &str,
138        target_sample_rate: u32,
139        max_len: Option<usize>,
140        config: ConvolverConfig,
141    ) -> anyhow::Result<Self> {
142        let mut reader = hound::WavReader::open(path)?;
143        let spec = reader.spec();
144        let mut ir = Vec::new();
145
146        if spec.sample_format == hound::SampleFormat::Float {
147            let mut iter = reader.samples::<f32>();
148            while let Some(Ok(l)) = iter.next() {
149                let r = if spec.channels == 2 {
150                    iter.next().unwrap().unwrap_or(l)
151                } else {
152                    l
153                };
154                ir.push([l, r]);
155            }
156        } else {
157            panic!("Unexpected IR file format")
158        }
159
160        let mut ir = if spec.sample_rate != target_sample_rate {
161            Self::resample_ir(&ir, spec.sample_rate, target_sample_rate)
162        } else {
163            ir
164        };
165
166        if let Some(max) = max_len
167            && ir.len() > max
168        {
169            ir.truncate(max);
170
171            // Apply fade-out to avoid artifacts from abrupt truncation (fade-out last 100ms)
172            let fade_len = (target_sample_rate as f32 * 0.1) as usize;
173            let fade_len = fade_len.min(max);
174            for i in 0..fade_len {
175                let idx = max - 1 - i;
176                let fade_gain = i as f32 / fade_len as f32;
177                // Use exponential or smooth fade-out; here we use simple linear
178                ir[idx][0] *= fade_gain;
179                ir[idx][1] *= fade_gain;
180            }
181        }
182
183        Ok(Self::with_config(&ir, config))
184    }
185
186    fn resample_ir(ir: &[[f32; 2]], from_hz: u32, to_hz: u32) -> Vec<[f32; 2]> {
187        use dasp::signal::Signal;
188        let signal = dasp::signal::from_iter(ir.iter().cloned());
189        let ring_buffer = dasp::ring_buffer::Fixed::from([[0.0; 2]; AUDIO_UNIT_SIZE]);
190        let sinc = dasp::interpolate::sinc::Sinc::new(ring_buffer);
191        let mut converter = signal.from_hz_to_hz(sinc, from_hz as f64, to_hz as f64);
192
193        let new_len = (ir.len() as f64 * (to_hz as f64 / from_hz as f64)).ceil() as usize;
194        let mut new_ir = Vec::with_capacity(new_len);
195        for _ in 0..new_len {
196            new_ir.push(converter.next());
197        }
198        new_ir
199    }
200
201    pub fn new(ir: &[[f32; 2]]) -> Self {
202        Self::with_config(ir, ConvolverConfig::default())
203    }
204
205    pub fn with_config(ir: &[[f32; 2]], config: ConvolverConfig) -> Self {
206        let stereo = config.stereo;
207        let (b0_l_vec, b0_r_vec, blocks_info) =
208            Self::partition_ir(ir, config.growth_exponent, config.block_0_size);
209
210        let max_block_size = blocks_info
211            .last()
212            .map(|b| b.size)
213            .unwrap_or(AUDIO_UNIT_SIZE);
214        let mut capacity = (ir.len() + max_block_size * 2).next_power_of_two() * 4;
215        if capacity < 65536 {
216            capacity = 65536;
217        }
218        let carry_mask = capacity - 1;
219
220        let carry_buffer_l = Arc::new(
221            (0..capacity)
222                .map(|_| AtomicF32::new(0.0))
223                .collect::<Vec<_>>(),
224        );
225        let carry_buffer_r = Arc::new(
226            (0..capacity)
227                .map(|_| AtomicF32::new(0.0))
228                .collect::<Vec<_>>(),
229        );
230
231        let mut history_capacity = max_block_size.next_power_of_two();
232        if history_capacity < 65536 {
233            history_capacity = 65536;
234        }
235        let history_mask = history_capacity - 1;
236
237        let history_buffer_l = Arc::new(
238            (0..history_capacity)
239                .map(|_| AtomicF32::new(0.0))
240                .collect::<Vec<_>>(),
241        );
242        let history_buffer_r = Arc::new(
243            (0..history_capacity)
244                .map(|_| AtomicF32::new(0.0))
245                .collect::<Vec<_>>(),
246        );
247
248        let drop_count = Arc::new(AtomicUsize::new(0));
249        let shared_read_ptr = Arc::new(AtomicUsize::new(0));
250
251        let b0_l = b0_l_vec.clone();
252        let b0_r = b0_r_vec.clone();
253        let b0_out_len = AUDIO_UNIT_SIZE + config.block_0_size - 1;
254        let b0_out_l = vec![0.0f32; b0_out_len];
255        let b0_out_r = vec![0.0f32; b0_out_len];
256
257        let max_queue_len = 2048;
258        let (task_tx, rx) = bounded::<TaskMsg>(max_queue_len);
259
260        let partition_blocks: Arc<[PartitionBlock]> = blocks_info.into();
261        let num_workers = std::thread::available_parallelism()
262            .map(|x| x.get())
263            .unwrap_or(4);
264
265        for _ in 0..num_workers {
266            let rx = rx.clone();
267            let worker_carry_l = Arc::clone(&carry_buffer_l);
268            let worker_carry_r = Arc::clone(&carry_buffer_r);
269            let worker_hist_l = Arc::clone(&history_buffer_l);
270            let worker_hist_r = Arc::clone(&history_buffer_r);
271            let worker_drop_count = Arc::clone(&drop_count);
272            let worker_shared_read_ptr = Arc::clone(&shared_read_ptr);
273            let worker_blocks = Arc::clone(&partition_blocks);
274            let worker_stereo = stereo;
275            let global_hist_cap = history_capacity;
276            let global_hist_mask = history_mask;
277            let global_carry_mask = carry_mask;
278
279            std::thread::spawn(move || {
280                if let Err(e) = set_current_thread_priority(ThreadPriority::Max) {
281                    eprintln!(
282                        "Warning: Failed to set convolution block thread priority: {:?}",
283                        e
284                    );
285                }
286
287                let max_len2 = max_block_size * 2;
288                let max_out_len = max_block_size + 1;
289
290                let mut pad_l = vec![0.0f32; max_len2];
291                let mut pad_r = vec![0.0f32; max_len2];
292                let mut out_l_slice =
293                    vec![rustfft::num_complex::Complex::new(0.0, 0.0); max_out_len];
294                let mut out_r_slice =
295                    vec![rustfft::num_complex::Complex::new(0.0, 0.0); max_out_len];
296                let mut res_l = vec![0.0f32; max_len2];
297                let mut res_r = vec![0.0f32; max_len2];
298
299                while let Ok(task) = rx.recv() {
300                    let queue_len = rx.len();
301
302                    let max_queue_age = if queue_len > max_queue_len / 2 { 2 } else { 8 };
303                    if queue_len > max_queue_age {
304                        worker_drop_count.fetch_add(1, Ordering::Relaxed);
305                        continue;
306                    }
307
308                    let block = &worker_blocks[task.block_index];
309                    let s = block.size;
310                    let len2 = s * 2;
311                    let out_len = s + 1;
312
313                    let start_idx =
314                        (task.history_write_ptr + global_hist_cap - s) & global_hist_mask;
315
316                    for i in 0..s {
317                        pad_l[i] = worker_hist_l[(start_idx + i) & global_hist_mask]
318                            .load(Ordering::Relaxed);
319                    }
320                    pad_l[s..len2].fill(0.0);
321
322                    if worker_stereo {
323                        for i in 0..s {
324                            pad_r[i] = worker_hist_r[(start_idx + i) & global_hist_mask]
325                                .load(Ordering::Relaxed);
326                        }
327                        pad_r[s..len2].fill(0.0);
328                    }
329
330                    let pad_l_slice = &mut pad_l[..len2];
331                    block
332                        .fft_plan
333                        .process(pad_l_slice, &mut out_l_slice[..out_len])
334                        .unwrap();
335
336                    if worker_stereo {
337                        let pad_r_slice = &mut pad_r[..len2];
338                        block
339                            .fft_plan
340                            .process(pad_r_slice, &mut out_r_slice[..out_len])
341                            .unwrap();
342                    }
343
344                    for i in 0..out_len {
345                        out_l_slice[i] *= block.fft_data_l[i];
346                        if worker_stereo {
347                            out_r_slice[i] *= block.fft_data_r[i];
348                        }
349                    }
350
351                    let res_l_mut = &mut res_l[..len2];
352                    block
353                        .ifft_plan
354                        .process(&mut out_l_slice[..out_len], res_l_mut)
355                        .unwrap();
356                    let scale = 1.0 / (len2 as f32);
357                    for x in res_l_mut.iter_mut() {
358                        *x *= scale;
359                    }
360
361                    if worker_stereo {
362                        let res_r_mut = &mut res_r[..len2];
363                        block
364                            .ifft_plan
365                            .process(&mut out_r_slice[..out_len], res_r_mut)
366                            .unwrap();
367                        for x in res_r_mut.iter_mut() {
368                            *x *= scale;
369                        }
370                    }
371
372                    let current_ptr = worker_shared_read_ptr.load(Ordering::Relaxed);
373                    let task_ptr = task.carry_read_ptr;
374                    let capacity = global_carry_mask + 1;
375
376                    let current_real = if current_ptr < task_ptr {
377                        current_ptr + capacity
378                    } else {
379                        current_ptr
380                    };
381
382                    let out_base_real =
383                        (task_ptr + AUDIO_UNIT_SIZE + block.offset).saturating_sub(s);
384                    let safe_current_real = current_real + AUDIO_UNIT_SIZE;
385
386                    let skip = safe_current_real.saturating_sub(out_base_real);
387
388                    const FADE_LEN: usize = AUDIO_UNIT_SIZE / 4;
389
390                    for i in skip..(len2 - 1) {
391                        let mut sample_l = res_l[i];
392                        let mut sample_r = res_r[i];
393
394                        // fade in
395                        let current_offset = i - skip;
396                        if current_offset < FADE_LEN {
397                            let gain = current_offset as f32 / FADE_LEN as f32;
398                            sample_l *= gain;
399                            if worker_stereo {
400                                sample_r *= gain;
401                            }
402                        }
403
404                        let idx = (out_base_real + i) & global_carry_mask;
405                        worker_carry_l[idx].fetch_add(sample_l, Ordering::Relaxed);
406                        if worker_stereo {
407                            worker_carry_r[idx].fetch_add(sample_r, Ordering::Relaxed);
408                        }
409                    }
410                }
411            });
412        }
413
414        Self {
415            stereo,
416            block_0_l: b0_l,
417            block_0_r: b0_r,
418            b0_out_l,
419            b0_out_r,
420            task_tx,
421            carry_buffer_l,
422            carry_buffer_r,
423            carry_mask,
424            carry_read_ptr: 0,
425
426            history_buffer_l,
427            history_buffer_r,
428            history_mask,
429            history_write_ptr: 0,
430
431            partition_blocks,
432
433            shared_read_ptr,
434            drop_count,
435        }
436    }
437
438    fn partition_ir(
439        ir: &[[f32; 2]],
440        growth_exponent: u32,
441        b0_len: usize,
442    ) -> (Vec<f32>, Vec<f32>, Vec<PartitionBlock>) {
443        let mut blocks = Vec::new();
444        let mut offset = 0;
445        let growth_factor = growth_exponent.max(1) as usize;
446
447        let b0_l = Self::take_slice_padded(ir, offset, b0_len, 0);
448        let b0_r = Self::take_slice_padded(ir, offset, b0_len, 1);
449        offset += b0_len;
450
451        let mut current_size = AUDIO_UNIT_SIZE * (growth_exponent as usize);
452        let mut planner = realfft::RealFftPlanner::<f32>::new();
453
454        while offset < ir.len() {
455            let len = current_size;
456            let len2 = len * 2;
457            let l_slice = Self::take_slice_padded(ir, offset, len, 0);
458            let r_slice = Self::take_slice_padded(ir, offset, len, 1);
459
460            let fft_fwd = planner.plan_fft_forward(len2);
461            let fft_inv = planner.plan_fft_inverse(len2);
462
463            let mut padded_l = vec![0.0; len2];
464            padded_l[..len].copy_from_slice(&l_slice);
465            let mut out_l = fft_fwd.make_output_vec();
466            fft_fwd.process(&mut padded_l, &mut out_l).unwrap();
467
468            let mut padded_r = vec![0.0; len2];
469            padded_r[..len].copy_from_slice(&r_slice);
470            let mut out_r = fft_fwd.make_output_vec();
471            fft_fwd.process(&mut padded_r, &mut out_r).unwrap();
472
473            blocks.push(PartitionBlock {
474                size: len,
475                offset,
476                fft_data_l: out_l.into(),
477                fft_data_r: out_r.into(),
478                fft_plan: fft_fwd,
479                ifft_plan: fft_inv,
480            });
481
482            offset += len;
483
484            // offset represents current input audio length
485            // current_size * growth_factor + AUDIO_UNIT_SIZE represents the threshold for growth
486            // When this condition is met, the new size can be calculated without waiting,
487            // preventing the block from becoming too large and starving the main thread.
488            if offset >= current_size * growth_factor + AUDIO_UNIT_SIZE {
489                current_size *= growth_factor;
490            }
491        }
492        (b0_l, b0_r, blocks)
493    }
494
495    fn take_slice_padded(ir: &[[f32; 2]], offset: usize, len: usize, ch: usize) -> Vec<f32> {
496        let mut res = vec![0.0; len];
497        if offset < ir.len() {
498            let take = (ir.len() - offset).min(len);
499            for i in 0..take {
500                res[i] = ir[offset + i][ch];
501            }
502        }
503        res
504    }
505
506    #[inline(always)]
507    pub fn process(&mut self, input: Option<&AudioUnit>, output: &mut AudioUnit) {
508        let empty_input = crate::types::empty_audio_unit();
509        let input_ref = input.unwrap_or(&empty_input);
510
511        let mut in_l = [0.0f32; AUDIO_UNIT_SIZE];
512        let mut in_r = [0.0f32; AUDIO_UNIT_SIZE];
513        for i in 0..AUDIO_UNIT_SIZE {
514            in_l[i] = input_ref[i][0];
515            in_r[i] = input_ref[i][1];
516        }
517
518        for i in 0..AUDIO_UNIT_SIZE {
519            let idx = (self.history_write_ptr + i) & self.history_mask;
520            self.history_buffer_l[idx].store(in_l[i], Ordering::Relaxed);
521            if self.stereo {
522                self.history_buffer_r[idx].store(in_r[i], Ordering::Relaxed);
523            }
524        }
525        self.history_write_ptr = (self.history_write_ptr + AUDIO_UNIT_SIZE) & self.history_mask;
526
527        let mask = self.carry_mask;
528
529        let b0_len = self.block_0_l.len();
530        let b0_out_len = AUDIO_UNIT_SIZE + b0_len - 1;
531        self.b0_out_l[..b0_out_len].fill(0.0);
532        if self.stereo {
533            self.b0_out_r[..b0_out_len].fill(0.0);
534        }
535
536        for i in 0..AUDIO_UNIT_SIZE {
537            let il = in_l[i];
538            let ir = in_r[i];
539            let out_l_slice = &mut self.b0_out_l[i..i + b0_len];
540
541            for (out_l, &b0l) in out_l_slice.iter_mut().zip(self.block_0_l.iter()) {
542                *out_l += il * b0l;
543            }
544            if self.stereo {
545                let out_r_slice = &mut self.b0_out_r[i..i + b0_len];
546                for (out_r, &b0r) in out_r_slice.iter_mut().zip(self.block_0_r.iter()) {
547                    *out_r += ir * b0r;
548                }
549            }
550        }
551
552        for i in 0..b0_out_len {
553            let idx = (self.carry_read_ptr + i) & mask;
554            self.carry_buffer_l[idx].fetch_add(self.b0_out_l[i], Ordering::Relaxed);
555            if self.stereo {
556                self.carry_buffer_r[idx].fetch_add(self.b0_out_r[i], Ordering::Relaxed);
557            }
558        }
559
560        for (i, out) in output.iter_mut().enumerate().take(AUDIO_UNIT_SIZE) {
561            let idx = (self.carry_read_ptr + i) & mask;
562            let out_l = self.carry_buffer_l[idx].swap(0.0, Ordering::Relaxed);
563            let out_r = self.carry_buffer_r[idx].swap(0.0, Ordering::Relaxed);
564
565            out[0] = out_l;
566            if self.stereo {
567                out[1] = out_r;
568            } else {
569                out[1] = out_l;
570            }
571        }
572
573        for (idx, block) in self.partition_blocks.iter().enumerate() {
574            if (self.carry_read_ptr + AUDIO_UNIT_SIZE).is_multiple_of(block.size) {
575                let task = TaskMsg {
576                    block_index: idx,
577                    carry_read_ptr: self.carry_read_ptr,
578                    history_write_ptr: self.history_write_ptr,
579                };
580                if self.task_tx.try_send(task).is_err() {
581                    self.drop_count.fetch_add(1, Ordering::Relaxed);
582                }
583            }
584        }
585
586        self.carry_read_ptr = (self.carry_read_ptr + AUDIO_UNIT_SIZE) & mask;
587        self.shared_read_ptr
588            .store(self.carry_read_ptr, Ordering::Relaxed);
589    }
590
591    pub fn get_drop_count(&self) -> usize {
592        self.drop_count.load(Ordering::Relaxed)
593    }
594
595    /// Returns a shared reference to the drop count.
596    ///
597    /// Drop count increases when the convolution worker threads fall behind
598    /// the real-time audio thread.
599    pub fn clone_drop_count(&self) -> Arc<AtomicUsize> {
600        Arc::clone(&self.drop_count)
601    }
602}