subx_cli/services/vad/
detector.rs1use super::audio_processor::VadAudioProcessor;
2use crate::config::VadConfig;
3use crate::{Result, error::SubXError};
4use std::path::Path;
5use std::time::{Duration, Instant};
6use voice_activity_detector::{IteratorExt, LabeledAudio, VoiceActivityDetector};
7
8pub struct LocalVadDetector {
14 config: VadConfig,
15 audio_processor: VadAudioProcessor,
16}
17
18impl LocalVadDetector {
19 pub fn new(config: VadConfig) -> Result<Self> {
33 let cfg_clone = config.clone();
35 Ok(Self {
36 config,
37 audio_processor: VadAudioProcessor::new(cfg_clone.sample_rate, 1)?,
38 })
39 }
40
41 pub async fn detect_speech(&self, audio_path: &Path) -> Result<VadResult> {
61 let start_time = Instant::now();
62
63 let audio_data = self
65 .audio_processor
66 .load_and_prepare_audio(audio_path)
67 .await?;
68
69 let vad = VoiceActivityDetector::builder()
71 .sample_rate(self.config.sample_rate)
72 .chunk_size(self.config.chunk_size)
73 .build()
74 .map_err(|e| SubXError::audio_processing(format!("Failed to create VAD: {}", e)))?;
75
76 let speech_segments = self.detect_speech_segments(vad, &audio_data.samples)?;
78
79 let processing_duration = start_time.elapsed();
80
81 Ok(VadResult {
82 speech_segments,
83 processing_duration,
84 audio_info: audio_data.info,
85 })
86 }
87
88 fn detect_speech_segments(
89 &self,
90 vad: VoiceActivityDetector,
91 samples: &[i16],
92 ) -> Result<Vec<SpeechSegment>> {
93 let mut segments = Vec::new();
94 let chunk_duration_seconds = self.config.chunk_size as f64 / self.config.sample_rate as f64;
95
96 let labels: Vec<LabeledAudio<i16>> = samples
98 .iter()
99 .copied()
100 .label(
101 vad,
102 self.config.sensitivity,
103 self.config.padding_chunks as usize,
104 )
105 .collect();
106
107 let mut current_speech_start: Option<f64> = None;
108 let mut chunk_index = 0;
109
110 for label in labels {
111 let chunk_start_time = chunk_index as f64 * chunk_duration_seconds;
112
113 match label {
114 LabeledAudio::Speech(_chunk) => {
115 if current_speech_start.is_none() {
116 current_speech_start = Some(chunk_start_time);
117 }
118 }
119 LabeledAudio::NonSpeech(_chunk) => {
120 if let Some(start_time) = current_speech_start.take() {
121 let end_time = chunk_start_time;
122 let duration = end_time - start_time;
123
124 if duration >= self.config.min_speech_duration_ms as f64 / 1000.0 {
126 segments.push(SpeechSegment {
127 start_time,
128 end_time,
129 probability: self.config.sensitivity, duration,
131 });
132 }
133 }
134 }
135 }
136
137 chunk_index += 1;
138 }
139
140 if let Some(start_time) = current_speech_start {
142 let end_time = chunk_index as f64 * chunk_duration_seconds;
143 let duration = end_time - start_time;
144
145 if duration >= self.config.min_speech_duration_ms as f64 / 1000.0 {
146 segments.push(SpeechSegment {
147 start_time,
148 end_time,
149 probability: self.config.sensitivity,
150 duration,
151 });
152 }
153 }
154
155 Ok(self.merge_close_segments(segments))
157 }
158
159 fn merge_close_segments(&self, segments: Vec<SpeechSegment>) -> Vec<SpeechSegment> {
160 if segments.is_empty() {
161 return segments;
162 }
163
164 let mut merged = Vec::new();
165 let mut current = segments[0].clone();
166 let merge_threshold = self.config.speech_merge_gap_ms as f64 / 1000.0;
167
168 for segment in segments.into_iter().skip(1) {
169 if segment.start_time - current.end_time <= merge_threshold {
170 current.end_time = segment.end_time;
172 current.duration = current.end_time - current.start_time;
173 current.probability = current.probability.max(segment.probability);
174 } else {
175 merged.push(current);
177 current = segment;
178 }
179 }
180
181 merged.push(current);
182 merged
183 }
184}
185
186#[derive(Debug, Clone)]
191pub struct VadResult {
192 pub speech_segments: Vec<SpeechSegment>,
194 pub processing_duration: Duration,
196 pub audio_info: AudioInfo,
198}
199
200#[derive(Debug, Clone)]
205pub struct SpeechSegment {
206 pub start_time: f64,
208 pub end_time: f64,
210 pub probability: f32,
212 pub duration: f64,
214}
215
216#[derive(Debug, Clone)]
221pub struct AudioInfo {
222 pub sample_rate: u32,
224 pub channels: u16,
226 pub duration_seconds: f64,
228 pub total_samples: usize,
230}