1use std::io::{Cursor, Read, Seek};
2
3use crate::error::{Error, Result};
4
5use super::DecodedAudio;
6
7#[derive(Debug, Default, Clone, Copy)]
12pub struct WavDecoder;
13
14impl WavDecoder {
15 #[must_use]
17 pub const fn new() -> Self {
18 Self
19 }
20
21 pub fn decode(data: &[u8]) -> Result<DecodedAudio> {
23 let cursor = Cursor::new(data);
24 Self::decode_from_reader(cursor)
25 }
26
27 pub fn decode_from_reader<R: Read + Seek>(reader: R) -> Result<DecodedAudio> {
29 let mut wav_reader = hound::WavReader::new(reader)
30 .map_err(|err| Error::InvalidInput(format!("failed to parse WAV header: {err}")))?;
31
32 let spec = wav_reader.spec();
33
34 if spec.sample_format != hound::SampleFormat::Int {
35 return Err(Error::InvalidInput(format!(
36 "unsupported WAV format: {:?} (only PCM is supported)",
37 spec.sample_format
38 )));
39 }
40
41 if spec.bits_per_sample != 16 && spec.bits_per_sample != 24 {
42 return Err(Error::InvalidInput(format!(
43 "unsupported bit depth: {} (only 16-bit and 24-bit PCM supported)",
44 spec.bits_per_sample
45 )));
46 }
47
48 if spec.channels > 2 {
49 return Err(Error::InvalidInput(format!(
50 "unsupported channel count: {} (only mono and stereo supported)",
51 spec.channels
52 )));
53 }
54
55 let samples = match spec.bits_per_sample {
56 16 => Self::decode_16bit(&mut wav_reader)?,
57 24 => Self::decode_24bit(&mut wav_reader)?,
58 _ => {
59 return Err(Error::InvalidInput(format!(
60 "internal error: unhandled bit depth {}",
61 spec.bits_per_sample
62 )));
63 }
64 };
65
66 let frame_count = samples.len() / spec.channels as usize;
67 let duration_sec = if spec.sample_rate > 0 {
68 frame_count as f64 / f64::from(spec.sample_rate)
69 } else {
70 0.0
71 };
72
73 Ok(DecodedAudio {
74 samples,
75 sample_rate: spec.sample_rate,
76 channels: spec.channels as u8,
77 bit_depth: spec.bits_per_sample,
78 duration_sec,
79 })
80 }
81
82 fn decode_16bit<R: Read + Seek>(wav_reader: &mut hound::WavReader<R>) -> Result<Vec<f32>> {
83 wav_reader
84 .samples::<i16>()
85 .map(|sample_result| {
86 sample_result.map(Self::normalize_i16).map_err(|err| {
87 Error::InvalidInput(format!("failed to read 16-bit sample: {err}"))
88 })
89 })
90 .collect()
91 }
92
93 fn decode_24bit<R: Read + Seek>(wav_reader: &mut hound::WavReader<R>) -> Result<Vec<f32>> {
94 wav_reader
95 .samples::<i32>()
96 .map(|sample_result| {
97 sample_result.map(Self::normalize_i24).map_err(|err| {
98 Error::InvalidInput(format!("failed to read 24-bit sample: {err}"))
99 })
100 })
101 .collect()
102 }
103
104 #[inline]
105 fn normalize_i16(sample: i16) -> f32 {
106 f32::from(sample) / 32768.0
107 }
108
109 #[inline]
110 fn normalize_i24(sample: i32) -> f32 {
111 (sample as f32) / 8_388_608.0
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use std::io::Cursor;
118
119 use super::*;
120
121 type TestResult<T> = std::result::Result<T, String>;
122
123 #[test]
124 fn test_decode_16bit_mono_44100hz() -> TestResult<()> {
125 let wav_data = create_wav_header(44100, 1, 16, 4410)?; let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
127
128 assert_eq!(decoded.sample_rate, 44100);
129 assert_eq!(decoded.channels, 1);
130 assert_eq!(decoded.bit_depth, 16);
131 assert_eq!(decoded.samples.len(), 4410);
132 assert!((decoded.duration_sec - 0.1).abs() < 1e-6);
133 assert!(decoded.is_normalized());
134
135 Ok(())
136 }
137
138 #[test]
139 fn test_decode_16bit_stereo_48000hz() -> TestResult<()> {
140 let wav_data = create_wav_header(48000, 2, 16, 9600)?; let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
142
143 assert_eq!(decoded.sample_rate, 48000);
144 assert_eq!(decoded.channels, 2);
145 assert_eq!(decoded.bit_depth, 16);
146 assert_eq!(decoded.samples.len(), 9600);
147 assert_eq!(decoded.frame_count(), 4800);
148 assert!((decoded.duration_sec - 0.1).abs() < 1e-6);
149 assert!(decoded.is_normalized());
150
151 Ok(())
152 }
153
154 #[test]
155 fn test_decode_24bit_mono_96000hz() -> TestResult<()> {
156 let wav_data = create_wav_header(96000, 1, 24, 9600)?; let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
158
159 assert_eq!(decoded.sample_rate, 96000);
160 assert_eq!(decoded.channels, 1);
161 assert_eq!(decoded.bit_depth, 24);
162 assert_eq!(decoded.samples.len(), 9600);
163 assert!((decoded.duration_sec - 0.1).abs() < 1e-6);
164 assert!(decoded.is_normalized());
165
166 Ok(())
167 }
168
169 #[test]
170 fn test_decode_24bit_stereo_192000hz() -> TestResult<()> {
171 let wav_data = create_wav_header(192000, 2, 24, 19200)?; let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
173
174 assert_eq!(decoded.sample_rate, 192000);
175 assert_eq!(decoded.channels, 2);
176 assert_eq!(decoded.bit_depth, 24);
177 assert_eq!(decoded.samples.len(), 19200);
178 assert!(decoded.is_normalized());
179
180 Ok(())
181 }
182
183 #[test]
184 fn test_decode_sine_wave_preserves_amplitude() -> TestResult<()> {
185 let wav_data = create_sine_wave_wav(44100, 1, 16, 440.0, 0.1)?;
186 let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
187
188 let max_amplitude = decoded
189 .samples
190 .iter()
191 .map(|s| s.abs())
192 .fold(0.0f32, f32::max);
193 assert!(
194 (max_amplitude - 0.8).abs() < 0.05,
195 "expected max amplitude ~0.8, got {max_amplitude}"
196 );
197
198 Ok(())
199 }
200
201 #[test]
202 fn test_reject_empty_data() {
203 let result = WavDecoder::decode(&[]);
204 assert!(result.is_err());
205 }
206
207 #[test]
208 fn test_decode_zero_samples() -> TestResult<()> {
209 let wav_data = create_wav_header(44_100, 1, 16, 0)?;
210 let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
211 assert_eq!(decoded.samples.len(), 0);
212 assert_eq!(decoded.frame_count(), 0);
213 Ok(())
214 }
215
216 #[test]
217 fn test_decode_single_sample() -> TestResult<()> {
218 let wav_data = create_wav_header(44_100, 1, 16, 1)?;
219 let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
220 assert_eq!(decoded.samples.len(), 1);
221 assert_eq!(decoded.frame_count(), 1);
222 Ok(())
223 }
224
225 #[test]
226 fn test_normalization_bounds_16bit() {
227 let min_i16 = WavDecoder::normalize_i16(i16::MIN);
228 let max_i16 = WavDecoder::normalize_i16(i16::MAX);
229 let zero = WavDecoder::normalize_i16(0);
230
231 assert!((-1.0..=1.0).contains(&min_i16));
232 assert!((-1.0..=1.0).contains(&max_i16));
233 assert!(zero.abs() < f32::EPSILON);
234 }
235
236 #[test]
237 fn test_frame_count_calculation() -> TestResult<()> {
238 let mono = create_wav_header(44_100, 1, 16, 4_410)?;
239 let decoded_mono = WavDecoder::decode(&mono).map_err(|e| e.to_string())?;
240 assert_eq!(decoded_mono.frame_count(), 4_410);
241
242 let stereo = create_wav_header(44_100, 2, 16, 8_820)?;
243 let decoded_stereo = WavDecoder::decode(&stereo).map_err(|e| e.to_string())?;
244 assert_eq!(decoded_stereo.frame_count(), 4_410);
245 Ok(())
246 }
247
248 #[test]
249 fn test_duration_calculation_accuracy() -> TestResult<()> {
250 let wav_data = create_wav_header(48_000, 2, 16, 96_000)?; let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
252 assert!((decoded.duration_sec - 1.0).abs() < 1e-6);
253 Ok(())
254 }
255
256 fn create_wav_header(
257 sample_rate: u32,
258 channels: u16,
259 bits_per_sample: u16,
260 num_samples: usize,
261 ) -> TestResult<Vec<u8>> {
262 let spec = hound::WavSpec {
263 sample_rate,
264 channels,
265 bits_per_sample,
266 sample_format: hound::SampleFormat::Int,
267 };
268
269 let mut cursor = Cursor::new(Vec::new());
270 {
271 let mut writer = hound::WavWriter::new(&mut cursor, spec)
272 .map_err(|err| format!("failed to create WAV writer: {err}"))?;
273
274 for _ in 0..num_samples {
275 match bits_per_sample {
276 16 => writer
277 .write_sample(0i16)
278 .map_err(|err| format!("failed to write 16-bit sample: {err}"))?,
279 24 => writer
280 .write_sample(0i32)
281 .map_err(|err| format!("failed to write 24-bit sample: {err}"))?,
282 _ => {
283 return Err(format!("unsupported bit depth ({bits_per_sample})"));
284 }
285 }
286 }
287
288 writer
289 .finalize()
290 .map_err(|err| format!("failed to finalize WAV: {err}"))?;
291 }
292
293 Ok(cursor.into_inner())
294 }
295
296 fn create_sine_wave_wav(
297 sample_rate: u32,
298 channels: u16,
299 bits_per_sample: u16,
300 frequency: f32,
301 duration_sec: f32,
302 ) -> TestResult<Vec<u8>> {
303 let spec = hound::WavSpec {
304 sample_rate,
305 channels,
306 bits_per_sample,
307 sample_format: hound::SampleFormat::Int,
308 };
309
310 let mut cursor = Cursor::new(Vec::new());
311 let mut writer = hound::WavWriter::new(&mut cursor, spec)
312 .map_err(|err| format!("failed to create WAV writer for sine wave: {err}"))?;
313
314 let num_samples = (sample_rate as f32 * duration_sec) as usize;
315 let amplitude = match bits_per_sample {
316 16 => 32767.0 * 0.8,
317 24 => 8_388_607.0 * 0.8,
318 _ => return Err(format!("unsupported bit depth ({bits_per_sample})")),
319 };
320
321 for i in 0..num_samples {
322 let t = i as f32 / sample_rate as f32;
323 let sample_f32 = amplitude * (2.0 * std::f32::consts::PI * frequency * t).sin();
324
325 for _ in 0..channels {
326 match bits_per_sample {
327 16 => writer
328 .write_sample(sample_f32 as i16)
329 .map_err(|err| format!("failed to write sine sample: {err}"))?,
330 24 => writer
331 .write_sample(sample_f32 as i32)
332 .map_err(|err| format!("failed to write sine sample: {err}"))?,
333 _ => return Err(format!("unsupported bit depth ({bits_per_sample})")),
334 }
335 }
336 }
337
338 writer
339 .finalize()
340 .map_err(|err| format!("failed to finalize sine wave WAV: {err}"))?;
341 Ok(cursor.into_inner())
342 }
343}