Skip to main content

speech_prep/decoder/
wav.rs

1use std::io::{Cursor, Read, Seek};
2
3use crate::error::{Error, Result};
4
5use super::DecodedAudio;
6
7/// WAV/PCM audio decoder using the `hound` crate.
8///
9/// Supports 16-bit and 24-bit PCM with mono or stereo channels and normalizes
10/// all samples into the [-1.0, 1.0] range.
11#[derive(Debug, Default, Clone, Copy)]
12pub struct WavDecoder;
13
14impl WavDecoder {
15    /// Create a new WAV decoder instance.
16    #[must_use]
17    pub const fn new() -> Self {
18        Self
19    }
20
21    /// Decode WAV audio from a byte slice.
22    pub fn decode(data: &[u8]) -> Result<DecodedAudio> {
23        let cursor = Cursor::new(data);
24        Self::decode_from_reader(cursor)
25    }
26
27    /// Decode WAV audio from any reader implementing `Read + Seek`.
28    pub fn decode_from_reader<R: Read + Seek>(reader: R) -> Result<DecodedAudio> {
29        let mut wav_reader = hound::WavReader::new(reader)
30            .map_err(|err| Error::InvalidInput(format!("failed to parse WAV header: {err}")))?;
31
32        let spec = wav_reader.spec();
33
34        if spec.sample_format != hound::SampleFormat::Int {
35            return Err(Error::InvalidInput(format!(
36                "unsupported WAV format: {:?} (only PCM is supported)",
37                spec.sample_format
38            )));
39        }
40
41        if spec.bits_per_sample != 16 && spec.bits_per_sample != 24 {
42            return Err(Error::InvalidInput(format!(
43                "unsupported bit depth: {} (only 16-bit and 24-bit PCM supported)",
44                spec.bits_per_sample
45            )));
46        }
47
48        if spec.channels > 2 {
49            return Err(Error::InvalidInput(format!(
50                "unsupported channel count: {} (only mono and stereo supported)",
51                spec.channels
52            )));
53        }
54
55        let samples = match spec.bits_per_sample {
56            16 => Self::decode_16bit(&mut wav_reader)?,
57            24 => Self::decode_24bit(&mut wav_reader)?,
58            _ => {
59                return Err(Error::InvalidInput(format!(
60                    "internal error: unhandled bit depth {}",
61                    spec.bits_per_sample
62                )));
63            }
64        };
65
66        let frame_count = samples.len() / spec.channels as usize;
67        let duration_sec = if spec.sample_rate > 0 {
68            frame_count as f64 / f64::from(spec.sample_rate)
69        } else {
70            0.0
71        };
72
73        Ok(DecodedAudio {
74            samples,
75            sample_rate: spec.sample_rate,
76            channels: spec.channels as u8,
77            bit_depth: spec.bits_per_sample,
78            duration_sec,
79        })
80    }
81
82    fn decode_16bit<R: Read + Seek>(wav_reader: &mut hound::WavReader<R>) -> Result<Vec<f32>> {
83        wav_reader
84            .samples::<i16>()
85            .map(|sample_result| {
86                sample_result.map(Self::normalize_i16).map_err(|err| {
87                    Error::InvalidInput(format!("failed to read 16-bit sample: {err}"))
88                })
89            })
90            .collect()
91    }
92
93    fn decode_24bit<R: Read + Seek>(wav_reader: &mut hound::WavReader<R>) -> Result<Vec<f32>> {
94        wav_reader
95            .samples::<i32>()
96            .map(|sample_result| {
97                sample_result.map(Self::normalize_i24).map_err(|err| {
98                    Error::InvalidInput(format!("failed to read 24-bit sample: {err}"))
99                })
100            })
101            .collect()
102    }
103
104    #[inline]
105    fn normalize_i16(sample: i16) -> f32 {
106        f32::from(sample) / 32768.0
107    }
108
109    #[inline]
110    fn normalize_i24(sample: i32) -> f32 {
111        (sample as f32) / 8_388_608.0
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use std::io::Cursor;
118
119    use super::*;
120
121    type TestResult<T> = std::result::Result<T, String>;
122
123    #[test]
124    fn test_decode_16bit_mono_44100hz() -> TestResult<()> {
125        let wav_data = create_wav_header(44100, 1, 16, 4410)?; // 0.1s mono
126        let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
127
128        assert_eq!(decoded.sample_rate, 44100);
129        assert_eq!(decoded.channels, 1);
130        assert_eq!(decoded.bit_depth, 16);
131        assert_eq!(decoded.samples.len(), 4410);
132        assert!((decoded.duration_sec - 0.1).abs() < 1e-6);
133        assert!(decoded.is_normalized());
134
135        Ok(())
136    }
137
138    #[test]
139    fn test_decode_16bit_stereo_48000hz() -> TestResult<()> {
140        let wav_data = create_wav_header(48000, 2, 16, 9600)?; // 0.1s stereo (4800 frames)
141        let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
142
143        assert_eq!(decoded.sample_rate, 48000);
144        assert_eq!(decoded.channels, 2);
145        assert_eq!(decoded.bit_depth, 16);
146        assert_eq!(decoded.samples.len(), 9600);
147        assert_eq!(decoded.frame_count(), 4800);
148        assert!((decoded.duration_sec - 0.1).abs() < 1e-6);
149        assert!(decoded.is_normalized());
150
151        Ok(())
152    }
153
154    #[test]
155    fn test_decode_24bit_mono_96000hz() -> TestResult<()> {
156        let wav_data = create_wav_header(96000, 1, 24, 9600)?; // 0.1s mono
157        let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
158
159        assert_eq!(decoded.sample_rate, 96000);
160        assert_eq!(decoded.channels, 1);
161        assert_eq!(decoded.bit_depth, 24);
162        assert_eq!(decoded.samples.len(), 9600);
163        assert!((decoded.duration_sec - 0.1).abs() < 1e-6);
164        assert!(decoded.is_normalized());
165
166        Ok(())
167    }
168
169    #[test]
170    fn test_decode_24bit_stereo_192000hz() -> TestResult<()> {
171        let wav_data = create_wav_header(192000, 2, 24, 19200)?; // 0.05s stereo
172        let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
173
174        assert_eq!(decoded.sample_rate, 192000);
175        assert_eq!(decoded.channels, 2);
176        assert_eq!(decoded.bit_depth, 24);
177        assert_eq!(decoded.samples.len(), 19200);
178        assert!(decoded.is_normalized());
179
180        Ok(())
181    }
182
183    #[test]
184    fn test_decode_sine_wave_preserves_amplitude() -> TestResult<()> {
185        let wav_data = create_sine_wave_wav(44100, 1, 16, 440.0, 0.1)?;
186        let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
187
188        let max_amplitude = decoded
189            .samples
190            .iter()
191            .map(|s| s.abs())
192            .fold(0.0f32, f32::max);
193        assert!(
194            (max_amplitude - 0.8).abs() < 0.05,
195            "expected max amplitude ~0.8, got {max_amplitude}"
196        );
197
198        Ok(())
199    }
200
201    #[test]
202    fn test_reject_empty_data() {
203        let result = WavDecoder::decode(&[]);
204        assert!(result.is_err());
205    }
206
207    #[test]
208    fn test_decode_zero_samples() -> TestResult<()> {
209        let wav_data = create_wav_header(44_100, 1, 16, 0)?;
210        let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
211        assert_eq!(decoded.samples.len(), 0);
212        assert_eq!(decoded.frame_count(), 0);
213        Ok(())
214    }
215
216    #[test]
217    fn test_decode_single_sample() -> TestResult<()> {
218        let wav_data = create_wav_header(44_100, 1, 16, 1)?;
219        let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
220        assert_eq!(decoded.samples.len(), 1);
221        assert_eq!(decoded.frame_count(), 1);
222        Ok(())
223    }
224
225    #[test]
226    fn test_normalization_bounds_16bit() {
227        let min_i16 = WavDecoder::normalize_i16(i16::MIN);
228        let max_i16 = WavDecoder::normalize_i16(i16::MAX);
229        let zero = WavDecoder::normalize_i16(0);
230
231        assert!((-1.0..=1.0).contains(&min_i16));
232        assert!((-1.0..=1.0).contains(&max_i16));
233        assert!(zero.abs() < f32::EPSILON);
234    }
235
236    #[test]
237    fn test_frame_count_calculation() -> TestResult<()> {
238        let mono = create_wav_header(44_100, 1, 16, 4_410)?;
239        let decoded_mono = WavDecoder::decode(&mono).map_err(|e| e.to_string())?;
240        assert_eq!(decoded_mono.frame_count(), 4_410);
241
242        let stereo = create_wav_header(44_100, 2, 16, 8_820)?;
243        let decoded_stereo = WavDecoder::decode(&stereo).map_err(|e| e.to_string())?;
244        assert_eq!(decoded_stereo.frame_count(), 4_410);
245        Ok(())
246    }
247
248    #[test]
249    fn test_duration_calculation_accuracy() -> TestResult<()> {
250        let wav_data = create_wav_header(48_000, 2, 16, 96_000)?; // 1s stereo
251        let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
252        assert!((decoded.duration_sec - 1.0).abs() < 1e-6);
253        Ok(())
254    }
255
256    fn create_wav_header(
257        sample_rate: u32,
258        channels: u16,
259        bits_per_sample: u16,
260        num_samples: usize,
261    ) -> TestResult<Vec<u8>> {
262        let spec = hound::WavSpec {
263            sample_rate,
264            channels,
265            bits_per_sample,
266            sample_format: hound::SampleFormat::Int,
267        };
268
269        let mut cursor = Cursor::new(Vec::new());
270        {
271            let mut writer = hound::WavWriter::new(&mut cursor, spec)
272                .map_err(|err| format!("failed to create WAV writer: {err}"))?;
273
274            for _ in 0..num_samples {
275                match bits_per_sample {
276                    16 => writer
277                        .write_sample(0i16)
278                        .map_err(|err| format!("failed to write 16-bit sample: {err}"))?,
279                    24 => writer
280                        .write_sample(0i32)
281                        .map_err(|err| format!("failed to write 24-bit sample: {err}"))?,
282                    _ => {
283                        return Err(format!("unsupported bit depth ({bits_per_sample})"));
284                    }
285                }
286            }
287
288            writer
289                .finalize()
290                .map_err(|err| format!("failed to finalize WAV: {err}"))?;
291        }
292
293        Ok(cursor.into_inner())
294    }
295
296    fn create_sine_wave_wav(
297        sample_rate: u32,
298        channels: u16,
299        bits_per_sample: u16,
300        frequency: f32,
301        duration_sec: f32,
302    ) -> TestResult<Vec<u8>> {
303        let spec = hound::WavSpec {
304            sample_rate,
305            channels,
306            bits_per_sample,
307            sample_format: hound::SampleFormat::Int,
308        };
309
310        let mut cursor = Cursor::new(Vec::new());
311        let mut writer = hound::WavWriter::new(&mut cursor, spec)
312            .map_err(|err| format!("failed to create WAV writer for sine wave: {err}"))?;
313
314        let num_samples = (sample_rate as f32 * duration_sec) as usize;
315        let amplitude = match bits_per_sample {
316            16 => 32767.0 * 0.8,
317            24 => 8_388_607.0 * 0.8,
318            _ => return Err(format!("unsupported bit depth ({bits_per_sample})")),
319        };
320
321        for i in 0..num_samples {
322            let t = i as f32 / sample_rate as f32;
323            let sample_f32 = amplitude * (2.0 * std::f32::consts::PI * frequency * t).sin();
324
325            for _ in 0..channels {
326                match bits_per_sample {
327                    16 => writer
328                        .write_sample(sample_f32 as i16)
329                        .map_err(|err| format!("failed to write sine sample: {err}"))?,
330                    24 => writer
331                        .write_sample(sample_f32 as i32)
332                        .map_err(|err| format!("failed to write sine sample: {err}"))?,
333                    _ => return Err(format!("unsupported bit depth ({bits_per_sample})")),
334                }
335            }
336        }
337
338        writer
339            .finalize()
340            .map_err(|err| format!("failed to finalize sine wave WAV: {err}"))?;
341        Ok(cursor.into_inner())
342    }
343}