1use std::any::Any;
2
3use rubato::{FftFixedInOut, Resampler as _};
4
5use crate::{
6    context::{AudioContextRegistration, BaseAudioContext},
7    render::{AudioParamValues, AudioProcessor, AudioRenderQuantum, AudioWorkletGlobalScope},
8    RENDER_QUANTUM_SIZE,
9};
10
11use super::{AudioNode, AudioNodeOptions, ChannelConfig};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum OverSampleType {
17    None,
19    X2,
21    X4,
23}
24
25impl Default for OverSampleType {
26    fn default() -> Self {
27        Self::None
28    }
29}
30
31impl From<u32> for OverSampleType {
32    fn from(i: u32) -> Self {
33        match i {
34            0 => OverSampleType::None,
35            1 => OverSampleType::X2,
36            2 => OverSampleType::X4,
37            _ => unreachable!(),
38        }
39    }
40}
41
42#[derive(Clone, Debug)]
48pub struct WaveShaperOptions {
49    pub curve: Option<Vec<f32>>,
51    pub oversample: OverSampleType,
53    pub audio_node_options: AudioNodeOptions,
55}
56
57impl Default for WaveShaperOptions {
58    fn default() -> Self {
59        Self {
60            oversample: OverSampleType::None,
61            curve: None,
62            audio_node_options: AudioNodeOptions::default(),
63        }
64    }
65}
66
67#[derive(Debug)]
122pub struct WaveShaperNode {
123    registration: AudioContextRegistration,
125    channel_config: ChannelConfig,
127    curve: Option<Vec<f32>>,
129    oversample: OverSampleType,
131}
132
133impl AudioNode for WaveShaperNode {
134    fn registration(&self) -> &AudioContextRegistration {
135        &self.registration
136    }
137
138    fn channel_config(&self) -> &ChannelConfig {
139        &self.channel_config
140    }
141
142    fn number_of_inputs(&self) -> usize {
143        1
144    }
145
146    fn number_of_outputs(&self) -> usize {
147        1
148    }
149}
150
151impl WaveShaperNode {
152    pub fn new<C: BaseAudioContext>(context: &C, options: WaveShaperOptions) -> Self {
159        let WaveShaperOptions {
160            oversample,
161            curve,
162            audio_node_options: channel_config,
163        } = options;
164
165        let mut node = context.base().register(move |registration| {
166            let sample_rate = context.sample_rate() as usize;
167
168            let renderer = WaveShaperRenderer::new(RendererConfig {
169                oversample,
170                sample_rate,
171            });
172
173            let node = Self {
174                registration,
175                channel_config: channel_config.into(),
176                curve: None,
177                oversample,
178            };
179
180            (node, Box::new(renderer))
181        });
182
183        if let Some(curve) = curve {
185            node.set_curve(curve);
186        }
187
188        node
189    }
190
191    #[must_use]
193    pub fn curve(&self) -> Option<&[f32]> {
194        self.curve.as_deref()
195    }
196
197    pub fn set_curve(&mut self, curve: Vec<f32>) {
208        assert!(
209            self.curve.is_none(),
210            "InvalidStateError - cannot assign curve twice",
211        );
212
213        let clone = curve.clone();
214
215        self.curve = Some(curve);
216        self.registration.post_message(Some(clone));
217    }
218
219    #[must_use]
221    pub fn oversample(&self) -> OverSampleType {
222        self.oversample
223    }
224
225    pub fn set_oversample(&mut self, oversample: OverSampleType) {
231        self.oversample = oversample;
232        self.registration.post_message(oversample);
233    }
234}
235
236#[derive(Debug, Clone, PartialEq, Eq)]
237struct ResamplerConfig {
238    channels: usize,
239    chunk_size_in: usize,
240    sample_rate_in: usize,
241    sample_rate_out: usize,
242}
243
244impl ResamplerConfig {
245    fn upsample_x2(channels: usize, sample_rate: usize) -> Self {
246        let chunk_size_in = RENDER_QUANTUM_SIZE;
247        let sample_rate_in = sample_rate;
248        let sample_rate_out = sample_rate * 2;
249        Self {
250            channels,
251            chunk_size_in,
252            sample_rate_in,
253            sample_rate_out,
254        }
255    }
256
257    fn upsample_x4(channels: usize, sample_rate: usize) -> Self {
258        let chunk_size_in = RENDER_QUANTUM_SIZE;
259        let sample_rate_in = sample_rate;
260        let sample_rate_out = sample_rate * 4;
261        Self {
262            channels,
263            chunk_size_in,
264            sample_rate_in,
265            sample_rate_out,
266        }
267    }
268
269    fn downsample_x2(channels: usize, sample_rate: usize) -> Self {
270        let chunk_size_in = RENDER_QUANTUM_SIZE * 2;
271        let sample_rate_in = sample_rate * 2;
272        let sample_rate_out = sample_rate;
273        Self {
274            channels,
275            chunk_size_in,
276            sample_rate_in,
277            sample_rate_out,
278        }
279    }
280
281    fn downsample_x4(channels: usize, sample_rate: usize) -> Self {
282        let chunk_size_in = RENDER_QUANTUM_SIZE * 4;
283        let sample_rate_in = sample_rate * 4;
284        let sample_rate_out = sample_rate;
285        Self {
286            channels,
287            chunk_size_in,
288            sample_rate_in,
289            sample_rate_out,
290        }
291    }
292}
293
294struct Resampler {
295    config: ResamplerConfig,
296    processor: FftFixedInOut<f32>,
297    samples_out: Vec<Vec<f32>>,
298}
299
300impl Resampler {
301    fn new(config: ResamplerConfig) -> Self {
302        let ResamplerConfig {
303            channels,
304            chunk_size_in,
305            sample_rate_in,
306            sample_rate_out,
307        } = &config;
308
309        let processor =
310            FftFixedInOut::new(*sample_rate_in, *sample_rate_out, *chunk_size_in, *channels)
311                .unwrap();
312
313        let samples_out = processor.output_buffer_allocate(true);
314
315        Self {
316            config,
317            processor,
318            samples_out,
319        }
320    }
321
322    fn process<T>(&mut self, samples_in: &[T])
323    where
324        T: AsRef<[f32]>,
325    {
326        debug_assert_eq!(self.config.channels, samples_in.len());
327        debug_assert!(samples_in
329            .iter()
330            .all(|channel| channel.as_ref().len() == self.processor.input_frames_next()));
331        let (in_len, out_len) = self
332            .processor
333            .process_into_buffer(samples_in, &mut self.samples_out[..], None)
334            .unwrap();
335        debug_assert_eq!(in_len, samples_in[0].as_ref().len());
337        debug_assert!(self
339            .samples_out
340            .iter()
341            .all(|channel| channel.len() == out_len));
342    }
343
344    fn samples_out(&self) -> &[Vec<f32>] {
345        &self.samples_out[..]
346    }
347
348    fn samples_out_mut(&mut self) -> &mut [Vec<f32>] {
349        &mut self.samples_out[..]
350    }
351}
352
353struct RendererConfig {
356    oversample: OverSampleType,
358    sample_rate: usize,
360}
361
362struct WaveShaperRenderer {
364    oversample: OverSampleType,
366    curve: Option<Vec<f32>>,
368    sample_rate: usize,
370    channels_x2: usize,
372    channels_x4: usize,
374    upsampler_x2: Resampler,
376    upsampler_x4: Resampler,
378    downsampler_x2: Resampler,
380    downsampler_x4: Resampler,
382    can_propagate_silence: bool,
385}
386
387impl AudioProcessor for WaveShaperRenderer {
388    fn process(
389        &mut self,
390        inputs: &[AudioRenderQuantum],
391        outputs: &mut [AudioRenderQuantum],
392        _params: AudioParamValues<'_>,
393        _scope: &AudioWorkletGlobalScope,
394    ) -> bool {
395        let input = &inputs[0];
397        let output = &mut outputs[0];
398
399        if input.is_silent() && self.can_propagate_silence {
400            output.make_silent();
401            return false;
402        }
403
404        *output = input.clone();
405
406        if let Some(curve) = &self.curve {
407            match self.oversample {
408                OverSampleType::None => {
409                    output.modify_channels(|channel| {
410                        channel.iter_mut().for_each(|o| *o = apply_curve(curve, *o));
411                    });
412                }
413                OverSampleType::X2 => {
414                    let channels = output.channels();
415
416                    if channels.len() != self.channels_x2 {
418                        self.channels_x2 = channels.len();
419
420                        self.upsampler_x2 = Resampler::new(ResamplerConfig::upsample_x2(
421                            self.channels_x2,
422                            self.sample_rate,
423                        ));
424
425                        self.downsampler_x2 = Resampler::new(ResamplerConfig::downsample_x2(
426                            self.channels_x2,
427                            self.sample_rate,
428                        ));
429                    }
430
431                    self.upsampler_x2.process(channels);
432                    for channel in self.upsampler_x2.samples_out_mut().iter_mut() {
433                        for s in channel.iter_mut() {
434                            *s = apply_curve(curve, *s);
435                        }
436                    }
437
438                    self.downsampler_x2.process(self.upsampler_x2.samples_out());
439
440                    for (processed, output) in self
441                        .downsampler_x2
442                        .samples_out()
443                        .iter()
444                        .zip(output.channels_mut())
445                    {
446                        output.copy_from_slice(&processed[..]);
447                    }
448                }
449                OverSampleType::X4 => {
450                    let channels = output.channels();
451
452                    if channels.len() != self.channels_x4 {
454                        self.channels_x4 = channels.len();
455
456                        self.upsampler_x4 = Resampler::new(ResamplerConfig::upsample_x4(
457                            self.channels_x4,
458                            self.sample_rate,
459                        ));
460
461                        self.downsampler_x4 = Resampler::new(ResamplerConfig::downsample_x4(
462                            self.channels_x4,
463                            self.sample_rate,
464                        ));
465                    }
466
467                    self.upsampler_x4.process(channels);
468
469                    for channel in self.upsampler_x4.samples_out_mut().iter_mut() {
470                        for s in channel.iter_mut() {
471                            *s = apply_curve(curve, *s);
472                        }
473                    }
474
475                    self.downsampler_x4.process(self.upsampler_x4.samples_out());
476
477                    for (processed, output) in self
478                        .downsampler_x4
479                        .samples_out()
480                        .iter()
481                        .zip(output.channels_mut())
482                    {
483                        output.copy_from_slice(&processed[..]);
484                    }
485                }
486            }
487        }
488
489        false
491    }
492
493    fn onmessage(&mut self, msg: &mut dyn Any) {
494        if let Some(&oversample) = msg.downcast_ref::<OverSampleType>() {
495            self.oversample = oversample;
496            return;
497        }
498
499        if let Some(curve) = msg.downcast_mut::<Option<Vec<f32>>>() {
500            std::mem::swap(&mut self.curve, curve);
501
502            self.can_propagate_silence = if let Some(curve) = &self.curve {
504                if curve.len() % 2 == 1 {
505                    curve[curve.len() / 2].abs() < 1e-9
506                } else {
507                    let a = curve[curve.len() / 2 - 1];
508                    let b = curve[curve.len() / 2];
509                    ((a + b) / 2.).abs() < 1e-9
510                }
511            } else {
512                true
513            };
514
515            return;
516        }
517
518        log::warn!("WaveShaperRenderer: Dropping incoming message {msg:?}");
519    }
520}
521
522impl WaveShaperRenderer {
523    #[allow(clippy::missing_const_for_fn)]
525    fn new(config: RendererConfig) -> Self {
526        let RendererConfig {
527            sample_rate,
528            oversample,
529        } = config;
530
531        let channels_x2 = 1;
532        let channels_x4 = 1;
533
534        let upsampler_x2 = Resampler::new(ResamplerConfig::upsample_x2(channels_x2, sample_rate));
535
536        let downsampler_x2 =
537            Resampler::new(ResamplerConfig::downsample_x2(channels_x2, sample_rate));
538
539        let upsampler_x4 = Resampler::new(ResamplerConfig::upsample_x4(channels_x2, sample_rate));
540
541        let downsampler_x4 =
542            Resampler::new(ResamplerConfig::downsample_x4(channels_x2, sample_rate));
543
544        Self {
545            oversample,
546            curve: None,
547            sample_rate,
548            channels_x2,
549            channels_x4,
550            upsampler_x2,
551            upsampler_x4,
552            downsampler_x2,
553            downsampler_x4,
554            can_propagate_silence: true,
555        }
556    }
557}
558
559#[inline]
560fn apply_curve(curve: &[f32], input: f32) -> f32 {
561    if curve.is_empty() {
562        return 0.;
563    }
564
565    let n = curve.len() as f32;
566    let v = (n - 1.) / 2.0 * (input + 1.);
567
568    if v <= 0. {
569        curve[0]
570    } else if v >= n - 1. {
571        curve[(n - 1.) as usize]
572    } else {
573        let k = v.floor();
574        let f = v - k;
575        (1. - f) * curve[k as usize] + f * curve[(k + 1.) as usize]
576    }
577}
578
579#[cfg(test)]
580mod tests {
581    use float_eq::assert_float_eq;
582
583    use crate::context::OfflineAudioContext;
584    use crate::node::AudioScheduledSourceNode;
585
586    use super::*;
587
588    const LENGTH: usize = 555;
589
590    #[test]
591    fn build_with_new() {
592        let context = OfflineAudioContext::new(2, LENGTH, 44_100.);
593        let _shaper = WaveShaperNode::new(&context, WaveShaperOptions::default());
594    }
595
596    #[test]
597    fn build_with_factory_func() {
598        let context = OfflineAudioContext::new(2, LENGTH, 44_100.);
599        let _shaper = context.create_wave_shaper();
600    }
601
602    #[test]
603    fn test_default_options() {
604        let context = OfflineAudioContext::new(2, LENGTH, 44_100.);
605        let shaper = WaveShaperNode::new(&context, WaveShaperOptions::default());
606
607        assert_eq!(shaper.curve(), None);
608        assert_eq!(shaper.oversample(), OverSampleType::None);
609    }
610
611    #[test]
612    fn test_user_defined_options() {
613        let mut context = OfflineAudioContext::new(2, LENGTH, 44_100.);
614
615        let options = WaveShaperOptions {
616            curve: Some(vec![1.0]),
617            oversample: OverSampleType::X2,
618            ..Default::default()
619        };
620
621        let shaper = WaveShaperNode::new(&context, options);
622
623        let _ = context.start_rendering_sync();
624
625        assert_eq!(shaper.curve(), Some(&[1.0][..]));
626        assert_eq!(shaper.oversample(), OverSampleType::X2);
627    }
628
629    #[test]
630    #[should_panic]
631    fn change_a_curve_for_another_curve_should_panic() {
632        let mut context = OfflineAudioContext::new(2, LENGTH, 44_100.);
633
634        let options = WaveShaperOptions {
635            curve: Some(vec![1.0]),
636            oversample: OverSampleType::X2,
637            ..Default::default()
638        };
639
640        let mut shaper = WaveShaperNode::new(&context, options);
641        assert_eq!(shaper.curve(), Some(&[1.0][..]));
642        assert_eq!(shaper.oversample(), OverSampleType::X2);
643
644        shaper.set_curve(vec![2.0]);
645        shaper.set_oversample(OverSampleType::X4);
646
647        let _ = context.start_rendering_sync();
648
649        assert_eq!(shaper.curve(), Some(&[2.0][..]));
650        assert_eq!(shaper.oversample(), OverSampleType::X4);
651    }
652
653    #[test]
654    fn change_none_for_curve_after_build() {
655        let mut context = OfflineAudioContext::new(2, LENGTH, 44_100.);
656
657        let options = WaveShaperOptions {
658            curve: None,
659            oversample: OverSampleType::X2,
660            ..Default::default()
661        };
662
663        let mut shaper = WaveShaperNode::new(&context, options);
664        assert_eq!(shaper.curve(), None);
665        assert_eq!(shaper.oversample(), OverSampleType::X2);
666
667        shaper.set_curve(vec![2.0]);
668        shaper.set_oversample(OverSampleType::X4);
669
670        let _ = context.start_rendering_sync();
671
672        assert_eq!(shaper.curve(), Some(&[2.0][..]));
673        assert_eq!(shaper.oversample(), OverSampleType::X4);
674    }
675
676    #[test]
677    fn test_shape_boundaries() {
678        let sample_rate = 44100.;
679        let mut context = OfflineAudioContext::new(1, 3 * RENDER_QUANTUM_SIZE, sample_rate);
680
681        let mut shaper = context.create_wave_shaper();
682        let curve = vec![-0.5, 0., 0.5];
683        shaper.set_curve(curve);
684        shaper.connect(&context.destination());
685
686        let mut data = vec![0.; 3 * RENDER_QUANTUM_SIZE];
687        let mut expected = vec![0.; 3 * RENDER_QUANTUM_SIZE];
688        for i in 0..(3 * RENDER_QUANTUM_SIZE) {
689            if i < RENDER_QUANTUM_SIZE {
690                data[i] = -1.;
691                expected[i] = -0.5;
692            } else if i < 2 * RENDER_QUANTUM_SIZE {
693                data[i] = 0.;
694                expected[i] = 0.;
695            } else {
696                data[i] = 1.;
697                expected[i] = 0.5;
698            }
699        }
700        let mut buffer = context.create_buffer(1, 3 * RENDER_QUANTUM_SIZE, sample_rate);
701        buffer.copy_to_channel(&data, 0);
702
703        let mut src = context.create_buffer_source();
704        src.connect(&shaper);
705        src.set_buffer(buffer);
706        src.start_at(0.);
707
708        let result = context.start_rendering_sync();
709        let channel = result.get_channel_data(0);
710
711        assert_float_eq!(channel[..], expected[..], abs_all <= 0.);
712    }
713
714    #[test]
715    fn test_shape_interpolation() {
716        let sample_rate = 44100.;
717        let mut context = OfflineAudioContext::new(1, RENDER_QUANTUM_SIZE, sample_rate);
718
719        let mut shaper = context.create_wave_shaper();
720        let curve = vec![-0.5, 0., 0.5];
721        shaper.set_curve(curve);
722        shaper.connect(&context.destination());
723
724        let mut data = vec![0.; RENDER_QUANTUM_SIZE];
725        let mut expected = vec![0.; RENDER_QUANTUM_SIZE];
726
727        for i in 0..RENDER_QUANTUM_SIZE {
728            let sample = i as f32 / (RENDER_QUANTUM_SIZE as f32) * 2. - 1.;
729            data[i] = sample;
730            expected[i] = sample / 2.;
731        }
732
733        let mut buffer = context.create_buffer(1, 3 * RENDER_QUANTUM_SIZE, sample_rate);
734        buffer.copy_to_channel(&data, 0);
735
736        let mut src = context.create_buffer_source();
737        src.connect(&shaper);
738        src.set_buffer(buffer);
739        src.start_at(0.);
740
741        let result = context.start_rendering_sync();
742        let channel = result.get_channel_data(0);
743
744        assert_float_eq!(channel[..], expected[..], abs_all <= 0.);
745    }
746}