Skip to main content

web_audio_api/node/
waveshaper.rs

1use std::any::Any;
2
3use rubato::{FftFixedInOut, Resampler as _};
4
5use crate::{
6    context::{AudioContextRegistration, BaseAudioContext},
7    render::{AudioParamValues, AudioProcessor, AudioRenderQuantum, AudioWorkletGlobalScope},
8    RENDER_QUANTUM_SIZE,
9};
10
11use super::{AudioNode, AudioNodeOptions, ChannelConfig};
12
13/// enumerates the oversampling rate available for `WaveShaperNode`
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15// the naming comes from the web audio specification
16#[derive(Default)]
17pub enum OverSampleType {
18    /// No oversampling is applied
19    #[default]
20    None,
21    /// Oversampled by a factor of 2
22    X2,
23    /// Oversampled by a factor of 4
24    X4,
25}
26
27impl From<u32> for OverSampleType {
28    fn from(i: u32) -> Self {
29        match i {
30            0 => OverSampleType::None,
31            1 => OverSampleType::X2,
32            2 => OverSampleType::X4,
33            _ => unreachable!(),
34        }
35    }
36}
37
38/// `WaveShaperNode` options
39// dictionary WaveShaperOptions : AudioNodeOptions {
40//   sequence<float> curve;
41//   OverSampleType oversample = "none";
42// };
43#[derive(Clone, Debug)]
44pub struct WaveShaperOptions {
45    /// The distortion curve
46    pub curve: Option<Vec<f32>>,
47    /// Oversampling rate - default to `None`
48    pub oversample: OverSampleType,
49    /// audio node options
50    pub audio_node_options: AudioNodeOptions,
51}
52
53impl Default for WaveShaperOptions {
54    fn default() -> Self {
55        Self {
56            oversample: OverSampleType::None,
57            curve: None,
58            audio_node_options: AudioNodeOptions::default(),
59        }
60    }
61}
62
63/// `WaveShaperNode` allows to apply non-linear distortion effect on a audio
64/// signal. Arbitrary non-linear shaping curves may be specified.
65///
66/// - MDN documentation: <https://developer.mozilla.org/en-US/docs/Web/API/WaveShaperNode>
67/// - specification: <https://webaudio.github.io/web-audio-api/#WaveShaperNode>
68/// - see also: [`BaseAudioContext::create_wave_shaper`]
69///
70/// # Usage
71///
72/// ```no_run
73/// use std::fs::File;
74/// use web_audio_api::context::{BaseAudioContext, AudioContext};
75/// use web_audio_api::node::{AudioNode, AudioScheduledSourceNode};
76///
77/// # use std::f32::consts::PI;
78/// # fn make_distortion_curve(size: usize) -> Vec<f32> {
79/// #     let mut curve = vec![0.; size];
80/// #     let mut phase = 0.;
81/// #     let phase_incr = PI / (size - 1) as f32;
82/// #     for i in 0..size {
83/// #         curve[i] = (PI + phase).cos();
84/// #         phase += phase_incr;
85/// #     }
86/// #     curve
87/// # }
88/// let context = AudioContext::default();
89///
90/// let file = File::open("sample.wav").unwrap();
91/// let buffer = context.decode_audio_data_sync(file).unwrap();
92/// let curve = make_distortion_curve(2048);
93/// let drive = 4.;
94///
95/// let post_gain = context.create_gain();
96/// post_gain.connect(&context.destination());
97/// post_gain.gain().set_value(1. / drive);
98///
99/// let mut shaper = context.create_wave_shaper();
100/// shaper.connect(&post_gain);
101/// shaper.set_curve(curve);
102///
103/// let pre_gain = context.create_gain();
104/// pre_gain.connect(&shaper);
105/// pre_gain.gain().set_value(drive);
106///
107/// let mut src = context.create_buffer_source();
108/// src.connect(&pre_gain);
109/// src.set_buffer(buffer);
110///
111/// src.start();
112/// ```
113///
114/// # Example
115///
116/// - `cargo run --release --example waveshaper`
117#[derive(Debug)]
118pub struct WaveShaperNode {
119    /// Represents the node instance and its associated audio context
120    registration: AudioContextRegistration,
121    /// Infos about audio node channel configuration
122    channel_config: ChannelConfig,
123    /// distortion curve
124    curve: Option<Vec<f32>>,
125    /// oversample type
126    oversample: OverSampleType,
127}
128
129impl AudioNode for WaveShaperNode {
130    fn registration(&self) -> &AudioContextRegistration {
131        &self.registration
132    }
133
134    fn channel_config(&self) -> &ChannelConfig {
135        &self.channel_config
136    }
137
138    fn number_of_inputs(&self) -> usize {
139        1
140    }
141
142    fn number_of_outputs(&self) -> usize {
143        1
144    }
145}
146
147impl WaveShaperNode {
148    /// returns a `WaveShaperNode` instance
149    ///
150    /// # Arguments
151    ///
152    /// * `context` - audio context in which the audio node will live.
153    /// * `options` - waveshaper options
154    pub fn new<C: BaseAudioContext>(context: &C, options: WaveShaperOptions) -> Self {
155        let WaveShaperOptions {
156            oversample,
157            curve,
158            audio_node_options: channel_config,
159        } = options;
160
161        let mut node = context.base().register(move |registration| {
162            let sample_rate = context.sample_rate() as usize;
163
164            let renderer = WaveShaperRenderer::new(RendererConfig {
165                oversample,
166                sample_rate,
167            });
168
169            let node = Self {
170                registration,
171                channel_config: channel_config.into(),
172                curve: None,
173                oversample,
174            };
175
176            (node, Box::new(renderer))
177        });
178
179        // renderer has been sent to render thread, we can sent it messages
180        if let Some(curve) = curve {
181            node.set_curve(curve);
182        }
183
184        node
185    }
186
187    /// Returns the distortion curve
188    #[must_use]
189    pub fn curve(&self) -> Option<&[f32]> {
190        self.curve.as_deref()
191    }
192
193    /// Set the distortion `curve` of this node
194    ///
195    /// # Arguments
196    ///
197    /// * `curve` - the desired distortion `curve`
198    ///
199    /// # Panics
200    ///
201    /// Panics if a curve has already been given to the source (though `new` or through
202    /// `set_curve`)
203    pub fn set_curve(&mut self, curve: Vec<f32>) {
204        assert!(
205            self.curve.is_none(),
206            "InvalidStateError - cannot assign curve twice",
207        );
208
209        let clone = curve.clone();
210
211        self.curve = Some(curve);
212        self.registration.post_message(Some(clone));
213    }
214
215    /// Returns the `oversample` faactor of this node
216    #[must_use]
217    pub fn oversample(&self) -> OverSampleType {
218        self.oversample
219    }
220
221    /// set the `oversample` factor of this node
222    ///
223    /// # Arguments
224    ///
225    /// * `oversample` - the desired `OversampleType` variant
226    pub fn set_oversample(&mut self, oversample: OverSampleType) {
227        self.oversample = oversample;
228        self.registration.post_message(oversample);
229    }
230}
231
232#[derive(Debug, Clone, PartialEq, Eq)]
233struct ResamplerConfig {
234    channels: usize,
235    chunk_size_in: usize,
236    sample_rate_in: usize,
237    sample_rate_out: usize,
238}
239
240impl ResamplerConfig {
241    fn upsample_x2(channels: usize, sample_rate: usize) -> Self {
242        let chunk_size_in = RENDER_QUANTUM_SIZE;
243        let sample_rate_in = sample_rate;
244        let sample_rate_out = sample_rate * 2;
245        Self {
246            channels,
247            chunk_size_in,
248            sample_rate_in,
249            sample_rate_out,
250        }
251    }
252
253    fn upsample_x4(channels: usize, sample_rate: usize) -> Self {
254        let chunk_size_in = RENDER_QUANTUM_SIZE;
255        let sample_rate_in = sample_rate;
256        let sample_rate_out = sample_rate * 4;
257        Self {
258            channels,
259            chunk_size_in,
260            sample_rate_in,
261            sample_rate_out,
262        }
263    }
264
265    fn downsample_x2(channels: usize, sample_rate: usize) -> Self {
266        let chunk_size_in = RENDER_QUANTUM_SIZE * 2;
267        let sample_rate_in = sample_rate * 2;
268        let sample_rate_out = sample_rate;
269        Self {
270            channels,
271            chunk_size_in,
272            sample_rate_in,
273            sample_rate_out,
274        }
275    }
276
277    fn downsample_x4(channels: usize, sample_rate: usize) -> Self {
278        let chunk_size_in = RENDER_QUANTUM_SIZE * 4;
279        let sample_rate_in = sample_rate * 4;
280        let sample_rate_out = sample_rate;
281        Self {
282            channels,
283            chunk_size_in,
284            sample_rate_in,
285            sample_rate_out,
286        }
287    }
288}
289
290struct Resampler {
291    config: ResamplerConfig,
292    processor: FftFixedInOut<f32>,
293    samples_out: Vec<Vec<f32>>,
294}
295
296impl Resampler {
297    fn new(config: ResamplerConfig) -> Self {
298        let ResamplerConfig {
299            channels,
300            chunk_size_in,
301            sample_rate_in,
302            sample_rate_out,
303        } = &config;
304
305        let processor =
306            FftFixedInOut::new(*sample_rate_in, *sample_rate_out, *chunk_size_in, *channels)
307                .unwrap();
308
309        let samples_out = processor.output_buffer_allocate(true);
310
311        Self {
312            config,
313            processor,
314            samples_out,
315        }
316    }
317
318    fn process<T>(&mut self, samples_in: &[T])
319    where
320        T: AsRef<[f32]>,
321    {
322        debug_assert_eq!(self.config.channels, samples_in.len());
323        // Processing the output from another resampler directly as input requires this assumption.
324        debug_assert!(samples_in
325            .iter()
326            .all(|channel| channel.as_ref().len() == self.processor.input_frames_next()));
327        let (in_len, out_len) = self
328            .processor
329            .process_into_buffer(samples_in, &mut self.samples_out[..], None)
330            .unwrap();
331        // All input samples must have been consumed.
332        debug_assert_eq!(in_len, samples_in[0].as_ref().len());
333        // All output samples must have been initialized.
334        debug_assert!(self
335            .samples_out
336            .iter()
337            .all(|channel| channel.len() == out_len));
338    }
339
340    fn samples_out(&self) -> &[Vec<f32>] {
341        &self.samples_out[..]
342    }
343
344    fn samples_out_mut(&mut self) -> &mut [Vec<f32>] {
345        &mut self.samples_out[..]
346    }
347}
348
349/// Helper struct which regroups all parameters
350/// required to build `WaveShaperRenderer`
351struct RendererConfig {
352    /// oversample factor
353    oversample: OverSampleType,
354    /// Sample rate (equals to audio context sample rate)
355    sample_rate: usize,
356}
357
358/// `WaveShaperRenderer` represents the rendering part of `WaveShaperNode`
359struct WaveShaperRenderer {
360    /// oversample factor
361    oversample: OverSampleType,
362    /// distortion curve
363    curve: Option<Vec<f32>>,
364    /// Sample rate (equals to audio context sample rate)
365    sample_rate: usize,
366    /// Number of channels used to build the up/down sampler X2
367    channels_x2: usize,
368    /// Number of channels used to build the up/down sampler X4
369    channels_x4: usize,
370    // up sampler configured to multiply by 2 the input signal
371    upsampler_x2: Resampler,
372    // up sampler configured to multiply by 4 the input signal
373    upsampler_x4: Resampler,
374    // down sampler configured to divide by 2 the upsampled signal
375    downsampler_x2: Resampler,
376    // down sampler configured to divide by 4 the upsampled signal
377    downsampler_x4: Resampler,
378    // check if silence can be propagated, i.e. if curve if None or if
379    // it's output value for zero signal is zero (i.e. < 1e-9)
380    can_propagate_silence: bool,
381}
382
383impl AudioProcessor for WaveShaperRenderer {
384    fn process(
385        &mut self,
386        inputs: &[AudioRenderQuantum],
387        outputs: &mut [AudioRenderQuantum],
388        _params: AudioParamValues<'_>,
389        _scope: &AudioWorkletGlobalScope,
390    ) -> bool {
391        // single input/output node
392        let input = &inputs[0];
393        let output = &mut outputs[0];
394
395        if input.is_silent() && self.can_propagate_silence {
396            output.make_silent();
397            return false;
398        }
399
400        *output = input.clone();
401
402        if let Some(curve) = &self.curve {
403            match self.oversample {
404                OverSampleType::None => {
405                    output.modify_channels(|channel| {
406                        channel.iter_mut().for_each(|o| *o = apply_curve(curve, *o));
407                    });
408                }
409                OverSampleType::X2 => {
410                    let channels = output.channels();
411
412                    // recreate up/down sampler if number of channels changed
413                    if channels.len() != self.channels_x2 {
414                        self.channels_x2 = channels.len();
415
416                        self.upsampler_x2 = Resampler::new(ResamplerConfig::upsample_x2(
417                            self.channels_x2,
418                            self.sample_rate,
419                        ));
420
421                        self.downsampler_x2 = Resampler::new(ResamplerConfig::downsample_x2(
422                            self.channels_x2,
423                            self.sample_rate,
424                        ));
425                    }
426
427                    self.upsampler_x2.process(channels);
428                    for channel in self.upsampler_x2.samples_out_mut().iter_mut() {
429                        for s in channel.iter_mut() {
430                            *s = apply_curve(curve, *s);
431                        }
432                    }
433
434                    self.downsampler_x2.process(self.upsampler_x2.samples_out());
435
436                    for (processed, output) in self
437                        .downsampler_x2
438                        .samples_out()
439                        .iter()
440                        .zip(output.channels_mut())
441                    {
442                        output.copy_from_slice(&processed[..]);
443                    }
444                }
445                OverSampleType::X4 => {
446                    let channels = output.channels();
447
448                    // recreate up/down sampler if number of channels changed
449                    if channels.len() != self.channels_x4 {
450                        self.channels_x4 = channels.len();
451
452                        self.upsampler_x4 = Resampler::new(ResamplerConfig::upsample_x4(
453                            self.channels_x4,
454                            self.sample_rate,
455                        ));
456
457                        self.downsampler_x4 = Resampler::new(ResamplerConfig::downsample_x4(
458                            self.channels_x4,
459                            self.sample_rate,
460                        ));
461                    }
462
463                    self.upsampler_x4.process(channels);
464
465                    for channel in self.upsampler_x4.samples_out_mut().iter_mut() {
466                        for s in channel.iter_mut() {
467                            *s = apply_curve(curve, *s);
468                        }
469                    }
470
471                    self.downsampler_x4.process(self.upsampler_x4.samples_out());
472
473                    for (processed, output) in self
474                        .downsampler_x4
475                        .samples_out()
476                        .iter()
477                        .zip(output.channels_mut())
478                    {
479                        output.copy_from_slice(&processed[..]);
480                    }
481                }
482            }
483        }
484
485        // @tbc - rubato::FftFixedInOut doesn't seem to introduce any latency
486        false
487    }
488
489    fn onmessage(&mut self, msg: &mut dyn Any) {
490        if let Some(&oversample) = msg.downcast_ref::<OverSampleType>() {
491            self.oversample = oversample;
492            return;
493        }
494
495        if let Some(curve) = msg.downcast_mut::<Option<Vec<f32>>>() {
496            std::mem::swap(&mut self.curve, curve);
497
498            // We can propagate silent input only if the center of the curve is zero
499            self.can_propagate_silence = if let Some(curve) = &self.curve {
500                if curve.len() % 2 == 1 {
501                    curve[curve.len() / 2].abs() < 1e-9
502                } else {
503                    let a = curve[curve.len() / 2 - 1];
504                    let b = curve[curve.len() / 2];
505                    ((a + b) / 2.).abs() < 1e-9
506                }
507            } else {
508                true
509            };
510
511            return;
512        }
513
514        log::warn!("WaveShaperRenderer: Dropping incoming message {msg:?}");
515    }
516}
517
518impl WaveShaperRenderer {
519    /// returns an `WaveShaperRenderer` instance
520    #[allow(clippy::missing_const_for_fn)]
521    fn new(config: RendererConfig) -> Self {
522        let RendererConfig {
523            sample_rate,
524            oversample,
525        } = config;
526
527        let channels_x2 = 1;
528        let channels_x4 = 1;
529
530        let upsampler_x2 = Resampler::new(ResamplerConfig::upsample_x2(channels_x2, sample_rate));
531
532        let downsampler_x2 =
533            Resampler::new(ResamplerConfig::downsample_x2(channels_x2, sample_rate));
534
535        let upsampler_x4 = Resampler::new(ResamplerConfig::upsample_x4(channels_x2, sample_rate));
536
537        let downsampler_x4 =
538            Resampler::new(ResamplerConfig::downsample_x4(channels_x2, sample_rate));
539
540        Self {
541            oversample,
542            curve: None,
543            sample_rate,
544            channels_x2,
545            channels_x4,
546            upsampler_x2,
547            upsampler_x4,
548            downsampler_x2,
549            downsampler_x4,
550            can_propagate_silence: true,
551        }
552    }
553}
554
555#[inline]
556fn apply_curve(curve: &[f32], input: f32) -> f32 {
557    if curve.is_empty() {
558        return 0.;
559    }
560
561    let n = curve.len() as f32;
562    let v = (n - 1.) / 2.0 * (input + 1.);
563
564    if v <= 0. {
565        curve[0]
566    } else if v >= n - 1. {
567        curve[(n - 1.) as usize]
568    } else {
569        let k = v.floor();
570        let f = v - k;
571        (1. - f) * curve[k as usize] + f * curve[(k + 1.) as usize]
572    }
573}
574
575#[cfg(test)]
576mod tests {
577    use float_eq::assert_float_eq;
578
579    use crate::context::OfflineAudioContext;
580    use crate::node::AudioScheduledSourceNode;
581
582    use super::*;
583
584    const LENGTH: usize = 555;
585
586    #[test]
587    fn build_with_new() {
588        let context = OfflineAudioContext::new(2, LENGTH, 44_100.);
589        let _shaper = WaveShaperNode::new(&context, WaveShaperOptions::default());
590    }
591
592    #[test]
593    fn build_with_factory_func() {
594        let context = OfflineAudioContext::new(2, LENGTH, 44_100.);
595        let _shaper = context.create_wave_shaper();
596    }
597
598    #[test]
599    fn test_default_options() {
600        let context = OfflineAudioContext::new(2, LENGTH, 44_100.);
601        let shaper = WaveShaperNode::new(&context, WaveShaperOptions::default());
602
603        assert_eq!(shaper.curve(), None);
604        assert_eq!(shaper.oversample(), OverSampleType::None);
605    }
606
607    #[test]
608    fn test_user_defined_options() {
609        let mut context = OfflineAudioContext::new(2, LENGTH, 44_100.);
610
611        let options = WaveShaperOptions {
612            curve: Some(vec![1.0]),
613            oversample: OverSampleType::X2,
614            ..Default::default()
615        };
616
617        let shaper = WaveShaperNode::new(&context, options);
618
619        let _ = context.start_rendering_sync();
620
621        assert_eq!(shaper.curve(), Some(&[1.0][..]));
622        assert_eq!(shaper.oversample(), OverSampleType::X2);
623    }
624
625    #[test]
626    #[should_panic]
627    fn change_a_curve_for_another_curve_should_panic() {
628        let mut context = OfflineAudioContext::new(2, LENGTH, 44_100.);
629
630        let options = WaveShaperOptions {
631            curve: Some(vec![1.0]),
632            oversample: OverSampleType::X2,
633            ..Default::default()
634        };
635
636        let mut shaper = WaveShaperNode::new(&context, options);
637        assert_eq!(shaper.curve(), Some(&[1.0][..]));
638        assert_eq!(shaper.oversample(), OverSampleType::X2);
639
640        shaper.set_curve(vec![2.0]);
641        shaper.set_oversample(OverSampleType::X4);
642
643        let _ = context.start_rendering_sync();
644
645        assert_eq!(shaper.curve(), Some(&[2.0][..]));
646        assert_eq!(shaper.oversample(), OverSampleType::X4);
647    }
648
649    #[test]
650    fn change_none_for_curve_after_build() {
651        let mut context = OfflineAudioContext::new(2, LENGTH, 44_100.);
652
653        let options = WaveShaperOptions {
654            curve: None,
655            oversample: OverSampleType::X2,
656            ..Default::default()
657        };
658
659        let mut shaper = WaveShaperNode::new(&context, options);
660        assert_eq!(shaper.curve(), None);
661        assert_eq!(shaper.oversample(), OverSampleType::X2);
662
663        shaper.set_curve(vec![2.0]);
664        shaper.set_oversample(OverSampleType::X4);
665
666        let _ = context.start_rendering_sync();
667
668        assert_eq!(shaper.curve(), Some(&[2.0][..]));
669        assert_eq!(shaper.oversample(), OverSampleType::X4);
670    }
671
672    #[test]
673    fn test_shape_boundaries() {
674        let sample_rate = 44100.;
675        let mut context = OfflineAudioContext::new(1, 3 * RENDER_QUANTUM_SIZE, sample_rate);
676
677        let mut shaper = context.create_wave_shaper();
678        let curve = vec![-0.5, 0., 0.5];
679        shaper.set_curve(curve);
680        shaper.connect(&context.destination());
681
682        let mut data = vec![0.; 3 * RENDER_QUANTUM_SIZE];
683        let mut expected = vec![0.; 3 * RENDER_QUANTUM_SIZE];
684        for i in 0..(3 * RENDER_QUANTUM_SIZE) {
685            if i < RENDER_QUANTUM_SIZE {
686                data[i] = -1.;
687                expected[i] = -0.5;
688            } else if i < 2 * RENDER_QUANTUM_SIZE {
689                data[i] = 0.;
690                expected[i] = 0.;
691            } else {
692                data[i] = 1.;
693                expected[i] = 0.5;
694            }
695        }
696        let mut buffer = context.create_buffer(1, 3 * RENDER_QUANTUM_SIZE, sample_rate);
697        buffer.copy_to_channel(&data, 0);
698
699        let mut src = context.create_buffer_source();
700        src.connect(&shaper);
701        src.set_buffer(buffer);
702        src.start_at(0.);
703
704        let result = context.start_rendering_sync();
705        let channel = result.get_channel_data(0);
706
707        assert_float_eq!(channel[..], expected[..], abs_all <= 0.);
708    }
709
710    #[test]
711    fn test_shape_interpolation() {
712        let sample_rate = 44100.;
713        let mut context = OfflineAudioContext::new(1, RENDER_QUANTUM_SIZE, sample_rate);
714
715        let mut shaper = context.create_wave_shaper();
716        let curve = vec![-0.5, 0., 0.5];
717        shaper.set_curve(curve);
718        shaper.connect(&context.destination());
719
720        let mut data = vec![0.; RENDER_QUANTUM_SIZE];
721        let mut expected = vec![0.; RENDER_QUANTUM_SIZE];
722
723        for i in 0..RENDER_QUANTUM_SIZE {
724            let sample = i as f32 / (RENDER_QUANTUM_SIZE as f32) * 2. - 1.;
725            data[i] = sample;
726            expected[i] = sample / 2.;
727        }
728
729        let mut buffer = context.create_buffer(1, 3 * RENDER_QUANTUM_SIZE, sample_rate);
730        buffer.copy_to_channel(&data, 0);
731
732        let mut src = context.create_buffer_source();
733        src.connect(&shaper);
734        src.set_buffer(buffer);
735        src.start_at(0.);
736
737        let result = context.start_rendering_sync();
738        let channel = result.get_channel_data(0);
739
740        assert_float_eq!(channel[..], expected[..], abs_all <= 0.);
741    }
742}