Skip to main content

speech_prep/decoder/
resampler.rs

1use crate::error::{Error, Result};
2
3/// Sample rate converter using linear interpolation.
4#[derive(Debug, Default, Clone, Copy)]
5pub struct SampleRateConverter;
6
7impl SampleRateConverter {
8    /// Standard target sample rate (16kHz) used across the pipeline.
9    pub const TARGET_SAMPLE_RATE: u32 = 16_000;
10
11    /// Resample audio between arbitrary rates using linear interpolation.
12    pub fn resample(
13        samples: &[f32],
14        channels: u8,
15        from_rate: u32,
16        to_rate: u32,
17    ) -> Result<Vec<f32>> {
18        if channels == 0 {
19            return Err(Error::InvalidInput("channel count cannot be zero".into()));
20        }
21        if from_rate == 0 {
22            return Err(Error::InvalidInput(
23                "input sample rate cannot be zero".into(),
24            ));
25        }
26        if to_rate == 0 {
27            return Err(Error::InvalidInput(
28                "output sample rate cannot be zero".into(),
29            ));
30        }
31
32        if from_rate == to_rate {
33            return Ok(samples.to_vec());
34        }
35        if samples.is_empty() {
36            return Ok(Vec::new());
37        }
38
39        let channel_count = channels as usize;
40        if !samples.len().is_multiple_of(channel_count) {
41            return Err(Error::InvalidInput(format!(
42                "sample count {} not divisible by channel count {}",
43                samples.len(),
44                channels
45            )));
46        }
47
48        let frames_in = samples.len() / channel_count;
49        if frames_in == 0 {
50            return Ok(Vec::new());
51        }
52
53        let ratio = f64::from(to_rate) / f64::from(from_rate);
54        let frames_out = (frames_in as f64 * ratio).ceil() as usize;
55        let output_len = frames_out * channel_count;
56        let mut output = Vec::with_capacity(output_len);
57
58        for frame_out in 0..frames_out {
59            let input_pos = (frame_out as f64) / ratio;
60            let idx = input_pos.floor() as usize;
61            let frac = (input_pos - idx as f64) as f32;
62
63            for channel_idx in 0..channel_count {
64                let sample = Self::interpolate_channel(
65                    samples,
66                    channel_count,
67                    frames_in,
68                    channel_idx,
69                    idx,
70                    frac,
71                );
72                output.push(sample);
73            }
74        }
75
76        Ok(output)
77    }
78
79    /// Convenience helper that resamples directly to 16kHz.
80    pub fn resample_to_16khz(samples: &[f32], channels: u8, from_rate: u32) -> Result<Vec<f32>> {
81        Self::resample(samples, channels, from_rate, Self::TARGET_SAMPLE_RATE)
82    }
83
84    #[inline]
85    #[allow(clippy::indexing_slicing)]
86    fn interpolate_channel(
87        samples: &[f32],
88        channel_count: usize,
89        frames_in: usize,
90        channel_idx: usize,
91        frame_idx: usize,
92        frac: f32,
93    ) -> f32 {
94        if frames_in == 0 {
95            return 0.0;
96        }
97
98        let idx_clamped = frame_idx.min(frames_in - 1);
99        let base = idx_clamped * channel_count + channel_idx;
100        let s0 = samples[base];
101
102        let next_frame = idx_clamped + 1;
103        let s1 = if next_frame < frames_in {
104            samples[next_frame * channel_count + channel_idx]
105        } else {
106            s0
107        };
108
109        frac.mul_add(s1 - s0, s0)
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    type TestResult<T> = std::result::Result<T, String>;
118
119    #[test]
120    fn test_resample_identity_16khz() -> TestResult<()> {
121        let input = vec![0.0, 0.5, 1.0, 0.5, 0.0];
122        let output =
123            SampleRateConverter::resample(&input, 1, 16_000, 16_000).map_err(|e| e.to_string())?;
124
125        assert_eq!(output, input);
126        Ok(())
127    }
128
129    #[test]
130    fn test_resample_44100_to_16000() -> TestResult<()> {
131        let input = vec![0.0f32; 44_100];
132        let output =
133            SampleRateConverter::resample(&input, 1, 44_100, 16_000).map_err(|e| e.to_string())?;
134
135        assert_eq!(output.len(), 16_000);
136        assert!(output.iter().all(|&s| s.abs() < f32::EPSILON));
137        Ok(())
138    }
139
140    #[test]
141    fn test_resample_48000_to_16000() -> TestResult<()> {
142        let input = vec![0.0f32; 48_000];
143        let output =
144            SampleRateConverter::resample(&input, 1, 48_000, 16_000).map_err(|e| e.to_string())?;
145
146        assert_eq!(output.len(), 16_000);
147        Ok(())
148    }
149
150    #[test]
151    fn test_resample_8000_to_16000() -> TestResult<()> {
152        let input = vec![0.0f32; 8_000];
153        let output =
154            SampleRateConverter::resample(&input, 1, 8_000, 16_000).map_err(|e| e.to_string())?;
155
156        assert_eq!(output.len(), 16_000);
157        Ok(())
158    }
159
160    #[test]
161    fn test_resample_preserves_amplitude() -> TestResult<()> {
162        let input = vec![0.0, 0.5, 1.0, 0.5, 0.0, -0.5, -1.0, -0.5];
163        let output = SampleRateConverter::resample(&input, 1, 8, 16).map_err(|e| e.to_string())?;
164
165        let max_input = input.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
166        let max_output = output.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
167
168        assert!((max_input - max_output).abs() < 0.05);
169        Ok(())
170    }
171
172    #[test]
173    fn test_resample_to_16khz_helper() -> TestResult<()> {
174        let input = vec![0.0f32; 8_000];
175        let output =
176            SampleRateConverter::resample_to_16khz(&input, 1, 8_000).map_err(|e| e.to_string())?;
177        assert_eq!(output.len(), 16_000);
178        Ok(())
179    }
180
181    #[test]
182    fn test_resample_reject_zero_channels() {
183        let samples = vec![0.0, 0.0];
184        let result = SampleRateConverter::resample(&samples, 0, 16_000, 8_000);
185        assert!(result.is_err());
186    }
187
188    #[test]
189    fn test_resample_reject_misaligned_samples() {
190        let samples = vec![0.0, 0.0, 0.0];
191        let result = SampleRateConverter::resample(&samples, 2, 16_000, 8_000);
192        assert!(result.is_err());
193    }
194}