stream_wave_parser/
wave.rs1use crate::{Error, Result};
4use futures_util::stream::{iter, BoxStream};
5use futures_util::{Stream, StreamExt as _};
6use std::collections::VecDeque;
7
8pub struct WaveStream<'a> {
10 stream: BoxStream<'a, Result<Vec<u8>>>,
11 current: VecDeque<u8>,
12 riff_size: Option<u32>,
13 spec: Option<WaveSpec>,
14 data_size: Option<u32>,
15}
16
17#[derive(Clone, Debug)]
19pub struct WaveSpec {
20 pub pcm_format: u16,
22
23 pub channels: u16,
25
26 pub sample_rate: u32,
28
29 pub bits_per_sample: u16,
31}
32
33struct DataChunk<'a> {
35 stream: BoxStream<'a, Result<Vec<u8>>>,
36 data_size: u32,
37 consumed: u32,
38}
39
40impl<'a> WaveStream<'a> {
41 pub fn new(stream: impl Stream<Item = Result<Vec<u8>>> + Send + 'a) -> Self {
43 Self {
44 stream: Box::pin(stream),
45 current: VecDeque::new(),
46 riff_size: None,
47 spec: None,
48 data_size: None,
49 }
50 }
51
52 pub async fn spec(&mut self) -> Result<WaveSpec> {
54 self.take_riff().await?;
55 self.skip_to_data_chunk().await?;
56
57 let spec = self.spec.as_ref().ok_or(Error::FmtChunkIsNotFound)?;
58 Ok(spec.clone())
59 }
60
61 pub async fn into_data(mut self) -> BoxStream<'a, Result<Vec<u8>>> {
63 if let Err(e) = self.take_riff().await {
64 return Box::pin(iter(vec![Err(e)]));
65 }
66
67 if let Err(e) = self.skip_to_data_chunk().await {
68 return Box::pin(iter(vec![Err(e)]));
69 }
70
71 let data_size = self.data_size.unwrap(); if data_size <= self.current.len() as u32 {
74 return Box::pin(iter(vec![Ok(self
75 .current
76 .into_iter()
77 .take(data_size as usize)
78 .collect())]));
79 }
80
81 let consumed = self.current.len() as u32;
82 let data_chunk = DataChunk {
83 stream: self.stream,
84 data_size,
85 consumed,
86 };
87
88 Box::pin(iter(vec![Ok(self.current.into())]).chain(data_chunk))
89 }
90
91 async fn take_riff(&mut self) -> Result<()> {
92 if self.riff_size.is_some() {
93 return Ok(());
94 }
95
96 let four = self.take::<4>().await?;
97 if b"RIFF" != &four {
98 return Err(Error::RiffChunkHeaderIsNotFound);
99 }
100
101 self.riff_size = Some(self.take_u32().await?);
102
103 let four = self.take::<4>().await?;
104 if b"WAVE" != &four {
105 return Err(Error::WaveChunkHeaderIsNotFound);
106 }
107
108 Ok(())
109 }
110
111 async fn skip_to_data_chunk(&mut self) -> Result<()> {
112 if self.data_size.is_some() {
113 return Ok(());
114 }
115
116 loop {
117 let four = self.take::<4>().await?;
118 let size = self.take_u32().await?;
119 match &four {
120 b"data" => {
121 self.data_size = Some(size);
122 return Ok(());
123 }
124
125 b"fmt " => {
126 let spec = self.parse_fmt(size).await?;
127 self.spec = Some(spec);
128 }
129
130 _ => {
132 for _ in 0..size {
133 self.next().await?;
134 }
135 }
136 }
137 }
138 }
139
140 async fn take_u16(&mut self) -> Result<u16> {
141 let four = self.take::<2>().await?;
142 Ok(u16::from_le_bytes(four))
143 }
144
145 async fn take_u32(&mut self) -> Result<u32> {
146 let four = self.take::<4>().await?;
147 Ok(u32::from_le_bytes(four))
148 }
149
150 async fn parse_fmt(&mut self, size: u32) -> Result<WaveSpec> {
151 let pcm_format = self.take_u16().await?;
152
153 let channels = self.take_u16().await?;
154 let sample_rate = self.take_u32().await?;
155 let _bit_rate = self.take_u32().await?;
156 let _block_size = self.take_u16().await?;
157 let bits_per_sample = self.take_u16().await?;
158
159 if size > 16 {
161 for _ in 0..(size - 16) {
162 self.next().await?;
163 }
164 }
165
166 let spec = WaveSpec {
167 pcm_format,
168 channels,
169 sample_rate,
170 bits_per_sample,
171 };
172 Ok(spec)
173 }
174
175 async fn take<const N: usize>(&mut self) -> Result<[u8; N]> {
176 let mut bytes = [0; N];
177 for item in bytes.iter_mut() {
178 *item = self.next().await?;
179 }
180 Ok(bytes)
181 }
182
183 async fn next(&mut self) -> Result<u8> {
184 while self.current.is_empty() {
185 self.current = self
186 .stream
187 .next()
188 .await
189 .ok_or(Error::DataIsNotEnough)??
190 .into();
191 }
192
193 Ok(self.current.pop_front().unwrap())
194 }
195}
196
197mod impls {
198 use super::*;
201
202 use std::pin::Pin;
203 use std::task::{Context, Poll};
204
205 impl<'a> Stream for DataChunk<'a> {
206 type Item = Result<Vec<u8>>;
207
208 fn poll_next(
209 mut self: Pin<&mut Self>,
210 context: &mut Context<'_>,
211 ) -> Poll<Option<<Self as Stream>::Item>> {
212 let polled = self.stream.as_mut().poll_next(context);
213 let ready = match polled {
214 Poll::Ready(ready) => ready,
215 Poll::Pending => return Poll::Pending,
216 };
217
218 let Some(chunk) = ready else {
219 return Poll::Ready(None);
220 };
221
222 let chunk = match chunk {
223 Ok(chunk) => chunk,
224 Err(e) => return Poll::Ready(Some(Err(e))),
225 };
226
227 let rest_size = (self.data_size - self.consumed) as usize;
228 if chunk.len() < rest_size {
229 self.consumed += chunk.len() as u32;
230 Poll::Ready(Some(Ok(chunk)))
231 } else {
232 let chunk = chunk.into_iter().take(rest_size).collect();
233 Poll::Ready(Some(Ok(chunk)))
234 }
235 }
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 use tokio::fs::read;
244
245 const FILE: &str = "./assets/test/maou_se_system49.wav";
247
248 #[tokio::test]
249 async fn test_one_chunk() {
250 let read = read(FILE).await.unwrap();
251 let mut stream = WaveStream::new(iter(vec![Ok(read)]));
252
253 let spec = stream.spec().await.unwrap();
254 assert_eq!(spec.pcm_format, 1);
255 assert_eq!(spec.channels, 2);
256 assert_eq!(spec.sample_rate, 44100);
257 assert_eq!(spec.bits_per_sample, 24);
258
259 let data_size = stream.data_size.unwrap();
260 let mut data = stream.into_data().await;
261 let mut size = 0;
262 while let Some(chunk) = data.next().await {
263 let chunk = chunk.unwrap();
264 size += chunk.len();
265 }
266 assert_eq!(data_size, size as u32);
267 }
268
269 #[tokio::test]
270 async fn test_chunks() {
271 let read = read(FILE).await.unwrap();
272 let chunks = read
273 .chunks(65536)
274 .map(|x| Ok(x.iter().cloned().collect::<Vec<_>>()))
275 .collect::<Vec<_>>();
276 let mut stream = WaveStream::new(iter(chunks));
277
278 let spec = stream.spec().await.unwrap();
279 assert_eq!(spec.pcm_format, 1);
280 assert_eq!(spec.channels, 2);
281 assert_eq!(spec.sample_rate, 44100);
282 assert_eq!(spec.bits_per_sample, 24);
283
284 let data_size = stream.data_size.unwrap();
285 let mut data = stream.into_data().await;
286 let mut size = 0;
287 while let Some(chunk) = data.next().await {
288 let chunk = chunk.unwrap();
289 size += chunk.len();
290 }
291 assert_eq!(data_size, size as u32);
292 }
293
294 #[tokio::test]
295 async fn test_generate() {
296 use std::f32::consts::PI;
298
299 let data_chunk = (0..)
300 .enumerate()
301 .map(|(_, idx)| {
302 let t = idx as f32 / 8000.0;
303 let sample = (t * 440. * 2. * PI).sin();
304 ((sample * i16::MAX as f32) as i16).to_le_bytes()
305 })
306 .take(8000)
307 .flatten()
308 .collect::<Vec<u8>>();
309
310 let mut wave = b"RIFF".to_vec();
311 let riff_length = ((data_chunk.len() + 36) as u32).to_le_bytes();
312 wave.extend(riff_length);
313 wave.extend(b"WAVE");
314 wave.extend(b"fmt ");
315 wave.extend(16u32.to_le_bytes()); wave.extend(1u16.to_le_bytes()); wave.extend(1u16.to_le_bytes()); wave.extend(8000u32.to_le_bytes()); wave.extend(16000u32.to_le_bytes()); wave.extend(2u16.to_le_bytes()); wave.extend(16u16.to_le_bytes()); wave.extend(b"data");
324 wave.extend((data_chunk.len() as u32).to_le_bytes());
325 wave.extend(&data_chunk);
326
327 let chunks = wave
329 .chunks(65536)
330 .map(|x| Ok(x.iter().cloned().collect::<Vec<_>>()))
331 .collect::<Vec<_>>();
332 let mut stream = WaveStream::new(iter(chunks));
333
334 let spec = stream.spec().await.unwrap();
336 assert_eq!(spec.pcm_format, 1);
337 assert_eq!(spec.channels, 1);
338 assert_eq!(spec.sample_rate, 8000);
339 assert_eq!(spec.bits_per_sample, 16);
340
341 let data_size = stream.data_size.unwrap();
342 let mut data = stream.into_data().await;
343 let mut size = 0;
344 while let Some(chunk) = data.next().await {
345 let chunk = chunk.unwrap();
346 size += chunk.len();
347 }
348 assert_eq!(data_size, size as u32);
349 assert_eq!(data_size, data_chunk.len() as u32);
350 }
351}