voice_engine/media/vad/
mod.rs

1use crate::event::{EventSender, SessionEvent};
2use crate::media::processor::Processor;
3use crate::media::{AudioFrame, PcmBuf, Samples};
4use anyhow::Result;
5use serde::{Deserialize, Serialize};
6use serde_with::skip_serializing_none;
7use std::any::Any;
8use std::cell::RefCell;
9use tokio_util::sync::CancellationToken;
10#[cfg(feature = "vad_silero")]
11mod silero;
12#[cfg(feature = "vad_ten")]
13mod ten;
14
15#[cfg(test)]
16mod tests;
17#[cfg(feature = "vad_webrtc")]
18mod webrtc;
19
20#[skip_serializing_none]
21#[derive(Clone, Debug, Deserialize, Serialize)]
22#[serde(rename_all = "camelCase")]
23#[serde(default)]
24pub struct VADOption {
25    pub r#type: VadType,
26    pub samplerate: u32,
27    /// Padding before speech detection (in ms)
28    pub speech_padding: u64,
29    /// Padding after silence detection (in ms)
30    pub silence_padding: u64,
31    pub ratio: f32,
32    pub voice_threshold: f32,
33    pub max_buffer_duration_secs: u64,
34    /// Timeout duration for silence (in ms), None means disable this feature
35    pub silence_timeout: Option<u64>,
36    pub endpoint: Option<String>,
37    pub secret_key: Option<String>,
38    pub secret_id: Option<String>,
39}
40
41impl Default for VADOption {
42    fn default() -> Self {
43        Self {
44            #[cfg(feature = "vad_webrtc")]
45            r#type: VadType::WebRTC,
46            #[cfg(all(
47                not(feature = "vad_webrtc"),
48                not(feature = "vad_ten"),
49                feature = "vad_silero"
50            ))]
51            r#type: VadType::Silero,
52            #[cfg(all(
53                not(feature = "vad_webrtc"),
54                not(feature = "vad_silero"),
55                feature = "vad_ten"
56            ))]
57            r#type: VadType::Ten,
58            #[cfg(all(
59                not(feature = "vad_webrtc"),
60                not(feature = "vad_silero"),
61                not(feature = "vad_ten"),
62            ))]
63            r#type: VadType::Other("nop".to_string()),
64            samplerate: 16000,
65            // Python defaults: min_speech_duration_ms=250, min_silence_duration_ms=100, speech_pad_ms=30
66            speech_padding: 250,  // min_speech_duration_ms
67            silence_padding: 100, // min_silence_duration_ms
68            ratio: 0.5,
69            voice_threshold: 0.5,
70            max_buffer_duration_secs: 50,
71            silence_timeout: None,
72            endpoint: None,
73            secret_key: None,
74            secret_id: None,
75        }
76    }
77}
78
79#[derive(Clone, Debug, Serialize, Eq, Hash, PartialEq)]
80#[serde(rename_all = "lowercase")]
81pub enum VadType {
82    #[cfg(feature = "vad_webrtc")]
83    WebRTC,
84    #[cfg(feature = "vad_silero")]
85    Silero,
86    #[cfg(feature = "vad_ten")]
87    Ten,
88    Other(String),
89}
90
91impl<'de> Deserialize<'de> for VadType {
92    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
93    where
94        D: serde::Deserializer<'de>,
95    {
96        let value = String::deserialize(deserializer)?;
97        match value.as_str() {
98            #[cfg(feature = "vad_webrtc")]
99            "webrtc" => Ok(VadType::WebRTC),
100            #[cfg(feature = "vad_silero")]
101            "silero" => Ok(VadType::Silero),
102            #[cfg(feature = "vad_ten")]
103            "ten" => Ok(VadType::Ten),
104            _ => Ok(VadType::Other(value)),
105        }
106    }
107}
108
109impl std::fmt::Display for VadType {
110    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111        match self {
112            #[cfg(feature = "vad_webrtc")]
113            VadType::WebRTC => write!(f, "webrtc"),
114            #[cfg(feature = "vad_silero")]
115            VadType::Silero => write!(f, "silero"),
116            #[cfg(feature = "vad_ten")]
117            VadType::Ten => write!(f, "ten"),
118            VadType::Other(provider) => write!(f, "{}", provider),
119        }
120    }
121}
122
123impl TryFrom<&String> for VadType {
124    type Error = String;
125
126    fn try_from(value: &String) -> std::result::Result<Self, Self::Error> {
127        match value.as_str() {
128            #[cfg(feature = "vad_webrtc")]
129            "webrtc" => Ok(VadType::WebRTC),
130            #[cfg(feature = "vad_silero")]
131            "silero" => Ok(VadType::Silero),
132            #[cfg(feature = "vad_ten")]
133            "ten" => Ok(VadType::Ten),
134            other => Ok(VadType::Other(other.to_string())),
135        }
136    }
137}
138struct SpeechBuf {
139    samples: PcmBuf,
140    timestamp: u64,
141}
142
143struct VadProcessorInner {
144    vad: Box<dyn VadEngine>,
145    event_sender: EventSender,
146    option: VADOption,
147    window_bufs: Vec<SpeechBuf>,
148    triggered: bool,
149    current_speech_start: Option<u64>,
150    temp_end: Option<u64>,
151}
152pub struct VadProcessor {
153    inner: RefCell<VadProcessorInner>,
154}
155unsafe impl Send for VadProcessor {}
156unsafe impl Sync for VadProcessor {}
157
158pub trait VadEngine: Send + Sync + Any {
159    fn process(&mut self, frame: &mut AudioFrame) -> Option<(bool, u64)>;
160}
161
162impl VadProcessorInner {
163    pub fn process_frame(&mut self, frame: &mut AudioFrame) -> Result<()> {
164        let samples = match &frame.samples {
165            Samples::PCM { samples } => samples,
166            _ => return Ok(()),
167        };
168
169        let samples = samples.to_owned();
170        let result = self.vad.process(frame);
171        if let Some((is_speaking, timestamp)) = result {
172            if is_speaking || self.triggered {
173                let current_buf = SpeechBuf { samples, timestamp };
174                self.window_bufs.push(current_buf);
175            }
176            self.process_vad_logic(is_speaking, timestamp, &frame.track_id)?;
177
178            // Clean up old buffers periodically
179            if self.window_bufs.len() > 1000 || !self.triggered {
180                let cutoff = if self.triggered {
181                    timestamp.saturating_sub(5000)
182                } else {
183                    timestamp.saturating_sub(self.option.silence_padding)
184                };
185                self.window_bufs.retain(|buf| buf.timestamp > cutoff);
186            }
187        }
188
189        Ok(())
190    }
191
192    fn process_vad_logic(
193        &mut self,
194        is_speaking: bool,
195        timestamp: u64,
196        track_id: &str,
197    ) -> Result<()> {
198        if is_speaking && !self.triggered {
199            self.triggered = true;
200            self.current_speech_start = Some(timestamp);
201            let event = SessionEvent::Speaking {
202                track_id: track_id.to_string(),
203                timestamp: crate::media::get_timestamp(),
204                start_time: timestamp,
205            };
206            self.event_sender.send(event).ok();
207        } else if !is_speaking {
208            if self.temp_end.is_none() {
209                self.temp_end = Some(timestamp);
210            }
211
212            if let Some(temp_end) = self.temp_end {
213                // Use saturating_sub to handle timestamp wrapping or out-of-order frames
214                let silence_duration = timestamp.saturating_sub(temp_end);
215
216                // Process regular silence detection for speech segments
217                if self.triggered && silence_duration >= self.option.silence_padding {
218                    if let Some(start_time) = self.current_speech_start {
219                        // Use safe duration calculation
220                        let duration = temp_end.saturating_sub(start_time);
221                        if duration >= self.option.speech_padding {
222                            let samples_vec = self
223                                .window_bufs
224                                .iter()
225                                .filter(|buf| {
226                                    buf.timestamp >= start_time && buf.timestamp <= temp_end
227                                })
228                                .flat_map(|buf| buf.samples.iter())
229                                .cloned()
230                                .collect();
231                            self.window_bufs.clear();
232
233                            let event = SessionEvent::Silence {
234                                track_id: track_id.to_string(),
235                                timestamp: crate::media::get_timestamp(),
236                                start_time,
237                                duration,
238                                samples: Some(samples_vec),
239                            };
240                            self.event_sender.send(event).ok();
241                        }
242                    }
243                    self.triggered = false;
244                    self.current_speech_start = None;
245                    self.temp_end = Some(timestamp); // Update temp_end for silence timeout tracking
246                }
247
248                // Process silence timeout if configured
249                if let Some(timeout) = self.option.silence_timeout {
250                    // Use same safe calculation for silence timeout
251                    let timeout_duration = timestamp.saturating_sub(temp_end);
252
253                    if timeout_duration >= timeout {
254                        let event = SessionEvent::Silence {
255                            track_id: track_id.to_string(),
256                            timestamp: crate::media::get_timestamp(),
257                            start_time: temp_end,
258                            duration: timeout_duration,
259                            samples: None,
260                        };
261                        self.event_sender.send(event).ok();
262                        self.temp_end = Some(timestamp);
263                    }
264                }
265            }
266        } else if is_speaking && self.temp_end.is_some() {
267            self.temp_end = None;
268        }
269
270        Ok(())
271    }
272}
273
274impl VadProcessor {
275    #[cfg(feature = "vad_webrtc")]
276    pub fn create_webrtc(
277        _token: CancellationToken,
278        event_sender: EventSender,
279        option: VADOption,
280    ) -> Result<Box<dyn Processor>> {
281        let vad: Box<dyn VadEngine> = match option.r#type {
282            VadType::WebRTC => Box::new(webrtc::WebRtcVad::new(option.samplerate)?),
283            _ => Box::new(NopVad::new()?),
284        };
285        Ok(Box::new(VadProcessor::new(vad, event_sender, option)?))
286    }
287    #[cfg(feature = "vad_silero")]
288    pub fn create_silero(
289        _token: CancellationToken,
290        event_sender: EventSender,
291        option: VADOption,
292    ) -> Result<Box<dyn Processor>> {
293        let vad: Box<dyn VadEngine> = match option.r#type {
294            VadType::Silero => Box::new(silero::SileroVad::new(option.clone())?),
295            _ => Box::new(NopVad::new()?),
296        };
297        Ok(Box::new(VadProcessor::new(vad, event_sender, option)?))
298    }
299    #[cfg(feature = "vad_ten")]
300    pub fn create_ten(
301        _token: CancellationToken,
302        event_sender: EventSender,
303        option: VADOption,
304    ) -> Result<Box<dyn Processor>> {
305        let vad: Box<dyn VadEngine> = match option.r#type {
306            VadType::Ten => Box::new(ten::TenVad::new(option.clone())?),
307            _ => Box::new(NopVad::new()?),
308        };
309        Ok(Box::new(VadProcessor::new(vad, event_sender, option)?))
310    }
311
312    pub fn create_nop(
313        _token: CancellationToken,
314        event_sender: EventSender,
315        option: VADOption,
316    ) -> Result<Box<dyn Processor>> {
317        let vad: Box<dyn VadEngine> = match option.r#type {
318            _ => Box::new(NopVad::new()?),
319        };
320        Ok(Box::new(VadProcessor::new(vad, event_sender, option)?))
321    }
322
323    pub fn new(
324        engine: Box<dyn VadEngine>,
325        event_sender: EventSender,
326        option: VADOption,
327    ) -> Result<Self> {
328        let inner = VadProcessorInner {
329            vad: engine,
330            event_sender,
331            option,
332            window_bufs: Vec::new(),
333            triggered: false,
334            current_speech_start: None,
335            temp_end: None,
336        };
337        Ok(Self {
338            inner: RefCell::new(inner),
339        })
340    }
341}
342
343impl Processor for VadProcessor {
344    fn process_frame(&self, frame: &mut AudioFrame) -> Result<()> {
345        self.inner.borrow_mut().process_frame(frame)
346    }
347}
348
349struct NopVad {}
350
351impl NopVad {
352    pub fn new() -> Result<Self> {
353        Ok(Self {})
354    }
355}
356
357impl VadEngine for NopVad {
358    fn process(&mut self, frame: &mut AudioFrame) -> Option<(bool, u64)> {
359        let samples = match &frame.samples {
360            Samples::PCM { samples } => samples,
361            _ => return Some((false, frame.timestamp)),
362        };
363        // Check if there are any non-zero samples
364        let has_speech = samples.iter().any(|&x| x != 0);
365        Some((has_speech, frame.timestamp))
366    }
367}