Skip to main content

whisper_cpp_plus/
stream.rs

1//! Streaming transcription — faithful port of stream.cpp
2//!
3//! Replaces SDL audio capture with a push-based `feed_audio()` API
4//! since we're a library, not a binary.
5
6use crate::context::WhisperContext;
7use crate::error::Result;
8use crate::params::FullParams;
9use crate::state::{Segment, WhisperState};
10use std::collections::VecDeque;
11
12const WHISPER_SAMPLE_RATE: i32 = 16000;
13
14// ---------------------------------------------------------------------------
15// WhisperStreamConfig
16// ---------------------------------------------------------------------------
17
18/// Streaming config — maps to stream.cpp's whisper_params (streaming subset).
19#[derive(Debug, Clone)]
20pub struct WhisperStreamConfig {
21    /// Audio step size in ms. Set <= 0 for VAD mode.
22    pub step_ms: i32,
23    /// Audio length per inference in ms.
24    pub length_ms: i32,
25    /// Audio to keep from previous step in ms.
26    pub keep_ms: i32,
27    /// VAD energy threshold.
28    pub vad_thold: f32,
29    /// High-pass frequency cutoff for VAD.
30    pub freq_thold: f32,
31    /// If true, don't carry prompt tokens across boundaries.
32    pub no_context: bool,
33}
34
35impl Default for WhisperStreamConfig {
36    fn default() -> Self {
37        Self {
38            step_ms: 3000,
39            length_ms: 10000,
40            keep_ms: 200,
41            vad_thold: 0.6,
42            freq_thold: 100.0,
43            no_context: true,
44        }
45    }
46}
47
48// ---------------------------------------------------------------------------
49// WhisperStream
50// ---------------------------------------------------------------------------
51
52/// Streaming transcriber — faithful port of stream.cpp main loop.
53///
54/// Two modes:
55/// - **Fixed-step** (`step_ms > 0`): sliding window with overlap.
56/// - **VAD** (`step_ms <= 0`): transcribe on speech activity.
57pub struct WhisperStream {
58    state: WhisperState,
59    params: FullParams,
60    config: WhisperStreamConfig,
61    use_vad: bool,
62
63    // Pre-computed sample counts
64    n_samples_step: usize,
65    n_samples_len: usize,
66    n_samples_keep: usize,
67    n_new_line: i32,
68
69    // Overlap buffer from previous inference
70    pcmf32_old: Vec<f32>,
71    // Context propagation
72    prompt_tokens: Vec<i32>,
73
74    n_iter: i32,
75
76    // Internal audio buffer (replaces SDL capture)
77    audio_buf: VecDeque<f32>,
78
79    // Total samples consumed from audio_buf
80    total_samples_processed: i64,
81}
82
83impl WhisperStream {
84    /// Create with default config.
85    pub fn new(ctx: &WhisperContext, params: FullParams) -> Result<Self> {
86        Self::with_config(ctx, params, WhisperStreamConfig::default())
87    }
88
89    /// Create with custom config.
90    pub fn with_config(
91        ctx: &WhisperContext,
92        mut params: FullParams,
93        mut config: WhisperStreamConfig,
94    ) -> Result<Self> {
95        let state = WhisperState::new(ctx)?;
96
97        // --- Config normalization (stream.cpp main()) ---
98        config.keep_ms = config.keep_ms.min(config.step_ms);
99        config.length_ms = config.length_ms.max(config.step_ms);
100
101        // Sample counts
102        let n_samples_step =
103            (1e-3 * config.step_ms as f64 * WHISPER_SAMPLE_RATE as f64) as usize;
104        let n_samples_len =
105            (1e-3 * config.length_ms as f64 * WHISPER_SAMPLE_RATE as f64) as usize;
106        let n_samples_keep =
107            (1e-3 * config.keep_ms as f64 * WHISPER_SAMPLE_RATE as f64) as usize;
108
109        // Mode detection
110        let use_vad = n_samples_step == 0; // step_ms <= 0 → VAD
111
112        // n_new_line: guard against division by zero when step_ms <= 0
113        let n_new_line = if !use_vad {
114            (config.length_ms / config.step_ms - 1).max(1)
115        } else {
116            1
117        };
118
119        // Auto-set mode-dependent FullParams (stream.cpp lines 141-143)
120        params = params
121            .no_timestamps(!use_vad)
122            .max_tokens(0)
123            .single_segment(!use_vad)
124            .print_progress(false)
125            .print_realtime(false);
126
127        // Force no_context in VAD mode: no_context |= use_vad
128        if use_vad {
129            config.no_context = true;
130            params = params.no_context(true);
131        }
132
133        Ok(Self {
134            state,
135            params,
136            config,
137            use_vad,
138            n_samples_step,
139            n_samples_len,
140            n_samples_keep,
141            n_new_line,
142            pcmf32_old: Vec::new(),
143            prompt_tokens: Vec::new(),
144            n_iter: 0,
145            audio_buf: VecDeque::new(),
146            total_samples_processed: 0,
147        })
148    }
149
150    // --- Audio input ---
151
152    /// Push samples into the internal buffer (replaces SDL capture).
153    pub fn feed_audio(&mut self, samples: &[f32]) {
154        self.audio_buf.extend(samples.iter());
155    }
156
157    // --- Processing ---
158
159    /// Dispatch to fixed-step or VAD mode.
160    pub fn process_step(&mut self) -> Result<Option<Vec<Segment>>> {
161        if !self.use_vad {
162            self.process_step_fixed()
163        } else {
164            self.process_step_vad()
165        }
166    }
167
168    /// Fixed-step (sliding window) mode — port of stream.cpp lines 253-428.
169    fn process_step_fixed(&mut self) -> Result<Option<Vec<Segment>>> {
170        // Need at least n_samples_step new samples
171        if self.audio_buf.len() < self.n_samples_step {
172            return Ok(None);
173        }
174
175        // Pop n_samples_step from front of audio_buf
176        let pcmf32_new: Vec<f32> = self.audio_buf.drain(..self.n_samples_step).collect();
177        self.total_samples_processed += pcmf32_new.len() as i64;
178
179        let n_samples_new = pcmf32_new.len();
180
181        // Exact formula from stream.cpp line 279:
182        // n_samples_take = min(pcmf32_old.size(), max(0, n_samples_keep + n_samples_len - n_samples_new))
183        let n_samples_take = self.pcmf32_old.len().min(
184            (self.n_samples_keep + self.n_samples_len).saturating_sub(n_samples_new),
185        );
186
187        // Build pcmf32: tail of pcmf32_old + pcmf32_new
188        let mut pcmf32 = Vec::with_capacity(n_samples_take + n_samples_new);
189        if n_samples_take > 0 && !self.pcmf32_old.is_empty() {
190            let start = self.pcmf32_old.len() - n_samples_take;
191            pcmf32.extend_from_slice(&self.pcmf32_old[start..]);
192        }
193        pcmf32.extend_from_slice(&pcmf32_new);
194
195        // Save for next iteration
196        self.pcmf32_old = pcmf32.clone();
197
198        // Run inference
199        let segments = self.run_inference(&pcmf32)?;
200
201        self.n_iter += 1;
202
203        // At n_new_line boundary (stream.cpp lines 408-425)
204        if self.n_iter % self.n_new_line == 0 {
205            // Keep only last n_samples_keep samples
206            if self.n_samples_keep > 0 && pcmf32.len() >= self.n_samples_keep {
207                self.pcmf32_old =
208                    pcmf32[pcmf32.len() - self.n_samples_keep..].to_vec();
209            } else {
210                self.pcmf32_old.clear();
211            }
212
213            // Collect prompt tokens if !no_context
214            if !self.config.no_context {
215                self.collect_prompt_tokens();
216            }
217        }
218
219        Ok(Some(segments))
220    }
221
222    /// VAD mode — port of stream.cpp lines 293-313.
223    fn process_step_vad(&mut self) -> Result<Option<Vec<Segment>>> {
224        // Need at least 2 seconds of audio (stream.cpp: t_diff < 2000 → continue)
225        let n_vad_samples = (WHISPER_SAMPLE_RATE * 2) as usize; // 32000 samples
226        if self.audio_buf.len() < n_vad_samples {
227            return Ok(None);
228        }
229
230        // Pop 2 seconds for VAD probe
231        let pcmf32_vad: Vec<f32> = self.audio_buf.drain(..n_vad_samples).collect();
232        self.total_samples_processed += pcmf32_vad.len() as i64;
233
234        // Check for speech
235        let is_silence = vad_simple(
236            &pcmf32_vad,
237            WHISPER_SAMPLE_RATE,
238            1000,
239            self.config.vad_thold,
240            self.config.freq_thold,
241        );
242
243        if is_silence {
244            return Ok(None);
245        }
246
247        // Speech detected — grab length_ms of audio total (stream.cpp line 305)
248        let n_samples_len = self.n_samples_len;
249        let additional = n_samples_len.saturating_sub(pcmf32_vad.len());
250        let mut pcmf32 = pcmf32_vad;
251
252        if additional > 0 {
253            let available = additional.min(self.audio_buf.len());
254            let extra: Vec<f32> = self.audio_buf.drain(..available).collect();
255            self.total_samples_processed += extra.len() as i64;
256            pcmf32.extend_from_slice(&extra);
257        }
258
259        let segments = self.run_inference(&pcmf32)?;
260        self.n_iter += 1;
261
262        Ok(Some(segments))
263    }
264
265    /// Run whisper inference on audio — port of stream.cpp lines 316-344.
266    fn run_inference(&mut self, audio: &[f32]) -> Result<Vec<Segment>> {
267        if audio.is_empty() {
268            return Ok(Vec::new());
269        }
270
271        // Clone params so we can set prompt_tokens pointer
272        let mut params = self.params.clone();
273
274        // Set prompt tokens on the clone, pointing to self.prompt_tokens.
275        // The prompt_tokens() method stores a raw pointer. self.prompt_tokens
276        // (Vec<i32>) lives on self and outlives the full() call, so this is safe.
277        if !self.config.no_context && !self.prompt_tokens.is_empty() {
278            params = params.prompt_tokens(&self.prompt_tokens);
279        }
280
281        self.state.full(params, audio)?;
282
283        // Extract segments
284        let n_segments = self.state.full_n_segments();
285        let mut segments = Vec::with_capacity(n_segments as usize);
286
287        for i in 0..n_segments {
288            let text = self.state.full_get_segment_text(i)?;
289            let (start_ms, end_ms) = self.state.full_get_segment_timestamps(i);
290            let speaker_turn_next = self.state.full_get_segment_speaker_turn_next(i);
291
292            segments.push(Segment {
293                start_ms,
294                end_ms,
295                text,
296                speaker_turn_next,
297            });
298        }
299
300        Ok(segments)
301    }
302
303    /// Collect prompt tokens from last inference — port of stream.cpp lines 416-425.
304    fn collect_prompt_tokens(&mut self) {
305        self.prompt_tokens.clear();
306
307        let n_segments = self.state.full_n_segments();
308        for i in 0..n_segments {
309            let token_count = self.state.full_n_tokens(i);
310            for j in 0..token_count {
311                self.prompt_tokens
312                    .push(self.state.full_get_token_id(i, j));
313            }
314        }
315    }
316
317    // --- Convenience methods ---
318
319    /// Process all remaining audio in buffer.
320    pub fn flush(&mut self) -> Result<Vec<Segment>> {
321        let mut all_segments = Vec::new();
322
323        loop {
324            match self.process_step()? {
325                Some(segments) => all_segments.extend(segments),
326                None => break,
327            }
328        }
329
330        // If there's leftover audio that's less than a full step, run inference on it
331        if !self.audio_buf.is_empty() {
332            let remaining: Vec<f32> = self.audio_buf.drain(..).collect();
333            self.total_samples_processed += remaining.len() as i64;
334
335            if !self.use_vad {
336                // Build final buffer with overlap
337                let n_samples_take = self.pcmf32_old.len().min(
338                    (self.n_samples_keep + self.n_samples_len)
339                        .saturating_sub(remaining.len()),
340                );
341                let mut pcmf32 = Vec::with_capacity(n_samples_take + remaining.len());
342                if n_samples_take > 0 && !self.pcmf32_old.is_empty() {
343                    let start = self.pcmf32_old.len() - n_samples_take;
344                    pcmf32.extend_from_slice(&self.pcmf32_old[start..]);
345                }
346                pcmf32.extend_from_slice(&remaining);
347
348                let segments = self.run_inference(&pcmf32)?;
349                all_segments.extend(segments);
350            } else {
351                let segments = self.run_inference(&remaining)?;
352                all_segments.extend(segments);
353            }
354        }
355
356        Ok(all_segments)
357    }
358
359    /// Clear buffers, counters, prompt tokens.
360    pub fn reset(&mut self) {
361        self.audio_buf.clear();
362        self.pcmf32_old.clear();
363        self.prompt_tokens.clear();
364        self.n_iter = 0;
365        self.total_samples_processed = 0;
366    }
367
368    /// Samples currently in the internal buffer.
369    pub fn buffer_size(&self) -> usize {
370        self.audio_buf.len()
371    }
372
373    /// Total samples consumed from the buffer.
374    pub fn processed_samples(&self) -> i64 {
375        self.total_samples_processed
376    }
377}
378
379// ---------------------------------------------------------------------------
380// vad_simple + high_pass_filter — port from common.cpp
381// ---------------------------------------------------------------------------
382
383/// High-pass filter — port of common.cpp::high_pass_filter (lines 597-608).
384fn high_pass_filter(data: &mut [f32], cutoff: f32, sample_rate: f32) {
385    if data.is_empty() {
386        return;
387    }
388    let rc = 1.0 / (2.0 * std::f32::consts::PI * cutoff);
389    let dt = 1.0 / sample_rate;
390    let alpha = dt / (rc + dt);
391
392    let mut y = data[0];
393    for i in 1..data.len() {
394        y = alpha * (y + data[i] - data[i - 1]);
395        data[i] = y;
396    }
397}
398
399/// Energy-based VAD — port of common.cpp::vad_simple (lines 610-646).
400///
401/// Returns `true` if **silence** (no speech detected).
402fn vad_simple(
403    pcmf32: &[f32],
404    sample_rate: i32,
405    last_ms: i32,
406    vad_thold: f32,
407    freq_thold: f32,
408) -> bool {
409    let n_samples = pcmf32.len();
410    let n_samples_last = (sample_rate as usize * last_ms.max(0) as usize) / 1000;
411
412    if n_samples_last >= n_samples {
413        // not enough samples — assume no speech (C++ returns false here,
414        // but the sense in C++ is inverted: false = silence. We return true = silence.)
415        return true;
416    }
417
418    // Work on a copy so we can apply the high-pass filter
419    let mut data = pcmf32.to_vec();
420
421    if freq_thold > 0.0 {
422        high_pass_filter(&mut data, freq_thold, sample_rate as f32);
423    }
424
425    let mut energy_all: f32 = 0.0;
426    let mut energy_last: f32 = 0.0;
427
428    for (i, &s) in data.iter().enumerate() {
429        energy_all += s.abs();
430        if i >= n_samples - n_samples_last {
431            energy_last += s.abs();
432        }
433    }
434
435    energy_all /= n_samples as f32;
436    energy_last /= n_samples_last as f32;
437
438    // C++ returns false (speech) when energy_last > thold * energy_all.
439    // We return true for silence.
440    energy_last <= vad_thold * energy_all
441}
442
443// ---------------------------------------------------------------------------
444// Tests
445// ---------------------------------------------------------------------------
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450    use crate::SamplingStrategy;
451    use std::path::Path;
452
453    #[test]
454    fn test_config_defaults() {
455        let config = WhisperStreamConfig::default();
456        assert_eq!(config.step_ms, 3000);
457        assert_eq!(config.length_ms, 10000);
458        assert_eq!(config.keep_ms, 200);
459        assert!((config.vad_thold - 0.6).abs() < f32::EPSILON);
460        assert!((config.freq_thold - 100.0).abs() < f32::EPSILON);
461        assert!(config.no_context);
462    }
463
464    #[test]
465    fn test_config_normalization() {
466        // keep_ms clamped to step_ms
467        let model_path = "tests/models/ggml-tiny.en.bin";
468        if !Path::new(model_path).exists() {
469            // Can't test normalization without a model for the constructor.
470            // Test the logic directly instead.
471            let mut config = WhisperStreamConfig {
472                step_ms: 2000,
473                length_ms: 5000,
474                keep_ms: 3000, // > step_ms, should be clamped
475                ..Default::default()
476            };
477            config.keep_ms = config.keep_ms.min(config.step_ms);
478            config.length_ms = config.length_ms.max(config.step_ms);
479            assert_eq!(config.keep_ms, 2000);
480            assert_eq!(config.length_ms, 5000);
481
482            // length_ms clamped up to step_ms
483            let mut config2 = WhisperStreamConfig {
484                step_ms: 8000,
485                length_ms: 5000, // < step_ms, should be raised
486                keep_ms: 200,
487                ..Default::default()
488            };
489            config2.keep_ms = config2.keep_ms.min(config2.step_ms);
490            config2.length_ms = config2.length_ms.max(config2.step_ms);
491            assert_eq!(config2.length_ms, 8000);
492            assert_eq!(config2.keep_ms, 200);
493        }
494    }
495
496    #[test]
497    fn test_n_new_line_calculation() {
498        // n_new_line = max(1, length_ms / step_ms - 1) when !use_vad
499        // Defaults: length_ms=10000, step_ms=3000 → 10000/3000 - 1 = 2
500        let n = (10000i32 / 3000 - 1).max(1);
501        assert_eq!(n, 2);
502
503        // step_ms=5000, length_ms=10000 → 10000/5000 - 1 = 1
504        let n = (10000i32 / 5000 - 1).max(1);
505        assert_eq!(n, 1);
506
507        // step_ms=10000, length_ms=10000 → 10000/10000 - 1 = 0 → clamped to 1
508        let n = (10000i32 / 10000 - 1).max(1);
509        assert_eq!(n, 1);
510
511        // step_ms=2000, length_ms=10000 → 10000/2000 - 1 = 4
512        let n = (10000i32 / 2000 - 1).max(1);
513        assert_eq!(n, 4);
514
515        // VAD mode: always 1
516        let n_vad = 1i32;
517        assert_eq!(n_vad, 1);
518    }
519
520    #[test]
521    fn test_vad_mode_detection() {
522        // step_ms <= 0 → use_vad
523        let step_ms_values = [0, -1, -100];
524        for step_ms in step_ms_values {
525            let n_samples_step =
526                (1e-3 * step_ms as f64 * WHISPER_SAMPLE_RATE as f64) as usize;
527            assert_eq!(n_samples_step, 0, "step_ms={} should yield 0 samples", step_ms);
528        }
529
530        // step_ms > 0 → fixed step
531        let n = (1e-3 * 3000.0 * WHISPER_SAMPLE_RATE as f64) as usize;
532        assert_eq!(n, 48000);
533    }
534
535    #[test]
536    fn test_feed_and_buffer() {
537        let model_path = "tests/models/ggml-tiny.en.bin";
538        if !Path::new(model_path).exists() {
539            eprintln!("Skipping test_feed_and_buffer: model not found");
540            return;
541        }
542
543        let ctx = WhisperContext::new(model_path).unwrap();
544        let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
545        let mut stream = WhisperStream::new(&ctx, params).unwrap();
546
547        assert_eq!(stream.buffer_size(), 0);
548
549        let samples = vec![0.0f32; 16000];
550        stream.feed_audio(&samples);
551        assert_eq!(stream.buffer_size(), 16000);
552
553        stream.feed_audio(&samples);
554        assert_eq!(stream.buffer_size(), 32000);
555    }
556
557    #[test]
558    fn test_vad_simple_silence() {
559        let silence = vec![0.0f32; 16000];
560        assert!(vad_simple(&silence, 16000, 100, 0.6, 100.0));
561    }
562
563    #[test]
564    fn test_vad_simple_too_few_samples() {
565        let short = vec![0.1f32; 100];
566        assert!(vad_simple(&short, 16000, 1000, 0.6, 100.0));
567    }
568
569    #[test]
570    fn test_high_pass_filter_basic() {
571        let mut data = vec![1.0, 0.0, 1.0, 0.0, 1.0];
572        high_pass_filter(&mut data, 100.0, 16000.0);
573        assert_ne!(data[2], 1.0);
574    }
575
576    #[test]
577    fn test_reset() {
578        let model_path = "tests/models/ggml-tiny.en.bin";
579        if !Path::new(model_path).exists() {
580            eprintln!("Skipping test_reset: model not found");
581            return;
582        }
583
584        let ctx = WhisperContext::new(model_path).unwrap();
585        let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
586        let mut stream = WhisperStream::new(&ctx, params).unwrap();
587
588        stream.feed_audio(&vec![0.0f32; 16000]);
589        assert_eq!(stream.buffer_size(), 16000);
590
591        stream.reset();
592        assert_eq!(stream.buffer_size(), 0);
593        assert_eq!(stream.processed_samples(), 0);
594    }
595
596    // --- Integration tests (require model) ---
597
598    #[test]
599    fn test_fixed_step_basic() {
600        let model_path = "tests/models/ggml-tiny.en.bin";
601        if !Path::new(model_path).exists() {
602            eprintln!("Skipping test_fixed_step_basic: model not found");
603            return;
604        }
605
606        let ctx = WhisperContext::new(model_path).unwrap();
607        let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 })
608            .language("en");
609
610        // Use a small step for testing
611        let config = WhisperStreamConfig {
612            step_ms: 3000,
613            length_ms: 10000,
614            keep_ms: 200,
615            ..Default::default()
616        };
617
618        let mut stream = WhisperStream::with_config(&ctx, params, config).unwrap();
619
620        // Feed enough audio for one step (3 seconds = 48000 samples)
621        let audio = vec![0.0f32; 48000];
622        stream.feed_audio(&audio);
623
624        let result = stream.process_step().unwrap();
625        assert!(result.is_some(), "Should produce segments with enough audio");
626        assert!(stream.processed_samples() > 0);
627    }
628
629    #[test]
630    fn test_prompt_propagation() {
631        let model_path = "tests/models/ggml-tiny.en.bin";
632        if !Path::new(model_path).exists() {
633            eprintln!("Skipping test_prompt_propagation: model not found");
634            return;
635        }
636
637        let ctx = WhisperContext::new(model_path).unwrap();
638        let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 })
639            .language("en");
640
641        let config = WhisperStreamConfig {
642            step_ms: 3000,
643            length_ms: 6000,
644            keep_ms: 200,
645            no_context: false, // enable prompt propagation
646            ..Default::default()
647        };
648
649        let mut stream = WhisperStream::with_config(&ctx, params, config).unwrap();
650
651        // n_new_line = max(1, 6000/3000 - 1) = 1, so every iteration triggers
652        // prompt collection when no_context=false.
653
654        // Feed enough for one step
655        let audio = vec![0.0f32; 48000];
656        stream.feed_audio(&audio);
657
658        let result = stream.process_step().unwrap();
659        assert!(result.is_some());
660
661        // After one iteration at the n_new_line boundary, prompt_tokens should
662        // be populated (assuming whisper produced at least one token).
663        // With silence input, whisper may or may not produce tokens, so we
664        // just verify the mechanism didn't panic.
665        assert!(stream.processed_samples() > 0);
666    }
667}