Skip to main content

wavekat_vad/
adapter.rs

1//! Frame adapter for matching audio frames to VAD backend requirements.
2//!
3//! Different VAD backends have different frame size requirements. This module
4//! provides an adapter that buffers incoming audio and produces frames of the
5//! exact size required by each backend.
6
7use crate::{ProcessTimings, VadCapabilities, VadError, VoiceActivityDetector};
8
9/// Adapts audio frames to match a VAD backend's requirements.
10///
11/// Buffers incoming samples and produces frames of the exact size
12/// required by the wrapped detector. Also handles sample rate validation.
13pub struct FrameAdapter {
14    /// The wrapped VAD detector.
15    inner: Box<dyn VoiceActivityDetector>,
16    /// Capabilities of the inner detector.
17    capabilities: VadCapabilities,
18    /// Buffer for accumulating samples.
19    buffer: Vec<i16>,
20}
21
22impl FrameAdapter {
23    /// Create a new frame adapter wrapping a VAD detector.
24    pub fn new(inner: Box<dyn VoiceActivityDetector>) -> Self {
25        let capabilities = inner.capabilities();
26        Self {
27            inner,
28            capabilities,
29            buffer: Vec::new(),
30        }
31    }
32
33    /// Returns the capabilities of the wrapped detector.
34    pub fn capabilities(&self) -> &VadCapabilities {
35        &self.capabilities
36    }
37
38    /// Returns the required sample rate.
39    pub fn sample_rate(&self) -> u32 {
40        self.capabilities.sample_rate
41    }
42
43    /// Returns the required frame size in samples.
44    pub fn frame_size(&self) -> usize {
45        self.capabilities.frame_size
46    }
47
48    /// Process audio samples, buffering until a complete frame is available.
49    ///
50    /// Returns `Some(probability)` when a complete frame was processed,
51    /// or `None` if more samples are needed.
52    ///
53    /// # Arguments
54    /// * `samples` - Audio samples (any length)
55    /// * `sample_rate` - Sample rate of the input audio
56    ///
57    /// # Errors
58    /// Returns an error if the sample rate doesn't match the detector's requirements.
59    pub fn process(&mut self, samples: &[i16], sample_rate: u32) -> Result<Option<f32>, VadError> {
60        if sample_rate != self.capabilities.sample_rate {
61            return Err(VadError::InvalidSampleRate(sample_rate));
62        }
63
64        self.buffer.extend_from_slice(samples);
65
66        if self.buffer.len() >= self.capabilities.frame_size {
67            let frame: Vec<i16> = self.buffer.drain(..self.capabilities.frame_size).collect();
68            let probability = self.inner.process(&frame, sample_rate)?;
69            Ok(Some(probability))
70        } else {
71            Ok(None)
72        }
73    }
74
75    /// Process all complete frames in the buffer.
76    ///
77    /// Returns a vector of probabilities, one for each complete frame processed.
78    /// Useful when you want to process multiple frames at once.
79    pub fn process_all(&mut self, samples: &[i16], sample_rate: u32) -> Result<Vec<f32>, VadError> {
80        if sample_rate != self.capabilities.sample_rate {
81            return Err(VadError::InvalidSampleRate(sample_rate));
82        }
83
84        self.buffer.extend_from_slice(samples);
85
86        let mut results = Vec::new();
87        while self.buffer.len() >= self.capabilities.frame_size {
88            let frame: Vec<i16> = self.buffer.drain(..self.capabilities.frame_size).collect();
89            let probability = self.inner.process(&frame, sample_rate)?;
90            results.push(probability);
91        }
92
93        Ok(results)
94    }
95
96    /// Returns the last probability from processing, or 0.0 if no frame was complete.
97    ///
98    /// This is a convenience method for real-time processing where you only
99    /// care about the most recent result.
100    pub fn process_latest(&mut self, samples: &[i16], sample_rate: u32) -> Result<f32, VadError> {
101        let results = self.process_all(samples, sample_rate)?;
102        Ok(results.into_iter().last().unwrap_or(0.0))
103    }
104
105    /// Reset the adapter and the wrapped detector.
106    pub fn reset(&mut self) {
107        self.buffer.clear();
108        self.inner.reset();
109    }
110
111    /// Returns the number of samples currently buffered.
112    pub fn buffered_samples(&self) -> usize {
113        self.buffer.len()
114    }
115
116    /// Returns accumulated processing timings from the inner detector.
117    pub fn timings(&self) -> ProcessTimings {
118        self.inner.timings()
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    // Mock VAD for testing
127    struct MockVad {
128        sample_rate: u32,
129        frame_size: usize,
130        call_count: usize,
131    }
132
133    impl MockVad {
134        fn new(sample_rate: u32, frame_size: usize) -> Self {
135            Self {
136                sample_rate,
137                frame_size,
138                call_count: 0,
139            }
140        }
141    }
142
143    impl VoiceActivityDetector for MockVad {
144        fn capabilities(&self) -> VadCapabilities {
145            VadCapabilities {
146                sample_rate: self.sample_rate,
147                frame_size: self.frame_size,
148                frame_duration_ms: (self.frame_size as u32 * 1000) / self.sample_rate,
149            }
150        }
151
152        fn process(&mut self, samples: &[i16], _sample_rate: u32) -> Result<f32, VadError> {
153            assert_eq!(samples.len(), self.frame_size);
154            self.call_count += 1;
155            Ok(0.5)
156        }
157
158        fn reset(&mut self) {
159            self.call_count = 0;
160        }
161    }
162
163    #[test]
164    fn test_adapter_buffers_samples() {
165        let mock = MockVad::new(16000, 512);
166        let mut adapter = FrameAdapter::new(Box::new(mock));
167
168        // Send less than a full frame
169        let result = adapter.process(&[0i16; 256], 16000).unwrap();
170        assert!(result.is_none());
171        assert_eq!(adapter.buffered_samples(), 256);
172
173        // Send more to complete the frame
174        let result = adapter.process(&[0i16; 256], 16000).unwrap();
175        assert!(result.is_some());
176        assert_eq!(adapter.buffered_samples(), 0);
177    }
178
179    #[test]
180    fn test_adapter_handles_multiple_frames() {
181        let mock = MockVad::new(16000, 512);
182        let mut adapter = FrameAdapter::new(Box::new(mock));
183
184        // Send two complete frames worth
185        let results = adapter.process_all(&[0i16; 1024], 16000).unwrap();
186        assert_eq!(results.len(), 2);
187    }
188
189    #[test]
190    fn test_adapter_wrong_sample_rate() {
191        let mock = MockVad::new(16000, 512);
192        let mut adapter = FrameAdapter::new(Box::new(mock));
193
194        let result = adapter.process(&[0i16; 512], 48000);
195        assert!(matches!(result, Err(VadError::InvalidSampleRate(48000))));
196    }
197
198    #[test]
199    fn test_adapter_reset() {
200        let mock = MockVad::new(16000, 512);
201        let mut adapter = FrameAdapter::new(Box::new(mock));
202
203        // Buffer some samples
204        let _ = adapter.process(&[0i16; 256], 16000);
205        assert_eq!(adapter.buffered_samples(), 256);
206
207        // Reset
208        adapter.reset();
209        assert_eq!(adapter.buffered_samples(), 0);
210    }
211
212    #[test]
213    fn test_process_latest() {
214        let mock = MockVad::new(16000, 512);
215        let mut adapter = FrameAdapter::new(Box::new(mock));
216
217        // Send multiple frames (1600 = 3 full frames + 64 left over)
218        let result = adapter.process_latest(&[0i16; 1600], 16000).unwrap();
219        assert_eq!(result, 0.5); // Mock returns 0.5
220        assert_eq!(adapter.buffered_samples(), 64); // 1600 - 3*512 = 64 left over
221    }
222}