Skip to main content

wedeo_resample/
lib.rs

1// wedeo-resample: Audio resampling (libswresample equivalent)
2// Wraps the rubato crate behind a simple interleaved f32 interface.
3
4use audioadapter_buffers::direct::InterleavedSlice;
5use rubato::{
6    Async, FixedAsync, PolynomialDegree, SincInterpolationParameters, SincInterpolationType,
7    WindowFunction,
8};
9use wedeo_core::error::{Error, Result};
10
11/// Resampling quality preset.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum Quality {
14    /// Polynomial (cubic) interpolation — fastest, lowest quality.
15    /// No anti-aliasing filter; suitable when speed matters more than fidelity.
16    Fast,
17    /// Sinc interpolation with moderate filter length — good balance.
18    Normal,
19    /// Sinc interpolation with long filter — best quality, highest CPU cost.
20    High,
21}
22
23/// Audio resampler that accepts interleaved f32 samples.
24///
25/// Internally buffers partial chunks and drives rubato's fixed-input-size
26/// resampler, producing interleaved f32 output.
27pub struct Resampler {
28    inner: Box<dyn rubato::Resampler<f32>>,
29    channels: usize,
30    from_rate: u32,
31    to_rate: u32,
32    /// rubato's required input chunk size (frames per channel).
33    chunk_size: usize,
34    /// Accumulator for input frames that don't fill a complete chunk.
35    pending: Vec<f32>,
36}
37
38/// Default chunk size fed to rubato constructors.
39const DEFAULT_CHUNK_SIZE: usize = 1024;
40
41fn map_rubato_err<E: std::fmt::Display>(e: E) -> Error {
42    Error::Other(format!("resample: {e}"))
43}
44
45impl Resampler {
46    /// Create a new resampler.
47    ///
48    /// * `from_rate` / `to_rate` — input / output sample rates in Hz.
49    /// * `channels` — number of interleaved channels (must be >= 1).
50    /// * `quality` — algorithm selection (see [`Quality`]).
51    pub fn new(from_rate: u32, to_rate: u32, channels: usize, quality: Quality) -> Result<Self> {
52        if from_rate == 0 || to_rate == 0 {
53            return Err(Error::InvalidArgument);
54        }
55        if channels == 0 {
56            return Err(Error::InvalidArgument);
57        }
58
59        let ratio = to_rate as f64 / from_rate as f64;
60        let chunk_size = DEFAULT_CHUNK_SIZE;
61
62        let inner: Box<dyn rubato::Resampler<f32>> = match quality {
63            Quality::Fast => Box::new(
64                Async::<f32>::new_poly(
65                    ratio,
66                    1.0,
67                    PolynomialDegree::Cubic,
68                    chunk_size,
69                    channels,
70                    FixedAsync::Input,
71                )
72                .map_err(map_rubato_err)?,
73            ),
74            Quality::Normal => {
75                let params = SincInterpolationParameters {
76                    sinc_len: 128,
77                    f_cutoff: 0.95,
78                    interpolation: SincInterpolationType::Linear,
79                    oversampling_factor: 128,
80                    window: WindowFunction::BlackmanHarris2,
81                };
82                Box::new(
83                    Async::<f32>::new_sinc(
84                        ratio,
85                        1.0,
86                        &params,
87                        chunk_size,
88                        channels,
89                        FixedAsync::Input,
90                    )
91                    .map_err(map_rubato_err)?,
92                )
93            }
94            Quality::High => {
95                let params = SincInterpolationParameters {
96                    sinc_len: 256,
97                    f_cutoff: 0.95,
98                    interpolation: SincInterpolationType::Cubic,
99                    oversampling_factor: 256,
100                    window: WindowFunction::BlackmanHarris2,
101                };
102                Box::new(
103                    Async::<f32>::new_sinc(
104                        ratio,
105                        1.0,
106                        &params,
107                        chunk_size,
108                        channels,
109                        FixedAsync::Input,
110                    )
111                    .map_err(map_rubato_err)?,
112                )
113            }
114        };
115
116        Ok(Self {
117            inner,
118            channels,
119            from_rate,
120            to_rate,
121            chunk_size,
122            pending: Vec::new(),
123        })
124    }
125
126    /// Feed interleaved f32 samples and receive resampled interleaved output.
127    ///
128    /// Input length must be a multiple of `channels`. Output length will also
129    /// be a multiple of `channels`.
130    ///
131    /// The resampler buffers internally, so output may be shorter or longer
132    /// than a naive ratio calculation — call [`flush`](Self::flush) at the end
133    /// to drain remaining samples.
134    pub fn process(&mut self, input: &[f32]) -> Result<Vec<f32>> {
135        if !input.len().is_multiple_of(self.channels) {
136            return Err(Error::InvalidArgument);
137        }
138
139        // Append new data to pending buffer.
140        self.pending.extend_from_slice(input);
141
142        let samples_per_chunk = self.chunk_size * self.channels;
143        let mut output = Vec::new();
144
145        // Process as many full chunks as we can.
146        while self.pending.len() >= samples_per_chunk {
147            // Split off one chunk to avoid borrowing self.pending while calling
148            // process_chunk(&mut self).
149            let rest = self.pending.split_off(samples_per_chunk);
150            let chunk = std::mem::replace(&mut self.pending, rest);
151            self.process_chunk(&chunk, &mut output)?;
152        }
153
154        Ok(output)
155    }
156
157    /// Flush remaining buffered samples by zero-padding to a full chunk.
158    ///
159    /// Call this once after the last `process()` call to retrieve the tail
160    /// of the resampled signal.
161    pub fn flush(&mut self) -> Result<Vec<f32>> {
162        if self.pending.is_empty() {
163            return Ok(Vec::new());
164        }
165
166        let samples_per_chunk = self.chunk_size * self.channels;
167
168        // Pad pending data to a full chunk with zeros.
169        let pending_frames = self.pending.len() / self.channels;
170        let partial_samples = self.pending.len() % self.channels;
171
172        // If pending isn't frame-aligned, pad to the next frame boundary first.
173        if partial_samples != 0 {
174            self.pending
175                .resize(self.pending.len() + (self.channels - partial_samples), 0.0);
176        }
177
178        // Now pad to a full chunk.
179        self.pending.resize(samples_per_chunk, 0.0);
180
181        let mut output = Vec::new();
182
183        let input_adapter = InterleavedSlice::new(&self.pending, self.channels, self.chunk_size)
184            .map_err(map_rubato_err)?;
185
186        let result = self
187            .inner
188            .process(&input_adapter, 0, None)
189            .map_err(map_rubato_err)?;
190
191        // The result is an InterleavedOwned — extract the interleaved data.
192        let out_data = result.take_data();
193        let out_frames = self.inner.output_frames_next();
194        // rubato may have allocated more than needed; only take valid frames.
195        let valid_samples = out_frames.min(out_data.len() / self.channels) * self.channels;
196        output.extend_from_slice(&out_data[..valid_samples]);
197
198        // Trim output to the expected number of frames based on the real
199        // pending frame count (not the zero-padded chunk).
200        let expected_out_frames = self.output_frames_estimate(pending_frames);
201        let expected_samples = expected_out_frames * self.channels;
202        if output.len() > expected_samples {
203            output.truncate(expected_samples);
204        }
205
206        self.pending.clear();
207        Ok(output)
208    }
209
210    /// Estimate the number of output frames for a given number of input frames.
211    pub fn output_frames_estimate(&self, input_frames: usize) -> usize {
212        (input_frames as u64 * self.to_rate as u64).div_ceil(self.from_rate as u64) as usize
213    }
214
215    /// Get the input sample rate.
216    pub fn from_rate(&self) -> u32 {
217        self.from_rate
218    }
219
220    /// Get the output sample rate.
221    pub fn to_rate(&self) -> u32 {
222        self.to_rate
223    }
224
225    /// Get the number of channels.
226    pub fn channels(&self) -> usize {
227        self.channels
228    }
229
230    /// Reset the resampler state and clear all internal buffers.
231    pub fn reset(&mut self) {
232        self.inner.reset();
233        self.pending.clear();
234    }
235
236    /// Process exactly one chunk of interleaved data through rubato and
237    /// append the interleaved output to `output`.
238    fn process_chunk(&mut self, chunk: &[f32], output: &mut Vec<f32>) -> Result<()> {
239        let frames = chunk.len() / self.channels;
240        let input_adapter =
241            InterleavedSlice::new(chunk, self.channels, frames).map_err(map_rubato_err)?;
242
243        let result = self
244            .inner
245            .process(&input_adapter, 0, None)
246            .map_err(map_rubato_err)?;
247
248        let out_data = result.take_data();
249        let out_frames = out_data.len() / self.channels;
250        output.extend_from_slice(&out_data[..out_frames * self.channels]);
251        Ok(())
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    #[test]
260    fn create_resampler_all_qualities() {
261        for quality in [Quality::Fast, Quality::Normal, Quality::High] {
262            let r = Resampler::new(44100, 48000, 2, quality);
263            assert!(r.is_ok(), "Failed to create resampler with {quality:?}");
264            let r = r.unwrap();
265            assert_eq!(r.from_rate(), 44100);
266            assert_eq!(r.to_rate(), 48000);
267            assert_eq!(r.channels(), 2);
268        }
269    }
270
271    #[test]
272    fn invalid_params() {
273        assert!(Resampler::new(0, 48000, 2, Quality::Fast).is_err());
274        assert!(Resampler::new(44100, 0, 2, Quality::Fast).is_err());
275        assert!(Resampler::new(44100, 48000, 0, Quality::Fast).is_err());
276    }
277
278    #[test]
279    fn process_silence() {
280        let mut r = Resampler::new(44100, 48000, 1, Quality::Fast).unwrap();
281        let input = vec![0.0f32; 44100]; // 1 second mono
282        let output = r.process(&input).unwrap();
283        let tail = r.flush().unwrap();
284        let total_frames = output.len() + tail.len();
285        // Should be approximately 48000 frames (1 second at 48 kHz).
286        // Allow generous margin for resampler latency.
287        assert!(
288            total_frames > 40000 && total_frames < 56000,
289            "Unexpected output length: {total_frames}"
290        );
291        // All input was silence, so output should be (approximately) silence.
292        for &s in output.iter().chain(tail.iter()) {
293            assert!(s.abs() < 1e-6, "Non-silent sample in silence resample: {s}");
294        }
295    }
296
297    #[test]
298    fn process_non_multiple_of_chunk() {
299        // Feed an odd number of frames that doesn't align with chunk_size.
300        let mut r = Resampler::new(48000, 16000, 2, Quality::Fast).unwrap();
301        let frames = 3000; // not a multiple of 1024
302        let input = vec![0.0f32; frames * 2];
303        let output = r.process(&input).unwrap();
304        let tail = r.flush().unwrap();
305        let total_frames = (output.len() + tail.len()) / 2;
306        let expected = r.output_frames_estimate(frames);
307        // Allow 20% tolerance.
308        assert!(
309            (total_frames as f64 - expected as f64).abs() / expected as f64 <= 0.2,
310            "Output frame count {total_frames} too far from estimate {expected}"
311        );
312    }
313
314    #[test]
315    fn reset_clears_state() {
316        let mut r = Resampler::new(44100, 48000, 1, Quality::Fast).unwrap();
317        let _ = r.process(&vec![0.5f32; 500]).unwrap();
318        r.reset();
319        // After reset, pending should be empty.
320        let tail = r.flush().unwrap();
321        assert!(tail.is_empty());
322    }
323
324    #[test]
325    fn channel_mismatch_rejected() {
326        let mut r = Resampler::new(44100, 48000, 2, Quality::Fast).unwrap();
327        // 3 samples is not a multiple of 2 channels.
328        let result = r.process(&[1.0, 2.0, 3.0]);
329        assert!(result.is_err());
330    }
331}