Skip to main content

rustant_core/voice/
session.rs

1//! Voice command session — background listen→transcribe→respond loop.
2//!
3//! Runs as a `tokio::spawn` task with graceful shutdown via `watch` channel.
4
5use crate::config::AgentConfig;
6use crate::error::VoiceError;
7use std::path::PathBuf;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, Ordering};
10use tokio::sync::watch;
11use tokio::task::JoinHandle;
12use tracing::{debug, info, warn};
13
14/// A non-blocking voice command session that listens for speech,
15/// transcribes it, and delivers the text via a callback.
16pub struct VoiceCommandSession {
17    cancel_tx: watch::Sender<bool>,
18    handle: JoinHandle<()>,
19    active: Arc<AtomicBool>,
20}
21
22impl VoiceCommandSession {
23    /// Start a new voice command session in the background.
24    ///
25    /// The session records audio in chunks, detects speech via VAD,
26    /// transcribes via OpenAI Whisper, and sends the transcription text
27    /// through `on_transcription`. The response text can be spoken back
28    /// via `on_response`.
29    pub async fn start(
30        config: AgentConfig,
31        _workspace: PathBuf,
32        on_transcription: Arc<dyn Fn(String) + Send + Sync>,
33    ) -> Result<Self, VoiceError> {
34        let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| VoiceError::AuthFailed {
35            provider: "openai".into(),
36        })?;
37
38        let (cancel_tx, cancel_rx) = watch::channel(false);
39        let active = Arc::new(AtomicBool::new(true));
40        let active_clone = active.clone();
41
42        let voice_config = config.voice.clone().unwrap_or_default();
43        let sample_rate: u32 = 16000;
44        let chunk_duration: f32 = 3.0;
45
46        let handle = tokio::spawn(async move {
47            voice_loop(
48                cancel_rx,
49                active_clone,
50                api_key,
51                voice_config,
52                sample_rate,
53                chunk_duration,
54                on_transcription,
55            )
56            .await;
57        });
58
59        info!("Voice command session started");
60        Ok(Self {
61            cancel_tx,
62            handle,
63            active,
64        })
65    }
66
67    /// Gracefully stop the voice command session.
68    pub async fn stop(self) -> Result<(), VoiceError> {
69        info!("Stopping voice command session");
70        let _ = self.cancel_tx.send(true);
71        self.active.store(false, Ordering::SeqCst);
72        // Wait for the background task to finish (with timeout).
73        let _ = tokio::time::timeout(std::time::Duration::from_secs(10), self.handle).await;
74        info!("Voice command session stopped");
75        Ok(())
76    }
77
78    /// Check if the session is currently active.
79    pub fn is_active(&self) -> bool {
80        self.active.load(Ordering::SeqCst)
81    }
82}
83
84/// The main voice command loop.
85async fn voice_loop(
86    mut cancel_rx: watch::Receiver<bool>,
87    active: Arc<AtomicBool>,
88    api_key: String,
89    voice_config: crate::config::VoiceConfig,
90    sample_rate: u32,
91    chunk_duration: f32,
92    on_transcription: Arc<dyn Fn(String) + Send + Sync>,
93) {
94    use crate::voice::audio_io::record_audio_chunk;
95    use crate::voice::stt::{OpenAiSttProvider, SttProvider};
96    use crate::voice::vad::{VadEvent, VoiceActivityDetector};
97
98    let stt = OpenAiSttProvider::new(&api_key);
99    let vad_threshold = voice_config.vad_threshold;
100    let mut vad = VoiceActivityDetector::new(vad_threshold);
101    let mut speech_buffer: Vec<f32> = Vec::new();
102    let mut is_collecting = false;
103    let max_collect_secs = voice_config.max_listen_secs as f32;
104    let max_collect_samples = (max_collect_secs * sample_rate as f32) as usize;
105
106    info!(
107        threshold = vad_threshold,
108        max_listen = voice_config.max_listen_secs,
109        "Voice loop started"
110    );
111
112    loop {
113        // Check for cancellation.
114        if *cancel_rx.borrow() {
115            debug!("Voice loop cancelled");
116            break;
117        }
118
119        // Record a chunk.
120        let chunk = tokio::select! {
121            result = record_audio_chunk(chunk_duration, sample_rate) => {
122                match result {
123                    Ok(c) => c,
124                    Err(e) => {
125                        warn!(error = %e, "Voice loop: failed to record chunk");
126                        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
127                        continue;
128                    }
129                }
130            }
131            _ = cancel_rx.changed() => {
132                debug!("Voice loop cancelled during recording");
133                break;
134            }
135        };
136
137        // Process through VAD.
138        let event = vad.process_chunk(&chunk);
139
140        match event {
141            VadEvent::SpeechStart => {
142                debug!("Speech detected — collecting audio");
143                is_collecting = true;
144                speech_buffer.clear();
145                speech_buffer.extend_from_slice(&chunk.samples);
146            }
147            VadEvent::SpeechEnd if is_collecting => {
148                // Add final chunk and transcribe.
149                speech_buffer.extend_from_slice(&chunk.samples);
150                is_collecting = false;
151
152                let audio = crate::voice::types::AudioChunk::new(
153                    std::mem::take(&mut speech_buffer),
154                    sample_rate,
155                    1,
156                );
157
158                debug!(
159                    duration = audio.duration_secs(),
160                    "Speech ended — transcribing"
161                );
162
163                match stt.transcribe(&audio).await {
164                    Ok(result) if !result.text.trim().is_empty() => {
165                        info!(text = %result.text, "Transcription received");
166                        on_transcription(result.text);
167                    }
168                    Ok(_) => debug!("Empty transcription — ignoring"),
169                    Err(e) => warn!(error = %e, "Transcription failed"),
170                }
171            }
172            VadEvent::NoChange if is_collecting => {
173                speech_buffer.extend_from_slice(&chunk.samples);
174                // Enforce max collection length.
175                if speech_buffer.len() >= max_collect_samples {
176                    is_collecting = false;
177                    let audio = crate::voice::types::AudioChunk::new(
178                        std::mem::take(&mut speech_buffer),
179                        sample_rate,
180                        1,
181                    );
182                    debug!(
183                        duration = audio.duration_secs(),
184                        "Max listen reached — transcribing"
185                    );
186                    match stt.transcribe(&audio).await {
187                        Ok(result) if !result.text.trim().is_empty() => {
188                            info!(text = %result.text, "Transcription received (max length)");
189                            on_transcription(result.text);
190                        }
191                        Ok(_) => debug!("Empty transcription — ignoring"),
192                        Err(e) => warn!(error = %e, "Transcription failed"),
193                    }
194                }
195            }
196            _ => {
197                // Silence or speech continuing without change — do nothing special.
198            }
199        }
200    }
201
202    active.store(false, Ordering::SeqCst);
203    info!("Voice loop exited");
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn test_session_active_flag() {
212        let active = Arc::new(AtomicBool::new(true));
213        assert!(active.load(Ordering::SeqCst));
214        active.store(false, Ordering::SeqCst);
215        assert!(!active.load(Ordering::SeqCst));
216    }
217}