Skip to main content

speech_prep/decoder/
wav.rs

1use std::io::{Cursor, Read, Seek};
2
3use crate::error::{Error, Result};
4
5use super::DecodedAudio;
6
7/// WAV/PCM audio decoder using the `hound` crate.
8///
9/// Supports 16-bit and 24-bit PCM with mono or stereo channels and normalizes
10/// all samples into the [-1.0, 1.0] range.
11#[derive(Debug, Default, Clone, Copy)]
12pub struct WavDecoder;
13
14impl WavDecoder {
15    /// Decode WAV audio from a byte slice.
16    pub fn decode(data: &[u8]) -> Result<DecodedAudio> {
17        let cursor = Cursor::new(data);
18        Self::decode_from_reader(cursor)
19    }
20
21    /// Decode WAV audio from any reader implementing `Read + Seek`.
22    pub fn decode_from_reader<R: Read + Seek>(reader: R) -> Result<DecodedAudio> {
23        let mut wav_reader = hound::WavReader::new(reader)
24            .map_err(|err| Error::InvalidInput(format!("failed to parse WAV header: {err}")))?;
25
26        let spec = wav_reader.spec();
27
28        if spec.sample_format != hound::SampleFormat::Int {
29            return Err(Error::InvalidInput(format!(
30                "unsupported WAV format: {:?} (only PCM is supported)",
31                spec.sample_format
32            )));
33        }
34
35        if spec.bits_per_sample != 16 && spec.bits_per_sample != 24 {
36            return Err(Error::InvalidInput(format!(
37                "unsupported bit depth: {} (only 16-bit and 24-bit PCM supported)",
38                spec.bits_per_sample
39            )));
40        }
41
42        if spec.channels > 2 {
43            return Err(Error::InvalidInput(format!(
44                "unsupported channel count: {} (only mono and stereo supported)",
45                spec.channels
46            )));
47        }
48
49        let samples = match spec.bits_per_sample {
50            16 => Self::decode_16bit(&mut wav_reader)?,
51            24 => Self::decode_24bit(&mut wav_reader)?,
52            _ => {
53                return Err(Error::InvalidInput(format!(
54                    "internal error: unhandled bit depth {}",
55                    spec.bits_per_sample
56                )));
57            }
58        };
59
60        let frame_count = samples.len() / spec.channels as usize;
61        let duration_sec = if spec.sample_rate > 0 {
62            frame_count as f64 / f64::from(spec.sample_rate)
63        } else {
64            0.0
65        };
66
67        Ok(DecodedAudio {
68            samples,
69            sample_rate: spec.sample_rate,
70            channels: spec.channels as u8,
71            bit_depth: spec.bits_per_sample,
72            duration_sec,
73        })
74    }
75
76    fn decode_16bit<R: Read + Seek>(wav_reader: &mut hound::WavReader<R>) -> Result<Vec<f32>> {
77        wav_reader
78            .samples::<i16>()
79            .map(|sample_result| {
80                sample_result.map(Self::normalize_i16).map_err(|err| {
81                    Error::InvalidInput(format!("failed to read 16-bit sample: {err}"))
82                })
83            })
84            .collect()
85    }
86
87    fn decode_24bit<R: Read + Seek>(wav_reader: &mut hound::WavReader<R>) -> Result<Vec<f32>> {
88        wav_reader
89            .samples::<i32>()
90            .map(|sample_result| {
91                sample_result.map(Self::normalize_i24).map_err(|err| {
92                    Error::InvalidInput(format!("failed to read 24-bit sample: {err}"))
93                })
94            })
95            .collect()
96    }
97
98    #[inline]
99    fn normalize_i16(sample: i16) -> f32 {
100        f32::from(sample) / 32768.0
101    }
102
103    #[inline]
104    fn normalize_i24(sample: i32) -> f32 {
105        (sample as f32) / 8_388_608.0
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use std::io::Cursor;
112
113    use super::*;
114
115    type TestResult<T> = std::result::Result<T, String>;
116
117    #[test]
118    fn test_decode_16bit_mono_44100hz() -> TestResult<()> {
119        let wav_data = create_wav_header(44100, 1, 16, 4410)?; // 0.1s mono
120        let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
121
122        assert_eq!(decoded.sample_rate, 44100);
123        assert_eq!(decoded.channels, 1);
124        assert_eq!(decoded.bit_depth, 16);
125        assert_eq!(decoded.samples.len(), 4410);
126        assert!((decoded.duration_sec - 0.1).abs() < 1e-6);
127        assert!(decoded.is_normalized());
128
129        Ok(())
130    }
131
132    #[test]
133    fn test_decode_16bit_stereo_48000hz() -> TestResult<()> {
134        let wav_data = create_wav_header(48000, 2, 16, 9600)?; // 0.1s stereo (4800 frames)
135        let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
136
137        assert_eq!(decoded.sample_rate, 48000);
138        assert_eq!(decoded.channels, 2);
139        assert_eq!(decoded.bit_depth, 16);
140        assert_eq!(decoded.samples.len(), 9600);
141        assert_eq!(decoded.frame_count(), 4800);
142        assert!((decoded.duration_sec - 0.1).abs() < 1e-6);
143        assert!(decoded.is_normalized());
144
145        Ok(())
146    }
147
148    #[test]
149    fn test_decode_24bit_mono_96000hz() -> TestResult<()> {
150        let wav_data = create_wav_header(96000, 1, 24, 9600)?; // 0.1s mono
151        let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
152
153        assert_eq!(decoded.sample_rate, 96000);
154        assert_eq!(decoded.channels, 1);
155        assert_eq!(decoded.bit_depth, 24);
156        assert_eq!(decoded.samples.len(), 9600);
157        assert!((decoded.duration_sec - 0.1).abs() < 1e-6);
158        assert!(decoded.is_normalized());
159
160        Ok(())
161    }
162
163    #[test]
164    fn test_decode_24bit_stereo_192000hz() -> TestResult<()> {
165        let wav_data = create_wav_header(192000, 2, 24, 19200)?; // 0.05s stereo
166        let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
167
168        assert_eq!(decoded.sample_rate, 192000);
169        assert_eq!(decoded.channels, 2);
170        assert_eq!(decoded.bit_depth, 24);
171        assert_eq!(decoded.samples.len(), 19200);
172        assert!(decoded.is_normalized());
173
174        Ok(())
175    }
176
177    #[test]
178    fn test_decode_sine_wave_preserves_amplitude() -> TestResult<()> {
179        let wav_data = create_sine_wave_wav(44100, 1, 16, 440.0, 0.1)?;
180        let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
181
182        let max_amplitude = decoded
183            .samples
184            .iter()
185            .map(|s| s.abs())
186            .fold(0.0f32, f32::max);
187        assert!(
188            (max_amplitude - 0.8).abs() < 0.05,
189            "expected max amplitude ~0.8, got {max_amplitude}"
190        );
191
192        Ok(())
193    }
194
195    #[test]
196    fn test_reject_empty_data() {
197        let result = WavDecoder::decode(&[]);
198        assert!(result.is_err());
199    }
200
201    #[test]
202    fn test_decode_zero_samples() -> TestResult<()> {
203        let wav_data = create_wav_header(44_100, 1, 16, 0)?;
204        let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
205        assert_eq!(decoded.samples.len(), 0);
206        assert_eq!(decoded.frame_count(), 0);
207        Ok(())
208    }
209
210    #[test]
211    fn test_decode_single_sample() -> TestResult<()> {
212        let wav_data = create_wav_header(44_100, 1, 16, 1)?;
213        let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
214        assert_eq!(decoded.samples.len(), 1);
215        assert_eq!(decoded.frame_count(), 1);
216        Ok(())
217    }
218
219    #[test]
220    fn test_normalization_bounds_16bit() {
221        let min_i16 = WavDecoder::normalize_i16(i16::MIN);
222        let max_i16 = WavDecoder::normalize_i16(i16::MAX);
223        let zero = WavDecoder::normalize_i16(0);
224
225        assert!((-1.0..=1.0).contains(&min_i16));
226        assert!((-1.0..=1.0).contains(&max_i16));
227        assert!(zero.abs() < f32::EPSILON);
228    }
229
230    #[test]
231    fn test_frame_count_calculation() -> TestResult<()> {
232        let mono = create_wav_header(44_100, 1, 16, 4_410)?;
233        let decoded_mono = WavDecoder::decode(&mono).map_err(|e| e.to_string())?;
234        assert_eq!(decoded_mono.frame_count(), 4_410);
235
236        let stereo = create_wav_header(44_100, 2, 16, 8_820)?;
237        let decoded_stereo = WavDecoder::decode(&stereo).map_err(|e| e.to_string())?;
238        assert_eq!(decoded_stereo.frame_count(), 4_410);
239        Ok(())
240    }
241
242    #[test]
243    fn test_duration_calculation_accuracy() -> TestResult<()> {
244        let wav_data = create_wav_header(48_000, 2, 16, 96_000)?; // 1s stereo
245        let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
246        assert!((decoded.duration_sec - 1.0).abs() < 1e-6);
247        Ok(())
248    }
249
250    fn create_wav_header(
251        sample_rate: u32,
252        channels: u16,
253        bits_per_sample: u16,
254        num_samples: usize,
255    ) -> TestResult<Vec<u8>> {
256        let spec = hound::WavSpec {
257            sample_rate,
258            channels,
259            bits_per_sample,
260            sample_format: hound::SampleFormat::Int,
261        };
262
263        let mut cursor = Cursor::new(Vec::new());
264        {
265            let mut writer = hound::WavWriter::new(&mut cursor, spec)
266                .map_err(|err| format!("failed to create WAV writer: {err}"))?;
267
268            for _ in 0..num_samples {
269                match bits_per_sample {
270                    16 => writer
271                        .write_sample(0i16)
272                        .map_err(|err| format!("failed to write 16-bit sample: {err}"))?,
273                    24 => writer
274                        .write_sample(0i32)
275                        .map_err(|err| format!("failed to write 24-bit sample: {err}"))?,
276                    _ => {
277                        return Err(format!("unsupported bit depth ({bits_per_sample})"));
278                    }
279                }
280            }
281
282            writer
283                .finalize()
284                .map_err(|err| format!("failed to finalize WAV: {err}"))?;
285        }
286
287        Ok(cursor.into_inner())
288    }
289
290    fn create_sine_wave_wav(
291        sample_rate: u32,
292        channels: u16,
293        bits_per_sample: u16,
294        frequency: f32,
295        duration_sec: f32,
296    ) -> TestResult<Vec<u8>> {
297        let spec = hound::WavSpec {
298            sample_rate,
299            channels,
300            bits_per_sample,
301            sample_format: hound::SampleFormat::Int,
302        };
303
304        let mut cursor = Cursor::new(Vec::new());
305        let mut writer = hound::WavWriter::new(&mut cursor, spec)
306            .map_err(|err| format!("failed to create WAV writer for sine wave: {err}"))?;
307
308        let num_samples = (sample_rate as f32 * duration_sec) as usize;
309        let amplitude = match bits_per_sample {
310            16 => 32767.0 * 0.8,
311            24 => 8_388_607.0 * 0.8,
312            _ => return Err(format!("unsupported bit depth ({bits_per_sample})")),
313        };
314
315        for i in 0..num_samples {
316            let t = i as f32 / sample_rate as f32;
317            let sample_f32 = amplitude * (2.0 * std::f32::consts::PI * frequency * t).sin();
318
319            for _ in 0..channels {
320                match bits_per_sample {
321                    16 => writer
322                        .write_sample(sample_f32 as i16)
323                        .map_err(|err| format!("failed to write sine sample: {err}"))?,
324                    24 => writer
325                        .write_sample(sample_f32 as i32)
326                        .map_err(|err| format!("failed to write sine sample: {err}"))?,
327                    _ => return Err(format!("unsupported bit depth ({bits_per_sample})")),
328                }
329            }
330        }
331
332        writer
333            .finalize()
334            .map_err(|err| format!("failed to finalize sine wave WAV: {err}"))?;
335        Ok(cursor.into_inner())
336    }
337}