Skip to main content

whisper_cpp_plus/
stream_pcm.rs

1//! Direct port of stream-pcm.cpp — streaming transcription from raw PCM input.
2//!
3//! Architecture: async PCM reader → ring buffer → fixed-step or VAD-driven processing.
4
5use crate::context::WhisperContext;
6use crate::error::{Result, WhisperError};
7use crate::params::FullParams;
8use crate::state::{Segment, WhisperState};
9use crate::vad::WhisperVadProcessor;
10
11use std::io::Read;
12use std::sync::{Arc, Mutex};
13use std::thread;
14
15const WHISPER_SAMPLE_RATE: i32 = 16000;
16
17// ---------------------------------------------------------------------------
18// PcmFormat
19// ---------------------------------------------------------------------------
20
21/// Input PCM sample format.
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum PcmFormat {
24    F32,
25    S16,
26}
27
28// ---------------------------------------------------------------------------
29// PcmReader — direct port of pcm_async
30// ---------------------------------------------------------------------------
31
32/// Configuration for [`PcmReader`].
33#[derive(Debug, Clone)]
34pub struct PcmReaderConfig {
35    /// Ring buffer length in milliseconds (maps to `m_len_ms`).
36    pub buffer_len_ms: i32,
37    /// Sample rate (must be 16000).
38    pub sample_rate: i32,
39    /// Input PCM format.
40    pub format: PcmFormat,
41}
42
43impl Default for PcmReaderConfig {
44    fn default() -> Self {
45        Self {
46            buffer_len_ms: 10000,
47            sample_rate: WHISPER_SAMPLE_RATE,
48            format: PcmFormat::F32,
49        }
50    }
51}
52
53/// Shared ring-buffer state (behind Mutex, matching C++ `m_mutex`).
54struct RingBuffer {
55    audio: Vec<f32>,
56    /// Write position in ring (next sample to write).
57    audio_pos: usize,
58    /// Total unread samples available.
59    audio_len: usize,
60    /// Read position in ring (next sample to pop).
61    audio_read: usize,
62    eof: bool,
63}
64
65/// Threaded PCM reader — direct port of `pcm_async`.
66///
67/// Reads raw PCM from any `Read` source on a background thread,
68/// converts S16→f32 if needed, and fills a ring buffer.
69pub struct PcmReader {
70    shared: Arc<Mutex<RingBuffer>>,
71    handle: Option<thread::JoinHandle<()>>,
72    stop: Arc<std::sync::atomic::AtomicBool>,
73}
74
75impl PcmReader {
76    /// Create and immediately start the reader thread.
77    pub fn new(source: Box<dyn Read + Send>, config: PcmReaderConfig) -> Self {
78        let ring_samples = (config.sample_rate as usize * config.buffer_len_ms as usize) / 1000;
79
80        let shared = Arc::new(Mutex::new(RingBuffer {
81            audio: vec![0.0; ring_samples],
82            audio_pos: 0,
83            audio_len: 0,
84            audio_read: 0,
85            eof: false,
86        }));
87
88        let stop = Arc::new(std::sync::atomic::AtomicBool::new(false));
89
90        let shared_clone = Arc::clone(&shared);
91        let stop_clone = Arc::clone(&stop);
92        let format = config.format;
93
94        let handle = thread::spawn(move || {
95            reader_loop(source, shared_clone, stop_clone, format);
96        });
97
98        Self {
99            shared,
100            handle: Some(handle),
101            stop,
102        }
103    }
104
105    /// Pop up to `ms` milliseconds of audio from the ring buffer.
106    /// Returns fewer samples if not enough are available.
107    pub fn pop_ms(&self, ms: i32) -> Vec<f32> {
108        let mut ring = self.shared.lock().unwrap();
109        let n_samples = ((WHISPER_SAMPLE_RATE as usize) * ms.max(0) as usize) / 1000;
110        let n = n_samples.min(ring.audio_len);
111
112        if n == 0 {
113            return Vec::new();
114        }
115
116        let mut result = vec![0.0f32; n];
117        let cap = ring.audio.len();
118        let s0 = ring.audio_read;
119
120        if s0 + n > cap {
121            let n0 = cap - s0;
122            result[..n0].copy_from_slice(&ring.audio[s0..]);
123            result[n0..].copy_from_slice(&ring.audio[..n - n0]);
124        } else {
125            result.copy_from_slice(&ring.audio[s0..s0 + n]);
126        }
127
128        ring.audio_read = (ring.audio_read + n) % cap;
129        ring.audio_len -= n;
130        result
131    }
132
133    /// Number of unread samples currently in the ring buffer.
134    pub fn available_samples(&self) -> usize {
135        self.shared.lock().unwrap().audio_len
136    }
137
138    /// Whether the source has reached EOF.
139    pub fn is_eof(&self) -> bool {
140        self.shared.lock().unwrap().eof
141    }
142
143    /// Signal the reader thread to stop.
144    pub fn stop(&mut self) {
145        self.stop
146            .store(true, std::sync::atomic::Ordering::Relaxed);
147        if let Some(h) = self.handle.take() {
148            let _ = h.join();
149        }
150    }
151}
152
153impl Drop for PcmReader {
154    fn drop(&mut self) {
155        self.stop();
156    }
157}
158
159/// Background reader loop — direct port of `pcm_async::reader_loop`.
160fn reader_loop(
161    mut source: Box<dyn Read + Send>,
162    shared: Arc<Mutex<RingBuffer>>,
163    stop: Arc<std::sync::atomic::AtomicBool>,
164    format: PcmFormat,
165) {
166    let bytes_per_sample: usize = match format {
167        PcmFormat::F32 => 4,
168        PcmFormat::S16 => 2,
169    };
170
171    let mut buffer = vec![0u8; 4096];
172    let mut carry: Vec<u8> = Vec::new();
173
174    loop {
175        if stop.load(std::sync::atomic::Ordering::Relaxed) {
176            break;
177        }
178
179        let n_read = match source.read(&mut buffer) {
180            Ok(0) => {
181                shared.lock().unwrap().eof = true;
182                break;
183            }
184            Ok(n) => n,
185            Err(_) => {
186                shared.lock().unwrap().eof = true;
187                break;
188            }
189        };
190
191        // Combine carry bytes with freshly read bytes
192        let mut data = Vec::with_capacity(carry.len() + n_read);
193        data.extend_from_slice(&carry);
194        data.extend_from_slice(&buffer[..n_read]);
195        carry.clear();
196
197        let total_bytes = data.len();
198        let n_samples = total_bytes / bytes_per_sample;
199        let rem = total_bytes % bytes_per_sample;
200
201        if rem > 0 {
202            carry.extend_from_slice(&data[total_bytes - rem..]);
203        }
204
205        if n_samples == 0 {
206            continue;
207        }
208
209        // Convert to f32
210        let samples: Vec<f32> = match format {
211            PcmFormat::F32 => {
212                (0..n_samples)
213                    .map(|i| {
214                        let o = i * 4;
215                        f32::from_le_bytes([data[o], data[o + 1], data[o + 2], data[o + 3]])
216                    })
217                    .collect()
218            }
219            PcmFormat::S16 => {
220                (0..n_samples)
221                    .map(|i| {
222                        let o = i * 2;
223                        i16::from_le_bytes([data[o], data[o + 1]]) as f32 / 32768.0
224                    })
225                    .collect()
226            }
227        };
228
229        // Push into ring buffer
230        push_samples(&shared, &samples);
231    }
232}
233
234/// Push samples into the ring buffer — direct port of `pcm_async::push_samples`.
235fn push_samples(shared: &Arc<Mutex<RingBuffer>>, data: &[f32]) {
236    if data.is_empty() {
237        return;
238    }
239
240    let mut ring = shared.lock().unwrap();
241    let cap = ring.audio.len();
242    let mut src = data;
243    let mut n = data.len();
244
245    // If more samples than ring capacity, skip the oldest
246    if n > cap {
247        src = &data[n - cap..];
248        n = cap;
249    }
250
251    // Drop oldest unread samples if we'd overflow
252    if n > cap - ring.audio_len {
253        let drop = n - (cap - ring.audio_len);
254        ring.audio_read = (ring.audio_read + drop) % cap;
255        ring.audio_len -= drop;
256    }
257
258    // Write into ring
259    let pos = ring.audio_pos;
260    if pos + n > cap {
261        let n0 = cap - pos;
262        ring.audio[pos..].copy_from_slice(&src[..n0]);
263        ring.audio[..n - n0].copy_from_slice(&src[n0..]);
264    } else {
265        ring.audio[pos..pos + n].copy_from_slice(src);
266    }
267
268    ring.audio_pos = (ring.audio_pos + n) % cap;
269    ring.audio_len = (ring.audio_len + n).min(cap);
270}
271
272// ---------------------------------------------------------------------------
273// vad_simple — port of common.cpp::vad_simple + high_pass_filter
274// ---------------------------------------------------------------------------
275
276/// High-pass filter — port of `common.cpp::high_pass_filter`.
277fn high_pass_filter(data: &mut [f32], cutoff: f32, sample_rate: f32) {
278    if data.is_empty() {
279        return;
280    }
281    let rc = 1.0 / (2.0 * std::f32::consts::PI * cutoff);
282    let dt = 1.0 / sample_rate;
283    let alpha = dt / (rc + dt);
284
285    let mut y = data[0];
286    for i in 1..data.len() {
287        y = alpha * (y + data[i] - data[i - 1]);
288        data[i] = y;
289    }
290}
291
292/// Energy-based VAD — port of `common.cpp::vad_simple`.
293///
294/// Returns `true` if the audio chunk is **silence** (no speech detected).
295pub fn vad_simple(
296    pcmf32: &[f32],
297    sample_rate: i32,
298    last_ms: i32,
299    vad_thold: f32,
300    freq_thold: f32,
301) -> bool {
302    let n_samples = pcmf32.len();
303    let n_samples_last = (sample_rate as usize * last_ms.max(0) as usize) / 1000;
304
305    if n_samples_last >= n_samples {
306        // not enough samples — assume no speech
307        return true; // silence
308    }
309
310    // Work on a copy so we can apply the high-pass filter
311    let mut data = pcmf32.to_vec();
312
313    if freq_thold > 0.0 {
314        high_pass_filter(&mut data, freq_thold, sample_rate as f32);
315    }
316
317    let mut energy_all: f32 = 0.0;
318    let mut energy_last: f32 = 0.0;
319
320    for (i, &s) in data.iter().enumerate() {
321        energy_all += s.abs();
322        if i >= n_samples - n_samples_last {
323            energy_last += s.abs();
324        }
325    }
326
327    energy_all /= n_samples as f32;
328    energy_last /= n_samples_last as f32;
329
330    // C++ returns false (= NOT silence) when energy_last > thold*energy_all
331    // We return true for silence, matching the C++ sense where true = silence.
332    energy_last <= vad_thold * energy_all
333}
334
335// ---------------------------------------------------------------------------
336// WhisperStreamPcmConfig
337// ---------------------------------------------------------------------------
338
339/// Configuration for [`WhisperStreamPcm`] — maps to `whisper_params` streaming subset.
340#[derive(Debug, Clone)]
341pub struct WhisperStreamPcmConfig {
342    /// Fixed-step chunk size in ms (non-VAD mode).
343    pub step_ms: i32,
344    /// Max audio length per inference in ms.
345    pub length_ms: i32,
346    /// Overlap to retain from previous step in ms.
347    pub keep_ms: i32,
348    /// Enable VAD-driven segmentation.
349    pub use_vad: bool,
350    /// VAD threshold (both simple & Silero).
351    pub vad_thold: f32,
352    /// High-pass frequency cutoff for simple VAD.
353    pub freq_thold: f32,
354    /// If true, don't carry prompt tokens across inference boundaries.
355    pub no_context: bool,
356    /// VAD probe chunk size in ms.
357    pub vad_probe_ms: i32,
358    /// Silence duration to end a segment in ms.
359    pub vad_silence_ms: i32,
360    /// Audio prepended before VAD trigger in ms.
361    pub vad_pre_roll_ms: i32,
362}
363
364impl Default for WhisperStreamPcmConfig {
365    fn default() -> Self {
366        Self {
367            step_ms: 3000,
368            length_ms: 10000,
369            keep_ms: 200,
370            use_vad: false,
371            vad_thold: 0.6,
372            freq_thold: 100.0,
373            no_context: true,
374            vad_probe_ms: 200,
375            vad_silence_ms: 800,
376            vad_pre_roll_ms: 300,
377        }
378    }
379}
380
381// ---------------------------------------------------------------------------
382// WhisperStreamPcm — main processor
383// ---------------------------------------------------------------------------
384
385/// Streaming PCM transcriber — direct port of `stream-pcm.cpp` main loop.
386///
387/// Two modes:
388/// - **Fixed-step** (`use_vad = false`): process `step_ms` chunks with overlap.
389/// - **VAD-driven** (`use_vad = true`): accumulate speech, transcribe on silence.
390pub struct WhisperStreamPcm {
391    state: WhisperState,
392    params: FullParams,
393    config: WhisperStreamPcmConfig,
394    reader: PcmReader,
395    vad: Option<WhisperVadProcessor>,
396
397    // Pre-computed sample counts
398    n_samples_step: usize,
399    n_samples_len: usize,
400    n_samples_keep: usize,
401
402    // Fixed-step state
403    pcmf32_old: Vec<f32>,
404    n_new_line: i32,
405    prompt_tokens: Vec<i32>,
406
407    // VAD state machine
408    in_speech: bool,
409    speech_buf: Vec<f32>,
410    pre_roll: Vec<f32>,
411    silence_samples: usize,
412
413    total_samples: i64,
414    n_iter: i32,
415
416    // VAD pre-computed
417    vad_last_ms: i32,
418    vad_pre_roll_samples: usize,
419    vad_silence_samples: usize,
420    vad_max_segment_samples: usize,
421}
422
423impl WhisperStreamPcm {
424    /// Create a new WhisperStreamPcm processor (simple VAD or no VAD).
425    pub fn new(
426        ctx: &WhisperContext,
427        params: FullParams,
428        mut config: WhisperStreamPcmConfig,
429        reader: PcmReader,
430    ) -> Result<Self> {
431        Self::build(ctx, params, &mut config, reader, None)
432    }
433
434    /// Create a new WhisperStreamPcm processor with Silero VAD.
435    pub fn with_vad(
436        ctx: &WhisperContext,
437        params: FullParams,
438        mut config: WhisperStreamPcmConfig,
439        reader: PcmReader,
440        vad: WhisperVadProcessor,
441    ) -> Result<Self> {
442        Self::build(ctx, params, &mut config, reader, Some(vad))
443    }
444
445    fn build(
446        ctx: &WhisperContext,
447        params: FullParams,
448        config: &mut WhisperStreamPcmConfig,
449        reader: PcmReader,
450        vad: Option<WhisperVadProcessor>,
451    ) -> Result<Self> {
452        let state = WhisperState::new(ctx)?;
453
454        // Normalize config (matches C++ main)
455        if !config.use_vad {
456            if config.step_ms <= 0 {
457                return Err(WhisperError::InvalidParameter(
458                    "step_ms must be > 0 unless use_vad is true".into(),
459                ));
460            }
461            config.keep_ms = config.keep_ms.min(config.step_ms);
462            config.length_ms = config.length_ms.max(config.step_ms);
463        } else {
464            if config.length_ms <= 0 {
465                config.length_ms = 5000;
466            }
467            config.keep_ms = 0;
468            // Force no_context in VAD mode (stream.cpp: no_context |= use_vad)
469            config.no_context = true;
470        }
471
472        let n_samples_step = if config.use_vad {
473            0
474        } else {
475            (config.step_ms as f64 * 0.001 * WHISPER_SAMPLE_RATE as f64) as usize
476        };
477        let n_samples_len =
478            (config.length_ms as f64 * 0.001 * WHISPER_SAMPLE_RATE as f64) as usize;
479        let n_samples_keep =
480            (config.keep_ms as f64 * 0.001 * WHISPER_SAMPLE_RATE as f64) as usize;
481
482        let n_new_line = if !config.use_vad && config.step_ms > 0 {
483            (config.length_ms / config.step_ms - 1).max(1)
484        } else {
485            1
486        };
487
488        let vad_probe_ms = config.vad_probe_ms.max(1);
489        let vad_last_ms = (vad_probe_ms / 2).clamp(1, 1000);
490        let vad_pre_roll_samples =
491            (WHISPER_SAMPLE_RATE as usize * config.vad_pre_roll_ms.max(0) as usize) / 1000;
492        let vad_silence_samples =
493            (WHISPER_SAMPLE_RATE as usize * config.vad_silence_ms.max(0) as usize) / 1000;
494
495        Ok(Self {
496            state,
497            params,
498            config: config.clone(),
499            reader,
500            vad,
501            n_samples_step,
502            n_samples_len,
503            n_samples_keep,
504            pcmf32_old: Vec::new(),
505            n_new_line,
506            prompt_tokens: Vec::new(),
507            in_speech: false,
508            speech_buf: Vec::new(),
509            pre_roll: Vec::new(),
510            silence_samples: 0,
511            total_samples: 0,
512            n_iter: 0,
513            vad_last_ms,
514            vad_pre_roll_samples,
515            vad_silence_samples,
516            vad_max_segment_samples: n_samples_len,
517        })
518    }
519
520    /// Returns `true` when the underlying reader has hit EOF and all samples are drained.
521    pub fn is_eof(&self) -> bool {
522        self.reader.is_eof() && self.reader.available_samples() == 0
523    }
524
525    /// Run one iteration of the main loop.
526    ///
527    /// Returns `Ok(Some(segments))` if transcription occurred,
528    /// `Ok(None)` if waiting for more audio or sleeping,
529    /// `Err` on fatal error.
530    ///
531    /// Returns `Ok(None)` with no more audio when EOF + drained.
532    pub fn process_step(&mut self) -> Result<Option<Vec<Segment>>> {
533        if !self.config.use_vad {
534            self.process_step_fixed()
535        } else {
536            self.process_step_vad()
537        }
538    }
539
540    /// Run until EOF or error. Calls `callback` for each transcription.
541    pub fn run<F>(&mut self, mut callback: F) -> Result<()>
542    where
543        F: FnMut(&[Segment], i64, i64),
544    {
545        loop {
546            match self.process_step()? {
547                Some(segments) if !segments.is_empty() => {
548                    let start = segments.first().map(|s| s.start_ms).unwrap_or(0);
549                    let end = segments.last().map(|s| s.end_ms).unwrap_or(0);
550                    callback(&segments, start, end);
551                }
552                Some(_) => {} // empty segments, keep going
553                None => {
554                    // Check if truly done (EOF + no audio left)
555                    if self.reader.is_eof() && self.reader.available_samples() == 0 {
556                        // Flush any remaining VAD speech
557                        if self.config.use_vad && self.in_speech && !self.speech_buf.is_empty() {
558                            let segments =
559                                self.run_inference(&self.speech_buf.clone())?;
560                            if !segments.is_empty() {
561                                let start =
562                                    segments.first().map(|s| s.start_ms).unwrap_or(0);
563                                let end =
564                                    segments.last().map(|s| s.end_ms).unwrap_or(0);
565                                callback(&segments, start, end);
566                            }
567                            self.speech_buf.clear();
568                            self.in_speech = false;
569                        }
570                        break;
571                    }
572                    // Still waiting for audio
573                    std::thread::sleep(std::time::Duration::from_millis(5));
574                }
575            }
576        }
577        Ok(())
578    }
579
580    /// Fixed-step processing — port of the non-VAD branch of the C++ main loop.
581    fn process_step_fixed(&mut self) -> Result<Option<Vec<Segment>>> {
582        let available = self.reader.available_samples();
583
584        if available < self.n_samples_step {
585            if self.reader.is_eof() {
586                if available == 0 {
587                    return Ok(None); // done
588                }
589                // Fall through to process remaining
590            } else {
591                return Ok(None); // wait for more audio
592            }
593        }
594
595        let pcmf32_new = self.reader.pop_ms(self.config.step_ms);
596        if pcmf32_new.is_empty() {
597            return Ok(None);
598        }
599
600        self.total_samples += pcmf32_new.len() as i64;
601
602        let n_samples_new = pcmf32_new.len();
603        let n_samples_take = self
604            .pcmf32_old
605            .len()
606            .min((self.n_samples_keep + self.n_samples_len).saturating_sub(n_samples_new));
607
608        let mut pcmf32 = Vec::with_capacity(n_samples_new + n_samples_take);
609
610        // Prepend overlap from previous step
611        if n_samples_take > 0 && !self.pcmf32_old.is_empty() {
612            let start = self.pcmf32_old.len() - n_samples_take;
613            pcmf32.extend_from_slice(&self.pcmf32_old[start..]);
614        }
615        pcmf32.extend_from_slice(&pcmf32_new);
616
617        self.pcmf32_old = pcmf32.clone();
618
619        let segments = self.run_inference(&pcmf32)?;
620        self.n_iter += 1;
621
622        // At n_new_line boundary (stream.cpp lines 408-425)
623        if self.n_iter % self.n_new_line == 0 {
624            if self.n_samples_keep > 0 && pcmf32.len() >= self.n_samples_keep {
625                self.pcmf32_old = pcmf32[pcmf32.len() - self.n_samples_keep..].to_vec();
626            }
627
628            if !self.config.no_context {
629                self.collect_prompt_tokens();
630            }
631        }
632
633        Ok(Some(segments))
634    }
635
636    /// VAD-driven processing — port of the VAD branch of the C++ main loop.
637    fn process_step_vad(&mut self) -> Result<Option<Vec<Segment>>> {
638        let available = self.reader.available_samples();
639
640        if available == 0 {
641            if self.reader.is_eof() {
642                // Flush remaining speech
643                if self.in_speech && !self.speech_buf.is_empty() {
644                    let segments = self.run_inference(
645                        &self.speech_buf.clone(),
646                    )?;
647                    self.speech_buf.clear();
648                    self.in_speech = false;
649                    self.n_iter += 1;
650                    return Ok(Some(segments));
651                }
652                return Ok(None);
653            }
654            return Ok(None); // wait
655        }
656
657        let pcmf32_new = self.reader.pop_ms(self.config.vad_probe_ms);
658        if pcmf32_new.is_empty() {
659            return Ok(None);
660        }
661
662        self.total_samples += pcmf32_new.len() as i64;
663
664        // Determine silence via Silero or simple VAD
665        let silence = if let Some(ref mut vad) = self.vad {
666            if vad.detect_speech(&pcmf32_new) {
667                let probs = vad.get_probs();
668                let avg = if probs.is_empty() {
669                    0.0
670                } else {
671                    probs.iter().sum::<f32>() / probs.len() as f32
672                };
673                avg < self.config.vad_thold
674            } else {
675                true // detect failed → treat as silence
676            }
677        } else {
678            vad_simple(
679                &pcmf32_new,
680                WHISPER_SAMPLE_RATE,
681                self.vad_last_ms,
682                self.config.vad_thold,
683                self.config.freq_thold,
684            )
685        };
686
687        let mut result_segments: Option<Vec<Segment>> = None;
688
689        if !self.in_speech {
690            if !silence {
691                self.in_speech = true;
692                self.silence_samples = 0;
693                self.speech_buf.clear();
694                if !self.pre_roll.is_empty() {
695                    self.speech_buf.extend_from_slice(&self.pre_roll);
696                }
697                self.speech_buf.extend_from_slice(&pcmf32_new);
698            }
699        } else {
700            self.speech_buf.extend_from_slice(&pcmf32_new);
701            if !silence {
702                self.silence_samples = 0;
703            } else {
704                self.silence_samples += pcmf32_new.len();
705            }
706
707            if self.speech_buf.len() >= self.vad_max_segment_samples
708                || self.silence_samples >= self.vad_silence_samples
709            {
710                let segments = self.run_inference(
711                    &self.speech_buf.clone(),
712                )?;
713                self.speech_buf.clear();
714                self.in_speech = false;
715                self.silence_samples = 0;
716                self.n_iter += 1;
717                result_segments = Some(segments);
718            }
719        }
720
721        // Maintain pre-roll buffer
722        if self.vad_pre_roll_samples > 0 {
723            self.pre_roll.extend_from_slice(&pcmf32_new);
724            if self.pre_roll.len() > self.vad_pre_roll_samples {
725                let excess = self.pre_roll.len() - self.vad_pre_roll_samples;
726                self.pre_roll.drain(..excess);
727            }
728        }
729
730        Ok(result_segments)
731    }
732
733    /// Run whisper inference on an audio buffer — port of `run_inference` lambda.
734    fn run_inference(&mut self, audio: &[f32]) -> Result<Vec<Segment>> {
735        if audio.is_empty() {
736            return Ok(Vec::new());
737        }
738
739        let mut params = self.params.clone();
740        if !self.config.no_context && !self.prompt_tokens.is_empty() {
741            params = params.prompt_tokens(&self.prompt_tokens);
742        }
743
744        self.state.full(params, audio)?;
745
746        let n_segments = self.state.full_n_segments();
747        let mut segments = Vec::with_capacity(n_segments as usize);
748
749        for i in 0..n_segments {
750            let text = self.state.full_get_segment_text(i)?;
751            let (start_ms, end_ms) = self.state.full_get_segment_timestamps(i);
752            let speaker_turn_next = self.state.full_get_segment_speaker_turn_next(i);
753
754            segments.push(Segment {
755                start_ms,
756                end_ms,
757                text,
758                speaker_turn_next,
759            });
760        }
761
762        Ok(segments)
763    }
764
765    /// Collect prompt tokens from last inference — port of stream.cpp lines 416-425.
766    fn collect_prompt_tokens(&mut self) {
767        self.prompt_tokens.clear();
768
769        let n_segments = self.state.full_n_segments();
770        for i in 0..n_segments {
771            let token_count = self.state.full_n_tokens(i);
772            for j in 0..token_count {
773                self.prompt_tokens
774                    .push(self.state.full_get_token_id(i, j));
775            }
776        }
777    }
778
779    /// Get the total number of processed samples.
780    pub fn total_samples(&self) -> i64 {
781        self.total_samples
782    }
783
784    /// Get the iteration count.
785    pub fn n_iter(&self) -> i32 {
786        self.n_iter
787    }
788}
789
790// ---------------------------------------------------------------------------
791// Tests
792// ---------------------------------------------------------------------------
793
794#[cfg(test)]
795mod tests {
796    use super::*;
797
798    #[test]
799    fn test_pcm_format_eq() {
800        assert_eq!(PcmFormat::F32, PcmFormat::F32);
801        assert_ne!(PcmFormat::F32, PcmFormat::S16);
802    }
803
804    #[test]
805    fn test_vad_simple_silence() {
806        // All zeros = silence
807        let silence = vec![0.0f32; 16000];
808        assert!(vad_simple(&silence, 16000, 100, 0.6, 100.0));
809    }
810
811    #[test]
812    fn test_vad_simple_too_few_samples() {
813        let short = vec![0.1f32; 100];
814        // last_ms=1000 → needs 16000 samples, only have 100 → silence
815        assert!(vad_simple(&short, 16000, 1000, 0.6, 100.0));
816    }
817
818    #[test]
819    fn test_high_pass_filter_basic() {
820        let mut data = vec![1.0, 0.0, 1.0, 0.0, 1.0];
821        high_pass_filter(&mut data, 100.0, 16000.0);
822        // After filter, values should be modified
823        assert_ne!(data[2], 1.0);
824    }
825
826    #[test]
827    fn test_pcm_reader_f32() {
828        // Simulate 1 second of f32 PCM data (16000 samples)
829        let n = 16000;
830        let mut raw = Vec::with_capacity(n * 4);
831        for i in 0..n {
832            let val = (i as f32 / n as f32) * 2.0 - 1.0; // ramp -1..1
833            raw.extend_from_slice(&val.to_le_bytes());
834        }
835
836        let cursor = std::io::Cursor::new(raw);
837        let config = PcmReaderConfig {
838            buffer_len_ms: 2000,
839            sample_rate: 16000,
840            format: PcmFormat::F32,
841        };
842        let reader = PcmReader::new(Box::new(cursor), config);
843
844        // Wait for reader thread to consume
845        std::thread::sleep(std::time::Duration::from_millis(100));
846
847        assert!(reader.is_eof());
848        let samples = reader.pop_ms(1000);
849        assert_eq!(samples.len(), 16000);
850    }
851
852    #[test]
853    fn test_pcm_reader_s16() {
854        let n = 16000;
855        let mut raw = Vec::with_capacity(n * 2);
856        for i in 0..n {
857            let val = ((i as f32 / n as f32) * 2.0 - 1.0) * 32767.0;
858            raw.extend_from_slice(&(val as i16).to_le_bytes());
859        }
860
861        let cursor = std::io::Cursor::new(raw);
862        let config = PcmReaderConfig {
863            buffer_len_ms: 2000,
864            sample_rate: 16000,
865            format: PcmFormat::S16,
866        };
867        let reader = PcmReader::new(Box::new(cursor), config);
868
869        std::thread::sleep(std::time::Duration::from_millis(100));
870
871        assert!(reader.is_eof());
872        let samples = reader.pop_ms(1000);
873        assert_eq!(samples.len(), 16000);
874
875        // Check conversion — first sample should be near -1.0
876        assert!(samples[0] < -0.9);
877    }
878
879    #[test]
880    fn test_ring_buffer_overflow() {
881        // Buffer only holds 500ms = 8000 samples, but we push 16000
882        let n = 16000;
883        let mut raw = Vec::with_capacity(n * 4);
884        for i in 0..n {
885            raw.extend_from_slice(&(i as f32).to_le_bytes());
886        }
887
888        let cursor = std::io::Cursor::new(raw);
889        let config = PcmReaderConfig {
890            buffer_len_ms: 500,
891            sample_rate: 16000,
892            format: PcmFormat::F32,
893        };
894        let reader = PcmReader::new(Box::new(cursor), config);
895
896        std::thread::sleep(std::time::Duration::from_millis(100));
897
898        // Should only have 8000 samples (most recent)
899        let available = reader.available_samples();
900        assert!(available <= 8000);
901
902        let samples = reader.pop_ms(500);
903        assert_eq!(samples.len(), 8000);
904        // Last sample should be 15999.0
905        assert!((samples[samples.len() - 1] - 15999.0).abs() < 0.01);
906    }
907
908    #[test]
909    fn test_stream_pcm_config_defaults() {
910        let config = WhisperStreamPcmConfig::default();
911        assert_eq!(config.step_ms, 3000);
912        assert_eq!(config.length_ms, 10000);
913        assert_eq!(config.keep_ms, 200);
914        assert!(!config.use_vad);
915    }
916
917    #[test]
918    fn test_stream_pcm_config_vad_normalization() {
919        // When use_vad=true, keep_ms should be forced to 0
920        use std::path::Path;
921        let model_path = "tests/models/ggml-tiny.en.bin";
922        if !Path::new(model_path).exists() {
923            eprintln!("Skipping: model not found");
924            return;
925        }
926
927        let ctx = WhisperContext::new(model_path).unwrap();
928        let params = FullParams::default();
929        let cursor = std::io::Cursor::new(Vec::<u8>::new());
930        let reader = PcmReader::new(Box::new(cursor), PcmReaderConfig::default());
931        let config = WhisperStreamPcmConfig {
932            use_vad: true,
933            keep_ms: 500, // should be forced to 0
934            ..Default::default()
935        };
936
937        let stream = WhisperStreamPcm::new(&ctx, params, config, reader).unwrap();
938        assert_eq!(stream.config.keep_ms, 0);
939    }
940}