whisper_cpp_plus/
async_api.rs1#[cfg(feature = "async")]
7use crate::{
8 context::WhisperContext, error::Result, params::FullParams,
9 state::WhisperState, stream::{WhisperStreamConfig, WhisperStream}, Segment, TranscriptionResult,
10};
11#[cfg(feature = "async")]
12use std::sync::Arc;
13#[cfg(feature = "async")]
14use tokio::sync::{mpsc, oneshot, Mutex};
15#[cfg(feature = "async")]
16use tokio::task;
17
18#[cfg(feature = "async")]
19impl WhisperContext {
20 pub async fn transcribe_async(&self, audio: Vec<f32>) -> Result<String> {
22 let ctx = self.clone();
23 task::spawn_blocking(move || ctx.transcribe(&audio))
24 .await
25 .map_err(|e| crate::WhisperError::TranscriptionError(e.to_string()))?
26 }
27
28 pub async fn transcribe_with_params_async(
30 &self,
31 audio: Vec<f32>,
32 params: crate::TranscriptionParams,
33 ) -> Result<TranscriptionResult> {
34 let ctx = self.clone();
35 task::spawn_blocking(move || ctx.transcribe_with_params(&audio, params))
36 .await
37 .map_err(|e| crate::WhisperError::TranscriptionError(e.to_string()))?
38 }
39
40 pub async fn create_state_async(&self) -> Result<WhisperState> {
42 let ctx = self.clone();
43 task::spawn_blocking(move || ctx.create_state())
44 .await
45 .map_err(|e| crate::WhisperError::TranscriptionError(e.to_string()))?
46 }
47}
48
49#[cfg(feature = "async")]
51pub struct AsyncWhisperStream {
52 audio_tx: mpsc::Sender<AudioCommand>,
54 segment_rx: mpsc::Receiver<Vec<Segment>>,
56 handle: task::JoinHandle<Result<()>>,
58}
59
60#[cfg(feature = "async")]
61enum AudioCommand {
62 Feed(Vec<f32>),
63 Flush(oneshot::Sender<Vec<Segment>>),
64 Stop,
65}
66
67#[cfg(feature = "async")]
68impl AsyncWhisperStream {
69 pub fn new(
71 context: WhisperContext,
72 params: FullParams,
73 ) -> Result<Self> {
74 Self::with_config(context, params, WhisperStreamConfig::default())
75 }
76
77 pub fn with_config(
79 context: WhisperContext,
80 params: FullParams,
81 config: WhisperStreamConfig,
82 ) -> Result<Self> {
83 let (audio_tx, mut audio_rx) = mpsc::channel::<AudioCommand>(100);
84 let (segment_tx, segment_rx) = mpsc::channel::<Vec<Segment>>(100);
85
86 let handle = task::spawn_blocking(move || {
87 let mut stream = WhisperStream::with_config(&context, params, config)?;
88
89 while let Some(cmd) = audio_rx.blocking_recv() {
90 match cmd {
91 AudioCommand::Feed(audio) => {
92 stream.feed_audio(&audio);
93
94 while let Some(segments) = stream.process_step()? {
96 if !segments.is_empty() {
97 let _ = segment_tx.blocking_send(segments);
98 }
99 }
100 }
101 AudioCommand::Flush(response) => {
102 let segments = stream.flush()?;
103 let _ = response.send(segments);
104 }
105 AudioCommand::Stop => break,
106 }
107 }
108
109 Ok(())
110 });
111
112 Ok(Self {
113 audio_tx,
114 segment_rx,
115 handle,
116 })
117 }
118
119 pub async fn feed_audio(&self, audio: Vec<f32>) -> Result<()> {
121 self.audio_tx
122 .send(AudioCommand::Feed(audio))
123 .await
124 .map_err(|_| crate::WhisperError::TranscriptionError("Stream closed".into()))
125 }
126
127 pub async fn recv_segments(&mut self) -> Option<Vec<Segment>> {
129 self.segment_rx.recv().await
130 }
131
132 pub fn try_recv_segments(&mut self) -> Option<Vec<Segment>> {
134 self.segment_rx.try_recv().ok()
135 }
136
137 pub async fn flush(&self) -> Result<Vec<Segment>> {
139 let (tx, rx) = oneshot::channel();
140 self.audio_tx
141 .send(AudioCommand::Flush(tx))
142 .await
143 .map_err(|_| crate::WhisperError::TranscriptionError("Stream closed".into()))?;
144 rx.await
145 .map_err(|_| crate::WhisperError::TranscriptionError("Failed to flush".into()))
146 }
147
148 pub async fn stop(self) -> Result<()> {
150 let _ = self.audio_tx.send(AudioCommand::Stop).await;
151 self.handle
152 .await
153 .map_err(|e| crate::WhisperError::TranscriptionError(e.to_string()))?
154 }
155}
156
157#[cfg(feature = "async")]
159pub struct SharedAsyncStream {
160 inner: Arc<Mutex<AsyncStreamInner>>,
161}
162
163#[cfg(feature = "async")]
164struct AsyncStreamInner {
165 stream: WhisperStream,
166 pending_segments: Vec<Segment>,
167}
168
169#[cfg(feature = "async")]
170impl SharedAsyncStream {
171 pub async fn new(
173 context: &WhisperContext,
174 params: FullParams,
175 config: WhisperStreamConfig,
176 ) -> Result<Self> {
177 let stream = WhisperStream::with_config(context, params, config)?;
178
179 Ok(Self {
180 inner: Arc::new(Mutex::new(AsyncStreamInner {
181 stream,
182 pending_segments: Vec::new(),
183 })),
184 })
185 }
186
187 pub async fn feed_and_process(&self, audio: Vec<f32>) -> Result<Vec<Segment>> {
189 let mut inner = self.inner.lock().await;
190
191 inner.stream.feed_audio(&audio);
193
194 let mut segments = Vec::new();
196 while let Some(segs) = inner.stream.process_step()? {
197 segments.extend(segs);
198 }
199
200 inner.pending_segments.extend(segments.clone());
201
202 Ok(segments)
203 }
204
205 pub async fn drain_segments(&self) -> Vec<Segment> {
207 let mut inner = self.inner.lock().await;
208 std::mem::take(&mut inner.pending_segments)
209 }
210
211 pub async fn flush(&self) -> Result<Vec<Segment>> {
213 let mut inner = self.inner.lock().await;
214 let segments = inner.stream.flush()?;
215 inner.pending_segments.extend(segments.clone());
216 Ok(segments)
217 }
218}
219
220#[cfg(all(test, feature = "async"))]
221mod tests {
222 use super::*;
223 use crate::SamplingStrategy;
224 use std::path::Path;
225
226 #[tokio::test]
227 async fn test_async_transcribe() {
228 let model_path = "tests/models/ggml-tiny.en.bin";
229 if Path::new(model_path).exists() {
230 let ctx = WhisperContext::new(model_path).unwrap();
231 let audio = vec![0.0f32; 16000]; let result = ctx.transcribe_async(audio).await;
234 assert!(result.is_ok());
235 }
236 }
237
238 #[tokio::test]
239 async fn test_async_stream() {
240 let model_path = "tests/models/ggml-tiny.en.bin";
241 if Path::new(model_path).exists() {
242 let ctx = WhisperContext::new(model_path).unwrap();
243 let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
244
245 let stream = AsyncWhisperStream::new(ctx, params);
246 assert!(stream.is_ok());
247
248 let stream = stream.unwrap();
249
250 let audio = vec![0.0f32; 16000];
252 let result = stream.feed_audio(audio).await;
253 assert!(result.is_ok());
254
255 let result = stream.stop().await;
257 assert!(result.is_ok());
258 }
259 }
260
261 #[tokio::test]
262 async fn test_shared_stream() {
263 let model_path = "tests/models/ggml-tiny.en.bin";
264 if Path::new(model_path).exists() {
265 let ctx = WhisperContext::new(model_path).unwrap();
266 let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
267
268 let stream = SharedAsyncStream::new(&ctx, params, WhisperStreamConfig::default()).await;
269 assert!(stream.is_ok());
270
271 let stream = stream.unwrap();
272
273 let stream1 = stream.clone();
275 let handle1 = tokio::spawn(async move {
276 let audio = vec![0.0f32; 16000];
277 stream1.feed_and_process(audio).await
278 });
279
280 let stream2 = stream.clone();
281 let handle2 = tokio::spawn(async move {
282 let audio = vec![0.0f32; 16000];
283 stream2.feed_and_process(audio).await
284 });
285
286 let result1 = handle1.await.unwrap();
288 let result2 = handle2.await.unwrap();
289
290 assert!(result1.is_ok());
291 assert!(result2.is_ok());
292 }
293 }
294}
295
296#[cfg(feature = "async")]
298impl Clone for SharedAsyncStream {
299 fn clone(&self) -> Self {
300 Self {
301 inner: self.inner.clone(),
302 }
303 }
304}