1use crate::context::WhisperContext;
7use crate::error::Result;
8use crate::params::FullParams;
9use crate::state::{Segment, WhisperState};
10use std::collections::VecDeque;
11
12const WHISPER_SAMPLE_RATE: i32 = 16000;
13
14#[derive(Debug, Clone)]
20pub struct WhisperStreamConfig {
21 pub step_ms: i32,
23 pub length_ms: i32,
25 pub keep_ms: i32,
27 pub vad_thold: f32,
29 pub freq_thold: f32,
31 pub no_context: bool,
33}
34
35impl Default for WhisperStreamConfig {
36 fn default() -> Self {
37 Self {
38 step_ms: 3000,
39 length_ms: 10000,
40 keep_ms: 200,
41 vad_thold: 0.6,
42 freq_thold: 100.0,
43 no_context: true,
44 }
45 }
46}
47
48pub struct WhisperStream {
58 state: WhisperState,
59 params: FullParams,
60 config: WhisperStreamConfig,
61 use_vad: bool,
62
63 n_samples_step: usize,
65 n_samples_len: usize,
66 n_samples_keep: usize,
67 n_new_line: i32,
68
69 pcmf32_old: Vec<f32>,
71 prompt_tokens: Vec<i32>,
73
74 n_iter: i32,
75
76 audio_buf: VecDeque<f32>,
78
79 total_samples_processed: i64,
81}
82
83impl WhisperStream {
84 pub fn new(ctx: &WhisperContext, params: FullParams) -> Result<Self> {
86 Self::with_config(ctx, params, WhisperStreamConfig::default())
87 }
88
89 pub fn with_config(
91 ctx: &WhisperContext,
92 mut params: FullParams,
93 mut config: WhisperStreamConfig,
94 ) -> Result<Self> {
95 let state = WhisperState::new(ctx)?;
96
97 config.keep_ms = config.keep_ms.min(config.step_ms);
99 config.length_ms = config.length_ms.max(config.step_ms);
100
101 let n_samples_step =
103 (1e-3 * config.step_ms as f64 * WHISPER_SAMPLE_RATE as f64) as usize;
104 let n_samples_len =
105 (1e-3 * config.length_ms as f64 * WHISPER_SAMPLE_RATE as f64) as usize;
106 let n_samples_keep =
107 (1e-3 * config.keep_ms as f64 * WHISPER_SAMPLE_RATE as f64) as usize;
108
109 let use_vad = n_samples_step == 0; let n_new_line = if !use_vad {
114 (config.length_ms / config.step_ms - 1).max(1)
115 } else {
116 1
117 };
118
119 params = params
121 .no_timestamps(!use_vad)
122 .max_tokens(0)
123 .single_segment(!use_vad)
124 .print_progress(false)
125 .print_realtime(false);
126
127 if use_vad {
129 config.no_context = true;
130 params = params.no_context(true);
131 }
132
133 Ok(Self {
134 state,
135 params,
136 config,
137 use_vad,
138 n_samples_step,
139 n_samples_len,
140 n_samples_keep,
141 n_new_line,
142 pcmf32_old: Vec::new(),
143 prompt_tokens: Vec::new(),
144 n_iter: 0,
145 audio_buf: VecDeque::new(),
146 total_samples_processed: 0,
147 })
148 }
149
150 pub fn feed_audio(&mut self, samples: &[f32]) {
154 self.audio_buf.extend(samples.iter());
155 }
156
157 pub fn process_step(&mut self) -> Result<Option<Vec<Segment>>> {
161 if !self.use_vad {
162 self.process_step_fixed()
163 } else {
164 self.process_step_vad()
165 }
166 }
167
168 fn process_step_fixed(&mut self) -> Result<Option<Vec<Segment>>> {
170 if self.audio_buf.len() < self.n_samples_step {
172 return Ok(None);
173 }
174
175 let pcmf32_new: Vec<f32> = self.audio_buf.drain(..self.n_samples_step).collect();
177 self.total_samples_processed += pcmf32_new.len() as i64;
178
179 let n_samples_new = pcmf32_new.len();
180
181 let n_samples_take = self.pcmf32_old.len().min(
184 (self.n_samples_keep + self.n_samples_len).saturating_sub(n_samples_new),
185 );
186
187 let mut pcmf32 = Vec::with_capacity(n_samples_take + n_samples_new);
189 if n_samples_take > 0 && !self.pcmf32_old.is_empty() {
190 let start = self.pcmf32_old.len() - n_samples_take;
191 pcmf32.extend_from_slice(&self.pcmf32_old[start..]);
192 }
193 pcmf32.extend_from_slice(&pcmf32_new);
194
195 self.pcmf32_old = pcmf32.clone();
197
198 let segments = self.run_inference(&pcmf32)?;
200
201 self.n_iter += 1;
202
203 if self.n_iter % self.n_new_line == 0 {
205 if self.n_samples_keep > 0 && pcmf32.len() >= self.n_samples_keep {
207 self.pcmf32_old =
208 pcmf32[pcmf32.len() - self.n_samples_keep..].to_vec();
209 } else {
210 self.pcmf32_old.clear();
211 }
212
213 if !self.config.no_context {
215 self.collect_prompt_tokens();
216 }
217 }
218
219 Ok(Some(segments))
220 }
221
222 fn process_step_vad(&mut self) -> Result<Option<Vec<Segment>>> {
224 let n_vad_samples = (WHISPER_SAMPLE_RATE * 2) as usize; if self.audio_buf.len() < n_vad_samples {
227 return Ok(None);
228 }
229
230 let pcmf32_vad: Vec<f32> = self.audio_buf.drain(..n_vad_samples).collect();
232 self.total_samples_processed += pcmf32_vad.len() as i64;
233
234 let is_silence = vad_simple(
236 &pcmf32_vad,
237 WHISPER_SAMPLE_RATE,
238 1000,
239 self.config.vad_thold,
240 self.config.freq_thold,
241 );
242
243 if is_silence {
244 return Ok(None);
245 }
246
247 let n_samples_len = self.n_samples_len;
249 let additional = n_samples_len.saturating_sub(pcmf32_vad.len());
250 let mut pcmf32 = pcmf32_vad;
251
252 if additional > 0 {
253 let available = additional.min(self.audio_buf.len());
254 let extra: Vec<f32> = self.audio_buf.drain(..available).collect();
255 self.total_samples_processed += extra.len() as i64;
256 pcmf32.extend_from_slice(&extra);
257 }
258
259 let segments = self.run_inference(&pcmf32)?;
260 self.n_iter += 1;
261
262 Ok(Some(segments))
263 }
264
265 fn run_inference(&mut self, audio: &[f32]) -> Result<Vec<Segment>> {
267 if audio.is_empty() {
268 return Ok(Vec::new());
269 }
270
271 let mut params = self.params.clone();
273
274 if !self.config.no_context && !self.prompt_tokens.is_empty() {
278 params = params.prompt_tokens(&self.prompt_tokens);
279 }
280
281 self.state.full(params, audio)?;
282
283 let n_segments = self.state.full_n_segments();
285 let mut segments = Vec::with_capacity(n_segments as usize);
286
287 for i in 0..n_segments {
288 let text = self.state.full_get_segment_text(i)?;
289 let (start_ms, end_ms) = self.state.full_get_segment_timestamps(i);
290 let speaker_turn_next = self.state.full_get_segment_speaker_turn_next(i);
291
292 segments.push(Segment {
293 start_ms,
294 end_ms,
295 text,
296 speaker_turn_next,
297 });
298 }
299
300 Ok(segments)
301 }
302
303 fn collect_prompt_tokens(&mut self) {
305 self.prompt_tokens.clear();
306
307 let n_segments = self.state.full_n_segments();
308 for i in 0..n_segments {
309 let token_count = self.state.full_n_tokens(i);
310 for j in 0..token_count {
311 self.prompt_tokens
312 .push(self.state.full_get_token_id(i, j));
313 }
314 }
315 }
316
317 pub fn flush(&mut self) -> Result<Vec<Segment>> {
321 let mut all_segments = Vec::new();
322
323 loop {
324 match self.process_step()? {
325 Some(segments) => all_segments.extend(segments),
326 None => break,
327 }
328 }
329
330 if !self.audio_buf.is_empty() {
332 let remaining: Vec<f32> = self.audio_buf.drain(..).collect();
333 self.total_samples_processed += remaining.len() as i64;
334
335 if !self.use_vad {
336 let n_samples_take = self.pcmf32_old.len().min(
338 (self.n_samples_keep + self.n_samples_len)
339 .saturating_sub(remaining.len()),
340 );
341 let mut pcmf32 = Vec::with_capacity(n_samples_take + remaining.len());
342 if n_samples_take > 0 && !self.pcmf32_old.is_empty() {
343 let start = self.pcmf32_old.len() - n_samples_take;
344 pcmf32.extend_from_slice(&self.pcmf32_old[start..]);
345 }
346 pcmf32.extend_from_slice(&remaining);
347
348 let segments = self.run_inference(&pcmf32)?;
349 all_segments.extend(segments);
350 } else {
351 let segments = self.run_inference(&remaining)?;
352 all_segments.extend(segments);
353 }
354 }
355
356 Ok(all_segments)
357 }
358
359 pub fn reset(&mut self) {
361 self.audio_buf.clear();
362 self.pcmf32_old.clear();
363 self.prompt_tokens.clear();
364 self.n_iter = 0;
365 self.total_samples_processed = 0;
366 }
367
368 pub fn buffer_size(&self) -> usize {
370 self.audio_buf.len()
371 }
372
373 pub fn processed_samples(&self) -> i64 {
375 self.total_samples_processed
376 }
377}
378
379fn high_pass_filter(data: &mut [f32], cutoff: f32, sample_rate: f32) {
385 if data.is_empty() {
386 return;
387 }
388 let rc = 1.0 / (2.0 * std::f32::consts::PI * cutoff);
389 let dt = 1.0 / sample_rate;
390 let alpha = dt / (rc + dt);
391
392 let mut y = data[0];
393 for i in 1..data.len() {
394 y = alpha * (y + data[i] - data[i - 1]);
395 data[i] = y;
396 }
397}
398
399fn vad_simple(
403 pcmf32: &[f32],
404 sample_rate: i32,
405 last_ms: i32,
406 vad_thold: f32,
407 freq_thold: f32,
408) -> bool {
409 let n_samples = pcmf32.len();
410 let n_samples_last = (sample_rate as usize * last_ms.max(0) as usize) / 1000;
411
412 if n_samples_last >= n_samples {
413 return true;
416 }
417
418 let mut data = pcmf32.to_vec();
420
421 if freq_thold > 0.0 {
422 high_pass_filter(&mut data, freq_thold, sample_rate as f32);
423 }
424
425 let mut energy_all: f32 = 0.0;
426 let mut energy_last: f32 = 0.0;
427
428 for (i, &s) in data.iter().enumerate() {
429 energy_all += s.abs();
430 if i >= n_samples - n_samples_last {
431 energy_last += s.abs();
432 }
433 }
434
435 energy_all /= n_samples as f32;
436 energy_last /= n_samples_last as f32;
437
438 energy_last <= vad_thold * energy_all
441}
442
443#[cfg(test)]
448mod tests {
449 use super::*;
450 use crate::SamplingStrategy;
451 use std::path::Path;
452
453 #[test]
454 fn test_config_defaults() {
455 let config = WhisperStreamConfig::default();
456 assert_eq!(config.step_ms, 3000);
457 assert_eq!(config.length_ms, 10000);
458 assert_eq!(config.keep_ms, 200);
459 assert!((config.vad_thold - 0.6).abs() < f32::EPSILON);
460 assert!((config.freq_thold - 100.0).abs() < f32::EPSILON);
461 assert!(config.no_context);
462 }
463
464 #[test]
465 fn test_config_normalization() {
466 let model_path = "tests/models/ggml-tiny.en.bin";
468 if !Path::new(model_path).exists() {
469 let mut config = WhisperStreamConfig {
472 step_ms: 2000,
473 length_ms: 5000,
474 keep_ms: 3000, ..Default::default()
476 };
477 config.keep_ms = config.keep_ms.min(config.step_ms);
478 config.length_ms = config.length_ms.max(config.step_ms);
479 assert_eq!(config.keep_ms, 2000);
480 assert_eq!(config.length_ms, 5000);
481
482 let mut config2 = WhisperStreamConfig {
484 step_ms: 8000,
485 length_ms: 5000, keep_ms: 200,
487 ..Default::default()
488 };
489 config2.keep_ms = config2.keep_ms.min(config2.step_ms);
490 config2.length_ms = config2.length_ms.max(config2.step_ms);
491 assert_eq!(config2.length_ms, 8000);
492 assert_eq!(config2.keep_ms, 200);
493 }
494 }
495
496 #[test]
497 fn test_n_new_line_calculation() {
498 let n = (10000i32 / 3000 - 1).max(1);
501 assert_eq!(n, 2);
502
503 let n = (10000i32 / 5000 - 1).max(1);
505 assert_eq!(n, 1);
506
507 let n = (10000i32 / 10000 - 1).max(1);
509 assert_eq!(n, 1);
510
511 let n = (10000i32 / 2000 - 1).max(1);
513 assert_eq!(n, 4);
514
515 let n_vad = 1i32;
517 assert_eq!(n_vad, 1);
518 }
519
520 #[test]
521 fn test_vad_mode_detection() {
522 let step_ms_values = [0, -1, -100];
524 for step_ms in step_ms_values {
525 let n_samples_step =
526 (1e-3 * step_ms as f64 * WHISPER_SAMPLE_RATE as f64) as usize;
527 assert_eq!(n_samples_step, 0, "step_ms={} should yield 0 samples", step_ms);
528 }
529
530 let n = (1e-3 * 3000.0 * WHISPER_SAMPLE_RATE as f64) as usize;
532 assert_eq!(n, 48000);
533 }
534
535 #[test]
536 fn test_feed_and_buffer() {
537 let model_path = "tests/models/ggml-tiny.en.bin";
538 if !Path::new(model_path).exists() {
539 eprintln!("Skipping test_feed_and_buffer: model not found");
540 return;
541 }
542
543 let ctx = WhisperContext::new(model_path).unwrap();
544 let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
545 let mut stream = WhisperStream::new(&ctx, params).unwrap();
546
547 assert_eq!(stream.buffer_size(), 0);
548
549 let samples = vec![0.0f32; 16000];
550 stream.feed_audio(&samples);
551 assert_eq!(stream.buffer_size(), 16000);
552
553 stream.feed_audio(&samples);
554 assert_eq!(stream.buffer_size(), 32000);
555 }
556
557 #[test]
558 fn test_vad_simple_silence() {
559 let silence = vec![0.0f32; 16000];
560 assert!(vad_simple(&silence, 16000, 100, 0.6, 100.0));
561 }
562
563 #[test]
564 fn test_vad_simple_too_few_samples() {
565 let short = vec![0.1f32; 100];
566 assert!(vad_simple(&short, 16000, 1000, 0.6, 100.0));
567 }
568
569 #[test]
570 fn test_high_pass_filter_basic() {
571 let mut data = vec![1.0, 0.0, 1.0, 0.0, 1.0];
572 high_pass_filter(&mut data, 100.0, 16000.0);
573 assert_ne!(data[2], 1.0);
574 }
575
576 #[test]
577 fn test_reset() {
578 let model_path = "tests/models/ggml-tiny.en.bin";
579 if !Path::new(model_path).exists() {
580 eprintln!("Skipping test_reset: model not found");
581 return;
582 }
583
584 let ctx = WhisperContext::new(model_path).unwrap();
585 let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
586 let mut stream = WhisperStream::new(&ctx, params).unwrap();
587
588 stream.feed_audio(&vec![0.0f32; 16000]);
589 assert_eq!(stream.buffer_size(), 16000);
590
591 stream.reset();
592 assert_eq!(stream.buffer_size(), 0);
593 assert_eq!(stream.processed_samples(), 0);
594 }
595
596 #[test]
599 fn test_fixed_step_basic() {
600 let model_path = "tests/models/ggml-tiny.en.bin";
601 if !Path::new(model_path).exists() {
602 eprintln!("Skipping test_fixed_step_basic: model not found");
603 return;
604 }
605
606 let ctx = WhisperContext::new(model_path).unwrap();
607 let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 })
608 .language("en");
609
610 let config = WhisperStreamConfig {
612 step_ms: 3000,
613 length_ms: 10000,
614 keep_ms: 200,
615 ..Default::default()
616 };
617
618 let mut stream = WhisperStream::with_config(&ctx, params, config).unwrap();
619
620 let audio = vec![0.0f32; 48000];
622 stream.feed_audio(&audio);
623
624 let result = stream.process_step().unwrap();
625 assert!(result.is_some(), "Should produce segments with enough audio");
626 assert!(stream.processed_samples() > 0);
627 }
628
629 #[test]
630 fn test_prompt_propagation() {
631 let model_path = "tests/models/ggml-tiny.en.bin";
632 if !Path::new(model_path).exists() {
633 eprintln!("Skipping test_prompt_propagation: model not found");
634 return;
635 }
636
637 let ctx = WhisperContext::new(model_path).unwrap();
638 let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 })
639 .language("en");
640
641 let config = WhisperStreamConfig {
642 step_ms: 3000,
643 length_ms: 6000,
644 keep_ms: 200,
645 no_context: false, ..Default::default()
647 };
648
649 let mut stream = WhisperStream::with_config(&ctx, params, config).unwrap();
650
651 let audio = vec![0.0f32; 48000];
656 stream.feed_audio(&audio);
657
658 let result = stream.process_step().unwrap();
659 assert!(result.is_some());
660
661 assert!(stream.processed_samples() > 0);
666 }
667}