whisper_stream_rs/
whisper_stream.rs1use std::sync::mpsc::{self, Receiver};
2use std::thread;
3use crate::model::Model;
4
5#[derive(Debug)]
9pub enum Event {
10 ProvisionalLiveUpdate { text: String, is_low_quality: bool },
18
19 SegmentTranscript { text: String, is_low_quality: bool },
24
25 SystemMessage(String),
27 Error(crate::error::WhisperStreamError),
29}
30
31pub struct WhisperStream {
48 }
50
51pub struct WhisperStreamBuilder {
66 device: Option<String>,
67 language: Option<String>,
68 record_to_wav: Option<String>,
69 step_ms: u32,
70 length_ms: u32,
71 keep_ms: u32,
72 max_tokens: i32,
73 n_threads: i32,
74 compute_partials: bool,
75 logging_enabled: bool,
76 model: Option<Model>,
77}
78
79impl WhisperStreamBuilder {
80 pub fn device(mut self, name: &str) -> Self {
81 self.device = Some(name.to_string());
82 self
83 }
84 pub fn language(mut self, lang: &str) -> Self {
85 self.language = Some(lang.to_string());
86 self
87 }
88 pub fn record_to_wav(mut self, path: &str) -> Self {
89 self.record_to_wav = Some(path.to_string());
90 self
91 }
92 pub fn step_ms(mut self, ms: u32) -> Self {
93 self.step_ms = ms;
94 self
95 }
96 pub fn length_ms(mut self, ms: u32) -> Self {
97 self.length_ms = ms;
98 self
99 }
100 pub fn keep_ms(mut self, ms: u32) -> Self {
101 self.keep_ms = ms;
102 self
103 }
104 pub fn max_tokens(mut self, n: i32) -> Self {
105 self.max_tokens = n;
106 self
107 }
108 pub fn n_threads(mut self, n: i32) -> Self {
109 self.n_threads = n;
110 self
111 }
112 pub fn compute_partials(mut self, enabled: bool) -> Self {
113 self.compute_partials = enabled;
114 self
115 }
116 pub fn disable_logging(mut self) -> Self {
117 self.logging_enabled = false;
118 self
119 }
120 pub fn model(mut self, model: Model) -> Self {
121 self.model = Some(model);
122 self
123 }
124 pub fn build(self) -> Result<(WhisperStream, Receiver<Event>), crate::error::WhisperStreamError> {
125 if self.logging_enabled {
127 whisper_rs::install_logging_hooks();
129 }
130
131 let (tx, rx) = mpsc::channel();
132 let config = self;
133 let selected_model = config.model.unwrap_or(Model::BaseEn);
134 thread::spawn(move || {
135 use crate::model::ensure_model;
136 use crate::audio::{AudioInput};
137 use crate::audio_utils::{pad_audio_if_needed, WavAudioRecorder};
138 use whisper_rs::{WhisperContext, WhisperContextParameters, FullParams, SamplingStrategy};
139 use log::info;
140 use std::sync::Arc;
141
142 const MIN_WHISPER_SAMPLES: usize = 16800; let model_path = match ensure_model(selected_model) {
145 Ok(p) => p,
146 Err(e) => {
147 let _ = tx.send(Event::Error(e));
148 return;
149 }
150 };
151
152 let system_info = whisper_rs::print_system_info();
153 info!("Whisper System Info: \n{}", system_info);
154
155 let ctx = match WhisperContext::new_with_params(
156 model_path.to_str().unwrap_or("invalid_model_path"),
157 WhisperContextParameters::default(),
158 ) {
159 Ok(c) => c,
160 Err(e) => {
161 let _ = tx.send(Event::Error(crate::error::WhisperStreamError::from(e)));
162 return;
163 }
164 };
165
166 let audio_input = match AudioInput::new(config.device.as_deref(), config.step_ms) {
167 Ok(input) => input,
168 Err(e) => {
169 let _ = tx.send(Event::Error(e));
170 return;
171 }
172 };
173 let audio_rx = audio_input.start_capture_16k();
174 let sample_rate = 16000;
175 let n_samples_window = (sample_rate as f32 * (config.length_ms as f32 / 1000.0)) as usize;
176 let n_samples_overlap = (sample_rate as f32 * (config.keep_ms as f32 / 1000.0)) as usize;
177 let mut segment_window: Vec<f32> = Vec::with_capacity(n_samples_window);
178 let mut state = match ctx.create_state() {
179 Ok(s) => s,
180 Err(e) => {
181 let _ = tx.send(Event::Error(crate::error::WhisperStreamError::from(e)));
182 return;
183 }
184 };
185
186 let mut params_full = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
187 params_full.set_n_threads(config.n_threads);
188 params_full.set_max_tokens(config.max_tokens);
189 params_full.set_print_special(false);
190 params_full.set_print_progress(false);
191 params_full.set_print_realtime(false);
192 params_full.set_print_timestamps(false);
193 if let Some(ref lang) = config.language {
194 params_full.set_language(Some(lang));
195 }
196 let arc_params_full = Arc::new(params_full);
197
198 let mut wav_audio_recorder = match WavAudioRecorder::new(config.record_to_wav.as_deref()) {
199 Ok(recorder) => recorder,
200 Err(e) => {
201 let _ = tx.send(Event::Error(e));
202 match WavAudioRecorder::new(None) {
203 Ok(no_op_recorder) => no_op_recorder,
204 Err(_) => return,
205 }
206 }
207 };
208
209 if wav_audio_recorder.is_recording() {
210 if let Some(path_str) = config.record_to_wav.as_ref() {
211 info!("[Recording] Saving transcribed audio to {}...", path_str);
212 let _ = tx.send(Event::SystemMessage(format!("[Recording] Saving transcribed audio to {}...", path_str)));
213 }
214 }
215
216 for pcmf32_new_result in audio_rx {
217 let pcmf32_new = match pcmf32_new_result {
218 Ok(audio_data) => {
219 if audio_data.is_empty() {
220 continue;
221 }
222 audio_data
223 }
224 Err(audio_err) => {
225 let _ = tx.send(Event::Error(audio_err));
226 continue;
227 }
228 };
229
230 if wav_audio_recorder.is_recording() {
231 if let Err(e) = wav_audio_recorder.write_audio_chunk(&pcmf32_new) {
232 let _ = tx.send(Event::Error(e));
233 }
234 }
235
236 segment_window.extend_from_slice(&pcmf32_new);
237 let audio_for_processing = pad_audio_if_needed(&segment_window, MIN_WHISPER_SAMPLES);
238
239 if let Err(e) = state.full(arc_params_full.as_ref().clone(), &audio_for_processing) {
240 let _ = tx.send(Event::Error(crate::error::WhisperStreamError::from(e)));
241 continue;
242 }
243
244 let mut current_text = String::new();
245 match state.full_n_segments() {
246 Ok(num_segments) => {
247 for i in 0..num_segments {
248 match state.full_get_segment_text(i) {
249 Ok(seg) => current_text.push_str(&seg),
250 Err(e) => {
251 let _ = tx.send(Event::Error(crate::error::WhisperStreamError::from(e)));
252 break;
253 }
254 }
255 }
256 }
257 Err(e) => {
258 let _ = tx.send(Event::Error(crate::error::WhisperStreamError::from(e)));
259 continue;
260 }
261 }
262
263 if !current_text.trim().is_empty() {
264 let is_low_quality = crate::score::is_low_quality_output(¤t_text);
265 if segment_window.len() >= n_samples_window {
266 let _ = tx.send(Event::SegmentTranscript { text: current_text.clone(), is_low_quality });
267 } else if config.compute_partials {
268 let _ = tx.send(Event::ProvisionalLiveUpdate { text: current_text.clone(), is_low_quality });
269 }
270 }
271
272 if segment_window.len() >= n_samples_window {
273 if n_samples_overlap > 0 && segment_window.len() > n_samples_overlap {
274 segment_window = segment_window[segment_window.len() - n_samples_overlap..].to_vec();
275 } else {
276 segment_window.clear();
277 }
278 }
279 }
280
281 if !segment_window.is_empty() {
282 let final_audio_for_processing = pad_audio_if_needed(&segment_window, MIN_WHISPER_SAMPLES);
283 if let Err(e) = state.full(arc_params_full.as_ref().clone(), &final_audio_for_processing) {
284 let _ = tx.send(Event::Error(crate::error::WhisperStreamError::from(e)));
285 } else {
286 let mut final_text = String::new();
287 match state.full_n_segments() {
288 Ok(num_segments) => {
289 for i in 0..num_segments {
290 match state.full_get_segment_text(i) {
291 Ok(seg) => final_text.push_str(&seg),
292 Err(e) => {
293 let _ = tx.send(Event::Error(crate::error::WhisperStreamError::from(e)));
294 break;
295 }
296 }
297 }
298 }
299 Err(e) => {
300 let _ = tx.send(Event::Error(crate::error::WhisperStreamError::from(e)));
301 }
302 }
303 if !final_text.trim().is_empty() {
304 let is_low_quality = crate::score::is_low_quality_output(&final_text);
305 let _ = tx.send(Event::SegmentTranscript { text: final_text, is_low_quality });
306 }
307 }
308 }
309
310 match wav_audio_recorder.finalize() {
311 Ok(Some(msg)) => {
312 info!("{}", msg);
313 let _ = tx.send(Event::SystemMessage(msg));
314 }
315 Ok(None) => { }
316 Err(e) => {
317 let _ = tx.send(Event::Error(e));
318 }
319 }
320 });
321 Ok((WhisperStream {}, rx))
322 }
323}
324
325impl WhisperStream {
326 pub fn builder() -> WhisperStreamBuilder {
327 WhisperStreamBuilder {
328 device: None,
329 language: Some("en".to_string()),
330 record_to_wav: None,
331 step_ms: 800,
332 length_ms: 5000,
333 keep_ms: 200,
334 max_tokens: 32,
335 n_threads: std::thread::available_parallelism().map(|n| n.get() as i32).unwrap_or(8),
336 compute_partials: true,
337 logging_enabled: true,
338 model: None,
339 }
340 }
341 pub fn list_devices() -> Result<Vec<String>, crate::error::WhisperStreamError> {
342 crate::audio::AudioInput::available_input_devices()
343 }
344 pub fn list_models() -> Vec<Model> {
345 Model::list()
346 }
347 pub fn start(&mut self) -> Result<(), crate::error::WhisperStreamError> {
348 Ok(())
350 }
351 pub fn stop(&mut self) -> Result<(), crate::error::WhisperStreamError> {
352 Ok(())
354 }
355}