Skip to main content

whisper_cpp_plus/
vad.rs

1//! Voice Activity Detection (VAD) support
2//!
3//! This module provides VAD capabilities for detecting speech segments
4//! in audio before transcription, improving performance and accuracy.
5
6use crate::error::{Result, WhisperError};
7use std::path::Path;
8use whisper_cpp_plus_sys as ffi;
9
10/// VAD parameters for speech detection
11#[derive(Debug, Clone)]
12pub struct VadParams {
13    /// Probability threshold to consider as speech (0.0 - 1.0)
14    pub threshold: f32,
15    /// Minimum duration for a valid speech segment (in milliseconds)
16    pub min_speech_duration_ms: i32,
17    /// Minimum duration of silence to split segments (in milliseconds)
18    pub min_silence_duration_ms: i32,
19    /// Maximum speech duration before forcing a segment break (in seconds)
20    pub max_speech_duration_s: f32,
21    /// Padding added before and after speech segments (in milliseconds)
22    pub speech_pad_ms: i32,
23    /// Overlap in seconds when copying audio samples from speech segment
24    pub samples_overlap: f32,
25}
26
27impl Default for VadParams {
28    fn default() -> Self {
29        // Use whisper.cpp's default VAD parameters
30        let default_params = unsafe { ffi::whisper_vad_default_params() };
31
32        Self {
33            threshold: default_params.threshold,
34            min_speech_duration_ms: default_params.min_speech_duration_ms,
35            min_silence_duration_ms: default_params.min_silence_duration_ms,
36            max_speech_duration_s: default_params.max_speech_duration_s,
37            speech_pad_ms: default_params.speech_pad_ms,
38            samples_overlap: default_params.samples_overlap,
39        }
40    }
41}
42
43impl VadParams {
44    /// Convert to FFI params
45    fn to_ffi(&self) -> ffi::whisper_vad_params {
46        ffi::whisper_vad_params {
47            threshold: self.threshold,
48            min_speech_duration_ms: self.min_speech_duration_ms,
49            min_silence_duration_ms: self.min_silence_duration_ms,
50            max_speech_duration_s: self.max_speech_duration_s,
51            speech_pad_ms: self.speech_pad_ms,
52            samples_overlap: self.samples_overlap,
53        }
54    }
55}
56
57/// VAD context parameters
58#[derive(Debug, Clone)]
59pub struct VadContextParams {
60    /// Number of threads to use for processing
61    pub n_threads: i32,
62    /// Whether to use GPU acceleration
63    pub use_gpu: bool,
64    /// GPU device ID to use
65    pub gpu_device: i32,
66}
67
68impl Default for VadContextParams {
69    fn default() -> Self {
70        let default_params = unsafe { ffi::whisper_vad_default_context_params() };
71
72        Self {
73            n_threads: default_params.n_threads,
74            use_gpu: default_params.use_gpu,
75            gpu_device: default_params.gpu_device,
76        }
77    }
78}
79
80impl VadContextParams {
81    /// Convert to FFI params
82    fn to_ffi(&self) -> ffi::whisper_vad_context_params {
83        ffi::whisper_vad_context_params {
84            n_threads: self.n_threads,
85            use_gpu: self.use_gpu,
86            gpu_device: self.gpu_device,
87        }
88    }
89}
90
91/// Voice Activity Detector
92pub struct WhisperVadProcessor {
93    ctx: *mut ffi::whisper_vad_context,
94}
95
96unsafe impl Send for WhisperVadProcessor {}
97unsafe impl Sync for WhisperVadProcessor {}
98
99impl Drop for WhisperVadProcessor {
100    fn drop(&mut self) {
101        unsafe {
102            if !self.ctx.is_null() {
103                ffi::whisper_vad_free(self.ctx);
104            }
105        }
106    }
107}
108
109impl WhisperVadProcessor {
110    /// Create a new VAD processor from a model file
111    pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
112        Self::new_with_params(model_path, VadContextParams::default())
113    }
114
115    /// Create a new VAD processor with custom parameters
116    pub fn new_with_params<P: AsRef<Path>>(
117        model_path: P,
118        params: VadContextParams,
119    ) -> Result<Self> {
120        let path_str = model_path
121            .as_ref()
122            .to_str()
123            .ok_or_else(|| WhisperError::ModelLoadError("Invalid path".into()))?;
124
125        let c_path = std::ffi::CString::new(path_str)?;
126
127        let ctx = unsafe {
128            ffi::whisper_vad_init_from_file_with_params(c_path.as_ptr(), params.to_ffi())
129        };
130
131        if ctx.is_null() {
132            return Err(WhisperError::ModelLoadError(
133                "Failed to load VAD model".into(),
134            ));
135        }
136
137        Ok(Self { ctx })
138    }
139
140    /// Detect speech in audio samples
141    pub fn detect_speech(&mut self, samples: &[f32]) -> bool {
142        if samples.is_empty() {
143            return false;
144        }
145
146        unsafe { ffi::whisper_vad_detect_speech(self.ctx, samples.as_ptr(), samples.len() as i32) }
147    }
148
149    /// Get the number of probability values
150    pub fn n_probs(&self) -> i32 {
151        unsafe { ffi::whisper_vad_n_probs(self.ctx) }
152    }
153
154    /// Get probability values
155    pub fn get_probs(&self) -> Vec<f32> {
156        let n = self.n_probs();
157        if n == 0 {
158            return Vec::new();
159        }
160
161        let probs_ptr = unsafe { ffi::whisper_vad_probs(self.ctx) };
162        if probs_ptr.is_null() {
163            return Vec::new();
164        }
165
166        let slice = unsafe { std::slice::from_raw_parts(probs_ptr, n as usize) };
167        slice.to_vec()
168    }
169
170    /// Get speech segments from probability values
171    pub fn segments_from_probs(&mut self, params: &VadParams) -> Result<VadSegments> {
172        let segments_ptr =
173            unsafe { ffi::whisper_vad_segments_from_probs(self.ctx, params.to_ffi()) };
174
175        if segments_ptr.is_null() {
176            return Err(WhisperError::InvalidContext);
177        }
178
179        Ok(VadSegments { ptr: segments_ptr })
180    }
181
182    /// Get speech segments directly from audio samples
183    pub fn segments_from_samples(
184        &mut self,
185        samples: &[f32],
186        params: &VadParams,
187    ) -> Result<VadSegments> {
188        if samples.is_empty() {
189            return Err(WhisperError::InvalidAudioFormat);
190        }
191
192        let segments_ptr = unsafe {
193            ffi::whisper_vad_segments_from_samples(
194                self.ctx,
195                params.to_ffi(),
196                samples.as_ptr(),
197                samples.len() as i32,
198            )
199        };
200
201        if segments_ptr.is_null() {
202            return Err(WhisperError::InvalidContext);
203        }
204
205        Ok(VadSegments { ptr: segments_ptr })
206    }
207}
208
209/// Speech segments detected by VAD
210pub struct VadSegments {
211    ptr: *mut ffi::whisper_vad_segments,
212}
213
214impl Drop for VadSegments {
215    fn drop(&mut self) {
216        unsafe {
217            if !self.ptr.is_null() {
218                ffi::whisper_vad_free_segments(self.ptr);
219            }
220        }
221    }
222}
223
224impl VadSegments {
225    /// Get the number of segments
226    pub fn n_segments(&self) -> i32 {
227        unsafe { ffi::whisper_vad_segments_n_segments(self.ptr) }
228    }
229
230    /// Get segment start time in seconds
231    pub fn get_segment_t0(&self, i_segment: i32) -> f32 {
232        // The FFI returns time in centiseconds, convert to seconds
233        unsafe { ffi::whisper_vad_segments_get_segment_t0(self.ptr, i_segment) / 100.0 }
234    }
235
236    /// Get segment end time in seconds
237    pub fn get_segment_t1(&self, i_segment: i32) -> f32 {
238        // The FFI returns time in centiseconds, convert to seconds
239        unsafe { ffi::whisper_vad_segments_get_segment_t1(self.ptr, i_segment) / 100.0 }
240    }
241
242    /// Get all segments as tuples of (start, end) times in seconds
243    pub fn get_all_segments(&self) -> Vec<(f32, f32)> {
244        let n = self.n_segments();
245        let mut segments = Vec::with_capacity(n as usize);
246
247        for i in 0..n {
248            segments.push((self.get_segment_t0(i), self.get_segment_t1(i)));
249        }
250
251        segments
252    }
253
254    /// Extract audio segments from the original audio based on VAD segments
255    pub fn extract_audio_segments(&self, audio: &[f32], sample_rate: f32) -> Vec<Vec<f32>> {
256        let segments = self.get_all_segments();
257        let mut audio_segments = Vec::with_capacity(segments.len());
258
259        for (start, end) in segments {
260            let start_sample = (start * sample_rate) as usize;
261            let end_sample = (end * sample_rate) as usize;
262
263            if start_sample < audio.len() && end_sample <= audio.len() {
264                audio_segments.push(audio[start_sample..end_sample].to_vec());
265            }
266        }
267
268        audio_segments
269    }
270}
271
272/// Builder for VadParams
273pub struct VadParamsBuilder {
274    params: VadParams,
275}
276
277impl VadParamsBuilder {
278    /// Create a new builder with default values
279    pub fn new() -> Self {
280        Self {
281            params: VadParams::default(),
282        }
283    }
284
285    /// Set the probability threshold (0.0 - 1.0)
286    pub fn threshold(mut self, threshold: f32) -> Self {
287        self.params.threshold = threshold.clamp(0.0, 1.0);
288        self
289    }
290
291    /// Set minimum speech duration in milliseconds
292    pub fn min_speech_duration_ms(mut self, ms: i32) -> Self {
293        self.params.min_speech_duration_ms = ms.max(0);
294        self
295    }
296
297    /// Set minimum silence duration in milliseconds
298    pub fn min_silence_duration_ms(mut self, ms: i32) -> Self {
299        self.params.min_silence_duration_ms = ms.max(0);
300        self
301    }
302
303    /// Set maximum speech duration in seconds
304    pub fn max_speech_duration_s(mut self, seconds: f32) -> Self {
305        self.params.max_speech_duration_s = seconds.max(0.0);
306        self
307    }
308
309    /// Set speech padding in milliseconds
310    pub fn speech_pad_ms(mut self, ms: i32) -> Self {
311        self.params.speech_pad_ms = ms.max(0);
312        self
313    }
314
315    /// Set samples overlap
316    pub fn samples_overlap(mut self, overlap: f32) -> Self {
317        self.params.samples_overlap = overlap.max(0.0);
318        self
319    }
320
321    /// Build the parameters
322    pub fn build(self) -> VadParams {
323        self.params
324    }
325}
326
327impl Default for VadParamsBuilder {
328    fn default() -> Self {
329        Self::new()
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    #[test]
338    fn test_vad_params_default() {
339        let params = VadParams::default();
340        assert!(params.threshold > 0.0 && params.threshold < 1.0);
341        assert!(params.min_speech_duration_ms >= 0);
342        assert!(params.max_speech_duration_s > 0.0);
343    }
344
345    #[test]
346    fn test_vad_params_builder() {
347        let params = VadParamsBuilder::new()
348            .threshold(0.6)
349            .min_speech_duration_ms(250)
350            .min_silence_duration_ms(100)
351            .max_speech_duration_s(30.0)
352            .speech_pad_ms(100)
353            .build();
354
355        assert_eq!(params.threshold, 0.6);
356        assert_eq!(params.min_speech_duration_ms, 250);
357        assert_eq!(params.min_silence_duration_ms, 100);
358        assert_eq!(params.max_speech_duration_s, 30.0);
359        assert_eq!(params.speech_pad_ms, 100);
360    }
361
362    #[test]
363    fn test_vad_params_builder_clamps() {
364        let params = VadParamsBuilder::new()
365            .threshold(1.5) // Should be clamped to 1.0
366            .min_speech_duration_ms(-100) // Should be clamped to 0
367            .build();
368
369        assert_eq!(params.threshold, 1.0);
370        assert_eq!(params.min_speech_duration_ms, 0);
371    }
372
373    #[test]
374    fn test_vad_processor_creation() {
375        // This test will only run if a VAD model is available
376        let model_path = "tests/models/ggml-silero-vad.bin";
377        if Path::new(model_path).exists() {
378            let processor = WhisperVadProcessor::new(model_path);
379            assert!(processor.is_ok());
380        } else {
381            eprintln!("Skipping VAD processor creation test: model not found");
382        }
383    }
384
385    #[test]
386    fn test_vad_context_params() {
387        let params = VadContextParams::default();
388        assert!(params.n_threads > 0);
389
390        let custom_params = VadContextParams {
391            n_threads: 4,
392            use_gpu: true,
393            gpu_device: 0,
394        };
395        assert_eq!(custom_params.n_threads, 4);
396        assert!(custom_params.use_gpu);
397        assert_eq!(custom_params.gpu_device, 0);
398    }
399}