Skip to main content

wavekat_vad/
adapter.rs

1//! Frame adapter for matching audio frames to VAD backend requirements.
2//!
3//! Different VAD backends have different frame size requirements. This module
4//! provides an adapter that buffers incoming audio and produces frames of the
5//! exact size required by each backend.
6
7use crate::{VadCapabilities, VadError, VoiceActivityDetector};
8
9/// Adapts audio frames to match a VAD backend's requirements.
10///
11/// Buffers incoming samples and produces frames of the exact size
12/// required by the wrapped detector. Also handles sample rate validation.
13pub struct FrameAdapter {
14    /// The wrapped VAD detector.
15    inner: Box<dyn VoiceActivityDetector>,
16    /// Capabilities of the inner detector.
17    capabilities: VadCapabilities,
18    /// Buffer for accumulating samples.
19    buffer: Vec<i16>,
20}
21
22impl FrameAdapter {
23    /// Create a new frame adapter wrapping a VAD detector.
24    pub fn new(inner: Box<dyn VoiceActivityDetector>) -> Self {
25        let capabilities = inner.capabilities();
26        Self {
27            inner,
28            capabilities,
29            buffer: Vec::new(),
30        }
31    }
32
33    /// Returns the capabilities of the wrapped detector.
34    pub fn capabilities(&self) -> &VadCapabilities {
35        &self.capabilities
36    }
37
38    /// Returns the required sample rate.
39    pub fn sample_rate(&self) -> u32 {
40        self.capabilities.sample_rate
41    }
42
43    /// Returns the required frame size in samples.
44    pub fn frame_size(&self) -> usize {
45        self.capabilities.frame_size
46    }
47
48    /// Process audio samples, buffering until a complete frame is available.
49    ///
50    /// Returns `Some(probability)` when a complete frame was processed,
51    /// or `None` if more samples are needed.
52    ///
53    /// # Arguments
54    /// * `samples` - Audio samples (any length)
55    /// * `sample_rate` - Sample rate of the input audio
56    ///
57    /// # Errors
58    /// Returns an error if the sample rate doesn't match the detector's requirements.
59    pub fn process(&mut self, samples: &[i16], sample_rate: u32) -> Result<Option<f32>, VadError> {
60        if sample_rate != self.capabilities.sample_rate {
61            return Err(VadError::InvalidSampleRate(sample_rate));
62        }
63
64        self.buffer.extend_from_slice(samples);
65
66        if self.buffer.len() >= self.capabilities.frame_size {
67            let frame: Vec<i16> = self.buffer.drain(..self.capabilities.frame_size).collect();
68            let probability = self.inner.process(&frame, sample_rate)?;
69            Ok(Some(probability))
70        } else {
71            Ok(None)
72        }
73    }
74
75    /// Process all complete frames in the buffer.
76    ///
77    /// Returns a vector of probabilities, one for each complete frame processed.
78    /// Useful when you want to process multiple frames at once.
79    pub fn process_all(&mut self, samples: &[i16], sample_rate: u32) -> Result<Vec<f32>, VadError> {
80        if sample_rate != self.capabilities.sample_rate {
81            return Err(VadError::InvalidSampleRate(sample_rate));
82        }
83
84        self.buffer.extend_from_slice(samples);
85
86        let mut results = Vec::new();
87        while self.buffer.len() >= self.capabilities.frame_size {
88            let frame: Vec<i16> = self.buffer.drain(..self.capabilities.frame_size).collect();
89            let probability = self.inner.process(&frame, sample_rate)?;
90            results.push(probability);
91        }
92
93        Ok(results)
94    }
95
96    /// Returns the last probability from processing, or 0.0 if no frame was complete.
97    ///
98    /// This is a convenience method for real-time processing where you only
99    /// care about the most recent result.
100    pub fn process_latest(&mut self, samples: &[i16], sample_rate: u32) -> Result<f32, VadError> {
101        let results = self.process_all(samples, sample_rate)?;
102        Ok(results.into_iter().last().unwrap_or(0.0))
103    }
104
105    /// Reset the adapter and the wrapped detector.
106    pub fn reset(&mut self) {
107        self.buffer.clear();
108        self.inner.reset();
109    }
110
111    /// Returns the number of samples currently buffered.
112    pub fn buffered_samples(&self) -> usize {
113        self.buffer.len()
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    // Mock VAD for testing
122    struct MockVad {
123        sample_rate: u32,
124        frame_size: usize,
125        call_count: usize,
126    }
127
128    impl MockVad {
129        fn new(sample_rate: u32, frame_size: usize) -> Self {
130            Self {
131                sample_rate,
132                frame_size,
133                call_count: 0,
134            }
135        }
136    }
137
138    impl VoiceActivityDetector for MockVad {
139        fn capabilities(&self) -> VadCapabilities {
140            VadCapabilities {
141                sample_rate: self.sample_rate,
142                frame_size: self.frame_size,
143                frame_duration_ms: (self.frame_size as u32 * 1000) / self.sample_rate,
144            }
145        }
146
147        fn process(&mut self, samples: &[i16], _sample_rate: u32) -> Result<f32, VadError> {
148            assert_eq!(samples.len(), self.frame_size);
149            self.call_count += 1;
150            Ok(0.5)
151        }
152
153        fn reset(&mut self) {
154            self.call_count = 0;
155        }
156    }
157
158    #[test]
159    fn test_adapter_buffers_samples() {
160        let mock = MockVad::new(16000, 512);
161        let mut adapter = FrameAdapter::new(Box::new(mock));
162
163        // Send less than a full frame
164        let result = adapter.process(&[0i16; 256], 16000).unwrap();
165        assert!(result.is_none());
166        assert_eq!(adapter.buffered_samples(), 256);
167
168        // Send more to complete the frame
169        let result = adapter.process(&[0i16; 256], 16000).unwrap();
170        assert!(result.is_some());
171        assert_eq!(adapter.buffered_samples(), 0);
172    }
173
174    #[test]
175    fn test_adapter_handles_multiple_frames() {
176        let mock = MockVad::new(16000, 512);
177        let mut adapter = FrameAdapter::new(Box::new(mock));
178
179        // Send two complete frames worth
180        let results = adapter.process_all(&[0i16; 1024], 16000).unwrap();
181        assert_eq!(results.len(), 2);
182    }
183
184    #[test]
185    fn test_adapter_wrong_sample_rate() {
186        let mock = MockVad::new(16000, 512);
187        let mut adapter = FrameAdapter::new(Box::new(mock));
188
189        let result = adapter.process(&[0i16; 512], 48000);
190        assert!(matches!(result, Err(VadError::InvalidSampleRate(48000))));
191    }
192
193    #[test]
194    fn test_adapter_reset() {
195        let mock = MockVad::new(16000, 512);
196        let mut adapter = FrameAdapter::new(Box::new(mock));
197
198        // Buffer some samples
199        let _ = adapter.process(&[0i16; 256], 16000);
200        assert_eq!(adapter.buffered_samples(), 256);
201
202        // Reset
203        adapter.reset();
204        assert_eq!(adapter.buffered_samples(), 0);
205    }
206
207    #[test]
208    fn test_process_latest() {
209        let mock = MockVad::new(16000, 512);
210        let mut adapter = FrameAdapter::new(Box::new(mock));
211
212        // Send multiple frames (1600 = 3 full frames + 64 left over)
213        let result = adapter.process_latest(&[0i16; 1600], 16000).unwrap();
214        assert_eq!(result, 0.5); // Mock returns 0.5
215        assert_eq!(adapter.buffered_samples(), 64); // 1600 - 3*512 = 64 left over
216    }
217}