Skip to main content

whisper_cpp_plus/
async_api.rs

1//! Async API for non-blocking transcription
2//!
3//! This module provides async wrappers around the synchronous whisper.cpp API,
4//! enabling non-blocking transcription in async Rust applications.
5
6#[cfg(feature = "async")]
7use crate::{
8    context::WhisperContext,
9    error::Result,
10    params::FullParams,
11    state::WhisperState,
12    stream::{WhisperStream, WhisperStreamConfig},
13    Segment, TranscriptionResult,
14};
15#[cfg(feature = "async")]
16use std::sync::Arc;
17#[cfg(feature = "async")]
18use tokio::sync::{mpsc, oneshot, Mutex};
19#[cfg(feature = "async")]
20use tokio::task;
21
22#[cfg(feature = "async")]
23impl WhisperContext {
24    /// Transcribe audio asynchronously using default parameters
25    pub async fn transcribe_async(&self, audio: Vec<f32>) -> Result<String> {
26        let ctx = self.clone();
27        task::spawn_blocking(move || ctx.transcribe(&audio))
28            .await
29            .map_err(|e| crate::WhisperError::TranscriptionError(e.to_string()))?
30    }
31
32    /// Transcribe audio asynchronously with custom parameters
33    pub async fn transcribe_with_params_async(
34        &self,
35        audio: Vec<f32>,
36        params: crate::TranscriptionParams,
37    ) -> Result<TranscriptionResult> {
38        let ctx = self.clone();
39        task::spawn_blocking(move || ctx.transcribe_with_params(&audio, params))
40            .await
41            .map_err(|e| crate::WhisperError::TranscriptionError(e.to_string()))?
42    }
43
44    /// Create an async state for manual control
45    pub async fn create_state_async(&self) -> Result<WhisperState> {
46        let ctx = self.clone();
47        task::spawn_blocking(move || ctx.create_state())
48            .await
49            .map_err(|e| crate::WhisperError::TranscriptionError(e.to_string()))?
50    }
51}
52
53/// An async streaming transcriber with channels for audio input
54#[cfg(feature = "async")]
55pub struct AsyncWhisperStream {
56    /// Channel for sending audio to the processing task
57    audio_tx: mpsc::Sender<AudioCommand>,
58    /// Channel for receiving transcribed segments
59    segment_rx: mpsc::Receiver<Vec<Segment>>,
60    /// Handle to the background processing task
61    handle: task::JoinHandle<Result<()>>,
62}
63
64#[cfg(feature = "async")]
65enum AudioCommand {
66    Feed(Vec<f32>),
67    Flush(oneshot::Sender<Vec<Segment>>),
68    Stop,
69}
70
71#[cfg(feature = "async")]
72impl AsyncWhisperStream {
73    /// Create a new async streaming transcriber
74    pub fn new(context: WhisperContext, params: FullParams) -> Result<Self> {
75        Self::with_config(context, params, WhisperStreamConfig::default())
76    }
77
78    /// Create a new async streaming transcriber with custom configuration
79    pub fn with_config(
80        context: WhisperContext,
81        params: FullParams,
82        config: WhisperStreamConfig,
83    ) -> Result<Self> {
84        let (audio_tx, mut audio_rx) = mpsc::channel::<AudioCommand>(100);
85        let (segment_tx, segment_rx) = mpsc::channel::<Vec<Segment>>(100);
86
87        let handle = task::spawn_blocking(move || {
88            let mut stream = WhisperStream::with_config(&context, params, config)?;
89
90            while let Some(cmd) = audio_rx.blocking_recv() {
91                match cmd {
92                    AudioCommand::Feed(audio) => {
93                        stream.feed_audio(&audio);
94
95                        // Process available audio
96                        while let Some(segments) = stream.process_step()? {
97                            if !segments.is_empty() {
98                                let _ = segment_tx.blocking_send(segments);
99                            }
100                        }
101                    }
102                    AudioCommand::Flush(response) => {
103                        let segments = stream.flush()?;
104                        let _ = response.send(segments);
105                    }
106                    AudioCommand::Stop => break,
107                }
108            }
109
110            Ok(())
111        });
112
113        Ok(Self {
114            audio_tx,
115            segment_rx,
116            handle,
117        })
118    }
119
120    /// Feed audio samples to the stream
121    pub async fn feed_audio(&self, audio: Vec<f32>) -> Result<()> {
122        self.audio_tx
123            .send(AudioCommand::Feed(audio))
124            .await
125            .map_err(|_| crate::WhisperError::TranscriptionError("Stream closed".into()))
126    }
127
128    /// Receive transcribed segments if available
129    pub async fn recv_segments(&mut self) -> Option<Vec<Segment>> {
130        self.segment_rx.recv().await
131    }
132
133    /// Try to receive segments without blocking
134    pub fn try_recv_segments(&mut self) -> Option<Vec<Segment>> {
135        self.segment_rx.try_recv().ok()
136    }
137
138    /// Flush the stream and get all pending segments
139    pub async fn flush(&self) -> Result<Vec<Segment>> {
140        let (tx, rx) = oneshot::channel();
141        self.audio_tx
142            .send(AudioCommand::Flush(tx))
143            .await
144            .map_err(|_| crate::WhisperError::TranscriptionError("Stream closed".into()))?;
145        rx.await
146            .map_err(|_| crate::WhisperError::TranscriptionError("Failed to flush".into()))
147    }
148
149    /// Stop the stream gracefully
150    pub async fn stop(self) -> Result<()> {
151        let _ = self.audio_tx.send(AudioCommand::Stop).await;
152        self.handle
153            .await
154            .map_err(|e| crate::WhisperError::TranscriptionError(e.to_string()))?
155    }
156}
157
158/// A shared async stream for multiple producers
159#[cfg(feature = "async")]
160pub struct SharedAsyncStream {
161    inner: Arc<Mutex<AsyncStreamInner>>,
162}
163
164#[cfg(feature = "async")]
165struct AsyncStreamInner {
166    stream: WhisperStream,
167    pending_segments: Vec<Segment>,
168}
169
170#[cfg(feature = "async")]
171impl SharedAsyncStream {
172    /// Create a new shared async stream
173    pub async fn new(
174        context: &WhisperContext,
175        params: FullParams,
176        config: WhisperStreamConfig,
177    ) -> Result<Self> {
178        let stream = WhisperStream::with_config(context, params, config)?;
179
180        Ok(Self {
181            inner: Arc::new(Mutex::new(AsyncStreamInner {
182                stream,
183                pending_segments: Vec::new(),
184            })),
185        })
186    }
187
188    /// Feed audio and get segments atomically
189    pub async fn feed_and_process(&self, audio: Vec<f32>) -> Result<Vec<Segment>> {
190        let mut inner = self.inner.lock().await;
191
192        // Feed audio
193        inner.stream.feed_audio(&audio);
194
195        // Process available audio
196        let mut segments = Vec::new();
197        while let Some(segs) = inner.stream.process_step()? {
198            segments.extend(segs);
199        }
200
201        inner.pending_segments.extend(segments.clone());
202
203        Ok(segments)
204    }
205
206    /// Drain all pending segments
207    pub async fn drain_segments(&self) -> Vec<Segment> {
208        let mut inner = self.inner.lock().await;
209        std::mem::take(&mut inner.pending_segments)
210    }
211
212    /// Flush the stream
213    pub async fn flush(&self) -> Result<Vec<Segment>> {
214        let mut inner = self.inner.lock().await;
215        let segments = inner.stream.flush()?;
216        inner.pending_segments.extend(segments.clone());
217        Ok(segments)
218    }
219}
220
221#[cfg(all(test, feature = "async"))]
222mod tests {
223    use super::*;
224    use crate::SamplingStrategy;
225    use std::path::Path;
226
227    #[tokio::test]
228    async fn test_async_transcribe() {
229        let model_path = "tests/models/ggml-tiny.en.bin";
230        if Path::new(model_path).exists() {
231            let ctx = WhisperContext::new(model_path).unwrap();
232            let audio = vec![0.0f32; 16000]; // 1 second of silence
233
234            let result = ctx.transcribe_async(audio).await;
235            assert!(result.is_ok());
236        }
237    }
238
239    #[tokio::test]
240    async fn test_async_stream() {
241        let model_path = "tests/models/ggml-tiny.en.bin";
242        if Path::new(model_path).exists() {
243            let ctx = WhisperContext::new(model_path).unwrap();
244            let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
245
246            let stream = AsyncWhisperStream::new(ctx, params);
247            assert!(stream.is_ok());
248
249            let stream = stream.unwrap();
250
251            // Feed some audio
252            let audio = vec![0.0f32; 16000];
253            let result = stream.feed_audio(audio).await;
254            assert!(result.is_ok());
255
256            // Stop the stream
257            let result = stream.stop().await;
258            assert!(result.is_ok());
259        }
260    }
261
262    #[tokio::test]
263    async fn test_shared_stream() {
264        let model_path = "tests/models/ggml-tiny.en.bin";
265        if Path::new(model_path).exists() {
266            let ctx = WhisperContext::new(model_path).unwrap();
267            let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
268
269            let stream = SharedAsyncStream::new(&ctx, params, WhisperStreamConfig::default()).await;
270            assert!(stream.is_ok());
271
272            let stream = stream.unwrap();
273
274            // Feed audio from multiple tasks
275            let stream1 = stream.clone();
276            let handle1 = tokio::spawn(async move {
277                let audio = vec![0.0f32; 16000];
278                stream1.feed_and_process(audio).await
279            });
280
281            let stream2 = stream.clone();
282            let handle2 = tokio::spawn(async move {
283                let audio = vec![0.0f32; 16000];
284                stream2.feed_and_process(audio).await
285            });
286
287            // Wait for both
288            let result1 = handle1.await.unwrap();
289            let result2 = handle2.await.unwrap();
290
291            assert!(result1.is_ok());
292            assert!(result2.is_ok());
293        }
294    }
295}
296
297// Implement Clone for SharedAsyncStream
298#[cfg(feature = "async")]
299impl Clone for SharedAsyncStream {
300    fn clone(&self) -> Self {
301        Self {
302            inner: self.inner.clone(),
303        }
304    }
305}