whisper_stream_rs/
whisper_stream.rs

1use std::sync::mpsc::{self, Receiver};
2use std::thread;
3use crate::model::Model;
4
5/// Events emitted by the transcription stream.
6///
7/// These are sent through the channel returned by [`WhisperStreamBuilder::build`].
8#[derive(Debug)]
9pub enum Event {
10    /// A provisional, live text update. This is an intermediate result, suitable for displaying
11    /// real-time feedback. It is subject to change and will be superseded by subsequent
12    /// `ProvisionalLiveUpdate` messages (providing more refined guesses for the same ongoing audio)
13    /// or ultimately by a `SegmentTranscript` for that audio segment.
14    /// These should not be stored or considered definitive.
15    ///
16    /// `is_low_quality` is true if the text is considered low quality by the detector.
17    ProvisionalLiveUpdate { text: String, is_low_quality: bool },
18
19    /// The final and complete transcription for a specific audio segment window.
20    /// This is the version of the transcript that should be considered the actual output for that portion of audio.
21    ///
22    /// `is_low_quality` is true if the text is considered low quality by the detector.
23    SegmentTranscript { text: String, is_low_quality: bool },
24
25    /// System messages (e.g., recording status, warnings).
26    SystemMessage(String),
27    /// Errors encountered during processing.
28    Error(crate::error::WhisperStreamError),
29}
30
31/// Main entry point for configuring and running a Whisper transcription stream.
32///
33/// Use [`WhisperStream::builder()`] to create a [`WhisperStreamBuilder`], configure options,
34/// and call `.build()` to start streaming and receive events.
35///
36/// Example:
37/// ```no_run
38/// use whisper_stream_rs::{WhisperStream, Event, Model};
39/// let (_stream, rx) = WhisperStream::builder().model(Model::TinyEn).build().unwrap();
40/// for event in rx {
41///     match event {
42///         Event::SegmentTranscript { text, .. } => println!("Final: {}", text),
43///         _ => {}
44///     }
45/// }
46/// ```
47pub struct WhisperStream {
48    // Will own the background thread and config
49}
50
51/// Builder for [`WhisperStream`].
52///
53/// All configuration is set via builder methods. Call `.build()` to start streaming and receive events.
54///
55/// Example:
56/// ```no_run
57/// use whisper_stream_rs::{WhisperStream, Model};
58/// let (_stream, rx) = WhisperStream::builder()
59///     .device("MacBook Pro Microphone")
60///     .language("en")
61///     .model(Model::SmallEn)
62///     .build()
63///     .unwrap();
64/// ```
65pub struct WhisperStreamBuilder {
66    device: Option<String>,
67    language: Option<String>,
68    record_to_wav: Option<String>,
69    step_ms: u32,
70    length_ms: u32,
71    keep_ms: u32,
72    max_tokens: i32,
73    n_threads: i32,
74    compute_partials: bool,
75    logging_enabled: bool,
76    model: Option<Model>,
77}
78
79impl WhisperStreamBuilder {
80    pub fn device(mut self, name: &str) -> Self {
81        self.device = Some(name.to_string());
82        self
83    }
84    pub fn language(mut self, lang: &str) -> Self {
85        self.language = Some(lang.to_string());
86        self
87    }
88    pub fn record_to_wav(mut self, path: &str) -> Self {
89        self.record_to_wav = Some(path.to_string());
90        self
91    }
92    pub fn step_ms(mut self, ms: u32) -> Self {
93        self.step_ms = ms;
94        self
95    }
96    pub fn length_ms(mut self, ms: u32) -> Self {
97        self.length_ms = ms;
98        self
99    }
100    pub fn keep_ms(mut self, ms: u32) -> Self {
101        self.keep_ms = ms;
102        self
103    }
104    pub fn max_tokens(mut self, n: i32) -> Self {
105        self.max_tokens = n;
106        self
107    }
108    pub fn n_threads(mut self, n: i32) -> Self {
109        self.n_threads = n;
110        self
111    }
112    pub fn compute_partials(mut self, enabled: bool) -> Self {
113        self.compute_partials = enabled;
114        self
115    }
116    pub fn disable_logging(mut self) -> Self {
117        self.logging_enabled = false;
118        self
119    }
120    pub fn model(mut self, model: Model) -> Self {
121        self.model = Some(model);
122        self
123    }
124    pub fn build(self) -> Result<(WhisperStream, Receiver<Event>), crate::error::WhisperStreamError> {
125        // Set up logging if enabled
126        if self.logging_enabled {
127            // Safe to call multiple times; only installs once
128            whisper_rs::install_logging_hooks();
129        }
130
131        let (tx, rx) = mpsc::channel();
132        let config = self;
133        let selected_model = config.model.unwrap_or(Model::BaseEn);
134        thread::spawn(move || {
135            use crate::model::ensure_model;
136            use crate::audio::{AudioInput};
137            use crate::audio_utils::{pad_audio_if_needed, WavAudioRecorder};
138            use whisper_rs::{WhisperContext, WhisperContextParameters, FullParams, SamplingStrategy};
139            use log::info;
140            use std::sync::Arc;
141
142            const MIN_WHISPER_SAMPLES: usize = 16800; // 1050ms at 16kHz (increased buffer)
143
144            let model_path = match ensure_model(selected_model) {
145                Ok(p) => p,
146                Err(e) => {
147                    let _ = tx.send(Event::Error(e));
148                    return;
149                }
150            };
151
152            let system_info = whisper_rs::print_system_info();
153            info!("Whisper System Info: \n{}", system_info);
154
155            let ctx = match WhisperContext::new_with_params(
156                model_path.to_str().unwrap_or("invalid_model_path"),
157                WhisperContextParameters::default(),
158            ) {
159                Ok(c) => c,
160                Err(e) => {
161                    let _ = tx.send(Event::Error(crate::error::WhisperStreamError::from(e)));
162                    return;
163                }
164            };
165
166            let audio_input = match AudioInput::new(config.device.as_deref(), config.step_ms) {
167                Ok(input) => input,
168                Err(e) => {
169                    let _ = tx.send(Event::Error(e));
170                    return;
171                }
172            };
173            let audio_rx = audio_input.start_capture_16k();
174            let sample_rate = 16000;
175            let n_samples_window = (sample_rate as f32 * (config.length_ms as f32 / 1000.0)) as usize;
176            let n_samples_overlap = (sample_rate as f32 * (config.keep_ms as f32 / 1000.0)) as usize;
177            let mut segment_window: Vec<f32> = Vec::with_capacity(n_samples_window);
178            let mut state = match ctx.create_state() {
179                Ok(s) => s,
180                Err(e) => {
181                    let _ = tx.send(Event::Error(crate::error::WhisperStreamError::from(e)));
182                    return;
183                }
184            };
185
186            let mut params_full = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
187            params_full.set_n_threads(config.n_threads);
188            params_full.set_max_tokens(config.max_tokens);
189            params_full.set_print_special(false);
190            params_full.set_print_progress(false);
191            params_full.set_print_realtime(false);
192            params_full.set_print_timestamps(false);
193            if let Some(ref lang) = config.language {
194                params_full.set_language(Some(lang));
195            }
196            let arc_params_full = Arc::new(params_full);
197
198            let mut wav_audio_recorder = match WavAudioRecorder::new(config.record_to_wav.as_deref()) {
199                Ok(recorder) => recorder,
200                Err(e) => {
201                    let _ = tx.send(Event::Error(e));
202                    match WavAudioRecorder::new(None) {
203                        Ok(no_op_recorder) => no_op_recorder,
204                        Err(_) => return,
205                    }
206                }
207            };
208
209            if wav_audio_recorder.is_recording() {
210                if let Some(path_str) = config.record_to_wav.as_ref() {
211                    info!("[Recording] Saving transcribed audio to {}...", path_str);
212                    let _ = tx.send(Event::SystemMessage(format!("[Recording] Saving transcribed audio to {}...", path_str)));
213                }
214            }
215
216            for pcmf32_new_result in audio_rx {
217                let pcmf32_new = match pcmf32_new_result {
218                    Ok(audio_data) => {
219                        if audio_data.is_empty() {
220                            continue;
221                        }
222                        audio_data
223                    }
224                    Err(audio_err) => {
225                        let _ = tx.send(Event::Error(audio_err));
226                        continue;
227                    }
228                };
229
230                if wav_audio_recorder.is_recording() {
231                    if let Err(e) = wav_audio_recorder.write_audio_chunk(&pcmf32_new) {
232                        let _ = tx.send(Event::Error(e));
233                    }
234                }
235
236                segment_window.extend_from_slice(&pcmf32_new);
237                let audio_for_processing = pad_audio_if_needed(&segment_window, MIN_WHISPER_SAMPLES);
238
239                if let Err(e) = state.full(arc_params_full.as_ref().clone(), &audio_for_processing) {
240                    let _ = tx.send(Event::Error(crate::error::WhisperStreamError::from(e)));
241                    continue;
242                }
243
244                let mut current_text = String::new();
245                match state.full_n_segments() {
246                    Ok(num_segments) => {
247                        for i in 0..num_segments {
248                            match state.full_get_segment_text(i) {
249                                Ok(seg) => current_text.push_str(&seg),
250                                Err(e) => {
251                                    let _ = tx.send(Event::Error(crate::error::WhisperStreamError::from(e)));
252                                    break;
253                                }
254                            }
255                        }
256                    }
257                    Err(e) => {
258                        let _ = tx.send(Event::Error(crate::error::WhisperStreamError::from(e)));
259                        continue;
260                    }
261                }
262
263                if !current_text.trim().is_empty() {
264                    let is_low_quality = crate::score::is_low_quality_output(&current_text);
265                    if segment_window.len() >= n_samples_window {
266                        let _ = tx.send(Event::SegmentTranscript { text: current_text.clone(), is_low_quality });
267                    } else if config.compute_partials {
268                        let _ = tx.send(Event::ProvisionalLiveUpdate { text: current_text.clone(), is_low_quality });
269                    }
270                }
271
272                if segment_window.len() >= n_samples_window {
273                    if n_samples_overlap > 0 && segment_window.len() > n_samples_overlap {
274                        segment_window = segment_window[segment_window.len() - n_samples_overlap..].to_vec();
275                    } else {
276                        segment_window.clear();
277                    }
278                }
279            }
280
281            if !segment_window.is_empty() {
282                let final_audio_for_processing = pad_audio_if_needed(&segment_window, MIN_WHISPER_SAMPLES);
283                if let Err(e) = state.full(arc_params_full.as_ref().clone(), &final_audio_for_processing) {
284                    let _ = tx.send(Event::Error(crate::error::WhisperStreamError::from(e)));
285                } else {
286                    let mut final_text = String::new();
287                    match state.full_n_segments() {
288                        Ok(num_segments) => {
289                            for i in 0..num_segments {
290                                match state.full_get_segment_text(i) {
291                                    Ok(seg) => final_text.push_str(&seg),
292                                    Err(e) => {
293                                        let _ = tx.send(Event::Error(crate::error::WhisperStreamError::from(e)));
294                                        break;
295                                    }
296                                }
297                            }
298                        }
299                        Err(e) => {
300                            let _ = tx.send(Event::Error(crate::error::WhisperStreamError::from(e)));
301                        }
302                    }
303                    if !final_text.trim().is_empty() {
304                        let is_low_quality = crate::score::is_low_quality_output(&final_text);
305                        let _ = tx.send(Event::SegmentTranscript { text: final_text, is_low_quality });
306                    }
307                }
308            }
309
310            match wav_audio_recorder.finalize() {
311                Ok(Some(msg)) => {
312                    info!("{}", msg);
313                    let _ = tx.send(Event::SystemMessage(msg));
314                }
315                Ok(None) => { /* No recording was active, nothing to report */ }
316                Err(e) => {
317                    let _ = tx.send(Event::Error(e));
318                }
319            }
320        });
321        Ok((WhisperStream {}, rx))
322    }
323}
324
325impl WhisperStream {
326    pub fn builder() -> WhisperStreamBuilder {
327        WhisperStreamBuilder {
328            device: None,
329            language: Some("en".to_string()),
330            record_to_wav: None,
331            step_ms: 800,
332            length_ms: 5000,
333            keep_ms: 200,
334            max_tokens: 32,
335            n_threads: std::thread::available_parallelism().map(|n| n.get() as i32).unwrap_or(8),
336            compute_partials: true,
337            logging_enabled: true,
338            model: None,
339        }
340    }
341    pub fn list_devices() -> Result<Vec<String>, crate::error::WhisperStreamError> {
342        crate::audio::AudioInput::available_input_devices()
343    }
344    pub fn list_models() -> Vec<Model> {
345        Model::list()
346    }
347    pub fn start(&mut self) -> Result<(), crate::error::WhisperStreamError> {
348        // Will start the background thread in next phase
349        Ok(())
350    }
351    pub fn stop(&mut self) -> Result<(), crate::error::WhisperStreamError> {
352        // Will stop the background thread in next phase
353        Ok(())
354    }
355}