rustwav_lib/
resampler.rs

1
2use std::{cmp::min, sync::Arc, fmt::{self, Debug, Formatter}};
3use rustfft::{FftPlanner, Fft, num_complex::Complex};
4
5#[derive(Debug, Clone)]
6pub enum ResamplerError {
7    SizeError(String),
8}
9
10#[derive(Clone)]
11pub struct Resampler {
12    fft_forward: Arc<dyn Fft<f64>>,
13    fft_inverse: Arc<dyn Fft<f64>>,
14    fft_size: usize,
15    normalize_scaler: f64,
16}
17
18fn get_average(complexes: &[Complex<f64>]) -> Complex<f64> {
19    let sum: Complex<f64> = complexes.iter().copied().sum();
20    let scaler = 1.0 / complexes.len() as f64;
21    Complex::<f64> {
22        re: sum.re * scaler,
23        im: sum.im * scaler,
24    }
25}
26
27fn interpolate(c1: Complex<f64>, c2: Complex<f64>, s: f64) -> Complex<f64> {
28    c1 + (c2 - c1) * s
29}
30
31// How the Resampler works
32// For audio stretching:
33//   1. The input audio remains its original length, and zero-padding is applied at the end to reach the target length.
34//   2. Perform FFT transformation to obtain the frequency domain.
35//   3. In the frequency domain, scale down the frequency values proportionally (shift them lower).
36//   4. Perform inverse FFT to obtain the stretched audio.
37// 
38// For audio compression:
39//   1. Take the input audio.
40//   2. Perform FFT transformation.
41//   3. In the frequency domain, scale up the frequency values proportionally (shift them higher).
42//   4. Perform inverse FFT to obtain audio with increased pitch but unchanged length.
43//   5. Truncate the audio to shorten its duration.
44// 
45// This implies: the FFT length must be chosen as the longest possible length involved.
46impl Resampler {
47    pub fn new(fft_size: usize) -> Self {
48        let mut planner = FftPlanner::new();
49        if fft_size & 1 != 0 {
50            panic!("The input size and the output size must be times of 2, got {fft_size}");
51        }
52        Self {
53            fft_forward: planner.plan_fft_forward(fft_size),
54            fft_inverse: planner.plan_fft_inverse(fft_size),
55            fft_size,
56            normalize_scaler: 1.0 / fft_size as f64,
57        }
58    }
59
60    // desired_length: The target audio length to achieve, which must not exceed the FFT size.
61    // When samples.len() < desired_length, it indicates audio stretching to desired_length.
62    // When samples.len() > desired_length, it indicates audio compression to desired_length.
63    pub fn resample_core(&self, samples: &[f32], desired_length: usize) -> Result<Vec<f32>, ResamplerError> {
64        const INTERPOLATE_UPSCALE: bool = true;
65        const INTERPOLATE_DNSCALE: bool = true;
66
67        let input_size = samples.len();
68        if input_size == desired_length {
69            return Ok(samples.to_vec());
70        }
71
72        if desired_length > self.fft_size {
73            return Err(ResamplerError::SizeError(format!("The desired size {desired_length} must not exceed the FFT size {}", self.fft_size)));
74        }
75
76        let mut fftbuf: Vec<Complex<f64>> = samples.iter().map(|sample: &f32| -> Complex<f64> {Complex{re: *sample as f64, im: 0.0}}).collect();
77
78        if fftbuf.len() <= self.fft_size {
79            fftbuf.resize(self.fft_size, Complex{re: 0.0, im: 0.0});
80        } else {
81            return Err(ResamplerError::SizeError(format!("The input size {} must not exceed the FFT size {}", fftbuf.len(), self.fft_size)));
82        }
83
84        // 进行 FFT 正向变换
85        self.fft_forward.process(&mut fftbuf);
86
87        // 准备进行插值
88        let mut fftdst = vec![Complex::<f64>{re: 0.0, im: 0.0}; self.fft_size];
89
90        let half = self.fft_size / 2;
91        let back = self.fft_size - 1;
92        let scaling = desired_length as f64 / input_size as f64;
93        if input_size > desired_length {
94            // Input size exceeds output size, indicating audio compression.
95            // This implies stretching in the frequency domain (scaling up).
96            for i in 0..half {
97                let scaled = i as f64 * scaling;
98                let i1 = scaled.trunc() as usize;
99                let i2 = i1 + 1;
100                let s = scaled.fract();
101                if INTERPOLATE_DNSCALE {
102                    fftdst[i] = interpolate(fftbuf[i1], fftbuf[i2], s);
103                    fftdst[back - i] = interpolate(fftbuf[back - i1], fftbuf[back - i2], s);
104                } else {
105                    fftdst[i] = fftbuf[i1];
106                    fftdst[back - i] = fftbuf[back - i1];
107                }
108            }
109        } else {
110            // Input size is smaller than the output size, indicating audio stretching.
111            // This implies compression in the frequency domain (scaling down).
112            for i in 0..half {
113                let i1 = (i as f64 * scaling).trunc() as usize;
114                let i2 = ((i + 1) as f64 * scaling).trunc() as usize;
115                if i2 >= half {break;}
116                let j1 = back - i2;
117                let j2 = back - i1;
118                if INTERPOLATE_UPSCALE {
119                    fftdst[i] = get_average(&fftbuf[i1..i2]);
120                    fftdst[back - i] = get_average(&fftbuf[j1..j2]);
121                } else {
122                    fftdst[i] = fftbuf[i1];
123                    fftdst[back - i] = fftbuf[back - i1];
124                }
125            }
126        }
127
128        self.fft_inverse.process(&mut fftdst);
129
130        fftdst.truncate(desired_length);
131
132        Ok(fftdst.into_iter().map(|c| -> f32 {(c.re * self.normalize_scaler) as f32}).collect())
133    }
134
135    pub fn get_process_size(&self, orig_size: usize, src_sample_rate: u32, dst_sample_rate: u32) -> usize {
136        const MAX_INFRASOUND_FREQ: usize = 20;
137        // The processing unit size should be adjusted to work in "chunks per second", 
138        // and artifacts will vanish when the chunk count aligns with the maximum infrasonic frequency.
139        // Calling `self.get_desired_length()` determines the processed chunk size calculated based on the target sample rate.
140        if src_sample_rate == dst_sample_rate {
141            min(self.fft_size, orig_size)
142        } else {
143            min(self.fft_size, src_sample_rate as usize / MAX_INFRASOUND_FREQ)
144        }
145    }
146
147    pub fn get_desired_length(&self, proc_size: usize, src_sample_rate: u32, dst_sample_rate: u32) -> usize {
148        min(self.fft_size, proc_size * dst_sample_rate as usize / src_sample_rate as usize)
149    }
150
151    pub fn resample(&self, input: &[f32], src_sample_rate: u32, dst_sample_rate: u32) -> Result<Vec<f32>, ResamplerError> {
152        if src_sample_rate == dst_sample_rate {
153            Ok(input.to_vec())
154        } else {
155            let proc_size = self.get_process_size(self.fft_size, src_sample_rate, dst_sample_rate);
156            let desired_length = self.get_desired_length(proc_size, src_sample_rate, dst_sample_rate);
157            if input.len() > proc_size {
158                Err(ResamplerError::SizeError(format!("To resize the waveform, the input size should be {proc_size}, not {}", input.len())))
159            } else if src_sample_rate > dst_sample_rate {
160                // Source sample rate is higher than the target, indicating waveform compression.
161                self.resample_core(input, desired_length)
162            } else {
163                // Source sample rate is lower than the target, indicating waveform stretching.
164                // When the input length is less than the desired length, zero-padding is applied to the end.
165                input.to_vec().resize(proc_size, 0.0);
166                self.resample_core(input, desired_length)
167            }
168        }
169    }
170
171    pub fn get_fft_size(&self) -> usize {
172        self.fft_size
173    }
174}
175
176impl Debug for Resampler {
177    fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
178        fmt.debug_struct("Resampler")
179            .field("fft_forward", &format_args!("..."))
180            .field("fft_inverse", &format_args!("..."))
181            .field("fft_size", &self.fft_size)
182            .field("normalize_scaler", &self.normalize_scaler)
183            .finish()
184    }
185}