use std::any::Any;
use rubato::{FftFixedInOut, Resampler as _};
use crate::{
    context::{AudioContextRegistration, BaseAudioContext},
    render::{AudioParamValues, AudioProcessor, AudioRenderQuantum, RenderScope},
    RENDER_QUANTUM_SIZE,
};
use super::{AudioNode, ChannelConfig, ChannelConfigOptions};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OverSampleType {
    None,
    X2,
    X4,
}
impl Default for OverSampleType {
    fn default() -> Self {
        Self::None
    }
}
impl From<u32> for OverSampleType {
    fn from(i: u32) -> Self {
        match i {
            0 => OverSampleType::None,
            1 => OverSampleType::X2,
            2 => OverSampleType::X4,
            _ => unreachable!(),
        }
    }
}
#[derive(Clone, Debug)]
pub struct WaveShaperOptions {
    pub curve: Option<Vec<f32>>,
    pub oversample: OverSampleType,
    pub channel_config: ChannelConfigOptions,
}
impl Default for WaveShaperOptions {
    fn default() -> Self {
        Self {
            oversample: OverSampleType::None,
            curve: None,
            channel_config: ChannelConfigOptions::default(),
        }
    }
}
pub struct WaveShaperNode {
    registration: AudioContextRegistration,
    channel_config: ChannelConfig,
    curve: Option<Vec<f32>>,
    oversample: OverSampleType,
}
impl AudioNode for WaveShaperNode {
    fn registration(&self) -> &AudioContextRegistration {
        &self.registration
    }
    fn channel_config(&self) -> &ChannelConfig {
        &self.channel_config
    }
    fn number_of_inputs(&self) -> usize {
        1
    }
    fn number_of_outputs(&self) -> usize {
        1
    }
}
impl WaveShaperNode {
    pub fn new<C: BaseAudioContext>(context: &C, options: WaveShaperOptions) -> Self {
        let WaveShaperOptions {
            oversample,
            curve,
            channel_config,
        } = options;
        let mut node = context.register(move |registration| {
            let sample_rate = context.sample_rate() as usize;
            let renderer = WaveShaperRenderer::new(RendererConfig {
                oversample,
                sample_rate,
            });
            let node = Self {
                registration,
                channel_config: channel_config.into(),
                curve: None,
                oversample,
            };
            (node, Box::new(renderer))
        });
        if let Some(curve) = curve {
            node.set_curve(curve);
        }
        node
    }
    #[must_use]
    pub fn curve(&self) -> Option<&[f32]> {
        self.curve.as_deref()
    }
    pub fn set_curve(&mut self, curve: Vec<f32>) {
        if self.curve.is_some() {
            panic!("InvalidStateError - cannot assign curve twice");
        }
        let clone = curve.clone();
        self.curve = Some(curve);
        self.registration.post_message(Some(clone));
    }
    #[must_use]
    pub fn oversample(&self) -> OverSampleType {
        self.oversample
    }
    pub fn set_oversample(&mut self, oversample: OverSampleType) {
        self.oversample = oversample;
        self.registration.post_message(oversample);
    }
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct ResamplerConfig {
    channels: usize,
    chunk_size_in: usize,
    sample_rate_in: usize,
    sample_rate_out: usize,
}
impl ResamplerConfig {
    fn upsample_x2(channels: usize, sample_rate: usize) -> Self {
        let chunk_size_in = RENDER_QUANTUM_SIZE * 2;
        let sample_rate_in = sample_rate;
        let sample_rate_out = sample_rate * 2;
        Self {
            channels,
            chunk_size_in,
            sample_rate_in,
            sample_rate_out,
        }
    }
    fn upsample_x4(channels: usize, sample_rate: usize) -> Self {
        let chunk_size_in = RENDER_QUANTUM_SIZE * 4;
        let sample_rate_in = sample_rate;
        let sample_rate_out = sample_rate * 4;
        Self {
            channels,
            chunk_size_in,
            sample_rate_in,
            sample_rate_out,
        }
    }
    fn downsample_x2(channels: usize, sample_rate: usize) -> Self {
        let chunk_size_in = RENDER_QUANTUM_SIZE;
        let sample_rate_in = sample_rate * 2;
        let sample_rate_out = sample_rate;
        Self {
            channels,
            chunk_size_in,
            sample_rate_in,
            sample_rate_out,
        }
    }
    fn downsample_x4(channels: usize, sample_rate: usize) -> Self {
        let chunk_size_in = RENDER_QUANTUM_SIZE;
        let sample_rate_in = sample_rate * 4;
        let sample_rate_out = sample_rate;
        Self {
            channels,
            chunk_size_in,
            sample_rate_in,
            sample_rate_out,
        }
    }
}
struct Resampler {
    config: ResamplerConfig,
    processor: FftFixedInOut<f32>,
    samples_out: Vec<Vec<f32>>,
}
impl Resampler {
    fn new(config: ResamplerConfig) -> Self {
        let ResamplerConfig {
            channels,
            chunk_size_in,
            sample_rate_in,
            sample_rate_out,
        } = &config;
        let processor =
            FftFixedInOut::new(*sample_rate_in, *sample_rate_out, *chunk_size_in, *channels)
                .unwrap();
        let samples_out = processor.output_buffer_allocate(true);
        Self {
            config,
            processor,
            samples_out,
        }
    }
    fn process<T>(&mut self, samples_in: &[T])
    where
        T: AsRef<[f32]>,
    {
        debug_assert_eq!(self.config.channels, samples_in.len());
        debug_assert!(samples_in
            .iter()
            .all(|channel| channel.as_ref().len() == self.processor.input_frames_next()));
        let (in_len, out_len) = self
            .processor
            .process_into_buffer(samples_in, &mut self.samples_out[..], None)
            .unwrap();
        debug_assert_eq!(in_len, samples_in[0].as_ref().len());
        debug_assert!(self
            .samples_out
            .iter()
            .all(|channel| channel.len() == out_len));
    }
    fn samples_out(&self) -> &[Vec<f32>] {
        &self.samples_out[..]
    }
    fn samples_out_mut(&mut self) -> &mut [Vec<f32>] {
        &mut self.samples_out[..]
    }
}
struct RendererConfig {
    oversample: OverSampleType,
    sample_rate: usize,
}
struct WaveShaperRenderer {
    oversample: OverSampleType,
    curve: Option<Vec<f32>>,
    sample_rate: usize,
    channels_x2: usize,
    channels_x4: usize,
    upsampler_x2: Resampler,
    upsampler_x4: Resampler,
    downsampler_x2: Resampler,
    downsampler_x4: Resampler,
}
impl AudioProcessor for WaveShaperRenderer {
    fn process(
        &mut self,
        inputs: &[AudioRenderQuantum],
        outputs: &mut [AudioRenderQuantum],
        _params: AudioParamValues<'_>,
        _scope: &RenderScope,
    ) -> bool {
        let input = &inputs[0];
        let output = &mut outputs[0];
        if input.is_silent() {
            output.make_silent();
            return false;
        }
        *output = input.clone();
        if let Some(curve) = &self.curve {
            match self.oversample {
                OverSampleType::None => {
                    output.modify_channels(|channel| {
                        channel.iter_mut().for_each(|o| *o = apply_curve(curve, *o));
                    });
                }
                OverSampleType::X2 => {
                    let channels = output.channels();
                    if channels.len() != self.channels_x2 {
                        self.channels_x2 = channels.len();
                        self.upsampler_x2 = Resampler::new(ResamplerConfig::upsample_x2(
                            self.channels_x2,
                            self.sample_rate,
                        ));
                        self.downsampler_x2 = Resampler::new(ResamplerConfig::downsample_x2(
                            self.channels_x2,
                            self.sample_rate,
                        ));
                    }
                    self.upsampler_x2.process(channels);
                    for channel in self.upsampler_x2.samples_out_mut().iter_mut() {
                        for s in channel.iter_mut() {
                            *s = apply_curve(curve, *s);
                        }
                    }
                    self.downsampler_x2.process(self.upsampler_x2.samples_out());
                    for (processed, output) in self
                        .downsampler_x2
                        .samples_out()
                        .iter()
                        .zip(output.channels_mut())
                    {
                        output.copy_from_slice(&processed[..]);
                    }
                }
                OverSampleType::X4 => {
                    let channels = output.channels();
                    if channels.len() != self.channels_x4 {
                        self.channels_x4 = channels.len();
                        self.upsampler_x4 = Resampler::new(ResamplerConfig::upsample_x4(
                            self.channels_x4,
                            self.sample_rate,
                        ));
                        self.downsampler_x4 = Resampler::new(ResamplerConfig::downsample_x4(
                            self.channels_x4,
                            self.sample_rate,
                        ));
                    }
                    self.upsampler_x4.process(channels);
                    for channel in self.upsampler_x4.samples_out_mut().iter_mut() {
                        for s in channel.iter_mut() {
                            *s = apply_curve(curve, *s);
                        }
                    }
                    self.downsampler_x4.process(self.upsampler_x4.samples_out());
                    for (processed, output) in self
                        .downsampler_x4
                        .samples_out()
                        .iter()
                        .zip(output.channels_mut())
                    {
                        output.copy_from_slice(&processed[..]);
                    }
                }
            }
        }
        false
    }
    fn onmessage(&mut self, msg: &mut dyn Any) {
        if let Some(&oversample) = msg.downcast_ref::<OverSampleType>() {
            self.oversample = oversample;
            return;
        }
        if let Some(curve) = msg.downcast_mut::<Option<Vec<f32>>>() {
            std::mem::swap(&mut self.curve, curve);
            return;
        }
        log::warn!("WaveShaperRenderer: Dropping incoming message {msg:?}");
    }
}
impl WaveShaperRenderer {
    #[allow(clippy::missing_const_for_fn)]
    fn new(config: RendererConfig) -> Self {
        let RendererConfig {
            sample_rate,
            oversample,
        } = config;
        let channels_x2 = 1;
        let channels_x4 = 1;
        let upsampler_x2 = Resampler::new(ResamplerConfig::upsample_x2(channels_x2, sample_rate));
        let downsampler_x2 =
            Resampler::new(ResamplerConfig::downsample_x2(channels_x2, sample_rate));
        let upsampler_x4 = Resampler::new(ResamplerConfig::upsample_x4(channels_x2, sample_rate));
        let downsampler_x4 =
            Resampler::new(ResamplerConfig::downsample_x4(channels_x2, sample_rate));
        Self {
            oversample,
            curve: None,
            sample_rate,
            channels_x2,
            channels_x4,
            upsampler_x2,
            upsampler_x4,
            downsampler_x2,
            downsampler_x4,
        }
    }
}
#[inline]
fn apply_curve(curve: &[f32], input: f32) -> f32 {
    if curve.is_empty() {
        return 0.;
    }
    let n = curve.len() as f32;
    let v = (n - 1.) / 2.0 * (input + 1.);
    if v <= 0. {
        curve[0]
    } else if v >= n - 1. {
        curve[(n - 1.) as usize]
    } else {
        let k = v.floor();
        let f = v - k;
        (1. - f) * curve[k as usize] + f * curve[(k + 1.) as usize]
    }
}
#[cfg(test)]
mod tests {
    use float_eq::assert_float_eq;
    use crate::context::OfflineAudioContext;
    use crate::node::AudioScheduledSourceNode;
    use super::*;
    const LENGTH: usize = 555;
    #[test]
    fn build_with_new() {
        let context = OfflineAudioContext::new(2, LENGTH, 44_100.);
        let _shaper = WaveShaperNode::new(&context, WaveShaperOptions::default());
    }
    #[test]
    fn build_with_factory_func() {
        let context = OfflineAudioContext::new(2, LENGTH, 44_100.);
        let _shaper = context.create_wave_shaper();
    }
    #[test]
    fn test_default_options() {
        let context = OfflineAudioContext::new(2, LENGTH, 44_100.);
        let shaper = WaveShaperNode::new(&context, WaveShaperOptions::default());
        assert_eq!(shaper.curve(), None);
        assert_eq!(shaper.oversample(), OverSampleType::None);
    }
    #[test]
    fn test_user_defined_options() {
        let context = OfflineAudioContext::new(2, LENGTH, 44_100.);
        let options = WaveShaperOptions {
            curve: Some(vec![1.0]),
            oversample: OverSampleType::X2,
            ..Default::default()
        };
        let shaper = WaveShaperNode::new(&context, options);
        context.start_rendering_sync();
        assert_eq!(shaper.curve(), Some(&[1.0][..]));
        assert_eq!(shaper.oversample(), OverSampleType::X2);
    }
    #[test]
    #[should_panic]
    fn change_a_curve_for_another_curve_should_panic() {
        let context = OfflineAudioContext::new(2, LENGTH, 44_100.);
        let options = WaveShaperOptions {
            curve: Some(vec![1.0]),
            oversample: OverSampleType::X2,
            ..Default::default()
        };
        let mut shaper = WaveShaperNode::new(&context, options);
        assert_eq!(shaper.curve(), Some(&[1.0][..]));
        assert_eq!(shaper.oversample(), OverSampleType::X2);
        shaper.set_curve(vec![2.0]);
        shaper.set_oversample(OverSampleType::X4);
        context.start_rendering_sync();
        assert_eq!(shaper.curve(), Some(&[2.0][..]));
        assert_eq!(shaper.oversample(), OverSampleType::X4);
    }
    #[test]
    fn change_none_for_curve_after_build() {
        let context = OfflineAudioContext::new(2, LENGTH, 44_100.);
        let options = WaveShaperOptions {
            curve: None,
            oversample: OverSampleType::X2,
            ..Default::default()
        };
        let mut shaper = WaveShaperNode::new(&context, options);
        assert_eq!(shaper.curve(), None);
        assert_eq!(shaper.oversample(), OverSampleType::X2);
        shaper.set_curve(vec![2.0]);
        shaper.set_oversample(OverSampleType::X4);
        context.start_rendering_sync();
        assert_eq!(shaper.curve(), Some(&[2.0][..]));
        assert_eq!(shaper.oversample(), OverSampleType::X4);
    }
    #[test]
    fn test_shape_boundaries() {
        let sample_rate = 44100.;
        let context = OfflineAudioContext::new(1, 3 * RENDER_QUANTUM_SIZE, sample_rate);
        let mut shaper = context.create_wave_shaper();
        let curve = vec![-0.5, 0., 0.5];
        shaper.set_curve(curve);
        shaper.connect(&context.destination());
        let mut data = vec![0.; 3 * RENDER_QUANTUM_SIZE];
        let mut expected = vec![0.; 3 * RENDER_QUANTUM_SIZE];
        for i in 0..(3 * RENDER_QUANTUM_SIZE) {
            if i < RENDER_QUANTUM_SIZE {
                data[i] = -1.;
                expected[i] = -0.5;
            } else if i < 2 * RENDER_QUANTUM_SIZE {
                data[i] = 0.;
                expected[i] = 0.;
            } else {
                data[i] = 1.;
                expected[i] = 0.5;
            }
        }
        let mut buffer = context.create_buffer(1, 3 * RENDER_QUANTUM_SIZE, sample_rate);
        buffer.copy_to_channel(&data, 0);
        let mut src = context.create_buffer_source();
        src.connect(&shaper);
        src.set_buffer(buffer);
        src.start_at(0.);
        let result = context.start_rendering_sync();
        let channel = result.get_channel_data(0);
        assert_float_eq!(channel[..], expected[..], abs_all <= 0.);
    }
    #[test]
    fn test_shape_interpolation() {
        let sample_rate = 44100.;
        let context = OfflineAudioContext::new(1, RENDER_QUANTUM_SIZE, sample_rate);
        let mut shaper = context.create_wave_shaper();
        let curve = vec![-0.5, 0., 0.5];
        shaper.set_curve(curve);
        shaper.connect(&context.destination());
        let mut data = vec![0.; RENDER_QUANTUM_SIZE];
        let mut expected = vec![0.; RENDER_QUANTUM_SIZE];
        for i in 0..RENDER_QUANTUM_SIZE {
            let sample = i as f32 / (RENDER_QUANTUM_SIZE as f32) * 2. - 1.;
            data[i] = sample;
            expected[i] = sample / 2.;
        }
        let mut buffer = context.create_buffer(1, 3 * RENDER_QUANTUM_SIZE, sample_rate);
        buffer.copy_to_channel(&data, 0);
        let mut src = context.create_buffer_source();
        src.connect(&shaper);
        src.set_buffer(buffer);
        src.start_at(0.);
        let result = context.start_rendering_sync();
        let channel = result.get_channel_data(0);
        assert_float_eq!(channel[..], expected[..], abs_all <= 0.);
    }
}