Skip to main content

speech_prep/vad/
detector.rs

1//! Core VAD detection engine with dual-metric analysis.
2
3use std::fmt;
4use std::sync::Arc;
5
6use crate::error::{Error, Result};
7use crate::monitoring::{AtomicCounter, VADStats};
8use crate::time::{AudioDuration, AudioInstant, AudioTimestamp};
9use parking_lot::Mutex;
10use realfft::{RealFftPlanner, RealToComplex};
11
12use super::config::VadConfig;
13use super::metrics::{AdaptiveThresholdSnapshot, VadMetricsCollector, VadMetricsSnapshot};
14
15/// Number of nanoseconds in one second, used for time conversion.
16const NANOS_PER_SECOND: u128 = 1_000_000_000;
17/// Small epsilon value for numerical stability in floating-point comparisons.
18const EPSILON: f32 = 1e-12;
19
20/// Maximum smoothing factor for baseline tracking (energy, flux, threshold).
21/// Capped at 0.999 to prevent numerical instability from exponential moving
22/// average converging too slowly. At 0.999, half-life ≈ 693 samples (43ms at
23/// 16kHz).
24const MAX_SMOOTHING_FACTOR: f32 = 0.999;
25
26/// Maximum normalized value for energy and spectral flux metrics.
27/// Caps outliers at 10x the baseline to prevent extreme transients from
28/// dominating the detection logic while still allowing headroom for loud
29/// signals.
30const MAX_NORMALIZED_METRIC: f32 = 10.0;
31
32/// Energy level under which we consider a frame near-silence regardless of
33/// normalization.
34const SILENCE_ENERGY_GATE: f32 = 0.02;
35/// Relative energy ratio below which we consider a frame near-silence.
36const SILENCE_RELATIVE_GATE: f32 = 1.7;
37
38/// Real-time voice activity detector combining energy and spectral flux
39/// metrics.
40pub struct VadDetector {
41    config: VadConfig,
42    fft: Arc<dyn RealToComplex<f32>>,
43    window: Vec<f32>,
44    metrics: Arc<dyn VadMetricsCollector>,
45    processed_samples: AtomicCounter,
46    energy_weight: f32,
47    flux_weight: f32,
48    state: Mutex<DetectorState>,
49}
50
51impl fmt::Debug for VadDetector {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        let processed_samples = self.processed_samples.get();
54        f.debug_struct("VadDetector")
55            .field("config", &self.config)
56            .field("window_length", &self.window.len())
57            .field("energy_weight", &self.energy_weight)
58            .field("flux_weight", &self.flux_weight)
59            .field("processed_samples", &processed_samples)
60            .finish_non_exhaustive()
61    }
62}
63
64impl VadDetector {
65    /// Construct a new detector instance.
66    pub fn new(config: VadConfig, metrics: Arc<dyn VadMetricsCollector>) -> Result<Self> {
67        config.validate()?;
68
69        let frame_length = config.frame_length_samples()?;
70        let window = hann_window(frame_length);
71
72        let mut planner = RealFftPlanner::<f32>::new();
73        let fft = planner.plan_fft_forward(config.fft_size()?);
74
75        let total_weight = config.energy_weight + config.flux_weight;
76        let (energy_weight, flux_weight) = (
77            config.energy_weight / total_weight,
78            config.flux_weight / total_weight,
79        );
80
81        let previous_spectrum = {
82            let tmp = fft.make_output_vec();
83            vec![0.0; tmp.len()]
84        };
85
86        let state = DetectorState {
87            energy_baseline: config.energy_floor.max(EPSILON),
88            flux_baseline: config.flux_floor.max(EPSILON),
89            dynamic_threshold: config.base_threshold.max(EPSILON),
90            previous_spectrum,
91            pre_emphasis_prev: 0.0,
92            active_segment: None,
93        };
94
95        Ok(Self {
96            config,
97            fft,
98            window,
99            metrics,
100            processed_samples: AtomicCounter::new(0),
101            energy_weight,
102            flux_weight,
103            state: Mutex::new(state),
104        })
105    }
106
107    /// Access detector configuration.
108    #[must_use]
109    pub fn config(&self) -> &VadConfig {
110        &self.config
111    }
112
113    /// Return the start sample of the currently active speech segment, if any.
114    #[must_use]
115    pub fn active_segment_start_sample(&self) -> Option<usize> {
116        let state = self.state.lock();
117        state
118            .active_segment
119            .as_ref()
120            .map(|segment| segment.start_sample)
121    }
122
123    /// Reset processed sample count and stream start time.
124    pub fn reset(&mut self, stream_start_time: AudioTimestamp) {
125        self.config.stream_start_time = stream_start_time;
126        self.processed_samples.reset();
127        let mut state = self.state.lock();
128        state.active_segment = None;
129        state.pre_emphasis_prev = 0.0;
130    }
131
132    /// Run detection on a slice of samples, returning detected speech segments.
133    pub fn detect(&self, samples: &[f32]) -> Result<Vec<SpeechChunk>> {
134        let detection_start = AudioInstant::now();
135        let chunk_len = samples.len() as u64;
136
137        let mut detector_state = self.state.lock();
138
139        let chunk_start_sample = self.processed_samples.fetch_add(chunk_len) as usize;
140        let chunk_end_sample = chunk_start_sample + samples.len();
141
142        let frames = match self.frame_signal(samples, chunk_start_sample, &mut detector_state) {
143            Ok(frames) => frames,
144            Err(err) => {
145                let _ = self.processed_samples.fetch_sub(chunk_len);
146                drop(detector_state);
147                return Err(err);
148            }
149        };
150
151        if frames.is_empty() {
152            let latency = AudioInstant::now().duration_since(detection_start);
153            let adaptive = AdaptiveThresholdSnapshot {
154                energy_baseline: detector_state.energy_baseline,
155                flux_baseline: detector_state.flux_baseline,
156                dynamic_threshold: detector_state.dynamic_threshold,
157            };
158            let snapshot = VadMetricsSnapshot::new(VADStats::new(), latency, adaptive);
159            self.metrics.record_vad_metrics(&snapshot);
160            drop(detector_state);
161            return Ok(Vec::new());
162        }
163
164        let energy = Self::compute_energy(&frames);
165        let flux = self.compute_spectral_flux(&frames, &mut detector_state)?;
166        let (chunks, mut stats) = self.merge_metrics(
167            &frames,
168            &energy,
169            &flux,
170            chunk_end_sample,
171            &mut detector_state,
172        )?;
173        stats.speech_frames = chunks.len() as u64;
174        let adaptive = AdaptiveThresholdSnapshot {
175            energy_baseline: detector_state.energy_baseline,
176            flux_baseline: detector_state.flux_baseline,
177            dynamic_threshold: detector_state.dynamic_threshold,
178        };
179        drop(detector_state);
180
181        let latency = AudioInstant::now().duration_since(detection_start);
182        let snapshot = VadMetricsSnapshot::new(stats, latency, adaptive);
183        self.metrics.record_vad_metrics(&snapshot);
184
185        Ok(chunks)
186    }
187
188    fn frame_signal(
189        &self,
190        samples: &[f32],
191        absolute_start: usize,
192        state: &mut DetectorState,
193    ) -> Result<Vec<Frame>> {
194        if samples.is_empty() {
195            return Ok(Vec::new());
196        }
197
198        let processed = self.preprocess_signal(samples, state);
199        let frame_length = self.config.frame_length_samples()?;
200        let hop_length = self.config.hop_length_samples()?;
201
202        if frame_length == 0 {
203            return Err(Error::Processing("frame length resolved to zero".into()));
204        }
205
206        let mut frames = Vec::new();
207        let mut start = 0usize;
208
209        while start + frame_length <= processed.len() {
210            #[allow(clippy::indexing_slicing)] // bounds checked by while condition
211            let slice = &processed[start..start + frame_length];
212            let mut frame = Vec::with_capacity(frame_length);
213            frame.extend(
214                slice
215                    .iter()
216                    .zip(&self.window)
217                    .map(|(sample, window)| sample * window),
218            );
219            frames.push(Frame {
220                data: frame,
221                start_sample: absolute_start + start,
222                valid_len: frame_length,
223            });
224            start += hop_length;
225        }
226
227        if start < processed.len() {
228            if let Some(slice) = processed.get(start..) {
229                let available = slice.len().min(frame_length);
230                let mut frame = Vec::with_capacity(frame_length);
231                frame.extend(
232                    slice
233                        .iter()
234                        .zip(&self.window)
235                        .map(|(sample, window)| sample * window),
236                );
237                frame.resize(frame_length, 0.0);
238                frames.push(Frame {
239                    data: frame,
240                    start_sample: absolute_start + start,
241                    valid_len: available,
242                });
243            }
244        }
245
246        Ok(frames)
247    }
248
249    fn preprocess_signal(&self, samples: &[f32], state: &mut DetectorState) -> Vec<f32> {
250        match self.config.pre_emphasis {
251            Some(coeff) if coeff > 0.0 => {
252                let mut processed = Vec::with_capacity(samples.len());
253                let mut previous = state.pre_emphasis_prev;
254                for &sample in samples {
255                    let emphasized = coeff.mul_add(-previous, sample);
256                    processed.push(emphasized);
257                    previous = sample;
258                }
259                if let Some(&last) = samples.last() {
260                    state.pre_emphasis_prev = last;
261                }
262                processed
263            }
264            _ => {
265                if let Some(&last) = samples.last() {
266                    state.pre_emphasis_prev = last;
267                }
268                samples.to_vec()
269            }
270        }
271    }
272
273    fn compute_energy(frames: &[Frame]) -> Vec<f32> {
274        let mut values = Vec::with_capacity(frames.len());
275
276        for frame in frames {
277            debug_assert!(!frame.data.is_empty(), "frame data should never be empty");
278            let sum_sq: f32 = frame.data.iter().map(|sample| sample * sample).sum();
279            let len = frame.data.len();
280            let rms = (sum_sq / len as f32).sqrt();
281
282            values.push(rms);
283        }
284
285        values
286    }
287
288    fn compute_spectral_flux(
289        &self,
290        frames: &[Frame],
291        state: &mut DetectorState,
292    ) -> Result<Vec<f32>> {
293        if frames.is_empty() {
294            return Ok(Vec::new());
295        }
296
297        let mut input = self.fft.make_input_vec();
298        let mut spectrum = self.fft.make_output_vec();
299        let mut scratch = self.fft.make_scratch_vec();
300        if state.previous_spectrum.len() != spectrum.len() {
301            state.previous_spectrum.resize(spectrum.len(), 0.0);
302        }
303        let previous = &mut state.previous_spectrum;
304
305        let mut values = Vec::with_capacity(frames.len());
306
307        for frame in frames {
308            debug_assert!(!frame.data.is_empty(), "frame data should never be empty");
309            input.fill(0.0);
310            let len = frame.data.len().min(input.len());
311            for (dst, &src) in input.iter_mut().zip(frame.data.iter()).take(len) {
312                *dst = src;
313            }
314
315            self.fft
316                .process_with_scratch(&mut input, &mut spectrum, &mut scratch)
317                .map_err(|err| Error::Processing(format!("FFT processing failed: {err}")))?;
318
319            let mut flux = 0.0f32;
320            for (bin, prev) in spectrum.iter().zip(previous.iter_mut()) {
321                let magnitude = bin.re.hypot(bin.im);
322                let diff = (magnitude - *prev).max(0.0);
323                flux += diff;
324                *prev = magnitude;
325            }
326
327            values.push(flux);
328        }
329
330        Ok(values)
331    }
332
333    fn merge_metrics(
334        &self,
335        frames: &[Frame],
336        energy: &[f32],
337        flux: &[f32],
338        chunk_end_sample: usize,
339        detector_state: &mut DetectorState,
340    ) -> Result<(Vec<SpeechChunk>, VADStats)> {
341        let mut stats = VADStats::new();
342        let mut segments = Vec::new();
343
344        let mut dynamic_threshold = detector_state.dynamic_threshold.max(EPSILON);
345        let mut energy_baseline = detector_state
346            .energy_baseline
347            .max(self.config.energy_floor)
348            .max(EPSILON);
349        let mut flux_baseline = detector_state
350            .flux_baseline
351            .max(self.config.flux_floor)
352            .max(EPSILON);
353
354        let silence_energy_smoothing = self.config.energy_smoothing.min(MAX_SMOOTHING_FACTOR);
355        let silence_flux_smoothing = self.config.flux_smoothing.min(MAX_SMOOTHING_FACTOR);
356        let silence_threshold_smoothing = self.config.threshold_smoothing.min(MAX_SMOOTHING_FACTOR);
357
358        let dynamic_threshold_min =
359            (self.config.base_threshold * self.config.release_margin).max(EPSILON);
360        let dynamic_threshold_max =
361            self.config.base_threshold * self.config.activation_margin * 2.0;
362
363        let mut active_segment = detector_state.active_segment.take();
364        let mut silence_run = active_segment
365            .as_ref()
366            .map_or(0usize, |state| state.silence_run);
367
368        for (idx, frame) in frames.iter().enumerate() {
369            let frame_start = AudioInstant::now();
370            let raw_energy = energy.get(idx).copied().ok_or_else(|| {
371                Error::Processing(format!("energy array length mismatch at index {idx}"))
372            })?;
373            let raw_flux = flux.get(idx).copied().ok_or_else(|| {
374                Error::Processing(format!("flux array length mismatch at index {idx}"))
375            })?;
376
377            let energy_denominator = energy_baseline.max(self.config.energy_floor).max(EPSILON);
378            let normalized_energy =
379                (raw_energy / energy_denominator).clamp(0.0, MAX_NORMALIZED_METRIC);
380            let flux_denominator = flux_baseline.max(self.config.flux_floor).max(EPSILON);
381            let normalized_flux = (raw_flux / flux_denominator).clamp(0.0, MAX_NORMALIZED_METRIC);
382            let energy_ratio = raw_energy / energy_denominator;
383
384            let combined = self
385                .energy_weight
386                .mul_add(normalized_energy, self.flux_weight * normalized_flux);
387
388            let base_threshold = if active_segment.is_some() {
389                dynamic_threshold * self.config.release_margin
390            } else {
391                dynamic_threshold * self.config.activation_margin
392            };
393            let threshold =
394                base_threshold.max(self.config.base_threshold * self.config.release_margin);
395            let low_energy = raw_energy < SILENCE_ENERGY_GATE;
396            let low_relative_energy = energy_ratio < SILENCE_RELATIVE_GATE;
397            let mut raw_is_speech = combined >= threshold;
398            if raw_is_speech && (low_energy || low_relative_energy) {
399                raw_is_speech = false;
400            }
401
402            let is_speech = if active_segment.is_some() {
403                if raw_is_speech {
404                    silence_run = 0;
405                    true
406                } else {
407                    silence_run += 1;
408                    silence_run <= self.config.hangover_frames
409                }
410            } else {
411                silence_run = 0;
412                raw_is_speech
413            };
414
415            if is_speech {
416                let segment_state = active_segment
417                    .get_or_insert_with(|| ActiveSegmentState::new(frame.start_sample));
418                segment_state.score_sum += combined;
419                segment_state.energy_sum += raw_energy;
420                segment_state.frame_count += 1;
421                segment_state.last_end_sample = frame.start_sample + frame.valid_len.max(1);
422                segment_state.silence_run = silence_run;
423            } else if let Some(segment_state) = active_segment.take() {
424                let finalize_result =
425                    self.finalize_segment(&segment_state, chunk_end_sample, &mut segments);
426                if let Err(err) = finalize_result {
427                    detector_state.active_segment = Some(segment_state);
428                    return Err(err);
429                }
430                silence_run = 0;
431            }
432
433            let _frame_processing = AudioInstant::now().duration_since(frame_start);
434            stats.frames_processed += 1;
435
436            // Update baselines only during silence to avoid noise floor contamination
437            if !is_speech {
438                dynamic_threshold = silence_threshold_smoothing.mul_add(
439                    dynamic_threshold,
440                    (1.0 - silence_threshold_smoothing) * combined,
441                );
442                energy_baseline = silence_energy_smoothing.mul_add(
443                    energy_baseline,
444                    (1.0 - silence_energy_smoothing) * raw_energy,
445                );
446                flux_baseline = silence_flux_smoothing
447                    .mul_add(flux_baseline, (1.0 - silence_flux_smoothing) * raw_flux);
448            }
449
450            dynamic_threshold =
451                dynamic_threshold.clamp(dynamic_threshold_min, dynamic_threshold_max);
452            energy_baseline = energy_baseline.max(self.config.energy_floor).max(EPSILON);
453            flux_baseline = flux_baseline.max(self.config.flux_floor).max(EPSILON);
454        }
455
456        detector_state.dynamic_threshold = dynamic_threshold;
457        detector_state.energy_baseline = energy_baseline;
458        detector_state.flux_baseline = flux_baseline;
459
460        // Preserve active segment for streaming continuity; flush to finalize
461        if let Some(mut segment_state) = active_segment {
462            segment_state.silence_run = silence_run;
463            detector_state.active_segment = Some(segment_state);
464        } else {
465            detector_state.active_segment = None;
466        }
467
468        Ok((segments, stats))
469    }
470
471    fn finalize_segment(
472        &self,
473        segment: &ActiveSegmentState,
474        chunk_end_sample: usize,
475        segments: &mut Vec<SpeechChunk>,
476    ) -> Result<()> {
477        if segment.last_end_sample <= segment.start_sample {
478            return Ok(());
479        }
480
481        if segment.frame_count < self.config.min_speech_frames {
482            return Ok(());
483        }
484
485        let clamped_end = segment
486            .last_end_sample
487            .min(chunk_end_sample.max(segment.start_sample + 1));
488        let start_time = self.absolute_time_for_sample(segment.start_sample)?;
489        let end_time = self.absolute_time_for_sample(clamped_end)?;
490
491        if end_time <= start_time {
492            return Ok(());
493        }
494
495        let confidence = (segment.score_sum / segment.frame_count as f32).clamp(0.0, 1.0);
496        let avg_energy = if segment.frame_count > 0 {
497            segment.energy_sum / segment.frame_count as f32
498        } else {
499            0.0
500        };
501
502        segments.push(SpeechChunk {
503            start_time,
504            end_time,
505            confidence,
506            avg_energy,
507            frame_count: segment.frame_count,
508        });
509
510        Ok(())
511    }
512
513    fn absolute_time_for_sample(&self, sample_index: usize) -> Result<AudioTimestamp> {
514        let offset = samples_to_duration(sample_index, self.config.sample_rate);
515        Ok(self.config.stream_start_time.add_duration(offset))
516    }
517}
518
519fn hann_window(length: usize) -> Vec<f32> {
520    if length == 0 {
521        return Vec::new();
522    }
523
524    if length == 1 {
525        return vec![1.0];
526    }
527
528    let denom = (length - 1) as f32;
529    (0..length)
530        .map(|n| {
531            let angle = 2.0 * std::f32::consts::PI * n as f32 / denom;
532            0.5f32.mul_add(-angle.cos(), 0.5)
533        })
534        .collect()
535}
536
537fn samples_to_duration(samples: usize, sample_rate: u32) -> AudioDuration {
538    let sr = u128::from(sample_rate);
539    let nanos = ((samples as u128) * NANOS_PER_SECOND + sr / 2) / sr;
540    AudioDuration::from_nanos(nanos as u64)
541}
542
543struct Frame {
544    data: Vec<f32>,
545    start_sample: usize,
546    valid_len: usize,
547}
548
549pub(super) struct DetectorState {
550    pub(super) energy_baseline: f32,
551    pub(super) flux_baseline: f32,
552    pub(super) dynamic_threshold: f32,
553    pub(super) previous_spectrum: Vec<f32>,
554    pub(super) pre_emphasis_prev: f32,
555    pub(super) active_segment: Option<ActiveSegmentState>,
556}
557
558pub(super) struct ActiveSegmentState {
559    pub(super) start_sample: usize,
560    pub(super) last_end_sample: usize,
561    pub(super) score_sum: f32,
562    pub(super) energy_sum: f32,
563    pub(super) frame_count: usize,
564    pub(super) silence_run: usize,
565}
566
567impl ActiveSegmentState {
568    pub(super) fn new(start_sample: usize) -> Self {
569        Self {
570            start_sample,
571            last_end_sample: start_sample,
572            score_sum: 0.0,
573            energy_sum: 0.0,
574            frame_count: 0,
575            silence_run: 0,
576        }
577    }
578}
579
580/// Speech segment with temporal metadata emitted by the detector.
581#[derive(Debug, Clone, Copy, PartialEq)]
582pub struct SpeechChunk {
583    /// Start time of the detected speech segment.
584    pub start_time: AudioTimestamp,
585    /// End time of the detected speech segment.
586    pub end_time: AudioTimestamp,
587    /// Aggregated confidence score derived from combined metrics.
588    pub confidence: f32,
589    /// Average energy observed within the segment.
590    pub avg_energy: f32,
591    /// Number of frames that contributed to the segment.
592    pub frame_count: usize,
593}
594
595impl SpeechChunk {
596    /// Duration of the speech segment.
597    pub fn duration(&self) -> Result<AudioDuration> {
598        self.end_time
599            .duration_since(self.start_time)
600            .ok_or_else(|| {
601                Error::Processing(
602                    "failed to compute segment duration: end_time precedes start_time".into(),
603                )
604            })
605    }
606}