1use crate::event::{EventSender, SessionEvent};
2use crate::media::processor::Processor;
3use crate::media::{AudioFrame, PcmBuf, Samples};
4use anyhow::Result;
5use serde::{Deserialize, Serialize};
6use serde_with::skip_serializing_none;
7use std::any::Any;
8use std::cell::RefCell;
9use tokio_util::sync::CancellationToken;
10#[cfg(feature = "vad_silero")]
11mod silero;
12#[cfg(feature = "vad_ten")]
13mod ten;
14#[cfg(test)]
15mod tests;
16#[cfg(feature = "vad_webrtc")]
17mod webrtc;
18use std::sync::Mutex;
19
20#[allow(unused)]
21pub(crate) struct SessionPool<T, F>
22where
23 F: Fn() -> Result<T> + Send + Sync + 'static,
24{
25 max_items: usize,
26 items: Mutex<Vec<T>>,
27 factory: F,
28}
29
30#[allow(unused)]
31impl<T, F> SessionPool<T, F>
32where
33 F: Fn() -> Result<T> + Send + Sync + 'static,
34{
35 pub fn new(max_items: usize, factory: F) -> Self {
36 Self {
37 max_items,
38 items: Mutex::new(Vec::with_capacity(max_items)),
39 factory,
40 }
41 }
42
43 pub fn pop_or_create(&self) -> Result<T> {
44 let mut guard = self.items.lock().unwrap();
45 match guard.pop() {
46 Some(item) => Ok(item),
47 None => (self.factory)(),
48 }
49 }
50
51 pub fn push(&self, item: T) {
52 let mut guard = self.items.lock().unwrap();
53 if guard.len() < self.max_items {
54 guard.push(item);
55 return;
56 }
57 drop(guard);
58 }
59}
60
61#[skip_serializing_none]
62#[derive(Clone, Debug, Deserialize, Serialize)]
63#[serde(rename_all = "camelCase")]
64#[serde(default)]
65pub struct VADOption {
66 pub r#type: VadType,
67 pub samplerate: u32,
68 pub speech_padding: u64,
70 pub silence_padding: u64,
72 pub ratio: f32,
73 pub voice_threshold: f32,
74 pub max_buffer_duration_secs: u64,
75 pub silence_timeout: Option<u64>,
77 pub endpoint: Option<String>,
78 pub secret_key: Option<String>,
79 pub secret_id: Option<String>,
80}
81
82impl Default for VADOption {
83 fn default() -> Self {
84 Self {
85 #[cfg(feature = "vad_webrtc")]
86 r#type: VadType::WebRTC,
87 #[cfg(all(
88 not(feature = "vad_webrtc"),
89 not(feature = "vad_ten"),
90 feature = "vad_silero"
91 ))]
92 r#type: VadType::Silero,
93 #[cfg(all(
94 not(feature = "vad_webrtc"),
95 not(feature = "vad_silero"),
96 feature = "vad_ten"
97 ))]
98 r#type: VadType::Ten,
99 #[cfg(all(
100 not(feature = "vad_webrtc"),
101 not(feature = "vad_silero"),
102 not(feature = "vad_ten"),
103 ))]
104 r#type: VadType::Other("nop".to_string()),
105 samplerate: 16000,
106 speech_padding: 250, silence_padding: 100, ratio: 0.5,
110 voice_threshold: 0.5,
111 max_buffer_duration_secs: 50,
112 silence_timeout: None,
113 endpoint: None,
114 secret_key: None,
115 secret_id: None,
116 }
117 }
118}
119
120#[derive(Clone, Debug, Serialize, Eq, Hash, PartialEq)]
121#[serde(rename_all = "lowercase")]
122pub enum VadType {
123 #[cfg(feature = "vad_webrtc")]
124 WebRTC,
125 #[cfg(feature = "vad_silero")]
126 Silero,
127 #[cfg(feature = "vad_ten")]
128 Ten,
129 Other(String),
130}
131
132impl<'de> Deserialize<'de> for VadType {
133 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
134 where
135 D: serde::Deserializer<'de>,
136 {
137 let value = String::deserialize(deserializer)?;
138 match value.as_str() {
139 #[cfg(feature = "vad_webrtc")]
140 "webrtc" => Ok(VadType::WebRTC),
141 #[cfg(feature = "vad_silero")]
142 "silero" => Ok(VadType::Silero),
143 #[cfg(feature = "vad_ten")]
144 "ten" => Ok(VadType::Ten),
145 _ => Ok(VadType::Other(value)),
146 }
147 }
148}
149
150impl std::fmt::Display for VadType {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 match self {
153 #[cfg(feature = "vad_webrtc")]
154 VadType::WebRTC => write!(f, "webrtc"),
155 #[cfg(feature = "vad_silero")]
156 VadType::Silero => write!(f, "silero"),
157 #[cfg(feature = "vad_ten")]
158 VadType::Ten => write!(f, "ten"),
159 VadType::Other(provider) => write!(f, "{}", provider),
160 }
161 }
162}
163
164impl TryFrom<&String> for VadType {
165 type Error = String;
166
167 fn try_from(value: &String) -> std::result::Result<Self, Self::Error> {
168 match value.as_str() {
169 #[cfg(feature = "vad_webrtc")]
170 "webrtc" => Ok(VadType::WebRTC),
171 #[cfg(feature = "vad_silero")]
172 "silero" => Ok(VadType::Silero),
173 #[cfg(feature = "vad_ten")]
174 "ten" => Ok(VadType::Ten),
175 other => Ok(VadType::Other(other.to_string())),
176 }
177 }
178}
179struct SpeechBuf {
180 samples: PcmBuf,
181 timestamp: u64,
182}
183
184struct VadProcessorInner {
185 vad: Box<dyn VadEngine>,
186 event_sender: EventSender,
187 option: VADOption,
188 window_bufs: Vec<SpeechBuf>,
189 triggered: bool,
190 current_speech_start: Option<u64>,
191 temp_end: Option<u64>,
192}
193pub struct VadProcessor {
194 inner: RefCell<VadProcessorInner>,
195}
196unsafe impl Send for VadProcessor {}
197unsafe impl Sync for VadProcessor {}
198
199pub trait VadEngine: Send + Sync + Any {
200 fn process(&mut self, frame: &mut AudioFrame) -> Option<(bool, u64)>;
201}
202
203impl VadProcessorInner {
204 pub fn process_frame(&mut self, frame: &mut AudioFrame) -> Result<()> {
205 let samples = match &frame.samples {
206 Samples::PCM { samples } => samples,
207 _ => return Ok(()),
208 };
209
210 let samples = samples.to_owned();
211 let result = self.vad.process(frame);
212 if let Some((is_speaking, timestamp)) = result {
213 if is_speaking || self.triggered {
214 let current_buf = SpeechBuf { samples, timestamp };
215 self.window_bufs.push(current_buf);
216 }
217 self.process_vad_logic(is_speaking, timestamp, &frame.track_id)?;
218
219 if self.window_bufs.len() > 1000 || !self.triggered {
221 let cutoff = if self.triggered {
222 timestamp.saturating_sub(5000)
223 } else {
224 timestamp.saturating_sub(self.option.silence_padding)
225 };
226 self.window_bufs.retain(|buf| buf.timestamp > cutoff);
227 }
228 }
229
230 Ok(())
231 }
232
233 fn process_vad_logic(
234 &mut self,
235 is_speaking: bool,
236 timestamp: u64,
237 track_id: &str,
238 ) -> Result<()> {
239 if is_speaking && !self.triggered {
240 self.triggered = true;
241 self.current_speech_start = Some(timestamp);
242 let event = SessionEvent::Speaking {
243 track_id: track_id.to_string(),
244 timestamp: crate::media::get_timestamp(),
245 start_time: timestamp,
246 };
247 self.event_sender.send(event).ok();
248 } else if !is_speaking {
249 if self.temp_end.is_none() {
250 self.temp_end = Some(timestamp);
251 }
252
253 if let Some(temp_end) = self.temp_end {
254 let silence_duration = timestamp.saturating_sub(temp_end);
256
257 if self.triggered && silence_duration >= self.option.silence_padding {
259 if let Some(start_time) = self.current_speech_start {
260 let duration = temp_end.saturating_sub(start_time);
262 if duration >= self.option.speech_padding {
263 let samples_vec = self
264 .window_bufs
265 .iter()
266 .filter(|buf| {
267 buf.timestamp >= start_time && buf.timestamp <= temp_end
268 })
269 .flat_map(|buf| buf.samples.iter())
270 .cloned()
271 .collect();
272 self.window_bufs.clear();
273
274 let event = SessionEvent::Silence {
275 track_id: track_id.to_string(),
276 timestamp: crate::media::get_timestamp(),
277 start_time,
278 duration,
279 samples: Some(samples_vec),
280 };
281 self.event_sender.send(event).ok();
282 }
283 }
284 self.triggered = false;
285 self.current_speech_start = None;
286 self.temp_end = Some(timestamp); }
288
289 if let Some(timeout) = self.option.silence_timeout {
291 let timeout_duration = timestamp.saturating_sub(temp_end);
293
294 if timeout_duration >= timeout {
295 let event = SessionEvent::Silence {
296 track_id: track_id.to_string(),
297 timestamp: crate::media::get_timestamp(),
298 start_time: temp_end,
299 duration: timeout_duration,
300 samples: None,
301 };
302 self.event_sender.send(event).ok();
303 self.temp_end = Some(timestamp);
304 }
305 }
306 }
307 } else if is_speaking && self.temp_end.is_some() {
308 self.temp_end = None;
309 }
310
311 Ok(())
312 }
313}
314
315impl VadProcessor {
316 #[cfg(feature = "vad_webrtc")]
317 pub fn create_webrtc(
318 _token: CancellationToken,
319 event_sender: EventSender,
320 option: VADOption,
321 ) -> Result<Box<dyn Processor>> {
322 let vad: Box<dyn VadEngine> = match option.r#type {
323 VadType::WebRTC => Box::new(webrtc::WebRtcVad::new(option.samplerate)?),
324 _ => Box::new(NopVad::new()?),
325 };
326 Ok(Box::new(VadProcessor::new(vad, event_sender, option)?))
327 }
328 #[cfg(feature = "vad_silero")]
329 pub fn create_silero(
330 _token: CancellationToken,
331 event_sender: EventSender,
332 option: VADOption,
333 ) -> Result<Box<dyn Processor>> {
334 let vad: Box<dyn VadEngine> = match option.r#type {
335 VadType::Silero => Box::new(silero::SileroVad::new(option.clone())?),
336 _ => Box::new(NopVad::new()?),
337 };
338 Ok(Box::new(VadProcessor::new(vad, event_sender, option)?))
339 }
340 #[cfg(feature = "vad_ten")]
341 pub fn create_ten(
342 _token: CancellationToken,
343 event_sender: EventSender,
344 option: VADOption,
345 ) -> Result<Box<dyn Processor>> {
346 let vad: Box<dyn VadEngine> = match option.r#type {
347 VadType::Ten => Box::new(ten::TenVad::new(option.clone())?),
348 _ => Box::new(NopVad::new()?),
349 };
350 Ok(Box::new(VadProcessor::new(vad, event_sender, option)?))
351 }
352
353 pub fn create_nop(
354 _token: CancellationToken,
355 event_sender: EventSender,
356 option: VADOption,
357 ) -> Result<Box<dyn Processor>> {
358 let vad: Box<dyn VadEngine> = match option.r#type {
359 _ => Box::new(NopVad::new()?),
360 };
361 Ok(Box::new(VadProcessor::new(vad, event_sender, option)?))
362 }
363
364 pub fn new(
365 engine: Box<dyn VadEngine>,
366 event_sender: EventSender,
367 option: VADOption,
368 ) -> Result<Self> {
369 let inner = VadProcessorInner {
370 vad: engine,
371 event_sender,
372 option,
373 window_bufs: Vec::new(),
374 triggered: false,
375 current_speech_start: None,
376 temp_end: None,
377 };
378 Ok(Self {
379 inner: RefCell::new(inner),
380 })
381 }
382}
383
384impl Processor for VadProcessor {
385 fn process_frame(&self, frame: &mut AudioFrame) -> Result<()> {
386 self.inner.borrow_mut().process_frame(frame)
387 }
388}
389
390struct NopVad {}
391
392impl NopVad {
393 pub fn new() -> Result<Self> {
394 Ok(Self {})
395 }
396}
397
398impl VadEngine for NopVad {
399 fn process(&mut self, frame: &mut AudioFrame) -> Option<(bool, u64)> {
400 let samples = match &frame.samples {
401 Samples::PCM { samples } => samples,
402 _ => return Some((false, frame.timestamp)),
403 };
404 let has_speech = samples.iter().any(|&x| x != 0);
406 Some((has_speech, frame.timestamp))
407 }
408}