Skip to main content

speech_prep/decoder/
resampler.rs

1use crate::error::{Error, Result};
2
3/// Sample rate converter using linear interpolation.
4#[derive(Debug, Default, Clone, Copy)]
5pub struct SampleRateConverter;
6
7impl SampleRateConverter {
8    /// Standard target sample rate (16kHz) used across the pipeline.
9    pub const TARGET_SAMPLE_RATE: u32 = 16_000;
10
11    /// Construct a new converter.
12    #[must_use]
13    pub const fn new() -> Self {
14        Self
15    }
16
17    /// Resample audio between arbitrary rates using linear interpolation.
18    pub fn resample(
19        samples: &[f32],
20        channels: u8,
21        from_rate: u32,
22        to_rate: u32,
23    ) -> Result<Vec<f32>> {
24        if channels == 0 {
25            return Err(Error::InvalidInput("channel count cannot be zero".into()));
26        }
27        if from_rate == 0 {
28            return Err(Error::InvalidInput(
29                "input sample rate cannot be zero".into(),
30            ));
31        }
32        if to_rate == 0 {
33            return Err(Error::InvalidInput(
34                "output sample rate cannot be zero".into(),
35            ));
36        }
37
38        if from_rate == to_rate {
39            return Ok(samples.to_vec());
40        }
41        if samples.is_empty() {
42            return Ok(Vec::new());
43        }
44
45        let channel_count = channels as usize;
46        if !samples.len().is_multiple_of(channel_count) {
47            return Err(Error::InvalidInput(format!(
48                "sample count {} not divisible by channel count {}",
49                samples.len(),
50                channels
51            )));
52        }
53
54        let frames_in = samples.len() / channel_count;
55        if frames_in == 0 {
56            return Ok(Vec::new());
57        }
58
59        let ratio = f64::from(to_rate) / f64::from(from_rate);
60        let frames_out = (frames_in as f64 * ratio).ceil() as usize;
61        let output_len = frames_out * channel_count;
62        let mut output = Vec::with_capacity(output_len);
63
64        for frame_out in 0..frames_out {
65            let input_pos = (frame_out as f64) / ratio;
66            let idx = input_pos.floor() as usize;
67            let frac = (input_pos - idx as f64) as f32;
68
69            for channel_idx in 0..channel_count {
70                let sample = Self::interpolate_channel(
71                    samples,
72                    channel_count,
73                    frames_in,
74                    channel_idx,
75                    idx,
76                    frac,
77                );
78                output.push(sample);
79            }
80        }
81
82        Ok(output)
83    }
84
85    /// Convenience helper that resamples directly to 16kHz.
86    pub fn resample_to_16khz(samples: &[f32], channels: u8, from_rate: u32) -> Result<Vec<f32>> {
87        Self::resample(samples, channels, from_rate, Self::TARGET_SAMPLE_RATE)
88    }
89
90    #[inline]
91    #[allow(clippy::indexing_slicing)]
92    fn interpolate_channel(
93        samples: &[f32],
94        channel_count: usize,
95        frames_in: usize,
96        channel_idx: usize,
97        frame_idx: usize,
98        frac: f32,
99    ) -> f32 {
100        if frames_in == 0 {
101            return 0.0;
102        }
103
104        let idx_clamped = frame_idx.min(frames_in - 1);
105        let base = idx_clamped * channel_count + channel_idx;
106        let s0 = samples[base];
107
108        let next_frame = idx_clamped + 1;
109        let s1 = if next_frame < frames_in {
110            samples[next_frame * channel_count + channel_idx]
111        } else {
112            s0
113        };
114
115        frac.mul_add(s1 - s0, s0)
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122
123    type TestResult<T> = std::result::Result<T, String>;
124
125    #[test]
126    fn test_resample_identity_16khz() -> TestResult<()> {
127        let input = vec![0.0, 0.5, 1.0, 0.5, 0.0];
128        let output =
129            SampleRateConverter::resample(&input, 1, 16_000, 16_000).map_err(|e| e.to_string())?;
130
131        assert_eq!(output, input);
132        Ok(())
133    }
134
135    #[test]
136    fn test_resample_44100_to_16000() -> TestResult<()> {
137        let input = vec![0.0f32; 44_100];
138        let output =
139            SampleRateConverter::resample(&input, 1, 44_100, 16_000).map_err(|e| e.to_string())?;
140
141        assert_eq!(output.len(), 16_000);
142        assert!(output.iter().all(|&s| s.abs() < f32::EPSILON));
143        Ok(())
144    }
145
146    #[test]
147    fn test_resample_48000_to_16000() -> TestResult<()> {
148        let input = vec![0.0f32; 48_000];
149        let output =
150            SampleRateConverter::resample(&input, 1, 48_000, 16_000).map_err(|e| e.to_string())?;
151
152        assert_eq!(output.len(), 16_000);
153        Ok(())
154    }
155
156    #[test]
157    fn test_resample_8000_to_16000() -> TestResult<()> {
158        let input = vec![0.0f32; 8_000];
159        let output =
160            SampleRateConverter::resample(&input, 1, 8_000, 16_000).map_err(|e| e.to_string())?;
161
162        assert_eq!(output.len(), 16_000);
163        Ok(())
164    }
165
166    #[test]
167    fn test_resample_preserves_amplitude() -> TestResult<()> {
168        let input = vec![0.0, 0.5, 1.0, 0.5, 0.0, -0.5, -1.0, -0.5];
169        let output = SampleRateConverter::resample(&input, 1, 8, 16).map_err(|e| e.to_string())?;
170
171        let max_input = input.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
172        let max_output = output.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
173
174        assert!((max_input - max_output).abs() < 0.05);
175        Ok(())
176    }
177
178    #[test]
179    fn test_resample_to_16khz_helper() -> TestResult<()> {
180        let input = vec![0.0f32; 8_000];
181        let output =
182            SampleRateConverter::resample_to_16khz(&input, 1, 8_000).map_err(|e| e.to_string())?;
183        assert_eq!(output.len(), 16_000);
184        Ok(())
185    }
186
187    #[test]
188    fn test_resample_reject_zero_channels() {
189        let samples = vec![0.0, 0.0];
190        let result = SampleRateConverter::resample(&samples, 0, 16_000, 8_000);
191        assert!(result.is_err());
192    }
193
194    #[test]
195    fn test_resample_reject_misaligned_samples() {
196        let samples = vec![0.0, 0.0, 0.0];
197        let result = SampleRateConverter::resample(&samples, 2, 16_000, 8_000);
198        assert!(result.is_err());
199    }
200}