Skip to main content

whisper_cpp_plus/
stream_pcm.rs

1//! Direct port of stream-pcm.cpp — streaming transcription from raw PCM input.
2//!
3//! Architecture: async PCM reader → ring buffer → fixed-step or VAD-driven processing.
4
5use crate::context::WhisperContext;
6use crate::error::{Result, WhisperError};
7use crate::params::FullParams;
8use crate::state::{Segment, WhisperState};
9use crate::vad::WhisperVadProcessor;
10
11use std::io::Read;
12use std::sync::{Arc, Mutex};
13use std::thread;
14
15const WHISPER_SAMPLE_RATE: i32 = 16000;
16
17#[cfg(target_os = "macos")]
18fn set_thread_qos_user_interactive() {
19    // Prioritize capture thread so we don't lose audio under CPU contention.
20    // Fire-and-forget: some environments may return EPERM after incompatible scheduling calls.
21    unsafe {
22        extern "C" {
23            fn pthread_set_qos_class_self_np(qos_class: u32, relative_priority: i32) -> i32;
24        }
25        const QOS_CLASS_USER_INTERACTIVE: u32 = 0x21;
26        let _ = pthread_set_qos_class_self_np(QOS_CLASS_USER_INTERACTIVE, 0);
27    }
28}
29
30#[cfg(not(target_os = "macos"))]
31fn set_thread_qos_user_interactive() {}
32
33// ---------------------------------------------------------------------------
34// PcmFormat
35// ---------------------------------------------------------------------------
36
37/// Input PCM sample format.
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum PcmFormat {
40    F32,
41    S16,
42}
43
44// ---------------------------------------------------------------------------
45// PcmReader — direct port of pcm_async
46// ---------------------------------------------------------------------------
47
48/// Configuration for [`PcmReader`].
49#[derive(Debug, Clone)]
50pub struct PcmReaderConfig {
51    /// Ring buffer length in milliseconds (maps to `m_len_ms`).
52    pub buffer_len_ms: i32,
53    /// Sample rate (must be 16000).
54    pub sample_rate: i32,
55    /// Input PCM format.
56    pub format: PcmFormat,
57}
58
59impl Default for PcmReaderConfig {
60    fn default() -> Self {
61        Self {
62            buffer_len_ms: 10000,
63            sample_rate: WHISPER_SAMPLE_RATE,
64            format: PcmFormat::F32,
65        }
66    }
67}
68
69/// Shared ring-buffer state (behind Mutex, matching C++ `m_mutex`).
70struct RingBuffer {
71    audio: Vec<f32>,
72    /// Write position in ring (next sample to write).
73    audio_pos: usize,
74    /// Total unread samples available.
75    audio_len: usize,
76    /// Read position in ring (next sample to pop).
77    audio_read: usize,
78    /// Total number of samples dropped due to ring buffer overflow.
79    dropped: u64,
80    eof: bool,
81}
82
83/// Threaded PCM reader — direct port of `pcm_async`.
84///
85/// Reads raw PCM from any `Read` source on a background thread,
86/// converts S16→f32 if needed, and fills a ring buffer.
87pub struct PcmReader {
88    shared: Arc<Mutex<RingBuffer>>,
89    handle: Option<thread::JoinHandle<()>>,
90    stop: Arc<std::sync::atomic::AtomicBool>,
91}
92
93impl PcmReader {
94    /// Create and immediately start the reader thread.
95    pub fn new(source: Box<dyn Read + Send>, config: PcmReaderConfig) -> Self {
96        let ring_samples = (config.sample_rate as usize * config.buffer_len_ms as usize) / 1000;
97
98        let shared = Arc::new(Mutex::new(RingBuffer {
99            audio: vec![0.0; ring_samples],
100            audio_pos: 0,
101            audio_len: 0,
102            audio_read: 0,
103            dropped: 0,
104            eof: false,
105        }));
106
107        let stop = Arc::new(std::sync::atomic::AtomicBool::new(false));
108
109        let shared_clone = Arc::clone(&shared);
110        let stop_clone = Arc::clone(&stop);
111        let format = config.format;
112
113        let handle = thread::spawn(move || {
114            set_thread_qos_user_interactive();
115            reader_loop(source, shared_clone, stop_clone, format);
116        });
117
118        Self {
119            shared,
120            handle: Some(handle),
121            stop,
122        }
123    }
124
125    /// Pop up to `ms` milliseconds of audio from the ring buffer.
126    /// Returns fewer samples if not enough are available.
127    pub fn pop_ms(&self, ms: i32) -> Vec<f32> {
128        let mut ring = self.shared.lock().unwrap();
129        let n_samples = ((WHISPER_SAMPLE_RATE as usize) * ms.max(0) as usize) / 1000;
130        let n = n_samples.min(ring.audio_len);
131
132        if n == 0 {
133            return Vec::new();
134        }
135
136        let mut result = vec![0.0f32; n];
137        let cap = ring.audio.len();
138        let s0 = ring.audio_read;
139
140        if s0 + n > cap {
141            let n0 = cap - s0;
142            result[..n0].copy_from_slice(&ring.audio[s0..]);
143            result[n0..].copy_from_slice(&ring.audio[..n - n0]);
144        } else {
145            result.copy_from_slice(&ring.audio[s0..s0 + n]);
146        }
147
148        ring.audio_read = (ring.audio_read + n) % cap;
149        ring.audio_len -= n;
150        result
151    }
152
153    /// Number of unread samples currently in the ring buffer.
154    pub fn available_samples(&self) -> usize {
155        self.shared.lock().unwrap().audio_len
156    }
157
158    /// Total number of samples dropped due to ring buffer overflow.
159    pub fn dropped_samples(&self) -> u64 {
160        self.shared.lock().unwrap().dropped
161    }
162
163    /// Whether the source has reached EOF.
164    pub fn is_eof(&self) -> bool {
165        self.shared.lock().unwrap().eof
166    }
167
168    /// Signal the reader thread to stop.
169    pub fn stop(&mut self) {
170        self.stop.store(true, std::sync::atomic::Ordering::Relaxed);
171        if let Some(h) = self.handle.take() {
172            let _ = h.join();
173        }
174    }
175}
176
177impl Drop for PcmReader {
178    fn drop(&mut self) {
179        self.stop();
180    }
181}
182
183/// Background reader loop — direct port of `pcm_async::reader_loop`.
184fn reader_loop(
185    mut source: Box<dyn Read + Send>,
186    shared: Arc<Mutex<RingBuffer>>,
187    stop: Arc<std::sync::atomic::AtomicBool>,
188    format: PcmFormat,
189) {
190    let bytes_per_sample: usize = match format {
191        PcmFormat::F32 => 4,
192        PcmFormat::S16 => 2,
193    };
194
195    let mut buffer = vec![0u8; 4096];
196    let mut carry: Vec<u8> = Vec::new();
197
198    loop {
199        if stop.load(std::sync::atomic::Ordering::Relaxed) {
200            break;
201        }
202
203        let n_read = match source.read(&mut buffer) {
204            Ok(0) => {
205                shared.lock().unwrap().eof = true;
206                break;
207            }
208            Ok(n) => n,
209            Err(_) => {
210                shared.lock().unwrap().eof = true;
211                break;
212            }
213        };
214
215        // Combine carry bytes with freshly read bytes
216        let mut data = Vec::with_capacity(carry.len() + n_read);
217        data.extend_from_slice(&carry);
218        data.extend_from_slice(&buffer[..n_read]);
219        carry.clear();
220
221        let total_bytes = data.len();
222        let n_samples = total_bytes / bytes_per_sample;
223        let rem = total_bytes % bytes_per_sample;
224
225        if rem > 0 {
226            carry.extend_from_slice(&data[total_bytes - rem..]);
227        }
228
229        if n_samples == 0 {
230            continue;
231        }
232
233        // Convert to f32
234        let samples: Vec<f32> = match format {
235            PcmFormat::F32 => (0..n_samples)
236                .map(|i| {
237                    let o = i * 4;
238                    f32::from_le_bytes([data[o], data[o + 1], data[o + 2], data[o + 3]])
239                })
240                .collect(),
241            PcmFormat::S16 => (0..n_samples)
242                .map(|i| {
243                    let o = i * 2;
244                    i16::from_le_bytes([data[o], data[o + 1]]) as f32 / 32768.0
245                })
246                .collect(),
247        };
248
249        // Push into ring buffer
250        push_samples(&shared, &samples);
251    }
252}
253
254/// Push samples into the ring buffer — direct port of `pcm_async::push_samples`.
255fn push_samples(shared: &Arc<Mutex<RingBuffer>>, data: &[f32]) {
256    if data.is_empty() {
257        return;
258    }
259
260    let mut ring = shared.lock().unwrap();
261    let cap = ring.audio.len();
262    let mut src = data;
263    let mut n = data.len();
264
265    // If more samples than ring capacity, skip the oldest
266    if n > cap {
267        ring.dropped += (n - cap) as u64;
268        src = &data[n - cap..];
269        n = cap;
270    }
271
272    // Drop oldest unread samples if we'd overflow
273    if n > cap - ring.audio_len {
274        let drop = n - (cap - ring.audio_len);
275        ring.audio_read = (ring.audio_read + drop) % cap;
276        ring.audio_len -= drop;
277        ring.dropped += drop as u64;
278    }
279
280    // Write into ring
281    let pos = ring.audio_pos;
282    if pos + n > cap {
283        let n0 = cap - pos;
284        ring.audio[pos..].copy_from_slice(&src[..n0]);
285        ring.audio[..n - n0].copy_from_slice(&src[n0..]);
286    } else {
287        ring.audio[pos..pos + n].copy_from_slice(src);
288    }
289
290    ring.audio_pos = (ring.audio_pos + n) % cap;
291    ring.audio_len = (ring.audio_len + n).min(cap);
292}
293
294// ---------------------------------------------------------------------------
295// vad_simple — port of common.cpp::vad_simple + high_pass_filter
296// ---------------------------------------------------------------------------
297
298/// High-pass filter — port of `common.cpp::high_pass_filter`.
299fn high_pass_filter(data: &mut [f32], cutoff: f32, sample_rate: f32) {
300    if data.is_empty() {
301        return;
302    }
303    let rc = 1.0 / (2.0 * std::f32::consts::PI * cutoff);
304    let dt = 1.0 / sample_rate;
305    let alpha = dt / (rc + dt);
306
307    let mut y = data[0];
308    for i in 1..data.len() {
309        y = alpha * (y + data[i] - data[i - 1]);
310        data[i] = y;
311    }
312}
313
314/// Energy-based VAD — port of `common.cpp::vad_simple`.
315///
316/// Returns `true` if the audio chunk is **silence** (no speech detected).
317pub fn vad_simple(
318    pcmf32: &[f32],
319    sample_rate: i32,
320    last_ms: i32,
321    vad_thold: f32,
322    freq_thold: f32,
323) -> bool {
324    let n_samples = pcmf32.len();
325    let n_samples_last = (sample_rate as usize * last_ms.max(0) as usize) / 1000;
326
327    if n_samples_last >= n_samples {
328        // not enough samples — assume no speech
329        return true; // silence
330    }
331
332    // Work on a copy so we can apply the high-pass filter
333    let mut data = pcmf32.to_vec();
334
335    if freq_thold > 0.0 {
336        high_pass_filter(&mut data, freq_thold, sample_rate as f32);
337    }
338
339    let mut energy_all: f32 = 0.0;
340    let mut energy_last: f32 = 0.0;
341
342    for (i, &s) in data.iter().enumerate() {
343        energy_all += s.abs();
344        if i >= n_samples - n_samples_last {
345            energy_last += s.abs();
346        }
347    }
348
349    energy_all /= n_samples as f32;
350    energy_last /= n_samples_last as f32;
351
352    // C++ returns false (= NOT silence) when energy_last > thold*energy_all
353    // We return true for silence, matching the C++ sense where true = silence.
354    energy_last <= vad_thold * energy_all
355}
356
357// ---------------------------------------------------------------------------
358// WhisperStreamPcmConfig
359// ---------------------------------------------------------------------------
360
361/// Configuration for [`WhisperStreamPcm`] — maps to `whisper_params` streaming subset.
362#[derive(Debug, Clone)]
363pub struct WhisperStreamPcmConfig {
364    /// Fixed-step chunk size in ms (non-VAD mode).
365    pub step_ms: i32,
366    /// Max audio length per inference in ms.
367    pub length_ms: i32,
368    /// Overlap to retain from previous step in ms.
369    pub keep_ms: i32,
370    /// Enable VAD-driven segmentation.
371    pub use_vad: bool,
372    /// VAD threshold (both simple & Silero).
373    pub vad_thold: f32,
374    /// High-pass frequency cutoff for simple VAD.
375    pub freq_thold: f32,
376    /// If true, don't carry prompt tokens across inference boundaries.
377    pub no_context: bool,
378    /// VAD probe chunk size in ms.
379    pub vad_probe_ms: i32,
380    /// Silence duration to end a segment in ms.
381    pub vad_silence_ms: i32,
382    /// Audio prepended before VAD trigger in ms.
383    pub vad_pre_roll_ms: i32,
384}
385
386impl Default for WhisperStreamPcmConfig {
387    fn default() -> Self {
388        Self {
389            step_ms: 3000,
390            length_ms: 10000,
391            keep_ms: 200,
392            use_vad: false,
393            vad_thold: 0.6,
394            freq_thold: 100.0,
395            no_context: true,
396            vad_probe_ms: 200,
397            vad_silence_ms: 800,
398            vad_pre_roll_ms: 300,
399        }
400    }
401}
402
403// ---------------------------------------------------------------------------
404// WhisperStreamPcm — main processor
405// ---------------------------------------------------------------------------
406
407/// Streaming PCM transcriber — direct port of `stream-pcm.cpp` main loop.
408///
409/// Two modes:
410/// - **Fixed-step** (`use_vad = false`): process `step_ms` chunks with overlap.
411/// - **VAD-driven** (`use_vad = true`): accumulate speech, transcribe on silence.
412pub struct WhisperStreamPcm {
413    state: WhisperState,
414    params: FullParams,
415    config: WhisperStreamPcmConfig,
416    reader: PcmReader,
417    vad: Option<WhisperVadProcessor>,
418
419    // Pre-computed sample counts
420    n_samples_step: usize,
421    n_samples_len: usize,
422    n_samples_keep: usize,
423
424    // Fixed-step state
425    pcmf32_old: Vec<f32>,
426    n_new_line: i32,
427    prompt_tokens: Vec<i32>,
428
429    // VAD state machine
430    in_speech: bool,
431    speech_buf: Vec<f32>,
432    pre_roll: Vec<f32>,
433    silence_samples: usize,
434
435    total_samples: i64,
436    n_iter: i32,
437
438    // VAD pre-computed
439    vad_last_ms: i32,
440    vad_pre_roll_samples: usize,
441    vad_silence_samples: usize,
442    vad_max_segment_samples: usize,
443}
444
445impl WhisperStreamPcm {
446    /// Create a new WhisperStreamPcm processor (simple VAD or no VAD).
447    pub fn new(
448        ctx: &WhisperContext,
449        params: FullParams,
450        mut config: WhisperStreamPcmConfig,
451        reader: PcmReader,
452    ) -> Result<Self> {
453        Self::build(ctx, params, &mut config, reader, None)
454    }
455
456    /// Create a new WhisperStreamPcm processor with Silero VAD.
457    pub fn with_vad(
458        ctx: &WhisperContext,
459        params: FullParams,
460        mut config: WhisperStreamPcmConfig,
461        reader: PcmReader,
462        vad: WhisperVadProcessor,
463    ) -> Result<Self> {
464        Self::build(ctx, params, &mut config, reader, Some(vad))
465    }
466
467    fn build(
468        ctx: &WhisperContext,
469        params: FullParams,
470        config: &mut WhisperStreamPcmConfig,
471        reader: PcmReader,
472        vad: Option<WhisperVadProcessor>,
473    ) -> Result<Self> {
474        let state = WhisperState::new(ctx)?;
475
476        // Normalize config (matches C++ main)
477        if !config.use_vad {
478            if config.step_ms <= 0 {
479                return Err(WhisperError::InvalidParameter(
480                    "step_ms must be > 0 unless use_vad is true".into(),
481                ));
482            }
483            config.keep_ms = config.keep_ms.min(config.step_ms);
484            config.length_ms = config.length_ms.max(config.step_ms);
485        } else {
486            if config.length_ms <= 0 {
487                config.length_ms = 5000;
488            }
489            config.keep_ms = 0;
490            // Force no_context in VAD mode (stream.cpp: no_context |= use_vad)
491            config.no_context = true;
492        }
493
494        let n_samples_step = if config.use_vad {
495            0
496        } else {
497            (config.step_ms as f64 * 0.001 * WHISPER_SAMPLE_RATE as f64) as usize
498        };
499        let n_samples_len = (config.length_ms as f64 * 0.001 * WHISPER_SAMPLE_RATE as f64) as usize;
500        let n_samples_keep = (config.keep_ms as f64 * 0.001 * WHISPER_SAMPLE_RATE as f64) as usize;
501
502        let n_new_line = if !config.use_vad && config.step_ms > 0 {
503            (config.length_ms / config.step_ms - 1).max(1)
504        } else {
505            1
506        };
507
508        let vad_probe_ms = config.vad_probe_ms.max(1);
509        let vad_last_ms = (vad_probe_ms / 2).clamp(1, 1000);
510        let vad_pre_roll_samples =
511            (WHISPER_SAMPLE_RATE as usize * config.vad_pre_roll_ms.max(0) as usize) / 1000;
512        let vad_silence_samples =
513            (WHISPER_SAMPLE_RATE as usize * config.vad_silence_ms.max(0) as usize) / 1000;
514
515        Ok(Self {
516            state,
517            params,
518            config: config.clone(),
519            reader,
520            vad,
521            n_samples_step,
522            n_samples_len,
523            n_samples_keep,
524            pcmf32_old: Vec::new(),
525            n_new_line,
526            prompt_tokens: Vec::new(),
527            in_speech: false,
528            speech_buf: Vec::new(),
529            pre_roll: Vec::new(),
530            silence_samples: 0,
531            total_samples: 0,
532            n_iter: 0,
533            vad_last_ms,
534            vad_pre_roll_samples,
535            vad_silence_samples,
536            vad_max_segment_samples: n_samples_len,
537        })
538    }
539
540    /// Returns `true` when the underlying reader has hit EOF and all samples are drained.
541    pub fn is_eof(&self) -> bool {
542        self.reader.is_eof() && self.reader.available_samples() == 0
543    }
544
545    /// Run one iteration of the main loop.
546    ///
547    /// Returns `Ok(Some(segments))` if transcription occurred,
548    /// `Ok(None)` if waiting for more audio or sleeping,
549    /// `Err` on fatal error.
550    ///
551    /// Returns `Ok(None)` with no more audio when EOF + drained.
552    pub fn process_step(&mut self) -> Result<Option<Vec<Segment>>> {
553        if !self.config.use_vad {
554            self.process_step_fixed()
555        } else {
556            self.process_step_vad()
557        }
558    }
559
560    /// Run until EOF or error. Calls `callback` for each transcription.
561    pub fn run<F>(&mut self, mut callback: F) -> Result<()>
562    where
563        F: FnMut(&[Segment], i64, i64),
564    {
565        loop {
566            match self.process_step()? {
567                Some(segments) if !segments.is_empty() => {
568                    let start = segments.first().map(|s| s.start_ms).unwrap_or(0);
569                    let end = segments.last().map(|s| s.end_ms).unwrap_or(0);
570                    callback(&segments, start, end);
571                }
572                Some(_) => {} // empty segments, keep going
573                None => {
574                    // Check if truly done (EOF + no audio left)
575                    if self.reader.is_eof() && self.reader.available_samples() == 0 {
576                        // Flush any remaining VAD speech
577                        if self.config.use_vad && self.in_speech && !self.speech_buf.is_empty() {
578                            let segments = self.run_inference(&self.speech_buf.clone())?;
579                            if !segments.is_empty() {
580                                let start = segments.first().map(|s| s.start_ms).unwrap_or(0);
581                                let end = segments.last().map(|s| s.end_ms).unwrap_or(0);
582                                callback(&segments, start, end);
583                            }
584                            self.speech_buf.clear();
585                            self.in_speech = false;
586                        }
587                        break;
588                    }
589                    // Still waiting for audio
590                    std::thread::sleep(std::time::Duration::from_millis(5));
591                }
592            }
593        }
594        Ok(())
595    }
596
597    /// Fixed-step processing — port of the non-VAD branch of the C++ main loop.
598    fn process_step_fixed(&mut self) -> Result<Option<Vec<Segment>>> {
599        let available = self.reader.available_samples();
600
601        if available < self.n_samples_step {
602            if self.reader.is_eof() {
603                if available == 0 {
604                    return Ok(None); // done
605                }
606                // Fall through to process remaining
607            } else {
608                return Ok(None); // wait for more audio
609            }
610        }
611
612        let pcmf32_new = self.reader.pop_ms(self.config.step_ms);
613        if pcmf32_new.is_empty() {
614            return Ok(None);
615        }
616
617        self.total_samples += pcmf32_new.len() as i64;
618
619        let n_samples_new = pcmf32_new.len();
620        let n_samples_take = self
621            .pcmf32_old
622            .len()
623            .min((self.n_samples_keep + self.n_samples_len).saturating_sub(n_samples_new));
624
625        let mut pcmf32 = Vec::with_capacity(n_samples_new + n_samples_take);
626
627        // Prepend overlap from previous step
628        if n_samples_take > 0 && !self.pcmf32_old.is_empty() {
629            let start = self.pcmf32_old.len() - n_samples_take;
630            pcmf32.extend_from_slice(&self.pcmf32_old[start..]);
631        }
632        pcmf32.extend_from_slice(&pcmf32_new);
633
634        self.pcmf32_old = pcmf32.clone();
635
636        let segments = self.run_inference(&pcmf32)?;
637        self.n_iter += 1;
638
639        // At n_new_line boundary (stream.cpp lines 408-425)
640        if self.n_iter % self.n_new_line == 0 {
641            if self.n_samples_keep > 0 && pcmf32.len() >= self.n_samples_keep {
642                self.pcmf32_old = pcmf32[pcmf32.len() - self.n_samples_keep..].to_vec();
643            }
644
645            if !self.config.no_context {
646                self.collect_prompt_tokens();
647            }
648        }
649
650        Ok(Some(segments))
651    }
652
653    /// VAD-driven processing — port of the VAD branch of the C++ main loop.
654    fn process_step_vad(&mut self) -> Result<Option<Vec<Segment>>> {
655        let available = self.reader.available_samples();
656
657        if available == 0 {
658            if self.reader.is_eof() {
659                // Flush remaining speech
660                if self.in_speech && !self.speech_buf.is_empty() {
661                    let segments = self.run_inference(&self.speech_buf.clone())?;
662                    self.speech_buf.clear();
663                    self.in_speech = false;
664                    self.n_iter += 1;
665                    return Ok(Some(segments));
666                }
667                return Ok(None);
668            }
669            return Ok(None); // wait
670        }
671
672        let pcmf32_new = self.reader.pop_ms(self.config.vad_probe_ms);
673        if pcmf32_new.is_empty() {
674            return Ok(None);
675        }
676
677        self.total_samples += pcmf32_new.len() as i64;
678
679        // Determine silence via Silero or simple VAD
680        let silence = if let Some(ref mut vad) = self.vad {
681            if vad.detect_speech(&pcmf32_new) {
682                let probs = vad.get_probs();
683                let avg = if probs.is_empty() {
684                    0.0
685                } else {
686                    probs.iter().sum::<f32>() / probs.len() as f32
687                };
688                avg < self.config.vad_thold
689            } else {
690                true // detect failed → treat as silence
691            }
692        } else {
693            vad_simple(
694                &pcmf32_new,
695                WHISPER_SAMPLE_RATE,
696                self.vad_last_ms,
697                self.config.vad_thold,
698                self.config.freq_thold,
699            )
700        };
701
702        let mut result_segments: Option<Vec<Segment>> = None;
703
704        if !self.in_speech {
705            if !silence {
706                self.in_speech = true;
707                self.silence_samples = 0;
708                self.speech_buf.clear();
709                if !self.pre_roll.is_empty() {
710                    self.speech_buf.extend_from_slice(&self.pre_roll);
711                }
712                self.speech_buf.extend_from_slice(&pcmf32_new);
713            }
714        } else {
715            self.speech_buf.extend_from_slice(&pcmf32_new);
716            if !silence {
717                self.silence_samples = 0;
718            } else {
719                self.silence_samples += pcmf32_new.len();
720            }
721
722            if self.speech_buf.len() >= self.vad_max_segment_samples
723                || self.silence_samples >= self.vad_silence_samples
724            {
725                let segments = self.run_inference(&self.speech_buf.clone())?;
726                self.speech_buf.clear();
727                self.in_speech = false;
728                self.silence_samples = 0;
729                self.n_iter += 1;
730                result_segments = Some(segments);
731            }
732        }
733
734        // Maintain pre-roll buffer
735        if self.vad_pre_roll_samples > 0 {
736            self.pre_roll.extend_from_slice(&pcmf32_new);
737            if self.pre_roll.len() > self.vad_pre_roll_samples {
738                let excess = self.pre_roll.len() - self.vad_pre_roll_samples;
739                self.pre_roll.drain(..excess);
740            }
741        }
742
743        Ok(result_segments)
744    }
745
746    /// Run whisper inference on an audio buffer — port of `run_inference` lambda.
747    fn run_inference(&mut self, audio: &[f32]) -> Result<Vec<Segment>> {
748        if audio.is_empty() {
749            return Ok(Vec::new());
750        }
751
752        let mut params = self.params.clone();
753        if !self.config.no_context && !self.prompt_tokens.is_empty() {
754            params = params.prompt_tokens(&self.prompt_tokens);
755        }
756
757        self.state.full(params, audio)?;
758
759        let n_segments = self.state.full_n_segments();
760        let mut segments = Vec::with_capacity(n_segments as usize);
761
762        for i in 0..n_segments {
763            let text = self.state.full_get_segment_text(i)?;
764            let (start_ms, end_ms) = self.state.full_get_segment_timestamps(i);
765            let speaker_turn_next = self.state.full_get_segment_speaker_turn_next(i);
766
767            segments.push(Segment {
768                start_ms,
769                end_ms,
770                text,
771                speaker_turn_next,
772            });
773        }
774
775        Ok(segments)
776    }
777
778    /// Collect prompt tokens from last inference — port of stream.cpp lines 416-425.
779    fn collect_prompt_tokens(&mut self) {
780        self.prompt_tokens.clear();
781
782        let n_segments = self.state.full_n_segments();
783        for i in 0..n_segments {
784            let token_count = self.state.full_n_tokens(i);
785            for j in 0..token_count {
786                self.prompt_tokens.push(self.state.full_get_token_id(i, j));
787            }
788        }
789    }
790
791    /// Get the total number of processed samples.
792    pub fn total_samples(&self) -> i64 {
793        self.total_samples
794    }
795
796    /// Get the iteration count.
797    pub fn n_iter(&self) -> i32 {
798        self.n_iter
799    }
800}
801
802// ---------------------------------------------------------------------------
803// Tests
804// ---------------------------------------------------------------------------
805
806#[cfg(test)]
807mod tests {
808    use super::*;
809
810    #[test]
811    fn test_pcm_format_eq() {
812        assert_eq!(PcmFormat::F32, PcmFormat::F32);
813        assert_ne!(PcmFormat::F32, PcmFormat::S16);
814    }
815
816    #[test]
817    fn test_vad_simple_silence() {
818        // All zeros = silence
819        let silence = vec![0.0f32; 16000];
820        assert!(vad_simple(&silence, 16000, 100, 0.6, 100.0));
821    }
822
823    #[test]
824    fn test_vad_simple_too_few_samples() {
825        let short = vec![0.1f32; 100];
826        // last_ms=1000 → needs 16000 samples, only have 100 → silence
827        assert!(vad_simple(&short, 16000, 1000, 0.6, 100.0));
828    }
829
830    #[test]
831    fn test_high_pass_filter_basic() {
832        let mut data = vec![1.0, 0.0, 1.0, 0.0, 1.0];
833        high_pass_filter(&mut data, 100.0, 16000.0);
834        // After filter, values should be modified
835        assert_ne!(data[2], 1.0);
836    }
837
838    #[test]
839    fn test_pcm_reader_f32() {
840        // Simulate 1 second of f32 PCM data (16000 samples)
841        let n = 16000;
842        let mut raw = Vec::with_capacity(n * 4);
843        for i in 0..n {
844            let val = (i as f32 / n as f32) * 2.0 - 1.0; // ramp -1..1
845            raw.extend_from_slice(&val.to_le_bytes());
846        }
847
848        let cursor = std::io::Cursor::new(raw);
849        let config = PcmReaderConfig {
850            buffer_len_ms: 2000,
851            sample_rate: 16000,
852            format: PcmFormat::F32,
853        };
854        let reader = PcmReader::new(Box::new(cursor), config);
855
856        // Wait for reader thread to consume
857        std::thread::sleep(std::time::Duration::from_millis(100));
858
859        assert!(reader.is_eof());
860        let samples = reader.pop_ms(1000);
861        assert_eq!(samples.len(), 16000);
862    }
863
864    #[test]
865    fn test_pcm_reader_s16() {
866        let n = 16000;
867        let mut raw = Vec::with_capacity(n * 2);
868        for i in 0..n {
869            let val = ((i as f32 / n as f32) * 2.0 - 1.0) * 32767.0;
870            raw.extend_from_slice(&(val as i16).to_le_bytes());
871        }
872
873        let cursor = std::io::Cursor::new(raw);
874        let config = PcmReaderConfig {
875            buffer_len_ms: 2000,
876            sample_rate: 16000,
877            format: PcmFormat::S16,
878        };
879        let reader = PcmReader::new(Box::new(cursor), config);
880
881        std::thread::sleep(std::time::Duration::from_millis(100));
882
883        assert!(reader.is_eof());
884        let samples = reader.pop_ms(1000);
885        assert_eq!(samples.len(), 16000);
886
887        // Check conversion — first sample should be near -1.0
888        assert!(samples[0] < -0.9);
889    }
890
891    #[test]
892    fn test_ring_buffer_overflow() {
893        // Buffer only holds 500ms = 8000 samples, but we push 16000
894        let n = 16000;
895        let mut raw = Vec::with_capacity(n * 4);
896        for i in 0..n {
897            raw.extend_from_slice(&(i as f32).to_le_bytes());
898        }
899
900        let cursor = std::io::Cursor::new(raw);
901        let config = PcmReaderConfig {
902            buffer_len_ms: 500,
903            sample_rate: 16000,
904            format: PcmFormat::F32,
905        };
906        let reader = PcmReader::new(Box::new(cursor), config);
907
908        std::thread::sleep(std::time::Duration::from_millis(100));
909
910        // Should only have 8000 samples (most recent)
911        let available = reader.available_samples();
912        assert!(available <= 8000);
913
914        // Overflow should have been tracked
915        let dropped = reader.dropped_samples();
916        assert!(dropped > 0, "Expected dropped samples on overflow");
917
918        let samples = reader.pop_ms(500);
919        assert_eq!(samples.len(), 8000);
920        // Last sample should be 15999.0
921        assert!((samples[samples.len() - 1] - 15999.0).abs() < 0.01);
922    }
923
924    #[test]
925    fn test_dropped_samples_zero_when_no_overflow() {
926        // Buffer holds 2000ms = 32000 samples, push only 16000
927        let n = 16000;
928        let mut raw = Vec::with_capacity(n * 4);
929        for i in 0..n {
930            let val = (i as f32 / n as f32) * 2.0 - 1.0;
931            raw.extend_from_slice(&val.to_le_bytes());
932        }
933
934        let cursor = std::io::Cursor::new(raw);
935        let config = PcmReaderConfig {
936            buffer_len_ms: 2000,
937            sample_rate: 16000,
938            format: PcmFormat::F32,
939        };
940        let reader = PcmReader::new(Box::new(cursor), config);
941
942        std::thread::sleep(std::time::Duration::from_millis(100));
943
944        assert_eq!(reader.dropped_samples(), 0);
945    }
946
947    #[test]
948    fn test_dropped_samples_tracked_on_overflow() {
949        // Buffer holds 500ms = 8000 samples, push 16000 — should drop 8000
950        let n = 16000;
951        let mut raw = Vec::with_capacity(n * 4);
952        for i in 0..n {
953            raw.extend_from_slice(&(i as f32).to_le_bytes());
954        }
955
956        let cursor = std::io::Cursor::new(raw);
957        let config = PcmReaderConfig {
958            buffer_len_ms: 500,
959            sample_rate: 16000,
960            format: PcmFormat::F32,
961        };
962        let reader = PcmReader::new(Box::new(cursor), config);
963
964        std::thread::sleep(std::time::Duration::from_millis(100));
965
966        let dropped = reader.dropped_samples();
967        assert_eq!(
968            dropped, 8000,
969            "Expected 8000 dropped samples, got {}",
970            dropped
971        );
972    }
973
974    #[test]
975    fn test_qos_does_not_panic() {
976        // On macOS this sets QoS, on other platforms it's a no-op.
977        // Either way it should not panic.
978        set_thread_qos_user_interactive();
979    }
980
981    #[test]
982    fn test_stream_pcm_config_defaults() {
983        let config = WhisperStreamPcmConfig::default();
984        assert_eq!(config.step_ms, 3000);
985        assert_eq!(config.length_ms, 10000);
986        assert_eq!(config.keep_ms, 200);
987        assert!(!config.use_vad);
988    }
989
990    #[test]
991    fn test_stream_pcm_config_vad_normalization() {
992        // When use_vad=true, keep_ms should be forced to 0
993        use std::path::Path;
994        let model_path = "tests/models/ggml-tiny.en.bin";
995        if !Path::new(model_path).exists() {
996            eprintln!("Skipping: model not found");
997            return;
998        }
999
1000        let ctx = WhisperContext::new(model_path).unwrap();
1001        let params = FullParams::default();
1002        let cursor = std::io::Cursor::new(Vec::<u8>::new());
1003        let reader = PcmReader::new(Box::new(cursor), PcmReaderConfig::default());
1004        let config = WhisperStreamPcmConfig {
1005            use_vad: true,
1006            keep_ms: 500, // should be forced to 0
1007            ..Default::default()
1008        };
1009
1010        let stream = WhisperStreamPcm::new(&ctx, params, config, reader).unwrap();
1011        assert_eq!(stream.config.keep_ms, 0);
1012    }
1013}