1use crate::error::{Result, WhisperError};
7use std::path::Path;
8use whisper_cpp_plus_sys as ffi;
9
10#[derive(Debug, Clone)]
12pub struct VadParams {
13 pub threshold: f32,
15 pub min_speech_duration_ms: i32,
17 pub min_silence_duration_ms: i32,
19 pub max_speech_duration_s: f32,
21 pub speech_pad_ms: i32,
23 pub samples_overlap: f32,
25}
26
27impl Default for VadParams {
28 fn default() -> Self {
29 let default_params = unsafe { ffi::whisper_vad_default_params() };
31
32 Self {
33 threshold: default_params.threshold,
34 min_speech_duration_ms: default_params.min_speech_duration_ms,
35 min_silence_duration_ms: default_params.min_silence_duration_ms,
36 max_speech_duration_s: default_params.max_speech_duration_s,
37 speech_pad_ms: default_params.speech_pad_ms,
38 samples_overlap: default_params.samples_overlap,
39 }
40 }
41}
42
43impl VadParams {
44 fn to_ffi(&self) -> ffi::whisper_vad_params {
46 ffi::whisper_vad_params {
47 threshold: self.threshold,
48 min_speech_duration_ms: self.min_speech_duration_ms,
49 min_silence_duration_ms: self.min_silence_duration_ms,
50 max_speech_duration_s: self.max_speech_duration_s,
51 speech_pad_ms: self.speech_pad_ms,
52 samples_overlap: self.samples_overlap,
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct VadContextParams {
60 pub n_threads: i32,
62 pub use_gpu: bool,
64 pub gpu_device: i32,
66}
67
68impl Default for VadContextParams {
69 fn default() -> Self {
70 let default_params = unsafe { ffi::whisper_vad_default_context_params() };
71
72 Self {
73 n_threads: default_params.n_threads,
74 use_gpu: default_params.use_gpu,
75 gpu_device: default_params.gpu_device,
76 }
77 }
78}
79
80impl VadContextParams {
81 fn to_ffi(&self) -> ffi::whisper_vad_context_params {
83 ffi::whisper_vad_context_params {
84 n_threads: self.n_threads,
85 use_gpu: self.use_gpu,
86 gpu_device: self.gpu_device,
87 }
88 }
89}
90
91pub struct WhisperVadProcessor {
93 ctx: *mut ffi::whisper_vad_context,
94}
95
96unsafe impl Send for WhisperVadProcessor {}
97unsafe impl Sync for WhisperVadProcessor {}
98
99impl Drop for WhisperVadProcessor {
100 fn drop(&mut self) {
101 unsafe {
102 if !self.ctx.is_null() {
103 ffi::whisper_vad_free(self.ctx);
104 }
105 }
106 }
107}
108
109impl WhisperVadProcessor {
110 pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
112 Self::new_with_params(model_path, VadContextParams::default())
113 }
114
115 pub fn new_with_params<P: AsRef<Path>>(
117 model_path: P,
118 params: VadContextParams,
119 ) -> Result<Self> {
120 let path_str = model_path
121 .as_ref()
122 .to_str()
123 .ok_or_else(|| WhisperError::ModelLoadError("Invalid path".into()))?;
124
125 let c_path = std::ffi::CString::new(path_str)?;
126
127 let ctx = unsafe {
128 ffi::whisper_vad_init_from_file_with_params(c_path.as_ptr(), params.to_ffi())
129 };
130
131 if ctx.is_null() {
132 return Err(WhisperError::ModelLoadError(
133 "Failed to load VAD model".into(),
134 ));
135 }
136
137 Ok(Self { ctx })
138 }
139
140 pub fn detect_speech(&mut self, samples: &[f32]) -> bool {
142 if samples.is_empty() {
143 return false;
144 }
145
146 unsafe {
147 ffi::whisper_vad_detect_speech(
148 self.ctx,
149 samples.as_ptr(),
150 samples.len() as i32,
151 )
152 }
153 }
154
155 pub fn n_probs(&self) -> i32 {
157 unsafe { ffi::whisper_vad_n_probs(self.ctx) }
158 }
159
160 pub fn get_probs(&self) -> Vec<f32> {
162 let n = self.n_probs();
163 if n == 0 {
164 return Vec::new();
165 }
166
167 let probs_ptr = unsafe { ffi::whisper_vad_probs(self.ctx) };
168 if probs_ptr.is_null() {
169 return Vec::new();
170 }
171
172 let slice = unsafe { std::slice::from_raw_parts(probs_ptr, n as usize) };
173 slice.to_vec()
174 }
175
176 pub fn segments_from_probs(&mut self, params: &VadParams) -> Result<VadSegments> {
178 let segments_ptr = unsafe {
179 ffi::whisper_vad_segments_from_probs(self.ctx, params.to_ffi())
180 };
181
182 if segments_ptr.is_null() {
183 return Err(WhisperError::InvalidContext);
184 }
185
186 Ok(VadSegments {
187 ptr: segments_ptr,
188 })
189 }
190
191 pub fn segments_from_samples(
193 &mut self,
194 samples: &[f32],
195 params: &VadParams,
196 ) -> Result<VadSegments> {
197 if samples.is_empty() {
198 return Err(WhisperError::InvalidAudioFormat);
199 }
200
201 let segments_ptr = unsafe {
202 ffi::whisper_vad_segments_from_samples(
203 self.ctx,
204 params.to_ffi(),
205 samples.as_ptr(),
206 samples.len() as i32,
207 )
208 };
209
210 if segments_ptr.is_null() {
211 return Err(WhisperError::InvalidContext);
212 }
213
214 Ok(VadSegments {
215 ptr: segments_ptr,
216 })
217 }
218}
219
220pub struct VadSegments {
222 ptr: *mut ffi::whisper_vad_segments,
223}
224
225impl Drop for VadSegments {
226 fn drop(&mut self) {
227 unsafe {
228 if !self.ptr.is_null() {
229 ffi::whisper_vad_free_segments(self.ptr);
230 }
231 }
232 }
233}
234
235impl VadSegments {
236 pub fn n_segments(&self) -> i32 {
238 unsafe { ffi::whisper_vad_segments_n_segments(self.ptr) }
239 }
240
241 pub fn get_segment_t0(&self, i_segment: i32) -> f32 {
243 unsafe { ffi::whisper_vad_segments_get_segment_t0(self.ptr, i_segment) / 100.0 }
245 }
246
247 pub fn get_segment_t1(&self, i_segment: i32) -> f32 {
249 unsafe { ffi::whisper_vad_segments_get_segment_t1(self.ptr, i_segment) / 100.0 }
251 }
252
253 pub fn get_all_segments(&self) -> Vec<(f32, f32)> {
255 let n = self.n_segments();
256 let mut segments = Vec::with_capacity(n as usize);
257
258 for i in 0..n {
259 segments.push((self.get_segment_t0(i), self.get_segment_t1(i)));
260 }
261
262 segments
263 }
264
265 pub fn extract_audio_segments(&self, audio: &[f32], sample_rate: f32) -> Vec<Vec<f32>> {
267 let segments = self.get_all_segments();
268 let mut audio_segments = Vec::with_capacity(segments.len());
269
270 for (start, end) in segments {
271 let start_sample = (start * sample_rate) as usize;
272 let end_sample = (end * sample_rate) as usize;
273
274 if start_sample < audio.len() && end_sample <= audio.len() {
275 audio_segments.push(audio[start_sample..end_sample].to_vec());
276 }
277 }
278
279 audio_segments
280 }
281}
282
283pub struct VadParamsBuilder {
285 params: VadParams,
286}
287
288impl VadParamsBuilder {
289 pub fn new() -> Self {
291 Self {
292 params: VadParams::default(),
293 }
294 }
295
296 pub fn threshold(mut self, threshold: f32) -> Self {
298 self.params.threshold = threshold.clamp(0.0, 1.0);
299 self
300 }
301
302 pub fn min_speech_duration_ms(mut self, ms: i32) -> Self {
304 self.params.min_speech_duration_ms = ms.max(0);
305 self
306 }
307
308 pub fn min_silence_duration_ms(mut self, ms: i32) -> Self {
310 self.params.min_silence_duration_ms = ms.max(0);
311 self
312 }
313
314 pub fn max_speech_duration_s(mut self, seconds: f32) -> Self {
316 self.params.max_speech_duration_s = seconds.max(0.0);
317 self
318 }
319
320 pub fn speech_pad_ms(mut self, ms: i32) -> Self {
322 self.params.speech_pad_ms = ms.max(0);
323 self
324 }
325
326 pub fn samples_overlap(mut self, overlap: f32) -> Self {
328 self.params.samples_overlap = overlap.max(0.0);
329 self
330 }
331
332 pub fn build(self) -> VadParams {
334 self.params
335 }
336}
337
338impl Default for VadParamsBuilder {
339 fn default() -> Self {
340 Self::new()
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn test_vad_params_default() {
350 let params = VadParams::default();
351 assert!(params.threshold > 0.0 && params.threshold < 1.0);
352 assert!(params.min_speech_duration_ms >= 0);
353 assert!(params.max_speech_duration_s > 0.0);
354 }
355
356 #[test]
357 fn test_vad_params_builder() {
358 let params = VadParamsBuilder::new()
359 .threshold(0.6)
360 .min_speech_duration_ms(250)
361 .min_silence_duration_ms(100)
362 .max_speech_duration_s(30.0)
363 .speech_pad_ms(100)
364 .build();
365
366 assert_eq!(params.threshold, 0.6);
367 assert_eq!(params.min_speech_duration_ms, 250);
368 assert_eq!(params.min_silence_duration_ms, 100);
369 assert_eq!(params.max_speech_duration_s, 30.0);
370 assert_eq!(params.speech_pad_ms, 100);
371 }
372
373 #[test]
374 fn test_vad_params_builder_clamps() {
375 let params = VadParamsBuilder::new()
376 .threshold(1.5) .min_speech_duration_ms(-100) .build();
379
380 assert_eq!(params.threshold, 1.0);
381 assert_eq!(params.min_speech_duration_ms, 0);
382 }
383
384 #[test]
385 fn test_vad_processor_creation() {
386 let model_path = "tests/models/ggml-silero-vad.bin";
388 if Path::new(model_path).exists() {
389 let processor = WhisperVadProcessor::new(model_path);
390 assert!(processor.is_ok());
391 } else {
392 eprintln!("Skipping VAD processor creation test: model not found");
393 }
394 }
395
396 #[test]
397 fn test_vad_context_params() {
398 let params = VadContextParams::default();
399 assert!(params.n_threads > 0);
400
401 let custom_params = VadContextParams {
402 n_threads: 4,
403 use_gpu: true,
404 gpu_device: 0,
405 };
406 assert_eq!(custom_params.n_threads, 4);
407 assert!(custom_params.use_gpu);
408 assert_eq!(custom_params.gpu_device, 0);
409 }
410}