rustant_core/voice/
session.rs1use crate::config::AgentConfig;
6use crate::error::VoiceError;
7use std::path::PathBuf;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, Ordering};
10use tokio::sync::watch;
11use tokio::task::JoinHandle;
12use tracing::{debug, info, warn};
13
14pub struct VoiceCommandSession {
17 cancel_tx: watch::Sender<bool>,
18 handle: JoinHandle<()>,
19 active: Arc<AtomicBool>,
20}
21
22impl VoiceCommandSession {
23 pub async fn start(
30 config: AgentConfig,
31 _workspace: PathBuf,
32 on_transcription: Arc<dyn Fn(String) + Send + Sync>,
33 ) -> Result<Self, VoiceError> {
34 let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| VoiceError::AuthFailed {
35 provider: "openai".into(),
36 })?;
37
38 let (cancel_tx, cancel_rx) = watch::channel(false);
39 let active = Arc::new(AtomicBool::new(true));
40 let active_clone = active.clone();
41
42 let voice_config = config.voice.clone().unwrap_or_default();
43 let sample_rate: u32 = 16000;
44 let chunk_duration: f32 = 3.0;
45
46 let handle = tokio::spawn(async move {
47 voice_loop(
48 cancel_rx,
49 active_clone,
50 api_key,
51 voice_config,
52 sample_rate,
53 chunk_duration,
54 on_transcription,
55 )
56 .await;
57 });
58
59 info!("Voice command session started");
60 Ok(Self {
61 cancel_tx,
62 handle,
63 active,
64 })
65 }
66
67 pub async fn stop(self) -> Result<(), VoiceError> {
69 info!("Stopping voice command session");
70 let _ = self.cancel_tx.send(true);
71 self.active.store(false, Ordering::SeqCst);
72 let _ = tokio::time::timeout(std::time::Duration::from_secs(10), self.handle).await;
74 info!("Voice command session stopped");
75 Ok(())
76 }
77
78 pub fn is_active(&self) -> bool {
80 self.active.load(Ordering::SeqCst)
81 }
82}
83
84async fn voice_loop(
86 mut cancel_rx: watch::Receiver<bool>,
87 active: Arc<AtomicBool>,
88 api_key: String,
89 voice_config: crate::config::VoiceConfig,
90 sample_rate: u32,
91 chunk_duration: f32,
92 on_transcription: Arc<dyn Fn(String) + Send + Sync>,
93) {
94 use crate::voice::audio_io::record_audio_chunk;
95 use crate::voice::stt::{OpenAiSttProvider, SttProvider};
96 use crate::voice::vad::{VadEvent, VoiceActivityDetector};
97
98 let stt = OpenAiSttProvider::new(&api_key);
99 let vad_threshold = voice_config.vad_threshold;
100 let mut vad = VoiceActivityDetector::new(vad_threshold);
101 let mut speech_buffer: Vec<f32> = Vec::new();
102 let mut is_collecting = false;
103 let max_collect_secs = voice_config.max_listen_secs as f32;
104 let max_collect_samples = (max_collect_secs * sample_rate as f32) as usize;
105
106 info!(
107 threshold = vad_threshold,
108 max_listen = voice_config.max_listen_secs,
109 "Voice loop started"
110 );
111
112 loop {
113 if *cancel_rx.borrow() {
115 debug!("Voice loop cancelled");
116 break;
117 }
118
119 let chunk = tokio::select! {
121 result = record_audio_chunk(chunk_duration, sample_rate) => {
122 match result {
123 Ok(c) => c,
124 Err(e) => {
125 warn!(error = %e, "Voice loop: failed to record chunk");
126 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
127 continue;
128 }
129 }
130 }
131 _ = cancel_rx.changed() => {
132 debug!("Voice loop cancelled during recording");
133 break;
134 }
135 };
136
137 let event = vad.process_chunk(&chunk);
139
140 match event {
141 VadEvent::SpeechStart => {
142 debug!("Speech detected — collecting audio");
143 is_collecting = true;
144 speech_buffer.clear();
145 speech_buffer.extend_from_slice(&chunk.samples);
146 }
147 VadEvent::SpeechEnd if is_collecting => {
148 speech_buffer.extend_from_slice(&chunk.samples);
150 is_collecting = false;
151
152 let audio = crate::voice::types::AudioChunk::new(
153 std::mem::take(&mut speech_buffer),
154 sample_rate,
155 1,
156 );
157
158 debug!(
159 duration = audio.duration_secs(),
160 "Speech ended — transcribing"
161 );
162
163 match stt.transcribe(&audio).await {
164 Ok(result) if !result.text.trim().is_empty() => {
165 info!(text = %result.text, "Transcription received");
166 on_transcription(result.text);
167 }
168 Ok(_) => debug!("Empty transcription — ignoring"),
169 Err(e) => warn!(error = %e, "Transcription failed"),
170 }
171 }
172 VadEvent::NoChange if is_collecting => {
173 speech_buffer.extend_from_slice(&chunk.samples);
174 if speech_buffer.len() >= max_collect_samples {
176 is_collecting = false;
177 let audio = crate::voice::types::AudioChunk::new(
178 std::mem::take(&mut speech_buffer),
179 sample_rate,
180 1,
181 );
182 debug!(
183 duration = audio.duration_secs(),
184 "Max listen reached — transcribing"
185 );
186 match stt.transcribe(&audio).await {
187 Ok(result) if !result.text.trim().is_empty() => {
188 info!(text = %result.text, "Transcription received (max length)");
189 on_transcription(result.text);
190 }
191 Ok(_) => debug!("Empty transcription — ignoring"),
192 Err(e) => warn!(error = %e, "Transcription failed"),
193 }
194 }
195 }
196 _ => {
197 }
199 }
200 }
201
202 active.store(false, Ordering::SeqCst);
203 info!("Voice loop exited");
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
211 fn test_session_active_flag() {
212 let active = Arc::new(AtomicBool::new(true));
213 assert!(active.load(Ordering::SeqCst));
214 active.store(false, Ordering::SeqCst);
215 assert!(!active.load(Ordering::SeqCst));
216 }
217}