voice_engine/media/vad/
mod.rs

1use crate::event::{EventSender, SessionEvent};
2use crate::media::processor::Processor;
3use crate::media::{AudioFrame, PcmBuf, Samples};
4use anyhow::Result;
5use serde::{Deserialize, Serialize};
6use serde_with::skip_serializing_none;
7use std::any::Any;
8use std::cell::RefCell;
9use tokio_util::sync::CancellationToken;
10#[cfg(feature = "vad_silero")]
11mod silero;
12#[cfg(feature = "vad_ten")]
13mod ten;
14#[cfg(test)]
15mod tests;
16#[cfg(feature = "vad_webrtc")]
17mod webrtc;
18use std::sync::Mutex;
19
20#[allow(unused)]
21pub(crate) struct SessionPool<T, F>
22where
23    F: Fn() -> Result<T> + Send + Sync + 'static,
24{
25    max_items: usize,
26    items: Mutex<Vec<T>>,
27    factory: F,
28}
29
30#[allow(unused)]
31impl<T, F> SessionPool<T, F>
32where
33    F: Fn() -> Result<T> + Send + Sync + 'static,
34{
35    pub fn new(max_items: usize, factory: F) -> Self {
36        Self {
37            max_items,
38            items: Mutex::new(Vec::with_capacity(max_items)),
39            factory,
40        }
41    }
42
43    pub fn pop_or_create(&self) -> Result<T> {
44        let mut guard = self.items.lock().unwrap();
45        match guard.pop() {
46            Some(item) => Ok(item),
47            None => (self.factory)(),
48        }
49    }
50
51    pub fn push(&self, item: T) {
52        let mut guard = self.items.lock().unwrap();
53        if guard.len() < self.max_items {
54            guard.push(item);
55            return;
56        }
57        drop(guard);
58    }
59}
60
61#[skip_serializing_none]
62#[derive(Clone, Debug, Deserialize, Serialize)]
63#[serde(rename_all = "camelCase")]
64#[serde(default)]
65pub struct VADOption {
66    pub r#type: VadType,
67    pub samplerate: u32,
68    /// Padding before speech detection (in ms)
69    pub speech_padding: u64,
70    /// Padding after silence detection (in ms)
71    pub silence_padding: u64,
72    pub ratio: f32,
73    pub voice_threshold: f32,
74    pub max_buffer_duration_secs: u64,
75    /// Timeout duration for silence (in ms), None means disable this feature
76    pub silence_timeout: Option<u64>,
77    pub endpoint: Option<String>,
78    pub secret_key: Option<String>,
79    pub secret_id: Option<String>,
80}
81
82impl Default for VADOption {
83    fn default() -> Self {
84        Self {
85            #[cfg(feature = "vad_webrtc")]
86            r#type: VadType::WebRTC,
87            #[cfg(all(
88                not(feature = "vad_webrtc"),
89                not(feature = "vad_ten"),
90                feature = "vad_silero"
91            ))]
92            r#type: VadType::Silero,
93            #[cfg(all(
94                not(feature = "vad_webrtc"),
95                not(feature = "vad_silero"),
96                feature = "vad_ten"
97            ))]
98            r#type: VadType::Ten,
99            #[cfg(all(
100                not(feature = "vad_webrtc"),
101                not(feature = "vad_silero"),
102                not(feature = "vad_ten"),
103            ))]
104            r#type: VadType::Other("nop".to_string()),
105            samplerate: 16000,
106            // Python defaults: min_speech_duration_ms=250, min_silence_duration_ms=100, speech_pad_ms=30
107            speech_padding: 250,  // min_speech_duration_ms
108            silence_padding: 100, // min_silence_duration_ms
109            ratio: 0.5,
110            voice_threshold: 0.5,
111            max_buffer_duration_secs: 50,
112            silence_timeout: None,
113            endpoint: None,
114            secret_key: None,
115            secret_id: None,
116        }
117    }
118}
119
120#[derive(Clone, Debug, Serialize, Eq, Hash, PartialEq)]
121#[serde(rename_all = "lowercase")]
122pub enum VadType {
123    #[cfg(feature = "vad_webrtc")]
124    WebRTC,
125    #[cfg(feature = "vad_silero")]
126    Silero,
127    #[cfg(feature = "vad_ten")]
128    Ten,
129    Other(String),
130}
131
132impl<'de> Deserialize<'de> for VadType {
133    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
134    where
135        D: serde::Deserializer<'de>,
136    {
137        let value = String::deserialize(deserializer)?;
138        match value.as_str() {
139            #[cfg(feature = "vad_webrtc")]
140            "webrtc" => Ok(VadType::WebRTC),
141            #[cfg(feature = "vad_silero")]
142            "silero" => Ok(VadType::Silero),
143            #[cfg(feature = "vad_ten")]
144            "ten" => Ok(VadType::Ten),
145            _ => Ok(VadType::Other(value)),
146        }
147    }
148}
149
150impl std::fmt::Display for VadType {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        match self {
153            #[cfg(feature = "vad_webrtc")]
154            VadType::WebRTC => write!(f, "webrtc"),
155            #[cfg(feature = "vad_silero")]
156            VadType::Silero => write!(f, "silero"),
157            #[cfg(feature = "vad_ten")]
158            VadType::Ten => write!(f, "ten"),
159            VadType::Other(provider) => write!(f, "{}", provider),
160        }
161    }
162}
163
164impl TryFrom<&String> for VadType {
165    type Error = String;
166
167    fn try_from(value: &String) -> std::result::Result<Self, Self::Error> {
168        match value.as_str() {
169            #[cfg(feature = "vad_webrtc")]
170            "webrtc" => Ok(VadType::WebRTC),
171            #[cfg(feature = "vad_silero")]
172            "silero" => Ok(VadType::Silero),
173            #[cfg(feature = "vad_ten")]
174            "ten" => Ok(VadType::Ten),
175            other => Ok(VadType::Other(other.to_string())),
176        }
177    }
178}
179struct SpeechBuf {
180    samples: PcmBuf,
181    timestamp: u64,
182}
183
184struct VadProcessorInner {
185    vad: Box<dyn VadEngine>,
186    event_sender: EventSender,
187    option: VADOption,
188    window_bufs: Vec<SpeechBuf>,
189    triggered: bool,
190    current_speech_start: Option<u64>,
191    temp_end: Option<u64>,
192}
193pub struct VadProcessor {
194    inner: RefCell<VadProcessorInner>,
195}
196unsafe impl Send for VadProcessor {}
197unsafe impl Sync for VadProcessor {}
198
199pub trait VadEngine: Send + Sync + Any {
200    fn process(&mut self, frame: &mut AudioFrame) -> Option<(bool, u64)>;
201}
202
203impl VadProcessorInner {
204    pub fn process_frame(&mut self, frame: &mut AudioFrame) -> Result<()> {
205        let samples = match &frame.samples {
206            Samples::PCM { samples } => samples,
207            _ => return Ok(()),
208        };
209
210        let samples = samples.to_owned();
211        let result = self.vad.process(frame);
212        if let Some((is_speaking, timestamp)) = result {
213            if is_speaking || self.triggered {
214                let current_buf = SpeechBuf { samples, timestamp };
215                self.window_bufs.push(current_buf);
216            }
217            self.process_vad_logic(is_speaking, timestamp, &frame.track_id)?;
218
219            // Clean up old buffers periodically
220            if self.window_bufs.len() > 1000 || !self.triggered {
221                let cutoff = if self.triggered {
222                    timestamp.saturating_sub(5000)
223                } else {
224                    timestamp.saturating_sub(self.option.silence_padding)
225                };
226                self.window_bufs.retain(|buf| buf.timestamp > cutoff);
227            }
228        }
229
230        Ok(())
231    }
232
233    fn process_vad_logic(
234        &mut self,
235        is_speaking: bool,
236        timestamp: u64,
237        track_id: &str,
238    ) -> Result<()> {
239        if is_speaking && !self.triggered {
240            self.triggered = true;
241            self.current_speech_start = Some(timestamp);
242            let event = SessionEvent::Speaking {
243                track_id: track_id.to_string(),
244                timestamp: crate::media::get_timestamp(),
245                start_time: timestamp,
246            };
247            self.event_sender.send(event).ok();
248        } else if !is_speaking {
249            if self.temp_end.is_none() {
250                self.temp_end = Some(timestamp);
251            }
252
253            if let Some(temp_end) = self.temp_end {
254                // Use saturating_sub to handle timestamp wrapping or out-of-order frames
255                let silence_duration = timestamp.saturating_sub(temp_end);
256
257                // Process regular silence detection for speech segments
258                if self.triggered && silence_duration >= self.option.silence_padding {
259                    if let Some(start_time) = self.current_speech_start {
260                        // Use safe duration calculation
261                        let duration = temp_end.saturating_sub(start_time);
262                        if duration >= self.option.speech_padding {
263                            let samples_vec = self
264                                .window_bufs
265                                .iter()
266                                .filter(|buf| {
267                                    buf.timestamp >= start_time && buf.timestamp <= temp_end
268                                })
269                                .flat_map(|buf| buf.samples.iter())
270                                .cloned()
271                                .collect();
272                            self.window_bufs.clear();
273
274                            let event = SessionEvent::Silence {
275                                track_id: track_id.to_string(),
276                                timestamp: crate::media::get_timestamp(),
277                                start_time,
278                                duration,
279                                samples: Some(samples_vec),
280                            };
281                            self.event_sender.send(event).ok();
282                        }
283                    }
284                    self.triggered = false;
285                    self.current_speech_start = None;
286                    self.temp_end = Some(timestamp); // Update temp_end for silence timeout tracking
287                }
288
289                // Process silence timeout if configured
290                if let Some(timeout) = self.option.silence_timeout {
291                    // Use same safe calculation for silence timeout
292                    let timeout_duration = timestamp.saturating_sub(temp_end);
293
294                    if timeout_duration >= timeout {
295                        let event = SessionEvent::Silence {
296                            track_id: track_id.to_string(),
297                            timestamp: crate::media::get_timestamp(),
298                            start_time: temp_end,
299                            duration: timeout_duration,
300                            samples: None,
301                        };
302                        self.event_sender.send(event).ok();
303                        self.temp_end = Some(timestamp);
304                    }
305                }
306            }
307        } else if is_speaking && self.temp_end.is_some() {
308            self.temp_end = None;
309        }
310
311        Ok(())
312    }
313}
314
315impl VadProcessor {
316    #[cfg(feature = "vad_webrtc")]
317    pub fn create_webrtc(
318        _token: CancellationToken,
319        event_sender: EventSender,
320        option: VADOption,
321    ) -> Result<Box<dyn Processor>> {
322        let vad: Box<dyn VadEngine> = match option.r#type {
323            VadType::WebRTC => Box::new(webrtc::WebRtcVad::new(option.samplerate)?),
324            _ => Box::new(NopVad::new()?),
325        };
326        Ok(Box::new(VadProcessor::new(vad, event_sender, option)?))
327    }
328    #[cfg(feature = "vad_silero")]
329    pub fn create_silero(
330        _token: CancellationToken,
331        event_sender: EventSender,
332        option: VADOption,
333    ) -> Result<Box<dyn Processor>> {
334        let vad: Box<dyn VadEngine> = match option.r#type {
335            VadType::Silero => Box::new(silero::SileroVad::new(option.clone())?),
336            _ => Box::new(NopVad::new()?),
337        };
338        Ok(Box::new(VadProcessor::new(vad, event_sender, option)?))
339    }
340    #[cfg(feature = "vad_ten")]
341    pub fn create_ten(
342        _token: CancellationToken,
343        event_sender: EventSender,
344        option: VADOption,
345    ) -> Result<Box<dyn Processor>> {
346        let vad: Box<dyn VadEngine> = match option.r#type {
347            VadType::Ten => Box::new(ten::TenVad::new(option.clone())?),
348            _ => Box::new(NopVad::new()?),
349        };
350        Ok(Box::new(VadProcessor::new(vad, event_sender, option)?))
351    }
352
353    pub fn create_nop(
354        _token: CancellationToken,
355        event_sender: EventSender,
356        option: VADOption,
357    ) -> Result<Box<dyn Processor>> {
358        let vad: Box<dyn VadEngine> = match option.r#type {
359            _ => Box::new(NopVad::new()?),
360        };
361        Ok(Box::new(VadProcessor::new(vad, event_sender, option)?))
362    }
363
364    pub fn new(
365        engine: Box<dyn VadEngine>,
366        event_sender: EventSender,
367        option: VADOption,
368    ) -> Result<Self> {
369        let inner = VadProcessorInner {
370            vad: engine,
371            event_sender,
372            option,
373            window_bufs: Vec::new(),
374            triggered: false,
375            current_speech_start: None,
376            temp_end: None,
377        };
378        Ok(Self {
379            inner: RefCell::new(inner),
380        })
381    }
382}
383
384impl Processor for VadProcessor {
385    fn process_frame(&self, frame: &mut AudioFrame) -> Result<()> {
386        self.inner.borrow_mut().process_frame(frame)
387    }
388}
389
390struct NopVad {}
391
392impl NopVad {
393    pub fn new() -> Result<Self> {
394        Ok(Self {})
395    }
396}
397
398impl VadEngine for NopVad {
399    fn process(&mut self, frame: &mut AudioFrame) -> Option<(bool, u64)> {
400        let samples = match &frame.samples {
401            Samples::PCM { samples } => samples,
402            _ => return Some((false, frame.timestamp)),
403        };
404        // Check if there are any non-zero samples
405        let has_speech = samples.iter().any(|&x| x != 0);
406        Some((has_speech, frame.timestamp))
407    }
408}