whisper_cpp_plus/enhanced/
vad.rs1use crate::vad::{WhisperVadProcessor, VadParams};
8use crate::error::Result;
9use std::path::Path;
10
11#[derive(Debug, Clone)]
13pub struct EnhancedVadParams {
14 pub base: VadParams,
16 pub max_segment_duration_s: f32,
18 pub merge_segments: bool,
20 pub min_gap_ms: i32,
22}
23
24impl Default for EnhancedVadParams {
25 fn default() -> Self {
26 Self {
27 base: VadParams::default(),
28 max_segment_duration_s: 30.0,
29 merge_segments: true,
30 min_gap_ms: 100,
31 }
32 }
33}
34
35pub struct EnhancedWhisperVadProcessor {
37 inner: WhisperVadProcessor,
38}
39
40impl EnhancedWhisperVadProcessor {
41 pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
42 Ok(Self {
43 inner: WhisperVadProcessor::new(model_path)?,
44 })
45 }
46
47 pub fn process_with_aggregation(
50 &mut self,
51 audio: &[f32],
52 params: &EnhancedVadParams,
53 ) -> Result<Vec<AudioChunk>> {
54 let segments = self.inner.segments_from_samples(audio, ¶ms.base)?;
56 let raw_segments = segments.get_all_segments();
57
58 let aggregated = self.aggregate_segments(
60 raw_segments,
61 params.max_segment_duration_s,
62 params.min_gap_ms,
63 params.merge_segments,
64 );
65
66 let chunks = self.extract_audio_chunks(audio, aggregated, 16000.0);
68 Ok(chunks)
69 }
70
71 #[doc(hidden)]
73 pub fn aggregate_segments(
74 &self,
75 segments: Vec<(f32, f32)>,
76 max_duration: f32,
77 min_gap_ms: i32,
78 merge: bool,
79 ) -> Vec<(f32, f32)> {
80 if segments.is_empty() {
81 return Vec::new();
82 }
83
84 let mut aggregated = Vec::new();
85 let min_gap = min_gap_ms as f32 / 1000.0;
86
87 let mut current_start = segments[0].0;
88 let mut current_end = segments[0].1;
89
90 for (start, end) in segments.iter().skip(1) {
91 let gap = start - current_end;
92 let combined_duration = end - current_start;
93
94 if merge && gap < min_gap && combined_duration <= max_duration {
96 current_end = *end;
98 } else {
99 aggregated.push((current_start, current_end));
101 current_start = *start;
102 current_end = *end;
103 }
104 }
105
106 aggregated.push((current_start, current_end));
108
109 aggregated
110 }
111
112 fn extract_audio_chunks(
114 &self,
115 audio: &[f32],
116 segments: Vec<(f32, f32)>,
117 sample_rate: f32,
118 ) -> Vec<AudioChunk> {
119 segments
120 .into_iter()
121 .map(|(start, end)| {
122 let start_sample = (start * sample_rate) as usize;
123 let end_sample = ((end * sample_rate) as usize).min(audio.len());
124
125 AudioChunk {
126 audio: audio[start_sample..end_sample].to_vec(),
127 offset_seconds: start,
128 duration_seconds: end - start,
129 metadata: ChunkMetadata {
130 original_start: start,
131 original_end: end,
132 sample_offset: start_sample,
133 },
134 }
135 })
136 .collect()
137 }
138}
139
140#[derive(Debug, Clone)]
142pub struct AudioChunk {
143 pub audio: Vec<f32>,
145 pub offset_seconds: f32,
147 pub duration_seconds: f32,
149 pub metadata: ChunkMetadata,
151}
152
153#[derive(Debug, Clone)]
154pub struct ChunkMetadata {
155 pub original_start: f32,
157 pub original_end: f32,
159 pub sample_offset: usize,
161}
162
163pub struct EnhancedVadParamsBuilder {
165 params: EnhancedVadParams,
166}
167
168impl EnhancedVadParamsBuilder {
169 pub fn new() -> Self {
170 Self {
171 params: EnhancedVadParams::default(),
172 }
173 }
174
175 pub fn threshold(mut self, threshold: f32) -> Self {
176 self.params.base.threshold = threshold;
177 self
178 }
179
180 pub fn max_segment_duration(mut self, seconds: f32) -> Self {
181 self.params.max_segment_duration_s = seconds;
182 self
183 }
184
185 pub fn merge_segments(mut self, merge: bool) -> Self {
186 self.params.merge_segments = merge;
187 self
188 }
189
190 pub fn min_gap_ms(mut self, ms: i32) -> Self {
191 self.params.min_gap_ms = ms;
192 self
193 }
194
195 pub fn speech_pad_ms(mut self, ms: i32) -> Self {
196 self.params.base.speech_pad_ms = ms;
197 self
198 }
199
200 pub fn build(self) -> EnhancedVadParams {
201 self.params
202 }
203}
204
205impl Default for EnhancedVadParamsBuilder {
206 fn default() -> Self {
207 Self::new()
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 #[test]
216 fn test_segment_aggregation() {
217 let processor = EnhancedWhisperVadProcessor {
218 inner: unsafe { std::mem::zeroed() }, };
220
221 let segments = vec![
222 (0.0, 2.0),
223 (2.1, 4.0), (4.5, 6.0), (10.0, 12.0), ];
227
228 let aggregated = processor.aggregate_segments(segments, 30.0, 100, true);
229
230 assert_eq!(aggregated.len(), 3);
231 assert_eq!(aggregated[0], (0.0, 4.0)); assert_eq!(aggregated[1], (4.5, 6.0));
233 assert_eq!(aggregated[2], (10.0, 12.0));
234 }
235
236 #[test]
237 fn test_max_duration_split() {
238 let processor = EnhancedWhisperVadProcessor {
239 inner: unsafe { std::mem::zeroed() },
240 };
241
242 let segments = vec![
243 (0.0, 20.0),
244 (20.1, 40.0), ];
246
247 let aggregated = processor.aggregate_segments(segments, 30.0, 100, true);
248
249 assert_eq!(aggregated.len(), 2); }
251
252 #[test]
253 fn test_enhanced_vad_params_builder() {
254 let params = EnhancedVadParamsBuilder::new()
255 .threshold(0.6)
256 .max_segment_duration(25.0)
257 .merge_segments(false)
258 .min_gap_ms(200)
259 .build();
260
261 assert_eq!(params.base.threshold, 0.6);
262 assert_eq!(params.max_segment_duration_s, 25.0);
263 assert!(!params.merge_segments);
264 assert_eq!(params.min_gap_ms, 200);
265 }
266}