use crate::{Error, Result};
use futures_core::Stream;
use futures_util::stream::{iter, BoxStream};
use futures_util::StreamExt as _;
use std::collections::VecDeque;
use std::pin::Pin;
use std::task::Poll;
pub struct WaveStream<'a> {
stream: BoxStream<'a, Result<Vec<u8>>>,
current: VecDeque<u8>,
riff_size: Option<u32>,
spec: Option<WaveSpec>,
data_size: Option<u32>,
}
#[derive(Clone, Debug)]
pub struct WaveSpec {
pub pcm_format: u16,
pub channels: u16,
pub sample_rate: u32,
pub bits_per_sample: u16,
}
struct DataChunk<'a> {
stream: BoxStream<'a, Result<Vec<u8>>>,
data_size: u32,
consumed: u32,
}
impl<'a> WaveStream<'a> {
pub fn new(stream: impl Stream<Item = Result<Vec<u8>>> + Send + 'a) -> Self {
Self {
stream: Box::pin(stream),
current: VecDeque::new(),
riff_size: None,
spec: None,
data_size: None,
}
}
pub async fn spec(&mut self) -> Result<WaveSpec> {
self.take_riff().await?;
self.skip_to_data_chunk().await?;
let spec = self.spec.as_ref().ok_or(Error::FmtChunkIsNotFound)?;
Ok(spec.clone())
}
pub async fn into_data(mut self) -> BoxStream<'a, Result<Vec<u8>>> {
if let Err(e) = self.take_riff().await {
return Box::pin(iter(vec![Err(e)]));
}
if let Err(e) = self.skip_to_data_chunk().await {
return Box::pin(iter(vec![Err(e)]));
}
let data_size = self.data_size.unwrap(); if data_size <= self.current.len() as u32 {
return Box::pin(iter(vec![Ok(self
.current
.into_iter()
.take(data_size as usize)
.collect())]));
}
let consumed = self.current.len() as u32;
let data_chunk = DataChunk {
stream: self.stream,
data_size,
consumed,
};
Box::pin(iter(vec![Ok(self.current.into())]).chain(data_chunk))
}
async fn take_riff(&mut self) -> Result<()> {
if self.riff_size.is_some() {
return Ok(());
}
let four = self.take::<4>().await?;
if b"RIFF" != &four {
return Err(Error::RiffChunkHeaderIsNotFound);
}
self.riff_size = Some(self.take_u32().await?);
let four = self.take::<4>().await?;
if b"WAVE" != &four {
return Err(Error::WaveChunkHeaderIsNotFound);
}
Ok(())
}
async fn skip_to_data_chunk(&mut self) -> Result<()> {
if self.data_size.is_some() {
return Ok(());
}
loop {
let four = self.take::<4>().await?;
let size = self.take_u32().await?;
match &four {
b"data" => {
self.data_size = Some(size);
return Ok(());
}
b"fmt " => {
let spec = self.parse_fmt(size).await?;
self.spec = Some(spec);
}
_ => {
for _ in 0..size {
self.next().await?;
}
}
}
}
}
async fn take_u16(&mut self) -> Result<u16> {
let four = self.take::<2>().await?;
Ok(u16::from_le_bytes(four))
}
async fn take_u32(&mut self) -> Result<u32> {
let four = self.take::<4>().await?;
Ok(u32::from_le_bytes(four))
}
async fn parse_fmt(&mut self, size: u32) -> Result<WaveSpec> {
let pcm_format = self.take_u16().await?;
let channels = self.take_u16().await?;
let sample_rate = self.take_u32().await?;
let _bit_rate = self.take_u32().await?;
let _block_size = self.take_u16().await?;
let bits_per_sample = self.take_u16().await?;
if size > 16 {
for _ in 0..(size - 16) {
self.next().await?;
}
}
let spec = WaveSpec {
pcm_format,
channels,
sample_rate,
bits_per_sample,
};
Ok(spec)
}
async fn take<const N: usize>(&mut self) -> Result<[u8; N]> {
let mut bytes = [0; N];
for item in bytes.iter_mut() {
*item = self.next().await?;
}
Ok(bytes)
}
async fn next(&mut self) -> Result<u8> {
while self.current.is_empty() {
self.current = self
.stream
.next()
.await
.ok_or(Error::DataIsNotEnough)??
.into();
}
Ok(self.current.pop_front().unwrap())
}
}
impl<'a> Stream for DataChunk<'a> {
type Item = Result<Vec<u8>>;
fn poll_next(
mut self: Pin<&mut Self>,
context: &mut std::task::Context<'_>,
) -> Poll<Option<<Self as Stream>::Item>> {
let polled = self.stream.as_mut().poll_next(context);
let Poll::Ready(ready) = polled else {
return Poll::Pending;
};
let Some(chunk) = ready else {
return Poll::Ready(None);
};
let chunk = match chunk {
Ok(chunk) => chunk,
Err(e) => return Poll::Ready(Some(Err(e))),
};
let rest_size = (self.data_size - self.consumed) as usize;
if chunk.len() < rest_size {
self.consumed += chunk.len() as u32;
Poll::Ready(Some(Ok(chunk)))
} else {
let chunk = chunk.into_iter().take(rest_size).collect();
Poll::Ready(Some(Ok(chunk)))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::fs::read;
const FILE: &str = "./assets/test/maou_se_system49.wav";
#[tokio::test]
async fn test_one_chunk() {
let read = read(FILE).await.unwrap();
let mut stream = WaveStream::new(iter(vec![Ok(read)]));
let spec = stream.spec().await.unwrap();
assert_eq!(spec.pcm_format, 1);
assert_eq!(spec.channels, 2);
assert_eq!(spec.sample_rate, 44100);
assert_eq!(spec.bits_per_sample, 24);
let data_size = stream.data_size.unwrap();
let mut data = stream.into_data().await;
let mut size = 0;
while let Some(chunk) = data.next().await {
let chunk = chunk.unwrap();
size += chunk.len();
}
assert_eq!(data_size, size as u32);
}
#[tokio::test]
async fn test_chunks() {
let read = read(FILE).await.unwrap();
let chunks = read
.chunks(65536)
.map(|x| Ok(x.iter().cloned().collect::<Vec<_>>()))
.collect::<Vec<_>>();
let mut stream = WaveStream::new(iter(chunks));
let spec = stream.spec().await.unwrap();
assert_eq!(spec.pcm_format, 1);
assert_eq!(spec.channels, 2);
assert_eq!(spec.sample_rate, 44100);
assert_eq!(spec.bits_per_sample, 24);
let data_size = stream.data_size.unwrap();
let mut data = stream.into_data().await;
let mut size = 0;
while let Some(chunk) = data.next().await {
let chunk = chunk.unwrap();
size += chunk.len();
}
assert_eq!(data_size, size as u32);
}
#[tokio::test]
async fn test_generate() {
use std::f32::consts::PI;
let data_chunk = (0..)
.enumerate()
.map(|(_, idx)| {
let t = idx as f32 / 8000.0;
let sample = (t * 440. * 2. * PI).sin();
((sample * i16::MAX as f32) as i16).to_le_bytes()
})
.take(8000)
.flatten()
.collect::<Vec<u8>>();
let mut wave = b"RIFF".to_vec();
let riff_length = ((data_chunk.len() + 36) as u32).to_le_bytes();
wave.extend(riff_length);
wave.extend(b"WAVE");
wave.extend(b"fmt ");
wave.extend(16u32.to_le_bytes()); wave.extend(1u16.to_le_bytes()); wave.extend(1u16.to_le_bytes()); wave.extend(8000u32.to_le_bytes()); wave.extend(16000u32.to_le_bytes()); wave.extend(2u16.to_le_bytes()); wave.extend(16u16.to_le_bytes()); wave.extend(b"data");
wave.extend((data_chunk.len() as u32).to_le_bytes());
wave.extend(&data_chunk);
let chunks = wave
.chunks(65536)
.map(|x| Ok(x.iter().cloned().collect::<Vec<_>>()))
.collect::<Vec<_>>();
let mut stream = WaveStream::new(iter(chunks));
let spec = stream.spec().await.unwrap();
assert_eq!(spec.pcm_format, 1);
assert_eq!(spec.channels, 1);
assert_eq!(spec.sample_rate, 8000);
assert_eq!(spec.bits_per_sample, 16);
let data_size = stream.data_size.unwrap();
let mut data = stream.into_data().await;
let mut size = 0;
while let Some(chunk) = data.next().await {
let chunk = chunk.unwrap();
size += chunk.len();
}
assert_eq!(data_size, size as u32);
assert_eq!(data_size, data_chunk.len() as u32);
}
}