Skip to main content

whisper_cpp_plus/
vad.rs

1//! Voice Activity Detection (VAD) support
2//!
3//! This module provides VAD capabilities for detecting speech segments
4//! in audio before transcription, improving performance and accuracy.
5
6use crate::error::{Result, WhisperError};
7use std::path::Path;
8use whisper_cpp_plus_sys as ffi;
9
10/// VAD parameters for speech detection
11#[derive(Debug, Clone)]
12pub struct VadParams {
13    /// Probability threshold to consider as speech (0.0 - 1.0)
14    pub threshold: f32,
15    /// Minimum duration for a valid speech segment (in milliseconds)
16    pub min_speech_duration_ms: i32,
17    /// Minimum duration of silence to split segments (in milliseconds)
18    pub min_silence_duration_ms: i32,
19    /// Maximum speech duration before forcing a segment break (in seconds)
20    pub max_speech_duration_s: f32,
21    /// Padding added before and after speech segments (in milliseconds)
22    pub speech_pad_ms: i32,
23    /// Overlap in seconds when copying audio samples from speech segment
24    pub samples_overlap: f32,
25}
26
27impl Default for VadParams {
28    fn default() -> Self {
29        // Use whisper.cpp's default VAD parameters
30        let default_params = unsafe { ffi::whisper_vad_default_params() };
31
32        Self {
33            threshold: default_params.threshold,
34            min_speech_duration_ms: default_params.min_speech_duration_ms,
35            min_silence_duration_ms: default_params.min_silence_duration_ms,
36            max_speech_duration_s: default_params.max_speech_duration_s,
37            speech_pad_ms: default_params.speech_pad_ms,
38            samples_overlap: default_params.samples_overlap,
39        }
40    }
41}
42
43impl VadParams {
44    /// Convert to FFI params
45    fn to_ffi(&self) -> ffi::whisper_vad_params {
46        ffi::whisper_vad_params {
47            threshold: self.threshold,
48            min_speech_duration_ms: self.min_speech_duration_ms,
49            min_silence_duration_ms: self.min_silence_duration_ms,
50            max_speech_duration_s: self.max_speech_duration_s,
51            speech_pad_ms: self.speech_pad_ms,
52            samples_overlap: self.samples_overlap,
53        }
54    }
55}
56
57/// VAD context parameters
58#[derive(Debug, Clone)]
59pub struct VadContextParams {
60    /// Number of threads to use for processing
61    pub n_threads: i32,
62    /// Whether to use GPU acceleration
63    pub use_gpu: bool,
64    /// GPU device ID to use
65    pub gpu_device: i32,
66}
67
68impl Default for VadContextParams {
69    fn default() -> Self {
70        let default_params = unsafe { ffi::whisper_vad_default_context_params() };
71
72        Self {
73            n_threads: default_params.n_threads,
74            use_gpu: default_params.use_gpu,
75            gpu_device: default_params.gpu_device,
76        }
77    }
78}
79
80impl VadContextParams {
81    /// Convert to FFI params
82    fn to_ffi(&self) -> ffi::whisper_vad_context_params {
83        ffi::whisper_vad_context_params {
84            n_threads: self.n_threads,
85            use_gpu: self.use_gpu,
86            gpu_device: self.gpu_device,
87        }
88    }
89}
90
91/// Voice Activity Detector
92pub struct WhisperVadProcessor {
93    ctx: *mut ffi::whisper_vad_context,
94}
95
96unsafe impl Send for WhisperVadProcessor {}
97unsafe impl Sync for WhisperVadProcessor {}
98
99impl Drop for WhisperVadProcessor {
100    fn drop(&mut self) {
101        unsafe {
102            if !self.ctx.is_null() {
103                ffi::whisper_vad_free(self.ctx);
104            }
105        }
106    }
107}
108
109impl WhisperVadProcessor {
110    /// Create a new VAD processor from a model file
111    pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
112        Self::new_with_params(model_path, VadContextParams::default())
113    }
114
115    /// Create a new VAD processor with custom parameters
116    pub fn new_with_params<P: AsRef<Path>>(
117        model_path: P,
118        params: VadContextParams,
119    ) -> Result<Self> {
120        let path_str = model_path
121            .as_ref()
122            .to_str()
123            .ok_or_else(|| WhisperError::ModelLoadError("Invalid path".into()))?;
124
125        let c_path = std::ffi::CString::new(path_str)?;
126
127        let ctx = unsafe {
128            ffi::whisper_vad_init_from_file_with_params(c_path.as_ptr(), params.to_ffi())
129        };
130
131        if ctx.is_null() {
132            return Err(WhisperError::ModelLoadError(
133                "Failed to load VAD model".into(),
134            ));
135        }
136
137        Ok(Self { ctx })
138    }
139
140    /// Detect speech in audio samples
141    pub fn detect_speech(&mut self, samples: &[f32]) -> bool {
142        if samples.is_empty() {
143            return false;
144        }
145
146        unsafe {
147            ffi::whisper_vad_detect_speech(
148                self.ctx,
149                samples.as_ptr(),
150                samples.len() as i32,
151            )
152        }
153    }
154
155    /// Get the number of probability values
156    pub fn n_probs(&self) -> i32 {
157        unsafe { ffi::whisper_vad_n_probs(self.ctx) }
158    }
159
160    /// Get probability values
161    pub fn get_probs(&self) -> Vec<f32> {
162        let n = self.n_probs();
163        if n == 0 {
164            return Vec::new();
165        }
166
167        let probs_ptr = unsafe { ffi::whisper_vad_probs(self.ctx) };
168        if probs_ptr.is_null() {
169            return Vec::new();
170        }
171
172        let slice = unsafe { std::slice::from_raw_parts(probs_ptr, n as usize) };
173        slice.to_vec()
174    }
175
176    /// Get speech segments from probability values
177    pub fn segments_from_probs(&mut self, params: &VadParams) -> Result<VadSegments> {
178        let segments_ptr = unsafe {
179            ffi::whisper_vad_segments_from_probs(self.ctx, params.to_ffi())
180        };
181
182        if segments_ptr.is_null() {
183            return Err(WhisperError::InvalidContext);
184        }
185
186        Ok(VadSegments {
187            ptr: segments_ptr,
188        })
189    }
190
191    /// Get speech segments directly from audio samples
192    pub fn segments_from_samples(
193        &mut self,
194        samples: &[f32],
195        params: &VadParams,
196    ) -> Result<VadSegments> {
197        if samples.is_empty() {
198            return Err(WhisperError::InvalidAudioFormat);
199        }
200
201        let segments_ptr = unsafe {
202            ffi::whisper_vad_segments_from_samples(
203                self.ctx,
204                params.to_ffi(),
205                samples.as_ptr(),
206                samples.len() as i32,
207            )
208        };
209
210        if segments_ptr.is_null() {
211            return Err(WhisperError::InvalidContext);
212        }
213
214        Ok(VadSegments {
215            ptr: segments_ptr,
216        })
217    }
218}
219
220/// Speech segments detected by VAD
221pub struct VadSegments {
222    ptr: *mut ffi::whisper_vad_segments,
223}
224
225impl Drop for VadSegments {
226    fn drop(&mut self) {
227        unsafe {
228            if !self.ptr.is_null() {
229                ffi::whisper_vad_free_segments(self.ptr);
230            }
231        }
232    }
233}
234
235impl VadSegments {
236    /// Get the number of segments
237    pub fn n_segments(&self) -> i32 {
238        unsafe { ffi::whisper_vad_segments_n_segments(self.ptr) }
239    }
240
241    /// Get segment start time in seconds
242    pub fn get_segment_t0(&self, i_segment: i32) -> f32 {
243        // The FFI returns time in centiseconds, convert to seconds
244        unsafe { ffi::whisper_vad_segments_get_segment_t0(self.ptr, i_segment) / 100.0 }
245    }
246
247    /// Get segment end time in seconds
248    pub fn get_segment_t1(&self, i_segment: i32) -> f32 {
249        // The FFI returns time in centiseconds, convert to seconds
250        unsafe { ffi::whisper_vad_segments_get_segment_t1(self.ptr, i_segment) / 100.0 }
251    }
252
253    /// Get all segments as tuples of (start, end) times in seconds
254    pub fn get_all_segments(&self) -> Vec<(f32, f32)> {
255        let n = self.n_segments();
256        let mut segments = Vec::with_capacity(n as usize);
257
258        for i in 0..n {
259            segments.push((self.get_segment_t0(i), self.get_segment_t1(i)));
260        }
261
262        segments
263    }
264
265    /// Extract audio segments from the original audio based on VAD segments
266    pub fn extract_audio_segments(&self, audio: &[f32], sample_rate: f32) -> Vec<Vec<f32>> {
267        let segments = self.get_all_segments();
268        let mut audio_segments = Vec::with_capacity(segments.len());
269
270        for (start, end) in segments {
271            let start_sample = (start * sample_rate) as usize;
272            let end_sample = (end * sample_rate) as usize;
273
274            if start_sample < audio.len() && end_sample <= audio.len() {
275                audio_segments.push(audio[start_sample..end_sample].to_vec());
276            }
277        }
278
279        audio_segments
280    }
281}
282
283/// Builder for VadParams
284pub struct VadParamsBuilder {
285    params: VadParams,
286}
287
288impl VadParamsBuilder {
289    /// Create a new builder with default values
290    pub fn new() -> Self {
291        Self {
292            params: VadParams::default(),
293        }
294    }
295
296    /// Set the probability threshold (0.0 - 1.0)
297    pub fn threshold(mut self, threshold: f32) -> Self {
298        self.params.threshold = threshold.clamp(0.0, 1.0);
299        self
300    }
301
302    /// Set minimum speech duration in milliseconds
303    pub fn min_speech_duration_ms(mut self, ms: i32) -> Self {
304        self.params.min_speech_duration_ms = ms.max(0);
305        self
306    }
307
308    /// Set minimum silence duration in milliseconds
309    pub fn min_silence_duration_ms(mut self, ms: i32) -> Self {
310        self.params.min_silence_duration_ms = ms.max(0);
311        self
312    }
313
314    /// Set maximum speech duration in seconds
315    pub fn max_speech_duration_s(mut self, seconds: f32) -> Self {
316        self.params.max_speech_duration_s = seconds.max(0.0);
317        self
318    }
319
320    /// Set speech padding in milliseconds
321    pub fn speech_pad_ms(mut self, ms: i32) -> Self {
322        self.params.speech_pad_ms = ms.max(0);
323        self
324    }
325
326    /// Set samples overlap
327    pub fn samples_overlap(mut self, overlap: f32) -> Self {
328        self.params.samples_overlap = overlap.max(0.0);
329        self
330    }
331
332    /// Build the parameters
333    pub fn build(self) -> VadParams {
334        self.params
335    }
336}
337
338impl Default for VadParamsBuilder {
339    fn default() -> Self {
340        Self::new()
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn test_vad_params_default() {
350        let params = VadParams::default();
351        assert!(params.threshold > 0.0 && params.threshold < 1.0);
352        assert!(params.min_speech_duration_ms >= 0);
353        assert!(params.max_speech_duration_s > 0.0);
354    }
355
356    #[test]
357    fn test_vad_params_builder() {
358        let params = VadParamsBuilder::new()
359            .threshold(0.6)
360            .min_speech_duration_ms(250)
361            .min_silence_duration_ms(100)
362            .max_speech_duration_s(30.0)
363            .speech_pad_ms(100)
364            .build();
365
366        assert_eq!(params.threshold, 0.6);
367        assert_eq!(params.min_speech_duration_ms, 250);
368        assert_eq!(params.min_silence_duration_ms, 100);
369        assert_eq!(params.max_speech_duration_s, 30.0);
370        assert_eq!(params.speech_pad_ms, 100);
371    }
372
373    #[test]
374    fn test_vad_params_builder_clamps() {
375        let params = VadParamsBuilder::new()
376            .threshold(1.5) // Should be clamped to 1.0
377            .min_speech_duration_ms(-100) // Should be clamped to 0
378            .build();
379
380        assert_eq!(params.threshold, 1.0);
381        assert_eq!(params.min_speech_duration_ms, 0);
382    }
383
384    #[test]
385    fn test_vad_processor_creation() {
386        // This test will only run if a VAD model is available
387        let model_path = "tests/models/ggml-silero-vad.bin";
388        if Path::new(model_path).exists() {
389            let processor = WhisperVadProcessor::new(model_path);
390            assert!(processor.is_ok());
391        } else {
392            eprintln!("Skipping VAD processor creation test: model not found");
393        }
394    }
395
396    #[test]
397    fn test_vad_context_params() {
398        let params = VadContextParams::default();
399        assert!(params.n_threads > 0);
400
401        let custom_params = VadContextParams {
402            n_threads: 4,
403            use_gpu: true,
404            gpu_device: 0,
405        };
406        assert_eq!(custom_params.n_threads, 4);
407        assert!(custom_params.use_gpu);
408        assert_eq!(custom_params.gpu_device, 0);
409    }
410}