1use std::fmt;
4use std::sync::Arc;
5
6use crate::error::{Error, Result};
7use crate::monitoring::{AtomicCounter, VADStats};
8use crate::time::{AudioDuration, AudioInstant, AudioTimestamp};
9use parking_lot::Mutex;
10use realfft::{RealFftPlanner, RealToComplex};
11
12use super::config::VadConfig;
13use super::metrics::{AdaptiveThresholdSnapshot, VadMetricsCollector, VadMetricsSnapshot};
14
15const NANOS_PER_SECOND: u128 = 1_000_000_000;
17const EPSILON: f32 = 1e-12;
19
20const MAX_SMOOTHING_FACTOR: f32 = 0.999;
25
26const MAX_NORMALIZED_METRIC: f32 = 10.0;
31
32const SILENCE_ENERGY_GATE: f32 = 0.02;
35const SILENCE_RELATIVE_GATE: f32 = 1.7;
37
38pub struct VadDetector {
41 config: VadConfig,
42 fft: Arc<dyn RealToComplex<f32>>,
43 window: Vec<f32>,
44 metrics: Arc<dyn VadMetricsCollector>,
45 processed_samples: AtomicCounter,
46 energy_weight: f32,
47 flux_weight: f32,
48 state: Mutex<DetectorState>,
49}
50
51impl fmt::Debug for VadDetector {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 let processed_samples = self.processed_samples.get();
54 f.debug_struct("VadDetector")
55 .field("config", &self.config)
56 .field("window_length", &self.window.len())
57 .field("energy_weight", &self.energy_weight)
58 .field("flux_weight", &self.flux_weight)
59 .field("processed_samples", &processed_samples)
60 .finish_non_exhaustive()
61 }
62}
63
64impl VadDetector {
65 pub fn new(config: VadConfig, metrics: Arc<dyn VadMetricsCollector>) -> Result<Self> {
67 config.validate()?;
68
69 let frame_length = config.frame_length_samples()?;
70 let window = hann_window(frame_length);
71
72 let mut planner = RealFftPlanner::<f32>::new();
73 let fft = planner.plan_fft_forward(config.fft_size()?);
74
75 let total_weight = config.energy_weight + config.flux_weight;
76 let (energy_weight, flux_weight) = (
77 config.energy_weight / total_weight,
78 config.flux_weight / total_weight,
79 );
80
81 let previous_spectrum = {
82 let tmp = fft.make_output_vec();
83 vec![0.0; tmp.len()]
84 };
85
86 let state = DetectorState {
87 energy_baseline: config.energy_floor.max(EPSILON),
88 flux_baseline: config.flux_floor.max(EPSILON),
89 dynamic_threshold: config.base_threshold.max(EPSILON),
90 previous_spectrum,
91 pre_emphasis_prev: 0.0,
92 active_segment: None,
93 };
94
95 Ok(Self {
96 config,
97 fft,
98 window,
99 metrics,
100 processed_samples: AtomicCounter::new(0),
101 energy_weight,
102 flux_weight,
103 state: Mutex::new(state),
104 })
105 }
106
107 #[must_use]
109 pub fn config(&self) -> &VadConfig {
110 &self.config
111 }
112
113 #[must_use]
115 pub fn active_segment_start_sample(&self) -> Option<usize> {
116 let state = self.state.lock();
117 state
118 .active_segment
119 .as_ref()
120 .map(|segment| segment.start_sample)
121 }
122
123 pub fn reset(&mut self, stream_start_time: AudioTimestamp) {
125 self.config.stream_start_time = stream_start_time;
126 self.processed_samples.reset();
127 let mut state = self.state.lock();
128 state.active_segment = None;
129 state.pre_emphasis_prev = 0.0;
130 }
131
132 pub fn detect(&self, samples: &[f32]) -> Result<Vec<SpeechChunk>> {
134 let detection_start = AudioInstant::now();
135 let chunk_len = samples.len() as u64;
136
137 let mut detector_state = self.state.lock();
138
139 let chunk_start_sample = self.processed_samples.fetch_add(chunk_len) as usize;
140 let chunk_end_sample = chunk_start_sample + samples.len();
141
142 let frames = match self.frame_signal(samples, chunk_start_sample, &mut detector_state) {
143 Ok(frames) => frames,
144 Err(err) => {
145 let _ = self.processed_samples.fetch_sub(chunk_len);
146 drop(detector_state);
147 return Err(err);
148 }
149 };
150
151 if frames.is_empty() {
152 let latency = AudioInstant::now().duration_since(detection_start);
153 let adaptive = AdaptiveThresholdSnapshot {
154 energy_baseline: detector_state.energy_baseline,
155 flux_baseline: detector_state.flux_baseline,
156 dynamic_threshold: detector_state.dynamic_threshold,
157 };
158 let snapshot = VadMetricsSnapshot::new(VADStats::new(), latency, adaptive);
159 self.metrics.record_vad_metrics(&snapshot);
160 drop(detector_state);
161 return Ok(Vec::new());
162 }
163
164 let energy = Self::compute_energy(&frames);
165 let flux = self.compute_spectral_flux(&frames, &mut detector_state)?;
166 let (chunks, mut stats) = self.merge_metrics(
167 &frames,
168 &energy,
169 &flux,
170 chunk_end_sample,
171 &mut detector_state,
172 )?;
173 stats.speech_frames = chunks.len() as u64;
174 let adaptive = AdaptiveThresholdSnapshot {
175 energy_baseline: detector_state.energy_baseline,
176 flux_baseline: detector_state.flux_baseline,
177 dynamic_threshold: detector_state.dynamic_threshold,
178 };
179 drop(detector_state);
180
181 let latency = AudioInstant::now().duration_since(detection_start);
182 let snapshot = VadMetricsSnapshot::new(stats, latency, adaptive);
183 self.metrics.record_vad_metrics(&snapshot);
184
185 Ok(chunks)
186 }
187
188 fn frame_signal(
189 &self,
190 samples: &[f32],
191 absolute_start: usize,
192 state: &mut DetectorState,
193 ) -> Result<Vec<Frame>> {
194 if samples.is_empty() {
195 return Ok(Vec::new());
196 }
197
198 let processed = self.preprocess_signal(samples, state);
199 let frame_length = self.config.frame_length_samples()?;
200 let hop_length = self.config.hop_length_samples()?;
201
202 if frame_length == 0 {
203 return Err(Error::Processing("frame length resolved to zero".into()));
204 }
205
206 let mut frames = Vec::new();
207 let mut start = 0usize;
208
209 while start + frame_length <= processed.len() {
210 #[allow(clippy::indexing_slicing)] let slice = &processed[start..start + frame_length];
212 let mut frame = Vec::with_capacity(frame_length);
213 frame.extend(
214 slice
215 .iter()
216 .zip(&self.window)
217 .map(|(sample, window)| sample * window),
218 );
219 frames.push(Frame {
220 data: frame,
221 start_sample: absolute_start + start,
222 valid_len: frame_length,
223 });
224 start += hop_length;
225 }
226
227 if start < processed.len() {
228 if let Some(slice) = processed.get(start..) {
229 let available = slice.len().min(frame_length);
230 let mut frame = Vec::with_capacity(frame_length);
231 frame.extend(
232 slice
233 .iter()
234 .zip(&self.window)
235 .map(|(sample, window)| sample * window),
236 );
237 frame.resize(frame_length, 0.0);
238 frames.push(Frame {
239 data: frame,
240 start_sample: absolute_start + start,
241 valid_len: available,
242 });
243 }
244 }
245
246 Ok(frames)
247 }
248
249 fn preprocess_signal(&self, samples: &[f32], state: &mut DetectorState) -> Vec<f32> {
250 match self.config.pre_emphasis {
251 Some(coeff) if coeff > 0.0 => {
252 let mut processed = Vec::with_capacity(samples.len());
253 let mut previous = state.pre_emphasis_prev;
254 for &sample in samples {
255 let emphasized = coeff.mul_add(-previous, sample);
256 processed.push(emphasized);
257 previous = sample;
258 }
259 if let Some(&last) = samples.last() {
260 state.pre_emphasis_prev = last;
261 }
262 processed
263 }
264 _ => {
265 if let Some(&last) = samples.last() {
266 state.pre_emphasis_prev = last;
267 }
268 samples.to_vec()
269 }
270 }
271 }
272
273 fn compute_energy(frames: &[Frame]) -> Vec<f32> {
274 let mut values = Vec::with_capacity(frames.len());
275
276 for frame in frames {
277 debug_assert!(!frame.data.is_empty(), "frame data should never be empty");
278 let sum_sq: f32 = frame.data.iter().map(|sample| sample * sample).sum();
279 let len = frame.data.len();
280 let rms = (sum_sq / len as f32).sqrt();
281
282 values.push(rms);
283 }
284
285 values
286 }
287
288 fn compute_spectral_flux(
289 &self,
290 frames: &[Frame],
291 state: &mut DetectorState,
292 ) -> Result<Vec<f32>> {
293 if frames.is_empty() {
294 return Ok(Vec::new());
295 }
296
297 let mut input = self.fft.make_input_vec();
298 let mut spectrum = self.fft.make_output_vec();
299 let mut scratch = self.fft.make_scratch_vec();
300 if state.previous_spectrum.len() != spectrum.len() {
301 state.previous_spectrum.resize(spectrum.len(), 0.0);
302 }
303 let previous = &mut state.previous_spectrum;
304
305 let mut values = Vec::with_capacity(frames.len());
306
307 for frame in frames {
308 debug_assert!(!frame.data.is_empty(), "frame data should never be empty");
309 input.fill(0.0);
310 let len = frame.data.len().min(input.len());
311 for (dst, &src) in input.iter_mut().zip(frame.data.iter()).take(len) {
312 *dst = src;
313 }
314
315 self.fft
316 .process_with_scratch(&mut input, &mut spectrum, &mut scratch)
317 .map_err(|err| Error::Processing(format!("FFT processing failed: {err}")))?;
318
319 let mut flux = 0.0f32;
320 for (bin, prev) in spectrum.iter().zip(previous.iter_mut()) {
321 let magnitude = bin.re.hypot(bin.im);
322 let diff = (magnitude - *prev).max(0.0);
323 flux += diff;
324 *prev = magnitude;
325 }
326
327 values.push(flux);
328 }
329
330 Ok(values)
331 }
332
333 fn merge_metrics(
334 &self,
335 frames: &[Frame],
336 energy: &[f32],
337 flux: &[f32],
338 chunk_end_sample: usize,
339 detector_state: &mut DetectorState,
340 ) -> Result<(Vec<SpeechChunk>, VADStats)> {
341 let mut stats = VADStats::new();
342 let mut segments = Vec::new();
343
344 let mut dynamic_threshold = detector_state.dynamic_threshold.max(EPSILON);
345 let mut energy_baseline = detector_state
346 .energy_baseline
347 .max(self.config.energy_floor)
348 .max(EPSILON);
349 let mut flux_baseline = detector_state
350 .flux_baseline
351 .max(self.config.flux_floor)
352 .max(EPSILON);
353
354 let silence_energy_smoothing = self.config.energy_smoothing.min(MAX_SMOOTHING_FACTOR);
355 let silence_flux_smoothing = self.config.flux_smoothing.min(MAX_SMOOTHING_FACTOR);
356 let silence_threshold_smoothing = self.config.threshold_smoothing.min(MAX_SMOOTHING_FACTOR);
357
358 let dynamic_threshold_min =
359 (self.config.base_threshold * self.config.release_margin).max(EPSILON);
360 let dynamic_threshold_max =
361 self.config.base_threshold * self.config.activation_margin * 2.0;
362
363 let mut active_segment = detector_state.active_segment.take();
364 let mut silence_run = active_segment
365 .as_ref()
366 .map_or(0usize, |state| state.silence_run);
367
368 for (idx, frame) in frames.iter().enumerate() {
369 let frame_start = AudioInstant::now();
370 let raw_energy = energy.get(idx).copied().ok_or_else(|| {
371 Error::Processing(format!("energy array length mismatch at index {idx}"))
372 })?;
373 let raw_flux = flux.get(idx).copied().ok_or_else(|| {
374 Error::Processing(format!("flux array length mismatch at index {idx}"))
375 })?;
376
377 let energy_denominator = energy_baseline.max(self.config.energy_floor).max(EPSILON);
378 let normalized_energy =
379 (raw_energy / energy_denominator).clamp(0.0, MAX_NORMALIZED_METRIC);
380 let flux_denominator = flux_baseline.max(self.config.flux_floor).max(EPSILON);
381 let normalized_flux = (raw_flux / flux_denominator).clamp(0.0, MAX_NORMALIZED_METRIC);
382 let energy_ratio = raw_energy / energy_denominator;
383
384 let combined = self
385 .energy_weight
386 .mul_add(normalized_energy, self.flux_weight * normalized_flux);
387
388 let base_threshold = if active_segment.is_some() {
389 dynamic_threshold * self.config.release_margin
390 } else {
391 dynamic_threshold * self.config.activation_margin
392 };
393 let threshold =
394 base_threshold.max(self.config.base_threshold * self.config.release_margin);
395 let low_energy = raw_energy < SILENCE_ENERGY_GATE;
396 let low_relative_energy = energy_ratio < SILENCE_RELATIVE_GATE;
397 let mut raw_is_speech = combined >= threshold;
398 if raw_is_speech && (low_energy || low_relative_energy) {
399 raw_is_speech = false;
400 }
401
402 let is_speech = if active_segment.is_some() {
403 if raw_is_speech {
404 silence_run = 0;
405 true
406 } else {
407 silence_run += 1;
408 silence_run <= self.config.hangover_frames
409 }
410 } else {
411 silence_run = 0;
412 raw_is_speech
413 };
414
415 if is_speech {
416 let segment_state = active_segment
417 .get_or_insert_with(|| ActiveSegmentState::new(frame.start_sample));
418 segment_state.score_sum += combined;
419 segment_state.energy_sum += raw_energy;
420 segment_state.frame_count += 1;
421 segment_state.last_end_sample = frame.start_sample + frame.valid_len.max(1);
422 segment_state.silence_run = silence_run;
423 } else if let Some(segment_state) = active_segment.take() {
424 let finalize_result =
425 self.finalize_segment(&segment_state, chunk_end_sample, &mut segments);
426 if let Err(err) = finalize_result {
427 detector_state.active_segment = Some(segment_state);
428 return Err(err);
429 }
430 silence_run = 0;
431 }
432
433 let _frame_processing = AudioInstant::now().duration_since(frame_start);
434 stats.frames_processed += 1;
435
436 if !is_speech {
438 dynamic_threshold = silence_threshold_smoothing.mul_add(
439 dynamic_threshold,
440 (1.0 - silence_threshold_smoothing) * combined,
441 );
442 energy_baseline = silence_energy_smoothing.mul_add(
443 energy_baseline,
444 (1.0 - silence_energy_smoothing) * raw_energy,
445 );
446 flux_baseline = silence_flux_smoothing
447 .mul_add(flux_baseline, (1.0 - silence_flux_smoothing) * raw_flux);
448 }
449
450 dynamic_threshold =
451 dynamic_threshold.clamp(dynamic_threshold_min, dynamic_threshold_max);
452 energy_baseline = energy_baseline.max(self.config.energy_floor).max(EPSILON);
453 flux_baseline = flux_baseline.max(self.config.flux_floor).max(EPSILON);
454 }
455
456 detector_state.dynamic_threshold = dynamic_threshold;
457 detector_state.energy_baseline = energy_baseline;
458 detector_state.flux_baseline = flux_baseline;
459
460 if let Some(mut segment_state) = active_segment {
462 segment_state.silence_run = silence_run;
463 detector_state.active_segment = Some(segment_state);
464 } else {
465 detector_state.active_segment = None;
466 }
467
468 Ok((segments, stats))
469 }
470
471 fn finalize_segment(
472 &self,
473 segment: &ActiveSegmentState,
474 chunk_end_sample: usize,
475 segments: &mut Vec<SpeechChunk>,
476 ) -> Result<()> {
477 if segment.last_end_sample <= segment.start_sample {
478 return Ok(());
479 }
480
481 if segment.frame_count < self.config.min_speech_frames {
482 return Ok(());
483 }
484
485 let clamped_end = segment
486 .last_end_sample
487 .min(chunk_end_sample.max(segment.start_sample + 1));
488 let start_time = self.absolute_time_for_sample(segment.start_sample)?;
489 let end_time = self.absolute_time_for_sample(clamped_end)?;
490
491 if end_time <= start_time {
492 return Ok(());
493 }
494
495 let confidence = (segment.score_sum / segment.frame_count as f32).clamp(0.0, 1.0);
496 let avg_energy = if segment.frame_count > 0 {
497 segment.energy_sum / segment.frame_count as f32
498 } else {
499 0.0
500 };
501
502 segments.push(SpeechChunk {
503 start_time,
504 end_time,
505 confidence,
506 avg_energy,
507 frame_count: segment.frame_count,
508 });
509
510 Ok(())
511 }
512
513 fn absolute_time_for_sample(&self, sample_index: usize) -> Result<AudioTimestamp> {
514 let offset = samples_to_duration(sample_index, self.config.sample_rate);
515 Ok(self.config.stream_start_time.add_duration(offset))
516 }
517}
518
519fn hann_window(length: usize) -> Vec<f32> {
520 if length == 0 {
521 return Vec::new();
522 }
523
524 if length == 1 {
525 return vec![1.0];
526 }
527
528 let denom = (length - 1) as f32;
529 (0..length)
530 .map(|n| {
531 let angle = 2.0 * std::f32::consts::PI * n as f32 / denom;
532 0.5f32.mul_add(-angle.cos(), 0.5)
533 })
534 .collect()
535}
536
537fn samples_to_duration(samples: usize, sample_rate: u32) -> AudioDuration {
538 let sr = u128::from(sample_rate);
539 let nanos = ((samples as u128) * NANOS_PER_SECOND + sr / 2) / sr;
540 AudioDuration::from_nanos(nanos as u64)
541}
542
543struct Frame {
544 data: Vec<f32>,
545 start_sample: usize,
546 valid_len: usize,
547}
548
549pub(super) struct DetectorState {
550 pub(super) energy_baseline: f32,
551 pub(super) flux_baseline: f32,
552 pub(super) dynamic_threshold: f32,
553 pub(super) previous_spectrum: Vec<f32>,
554 pub(super) pre_emphasis_prev: f32,
555 pub(super) active_segment: Option<ActiveSegmentState>,
556}
557
558pub(super) struct ActiveSegmentState {
559 pub(super) start_sample: usize,
560 pub(super) last_end_sample: usize,
561 pub(super) score_sum: f32,
562 pub(super) energy_sum: f32,
563 pub(super) frame_count: usize,
564 pub(super) silence_run: usize,
565}
566
567impl ActiveSegmentState {
568 pub(super) fn new(start_sample: usize) -> Self {
569 Self {
570 start_sample,
571 last_end_sample: start_sample,
572 score_sum: 0.0,
573 energy_sum: 0.0,
574 frame_count: 0,
575 silence_run: 0,
576 }
577 }
578}
579
580#[derive(Debug, Clone, Copy, PartialEq)]
582pub struct SpeechChunk {
583 pub start_time: AudioTimestamp,
585 pub end_time: AudioTimestamp,
587 pub confidence: f32,
589 pub avg_energy: f32,
591 pub frame_count: usize,
593}
594
595impl SpeechChunk {
596 pub fn duration(&self) -> Result<AudioDuration> {
598 self.end_time
599 .duration_since(self.start_time)
600 .ok_or_else(|| {
601 Error::Processing(
602 "failed to compute segment duration: end_time precedes start_time".into(),
603 )
604 })
605 }
606}