Skip to main content

whisper_cpp_plus/
async_api.rs

1//! Async API for non-blocking transcription
2//!
3//! This module provides async wrappers around the synchronous whisper.cpp API,
4//! enabling non-blocking transcription in async Rust applications.
5
6#[cfg(feature = "async")]
7use crate::{
8    context::WhisperContext, error::Result, params::FullParams,
9    state::WhisperState, stream::{WhisperStreamConfig, WhisperStream}, Segment, TranscriptionResult,
10};
11#[cfg(feature = "async")]
12use std::sync::Arc;
13#[cfg(feature = "async")]
14use tokio::sync::{mpsc, oneshot, Mutex};
15#[cfg(feature = "async")]
16use tokio::task;
17
18#[cfg(feature = "async")]
19impl WhisperContext {
20    /// Transcribe audio asynchronously using default parameters
21    pub async fn transcribe_async(&self, audio: Vec<f32>) -> Result<String> {
22        let ctx = self.clone();
23        task::spawn_blocking(move || ctx.transcribe(&audio))
24            .await
25            .map_err(|e| crate::WhisperError::TranscriptionError(e.to_string()))?
26    }
27
28    /// Transcribe audio asynchronously with custom parameters
29    pub async fn transcribe_with_params_async(
30        &self,
31        audio: Vec<f32>,
32        params: crate::TranscriptionParams,
33    ) -> Result<TranscriptionResult> {
34        let ctx = self.clone();
35        task::spawn_blocking(move || ctx.transcribe_with_params(&audio, params))
36            .await
37            .map_err(|e| crate::WhisperError::TranscriptionError(e.to_string()))?
38    }
39
40    /// Create an async state for manual control
41    pub async fn create_state_async(&self) -> Result<WhisperState> {
42        let ctx = self.clone();
43        task::spawn_blocking(move || ctx.create_state())
44            .await
45            .map_err(|e| crate::WhisperError::TranscriptionError(e.to_string()))?
46    }
47}
48
49/// An async streaming transcriber with channels for audio input
50#[cfg(feature = "async")]
51pub struct AsyncWhisperStream {
52    /// Channel for sending audio to the processing task
53    audio_tx: mpsc::Sender<AudioCommand>,
54    /// Channel for receiving transcribed segments
55    segment_rx: mpsc::Receiver<Vec<Segment>>,
56    /// Handle to the background processing task
57    handle: task::JoinHandle<Result<()>>,
58}
59
60#[cfg(feature = "async")]
61enum AudioCommand {
62    Feed(Vec<f32>),
63    Flush(oneshot::Sender<Vec<Segment>>),
64    Stop,
65}
66
67#[cfg(feature = "async")]
68impl AsyncWhisperStream {
69    /// Create a new async streaming transcriber
70    pub fn new(
71        context: WhisperContext,
72        params: FullParams,
73    ) -> Result<Self> {
74        Self::with_config(context, params, WhisperStreamConfig::default())
75    }
76
77    /// Create a new async streaming transcriber with custom configuration
78    pub fn with_config(
79        context: WhisperContext,
80        params: FullParams,
81        config: WhisperStreamConfig,
82    ) -> Result<Self> {
83        let (audio_tx, mut audio_rx) = mpsc::channel::<AudioCommand>(100);
84        let (segment_tx, segment_rx) = mpsc::channel::<Vec<Segment>>(100);
85
86        let handle = task::spawn_blocking(move || {
87            let mut stream = WhisperStream::with_config(&context, params, config)?;
88
89            while let Some(cmd) = audio_rx.blocking_recv() {
90                match cmd {
91                    AudioCommand::Feed(audio) => {
92                        stream.feed_audio(&audio);
93
94                        // Process available audio
95                        while let Some(segments) = stream.process_step()? {
96                            if !segments.is_empty() {
97                                let _ = segment_tx.blocking_send(segments);
98                            }
99                        }
100                    }
101                    AudioCommand::Flush(response) => {
102                        let segments = stream.flush()?;
103                        let _ = response.send(segments);
104                    }
105                    AudioCommand::Stop => break,
106                }
107            }
108
109            Ok(())
110        });
111
112        Ok(Self {
113            audio_tx,
114            segment_rx,
115            handle,
116        })
117    }
118
119    /// Feed audio samples to the stream
120    pub async fn feed_audio(&self, audio: Vec<f32>) -> Result<()> {
121        self.audio_tx
122            .send(AudioCommand::Feed(audio))
123            .await
124            .map_err(|_| crate::WhisperError::TranscriptionError("Stream closed".into()))
125    }
126
127    /// Receive transcribed segments if available
128    pub async fn recv_segments(&mut self) -> Option<Vec<Segment>> {
129        self.segment_rx.recv().await
130    }
131
132    /// Try to receive segments without blocking
133    pub fn try_recv_segments(&mut self) -> Option<Vec<Segment>> {
134        self.segment_rx.try_recv().ok()
135    }
136
137    /// Flush the stream and get all pending segments
138    pub async fn flush(&self) -> Result<Vec<Segment>> {
139        let (tx, rx) = oneshot::channel();
140        self.audio_tx
141            .send(AudioCommand::Flush(tx))
142            .await
143            .map_err(|_| crate::WhisperError::TranscriptionError("Stream closed".into()))?;
144        rx.await
145            .map_err(|_| crate::WhisperError::TranscriptionError("Failed to flush".into()))
146    }
147
148    /// Stop the stream gracefully
149    pub async fn stop(self) -> Result<()> {
150        let _ = self.audio_tx.send(AudioCommand::Stop).await;
151        self.handle
152            .await
153            .map_err(|e| crate::WhisperError::TranscriptionError(e.to_string()))?
154    }
155}
156
157/// A shared async stream for multiple producers
158#[cfg(feature = "async")]
159pub struct SharedAsyncStream {
160    inner: Arc<Mutex<AsyncStreamInner>>,
161}
162
163#[cfg(feature = "async")]
164struct AsyncStreamInner {
165    stream: WhisperStream,
166    pending_segments: Vec<Segment>,
167}
168
169#[cfg(feature = "async")]
170impl SharedAsyncStream {
171    /// Create a new shared async stream
172    pub async fn new(
173        context: &WhisperContext,
174        params: FullParams,
175        config: WhisperStreamConfig,
176    ) -> Result<Self> {
177        let stream = WhisperStream::with_config(context, params, config)?;
178
179        Ok(Self {
180            inner: Arc::new(Mutex::new(AsyncStreamInner {
181                stream,
182                pending_segments: Vec::new(),
183            })),
184        })
185    }
186
187    /// Feed audio and get segments atomically
188    pub async fn feed_and_process(&self, audio: Vec<f32>) -> Result<Vec<Segment>> {
189        let mut inner = self.inner.lock().await;
190
191        // Feed audio
192        inner.stream.feed_audio(&audio);
193
194        // Process available audio
195        let mut segments = Vec::new();
196        while let Some(segs) = inner.stream.process_step()? {
197            segments.extend(segs);
198        }
199
200        inner.pending_segments.extend(segments.clone());
201
202        Ok(segments)
203    }
204
205    /// Drain all pending segments
206    pub async fn drain_segments(&self) -> Vec<Segment> {
207        let mut inner = self.inner.lock().await;
208        std::mem::take(&mut inner.pending_segments)
209    }
210
211    /// Flush the stream
212    pub async fn flush(&self) -> Result<Vec<Segment>> {
213        let mut inner = self.inner.lock().await;
214        let segments = inner.stream.flush()?;
215        inner.pending_segments.extend(segments.clone());
216        Ok(segments)
217    }
218}
219
220#[cfg(all(test, feature = "async"))]
221mod tests {
222    use super::*;
223    use crate::SamplingStrategy;
224    use std::path::Path;
225
226    #[tokio::test]
227    async fn test_async_transcribe() {
228        let model_path = "tests/models/ggml-tiny.en.bin";
229        if Path::new(model_path).exists() {
230            let ctx = WhisperContext::new(model_path).unwrap();
231            let audio = vec![0.0f32; 16000]; // 1 second of silence
232
233            let result = ctx.transcribe_async(audio).await;
234            assert!(result.is_ok());
235        }
236    }
237
238    #[tokio::test]
239    async fn test_async_stream() {
240        let model_path = "tests/models/ggml-tiny.en.bin";
241        if Path::new(model_path).exists() {
242            let ctx = WhisperContext::new(model_path).unwrap();
243            let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
244
245            let stream = AsyncWhisperStream::new(ctx, params);
246            assert!(stream.is_ok());
247
248            let stream = stream.unwrap();
249
250            // Feed some audio
251            let audio = vec![0.0f32; 16000];
252            let result = stream.feed_audio(audio).await;
253            assert!(result.is_ok());
254
255            // Stop the stream
256            let result = stream.stop().await;
257            assert!(result.is_ok());
258        }
259    }
260
261    #[tokio::test]
262    async fn test_shared_stream() {
263        let model_path = "tests/models/ggml-tiny.en.bin";
264        if Path::new(model_path).exists() {
265            let ctx = WhisperContext::new(model_path).unwrap();
266            let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
267
268            let stream = SharedAsyncStream::new(&ctx, params, WhisperStreamConfig::default()).await;
269            assert!(stream.is_ok());
270
271            let stream = stream.unwrap();
272
273            // Feed audio from multiple tasks
274            let stream1 = stream.clone();
275            let handle1 = tokio::spawn(async move {
276                let audio = vec![0.0f32; 16000];
277                stream1.feed_and_process(audio).await
278            });
279
280            let stream2 = stream.clone();
281            let handle2 = tokio::spawn(async move {
282                let audio = vec![0.0f32; 16000];
283                stream2.feed_and_process(audio).await
284            });
285
286            // Wait for both
287            let result1 = handle1.await.unwrap();
288            let result2 = handle2.await.unwrap();
289
290            assert!(result1.is_ok());
291            assert!(result2.is_ok());
292        }
293    }
294}
295
296// Implement Clone for SharedAsyncStream
297#[cfg(feature = "async")]
298impl Clone for SharedAsyncStream {
299    fn clone(&self) -> Self {
300        Self {
301            inner: self.inner.clone(),
302        }
303    }
304}