stream_wave_parser/
mixer.rs1use crate::{Error, Result, WaveSpec};
4use futures_util::stream::BoxStream;
5use futures_util::Stream;
6
7const INVALID_PCM_FORMAT: &str = "only support PCM integer (`pcm_format` is 1)";
8const INVALID_BITS_PER_SAMPLE: &str = "only supports a `bits_per_sample` divisible by 8";
9const INVALID_BYTES_PER_SAMPLE: &str = "only supports a `bits_per_sample` less than or equal 64";
10
11pub struct WaveChannelMixer<'a> {
13 stream: BoxStream<'a, Result<Vec<u8>>>,
14
15 channels: u16,
16 bits_per_sample: u16,
17 options: WaveChannelMixerOptions,
18
19 rest: Vec<u8>,
20}
21
22pub enum WaveChannelMixerOptions {
24 Pick(u16),
26
27 Mean,
29}
30
31impl<'a> WaveChannelMixer<'a> {
32 pub fn new(
34 spec: &WaveSpec,
35 options: WaveChannelMixerOptions,
36 stream: impl Stream<Item = Result<Vec<u8>>> + Send + 'a,
37 ) -> Result<Self> {
38 if spec.pcm_format != 1 {
39 return Err(Error::MixerConstruction(INVALID_PCM_FORMAT));
40 }
41
42 if spec.bits_per_sample % 8 != 0 {
43 return Err(Error::MixerConstruction(INVALID_BITS_PER_SAMPLE));
44 }
45 if spec.bits_per_sample > 64 {
46 return Err(Error::MixerConstruction(INVALID_BYTES_PER_SAMPLE));
47 }
48
49 let options = if let WaveChannelMixerOptions::Pick(idx) = options {
50 WaveChannelMixerOptions::Pick(idx % spec.channels)
51 } else {
52 options
53 };
54
55 let ret = Self {
56 stream: Box::pin(stream),
57
58 channels: spec.channels,
59 bits_per_sample: spec.bits_per_sample,
60 options,
61
62 rest: vec![],
63 };
64
65 Ok(ret)
66 }
67
68 fn convert(&self, input: &[u8]) -> (Vec<u8>, Vec<u8>) {
70 match &self.options {
71 WaveChannelMixerOptions::Pick(idx) => self.pick(*idx, input),
72 WaveChannelMixerOptions::Mean => self.mean(input),
73 }
74 }
75
76 fn pick(&self, pick_idx: u16, input: &[u8]) -> (Vec<u8>, Vec<u8>) {
78 let mut idx = 0u16;
79 let mut converted = vec![];
80
81 println!("input = {input:?}");
82
83 let bytes_per_sample = (self.bits_per_sample / 8) as usize;
84 let mut head = 0usize;
85 let end = input.len() - input.len() % (bytes_per_sample * self.channels as usize);
86 while head < end {
87 if idx == pick_idx {
88 let sample = self.take_sample(input, head);
89 converted.extend(self.sample_into_vec(sample));
90 }
91 head += bytes_per_sample;
92 idx = (idx + 1) % self.channels;
93 }
94
95 let rest = input.iter().skip(head).cloned().collect();
96
97 println!("rest = {rest:?}");
98
99 (converted, rest)
100 }
101
102 fn mean(&self, input: &[u8]) -> (Vec<u8>, Vec<u8>) {
104 let mut idx = 0u16;
105 let mut sum = 0f32;
106 let mut converted = vec![];
107
108 let bytes_per_sample = (self.bits_per_sample / 8) as usize;
109 let mut head = 0usize;
110 let end = input.len() - input.len() % (bytes_per_sample * self.channels as usize);
111 while head < end {
112 let sample = self.take_sample(input, head);
113 sum += sample;
114
115 head += bytes_per_sample;
116 idx = (idx + 1) % self.channels;
117
118 if idx == 0 {
119 converted.extend(self.sample_into_vec(sum / self.channels as f32));
120 sum = 0.;
121 }
122 }
123
124 let rest = input.iter().skip(head).cloned().collect();
125
126 (converted, rest)
127 }
128
129 fn take_sample(&self, input: &[u8], head: usize) -> f32 {
130 let bytes_per_sample = (self.bits_per_sample / 8) as usize;
131 match bytes_per_sample {
132 1..=1 => from_bytes::take_i8_sample(input, head) as f32,
133 2..=2 => from_bytes::take_i16_sample(input, head) as f32,
134 3..=4 => from_bytes::take_i32_sample(input, head, bytes_per_sample) as f32,
135 5..=8 => from_bytes::take_i64_sample(input, head, bytes_per_sample) as f32,
136 _ => unreachable!(),
137 }
138 }
139
140 fn sample_into_vec(&self, sample: f32) -> Vec<u8> {
141 let bytes_per_sample = (self.bits_per_sample / 8) as usize;
142 match bytes_per_sample {
143 1..=1 => (sample as i8).to_le_bytes().to_vec(),
144 2..=2 => (sample as i16).to_le_bytes().to_vec(),
145 3..=4 => (sample as i32).to_le_bytes()[..bytes_per_sample].to_vec(),
146 5..=8 => (sample as i64).to_le_bytes()[..bytes_per_sample].to_vec(),
147 _ => unreachable!(),
148 }
149 }
150}
151
152mod impls {
153 use super::*;
156
157 use std::pin::Pin;
158 use std::task::{Context, Poll};
159
160 impl<'a> Stream for WaveChannelMixer<'a> {
161 type Item = Result<Vec<u8>>;
162
163 fn poll_next(
164 mut self: Pin<&mut Self>,
165 context: &mut Context<'_>,
166 ) -> Poll<Option<Self::Item>> {
167 let polled = self.stream.as_mut().poll_next(context);
168 let ready = match polled {
169 Poll::Ready(ready) => ready,
170 Poll::Pending => return Poll::Pending,
171 };
172
173 let Some(chunk) = ready else {
175 if self.rest.is_empty() {
176 return Poll::Ready(None);
177 } else {
178 return Poll::Ready(Some(Err(Error::DataIsNotEnough)));
179 }
180 };
181
182 let chunk = match chunk {
184 Ok(chunk) => chunk,
185 Err(e) => return Poll::Ready(Some(Err(e))),
186 };
187
188 self.rest.extend(&chunk);
190 let (converted, rest) = self.convert(&self.rest);
191 self.rest = rest;
192
193 Poll::Ready(Some(Ok(converted)))
194 }
195 }
196}
197
198mod from_bytes {
199 pub fn take_i8_sample(input: &[u8], head: usize) -> i8 {
201 let mut buf = [0u8; 1];
202 buf[0] = input[head];
203 i8::from_le_bytes(buf)
204 }
205
206 pub fn take_i16_sample(input: &[u8], head: usize) -> i16 {
207 let mut buf = [0u8; 2];
208 buf[..2].copy_from_slice(&input[head..(head + 2)]);
209 i16::from_le_bytes(buf)
210 }
211
212 pub fn take_i32_sample(input: &[u8], head: usize, width: usize) -> i32 {
213 let mut buf = [0u8; 4];
214 buf[..width].copy_from_slice(&input[head..(head + width)]);
215 i32::from_le_bytes(buf)
216 }
217
218 pub fn take_i64_sample(input: &[u8], head: usize, width: usize) -> i64 {
219 let mut buf = [0u8; 8];
220 buf[..width].copy_from_slice(&input[head..(head + width)]);
221 i64::from_le_bytes(buf)
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 use futures_util::stream::iter;
230 use futures_util::StreamExt as _;
231
232 #[tokio::test]
233 async fn test_pick() {
234 let spec = WaveSpec {
235 pcm_format: 1,
236 channels: 2,
237 sample_rate: 8000,
238 bits_per_sample: 16,
239 };
240
241 let options = WaveChannelMixerOptions::Pick(0); let chunks = (0i16..200)
244 .map(|x| x.to_le_bytes())
245 .flatten()
246 .collect::<Vec<_>>()
247 .chunks(31)
248 .map(|x| Ok(x.iter().cloned().collect::<Vec<_>>()))
249 .collect::<Vec<_>>();
250 let stream = iter(chunks);
251
252 let mut mixer = WaveChannelMixer::new(&spec, options, stream).unwrap();
253 let mut converted = vec![];
254 while let Some(chunk) = mixer.next().await {
255 let chunk = chunk.unwrap();
256 converted.extend(chunk);
257 }
258
259 let converted: Vec<_> = converted
260 .chunks(2)
261 .map(|x| from_bytes::take_i16_sample(x, 0))
262 .collect();
263 let expected: Vec<_> = (0i16..100).map(|x| x * 2).collect();
264 assert_eq!(converted, expected);
265
266 let options = WaveChannelMixerOptions::Pick(1); let chunks = (0i16..200)
269 .map(|x| x.to_le_bytes())
270 .flatten()
271 .collect::<Vec<_>>()
272 .chunks(31)
273 .map(|x| Ok(x.iter().cloned().collect::<Vec<_>>()))
274 .collect::<Vec<_>>();
275 let stream = iter(chunks);
276
277 let mut mixer = WaveChannelMixer::new(&spec, options, stream).unwrap();
278 let mut converted = vec![];
279 while let Some(chunk) = mixer.next().await {
280 let chunk = chunk.unwrap();
281 converted.extend(chunk);
282 }
283
284 let converted: Vec<_> = converted
285 .chunks(2)
286 .map(|x| from_bytes::take_i16_sample(x, 0))
287 .collect();
288 let expected: Vec<_> = (0i16..100).map(|x| x * 2 + 1).collect();
289 assert_eq!(converted, expected);
290 }
291
292 #[tokio::test]
293 async fn test_mean() {
294 let spec = WaveSpec {
295 pcm_format: 1,
296 channels: 2,
297 sample_rate: 8000,
298 bits_per_sample: 16,
299 };
300
301 let options = WaveChannelMixerOptions::Mean;
302
303 let chunks = (0i16..200)
304 .map(|x| (x * 2).to_le_bytes())
305 .flatten()
306 .collect::<Vec<_>>()
307 .chunks(31)
308 .map(|x| Ok(x.iter().cloned().collect::<Vec<_>>()))
309 .collect::<Vec<_>>();
310 let stream = iter(chunks);
311
312 let mut mixer = WaveChannelMixer::new(&spec, options, stream).unwrap();
313 let mut converted = vec![];
314 while let Some(chunk) = mixer.next().await {
315 let chunk = chunk.unwrap();
316 converted.extend(chunk);
317 }
318
319 let converted: Vec<_> = converted
320 .chunks(2)
321 .map(|x| from_bytes::take_i16_sample(x, 0))
322 .collect();
323 let expected: Vec<_> = (0i16..100).map(|x| x * 4 + 1).collect();
324 assert_eq!(converted, expected);
325 }
326}