Skip to main content

whisper_cpp_plus/enhanced/
vad.rs

1//! Enhanced VAD functionality with segment aggregation
2//!
3//! This module provides advanced VAD features beyond the basic whisper.cpp implementation,
4//! inspired by faster-whisper's optimizations. VAD is a preprocessing step that happens
5//! BEFORE transcription, not part of the transcription API itself.
6
7use crate::vad::{WhisperVadProcessor, VadParams};
8use crate::error::Result;
9use std::path::Path;
10
11/// Enhanced VAD parameters with aggregation settings
12#[derive(Debug, Clone)]
13pub struct EnhancedVadParams {
14    /// Base VAD parameters from whisper.cpp
15    pub base: VadParams,
16    /// Maximum duration for aggregated segments (seconds)
17    pub max_segment_duration_s: f32,
18    /// Whether to merge adjacent segments
19    pub merge_segments: bool,
20    /// Minimum gap between segments to keep them separate (ms)
21    pub min_gap_ms: i32,
22}
23
24impl Default for EnhancedVadParams {
25    fn default() -> Self {
26        Self {
27            base: VadParams::default(),
28            max_segment_duration_s: 30.0,
29            merge_segments: true,
30            min_gap_ms: 100,
31        }
32    }
33}
34
35/// Enhanced VAD processor with segment aggregation
36pub struct EnhancedWhisperVadProcessor {
37    inner: WhisperVadProcessor,
38}
39
40impl EnhancedWhisperVadProcessor {
41    pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
42        Ok(Self {
43            inner: WhisperVadProcessor::new(model_path)?,
44        })
45    }
46
47    /// Process audio with segment aggregation
48    /// Returns aggregated speech chunks optimized for transcription
49    pub fn process_with_aggregation(
50        &mut self,
51        audio: &[f32],
52        params: &EnhancedVadParams,
53    ) -> Result<Vec<AudioChunk>> {
54        // Get raw segments from base VAD
55        let segments = self.inner.segments_from_samples(audio, &params.base)?;
56        let raw_segments = segments.get_all_segments();
57
58        // Apply aggregation
59        let aggregated = self.aggregate_segments(
60            raw_segments,
61            params.max_segment_duration_s,
62            params.min_gap_ms,
63            params.merge_segments,
64        );
65
66        // Extract audio chunks with metadata
67        let chunks = self.extract_audio_chunks(audio, aggregated, 16000.0);
68        Ok(chunks)
69    }
70
71    /// Aggregate segments to optimize for transcription
72    #[doc(hidden)]
73    pub fn aggregate_segments(
74        &self,
75        segments: Vec<(f32, f32)>,
76        max_duration: f32,
77        min_gap_ms: i32,
78        merge: bool,
79    ) -> Vec<(f32, f32)> {
80        if segments.is_empty() {
81            return Vec::new();
82        }
83
84        let mut aggregated = Vec::new();
85        let min_gap = min_gap_ms as f32 / 1000.0;
86
87        let mut current_start = segments[0].0;
88        let mut current_end = segments[0].1;
89
90        for (start, end) in segments.iter().skip(1) {
91            let gap = start - current_end;
92            let combined_duration = end - current_start;
93
94            // Check if we should merge with current segment
95            if merge && gap < min_gap && combined_duration <= max_duration {
96                // Extend current segment
97                current_end = *end;
98            } else {
99                // Save current segment and start new one
100                aggregated.push((current_start, current_end));
101                current_start = *start;
102                current_end = *end;
103            }
104        }
105
106        // Don't forget the last segment
107        aggregated.push((current_start, current_end));
108
109        aggregated
110    }
111
112    /// Extract audio chunks with metadata
113    fn extract_audio_chunks(
114        &self,
115        audio: &[f32],
116        segments: Vec<(f32, f32)>,
117        sample_rate: f32,
118    ) -> Vec<AudioChunk> {
119        segments
120            .into_iter()
121            .map(|(start, end)| {
122                let start_sample = (start * sample_rate) as usize;
123                let end_sample = ((end * sample_rate) as usize).min(audio.len());
124
125                AudioChunk {
126                    audio: audio[start_sample..end_sample].to_vec(),
127                    offset_seconds: start,
128                    duration_seconds: end - start,
129                    metadata: ChunkMetadata {
130                        original_start: start,
131                        original_end: end,
132                        sample_offset: start_sample,
133                    },
134                }
135            })
136            .collect()
137    }
138}
139
140/// Audio chunk with metadata for transcription
141#[derive(Debug, Clone)]
142pub struct AudioChunk {
143    /// Audio samples
144    pub audio: Vec<f32>,
145    /// Offset from original audio start (seconds)
146    pub offset_seconds: f32,
147    /// Duration of this chunk (seconds)
148    pub duration_seconds: f32,
149    /// Additional metadata
150    pub metadata: ChunkMetadata,
151}
152
153#[derive(Debug, Clone)]
154pub struct ChunkMetadata {
155    /// Original segment start time
156    pub original_start: f32,
157    /// Original segment end time
158    pub original_end: f32,
159    /// Sample offset in original audio
160    pub sample_offset: usize,
161}
162
163/// Builder for enhanced VAD parameters
164pub struct EnhancedVadParamsBuilder {
165    params: EnhancedVadParams,
166}
167
168impl EnhancedVadParamsBuilder {
169    pub fn new() -> Self {
170        Self {
171            params: EnhancedVadParams::default(),
172        }
173    }
174
175    pub fn threshold(mut self, threshold: f32) -> Self {
176        self.params.base.threshold = threshold;
177        self
178    }
179
180    pub fn max_segment_duration(mut self, seconds: f32) -> Self {
181        self.params.max_segment_duration_s = seconds;
182        self
183    }
184
185    pub fn merge_segments(mut self, merge: bool) -> Self {
186        self.params.merge_segments = merge;
187        self
188    }
189
190    pub fn min_gap_ms(mut self, ms: i32) -> Self {
191        self.params.min_gap_ms = ms;
192        self
193    }
194
195    pub fn speech_pad_ms(mut self, ms: i32) -> Self {
196        self.params.base.speech_pad_ms = ms;
197        self
198    }
199
200    pub fn build(self) -> EnhancedVadParams {
201        self.params
202    }
203}
204
205impl Default for EnhancedVadParamsBuilder {
206    fn default() -> Self {
207        Self::new()
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn test_segment_aggregation() {
217        let processor = EnhancedWhisperVadProcessor {
218            inner: unsafe { std::mem::zeroed() }, // Mock for testing aggregation logic
219        };
220
221        let segments = vec![
222            (0.0, 2.0),
223            (2.1, 4.0),  // Small gap - should merge
224            (4.5, 6.0),  // Larger gap
225            (10.0, 12.0), // Large gap - separate segment
226        ];
227
228        let aggregated = processor.aggregate_segments(segments, 30.0, 100, true);
229
230        assert_eq!(aggregated.len(), 3);
231        assert_eq!(aggregated[0], (0.0, 4.0)); // First two merged
232        assert_eq!(aggregated[1], (4.5, 6.0));
233        assert_eq!(aggregated[2], (10.0, 12.0));
234    }
235
236    #[test]
237    fn test_max_duration_split() {
238        let processor = EnhancedWhisperVadProcessor {
239            inner: unsafe { std::mem::zeroed() },
240        };
241
242        let segments = vec![
243            (0.0, 20.0),
244            (20.1, 40.0), // Would exceed 30s if merged
245        ];
246
247        let aggregated = processor.aggregate_segments(segments, 30.0, 100, true);
248
249        assert_eq!(aggregated.len(), 2); // Should not merge due to max duration
250    }
251
252    #[test]
253    fn test_enhanced_vad_params_builder() {
254        let params = EnhancedVadParamsBuilder::new()
255            .threshold(0.6)
256            .max_segment_duration(25.0)
257            .merge_segments(false)
258            .min_gap_ms(200)
259            .build();
260
261        assert_eq!(params.base.threshold, 0.6);
262        assert_eq!(params.max_segment_duration_s, 25.0);
263        assert!(!params.merge_segments);
264        assert_eq!(params.min_gap_ms, 200);
265    }
266}