1use anyhow::Result;
6use std::collections::VecDeque;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::RwLock;
10
11pub struct MarineVAD {
14 detector: Arc<RwLock<MarineDetectorState>>,
16
17 audio_monitor: Arc<RwLock<AudioMonitor>>,
19
20 is_voice_active: Arc<RwLock<bool>>,
22
23 state_callback: StateCallback,
25}
26
27type StateCallback = Arc<RwLock<Option<Box<dyn Fn(bool) + Send + Sync>>>>;
28
29struct MarineDetectorState {
31 voice_threshold: f64,
33
34 tick_rate: f64,
36
37 peak_history: VecDeque<PeakEvent>,
39
40 period_ema: ExponentialMovingAverage,
42
43 amplitude_ema: ExponentialMovingAverage,
45
46 speech_detector: SpeechPatternDetector,
48
49 voice_salience: f64,
51
52 last_tick: Instant,
54
55 voice_onset: Option<Instant>,
57
58 voice_offset: Option<Instant>,
60}
61
62#[derive(Clone, Debug)]
64struct PeakEvent {
65 timestamp: Instant,
66 amplitude: f64,
67 frequency: f64, is_voiced: bool, }
70
71struct ExponentialMovingAverage {
73 value: f64,
74 alpha: f64, }
76
77impl ExponentialMovingAverage {
78 fn new(alpha: f64) -> Self {
79 Self { value: 0.0, alpha }
80 }
81
82 fn update(&mut self, sample: f64) -> f64 {
83 self.value = self.alpha * sample + (1.0 - self.alpha) * self.value;
84 self.value
85 }
86
87 fn jitter(&self, sample: f64) -> f64 {
88 (sample - self.value).abs()
89 }
90}
91
92struct SpeechPatternDetector {
94 f0_min: f64, f0_max: f64, formant_tracker: FormantTracker,
100
101 syllable_detector: SyllableRateDetector,
103
104 voice_quality: VoiceQuality,
106}
107
108struct FormantTracker {
110 f1_range: (f64, f64), f2_range: (f64, f64), f3_range: (f64, f64), }
114
115struct SyllableRateDetector {
117 energy_envelope: VecDeque<f64>,
118 peak_times: VecDeque<Instant>,
119 min_syllable_gap: Duration, max_syllable_gap: Duration, }
122
123struct VoiceQuality {
125 harmonicity: f64, spectral_tilt: f64, zero_crossing_rate: f64, energy_variance: f64, }
130
131struct AudioMonitor {
133 current_level: f64,
135
136 peak_level: f64,
138
139 noise_floor: f64,
141
142 snr: f64,
144
145 source: AudioSource,
147}
148
149#[derive(Clone, Debug)]
150enum AudioSource {
151 Microphone,
152 LineIn,
153 Virtual, }
155
156impl MarineVAD {
157 pub fn new() -> Result<Self> {
159 Ok(Self {
160 detector: Arc::new(RwLock::new(MarineDetectorState::new())),
161 audio_monitor: Arc::new(RwLock::new(AudioMonitor::new())),
162 is_voice_active: Arc::new(RwLock::new(false)),
163 state_callback: Arc::new(RwLock::new(None)),
164 })
165 }
166
167 pub async fn process_audio(&self, samples: &[f32], sample_rate: u32) -> Result<bool> {
169 let mut detector = self.detector.write().await;
170 let mut monitor = self.audio_monitor.write().await;
171
172 monitor.update_levels(samples);
174
175 let now = Instant::now();
177 let tick_duration = Duration::from_secs_f64(1.0 / detector.tick_rate);
178
179 if now.duration_since(detector.last_tick) < tick_duration {
180 return Ok(*self.is_voice_active.read().await);
181 }
182
183 detector.last_tick = now;
184
185 let voice_detected = detector.evaluate_voice(samples, sample_rate, monitor.snr);
187
188 let mut is_active = self.is_voice_active.write().await;
190 if voice_detected != *is_active {
191 *is_active = voice_detected;
192
193 if let Some(callback) = &*self.state_callback.read().await {
195 callback(voice_detected);
196 }
197
198 if voice_detected {
200 println!("🎤 Voice detected - switching to minimal output mode");
201 detector.voice_onset = Some(now);
202 } else {
203 println!("🔇 Voice ended - returning to normal output mode");
204 detector.voice_offset = Some(now);
205 }
206 }
207
208 Ok(voice_detected)
209 }
210
211 pub async fn set_state_callback<F>(&self, callback: F)
213 where
214 F: Fn(bool) + Send + Sync + 'static,
215 {
216 let mut cb = self.state_callback.write().await;
217 *cb = Some(Box::new(callback));
218 }
219
220 pub async fn is_voice_active(&self) -> bool {
222 *self.is_voice_active.read().await
223 }
224
225 pub async fn get_salience(&self) -> f64 {
227 self.detector.read().await.voice_salience
228 }
229
230 pub async fn get_voice_quality(&self) -> VoiceQualityReport {
232 let detector = self.detector.read().await;
233 VoiceQualityReport {
234 salience: detector.voice_salience,
235 harmonicity: detector.speech_detector.voice_quality.harmonicity,
236 spectral_tilt: detector.speech_detector.voice_quality.spectral_tilt,
237 zero_crossing_rate: detector.speech_detector.voice_quality.zero_crossing_rate,
238 energy_variance: detector.speech_detector.voice_quality.energy_variance,
239 }
240 }
241}
242
243impl MarineDetectorState {
244 fn new() -> Self {
245 Self {
246 voice_threshold: -40.0, tick_rate: 100.0, peak_history: VecDeque::with_capacity(100),
249 period_ema: ExponentialMovingAverage::new(0.1),
250 amplitude_ema: ExponentialMovingAverage::new(0.05),
251 speech_detector: SpeechPatternDetector::new(),
252 voice_salience: 0.0,
253 last_tick: Instant::now(),
254 voice_onset: None,
255 voice_offset: None,
256 }
257 }
258
259 fn evaluate_voice(&mut self, samples: &[f32], sample_rate: u32, snr: f64) -> bool {
261 let energy: f64 =
263 samples.iter().map(|&s| (s as f64).powi(2)).sum::<f64>() / samples.len() as f64;
264 let rms = energy.sqrt();
265 let db = 20.0 * rms.log10();
266
267 self.amplitude_ema.update(rms);
269
270 if db < self.voice_threshold {
272 self.voice_salience *= 0.9; return false;
274 }
275
276 let has_speech_pattern = self.speech_detector.analyze(samples, sample_rate);
278
279 let mut salience = 0.0;
281
282 let energy_score = ((db - self.voice_threshold) / 20.0).clamp(0.0, 1.0);
284 salience += energy_score * 0.3;
285
286 let snr_score = (snr / 20.0).clamp(0.0, 1.0);
288 salience += snr_score * 0.2;
289
290 if has_speech_pattern {
292 salience += 0.5;
293 }
294
295 self.voice_salience = 0.7 * salience + 0.3 * self.voice_salience;
297
298 self.voice_salience > 0.5
300 }
301}
302
303impl SpeechPatternDetector {
304 fn new() -> Self {
305 Self {
306 f0_min: 80.0,
307 f0_max: 400.0,
308 formant_tracker: FormantTracker {
309 f1_range: (200.0, 1000.0),
310 f2_range: (500.0, 2500.0),
311 f3_range: (1500.0, 3500.0),
312 },
313 syllable_detector: SyllableRateDetector {
314 energy_envelope: VecDeque::with_capacity(100),
315 peak_times: VecDeque::with_capacity(20),
316 min_syllable_gap: Duration::from_millis(100),
317 max_syllable_gap: Duration::from_millis(500),
318 },
319 voice_quality: VoiceQuality {
320 harmonicity: 0.0,
321 spectral_tilt: 0.0,
322 zero_crossing_rate: 0.0,
323 energy_variance: 0.0,
324 },
325 }
326 }
327
328 fn analyze(&mut self, samples: &[f32], sample_rate: u32) -> bool {
329 let mut zero_crossings = 0;
331 for i in 1..samples.len() {
332 if samples[i - 1] * samples[i] < 0.0 {
333 zero_crossings += 1;
334 }
335 }
336
337 let zcr = zero_crossings as f64 / samples.len() as f64;
338 self.voice_quality.zero_crossing_rate = zcr;
339
340 let is_voiced = zcr < 0.3;
342
343 let estimated_freq = zcr * sample_rate as f64 / 2.0;
345 let in_speech_range = estimated_freq >= self.f0_min && estimated_freq <= self.f0_max * 10.0;
346
347 is_voiced && in_speech_range
348 }
349}
350
351impl AudioMonitor {
352 fn new() -> Self {
353 Self {
354 current_level: 0.0,
355 peak_level: 0.0,
356 noise_floor: -60.0, snr: 0.0,
358 source: AudioSource::Microphone,
359 }
360 }
361
362 fn update_levels(&mut self, samples: &[f32]) {
363 let sum_squares: f32 = samples.iter().map(|&s| s * s).sum();
365 let rms = (sum_squares / samples.len() as f32).sqrt();
366 self.current_level = rms as f64;
367
368 let peak = samples.iter().map(|&s| s.abs()).fold(0.0f32, f32::max) as f64;
370 self.peak_level = peak;
371
372 if rms as f64 > 0.0 {
374 let db = 20.0 * (rms as f64).log10();
375 self.noise_floor = 0.99 * self.noise_floor + 0.01 * db;
376 self.snr = db - self.noise_floor;
377 }
378 }
379}
380
381#[derive(Debug, Clone)]
383pub struct VoiceQualityReport {
384 pub salience: f64,
385 pub harmonicity: f64,
386 pub spectral_tilt: f64,
387 pub zero_crossing_rate: f64,
388 pub energy_variance: f64,
389}
390
391impl super::rust_shell::RustShell {
393 pub async fn enable_marine_vad(&self) -> Result<()> {
395 println!("🎖️ Enabling Marine VAD - Semper Fi to voice detection!");
396
397 let vad = MarineVAD::new()?;
398
399 let output_mode = self.output_mode.clone();
401 vad.set_state_callback(move |is_voice| {
402 let mode = output_mode.clone();
404 tokio::spawn(async move {
405 let mut m = mode.write().await;
406 if is_voice {
407 m.verbosity = super::rust_shell::VerbosityLevel::Minimal;
408 m.format = super::rust_shell::OutputFormat::Voice;
409 } else {
410 m.verbosity = super::rust_shell::VerbosityLevel::Normal;
411 m.format = super::rust_shell::OutputFormat::Text;
412 }
413 });
414 })
415 .await;
416
417 Ok(())
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[tokio::test]
429 async fn test_marine_vad_creation() {
430 let vad = MarineVAD::new();
431 assert!(vad.is_ok());
432 }
433
434 #[tokio::test]
435 async fn test_voice_detection() {
436 let vad = MarineVAD::new().unwrap();
437
438 let sample_rate = 16000;
440 let frequency = 200.0;
441 let duration = 0.1; let num_samples = (sample_rate as f64 * duration) as usize;
443
444 let mut samples = vec![0.0f32; num_samples];
445 for (i, sample) in samples.iter_mut().enumerate().take(num_samples) {
446 let t = i as f64 / sample_rate as f64;
447 *sample = (2.0 * std::f64::consts::PI * frequency * t).sin() as f32 * 0.5;
448 }
449
450 let _is_voice = vad.process_audio(&samples, sample_rate).await.unwrap();
452
453 }
456}