stream_wave_parser/
mixer.rs

1//! Types to convert multiple channels data into one channel.
2
3use crate::{Error, Result, WaveSpec};
4use futures_util::stream::BoxStream;
5use futures_util::Stream;
6
7const INVALID_PCM_FORMAT: &str = "only support PCM integer (`pcm_format` is 1)";
8const INVALID_BITS_PER_SAMPLE: &str = "only supports a `bits_per_sample` divisible by 8";
9const INVALID_BYTES_PER_SAMPLE: &str = "only supports a `bits_per_sample` less than or equal 64";
10
11/// The type that converts WAVE multi channels data into a single channel.
12pub struct WaveChannelMixer<'a> {
13    stream: BoxStream<'a, Result<Vec<u8>>>,
14
15    channels: u16,
16    bits_per_sample: u16,
17    options: WaveChannelMixerOptions,
18
19    rest: Vec<u8>,
20}
21
22/// The type representing the options to mix.
23pub enum WaveChannelMixerOptions {
24    /// Picks one channel from multiple channels.
25    Pick(u16),
26
27    /// Makes the mean of multiple channels.
28    Mean,
29}
30
31impl<'a> WaveChannelMixer<'a> {
32    /// The constructor.
33    pub fn new(
34        spec: &WaveSpec,
35        options: WaveChannelMixerOptions,
36        stream: impl Stream<Item = Result<Vec<u8>>> + Send + 'a,
37    ) -> Result<Self> {
38        if spec.pcm_format != 1 {
39            return Err(Error::MixerConstruction(INVALID_PCM_FORMAT));
40        }
41
42        if spec.bits_per_sample % 8 != 0 {
43            return Err(Error::MixerConstruction(INVALID_BITS_PER_SAMPLE));
44        }
45        if spec.bits_per_sample > 64 {
46            return Err(Error::MixerConstruction(INVALID_BYTES_PER_SAMPLE));
47        }
48
49        let options = if let WaveChannelMixerOptions::Pick(idx) = options {
50            WaveChannelMixerOptions::Pick(idx % spec.channels)
51        } else {
52            options
53        };
54
55        let ret = Self {
56            stream: Box::pin(stream),
57
58            channels: spec.channels,
59            bits_per_sample: spec.bits_per_sample,
60            options,
61
62            rest: vec![],
63        };
64
65        Ok(ret)
66    }
67
68    /// Converts multiple channels into ([single channel], [rest data]).
69    fn convert(&self, input: &[u8]) -> (Vec<u8>, Vec<u8>) {
70        match &self.options {
71            WaveChannelMixerOptions::Pick(idx) => self.pick(*idx, input),
72            WaveChannelMixerOptions::Mean => self.mean(input),
73        }
74    }
75
76    /// Picks single channel in multiple channels.
77    fn pick(&self, pick_idx: u16, input: &[u8]) -> (Vec<u8>, Vec<u8>) {
78        let mut idx = 0u16;
79        let mut converted = vec![];
80
81        println!("input = {input:?}");
82
83        let bytes_per_sample = (self.bits_per_sample / 8) as usize;
84        let mut head = 0usize;
85        let end = input.len() - input.len() % (bytes_per_sample * self.channels as usize);
86        while head < end {
87            if idx == pick_idx {
88                let sample = self.take_sample(input, head);
89                converted.extend(self.sample_into_vec(sample));
90            }
91            head += bytes_per_sample;
92            idx = (idx + 1) % self.channels;
93        }
94
95        let rest = input.iter().skip(head).cloned().collect();
96
97        println!("rest = {rest:?}");
98
99        (converted, rest)
100    }
101
102    /// Makes the mean of multiple channels into a single channel.
103    fn mean(&self, input: &[u8]) -> (Vec<u8>, Vec<u8>) {
104        let mut idx = 0u16;
105        let mut sum = 0f32;
106        let mut converted = vec![];
107
108        let bytes_per_sample = (self.bits_per_sample / 8) as usize;
109        let mut head = 0usize;
110        let end = input.len() - input.len() % (bytes_per_sample * self.channels as usize);
111        while head < end {
112            let sample = self.take_sample(input, head);
113            sum += sample;
114
115            head += bytes_per_sample;
116            idx = (idx + 1) % self.channels;
117
118            if idx == 0 {
119                converted.extend(self.sample_into_vec(sum / self.channels as f32));
120                sum = 0.;
121            }
122        }
123
124        let rest = input.iter().skip(head).cloned().collect();
125
126        (converted, rest)
127    }
128
129    fn take_sample(&self, input: &[u8], head: usize) -> f32 {
130        let bytes_per_sample = (self.bits_per_sample / 8) as usize;
131        match bytes_per_sample {
132            1..=1 => from_bytes::take_i8_sample(input, head) as f32,
133            2..=2 => from_bytes::take_i16_sample(input, head) as f32,
134            3..=4 => from_bytes::take_i32_sample(input, head, bytes_per_sample) as f32,
135            5..=8 => from_bytes::take_i64_sample(input, head, bytes_per_sample) as f32,
136            _ => unreachable!(),
137        }
138    }
139
140    fn sample_into_vec(&self, sample: f32) -> Vec<u8> {
141        let bytes_per_sample = (self.bits_per_sample / 8) as usize;
142        match bytes_per_sample {
143            1..=1 => (sample as i8).to_le_bytes().to_vec(),
144            2..=2 => (sample as i16).to_le_bytes().to_vec(),
145            3..=4 => (sample as i32).to_le_bytes()[..bytes_per_sample].to_vec(),
146            5..=8 => (sample as i64).to_le_bytes()[..bytes_per_sample].to_vec(),
147            _ => unreachable!(),
148        }
149    }
150}
151
152mod impls {
153    //! Implements [`Stream`] for [`WaveChannelMixer`].
154
155    use super::*;
156
157    use std::pin::Pin;
158    use std::task::{Context, Poll};
159
160    impl<'a> Stream for WaveChannelMixer<'a> {
161        type Item = Result<Vec<u8>>;
162
163        fn poll_next(
164            mut self: Pin<&mut Self>,
165            context: &mut Context<'_>,
166        ) -> Poll<Option<Self::Item>> {
167            let polled = self.stream.as_mut().poll_next(context);
168            let ready = match polled {
169                Poll::Ready(ready) => ready,
170                Poll::Pending => return Poll::Pending,
171            };
172
173            // the input stream is exhausted
174            let Some(chunk) = ready else {
175                if self.rest.is_empty() {
176                    return Poll::Ready(None);
177                } else {
178                    return Poll::Ready(Some(Err(Error::DataIsNotEnough)));
179                }
180            };
181
182            // the input stream has error
183            let chunk = match chunk {
184                Ok(chunk) => chunk,
185                Err(e) => return Poll::Ready(Some(Err(e))),
186            };
187
188            // convert
189            self.rest.extend(&chunk);
190            let (converted, rest) = self.convert(&self.rest);
191            self.rest = rest;
192
193            Poll::Ready(Some(Ok(converted)))
194        }
195    }
196}
197
198mod from_bytes {
199    //! Utilities of `from_le_bytes()`.
200    pub fn take_i8_sample(input: &[u8], head: usize) -> i8 {
201        let mut buf = [0u8; 1];
202        buf[0] = input[head];
203        i8::from_le_bytes(buf)
204    }
205
206    pub fn take_i16_sample(input: &[u8], head: usize) -> i16 {
207        let mut buf = [0u8; 2];
208        buf[..2].copy_from_slice(&input[head..(head + 2)]);
209        i16::from_le_bytes(buf)
210    }
211
212    pub fn take_i32_sample(input: &[u8], head: usize, width: usize) -> i32 {
213        let mut buf = [0u8; 4];
214        buf[..width].copy_from_slice(&input[head..(head + width)]);
215        i32::from_le_bytes(buf)
216    }
217
218    pub fn take_i64_sample(input: &[u8], head: usize, width: usize) -> i64 {
219        let mut buf = [0u8; 8];
220        buf[..width].copy_from_slice(&input[head..(head + width)]);
221        i64::from_le_bytes(buf)
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    use futures_util::stream::iter;
230    use futures_util::StreamExt as _;
231
232    #[tokio::test]
233    async fn test_pick() {
234        let spec = WaveSpec {
235            pcm_format: 1,
236            channels: 2,
237            sample_rate: 8000,
238            bits_per_sample: 16,
239        };
240
241        let options = WaveChannelMixerOptions::Pick(0); // channel 0
242
243        let chunks = (0i16..200)
244            .map(|x| x.to_le_bytes())
245            .flatten()
246            .collect::<Vec<_>>()
247            .chunks(31)
248            .map(|x| Ok(x.iter().cloned().collect::<Vec<_>>()))
249            .collect::<Vec<_>>();
250        let stream = iter(chunks);
251
252        let mut mixer = WaveChannelMixer::new(&spec, options, stream).unwrap();
253        let mut converted = vec![];
254        while let Some(chunk) = mixer.next().await {
255            let chunk = chunk.unwrap();
256            converted.extend(chunk);
257        }
258
259        let converted: Vec<_> = converted
260            .chunks(2)
261            .map(|x| from_bytes::take_i16_sample(x, 0))
262            .collect();
263        let expected: Vec<_> = (0i16..100).map(|x| x * 2).collect();
264        assert_eq!(converted, expected);
265
266        let options = WaveChannelMixerOptions::Pick(1); // channel 1
267
268        let chunks = (0i16..200)
269            .map(|x| x.to_le_bytes())
270            .flatten()
271            .collect::<Vec<_>>()
272            .chunks(31)
273            .map(|x| Ok(x.iter().cloned().collect::<Vec<_>>()))
274            .collect::<Vec<_>>();
275        let stream = iter(chunks);
276
277        let mut mixer = WaveChannelMixer::new(&spec, options, stream).unwrap();
278        let mut converted = vec![];
279        while let Some(chunk) = mixer.next().await {
280            let chunk = chunk.unwrap();
281            converted.extend(chunk);
282        }
283
284        let converted: Vec<_> = converted
285            .chunks(2)
286            .map(|x| from_bytes::take_i16_sample(x, 0))
287            .collect();
288        let expected: Vec<_> = (0i16..100).map(|x| x * 2 + 1).collect();
289        assert_eq!(converted, expected);
290    }
291
292    #[tokio::test]
293    async fn test_mean() {
294        let spec = WaveSpec {
295            pcm_format: 1,
296            channels: 2,
297            sample_rate: 8000,
298            bits_per_sample: 16,
299        };
300
301        let options = WaveChannelMixerOptions::Mean;
302
303        let chunks = (0i16..200)
304            .map(|x| (x * 2).to_le_bytes())
305            .flatten()
306            .collect::<Vec<_>>()
307            .chunks(31)
308            .map(|x| Ok(x.iter().cloned().collect::<Vec<_>>()))
309            .collect::<Vec<_>>();
310        let stream = iter(chunks);
311
312        let mut mixer = WaveChannelMixer::new(&spec, options, stream).unwrap();
313        let mut converted = vec![];
314        while let Some(chunk) = mixer.next().await {
315            let chunk = chunk.unwrap();
316            converted.extend(chunk);
317        }
318
319        let converted: Vec<_> = converted
320            .chunks(2)
321            .map(|x| from_bytes::take_i16_sample(x, 0))
322            .collect();
323        let expected: Vec<_> = (0i16..100).map(|x| x * 4 + 1).collect();
324        assert_eq!(converted, expected);
325    }
326}