1use crate::types::{AUDIO_UNIT_SIZE, AudioUnit};
2use crossbeam_channel::{Sender, bounded};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU32, AtomicUsize, Ordering};
5use thread_priority::*;
6
7pub struct ConvolverConfig {
9 pub stereo: bool,
11 pub growth_exponent: u32,
14 pub block_0_size: usize,
16}
17
18impl Default for ConvolverConfig {
19 fn default() -> Self {
20 Self {
21 stereo: true,
22 growth_exponent: 2,
23 block_0_size: AUDIO_UNIT_SIZE * 4,
24 }
25 }
26}
27
28pub struct AtomicF32(AtomicU32);
29
30impl AtomicF32 {
31 #[inline(always)]
32 pub fn new(v: f32) -> Self {
33 Self(AtomicU32::new(v.to_bits()))
34 }
35
36 #[inline(always)]
37 pub fn load(&self, order: Ordering) -> f32 {
38 f32::from_bits(self.0.load(order))
39 }
40
41 #[inline(always)]
42 pub fn store(&self, val: f32, order: Ordering) {
43 self.0.store(val.to_bits(), order);
44 }
45
46 #[inline(always)]
47 pub fn fetch_add(&self, val: f32, order: Ordering) {
48 let mut current = self.0.load(order);
49 loop {
50 let current_f32 = f32::from_bits(current);
51 let new_f32 = current_f32 + val;
52 let new_bits = new_f32.to_bits();
53 match self
54 .0
55 .compare_exchange_weak(current, new_bits, order, order)
56 {
57 Ok(_) => break,
58 Err(c) => current = c,
59 }
60 }
61 }
62
63 #[inline(always)]
64 pub fn swap(&self, val: f32, order: Ordering) -> f32 {
65 let old = self.0.swap(val.to_bits(), order);
66 f32::from_bits(old)
67 }
68}
69
70struct PartitionBlock {
71 size: usize,
72 offset: usize,
73 fft_data_l: Arc<[rustfft::num_complex::Complex<f32>]>,
74 fft_data_r: Arc<[rustfft::num_complex::Complex<f32>]>,
75 fft_plan: Arc<dyn realfft::RealToComplex<f32>>,
76 ifft_plan: Arc<dyn realfft::ComplexToReal<f32>>,
77}
78
79#[derive(Clone, Copy)]
80struct TaskMsg {
81 block_index: usize,
82 carry_read_ptr: usize,
83 history_write_ptr: usize,
84}
85
86pub struct ConvolverNode {
92 stereo: bool,
93 block_0_l: Vec<f32>,
94 block_0_r: Vec<f32>,
95 b0_out_l: Vec<f32>,
96 b0_out_r: Vec<f32>,
97 task_tx: Sender<TaskMsg>,
98
99 carry_buffer_l: Arc<Vec<AtomicF32>>,
100 carry_buffer_r: Arc<Vec<AtomicF32>>,
101 carry_mask: usize,
102 carry_read_ptr: usize,
103
104 history_buffer_l: Arc<Vec<AtomicF32>>,
105 history_buffer_r: Arc<Vec<AtomicF32>>,
106 history_mask: usize,
107 history_write_ptr: usize,
108
109 partition_blocks: Arc<[PartitionBlock]>,
110
111 shared_read_ptr: Arc<AtomicUsize>,
112 drop_count: Arc<AtomicUsize>,
113}
114
115impl ConvolverNode {
116 pub fn from_file(
123 path: &str,
124 target_sample_rate: u32,
125 max_len: Option<usize>,
126 ) -> anyhow::Result<Self> {
127 Self::from_file_with_config(
128 path,
129 target_sample_rate,
130 max_len,
131 ConvolverConfig::default(),
132 )
133 }
134
135 pub fn from_file_with_config(
137 path: &str,
138 target_sample_rate: u32,
139 max_len: Option<usize>,
140 config: ConvolverConfig,
141 ) -> anyhow::Result<Self> {
142 let mut reader = hound::WavReader::open(path)?;
143 let spec = reader.spec();
144 let mut ir = Vec::new();
145
146 if spec.sample_format == hound::SampleFormat::Float {
147 let mut iter = reader.samples::<f32>();
148 while let Some(Ok(l)) = iter.next() {
149 let r = if spec.channels == 2 {
150 iter.next().unwrap().unwrap_or(l)
151 } else {
152 l
153 };
154 ir.push([l, r]);
155 }
156 } else {
157 panic!("Unexpected IR file format")
158 }
159
160 let mut ir = if spec.sample_rate != target_sample_rate {
161 Self::resample_ir(&ir, spec.sample_rate, target_sample_rate)
162 } else {
163 ir
164 };
165
166 if let Some(max) = max_len
167 && ir.len() > max
168 {
169 ir.truncate(max);
170
171 let fade_len = (target_sample_rate as f32 * 0.1) as usize;
173 let fade_len = fade_len.min(max);
174 for i in 0..fade_len {
175 let idx = max - 1 - i;
176 let fade_gain = i as f32 / fade_len as f32;
177 ir[idx][0] *= fade_gain;
179 ir[idx][1] *= fade_gain;
180 }
181 }
182
183 Ok(Self::with_config(&ir, config))
184 }
185
186 fn resample_ir(ir: &[[f32; 2]], from_hz: u32, to_hz: u32) -> Vec<[f32; 2]> {
187 use dasp::signal::Signal;
188 let signal = dasp::signal::from_iter(ir.iter().cloned());
189 let ring_buffer = dasp::ring_buffer::Fixed::from([[0.0; 2]; AUDIO_UNIT_SIZE]);
190 let sinc = dasp::interpolate::sinc::Sinc::new(ring_buffer);
191 let mut converter = signal.from_hz_to_hz(sinc, from_hz as f64, to_hz as f64);
192
193 let new_len = (ir.len() as f64 * (to_hz as f64 / from_hz as f64)).ceil() as usize;
194 let mut new_ir = Vec::with_capacity(new_len);
195 for _ in 0..new_len {
196 new_ir.push(converter.next());
197 }
198 new_ir
199 }
200
201 pub fn new(ir: &[[f32; 2]]) -> Self {
202 Self::with_config(ir, ConvolverConfig::default())
203 }
204
205 pub fn with_config(ir: &[[f32; 2]], config: ConvolverConfig) -> Self {
206 let stereo = config.stereo;
207 let (b0_l_vec, b0_r_vec, blocks_info) =
208 Self::partition_ir(ir, config.growth_exponent, config.block_0_size);
209
210 let max_block_size = blocks_info
211 .last()
212 .map(|b| b.size)
213 .unwrap_or(AUDIO_UNIT_SIZE);
214 let mut capacity = (ir.len() + max_block_size * 2).next_power_of_two() * 4;
215 if capacity < 65536 {
216 capacity = 65536;
217 }
218 let carry_mask = capacity - 1;
219
220 let carry_buffer_l = Arc::new(
221 (0..capacity)
222 .map(|_| AtomicF32::new(0.0))
223 .collect::<Vec<_>>(),
224 );
225 let carry_buffer_r = Arc::new(
226 (0..capacity)
227 .map(|_| AtomicF32::new(0.0))
228 .collect::<Vec<_>>(),
229 );
230
231 let mut history_capacity = max_block_size.next_power_of_two();
232 if history_capacity < 65536 {
233 history_capacity = 65536;
234 }
235 let history_mask = history_capacity - 1;
236
237 let history_buffer_l = Arc::new(
238 (0..history_capacity)
239 .map(|_| AtomicF32::new(0.0))
240 .collect::<Vec<_>>(),
241 );
242 let history_buffer_r = Arc::new(
243 (0..history_capacity)
244 .map(|_| AtomicF32::new(0.0))
245 .collect::<Vec<_>>(),
246 );
247
248 let drop_count = Arc::new(AtomicUsize::new(0));
249 let shared_read_ptr = Arc::new(AtomicUsize::new(0));
250
251 let b0_l = b0_l_vec.clone();
252 let b0_r = b0_r_vec.clone();
253 let b0_out_len = AUDIO_UNIT_SIZE + config.block_0_size - 1;
254 let b0_out_l = vec![0.0f32; b0_out_len];
255 let b0_out_r = vec![0.0f32; b0_out_len];
256
257 let max_queue_len = 2048;
258 let (task_tx, rx) = bounded::<TaskMsg>(max_queue_len);
259
260 let partition_blocks: Arc<[PartitionBlock]> = blocks_info.into();
261 let num_workers = std::thread::available_parallelism()
262 .map(|x| x.get())
263 .unwrap_or(4);
264
265 for _ in 0..num_workers {
266 let rx = rx.clone();
267 let worker_carry_l = Arc::clone(&carry_buffer_l);
268 let worker_carry_r = Arc::clone(&carry_buffer_r);
269 let worker_hist_l = Arc::clone(&history_buffer_l);
270 let worker_hist_r = Arc::clone(&history_buffer_r);
271 let worker_drop_count = Arc::clone(&drop_count);
272 let worker_shared_read_ptr = Arc::clone(&shared_read_ptr);
273 let worker_blocks = Arc::clone(&partition_blocks);
274 let worker_stereo = stereo;
275 let global_hist_cap = history_capacity;
276 let global_hist_mask = history_mask;
277 let global_carry_mask = carry_mask;
278
279 std::thread::spawn(move || {
280 if let Err(e) = set_current_thread_priority(ThreadPriority::Max) {
281 eprintln!(
282 "Warning: Failed to set convolution block thread priority: {:?}",
283 e
284 );
285 }
286
287 let max_len2 = max_block_size * 2;
288 let max_out_len = max_block_size + 1;
289
290 let mut pad_l = vec![0.0f32; max_len2];
291 let mut pad_r = vec![0.0f32; max_len2];
292 let mut out_l_slice =
293 vec![rustfft::num_complex::Complex::new(0.0, 0.0); max_out_len];
294 let mut out_r_slice =
295 vec![rustfft::num_complex::Complex::new(0.0, 0.0); max_out_len];
296 let mut res_l = vec![0.0f32; max_len2];
297 let mut res_r = vec![0.0f32; max_len2];
298
299 while let Ok(task) = rx.recv() {
300 let queue_len = rx.len();
301
302 let max_queue_age = if queue_len > max_queue_len / 2 { 2 } else { 8 };
303 if queue_len > max_queue_age {
304 worker_drop_count.fetch_add(1, Ordering::Relaxed);
305 continue;
306 }
307
308 let block = &worker_blocks[task.block_index];
309 let s = block.size;
310 let len2 = s * 2;
311 let out_len = s + 1;
312
313 let start_idx =
314 (task.history_write_ptr + global_hist_cap - s) & global_hist_mask;
315
316 for i in 0..s {
317 pad_l[i] = worker_hist_l[(start_idx + i) & global_hist_mask]
318 .load(Ordering::Relaxed);
319 }
320 pad_l[s..len2].fill(0.0);
321
322 if worker_stereo {
323 for i in 0..s {
324 pad_r[i] = worker_hist_r[(start_idx + i) & global_hist_mask]
325 .load(Ordering::Relaxed);
326 }
327 pad_r[s..len2].fill(0.0);
328 }
329
330 let pad_l_slice = &mut pad_l[..len2];
331 block
332 .fft_plan
333 .process(pad_l_slice, &mut out_l_slice[..out_len])
334 .unwrap();
335
336 if worker_stereo {
337 let pad_r_slice = &mut pad_r[..len2];
338 block
339 .fft_plan
340 .process(pad_r_slice, &mut out_r_slice[..out_len])
341 .unwrap();
342 }
343
344 for i in 0..out_len {
345 out_l_slice[i] *= block.fft_data_l[i];
346 if worker_stereo {
347 out_r_slice[i] *= block.fft_data_r[i];
348 }
349 }
350
351 let res_l_mut = &mut res_l[..len2];
352 block
353 .ifft_plan
354 .process(&mut out_l_slice[..out_len], res_l_mut)
355 .unwrap();
356 let scale = 1.0 / (len2 as f32);
357 for x in res_l_mut.iter_mut() {
358 *x *= scale;
359 }
360
361 if worker_stereo {
362 let res_r_mut = &mut res_r[..len2];
363 block
364 .ifft_plan
365 .process(&mut out_r_slice[..out_len], res_r_mut)
366 .unwrap();
367 for x in res_r_mut.iter_mut() {
368 *x *= scale;
369 }
370 }
371
372 let current_ptr = worker_shared_read_ptr.load(Ordering::Relaxed);
373 let task_ptr = task.carry_read_ptr;
374 let capacity = global_carry_mask + 1;
375
376 let current_real = if current_ptr < task_ptr {
377 current_ptr + capacity
378 } else {
379 current_ptr
380 };
381
382 let out_base_real =
383 (task_ptr + AUDIO_UNIT_SIZE + block.offset).saturating_sub(s);
384 let safe_current_real = current_real + AUDIO_UNIT_SIZE;
385
386 let skip = safe_current_real.saturating_sub(out_base_real);
387
388 const FADE_LEN: usize = AUDIO_UNIT_SIZE / 4;
389
390 for i in skip..(len2 - 1) {
391 let mut sample_l = res_l[i];
392 let mut sample_r = res_r[i];
393
394 let current_offset = i - skip;
396 if current_offset < FADE_LEN {
397 let gain = current_offset as f32 / FADE_LEN as f32;
398 sample_l *= gain;
399 if worker_stereo {
400 sample_r *= gain;
401 }
402 }
403
404 let idx = (out_base_real + i) & global_carry_mask;
405 worker_carry_l[idx].fetch_add(sample_l, Ordering::Relaxed);
406 if worker_stereo {
407 worker_carry_r[idx].fetch_add(sample_r, Ordering::Relaxed);
408 }
409 }
410 }
411 });
412 }
413
414 Self {
415 stereo,
416 block_0_l: b0_l,
417 block_0_r: b0_r,
418 b0_out_l,
419 b0_out_r,
420 task_tx,
421 carry_buffer_l,
422 carry_buffer_r,
423 carry_mask,
424 carry_read_ptr: 0,
425
426 history_buffer_l,
427 history_buffer_r,
428 history_mask,
429 history_write_ptr: 0,
430
431 partition_blocks,
432
433 shared_read_ptr,
434 drop_count,
435 }
436 }
437
438 fn partition_ir(
439 ir: &[[f32; 2]],
440 growth_exponent: u32,
441 b0_len: usize,
442 ) -> (Vec<f32>, Vec<f32>, Vec<PartitionBlock>) {
443 let mut blocks = Vec::new();
444 let mut offset = 0;
445 let growth_factor = growth_exponent.max(1) as usize;
446
447 let b0_l = Self::take_slice_padded(ir, offset, b0_len, 0);
448 let b0_r = Self::take_slice_padded(ir, offset, b0_len, 1);
449 offset += b0_len;
450
451 let mut current_size = AUDIO_UNIT_SIZE * (growth_exponent as usize);
452 let mut planner = realfft::RealFftPlanner::<f32>::new();
453
454 while offset < ir.len() {
455 let len = current_size;
456 let len2 = len * 2;
457 let l_slice = Self::take_slice_padded(ir, offset, len, 0);
458 let r_slice = Self::take_slice_padded(ir, offset, len, 1);
459
460 let fft_fwd = planner.plan_fft_forward(len2);
461 let fft_inv = planner.plan_fft_inverse(len2);
462
463 let mut padded_l = vec![0.0; len2];
464 padded_l[..len].copy_from_slice(&l_slice);
465 let mut out_l = fft_fwd.make_output_vec();
466 fft_fwd.process(&mut padded_l, &mut out_l).unwrap();
467
468 let mut padded_r = vec![0.0; len2];
469 padded_r[..len].copy_from_slice(&r_slice);
470 let mut out_r = fft_fwd.make_output_vec();
471 fft_fwd.process(&mut padded_r, &mut out_r).unwrap();
472
473 blocks.push(PartitionBlock {
474 size: len,
475 offset,
476 fft_data_l: out_l.into(),
477 fft_data_r: out_r.into(),
478 fft_plan: fft_fwd,
479 ifft_plan: fft_inv,
480 });
481
482 offset += len;
483
484 if offset >= current_size * growth_factor + AUDIO_UNIT_SIZE {
489 current_size *= growth_factor;
490 }
491 }
492 (b0_l, b0_r, blocks)
493 }
494
495 fn take_slice_padded(ir: &[[f32; 2]], offset: usize, len: usize, ch: usize) -> Vec<f32> {
496 let mut res = vec![0.0; len];
497 if offset < ir.len() {
498 let take = (ir.len() - offset).min(len);
499 for i in 0..take {
500 res[i] = ir[offset + i][ch];
501 }
502 }
503 res
504 }
505
506 #[inline(always)]
507 pub fn process(&mut self, input: Option<&AudioUnit>, output: &mut AudioUnit) {
508 let empty_input = crate::types::empty_audio_unit();
509 let input_ref = input.unwrap_or(&empty_input);
510
511 let mut in_l = [0.0f32; AUDIO_UNIT_SIZE];
512 let mut in_r = [0.0f32; AUDIO_UNIT_SIZE];
513 for i in 0..AUDIO_UNIT_SIZE {
514 in_l[i] = input_ref[i][0];
515 in_r[i] = input_ref[i][1];
516 }
517
518 for i in 0..AUDIO_UNIT_SIZE {
519 let idx = (self.history_write_ptr + i) & self.history_mask;
520 self.history_buffer_l[idx].store(in_l[i], Ordering::Relaxed);
521 if self.stereo {
522 self.history_buffer_r[idx].store(in_r[i], Ordering::Relaxed);
523 }
524 }
525 self.history_write_ptr = (self.history_write_ptr + AUDIO_UNIT_SIZE) & self.history_mask;
526
527 let mask = self.carry_mask;
528
529 let b0_len = self.block_0_l.len();
530 let b0_out_len = AUDIO_UNIT_SIZE + b0_len - 1;
531 self.b0_out_l[..b0_out_len].fill(0.0);
532 if self.stereo {
533 self.b0_out_r[..b0_out_len].fill(0.0);
534 }
535
536 for i in 0..AUDIO_UNIT_SIZE {
537 let il = in_l[i];
538 let ir = in_r[i];
539 let out_l_slice = &mut self.b0_out_l[i..i + b0_len];
540
541 for (out_l, &b0l) in out_l_slice.iter_mut().zip(self.block_0_l.iter()) {
542 *out_l += il * b0l;
543 }
544 if self.stereo {
545 let out_r_slice = &mut self.b0_out_r[i..i + b0_len];
546 for (out_r, &b0r) in out_r_slice.iter_mut().zip(self.block_0_r.iter()) {
547 *out_r += ir * b0r;
548 }
549 }
550 }
551
552 for i in 0..b0_out_len {
553 let idx = (self.carry_read_ptr + i) & mask;
554 self.carry_buffer_l[idx].fetch_add(self.b0_out_l[i], Ordering::Relaxed);
555 if self.stereo {
556 self.carry_buffer_r[idx].fetch_add(self.b0_out_r[i], Ordering::Relaxed);
557 }
558 }
559
560 for (i, out) in output.iter_mut().enumerate().take(AUDIO_UNIT_SIZE) {
561 let idx = (self.carry_read_ptr + i) & mask;
562 let out_l = self.carry_buffer_l[idx].swap(0.0, Ordering::Relaxed);
563 let out_r = self.carry_buffer_r[idx].swap(0.0, Ordering::Relaxed);
564
565 out[0] = out_l;
566 if self.stereo {
567 out[1] = out_r;
568 } else {
569 out[1] = out_l;
570 }
571 }
572
573 for (idx, block) in self.partition_blocks.iter().enumerate() {
574 if (self.carry_read_ptr + AUDIO_UNIT_SIZE).is_multiple_of(block.size) {
575 let task = TaskMsg {
576 block_index: idx,
577 carry_read_ptr: self.carry_read_ptr,
578 history_write_ptr: self.history_write_ptr,
579 };
580 if self.task_tx.try_send(task).is_err() {
581 self.drop_count.fetch_add(1, Ordering::Relaxed);
582 }
583 }
584 }
585
586 self.carry_read_ptr = (self.carry_read_ptr + AUDIO_UNIT_SIZE) & mask;
587 self.shared_read_ptr
588 .store(self.carry_read_ptr, Ordering::Relaxed);
589 }
590
591 pub fn get_drop_count(&self) -> usize {
592 self.drop_count.load(Ordering::Relaxed)
593 }
594
595 pub fn clone_drop_count(&self) -> Arc<AtomicUsize> {
600 Arc::clone(&self.drop_count)
601 }
602}