Skip to main content

whisperforge_core/
stream_decode.rs

1use std::time::Instant;
2
3use anyhow::{Context, Result};
4use burn::tensor::{Tensor, backend::Backend};
5use tokenizers::Tokenizer;
6use tracing::{Level, event};
7
8use crate::kv_cache::{KvCache, forward_decoder_cached};
9use crate::model::Whisper;
10
11#[derive(Clone)]
12pub struct TokenEmit {
13    pub id: u32,
14    /// Detokenised surface text for this single token (via `tokenizer.decode(&[id], false)`),
15    /// so concatenating regular-token `text` fields yields a properly-spaced transcript.
16    /// Empty string for special and timestamp tokens.
17    pub text: String,
18    pub logprob: f32,
19    /// Seconds within the 30-s window for timestamp tokens; `None` for regular tokens.
20    pub window_ts_secs: Option<f32>,
21    /// `true` when `id >= eot_token` (EOT, language, task, timestamp, etc.).
22    pub is_special: bool,
23}
24
25pub struct DecodeContext<'a> {
26    /// Tokens from the previous committed utterance fed as a prompt prefix before `<|sot|>`.
27    /// Empty slice for the first window.
28    pub prompt_tokens: &'a [u32],
29    pub language_token: u32,
30    /// Task token (`<|transcribe|>` or `<|translate|>`). Translate is X → English only.
31    pub task_token: u32,
32    pub sot_token: u32,
33    pub eot_token: u32,
34    pub no_speech_token: u32,
35    /// First timestamp token ID (50364 for all current Whisper models).
36    pub timestamp_begin_token: u32,
37    /// `<|notimestamps|>` token ID. Pushed as the 4th init token so the decoder produces
38    /// plain text (no per-token timestamps between words). Required for greedy decode to
39    /// produce coherent content on this checkpoint — greedy + timestamps-on emits mostly
40    /// timestamps with little content even with `ApplyTimestampRules`-style filtering.
41    pub notimestamps_token: u32,
42    /// Hard cap on new tokens generated (not counting prompt or init tokens).
43    pub max_new_tokens: usize,
44    /// Return empty `Vec` when P(no_speech) exceeds this at step 0 (default 0.6).
45    pub no_speech_threshold: f32,
46}
47
48/// Greedy KV-cached decode for one streaming window.
49///
50/// `encoder_out` is consumed by `KvCache::new`; clone before calling if you need it again.
51///
52/// `<|notimestamps|>` is pushed as the fourth init token so the decoder emits plain text
53/// (no per-token timestamps between words). This matches the one-shot transcribe path and
54/// gives reliable greedy output on short windows. The original plan was to leave timestamps
55/// enabled so the streaming caller could anchor buffer trims at committed-token boundaries
56/// — but empirically, greedy decode on tiny.en with timestamps on emits mostly timestamps
57/// and little content, even with `ApplyTimestampRules`-style logit filtering. Reliable
58/// timestamps in streaming would require temperature-fallback sampling (like the one-shot
59/// `HybridDecoder`), which is out of scope here. The streaming caller uses a stride-based
60/// trim heuristic on cap-hit instead — see `whisperforge/src/commands/stream.rs`.
61pub fn decode_window<B: Backend>(
62    model: &Whisper<B>,
63    encoder_out: Tensor<B, 3>,
64    ctx: &DecodeContext,
65    tokenizer: &Tokenizer,
66    device: &B::Device,
67) -> Result<Vec<TokenEmit>> {
68    let t0 = Instant::now();
69
70    let mut cache = KvCache::new(model, encoder_out);
71
72    // Seed the self-attention KV cache with any prior-context prompt tokens.
73    if !ctx.prompt_tokens.is_empty() {
74        event!(
75            Level::DEBUG,
76            prompt_len = ctx.prompt_tokens.len(),
77            prompt_first_token = ctx.prompt_tokens[0],
78            "feeding prompt prefix into KV cache",
79        );
80    }
81    for &tok in ctx.prompt_tokens {
82        forward_decoder_cached(model, tok, &mut cache, device)
83            .with_context(|| format!("feeding prompt token {tok}"))?;
84    }
85
86    // Feed the four init tokens. `<|notimestamps|>` is included so the greedy decoder
87    // produces plain text (matching the one-shot transcribe path). With timestamps enabled,
88    // greedy decode on this checkpoint emits mostly timestamps with little content even
89    // under `ApplyTimestampRules`-style filtering — see function docstring.
90    let init = [
91        ctx.sot_token,
92        ctx.language_token,
93        ctx.task_token,
94        ctx.notimestamps_token,
95    ];
96    let mut logits: Vec<f32> = Vec::new();
97    for (i, &tok) in init.iter().enumerate() {
98        logits = forward_decoder_cached(model, tok, &mut cache, device)
99            .with_context(|| format!("feeding init token at index {i}"))?;
100    }
101
102    // No-speech gate: if the model is confident there is no speech, skip this window.
103    if softmax_at(&logits, ctx.no_speech_token) > ctx.no_speech_threshold {
104        event!(
105            Level::DEBUG,
106            decode_ms = t0.elapsed().as_millis(),
107            n_tokens = 0usize,
108            skipped = true
109        );
110        return Ok(Vec::new());
111    }
112
113    // Suppress EOT at step 0 to avoid a premature stop on the very first generated token.
114    if (ctx.eot_token as usize) < logits.len() {
115        logits[ctx.eot_token as usize] = f32::NEG_INFINITY;
116    }
117
118    let mut emits: Vec<TokenEmit> = Vec::new();
119
120    for _ in 0..ctx.max_new_tokens {
121        let token_id = argmax(&logits);
122
123        if token_id == ctx.eot_token {
124            break;
125        }
126
127        let logprob = log_softmax_at(&logits, token_id);
128        let is_special = token_id >= ctx.eot_token;
129        let window_ts_secs = if token_id >= ctx.timestamp_begin_token {
130            Some((token_id - ctx.timestamp_begin_token) as f32 * 0.02)
131        } else {
132            None
133        };
134        let text = if is_special {
135            String::new()
136        } else {
137            tokenizer.decode(&[token_id], false).unwrap_or_default()
138        };
139
140        emits.push(TokenEmit {
141            id: token_id,
142            text,
143            logprob,
144            window_ts_secs,
145            is_special,
146        });
147
148        logits = forward_decoder_cached(model, token_id, &mut cache, device)
149            .with_context(|| format!("decode step {}", emits.len()))?;
150    }
151
152    event!(
153        Level::DEBUG,
154        decode_ms = t0.elapsed().as_millis(),
155        n_tokens = emits.len()
156    );
157
158    // Punctuation-only decodes are a Whisper failure mode on short/uncertain windows: the
159    // model emits a lone `.` or `,` then EOT. Treat as no-speech so the streaming committer
160    // doesn't pick up the noise as a stable prefix. Only fires when there *is* regular text
161    // — an emits list containing only specials/timestamps still passes through (the
162    // committer ignores them) so we don't accidentally drop legitimate timestamp-only
163    // windows.
164    let regular_text: String = emits
165        .iter()
166        .filter(|t| !t.is_special)
167        .map(|t| t.text.as_str())
168        .collect();
169    let trimmed = regular_text.trim();
170    if !trimmed.is_empty()
171        && trimmed
172            .chars()
173            .all(|c| c.is_ascii_punctuation() || c.is_whitespace())
174    {
175        event!(
176            Level::DEBUG,
177            dropped_punctuation_only = true,
178            text = %trimmed
179        );
180        return Ok(Vec::new());
181    }
182
183    Ok(emits)
184}
185
186fn argmax(logits: &[f32]) -> u32 {
187    logits
188        .iter()
189        .enumerate()
190        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
191        .map(|(i, _)| i as u32)
192        .unwrap_or(0)
193}
194
195fn softmax_at(logits: &[f32], token: u32) -> f32 {
196    let idx = token as usize;
197    if idx >= logits.len() {
198        return 0.0;
199    }
200    let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
201    let exp_sum: f32 = logits.iter().map(|&l| (l - max).exp()).sum();
202    ((logits[idx] - max).exp()) / exp_sum.max(f32::EPSILON)
203}
204
205fn log_softmax_at(logits: &[f32], token: u32) -> f32 {
206    let idx = token as usize;
207    if idx >= logits.len() {
208        return f32::NEG_INFINITY;
209    }
210    let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
211    let log_sum = max + logits.iter().map(|&l| (l - max).exp()).sum::<f32>().ln();
212    logits[idx] - log_sum
213}
214
215pub fn avg_logprob(tokens: &[TokenEmit]) -> f32 {
216    let content: Vec<f32> = tokens
217        .iter()
218        .filter(|t| !t.is_special)
219        .map(|t| t.logprob)
220        .collect();
221    if content.is_empty() {
222        return 0.0;
223    }
224    content.iter().sum::<f32>() / content.len() as f32
225}
226
227/// Per-window quality thresholds, mirroring faster-whisper's `log_prob_threshold`
228/// and `compression_ratio_threshold`. Used by the streaming caller to reject a decoded
229/// window before it reaches the LocalAgreement committer — LA-2 only rejects *unstable*
230/// output, so a *confident* hallucination loop (the `*sigh* *sigh* *sigh*` failure mode)
231/// would otherwise commit. The defaults match faster-whisper.
232#[derive(Clone, Copy, Debug)]
233pub struct QualityGate {
234    /// Reject the window when `avg_logprob` of content tokens is below this (default -1.0).
235    pub log_prob_threshold: f32,
236    /// Reject the window when the gzip compression ratio of its text exceeds this
237    /// (default 2.4) — high ratios signal a repetition/hallucination loop.
238    pub compression_ratio_threshold: f32,
239}
240
241impl Default for QualityGate {
242    fn default() -> Self {
243        Self {
244            log_prob_threshold: -1.0,
245            compression_ratio_threshold: 2.4,
246        }
247    }
248}
249
250/// Returns `false` when `emits` should be dropped as low-confidence or repetitive.
251///
252/// Windows with no content (regular) tokens always pass: the no-speech and
253/// punctuation-only gates in [`decode_window`] already handle empties, and the
254/// streaming committer ignores special/timestamp-only emits.
255pub fn passes_quality_gate(emits: &[TokenEmit], gate: &QualityGate) -> bool {
256    let text: String = emits
257        .iter()
258        .filter(|t| !t.is_special)
259        .map(|t| t.text.as_str())
260        .collect();
261    if text.trim().is_empty() {
262        return true;
263    }
264    if avg_logprob(emits) < gate.log_prob_threshold {
265        return false;
266    }
267    if crate::decoding::compression_ratio(&text) > gate.compression_ratio_threshold {
268        return false;
269    }
270    true
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use anyhow::Result;
277    use burn_flex::{Flex, FlexDevice};
278
279    use crate::model::WhisperConfig;
280
281    fn tiny_en_random() -> (Whisper<Flex<f32>>, FlexDevice) {
282        let device = FlexDevice;
283        let config = WhisperConfig::tiny_en();
284        let model = config.init::<Flex<f32>>(&device);
285        (model, device)
286    }
287
288    fn dummy_tokenizer() -> Tokenizer {
289        // Backed by an empty BPE model; id_to_token returns None for every ID,
290        // which decode_window handles gracefully via unwrap_or_default().
291        Tokenizer::new(tokenizers::models::bpe::BPE::default())
292    }
293
294    fn ctx_no_gate<'a>() -> DecodeContext<'a> {
295        DecodeContext {
296            prompt_tokens: &[],
297            language_token: 50259,
298            task_token: 50359,
299            sot_token: 50258,
300            eot_token: 50257,
301            no_speech_token: 50362,
302            notimestamps_token: 50363,
303            timestamp_begin_token: 50364,
304            max_new_tokens: 8,
305            // Set very high so the no-speech gate never fires with random weights.
306            no_speech_threshold: 0.999,
307        }
308    }
309
310    fn content_emit(text: &str, logprob: f32) -> TokenEmit {
311        TokenEmit {
312            id: 1,
313            text: text.to_string(),
314            logprob,
315            window_ts_secs: None,
316            is_special: false,
317        }
318    }
319
320    #[test]
321    fn test_quality_gate_passes_normal() {
322        let gate = QualityGate::default();
323        let emits = vec![
324            content_emit(" the", -0.2),
325            content_emit(" quick", -0.4),
326            content_emit(" brown", -0.3),
327            content_emit(" fox", -0.5),
328        ];
329        assert!(
330            passes_quality_gate(&emits, &gate),
331            "varied, confident text should pass"
332        );
333    }
334
335    #[test]
336    fn test_quality_gate_rejects_low_logprob() {
337        let gate = QualityGate::default();
338        // avg_logprob well below the -1.0 floor; compression ratio is irrelevant here.
339        let emits = vec![content_emit(" maybe", -2.5), content_emit(" perhaps", -3.0)];
340        assert!(
341            !passes_quality_gate(&emits, &gate),
342            "low-confidence window should be rejected"
343        );
344    }
345
346    #[test]
347    fn test_quality_gate_rejects_repetition() {
348        let gate = QualityGate::default();
349        // Confident (high logprob) but a repetition loop → high compression ratio.
350        let mut emits = Vec::new();
351        for _ in 0..60 {
352            emits.push(content_emit(" sigh", -0.1));
353        }
354        assert!(
355            !passes_quality_gate(&emits, &gate),
356            "confident repetition loop should be rejected on compression ratio"
357        );
358    }
359
360    #[test]
361    fn test_quality_gate_empty_passes() {
362        let gate = QualityGate::default();
363        // Specials-only / empty content → always passes (handled elsewhere).
364        let emits: Vec<TokenEmit> = vec![TokenEmit {
365            id: 50364,
366            text: String::new(),
367            logprob: -5.0,
368            window_ts_secs: Some(0.0),
369            is_special: true,
370        }];
371        assert!(passes_quality_gate(&emits, &gate));
372        assert!(passes_quality_gate(&[], &gate));
373    }
374
375    /// Structural test: verify `decode_window` compiles, runs, and returns Ok without panicking
376    /// on a random model with a zero encoder output.
377    #[test]
378    fn test_decode_window_random_model() -> Result<()> {
379        let (model, device) = tiny_en_random();
380        let encoder_out = burn::tensor::Tensor::<Flex<f32>, 3>::zeros([1, 1500, 384], &device);
381        let tokenizer = dummy_tokenizer();
382        let ctx = ctx_no_gate();
383
384        let emits = decode_window(&model, encoder_out, &ctx, &tokenizer, &device)?;
385        assert!(emits.len() <= 8, "emits exceeded max_new_tokens");
386        Ok(())
387    }
388
389    /// No-speech gate: with threshold = 0.0 the gate always fires (any P > 0).
390    #[test]
391    fn test_decode_window_no_speech_gate() -> Result<()> {
392        let (model, device) = tiny_en_random();
393        let encoder_out = burn::tensor::Tensor::<Flex<f32>, 3>::zeros([1, 1500, 384], &device);
394        let tokenizer = dummy_tokenizer();
395        let ctx = DecodeContext {
396            no_speech_threshold: 0.0,
397            ..ctx_no_gate()
398        };
399
400        let emits = decode_window(&model, encoder_out, &ctx, &tokenizer, &device)?;
401        assert!(
402            emits.is_empty(),
403            "no-speech gate should have returned an empty vec"
404        );
405        Ok(())
406    }
407
408    /// Real-model test: decode_window on tiny_en produces non-empty output for a speech clip
409    /// and the text (after filtering specials) is close to the one-shot transcribe path.
410    #[test]
411    #[ignore = "requires tiny_en_converted in ./models/ AND test_data/LJ001-0001_16k.wav at repo root"]
412    fn test_decode_window_matches_transcribe_path() -> Result<()> {
413        use crate::{
414            WhisperInference, WhisperTranscriber, audio::compute_mel_from_samples,
415            decoding::DecodingConfig, load::load_whisper,
416        };
417        use burn_flex::{Flex, FlexDevice};
418        use std::path::PathBuf;
419
420        let device = FlexDevice;
421        let models_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
422            .parent()
423            .expect("workspace root")
424            .join("models");
425        // Per-model layout: `<name>/model.{mpk,cfg}` + `<name>/tokenizer.json`.
426        let model_dir = models_dir.join("tiny_en_converted");
427        let model_path = model_dir.join("model");
428        let model_path_str = model_path.to_str().expect("valid UTF-8 model path");
429
430        let model = load_whisper::<Flex<f32>>(model_path_str, &device)?;
431        let tokenizer_path = model_dir.join("tokenizer.json");
432        let tokenizer = Tokenizer::from_file(&tokenizer_path)
433            .map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
434
435        // Load the first 480_000 samples (30 s at 16 kHz) from the test WAV.
436        let wav_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
437            .parent()
438            .expect("workspace root")
439            .join("test_data")
440            .join("LJ001-0001_16k.wav");
441        let raw =
442            std::fs::read(&wav_path).with_context(|| format!("read {}", wav_path.display()))?;
443        // Minimal WAV reader: skip header, parse 16-bit PCM or IEEE float.
444        // Reuse the same approach as streaming.rs tests.
445        let samples_30s = {
446            let needed = 480_000usize;
447            let mut pos = 12usize;
448            let mut audio_format = 1u16;
449            let mut data_start = None;
450            let mut data_len = 0usize;
451            while pos + 8 <= raw.len() {
452                let chunk_id = &raw[pos..pos + 4];
453                let size = u32::from_le_bytes(raw[pos + 4..pos + 8].try_into().unwrap()) as usize;
454                if chunk_id == b"fmt " {
455                    audio_format = u16::from_le_bytes(raw[pos + 8..pos + 10].try_into().unwrap());
456                } else if chunk_id == b"data" {
457                    data_start = Some(pos + 8);
458                    data_len = size;
459                    break;
460                }
461                pos += 8 + size + (size & 1);
462            }
463            let start = data_start.context("no 'data' chunk")?;
464            let end = (start + data_len).min(raw.len());
465            let all: Vec<f32> = if audio_format == 3 {
466                (0..(end - start) / 4)
467                    .map(|i| {
468                        f32::from_le_bytes(
469                            raw[start + i * 4..start + i * 4 + 4].try_into().unwrap(),
470                        )
471                    })
472                    .collect()
473            } else {
474                (0..(end - start) / 2)
475                    .map(|i| {
476                        i16::from_le_bytes(
477                            raw[start + i * 2..start + i * 2 + 2].try_into().unwrap(),
478                        ) as f32
479                            / 32768.0
480                    })
481                    .collect()
482            };
483            let mut padded = all;
484            padded.resize(needed, 0.0);
485            padded
486        };
487
488        // Reference path: compute mel and transcribe with <|notimestamps|>.
489        let mel = compute_mel_from_samples::<Flex<f32>>(&samples_30s, 400, 160, 80, &device)?;
490        let transcriber =
491            WhisperTranscriber::new(model.clone(), tokenizer.clone(), DecodingConfig::fast());
492        let ref_result = transcriber.transcribe(mel.clone())?;
493        let ref_text = ref_result.text.trim().to_lowercase();
494
495        // Stream-decode path on the same encoder output.
496        let encoder_out = model.forward_encoder(mel);
497        let tok = |s: &str, fb: u32| tokenizer.token_to_id(s).unwrap_or(fb);
498        let ctx = DecodeContext {
499            prompt_tokens: &[],
500            sot_token: tok("<|startoftranscript|>", 50258),
501            language_token: tok("<|en|>", 50259),
502            task_token: tok("<|transcribe|>", 50359),
503            eot_token: tok("<|endoftext|>", 50257),
504            no_speech_token: tok("<|nospeech|>", 50362),
505            notimestamps_token: tok("<|notimestamps|>", 50363),
506            timestamp_begin_token: 50364,
507            max_new_tokens: 128,
508            no_speech_threshold: 0.6,
509        };
510
511        let emits = decode_window(&model, encoder_out, &ctx, &tokenizer, &device)?;
512        assert!(
513            !emits.is_empty(),
514            "decode_window produced no tokens for a speech clip"
515        );
516
517        // Filter to regular text tokens only and decode.
518        let text_ids: Vec<u32> = emits
519            .iter()
520            .filter(|e| !e.is_special)
521            .map(|e| e.id)
522            .collect();
523        assert!(
524            !text_ids.is_empty(),
525            "no regular text tokens in decode_window output"
526        );
527
528        let stream_text = tokenizer
529            .decode(&text_ids, true)
530            .map_err(|e| anyhow::anyhow!("{e}"))?
531            .trim()
532            .to_lowercase();
533
534        assert_eq!(
535            stream_text, ref_text,
536            "stream_decode text diverges from one-shot path\n  stream: {stream_text:?}\n  ref:    {ref_text:?}"
537        );
538
539        Ok(())
540    }
541}