Skip to main content

wavekat_vad/
lib.rs

1//! WaveKat VAD — Unified voice activity detection with multiple backends.
2//!
3//! This crate provides a common [`VoiceActivityDetector`] trait with
4//! implementations for different VAD backends, enabling experimentation
5//! and benchmarking across technologies.
6//!
7//! # Backends
8//!
9//! | Backend | Feature | Sample Rates | Frame Size | Output |
10//! |---------|---------|-------------|------------|--------|
11//! | [WebRTC](`backends::webrtc`) | `webrtc` (default) | 8/16/32/48 kHz | 10, 20, or 30ms | Binary (0.0 or 1.0) |
12//! | [Silero](`backends::silero`) | `silero` | 8/16 kHz | 32ms | Continuous (0.0–1.0) |
13//! | [TEN-VAD](`backends::ten_vad`) | `ten-vad` | 16 kHz only | 16ms | Continuous (0.0–1.0) |
14//! | [FireRedVAD](`backends::firered`) | `firered` | 16 kHz only | 10ms | Continuous (0.0–1.0) |
15//!
16//! # Quick start
17//!
18//! Add the crate with the backend you need:
19//!
20//! ```toml
21//! [dependencies]
22//! wavekat-vad = "0.1"                                  # WebRTC only (default)
23//! wavekat-vad = { version = "0.1", features = ["silero"] }  # Silero
24//! wavekat-vad = { version = "0.1", features = ["ten-vad"] } # TEN-VAD
25//! wavekat-vad = { version = "0.1", features = ["firered"] } # FireRedVAD
26//! ```
27//!
28//! Then create a detector and process audio frames:
29//!
30//! ```no_run
31//! # #[cfg(feature = "webrtc")]
32//! # {
33//! use wavekat_vad::VoiceActivityDetector;
34//! use wavekat_vad::backends::webrtc::{WebRtcVad, WebRtcVadMode};
35//!
36//! let mut vad = WebRtcVad::new(16000, WebRtcVadMode::Quality).unwrap();
37//! let samples = vec![0i16; 480]; // 30ms at 16kHz
38//! let probability = vad.process(&samples, 16000).unwrap();
39//! println!("Speech probability: {probability}");
40//! # }
41//! ```
42//!
43//! # Writing backend-generic code
44//!
45//! All backends implement [`VoiceActivityDetector`], so you can write code
46//! that works with any backend:
47//!
48//! ```no_run
49//! use wavekat_vad::VoiceActivityDetector;
50//!
51//! fn detect_speech(vad: &mut dyn VoiceActivityDetector, audio: &[i16], sample_rate: u32) {
52//!     let caps = vad.capabilities();
53//!     for frame in audio.chunks_exact(caps.frame_size) {
54//!         let prob = vad.process(frame, sample_rate).unwrap();
55//!         if prob > 0.5 {
56//!             println!("Speech detected!");
57//!         }
58//!     }
59//! }
60//! ```
61//!
62//! # Handling arbitrary chunk sizes
63//!
64//! Real-world audio often arrives in chunks that don't match the backend's
65//! required frame size. Use [`FrameAdapter`] to buffer and split automatically:
66//!
67//! ```no_run
68//! # #[cfg(feature = "webrtc")]
69//! # {
70//! use wavekat_vad::FrameAdapter;
71//! use wavekat_vad::backends::webrtc::{WebRtcVad, WebRtcVadMode};
72//!
73//! let vad = WebRtcVad::new(16000, WebRtcVadMode::Quality).unwrap();
74//! let mut adapter = FrameAdapter::new(Box::new(vad));
75//!
76//! let chunk = vec![0i16; 1000]; // arbitrary size
77//! let results = adapter.process_all(&chunk, 16000).unwrap();
78//! for prob in &results {
79//!     println!("{prob:.3}");
80//! }
81//! # }
82//! ```
83//!
84//! # Audio preprocessing
85//!
86//! Optional preprocessing stages can improve accuracy with noisy input.
87//! See the [`preprocessing`] module for details.
88//!
89//! ```
90//! use wavekat_vad::preprocessing::{Preprocessor, PreprocessorConfig};
91//!
92//! let config = PreprocessorConfig::raw_mic(); // 80Hz HP + normalization
93//! let mut preprocessor = Preprocessor::new(&config, 16000);
94//! let raw: Vec<i16> = vec![0; 512];
95//! let cleaned = preprocessor.process(&raw);
96//! // feed `cleaned` to your VAD
97//! ```
98//!
99//! # Feature flags
100//!
101//! | Feature | Default | Description |
102//! |---------|---------|-------------|
103//! | `webrtc` | Yes | WebRTC VAD backend |
104//! | `silero` | No | Silero VAD backend (ONNX model downloaded at build time) |
105//! | `ten-vad` | No | TEN-VAD backend (ONNX model downloaded at build time) |
106//! | `firered` | No | FireRedVAD backend (ONNX model + CMVN downloaded at build time) |
107//! | `denoise` | No | RNNoise-based noise suppression in [`preprocessing`] |
108//! | `serde` | No | `Serialize`/`Deserialize` for config types |
109//!
110//! ## ONNX model downloads
111//!
112//! The Silero, TEN-VAD, and FireRedVAD backends download their ONNX models
113//! automatically at build time. The Silero backend is pinned to **v6.2.1** by
114//! default.
115//!
116//! For offline or CI builds, set environment variables to point to local model
117//! files:
118//!
119//! ```sh
120//! SILERO_MODEL_PATH=/path/to/silero_vad.onnx cargo build --features silero
121//! TEN_VAD_MODEL_PATH=/path/to/ten-vad.onnx cargo build --features ten-vad
122//! FIRERED_MODEL_PATH=/path/to/model.onnx FIRERED_CMVN_PATH=/path/to/cmvn.ark cargo build --features firered
123//! ```
124//!
125//! To use a different Silero model version, override the download URL:
126//!
127//! ```sh
128//! SILERO_MODEL_URL=https://github.com/snakers4/silero-vad/raw/v6.0/src/silero_vad/data/silero_vad.onnx cargo build --features silero
129//! ```
130//!
131//! # Error handling
132//!
133//! All backends return [`Result<f32, VadError>`]. Check a backend's
134//! requirements with [`VoiceActivityDetector::capabilities()`] before processing:
135//!
136//! - [`VadError::InvalidSampleRate`] — unsupported sample rate
137//! - [`VadError::InvalidFrameSize`] — wrong number of samples
138//! - [`VadError::BackendError`] — backend-specific error (e.g. ONNX failure)
139//!
140//! # Examples
141//!
142//! Runnable examples are in the
143//! [`examples/`](https://github.com/wavekat/wavekat-vad/tree/main/crates/wavekat-vad/examples)
144//! directory:
145//!
146//! - **[`detect_speech`](https://github.com/wavekat/wavekat-vad/blob/main/crates/wavekat-vad/examples/detect_speech.rs)** —
147//!   Detect speech in a WAV file using any backend
148//! - **[`ten_vad_file`](https://github.com/wavekat/wavekat-vad/blob/main/crates/wavekat-vad/examples/ten_vad_file.rs)** —
149//!   Process a WAV file with TEN-VAD directly
150//!
151//! ```sh
152//! cargo run --example detect_speech -- audio.wav
153//! cargo run --example detect_speech --features silero -- -b silero audio.wav
154//! cargo run --example ten_vad_file --features ten-vad -- audio.wav
155//! ```
156//!
157//! # TEN-VAD model license
158//!
159//! The TEN-VAD ONNX model is licensed under Apache-2.0 with a non-compete clause
160//! by the TEN-framework / Agora. It restricts deployment that competes with Agora's
161//! offerings. Review the [TEN-VAD license](https://github.com/TEN-framework/ten-vad)
162//! before using in production.
163
164pub mod adapter;
165pub mod backends;
166pub mod error;
167pub mod frame;
168pub mod preprocessing;
169
170pub use adapter::FrameAdapter;
171
172pub use error::VadError;
173
174use std::time::Duration;
175
176/// Accumulated processing time breakdown by named pipeline stage.
177///
178/// Each backend defines its own stages (e.g. `"fbank"`, `"cmvn"`, `"onnx"`),
179/// so you can see exactly where time is spent without hardcoding a fixed set
180/// of fields. Stages are returned in pipeline order.
181///
182/// Call [`VoiceActivityDetector::timings()`] to retrieve the current values.
183/// Timings accumulate across all calls to [`process()`](VoiceActivityDetector::process)
184/// and are **not** reset by [`reset()`](VoiceActivityDetector::reset).
185///
186/// # Example
187///
188/// ```ignore
189/// let t = vad.timings();
190/// for (name, dur) in &t.stages {
191///     let avg_us = dur.as_secs_f64() * 1_000_000.0 / t.frames as f64;
192///     println!("{name}: {avg_us:.1} µs/frame");
193/// }
194/// ```
195#[derive(Debug, Clone, Default)]
196pub struct ProcessTimings {
197    /// Named timing stages in pipeline order.
198    ///
199    /// Each entry is `(stage_name, accumulated_duration)`. The stage names
200    /// are backend-specific — for example FireRedVAD reports `"fbank"`,
201    /// `"cmvn"`, and `"onnx"`, while Silero reports `"normalize"` and `"onnx"`.
202    pub stages: Vec<(&'static str, Duration)>,
203    /// Number of frames that produced a result (excludes buffering-only frames).
204    pub frames: u64,
205}
206
207/// Describes the audio requirements of a VAD backend.
208#[derive(Debug, Clone, PartialEq, Eq)]
209pub struct VadCapabilities {
210    /// Sample rate in Hz.
211    pub sample_rate: u32,
212    /// Required frame size in samples.
213    pub frame_size: usize,
214    /// Frame duration in milliseconds (derived from sample_rate and frame_size).
215    pub frame_duration_ms: u32,
216}
217
218/// Common interface for voice activity detection backends.
219///
220/// Each backend implements this trait, allowing callers to swap
221/// implementations without changing their processing logic.
222pub trait VoiceActivityDetector: Send {
223    /// Returns the audio requirements of this detector.
224    ///
225    /// Use this to determine the expected sample rate and frame size
226    /// before calling [`process`](Self::process).
227    fn capabilities(&self) -> VadCapabilities;
228
229    /// Process an audio frame and return the probability of speech.
230    ///
231    /// Returns a value between `0.0` (silence) and `1.0` (speech).
232    /// Some backends (e.g. WebRTC) return only binary values (`0.0` or `1.0`),
233    /// while others (e.g. Silero) return continuous probabilities.
234    ///
235    /// # Arguments
236    ///
237    /// * `samples` — Audio samples as 16-bit signed integers, mono channel.
238    ///   Must match the `frame_size` from [`capabilities`](Self::capabilities).
239    /// * `sample_rate` — Sample rate in Hz (must match the rate the detector was created with).
240    ///
241    /// # Errors
242    ///
243    /// Returns [`VadError`] if the sample rate or frame size is invalid,
244    /// or if the backend encounters a processing error.
245    fn process(&mut self, samples: &[i16], sample_rate: u32) -> Result<f32, VadError>;
246
247    /// Reset the detector's internal state.
248    ///
249    /// Call this when starting a new audio stream or after a long pause.
250    /// Does **not** reset accumulated [`timings()`](Self::timings).
251    fn reset(&mut self);
252
253    /// Return accumulated processing time breakdown.
254    ///
255    /// Timings accumulate across all calls to [`process()`](Self::process)
256    /// and persist through [`reset()`](Self::reset). Returns default
257    /// (zero) timings if the backend does not track them.
258    fn timings(&self) -> ProcessTimings {
259        ProcessTimings::default()
260    }
261}