1use std::any::Any;
2
3use rubato::{FftFixedInOut, Resampler as _};
4
5use crate::{
6 context::{AudioContextRegistration, BaseAudioContext},
7 render::{AudioParamValues, AudioProcessor, AudioRenderQuantum, AudioWorkletGlobalScope},
8 RENDER_QUANTUM_SIZE,
9};
10
11use super::{AudioNode, AudioNodeOptions, ChannelConfig};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15#[derive(Default)]
17pub enum OverSampleType {
18 #[default]
20 None,
21 X2,
23 X4,
25}
26
27impl From<u32> for OverSampleType {
28 fn from(i: u32) -> Self {
29 match i {
30 0 => OverSampleType::None,
31 1 => OverSampleType::X2,
32 2 => OverSampleType::X4,
33 _ => unreachable!(),
34 }
35 }
36}
37
38#[derive(Clone, Debug)]
44pub struct WaveShaperOptions {
45 pub curve: Option<Vec<f32>>,
47 pub oversample: OverSampleType,
49 pub audio_node_options: AudioNodeOptions,
51}
52
53impl Default for WaveShaperOptions {
54 fn default() -> Self {
55 Self {
56 oversample: OverSampleType::None,
57 curve: None,
58 audio_node_options: AudioNodeOptions::default(),
59 }
60 }
61}
62
63#[derive(Debug)]
118pub struct WaveShaperNode {
119 registration: AudioContextRegistration,
121 channel_config: ChannelConfig,
123 curve: Option<Vec<f32>>,
125 oversample: OverSampleType,
127}
128
129impl AudioNode for WaveShaperNode {
130 fn registration(&self) -> &AudioContextRegistration {
131 &self.registration
132 }
133
134 fn channel_config(&self) -> &ChannelConfig {
135 &self.channel_config
136 }
137
138 fn number_of_inputs(&self) -> usize {
139 1
140 }
141
142 fn number_of_outputs(&self) -> usize {
143 1
144 }
145}
146
147impl WaveShaperNode {
148 pub fn new<C: BaseAudioContext>(context: &C, options: WaveShaperOptions) -> Self {
155 let WaveShaperOptions {
156 oversample,
157 curve,
158 audio_node_options: channel_config,
159 } = options;
160
161 let mut node = context.base().register(move |registration| {
162 let sample_rate = context.sample_rate() as usize;
163
164 let renderer = WaveShaperRenderer::new(RendererConfig {
165 oversample,
166 sample_rate,
167 });
168
169 let node = Self {
170 registration,
171 channel_config: channel_config.into(),
172 curve: None,
173 oversample,
174 };
175
176 (node, Box::new(renderer))
177 });
178
179 if let Some(curve) = curve {
181 node.set_curve(curve);
182 }
183
184 node
185 }
186
187 #[must_use]
189 pub fn curve(&self) -> Option<&[f32]> {
190 self.curve.as_deref()
191 }
192
193 pub fn set_curve(&mut self, curve: Vec<f32>) {
204 assert!(
205 self.curve.is_none(),
206 "InvalidStateError - cannot assign curve twice",
207 );
208
209 let clone = curve.clone();
210
211 self.curve = Some(curve);
212 self.registration.post_message(Some(clone));
213 }
214
215 #[must_use]
217 pub fn oversample(&self) -> OverSampleType {
218 self.oversample
219 }
220
221 pub fn set_oversample(&mut self, oversample: OverSampleType) {
227 self.oversample = oversample;
228 self.registration.post_message(oversample);
229 }
230}
231
232#[derive(Debug, Clone, PartialEq, Eq)]
233struct ResamplerConfig {
234 channels: usize,
235 chunk_size_in: usize,
236 sample_rate_in: usize,
237 sample_rate_out: usize,
238}
239
240impl ResamplerConfig {
241 fn upsample_x2(channels: usize, sample_rate: usize) -> Self {
242 let chunk_size_in = RENDER_QUANTUM_SIZE;
243 let sample_rate_in = sample_rate;
244 let sample_rate_out = sample_rate * 2;
245 Self {
246 channels,
247 chunk_size_in,
248 sample_rate_in,
249 sample_rate_out,
250 }
251 }
252
253 fn upsample_x4(channels: usize, sample_rate: usize) -> Self {
254 let chunk_size_in = RENDER_QUANTUM_SIZE;
255 let sample_rate_in = sample_rate;
256 let sample_rate_out = sample_rate * 4;
257 Self {
258 channels,
259 chunk_size_in,
260 sample_rate_in,
261 sample_rate_out,
262 }
263 }
264
265 fn downsample_x2(channels: usize, sample_rate: usize) -> Self {
266 let chunk_size_in = RENDER_QUANTUM_SIZE * 2;
267 let sample_rate_in = sample_rate * 2;
268 let sample_rate_out = sample_rate;
269 Self {
270 channels,
271 chunk_size_in,
272 sample_rate_in,
273 sample_rate_out,
274 }
275 }
276
277 fn downsample_x4(channels: usize, sample_rate: usize) -> Self {
278 let chunk_size_in = RENDER_QUANTUM_SIZE * 4;
279 let sample_rate_in = sample_rate * 4;
280 let sample_rate_out = sample_rate;
281 Self {
282 channels,
283 chunk_size_in,
284 sample_rate_in,
285 sample_rate_out,
286 }
287 }
288}
289
290struct Resampler {
291 config: ResamplerConfig,
292 processor: FftFixedInOut<f32>,
293 samples_out: Vec<Vec<f32>>,
294}
295
296impl Resampler {
297 fn new(config: ResamplerConfig) -> Self {
298 let ResamplerConfig {
299 channels,
300 chunk_size_in,
301 sample_rate_in,
302 sample_rate_out,
303 } = &config;
304
305 let processor =
306 FftFixedInOut::new(*sample_rate_in, *sample_rate_out, *chunk_size_in, *channels)
307 .unwrap();
308
309 let samples_out = processor.output_buffer_allocate(true);
310
311 Self {
312 config,
313 processor,
314 samples_out,
315 }
316 }
317
318 fn process<T>(&mut self, samples_in: &[T])
319 where
320 T: AsRef<[f32]>,
321 {
322 debug_assert_eq!(self.config.channels, samples_in.len());
323 debug_assert!(samples_in
325 .iter()
326 .all(|channel| channel.as_ref().len() == self.processor.input_frames_next()));
327 let (in_len, out_len) = self
328 .processor
329 .process_into_buffer(samples_in, &mut self.samples_out[..], None)
330 .unwrap();
331 debug_assert_eq!(in_len, samples_in[0].as_ref().len());
333 debug_assert!(self
335 .samples_out
336 .iter()
337 .all(|channel| channel.len() == out_len));
338 }
339
340 fn samples_out(&self) -> &[Vec<f32>] {
341 &self.samples_out[..]
342 }
343
344 fn samples_out_mut(&mut self) -> &mut [Vec<f32>] {
345 &mut self.samples_out[..]
346 }
347}
348
349struct RendererConfig {
352 oversample: OverSampleType,
354 sample_rate: usize,
356}
357
358struct WaveShaperRenderer {
360 oversample: OverSampleType,
362 curve: Option<Vec<f32>>,
364 sample_rate: usize,
366 channels_x2: usize,
368 channels_x4: usize,
370 upsampler_x2: Resampler,
372 upsampler_x4: Resampler,
374 downsampler_x2: Resampler,
376 downsampler_x4: Resampler,
378 can_propagate_silence: bool,
381}
382
383impl AudioProcessor for WaveShaperRenderer {
384 fn process(
385 &mut self,
386 inputs: &[AudioRenderQuantum],
387 outputs: &mut [AudioRenderQuantum],
388 _params: AudioParamValues<'_>,
389 _scope: &AudioWorkletGlobalScope,
390 ) -> bool {
391 let input = &inputs[0];
393 let output = &mut outputs[0];
394
395 if input.is_silent() && self.can_propagate_silence {
396 output.make_silent();
397 return false;
398 }
399
400 *output = input.clone();
401
402 if let Some(curve) = &self.curve {
403 match self.oversample {
404 OverSampleType::None => {
405 output.modify_channels(|channel| {
406 channel.iter_mut().for_each(|o| *o = apply_curve(curve, *o));
407 });
408 }
409 OverSampleType::X2 => {
410 let channels = output.channels();
411
412 if channels.len() != self.channels_x2 {
414 self.channels_x2 = channels.len();
415
416 self.upsampler_x2 = Resampler::new(ResamplerConfig::upsample_x2(
417 self.channels_x2,
418 self.sample_rate,
419 ));
420
421 self.downsampler_x2 = Resampler::new(ResamplerConfig::downsample_x2(
422 self.channels_x2,
423 self.sample_rate,
424 ));
425 }
426
427 self.upsampler_x2.process(channels);
428 for channel in self.upsampler_x2.samples_out_mut().iter_mut() {
429 for s in channel.iter_mut() {
430 *s = apply_curve(curve, *s);
431 }
432 }
433
434 self.downsampler_x2.process(self.upsampler_x2.samples_out());
435
436 for (processed, output) in self
437 .downsampler_x2
438 .samples_out()
439 .iter()
440 .zip(output.channels_mut())
441 {
442 output.copy_from_slice(&processed[..]);
443 }
444 }
445 OverSampleType::X4 => {
446 let channels = output.channels();
447
448 if channels.len() != self.channels_x4 {
450 self.channels_x4 = channels.len();
451
452 self.upsampler_x4 = Resampler::new(ResamplerConfig::upsample_x4(
453 self.channels_x4,
454 self.sample_rate,
455 ));
456
457 self.downsampler_x4 = Resampler::new(ResamplerConfig::downsample_x4(
458 self.channels_x4,
459 self.sample_rate,
460 ));
461 }
462
463 self.upsampler_x4.process(channels);
464
465 for channel in self.upsampler_x4.samples_out_mut().iter_mut() {
466 for s in channel.iter_mut() {
467 *s = apply_curve(curve, *s);
468 }
469 }
470
471 self.downsampler_x4.process(self.upsampler_x4.samples_out());
472
473 for (processed, output) in self
474 .downsampler_x4
475 .samples_out()
476 .iter()
477 .zip(output.channels_mut())
478 {
479 output.copy_from_slice(&processed[..]);
480 }
481 }
482 }
483 }
484
485 false
487 }
488
489 fn onmessage(&mut self, msg: &mut dyn Any) {
490 if let Some(&oversample) = msg.downcast_ref::<OverSampleType>() {
491 self.oversample = oversample;
492 return;
493 }
494
495 if let Some(curve) = msg.downcast_mut::<Option<Vec<f32>>>() {
496 std::mem::swap(&mut self.curve, curve);
497
498 self.can_propagate_silence = if let Some(curve) = &self.curve {
500 if curve.len() % 2 == 1 {
501 curve[curve.len() / 2].abs() < 1e-9
502 } else {
503 let a = curve[curve.len() / 2 - 1];
504 let b = curve[curve.len() / 2];
505 ((a + b) / 2.).abs() < 1e-9
506 }
507 } else {
508 true
509 };
510
511 return;
512 }
513
514 log::warn!("WaveShaperRenderer: Dropping incoming message {msg:?}");
515 }
516}
517
518impl WaveShaperRenderer {
519 #[allow(clippy::missing_const_for_fn)]
521 fn new(config: RendererConfig) -> Self {
522 let RendererConfig {
523 sample_rate,
524 oversample,
525 } = config;
526
527 let channels_x2 = 1;
528 let channels_x4 = 1;
529
530 let upsampler_x2 = Resampler::new(ResamplerConfig::upsample_x2(channels_x2, sample_rate));
531
532 let downsampler_x2 =
533 Resampler::new(ResamplerConfig::downsample_x2(channels_x2, sample_rate));
534
535 let upsampler_x4 = Resampler::new(ResamplerConfig::upsample_x4(channels_x2, sample_rate));
536
537 let downsampler_x4 =
538 Resampler::new(ResamplerConfig::downsample_x4(channels_x2, sample_rate));
539
540 Self {
541 oversample,
542 curve: None,
543 sample_rate,
544 channels_x2,
545 channels_x4,
546 upsampler_x2,
547 upsampler_x4,
548 downsampler_x2,
549 downsampler_x4,
550 can_propagate_silence: true,
551 }
552 }
553}
554
555#[inline]
556fn apply_curve(curve: &[f32], input: f32) -> f32 {
557 if curve.is_empty() {
558 return 0.;
559 }
560
561 let n = curve.len() as f32;
562 let v = (n - 1.) / 2.0 * (input + 1.);
563
564 if v <= 0. {
565 curve[0]
566 } else if v >= n - 1. {
567 curve[(n - 1.) as usize]
568 } else {
569 let k = v.floor();
570 let f = v - k;
571 (1. - f) * curve[k as usize] + f * curve[(k + 1.) as usize]
572 }
573}
574
575#[cfg(test)]
576mod tests {
577 use float_eq::assert_float_eq;
578
579 use crate::context::OfflineAudioContext;
580 use crate::node::AudioScheduledSourceNode;
581
582 use super::*;
583
584 const LENGTH: usize = 555;
585
586 #[test]
587 fn build_with_new() {
588 let context = OfflineAudioContext::new(2, LENGTH, 44_100.);
589 let _shaper = WaveShaperNode::new(&context, WaveShaperOptions::default());
590 }
591
592 #[test]
593 fn build_with_factory_func() {
594 let context = OfflineAudioContext::new(2, LENGTH, 44_100.);
595 let _shaper = context.create_wave_shaper();
596 }
597
598 #[test]
599 fn test_default_options() {
600 let context = OfflineAudioContext::new(2, LENGTH, 44_100.);
601 let shaper = WaveShaperNode::new(&context, WaveShaperOptions::default());
602
603 assert_eq!(shaper.curve(), None);
604 assert_eq!(shaper.oversample(), OverSampleType::None);
605 }
606
607 #[test]
608 fn test_user_defined_options() {
609 let mut context = OfflineAudioContext::new(2, LENGTH, 44_100.);
610
611 let options = WaveShaperOptions {
612 curve: Some(vec![1.0]),
613 oversample: OverSampleType::X2,
614 ..Default::default()
615 };
616
617 let shaper = WaveShaperNode::new(&context, options);
618
619 let _ = context.start_rendering_sync();
620
621 assert_eq!(shaper.curve(), Some(&[1.0][..]));
622 assert_eq!(shaper.oversample(), OverSampleType::X2);
623 }
624
625 #[test]
626 #[should_panic]
627 fn change_a_curve_for_another_curve_should_panic() {
628 let mut context = OfflineAudioContext::new(2, LENGTH, 44_100.);
629
630 let options = WaveShaperOptions {
631 curve: Some(vec![1.0]),
632 oversample: OverSampleType::X2,
633 ..Default::default()
634 };
635
636 let mut shaper = WaveShaperNode::new(&context, options);
637 assert_eq!(shaper.curve(), Some(&[1.0][..]));
638 assert_eq!(shaper.oversample(), OverSampleType::X2);
639
640 shaper.set_curve(vec![2.0]);
641 shaper.set_oversample(OverSampleType::X4);
642
643 let _ = context.start_rendering_sync();
644
645 assert_eq!(shaper.curve(), Some(&[2.0][..]));
646 assert_eq!(shaper.oversample(), OverSampleType::X4);
647 }
648
649 #[test]
650 fn change_none_for_curve_after_build() {
651 let mut context = OfflineAudioContext::new(2, LENGTH, 44_100.);
652
653 let options = WaveShaperOptions {
654 curve: None,
655 oversample: OverSampleType::X2,
656 ..Default::default()
657 };
658
659 let mut shaper = WaveShaperNode::new(&context, options);
660 assert_eq!(shaper.curve(), None);
661 assert_eq!(shaper.oversample(), OverSampleType::X2);
662
663 shaper.set_curve(vec![2.0]);
664 shaper.set_oversample(OverSampleType::X4);
665
666 let _ = context.start_rendering_sync();
667
668 assert_eq!(shaper.curve(), Some(&[2.0][..]));
669 assert_eq!(shaper.oversample(), OverSampleType::X4);
670 }
671
672 #[test]
673 fn test_shape_boundaries() {
674 let sample_rate = 44100.;
675 let mut context = OfflineAudioContext::new(1, 3 * RENDER_QUANTUM_SIZE, sample_rate);
676
677 let mut shaper = context.create_wave_shaper();
678 let curve = vec![-0.5, 0., 0.5];
679 shaper.set_curve(curve);
680 shaper.connect(&context.destination());
681
682 let mut data = vec![0.; 3 * RENDER_QUANTUM_SIZE];
683 let mut expected = vec![0.; 3 * RENDER_QUANTUM_SIZE];
684 for i in 0..(3 * RENDER_QUANTUM_SIZE) {
685 if i < RENDER_QUANTUM_SIZE {
686 data[i] = -1.;
687 expected[i] = -0.5;
688 } else if i < 2 * RENDER_QUANTUM_SIZE {
689 data[i] = 0.;
690 expected[i] = 0.;
691 } else {
692 data[i] = 1.;
693 expected[i] = 0.5;
694 }
695 }
696 let mut buffer = context.create_buffer(1, 3 * RENDER_QUANTUM_SIZE, sample_rate);
697 buffer.copy_to_channel(&data, 0);
698
699 let mut src = context.create_buffer_source();
700 src.connect(&shaper);
701 src.set_buffer(buffer);
702 src.start_at(0.);
703
704 let result = context.start_rendering_sync();
705 let channel = result.get_channel_data(0);
706
707 assert_float_eq!(channel[..], expected[..], abs_all <= 0.);
708 }
709
710 #[test]
711 fn test_shape_interpolation() {
712 let sample_rate = 44100.;
713 let mut context = OfflineAudioContext::new(1, RENDER_QUANTUM_SIZE, sample_rate);
714
715 let mut shaper = context.create_wave_shaper();
716 let curve = vec![-0.5, 0., 0.5];
717 shaper.set_curve(curve);
718 shaper.connect(&context.destination());
719
720 let mut data = vec![0.; RENDER_QUANTUM_SIZE];
721 let mut expected = vec![0.; RENDER_QUANTUM_SIZE];
722
723 for i in 0..RENDER_QUANTUM_SIZE {
724 let sample = i as f32 / (RENDER_QUANTUM_SIZE as f32) * 2. - 1.;
725 data[i] = sample;
726 expected[i] = sample / 2.;
727 }
728
729 let mut buffer = context.create_buffer(1, 3 * RENDER_QUANTUM_SIZE, sample_rate);
730 buffer.copy_to_channel(&data, 0);
731
732 let mut src = context.create_buffer_source();
733 src.connect(&shaper);
734 src.set_buffer(buffer);
735 src.start_at(0.);
736
737 let result = context.start_rendering_sync();
738 let channel = result.get_channel_data(0);
739
740 assert_float_eq!(channel[..], expected[..], abs_all <= 0.);
741 }
742}