1use std::any::Any;
2
3use rubato::{FftFixedInOut, Resampler as _};
4
5use crate::{
6 context::{AudioContextRegistration, BaseAudioContext},
7 render::{AudioParamValues, AudioProcessor, AudioRenderQuantum, AudioWorkletGlobalScope},
8 RENDER_QUANTUM_SIZE,
9};
10
11use super::{AudioNode, AudioNodeOptions, ChannelConfig};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum OverSampleType {
17 None,
19 X2,
21 X4,
23}
24
25impl Default for OverSampleType {
26 fn default() -> Self {
27 Self::None
28 }
29}
30
31impl From<u32> for OverSampleType {
32 fn from(i: u32) -> Self {
33 match i {
34 0 => OverSampleType::None,
35 1 => OverSampleType::X2,
36 2 => OverSampleType::X4,
37 _ => unreachable!(),
38 }
39 }
40}
41
42#[derive(Clone, Debug)]
48pub struct WaveShaperOptions {
49 pub curve: Option<Vec<f32>>,
51 pub oversample: OverSampleType,
53 pub audio_node_options: AudioNodeOptions,
55}
56
57impl Default for WaveShaperOptions {
58 fn default() -> Self {
59 Self {
60 oversample: OverSampleType::None,
61 curve: None,
62 audio_node_options: AudioNodeOptions::default(),
63 }
64 }
65}
66
67#[derive(Debug)]
122pub struct WaveShaperNode {
123 registration: AudioContextRegistration,
125 channel_config: ChannelConfig,
127 curve: Option<Vec<f32>>,
129 oversample: OverSampleType,
131}
132
133impl AudioNode for WaveShaperNode {
134 fn registration(&self) -> &AudioContextRegistration {
135 &self.registration
136 }
137
138 fn channel_config(&self) -> &ChannelConfig {
139 &self.channel_config
140 }
141
142 fn number_of_inputs(&self) -> usize {
143 1
144 }
145
146 fn number_of_outputs(&self) -> usize {
147 1
148 }
149}
150
151impl WaveShaperNode {
152 pub fn new<C: BaseAudioContext>(context: &C, options: WaveShaperOptions) -> Self {
159 let WaveShaperOptions {
160 oversample,
161 curve,
162 audio_node_options: channel_config,
163 } = options;
164
165 let mut node = context.base().register(move |registration| {
166 let sample_rate = context.sample_rate() as usize;
167
168 let renderer = WaveShaperRenderer::new(RendererConfig {
169 oversample,
170 sample_rate,
171 });
172
173 let node = Self {
174 registration,
175 channel_config: channel_config.into(),
176 curve: None,
177 oversample,
178 };
179
180 (node, Box::new(renderer))
181 });
182
183 if let Some(curve) = curve {
185 node.set_curve(curve);
186 }
187
188 node
189 }
190
191 #[must_use]
193 pub fn curve(&self) -> Option<&[f32]> {
194 self.curve.as_deref()
195 }
196
197 pub fn set_curve(&mut self, curve: Vec<f32>) {
208 assert!(
209 self.curve.is_none(),
210 "InvalidStateError - cannot assign curve twice",
211 );
212
213 let clone = curve.clone();
214
215 self.curve = Some(curve);
216 self.registration.post_message(Some(clone));
217 }
218
219 #[must_use]
221 pub fn oversample(&self) -> OverSampleType {
222 self.oversample
223 }
224
225 pub fn set_oversample(&mut self, oversample: OverSampleType) {
231 self.oversample = oversample;
232 self.registration.post_message(oversample);
233 }
234}
235
236#[derive(Debug, Clone, PartialEq, Eq)]
237struct ResamplerConfig {
238 channels: usize,
239 chunk_size_in: usize,
240 sample_rate_in: usize,
241 sample_rate_out: usize,
242}
243
244impl ResamplerConfig {
245 fn upsample_x2(channels: usize, sample_rate: usize) -> Self {
246 let chunk_size_in = RENDER_QUANTUM_SIZE;
247 let sample_rate_in = sample_rate;
248 let sample_rate_out = sample_rate * 2;
249 Self {
250 channels,
251 chunk_size_in,
252 sample_rate_in,
253 sample_rate_out,
254 }
255 }
256
257 fn upsample_x4(channels: usize, sample_rate: usize) -> Self {
258 let chunk_size_in = RENDER_QUANTUM_SIZE;
259 let sample_rate_in = sample_rate;
260 let sample_rate_out = sample_rate * 4;
261 Self {
262 channels,
263 chunk_size_in,
264 sample_rate_in,
265 sample_rate_out,
266 }
267 }
268
269 fn downsample_x2(channels: usize, sample_rate: usize) -> Self {
270 let chunk_size_in = RENDER_QUANTUM_SIZE * 2;
271 let sample_rate_in = sample_rate * 2;
272 let sample_rate_out = sample_rate;
273 Self {
274 channels,
275 chunk_size_in,
276 sample_rate_in,
277 sample_rate_out,
278 }
279 }
280
281 fn downsample_x4(channels: usize, sample_rate: usize) -> Self {
282 let chunk_size_in = RENDER_QUANTUM_SIZE * 4;
283 let sample_rate_in = sample_rate * 4;
284 let sample_rate_out = sample_rate;
285 Self {
286 channels,
287 chunk_size_in,
288 sample_rate_in,
289 sample_rate_out,
290 }
291 }
292}
293
294struct Resampler {
295 config: ResamplerConfig,
296 processor: FftFixedInOut<f32>,
297 samples_out: Vec<Vec<f32>>,
298}
299
300impl Resampler {
301 fn new(config: ResamplerConfig) -> Self {
302 let ResamplerConfig {
303 channels,
304 chunk_size_in,
305 sample_rate_in,
306 sample_rate_out,
307 } = &config;
308
309 let processor =
310 FftFixedInOut::new(*sample_rate_in, *sample_rate_out, *chunk_size_in, *channels)
311 .unwrap();
312
313 let samples_out = processor.output_buffer_allocate(true);
314
315 Self {
316 config,
317 processor,
318 samples_out,
319 }
320 }
321
322 fn process<T>(&mut self, samples_in: &[T])
323 where
324 T: AsRef<[f32]>,
325 {
326 debug_assert_eq!(self.config.channels, samples_in.len());
327 debug_assert!(samples_in
329 .iter()
330 .all(|channel| channel.as_ref().len() == self.processor.input_frames_next()));
331 let (in_len, out_len) = self
332 .processor
333 .process_into_buffer(samples_in, &mut self.samples_out[..], None)
334 .unwrap();
335 debug_assert_eq!(in_len, samples_in[0].as_ref().len());
337 debug_assert!(self
339 .samples_out
340 .iter()
341 .all(|channel| channel.len() == out_len));
342 }
343
344 fn samples_out(&self) -> &[Vec<f32>] {
345 &self.samples_out[..]
346 }
347
348 fn samples_out_mut(&mut self) -> &mut [Vec<f32>] {
349 &mut self.samples_out[..]
350 }
351}
352
353struct RendererConfig {
356 oversample: OverSampleType,
358 sample_rate: usize,
360}
361
362struct WaveShaperRenderer {
364 oversample: OverSampleType,
366 curve: Option<Vec<f32>>,
368 sample_rate: usize,
370 channels_x2: usize,
372 channels_x4: usize,
374 upsampler_x2: Resampler,
376 upsampler_x4: Resampler,
378 downsampler_x2: Resampler,
380 downsampler_x4: Resampler,
382 can_propagate_silence: bool,
385}
386
387impl AudioProcessor for WaveShaperRenderer {
388 fn process(
389 &mut self,
390 inputs: &[AudioRenderQuantum],
391 outputs: &mut [AudioRenderQuantum],
392 _params: AudioParamValues<'_>,
393 _scope: &AudioWorkletGlobalScope,
394 ) -> bool {
395 let input = &inputs[0];
397 let output = &mut outputs[0];
398
399 if input.is_silent() && self.can_propagate_silence {
400 output.make_silent();
401 return false;
402 }
403
404 *output = input.clone();
405
406 if let Some(curve) = &self.curve {
407 match self.oversample {
408 OverSampleType::None => {
409 output.modify_channels(|channel| {
410 channel.iter_mut().for_each(|o| *o = apply_curve(curve, *o));
411 });
412 }
413 OverSampleType::X2 => {
414 let channels = output.channels();
415
416 if channels.len() != self.channels_x2 {
418 self.channels_x2 = channels.len();
419
420 self.upsampler_x2 = Resampler::new(ResamplerConfig::upsample_x2(
421 self.channels_x2,
422 self.sample_rate,
423 ));
424
425 self.downsampler_x2 = Resampler::new(ResamplerConfig::downsample_x2(
426 self.channels_x2,
427 self.sample_rate,
428 ));
429 }
430
431 self.upsampler_x2.process(channels);
432 for channel in self.upsampler_x2.samples_out_mut().iter_mut() {
433 for s in channel.iter_mut() {
434 *s = apply_curve(curve, *s);
435 }
436 }
437
438 self.downsampler_x2.process(self.upsampler_x2.samples_out());
439
440 for (processed, output) in self
441 .downsampler_x2
442 .samples_out()
443 .iter()
444 .zip(output.channels_mut())
445 {
446 output.copy_from_slice(&processed[..]);
447 }
448 }
449 OverSampleType::X4 => {
450 let channels = output.channels();
451
452 if channels.len() != self.channels_x4 {
454 self.channels_x4 = channels.len();
455
456 self.upsampler_x4 = Resampler::new(ResamplerConfig::upsample_x4(
457 self.channels_x4,
458 self.sample_rate,
459 ));
460
461 self.downsampler_x4 = Resampler::new(ResamplerConfig::downsample_x4(
462 self.channels_x4,
463 self.sample_rate,
464 ));
465 }
466
467 self.upsampler_x4.process(channels);
468
469 for channel in self.upsampler_x4.samples_out_mut().iter_mut() {
470 for s in channel.iter_mut() {
471 *s = apply_curve(curve, *s);
472 }
473 }
474
475 self.downsampler_x4.process(self.upsampler_x4.samples_out());
476
477 for (processed, output) in self
478 .downsampler_x4
479 .samples_out()
480 .iter()
481 .zip(output.channels_mut())
482 {
483 output.copy_from_slice(&processed[..]);
484 }
485 }
486 }
487 }
488
489 false
491 }
492
493 fn onmessage(&mut self, msg: &mut dyn Any) {
494 if let Some(&oversample) = msg.downcast_ref::<OverSampleType>() {
495 self.oversample = oversample;
496 return;
497 }
498
499 if let Some(curve) = msg.downcast_mut::<Option<Vec<f32>>>() {
500 std::mem::swap(&mut self.curve, curve);
501
502 self.can_propagate_silence = if let Some(curve) = &self.curve {
504 if curve.len() % 2 == 1 {
505 curve[curve.len() / 2].abs() < 1e-9
506 } else {
507 let a = curve[curve.len() / 2 - 1];
508 let b = curve[curve.len() / 2];
509 ((a + b) / 2.).abs() < 1e-9
510 }
511 } else {
512 true
513 };
514
515 return;
516 }
517
518 log::warn!("WaveShaperRenderer: Dropping incoming message {msg:?}");
519 }
520}
521
522impl WaveShaperRenderer {
523 #[allow(clippy::missing_const_for_fn)]
525 fn new(config: RendererConfig) -> Self {
526 let RendererConfig {
527 sample_rate,
528 oversample,
529 } = config;
530
531 let channels_x2 = 1;
532 let channels_x4 = 1;
533
534 let upsampler_x2 = Resampler::new(ResamplerConfig::upsample_x2(channels_x2, sample_rate));
535
536 let downsampler_x2 =
537 Resampler::new(ResamplerConfig::downsample_x2(channels_x2, sample_rate));
538
539 let upsampler_x4 = Resampler::new(ResamplerConfig::upsample_x4(channels_x2, sample_rate));
540
541 let downsampler_x4 =
542 Resampler::new(ResamplerConfig::downsample_x4(channels_x2, sample_rate));
543
544 Self {
545 oversample,
546 curve: None,
547 sample_rate,
548 channels_x2,
549 channels_x4,
550 upsampler_x2,
551 upsampler_x4,
552 downsampler_x2,
553 downsampler_x4,
554 can_propagate_silence: true,
555 }
556 }
557}
558
559#[inline]
560fn apply_curve(curve: &[f32], input: f32) -> f32 {
561 if curve.is_empty() {
562 return 0.;
563 }
564
565 let n = curve.len() as f32;
566 let v = (n - 1.) / 2.0 * (input + 1.);
567
568 if v <= 0. {
569 curve[0]
570 } else if v >= n - 1. {
571 curve[(n - 1.) as usize]
572 } else {
573 let k = v.floor();
574 let f = v - k;
575 (1. - f) * curve[k as usize] + f * curve[(k + 1.) as usize]
576 }
577}
578
579#[cfg(test)]
580mod tests {
581 use float_eq::assert_float_eq;
582
583 use crate::context::OfflineAudioContext;
584 use crate::node::AudioScheduledSourceNode;
585
586 use super::*;
587
588 const LENGTH: usize = 555;
589
590 #[test]
591 fn build_with_new() {
592 let context = OfflineAudioContext::new(2, LENGTH, 44_100.);
593 let _shaper = WaveShaperNode::new(&context, WaveShaperOptions::default());
594 }
595
596 #[test]
597 fn build_with_factory_func() {
598 let context = OfflineAudioContext::new(2, LENGTH, 44_100.);
599 let _shaper = context.create_wave_shaper();
600 }
601
602 #[test]
603 fn test_default_options() {
604 let context = OfflineAudioContext::new(2, LENGTH, 44_100.);
605 let shaper = WaveShaperNode::new(&context, WaveShaperOptions::default());
606
607 assert_eq!(shaper.curve(), None);
608 assert_eq!(shaper.oversample(), OverSampleType::None);
609 }
610
611 #[test]
612 fn test_user_defined_options() {
613 let mut context = OfflineAudioContext::new(2, LENGTH, 44_100.);
614
615 let options = WaveShaperOptions {
616 curve: Some(vec![1.0]),
617 oversample: OverSampleType::X2,
618 ..Default::default()
619 };
620
621 let shaper = WaveShaperNode::new(&context, options);
622
623 let _ = context.start_rendering_sync();
624
625 assert_eq!(shaper.curve(), Some(&[1.0][..]));
626 assert_eq!(shaper.oversample(), OverSampleType::X2);
627 }
628
629 #[test]
630 #[should_panic]
631 fn change_a_curve_for_another_curve_should_panic() {
632 let mut context = OfflineAudioContext::new(2, LENGTH, 44_100.);
633
634 let options = WaveShaperOptions {
635 curve: Some(vec![1.0]),
636 oversample: OverSampleType::X2,
637 ..Default::default()
638 };
639
640 let mut shaper = WaveShaperNode::new(&context, options);
641 assert_eq!(shaper.curve(), Some(&[1.0][..]));
642 assert_eq!(shaper.oversample(), OverSampleType::X2);
643
644 shaper.set_curve(vec![2.0]);
645 shaper.set_oversample(OverSampleType::X4);
646
647 let _ = context.start_rendering_sync();
648
649 assert_eq!(shaper.curve(), Some(&[2.0][..]));
650 assert_eq!(shaper.oversample(), OverSampleType::X4);
651 }
652
653 #[test]
654 fn change_none_for_curve_after_build() {
655 let mut context = OfflineAudioContext::new(2, LENGTH, 44_100.);
656
657 let options = WaveShaperOptions {
658 curve: None,
659 oversample: OverSampleType::X2,
660 ..Default::default()
661 };
662
663 let mut shaper = WaveShaperNode::new(&context, options);
664 assert_eq!(shaper.curve(), None);
665 assert_eq!(shaper.oversample(), OverSampleType::X2);
666
667 shaper.set_curve(vec![2.0]);
668 shaper.set_oversample(OverSampleType::X4);
669
670 let _ = context.start_rendering_sync();
671
672 assert_eq!(shaper.curve(), Some(&[2.0][..]));
673 assert_eq!(shaper.oversample(), OverSampleType::X4);
674 }
675
676 #[test]
677 fn test_shape_boundaries() {
678 let sample_rate = 44100.;
679 let mut context = OfflineAudioContext::new(1, 3 * RENDER_QUANTUM_SIZE, sample_rate);
680
681 let mut shaper = context.create_wave_shaper();
682 let curve = vec![-0.5, 0., 0.5];
683 shaper.set_curve(curve);
684 shaper.connect(&context.destination());
685
686 let mut data = vec![0.; 3 * RENDER_QUANTUM_SIZE];
687 let mut expected = vec![0.; 3 * RENDER_QUANTUM_SIZE];
688 for i in 0..(3 * RENDER_QUANTUM_SIZE) {
689 if i < RENDER_QUANTUM_SIZE {
690 data[i] = -1.;
691 expected[i] = -0.5;
692 } else if i < 2 * RENDER_QUANTUM_SIZE {
693 data[i] = 0.;
694 expected[i] = 0.;
695 } else {
696 data[i] = 1.;
697 expected[i] = 0.5;
698 }
699 }
700 let mut buffer = context.create_buffer(1, 3 * RENDER_QUANTUM_SIZE, sample_rate);
701 buffer.copy_to_channel(&data, 0);
702
703 let mut src = context.create_buffer_source();
704 src.connect(&shaper);
705 src.set_buffer(buffer);
706 src.start_at(0.);
707
708 let result = context.start_rendering_sync();
709 let channel = result.get_channel_data(0);
710
711 assert_float_eq!(channel[..], expected[..], abs_all <= 0.);
712 }
713
714 #[test]
715 fn test_shape_interpolation() {
716 let sample_rate = 44100.;
717 let mut context = OfflineAudioContext::new(1, RENDER_QUANTUM_SIZE, sample_rate);
718
719 let mut shaper = context.create_wave_shaper();
720 let curve = vec![-0.5, 0., 0.5];
721 shaper.set_curve(curve);
722 shaper.connect(&context.destination());
723
724 let mut data = vec![0.; RENDER_QUANTUM_SIZE];
725 let mut expected = vec![0.; RENDER_QUANTUM_SIZE];
726
727 for i in 0..RENDER_QUANTUM_SIZE {
728 let sample = i as f32 / (RENDER_QUANTUM_SIZE as f32) * 2. - 1.;
729 data[i] = sample;
730 expected[i] = sample / 2.;
731 }
732
733 let mut buffer = context.create_buffer(1, 3 * RENDER_QUANTUM_SIZE, sample_rate);
734 buffer.copy_to_channel(&data, 0);
735
736 let mut src = context.create_buffer_source();
737 src.connect(&shaper);
738 src.set_buffer(buffer);
739 src.start_at(0.);
740
741 let result = context.start_rendering_sync();
742 let channel = result.get_channel_data(0);
743
744 assert_float_eq!(channel[..], expected[..], abs_all <= 0.);
745 }
746}