whisper_cpp_plus/
async_api.rs1#[cfg(feature = "async")]
7use crate::{
8 context::WhisperContext,
9 error::Result,
10 params::FullParams,
11 state::WhisperState,
12 stream::{WhisperStream, WhisperStreamConfig},
13 Segment, TranscriptionResult,
14};
15#[cfg(feature = "async")]
16use std::sync::Arc;
17#[cfg(feature = "async")]
18use tokio::sync::{mpsc, oneshot, Mutex};
19#[cfg(feature = "async")]
20use tokio::task;
21
22#[cfg(feature = "async")]
23impl WhisperContext {
24 pub async fn transcribe_async(&self, audio: Vec<f32>) -> Result<String> {
26 let ctx = self.clone();
27 task::spawn_blocking(move || ctx.transcribe(&audio))
28 .await
29 .map_err(|e| crate::WhisperError::TranscriptionError(e.to_string()))?
30 }
31
32 pub async fn transcribe_with_params_async(
34 &self,
35 audio: Vec<f32>,
36 params: crate::TranscriptionParams,
37 ) -> Result<TranscriptionResult> {
38 let ctx = self.clone();
39 task::spawn_blocking(move || ctx.transcribe_with_params(&audio, params))
40 .await
41 .map_err(|e| crate::WhisperError::TranscriptionError(e.to_string()))?
42 }
43
44 pub async fn create_state_async(&self) -> Result<WhisperState> {
46 let ctx = self.clone();
47 task::spawn_blocking(move || ctx.create_state())
48 .await
49 .map_err(|e| crate::WhisperError::TranscriptionError(e.to_string()))?
50 }
51}
52
53#[cfg(feature = "async")]
55pub struct AsyncWhisperStream {
56 audio_tx: mpsc::Sender<AudioCommand>,
58 segment_rx: mpsc::Receiver<Vec<Segment>>,
60 handle: task::JoinHandle<Result<()>>,
62}
63
64#[cfg(feature = "async")]
65enum AudioCommand {
66 Feed(Vec<f32>),
67 Flush(oneshot::Sender<Vec<Segment>>),
68 Stop,
69}
70
71#[cfg(feature = "async")]
72impl AsyncWhisperStream {
73 pub fn new(context: WhisperContext, params: FullParams) -> Result<Self> {
75 Self::with_config(context, params, WhisperStreamConfig::default())
76 }
77
78 pub fn with_config(
80 context: WhisperContext,
81 params: FullParams,
82 config: WhisperStreamConfig,
83 ) -> Result<Self> {
84 let (audio_tx, mut audio_rx) = mpsc::channel::<AudioCommand>(100);
85 let (segment_tx, segment_rx) = mpsc::channel::<Vec<Segment>>(100);
86
87 let handle = task::spawn_blocking(move || {
88 let mut stream = WhisperStream::with_config(&context, params, config)?;
89
90 while let Some(cmd) = audio_rx.blocking_recv() {
91 match cmd {
92 AudioCommand::Feed(audio) => {
93 stream.feed_audio(&audio);
94
95 while let Some(segments) = stream.process_step()? {
97 if !segments.is_empty() {
98 let _ = segment_tx.blocking_send(segments);
99 }
100 }
101 }
102 AudioCommand::Flush(response) => {
103 let segments = stream.flush()?;
104 let _ = response.send(segments);
105 }
106 AudioCommand::Stop => break,
107 }
108 }
109
110 Ok(())
111 });
112
113 Ok(Self {
114 audio_tx,
115 segment_rx,
116 handle,
117 })
118 }
119
120 pub async fn feed_audio(&self, audio: Vec<f32>) -> Result<()> {
122 self.audio_tx
123 .send(AudioCommand::Feed(audio))
124 .await
125 .map_err(|_| crate::WhisperError::TranscriptionError("Stream closed".into()))
126 }
127
128 pub async fn recv_segments(&mut self) -> Option<Vec<Segment>> {
130 self.segment_rx.recv().await
131 }
132
133 pub fn try_recv_segments(&mut self) -> Option<Vec<Segment>> {
135 self.segment_rx.try_recv().ok()
136 }
137
138 pub async fn flush(&self) -> Result<Vec<Segment>> {
140 let (tx, rx) = oneshot::channel();
141 self.audio_tx
142 .send(AudioCommand::Flush(tx))
143 .await
144 .map_err(|_| crate::WhisperError::TranscriptionError("Stream closed".into()))?;
145 rx.await
146 .map_err(|_| crate::WhisperError::TranscriptionError("Failed to flush".into()))
147 }
148
149 pub async fn stop(self) -> Result<()> {
151 let _ = self.audio_tx.send(AudioCommand::Stop).await;
152 self.handle
153 .await
154 .map_err(|e| crate::WhisperError::TranscriptionError(e.to_string()))?
155 }
156}
157
158#[cfg(feature = "async")]
160pub struct SharedAsyncStream {
161 inner: Arc<Mutex<AsyncStreamInner>>,
162}
163
164#[cfg(feature = "async")]
165struct AsyncStreamInner {
166 stream: WhisperStream,
167 pending_segments: Vec<Segment>,
168}
169
170#[cfg(feature = "async")]
171impl SharedAsyncStream {
172 pub async fn new(
174 context: &WhisperContext,
175 params: FullParams,
176 config: WhisperStreamConfig,
177 ) -> Result<Self> {
178 let stream = WhisperStream::with_config(context, params, config)?;
179
180 Ok(Self {
181 inner: Arc::new(Mutex::new(AsyncStreamInner {
182 stream,
183 pending_segments: Vec::new(),
184 })),
185 })
186 }
187
188 pub async fn feed_and_process(&self, audio: Vec<f32>) -> Result<Vec<Segment>> {
190 let mut inner = self.inner.lock().await;
191
192 inner.stream.feed_audio(&audio);
194
195 let mut segments = Vec::new();
197 while let Some(segs) = inner.stream.process_step()? {
198 segments.extend(segs);
199 }
200
201 inner.pending_segments.extend(segments.clone());
202
203 Ok(segments)
204 }
205
206 pub async fn drain_segments(&self) -> Vec<Segment> {
208 let mut inner = self.inner.lock().await;
209 std::mem::take(&mut inner.pending_segments)
210 }
211
212 pub async fn flush(&self) -> Result<Vec<Segment>> {
214 let mut inner = self.inner.lock().await;
215 let segments = inner.stream.flush()?;
216 inner.pending_segments.extend(segments.clone());
217 Ok(segments)
218 }
219}
220
221#[cfg(all(test, feature = "async"))]
222mod tests {
223 use super::*;
224 use crate::SamplingStrategy;
225 use std::path::Path;
226
227 #[tokio::test]
228 async fn test_async_transcribe() {
229 let model_path = "tests/models/ggml-tiny.en.bin";
230 if Path::new(model_path).exists() {
231 let ctx = WhisperContext::new(model_path).unwrap();
232 let audio = vec![0.0f32; 16000]; let result = ctx.transcribe_async(audio).await;
235 assert!(result.is_ok());
236 }
237 }
238
239 #[tokio::test]
240 async fn test_async_stream() {
241 let model_path = "tests/models/ggml-tiny.en.bin";
242 if Path::new(model_path).exists() {
243 let ctx = WhisperContext::new(model_path).unwrap();
244 let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
245
246 let stream = AsyncWhisperStream::new(ctx, params);
247 assert!(stream.is_ok());
248
249 let stream = stream.unwrap();
250
251 let audio = vec![0.0f32; 16000];
253 let result = stream.feed_audio(audio).await;
254 assert!(result.is_ok());
255
256 let result = stream.stop().await;
258 assert!(result.is_ok());
259 }
260 }
261
262 #[tokio::test]
263 async fn test_shared_stream() {
264 let model_path = "tests/models/ggml-tiny.en.bin";
265 if Path::new(model_path).exists() {
266 let ctx = WhisperContext::new(model_path).unwrap();
267 let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
268
269 let stream = SharedAsyncStream::new(&ctx, params, WhisperStreamConfig::default()).await;
270 assert!(stream.is_ok());
271
272 let stream = stream.unwrap();
273
274 let stream1 = stream.clone();
276 let handle1 = tokio::spawn(async move {
277 let audio = vec![0.0f32; 16000];
278 stream1.feed_and_process(audio).await
279 });
280
281 let stream2 = stream.clone();
282 let handle2 = tokio::spawn(async move {
283 let audio = vec![0.0f32; 16000];
284 stream2.feed_and_process(audio).await
285 });
286
287 let result1 = handle1.await.unwrap();
289 let result2 = handle2.await.unwrap();
290
291 assert!(result1.is_ok());
292 assert!(result2.is_ok());
293 }
294 }
295}
296
297#[cfg(feature = "async")]
299impl Clone for SharedAsyncStream {
300 fn clone(&self) -> Self {
301 Self {
302 inner: self.inner.clone(),
303 }
304 }
305}