stream_wave_parser/
wave.rs

1//! Types to parse stream as WAVE header and `data` stream.
2
3use crate::{Error, Result};
4use futures_util::stream::{iter, BoxStream};
5use futures_util::{Stream, StreamExt as _};
6use std::collections::VecDeque;
7
8/// The structure that handles WAVE files as a stream.
9pub struct WaveStream<'a> {
10    stream: BoxStream<'a, Result<Vec<u8>>>,
11    current: VecDeque<u8>,
12    riff_size: Option<u32>,
13    spec: Option<WaveSpec>,
14    data_size: Option<u32>,
15}
16
17/// The structure representing the metadata of the WAVE file.
18#[derive(Clone, Debug)]
19pub struct WaveSpec {
20    /// Audio format. (1: PCM integer, ...).
21    pub pcm_format: u16,
22
23    /// Number of channels.
24    pub channels: u16,
25
26    /// Sample rate.
27    pub sample_rate: u32,
28
29    /// Number of bits per sample.
30    pub bits_per_sample: u16,
31}
32
33/// The stream wraps a stream that returns rest.
34struct DataChunk<'a> {
35    stream: BoxStream<'a, Result<Vec<u8>>>,
36    data_size: u32,
37    consumed: u32,
38}
39
40impl<'a> WaveStream<'a> {
41    /// The constructor.
42    pub fn new(stream: impl Stream<Item = Result<Vec<u8>>> + Send + 'a) -> Self {
43        Self {
44            stream: Box::pin(stream),
45            current: VecDeque::new(),
46            riff_size: None,
47            spec: None,
48            data_size: None,
49        }
50    }
51
52    /// Parses a WAVE header and returns [`WaveSpec`].
53    pub async fn spec(&mut self) -> Result<WaveSpec> {
54        self.take_riff().await?;
55        self.skip_to_data_chunk().await?;
56
57        let spec = self.spec.as_ref().ok_or(Error::FmtChunkIsNotFound)?;
58        Ok(spec.clone())
59    }
60
61    /// Returns the stream that returns data chunks.
62    pub async fn into_data(mut self) -> BoxStream<'a, Result<Vec<u8>>> {
63        if let Err(e) = self.take_riff().await {
64            return Box::pin(iter(vec![Err(e)]));
65        }
66
67        if let Err(e) = self.skip_to_data_chunk().await {
68            return Box::pin(iter(vec![Err(e)]));
69        }
70
71        let data_size = self.data_size.unwrap(); // If `skip_to_data_chunk()` is `Ok(_)`, it is `Some(_)` and can be `unwrap()`.
72
73        if data_size <= self.current.len() as u32 {
74            return Box::pin(iter(vec![Ok(self
75                .current
76                .into_iter()
77                .take(data_size as usize)
78                .collect())]));
79        }
80
81        let consumed = self.current.len() as u32;
82        let data_chunk = DataChunk {
83            stream: self.stream,
84            data_size,
85            consumed,
86        };
87
88        Box::pin(iter(vec![Ok(self.current.into())]).chain(data_chunk))
89    }
90
91    async fn take_riff(&mut self) -> Result<()> {
92        if self.riff_size.is_some() {
93            return Ok(());
94        }
95
96        let four = self.take::<4>().await?;
97        if b"RIFF" != &four {
98            return Err(Error::RiffChunkHeaderIsNotFound);
99        }
100
101        self.riff_size = Some(self.take_u32().await?);
102
103        let four = self.take::<4>().await?;
104        if b"WAVE" != &four {
105            return Err(Error::WaveChunkHeaderIsNotFound);
106        }
107
108        Ok(())
109    }
110
111    async fn skip_to_data_chunk(&mut self) -> Result<()> {
112        if self.data_size.is_some() {
113            return Ok(());
114        }
115
116        loop {
117            let four = self.take::<4>().await?;
118            let size = self.take_u32().await?;
119            match &four {
120                b"data" => {
121                    self.data_size = Some(size);
122                    return Ok(());
123                }
124
125                b"fmt " => {
126                    let spec = self.parse_fmt(size).await?;
127                    self.spec = Some(spec);
128                }
129
130                // skip other chunk
131                _ => {
132                    for _ in 0..size {
133                        self.next().await?;
134                    }
135                }
136            }
137        }
138    }
139
140    async fn take_u16(&mut self) -> Result<u16> {
141        let four = self.take::<2>().await?;
142        Ok(u16::from_le_bytes(four))
143    }
144
145    async fn take_u32(&mut self) -> Result<u32> {
146        let four = self.take::<4>().await?;
147        Ok(u32::from_le_bytes(four))
148    }
149
150    async fn parse_fmt(&mut self, size: u32) -> Result<WaveSpec> {
151        let pcm_format = self.take_u16().await?;
152
153        let channels = self.take_u16().await?;
154        let sample_rate = self.take_u32().await?;
155        let _bit_rate = self.take_u32().await?;
156        let _block_size = self.take_u16().await?;
157        let bits_per_sample = self.take_u16().await?;
158
159        // skip extension
160        if size > 16 {
161            for _ in 0..(size - 16) {
162                self.next().await?;
163            }
164        }
165
166        let spec = WaveSpec {
167            pcm_format,
168            channels,
169            sample_rate,
170            bits_per_sample,
171        };
172        Ok(spec)
173    }
174
175    async fn take<const N: usize>(&mut self) -> Result<[u8; N]> {
176        let mut bytes = [0; N];
177        for item in bytes.iter_mut() {
178            *item = self.next().await?;
179        }
180        Ok(bytes)
181    }
182
183    async fn next(&mut self) -> Result<u8> {
184        while self.current.is_empty() {
185            self.current = self
186                .stream
187                .next()
188                .await
189                .ok_or(Error::DataIsNotEnough)??
190                .into();
191        }
192
193        Ok(self.current.pop_front().unwrap())
194    }
195}
196
197mod impls {
198    //! Implements [`Stream`] for [`DataChunk`].
199
200    use super::*;
201
202    use std::pin::Pin;
203    use std::task::{Context, Poll};
204
205    impl<'a> Stream for DataChunk<'a> {
206        type Item = Result<Vec<u8>>;
207
208        fn poll_next(
209            mut self: Pin<&mut Self>,
210            context: &mut Context<'_>,
211        ) -> Poll<Option<<Self as Stream>::Item>> {
212            let polled = self.stream.as_mut().poll_next(context);
213            let ready = match polled {
214                Poll::Ready(ready) => ready,
215                Poll::Pending => return Poll::Pending,
216            };
217
218            let Some(chunk) = ready else {
219                return Poll::Ready(None);
220            };
221
222            let chunk = match chunk {
223                Ok(chunk) => chunk,
224                Err(e) => return Poll::Ready(Some(Err(e))),
225            };
226
227            let rest_size = (self.data_size - self.consumed) as usize;
228            if chunk.len() < rest_size {
229                self.consumed += chunk.len() as u32;
230                Poll::Ready(Some(Ok(chunk)))
231            } else {
232                let chunk = chunk.into_iter().take(rest_size).collect();
233                Poll::Ready(Some(Ok(chunk)))
234            }
235        }
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    use tokio::fs::read;
244
245    /// The file path of WAVE file. The file from `魔王魂®︎`. <https://maou.audio/se_system49/>
246    const FILE: &str = "./assets/test/maou_se_system49.wav";
247
248    #[tokio::test]
249    async fn test_one_chunk() {
250        let read = read(FILE).await.unwrap();
251        let mut stream = WaveStream::new(iter(vec![Ok(read)]));
252
253        let spec = stream.spec().await.unwrap();
254        assert_eq!(spec.pcm_format, 1);
255        assert_eq!(spec.channels, 2);
256        assert_eq!(spec.sample_rate, 44100);
257        assert_eq!(spec.bits_per_sample, 24);
258
259        let data_size = stream.data_size.unwrap();
260        let mut data = stream.into_data().await;
261        let mut size = 0;
262        while let Some(chunk) = data.next().await {
263            let chunk = chunk.unwrap();
264            size += chunk.len();
265        }
266        assert_eq!(data_size, size as u32);
267    }
268
269    #[tokio::test]
270    async fn test_chunks() {
271        let read = read(FILE).await.unwrap();
272        let chunks = read
273            .chunks(65536)
274            .map(|x| Ok(x.iter().cloned().collect::<Vec<_>>()))
275            .collect::<Vec<_>>();
276        let mut stream = WaveStream::new(iter(chunks));
277
278        let spec = stream.spec().await.unwrap();
279        assert_eq!(spec.pcm_format, 1);
280        assert_eq!(spec.channels, 2);
281        assert_eq!(spec.sample_rate, 44100);
282        assert_eq!(spec.bits_per_sample, 24);
283
284        let data_size = stream.data_size.unwrap();
285        let mut data = stream.into_data().await;
286        let mut size = 0;
287        while let Some(chunk) = data.next().await {
288            let chunk = chunk.unwrap();
289            size += chunk.len();
290        }
291        assert_eq!(data_size, size as u32);
292    }
293
294    #[tokio::test]
295    async fn test_generate() {
296        // create sine wave (440 Hz 1 seconds)
297        use std::f32::consts::PI;
298
299        let data_chunk = (0..)
300            .enumerate()
301            .map(|(_, idx)| {
302                let t = idx as f32 / 8000.0;
303                let sample = (t * 440. * 2. * PI).sin();
304                ((sample * i16::MAX as f32) as i16).to_le_bytes()
305            })
306            .take(8000)
307            .flatten()
308            .collect::<Vec<u8>>();
309
310        let mut wave = b"RIFF".to_vec();
311        let riff_length = ((data_chunk.len() + 36) as u32).to_le_bytes();
312        wave.extend(riff_length);
313        wave.extend(b"WAVE");
314        wave.extend(b"fmt ");
315        wave.extend(16u32.to_le_bytes()); // `fmt ` chunk size
316        wave.extend(1u16.to_le_bytes()); // PCM format
317        wave.extend(1u16.to_le_bytes()); // channels
318        wave.extend(8000u32.to_le_bytes()); // sample rate
319        wave.extend(16000u32.to_le_bytes()); // bit rate
320        wave.extend(2u16.to_le_bytes()); // block size
321        wave.extend(16u16.to_le_bytes()); // bits per sample
322
323        wave.extend(b"data");
324        wave.extend((data_chunk.len() as u32).to_le_bytes());
325        wave.extend(&data_chunk);
326
327        // create stream
328        let chunks = wave
329            .chunks(65536)
330            .map(|x| Ok(x.iter().cloned().collect::<Vec<_>>()))
331            .collect::<Vec<_>>();
332        let mut stream = WaveStream::new(iter(chunks));
333
334        // read stream
335        let spec = stream.spec().await.unwrap();
336        assert_eq!(spec.pcm_format, 1);
337        assert_eq!(spec.channels, 1);
338        assert_eq!(spec.sample_rate, 8000);
339        assert_eq!(spec.bits_per_sample, 16);
340
341        let data_size = stream.data_size.unwrap();
342        let mut data = stream.into_data().await;
343        let mut size = 0;
344        while let Some(chunk) = data.next().await {
345            let chunk = chunk.unwrap();
346            size += chunk.len();
347        }
348        assert_eq!(data_size, size as u32);
349        assert_eq!(data_size, data_chunk.len() as u32);
350    }
351}