web_audio_api/node/
waveshaper.rs

1use std::any::Any;
2
3use rubato::{FftFixedInOut, Resampler as _};
4
5use crate::{
6    context::{AudioContextRegistration, BaseAudioContext},
7    render::{AudioParamValues, AudioProcessor, AudioRenderQuantum, AudioWorkletGlobalScope},
8    RENDER_QUANTUM_SIZE,
9};
10
11use super::{AudioNode, AudioNodeOptions, ChannelConfig};
12
13/// enumerates the oversampling rate available for `WaveShaperNode`
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15// the naming comes from the web audio specification
16pub enum OverSampleType {
17    /// No oversampling is applied
18    None,
19    /// Oversampled by a factor of 2
20    X2,
21    /// Oversampled by a factor of 4
22    X4,
23}
24
25impl Default for OverSampleType {
26    fn default() -> Self {
27        Self::None
28    }
29}
30
31impl From<u32> for OverSampleType {
32    fn from(i: u32) -> Self {
33        match i {
34            0 => OverSampleType::None,
35            1 => OverSampleType::X2,
36            2 => OverSampleType::X4,
37            _ => unreachable!(),
38        }
39    }
40}
41
42/// `WaveShaperNode` options
43// dictionary WaveShaperOptions : AudioNodeOptions {
44//   sequence<float> curve;
45//   OverSampleType oversample = "none";
46// };
47#[derive(Clone, Debug)]
48pub struct WaveShaperOptions {
49    /// The distortion curve
50    pub curve: Option<Vec<f32>>,
51    /// Oversampling rate - default to `None`
52    pub oversample: OverSampleType,
53    /// audio node options
54    pub audio_node_options: AudioNodeOptions,
55}
56
57impl Default for WaveShaperOptions {
58    fn default() -> Self {
59        Self {
60            oversample: OverSampleType::None,
61            curve: None,
62            audio_node_options: AudioNodeOptions::default(),
63        }
64    }
65}
66
67/// `WaveShaperNode` allows to apply non-linear distortion effect on a audio
68/// signal. Arbitrary non-linear shaping curves may be specified.
69///
70/// - MDN documentation: <https://developer.mozilla.org/en-US/docs/Web/API/WaveShaperNode>
71/// - specification: <https://webaudio.github.io/web-audio-api/#WaveShaperNode>
72/// - see also: [`BaseAudioContext::create_wave_shaper`]
73///
74/// # Usage
75///
76/// ```no_run
77/// use std::fs::File;
78/// use web_audio_api::context::{BaseAudioContext, AudioContext};
79/// use web_audio_api::node::{AudioNode, AudioScheduledSourceNode};
80///
81/// # use std::f32::consts::PI;
82/// # fn make_distortion_curve(size: usize) -> Vec<f32> {
83/// #     let mut curve = vec![0.; size];
84/// #     let mut phase = 0.;
85/// #     let phase_incr = PI / (size - 1) as f32;
86/// #     for i in 0..size {
87/// #         curve[i] = (PI + phase).cos();
88/// #         phase += phase_incr;
89/// #     }
90/// #     curve
91/// # }
92/// let context = AudioContext::default();
93///
94/// let file = File::open("sample.wav").unwrap();
95/// let buffer = context.decode_audio_data_sync(file).unwrap();
96/// let curve = make_distortion_curve(2048);
97/// let drive = 4.;
98///
99/// let post_gain = context.create_gain();
100/// post_gain.connect(&context.destination());
101/// post_gain.gain().set_value(1. / drive);
102///
103/// let mut shaper = context.create_wave_shaper();
104/// shaper.connect(&post_gain);
105/// shaper.set_curve(curve);
106///
107/// let pre_gain = context.create_gain();
108/// pre_gain.connect(&shaper);
109/// pre_gain.gain().set_value(drive);
110///
111/// let mut src = context.create_buffer_source();
112/// src.connect(&pre_gain);
113/// src.set_buffer(buffer);
114///
115/// src.start();
116/// ```
117///
118/// # Example
119///
120/// - `cargo run --release --example waveshaper`
121#[derive(Debug)]
122pub struct WaveShaperNode {
123    /// Represents the node instance and its associated audio context
124    registration: AudioContextRegistration,
125    /// Infos about audio node channel configuration
126    channel_config: ChannelConfig,
127    /// distortion curve
128    curve: Option<Vec<f32>>,
129    /// oversample type
130    oversample: OverSampleType,
131}
132
133impl AudioNode for WaveShaperNode {
134    fn registration(&self) -> &AudioContextRegistration {
135        &self.registration
136    }
137
138    fn channel_config(&self) -> &ChannelConfig {
139        &self.channel_config
140    }
141
142    fn number_of_inputs(&self) -> usize {
143        1
144    }
145
146    fn number_of_outputs(&self) -> usize {
147        1
148    }
149}
150
151impl WaveShaperNode {
152    /// returns a `WaveShaperNode` instance
153    ///
154    /// # Arguments
155    ///
156    /// * `context` - audio context in which the audio node will live.
157    /// * `options` - waveshaper options
158    pub fn new<C: BaseAudioContext>(context: &C, options: WaveShaperOptions) -> Self {
159        let WaveShaperOptions {
160            oversample,
161            curve,
162            audio_node_options: channel_config,
163        } = options;
164
165        let mut node = context.base().register(move |registration| {
166            let sample_rate = context.sample_rate() as usize;
167
168            let renderer = WaveShaperRenderer::new(RendererConfig {
169                oversample,
170                sample_rate,
171            });
172
173            let node = Self {
174                registration,
175                channel_config: channel_config.into(),
176                curve: None,
177                oversample,
178            };
179
180            (node, Box::new(renderer))
181        });
182
183        // renderer has been sent to render thread, we can sent it messages
184        if let Some(curve) = curve {
185            node.set_curve(curve);
186        }
187
188        node
189    }
190
191    /// Returns the distortion curve
192    #[must_use]
193    pub fn curve(&self) -> Option<&[f32]> {
194        self.curve.as_deref()
195    }
196
197    /// Set the distortion `curve` of this node
198    ///
199    /// # Arguments
200    ///
201    /// * `curve` - the desired distortion `curve`
202    ///
203    /// # Panics
204    ///
205    /// Panics if a curve has already been given to the source (though `new` or through
206    /// `set_curve`)
207    pub fn set_curve(&mut self, curve: Vec<f32>) {
208        assert!(
209            self.curve.is_none(),
210            "InvalidStateError - cannot assign curve twice",
211        );
212
213        let clone = curve.clone();
214
215        self.curve = Some(curve);
216        self.registration.post_message(Some(clone));
217    }
218
219    /// Returns the `oversample` faactor of this node
220    #[must_use]
221    pub fn oversample(&self) -> OverSampleType {
222        self.oversample
223    }
224
225    /// set the `oversample` factor of this node
226    ///
227    /// # Arguments
228    ///
229    /// * `oversample` - the desired `OversampleType` variant
230    pub fn set_oversample(&mut self, oversample: OverSampleType) {
231        self.oversample = oversample;
232        self.registration.post_message(oversample);
233    }
234}
235
236#[derive(Debug, Clone, PartialEq, Eq)]
237struct ResamplerConfig {
238    channels: usize,
239    chunk_size_in: usize,
240    sample_rate_in: usize,
241    sample_rate_out: usize,
242}
243
244impl ResamplerConfig {
245    fn upsample_x2(channels: usize, sample_rate: usize) -> Self {
246        let chunk_size_in = RENDER_QUANTUM_SIZE;
247        let sample_rate_in = sample_rate;
248        let sample_rate_out = sample_rate * 2;
249        Self {
250            channels,
251            chunk_size_in,
252            sample_rate_in,
253            sample_rate_out,
254        }
255    }
256
257    fn upsample_x4(channels: usize, sample_rate: usize) -> Self {
258        let chunk_size_in = RENDER_QUANTUM_SIZE;
259        let sample_rate_in = sample_rate;
260        let sample_rate_out = sample_rate * 4;
261        Self {
262            channels,
263            chunk_size_in,
264            sample_rate_in,
265            sample_rate_out,
266        }
267    }
268
269    fn downsample_x2(channels: usize, sample_rate: usize) -> Self {
270        let chunk_size_in = RENDER_QUANTUM_SIZE * 2;
271        let sample_rate_in = sample_rate * 2;
272        let sample_rate_out = sample_rate;
273        Self {
274            channels,
275            chunk_size_in,
276            sample_rate_in,
277            sample_rate_out,
278        }
279    }
280
281    fn downsample_x4(channels: usize, sample_rate: usize) -> Self {
282        let chunk_size_in = RENDER_QUANTUM_SIZE * 4;
283        let sample_rate_in = sample_rate * 4;
284        let sample_rate_out = sample_rate;
285        Self {
286            channels,
287            chunk_size_in,
288            sample_rate_in,
289            sample_rate_out,
290        }
291    }
292}
293
294struct Resampler {
295    config: ResamplerConfig,
296    processor: FftFixedInOut<f32>,
297    samples_out: Vec<Vec<f32>>,
298}
299
300impl Resampler {
301    fn new(config: ResamplerConfig) -> Self {
302        let ResamplerConfig {
303            channels,
304            chunk_size_in,
305            sample_rate_in,
306            sample_rate_out,
307        } = &config;
308
309        let processor =
310            FftFixedInOut::new(*sample_rate_in, *sample_rate_out, *chunk_size_in, *channels)
311                .unwrap();
312
313        let samples_out = processor.output_buffer_allocate(true);
314
315        Self {
316            config,
317            processor,
318            samples_out,
319        }
320    }
321
322    fn process<T>(&mut self, samples_in: &[T])
323    where
324        T: AsRef<[f32]>,
325    {
326        debug_assert_eq!(self.config.channels, samples_in.len());
327        // Processing the output from another resampler directly as input requires this assumption.
328        debug_assert!(samples_in
329            .iter()
330            .all(|channel| channel.as_ref().len() == self.processor.input_frames_next()));
331        let (in_len, out_len) = self
332            .processor
333            .process_into_buffer(samples_in, &mut self.samples_out[..], None)
334            .unwrap();
335        // All input samples must have been consumed.
336        debug_assert_eq!(in_len, samples_in[0].as_ref().len());
337        // All output samples must have been initialized.
338        debug_assert!(self
339            .samples_out
340            .iter()
341            .all(|channel| channel.len() == out_len));
342    }
343
344    fn samples_out(&self) -> &[Vec<f32>] {
345        &self.samples_out[..]
346    }
347
348    fn samples_out_mut(&mut self) -> &mut [Vec<f32>] {
349        &mut self.samples_out[..]
350    }
351}
352
353/// Helper struct which regroups all parameters
354/// required to build `WaveShaperRenderer`
355struct RendererConfig {
356    /// oversample factor
357    oversample: OverSampleType,
358    /// Sample rate (equals to audio context sample rate)
359    sample_rate: usize,
360}
361
362/// `WaveShaperRenderer` represents the rendering part of `WaveShaperNode`
363struct WaveShaperRenderer {
364    /// oversample factor
365    oversample: OverSampleType,
366    /// distortion curve
367    curve: Option<Vec<f32>>,
368    /// Sample rate (equals to audio context sample rate)
369    sample_rate: usize,
370    /// Number of channels used to build the up/down sampler X2
371    channels_x2: usize,
372    /// Number of channels used to build the up/down sampler X4
373    channels_x4: usize,
374    // up sampler configured to multiply by 2 the input signal
375    upsampler_x2: Resampler,
376    // up sampler configured to multiply by 4 the input signal
377    upsampler_x4: Resampler,
378    // down sampler configured to divide by 2 the upsampled signal
379    downsampler_x2: Resampler,
380    // down sampler configured to divide by 4 the upsampled signal
381    downsampler_x4: Resampler,
382    // check if silence can be propagated, i.e. if curve if None or if
383    // it's output value for zero signal is zero (i.e. < 1e-9)
384    can_propagate_silence: bool,
385}
386
387impl AudioProcessor for WaveShaperRenderer {
388    fn process(
389        &mut self,
390        inputs: &[AudioRenderQuantum],
391        outputs: &mut [AudioRenderQuantum],
392        _params: AudioParamValues<'_>,
393        _scope: &AudioWorkletGlobalScope,
394    ) -> bool {
395        // single input/output node
396        let input = &inputs[0];
397        let output = &mut outputs[0];
398
399        if input.is_silent() && self.can_propagate_silence {
400            output.make_silent();
401            return false;
402        }
403
404        *output = input.clone();
405
406        if let Some(curve) = &self.curve {
407            match self.oversample {
408                OverSampleType::None => {
409                    output.modify_channels(|channel| {
410                        channel.iter_mut().for_each(|o| *o = apply_curve(curve, *o));
411                    });
412                }
413                OverSampleType::X2 => {
414                    let channels = output.channels();
415
416                    // recreate up/down sampler if number of channels changed
417                    if channels.len() != self.channels_x2 {
418                        self.channels_x2 = channels.len();
419
420                        self.upsampler_x2 = Resampler::new(ResamplerConfig::upsample_x2(
421                            self.channels_x2,
422                            self.sample_rate,
423                        ));
424
425                        self.downsampler_x2 = Resampler::new(ResamplerConfig::downsample_x2(
426                            self.channels_x2,
427                            self.sample_rate,
428                        ));
429                    }
430
431                    self.upsampler_x2.process(channels);
432                    for channel in self.upsampler_x2.samples_out_mut().iter_mut() {
433                        for s in channel.iter_mut() {
434                            *s = apply_curve(curve, *s);
435                        }
436                    }
437
438                    self.downsampler_x2.process(self.upsampler_x2.samples_out());
439
440                    for (processed, output) in self
441                        .downsampler_x2
442                        .samples_out()
443                        .iter()
444                        .zip(output.channels_mut())
445                    {
446                        output.copy_from_slice(&processed[..]);
447                    }
448                }
449                OverSampleType::X4 => {
450                    let channels = output.channels();
451
452                    // recreate up/down sampler if number of channels changed
453                    if channels.len() != self.channels_x4 {
454                        self.channels_x4 = channels.len();
455
456                        self.upsampler_x4 = Resampler::new(ResamplerConfig::upsample_x4(
457                            self.channels_x4,
458                            self.sample_rate,
459                        ));
460
461                        self.downsampler_x4 = Resampler::new(ResamplerConfig::downsample_x4(
462                            self.channels_x4,
463                            self.sample_rate,
464                        ));
465                    }
466
467                    self.upsampler_x4.process(channels);
468
469                    for channel in self.upsampler_x4.samples_out_mut().iter_mut() {
470                        for s in channel.iter_mut() {
471                            *s = apply_curve(curve, *s);
472                        }
473                    }
474
475                    self.downsampler_x4.process(self.upsampler_x4.samples_out());
476
477                    for (processed, output) in self
478                        .downsampler_x4
479                        .samples_out()
480                        .iter()
481                        .zip(output.channels_mut())
482                    {
483                        output.copy_from_slice(&processed[..]);
484                    }
485                }
486            }
487        }
488
489        // @tbc - rubato::FftFixedInOut doesn't seem to introduce any latency
490        false
491    }
492
493    fn onmessage(&mut self, msg: &mut dyn Any) {
494        if let Some(&oversample) = msg.downcast_ref::<OverSampleType>() {
495            self.oversample = oversample;
496            return;
497        }
498
499        if let Some(curve) = msg.downcast_mut::<Option<Vec<f32>>>() {
500            std::mem::swap(&mut self.curve, curve);
501
502            // We can propagate silent input only if the center of the curve is zero
503            self.can_propagate_silence = if let Some(curve) = &self.curve {
504                if curve.len() % 2 == 1 {
505                    curve[curve.len() / 2].abs() < 1e-9
506                } else {
507                    let a = curve[curve.len() / 2 - 1];
508                    let b = curve[curve.len() / 2];
509                    ((a + b) / 2.).abs() < 1e-9
510                }
511            } else {
512                true
513            };
514
515            return;
516        }
517
518        log::warn!("WaveShaperRenderer: Dropping incoming message {msg:?}");
519    }
520}
521
522impl WaveShaperRenderer {
523    /// returns an `WaveShaperRenderer` instance
524    #[allow(clippy::missing_const_for_fn)]
525    fn new(config: RendererConfig) -> Self {
526        let RendererConfig {
527            sample_rate,
528            oversample,
529        } = config;
530
531        let channels_x2 = 1;
532        let channels_x4 = 1;
533
534        let upsampler_x2 = Resampler::new(ResamplerConfig::upsample_x2(channels_x2, sample_rate));
535
536        let downsampler_x2 =
537            Resampler::new(ResamplerConfig::downsample_x2(channels_x2, sample_rate));
538
539        let upsampler_x4 = Resampler::new(ResamplerConfig::upsample_x4(channels_x2, sample_rate));
540
541        let downsampler_x4 =
542            Resampler::new(ResamplerConfig::downsample_x4(channels_x2, sample_rate));
543
544        Self {
545            oversample,
546            curve: None,
547            sample_rate,
548            channels_x2,
549            channels_x4,
550            upsampler_x2,
551            upsampler_x4,
552            downsampler_x2,
553            downsampler_x4,
554            can_propagate_silence: true,
555        }
556    }
557}
558
559#[inline]
560fn apply_curve(curve: &[f32], input: f32) -> f32 {
561    if curve.is_empty() {
562        return 0.;
563    }
564
565    let n = curve.len() as f32;
566    let v = (n - 1.) / 2.0 * (input + 1.);
567
568    if v <= 0. {
569        curve[0]
570    } else if v >= n - 1. {
571        curve[(n - 1.) as usize]
572    } else {
573        let k = v.floor();
574        let f = v - k;
575        (1. - f) * curve[k as usize] + f * curve[(k + 1.) as usize]
576    }
577}
578
579#[cfg(test)]
580mod tests {
581    use float_eq::assert_float_eq;
582
583    use crate::context::OfflineAudioContext;
584    use crate::node::AudioScheduledSourceNode;
585
586    use super::*;
587
588    const LENGTH: usize = 555;
589
590    #[test]
591    fn build_with_new() {
592        let context = OfflineAudioContext::new(2, LENGTH, 44_100.);
593        let _shaper = WaveShaperNode::new(&context, WaveShaperOptions::default());
594    }
595
596    #[test]
597    fn build_with_factory_func() {
598        let context = OfflineAudioContext::new(2, LENGTH, 44_100.);
599        let _shaper = context.create_wave_shaper();
600    }
601
602    #[test]
603    fn test_default_options() {
604        let context = OfflineAudioContext::new(2, LENGTH, 44_100.);
605        let shaper = WaveShaperNode::new(&context, WaveShaperOptions::default());
606
607        assert_eq!(shaper.curve(), None);
608        assert_eq!(shaper.oversample(), OverSampleType::None);
609    }
610
611    #[test]
612    fn test_user_defined_options() {
613        let mut context = OfflineAudioContext::new(2, LENGTH, 44_100.);
614
615        let options = WaveShaperOptions {
616            curve: Some(vec![1.0]),
617            oversample: OverSampleType::X2,
618            ..Default::default()
619        };
620
621        let shaper = WaveShaperNode::new(&context, options);
622
623        let _ = context.start_rendering_sync();
624
625        assert_eq!(shaper.curve(), Some(&[1.0][..]));
626        assert_eq!(shaper.oversample(), OverSampleType::X2);
627    }
628
629    #[test]
630    #[should_panic]
631    fn change_a_curve_for_another_curve_should_panic() {
632        let mut context = OfflineAudioContext::new(2, LENGTH, 44_100.);
633
634        let options = WaveShaperOptions {
635            curve: Some(vec![1.0]),
636            oversample: OverSampleType::X2,
637            ..Default::default()
638        };
639
640        let mut shaper = WaveShaperNode::new(&context, options);
641        assert_eq!(shaper.curve(), Some(&[1.0][..]));
642        assert_eq!(shaper.oversample(), OverSampleType::X2);
643
644        shaper.set_curve(vec![2.0]);
645        shaper.set_oversample(OverSampleType::X4);
646
647        let _ = context.start_rendering_sync();
648
649        assert_eq!(shaper.curve(), Some(&[2.0][..]));
650        assert_eq!(shaper.oversample(), OverSampleType::X4);
651    }
652
653    #[test]
654    fn change_none_for_curve_after_build() {
655        let mut context = OfflineAudioContext::new(2, LENGTH, 44_100.);
656
657        let options = WaveShaperOptions {
658            curve: None,
659            oversample: OverSampleType::X2,
660            ..Default::default()
661        };
662
663        let mut shaper = WaveShaperNode::new(&context, options);
664        assert_eq!(shaper.curve(), None);
665        assert_eq!(shaper.oversample(), OverSampleType::X2);
666
667        shaper.set_curve(vec![2.0]);
668        shaper.set_oversample(OverSampleType::X4);
669
670        let _ = context.start_rendering_sync();
671
672        assert_eq!(shaper.curve(), Some(&[2.0][..]));
673        assert_eq!(shaper.oversample(), OverSampleType::X4);
674    }
675
676    #[test]
677    fn test_shape_boundaries() {
678        let sample_rate = 44100.;
679        let mut context = OfflineAudioContext::new(1, 3 * RENDER_QUANTUM_SIZE, sample_rate);
680
681        let mut shaper = context.create_wave_shaper();
682        let curve = vec![-0.5, 0., 0.5];
683        shaper.set_curve(curve);
684        shaper.connect(&context.destination());
685
686        let mut data = vec![0.; 3 * RENDER_QUANTUM_SIZE];
687        let mut expected = vec![0.; 3 * RENDER_QUANTUM_SIZE];
688        for i in 0..(3 * RENDER_QUANTUM_SIZE) {
689            if i < RENDER_QUANTUM_SIZE {
690                data[i] = -1.;
691                expected[i] = -0.5;
692            } else if i < 2 * RENDER_QUANTUM_SIZE {
693                data[i] = 0.;
694                expected[i] = 0.;
695            } else {
696                data[i] = 1.;
697                expected[i] = 0.5;
698            }
699        }
700        let mut buffer = context.create_buffer(1, 3 * RENDER_QUANTUM_SIZE, sample_rate);
701        buffer.copy_to_channel(&data, 0);
702
703        let mut src = context.create_buffer_source();
704        src.connect(&shaper);
705        src.set_buffer(buffer);
706        src.start_at(0.);
707
708        let result = context.start_rendering_sync();
709        let channel = result.get_channel_data(0);
710
711        assert_float_eq!(channel[..], expected[..], abs_all <= 0.);
712    }
713
714    #[test]
715    fn test_shape_interpolation() {
716        let sample_rate = 44100.;
717        let mut context = OfflineAudioContext::new(1, RENDER_QUANTUM_SIZE, sample_rate);
718
719        let mut shaper = context.create_wave_shaper();
720        let curve = vec![-0.5, 0., 0.5];
721        shaper.set_curve(curve);
722        shaper.connect(&context.destination());
723
724        let mut data = vec![0.; RENDER_QUANTUM_SIZE];
725        let mut expected = vec![0.; RENDER_QUANTUM_SIZE];
726
727        for i in 0..RENDER_QUANTUM_SIZE {
728            let sample = i as f32 / (RENDER_QUANTUM_SIZE as f32) * 2. - 1.;
729            data[i] = sample;
730            expected[i] = sample / 2.;
731        }
732
733        let mut buffer = context.create_buffer(1, 3 * RENDER_QUANTUM_SIZE, sample_rate);
734        buffer.copy_to_channel(&data, 0);
735
736        let mut src = context.create_buffer_source();
737        src.connect(&shaper);
738        src.set_buffer(buffer);
739        src.start_at(0.);
740
741        let result = context.start_rendering_sync();
742        let channel = result.get_channel_data(0);
743
744        assert_float_eq!(channel[..], expected[..], abs_all <= 0.);
745    }
746}