Skip to main content

whisperforge_core/
decoding.rs

1/// SOTA decoding strategies for Whisper transcription
2/// Implements faster-whisper's hybrid beam search + temperature fallback approach
3use anyhow::{Result, anyhow};
4use flate2::{Compression, write::GzEncoder};
5use rand::{RngExt, SeedableRng};
6use std::cmp::Ordering;
7use std::collections::BinaryHeap;
8use std::io::Write;
9
10use crate::language::Task;
11
12/// Configuration for beam search decoding
13#[derive(Debug, Clone)]
14pub struct DecodingConfig {
15    /// Beam size for beam search (1 = greedy, 5 = default, 10+ = very accurate)
16    pub beam_size: usize,
17    /// Temperature fallback sequence (tried in order; retry at next temp if quality fails)
18    pub temperatures: Vec<f32>,
19    /// Length penalty to prevent repeating short sequences (0.0 = no penalty)
20    pub length_penalty: f32,
21    /// Threshold for no-speech detection based on cross-attention (0.0 to 1.0)
22    pub no_speech_threshold: f32,
23    /// Maximum tokens to generate
24    pub max_length: usize,
25    /// Language token (e.g., "en" for English; "auto" for first-token detection)
26    pub language: String,
27    /// Decode task: `Transcribe` (output in `language`) or `Translate` (X → English only).
28    pub task: Task,
29    /// Gzip compression ratio threshold — ratio above this signals a hallucination loop (default 2.4)
30    pub compression_ratio_threshold: f32,
31    /// Average log-probability threshold — below this signals low-confidence output (default -1.0)
32    pub log_prob_threshold: f32,
33}
34
35impl Default for DecodingConfig {
36    fn default() -> Self {
37        Self {
38            beam_size: 5,
39            temperatures: vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
40            length_penalty: 1.0,
41            no_speech_threshold: 0.6,
42            max_length: 448, // Whisper max context
43            language: "en".to_string(),
44            task: Task::Transcribe,
45            compression_ratio_threshold: 2.4,
46            log_prob_threshold: -1.0,
47        }
48    }
49}
50
51impl DecodingConfig {
52    /// Create a fast decoding config (greedy, minimal processing)
53    pub fn fast() -> Self {
54        Self {
55            beam_size: 1,
56            temperatures: vec![0.0],
57            length_penalty: 0.0,
58            ..Default::default()
59        }
60    }
61
62    /// Create a balanced decoding config (good quality/speed tradeoff)
63    pub fn balanced() -> Self {
64        Self {
65            beam_size: 5,
66            temperatures: vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
67            length_penalty: 1.0,
68            ..Default::default()
69        }
70    }
71
72    /// Create an accurate decoding config (highest quality, slowest)
73    pub fn accurate() -> Self {
74        Self {
75            beam_size: 10,
76            temperatures: vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
77            length_penalty: 1.0,
78            ..Default::default()
79        }
80    }
81
82    /// Set beam size
83    pub fn with_beam_size(mut self, beam_size: usize) -> Self {
84        self.beam_size = beam_size.max(1);
85        self
86    }
87
88    /// Override temperatures with a single value (disables fallback sequence)
89    pub fn with_temperature(mut self, temperature: f32) -> Self {
90        self.temperatures = vec![temperature.max(0.0)];
91        self
92    }
93
94    /// Set length penalty
95    pub fn with_length_penalty(mut self, penalty: f32) -> Self {
96        self.length_penalty = penalty.max(0.0);
97        self
98    }
99
100    /// Set no-speech threshold
101    pub fn with_no_speech_threshold(mut self, threshold: f32) -> Self {
102        self.no_speech_threshold = threshold.clamp(0.0, 1.0);
103        self
104    }
105
106    /// Set language
107    pub fn with_language(mut self, language: String) -> Self {
108        self.language = language;
109        self
110    }
111
112    /// Set decode task (transcribe vs translate-to-English)
113    pub fn with_task(mut self, task: Task) -> Self {
114        self.task = task;
115        self
116    }
117}
118
119// ============================================================================
120// Quality metrics
121// ============================================================================
122
123/// Gzip compression ratio of `text` — `text.len() / compressed_len`.
124///
125/// Values above `DecodingConfig::compression_ratio_threshold` (2.4) indicate a
126/// repetitive or hallucinated output.
127pub fn compression_ratio(text: &str) -> f32 {
128    let bytes = text.as_bytes();
129    if bytes.is_empty() {
130        return 0.0;
131    }
132    let mut enc = GzEncoder::new(Vec::new(), Compression::default());
133    enc.write_all(bytes).ok();
134    let compressed_len = enc.finish().unwrap_or_default().len().max(1);
135    bytes.len() as f32 / compressed_len as f32
136}
137
138/// Softmax probability of `token` from raw logits.
139fn softmax_at(logits: &[f32], token: u32) -> f32 {
140    let idx = token as usize;
141    if idx >= logits.len() {
142        return 0.0;
143    }
144    let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
145    let exp_sum: f32 = logits.iter().map(|&l| (l - max).exp()).sum();
146    ((logits[idx] - max).exp()) / exp_sum.max(f32::EPSILON)
147}
148
149/// Log-softmax of `token` given logits scaled by `temp` (0.0 → greedy / unscaled).
150fn log_softmax_at(logits: &[f32], token: u32, temp: f32) -> f32 {
151    let idx = token as usize;
152    if idx >= logits.len() {
153        return f32::NEG_INFINITY;
154    }
155    let scaled: Vec<f32> = if temp > 0.0 {
156        logits.iter().map(|&l| l / temp).collect()
157    } else {
158        logits.to_vec()
159    };
160    let max = scaled.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
161    let log_sum = max + scaled.iter().map(|&l| (l - max).exp()).sum::<f32>().ln();
162    scaled[idx] - log_sum
163}
164
165/// Sample a token from `logits` at the given temperature.
166///
167/// At `temp == 0.0` this is equivalent to argmax (greedy).
168fn sample_from_logits(logits: &[f32], temp: f32, rng: &mut impl rand::Rng) -> u32 {
169    if temp <= 0.0 || logits.is_empty() {
170        return argmax_logits(logits);
171    }
172    let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
173    let exps: Vec<f32> = logits.iter().map(|&l| ((l - max) / temp).exp()).collect();
174    let sum: f32 = exps.iter().sum::<f32>().max(f32::EPSILON);
175    let threshold: f32 = rng.random::<f32>() * sum;
176    let mut cumsum = 0.0;
177    for (i, &e) in exps.iter().enumerate() {
178        cumsum += e;
179        if cumsum >= threshold {
180            return i as u32;
181        }
182    }
183    (logits.len() - 1) as u32
184}
185
186fn argmax_logits(logits: &[f32]) -> u32 {
187    logits
188        .iter()
189        .enumerate()
190        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
191        .map(|(i, _)| i as u32)
192        .unwrap_or(0)
193}
194
195/// Represents a candidate sequence during beam search
196#[derive(Debug, Clone)]
197struct BeamCandidate {
198    /// Token IDs in this sequence
199    tokens: Vec<u32>,
200    /// Cumulative log probability of this sequence
201    log_prob: f32,
202    /// Whether this sequence has finished (reached end-of-sequence token)
203    finished: bool,
204    /// Number of non-padding tokens (for length normalization)
205    token_count: usize,
206}
207
208impl BeamCandidate {
209    /// Create a new candidate with initial token
210    fn new(token: u32) -> Self {
211        Self {
212            tokens: vec![token],
213            log_prob: 0.0,
214            finished: false,
215            token_count: 1,
216        }
217    }
218
219    /// Calculate normalized score for ranking (higher is better)
220    fn normalized_score(&self, length_penalty: f32) -> f32 {
221        if self.token_count == 0 {
222            return self.log_prob;
223        }
224        // Length-normalized log probability (prevents bias toward short sequences)
225        self.log_prob / ((self.token_count as f32).powf(length_penalty))
226    }
227}
228
229/// Custom ordering for candidates in binary heap (max-heap behavior)
230impl PartialEq for BeamCandidate {
231    fn eq(&self, other: &Self) -> bool {
232        (self.normalized_score(1.0) - other.normalized_score(1.0)).abs() < 1e-6
233    }
234}
235
236impl Eq for BeamCandidate {}
237
238impl PartialOrd for BeamCandidate {
239    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
240        Some(self.cmp(other))
241    }
242}
243
244impl Ord for BeamCandidate {
245    fn cmp(&self, other: &Self) -> Ordering {
246        // Reverse ordering for max-heap (higher score = higher priority)
247        other
248            .normalized_score(1.0)
249            .partial_cmp(&self.normalized_score(1.0))
250            .unwrap_or(Ordering::Equal)
251    }
252}
253
254/// Beam search decoder for Whisper
255pub struct BeamSearchDecoder {
256    config: DecodingConfig,
257}
258
259impl BeamSearchDecoder {
260    /// Create a new beam search decoder
261    pub fn new(config: DecodingConfig) -> Self {
262        Self { config }
263    }
264
265    /// Decode token probabilities to a sequence using beam search
266    ///
267    /// # Arguments
268    /// * `token_probs` - Matrix of shape (seq_len, vocab_size) with log probabilities per token
269    /// * `initial_token` - Starting token (usually language token or BOS)
270    /// * `vocab_size` - Size of vocabulary
271    /// * `eos_token` - End-of-sequence token ID
272    /// * `_pad_token` - Padding token ID (reserved for future use)
273    ///
274    /// # Returns
275    /// Vector of token IDs representing the decoded sequence
276    pub fn decode(
277        &self,
278        token_probs: &[Vec<f32>],
279        initial_token: u32,
280        vocab_size: usize,
281        eos_token: u32,
282        _pad_token: u32,
283    ) -> Result<Vec<u32>> {
284        if token_probs.is_empty() {
285            return Ok(vec![initial_token]);
286        }
287
288        // Validate input
289        if token_probs.iter().any(|probs| probs.len() != vocab_size) {
290            return Err(anyhow!("Invalid token probabilities shape"));
291        }
292
293        // Start with initial token
294        let mut candidates = BinaryHeap::new();
295        candidates.push(BeamCandidate::new(initial_token));
296
297        // Process each timestep
298        for step in 0..token_probs.len().min(self.config.max_length) {
299            let probs = &token_probs[step];
300            let mut next_candidates = Vec::new();
301
302            // Expand each existing candidate
303            for candidate in candidates.iter().take(self.config.beam_size) {
304                if candidate.finished {
305                    next_candidates.push(candidate.clone());
306                    continue;
307                }
308
309                // Get top-k tokens for this candidate
310                let top_k = self.get_top_k_tokens(probs, self.config.beam_size);
311
312                for (token, log_prob) in top_k {
313                    let mut new_candidate = candidate.clone();
314                    new_candidate.tokens.push(token);
315                    new_candidate.log_prob += log_prob;
316                    new_candidate.token_count += 1;
317
318                    // Check if sequence is finished
319                    if token == eos_token || step == token_probs.len() - 1 {
320                        new_candidate.finished = true;
321                    }
322
323                    next_candidates.push(new_candidate);
324                }
325            }
326
327            // Keep only top beam_size candidates
328            next_candidates.sort_by(|a, b| {
329                b.normalized_score(self.config.length_penalty)
330                    .partial_cmp(&a.normalized_score(self.config.length_penalty))
331                    .unwrap_or(Ordering::Equal)
332            });
333
334            candidates = next_candidates
335                .into_iter()
336                .take(self.config.beam_size)
337                .collect::<BinaryHeap<_>>();
338
339            // Early exit if all candidates are finished
340            if candidates.iter().all(|c| c.finished) {
341                break;
342            }
343        }
344
345        // Return best candidate
346        candidates
347            .iter()
348            .max_by(|a, b| {
349                a.normalized_score(self.config.length_penalty)
350                    .partial_cmp(&b.normalized_score(self.config.length_penalty))
351                    .unwrap_or(Ordering::Equal)
352            })
353            .map(|c| c.tokens.clone())
354            .ok_or_else(|| anyhow!("No valid candidates found"))
355    }
356
357    /// Get top-k tokens from log probability distribution
358    fn get_top_k_tokens(&self, log_probs: &[f32], k: usize) -> Vec<(u32, f32)> {
359        let mut indexed: Vec<(u32, f32)> = log_probs
360            .iter()
361            .enumerate()
362            .map(|(i, &prob)| (i as u32, prob))
363            .collect();
364
365        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
366
367        indexed.into_iter().take(k).collect()
368    }
369}
370
371/// Greedy decoder (baseline, fastest)
372pub struct GreedyDecoder;
373
374impl GreedyDecoder {
375    /// Decode using greedy approach (always take highest probability token)
376    pub fn decode(
377        token_probs: &[Vec<f32>],
378        initial_token: u32,
379        _vocab_size: usize,
380        eos_token: u32,
381        _pad_token: u32,
382    ) -> Result<Vec<u32>> {
383        let mut tokens = vec![initial_token];
384
385        for probs in token_probs {
386            if probs.is_empty() {
387                break;
388            }
389
390            let (token, _) = probs
391                .iter()
392                .enumerate()
393                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(Ordering::Equal))
394                .unwrap_or((eos_token as usize, &f32::NEG_INFINITY));
395
396            let token = token as u32;
397            tokens.push(token);
398
399            if token == eos_token {
400                break;
401            }
402        }
403
404        Ok(tokens)
405    }
406}
407
408/// Multi-hypothesis decoding strategy that tries multiple approaches
409pub struct HybridDecoder {
410    config: DecodingConfig,
411    beam_decoder: BeamSearchDecoder,
412}
413
414impl HybridDecoder {
415    /// Create a new hybrid decoder
416    pub fn new(config: DecodingConfig) -> Self {
417        Self {
418            beam_decoder: BeamSearchDecoder::new(config.clone()),
419            config,
420        }
421    }
422
423    /// Decode with beam-search / greedy fallback (no quality gating).
424    pub fn decode(
425        &self,
426        token_probs: &[Vec<f32>],
427        initial_token: u32,
428        vocab_size: usize,
429        eos_token: u32,
430        pad_token: u32,
431    ) -> Result<Vec<u32>> {
432        match self
433            .beam_decoder
434            .decode(token_probs, initial_token, vocab_size, eos_token, pad_token)
435        {
436            Ok(tokens) if tokens.len() > 1 => Ok(tokens),
437            _ => {
438                GreedyDecoder::decode(token_probs, initial_token, vocab_size, eos_token, pad_token)
439            }
440        }
441    }
442
443    /// Decode with quality-gated temperature fallback (faster-whisper SOTA strategy).
444    ///
445    /// Iterates over `config.temperatures` in order. At each temperature, samples
446    /// tokens from the collected logits and checks:
447    /// - `no_speech_prob > no_speech_threshold` → return empty (silence)
448    /// - `avg_log_prob < log_prob_threshold` → retry at next temperature
449    /// - `compression_ratio > compression_ratio_threshold` → retry at next temperature
450    ///
451    /// Returns the first result that passes all quality gates, or the last attempt
452    /// if all temperatures are exhausted.
453    ///
454    /// # Arguments
455    /// * `token_probs` — per-step raw logits collected autoregressively from the decoder
456    /// * `initial_token` — first token to prepend to the output sequence
457    /// * `vocab_size` — vocabulary size (bounds check for `no_speech_token`)
458    /// * `eos_token` — end-of-sequence token; stops generation when sampled
459    /// * `no_speech_token` — token whose first-step softmax probability signals silence
460    /// * `decode_text` — closure that converts token IDs to a UTF-8 string for compression ratio
461    pub fn decode_with_fallback(
462        &self,
463        token_probs: &[Vec<f32>],
464        initial_token: u32,
465        vocab_size: usize,
466        eos_token: u32,
467        no_speech_token: u32,
468        decode_text: impl Fn(&[u32]) -> String,
469    ) -> Result<Vec<u32>> {
470        if token_probs.is_empty() {
471            return Ok(vec![initial_token]);
472        }
473
474        // No-speech check on first decode step (independent of temperature).
475        if (no_speech_token as usize) < vocab_size {
476            let ns_prob = softmax_at(&token_probs[0], no_speech_token);
477            if ns_prob > self.config.no_speech_threshold {
478                return Ok(vec![]);
479            }
480        }
481
482        let mut best: Option<Vec<u32>> = None;
483
484        for &temp in &self.config.temperatures {
485            let mut rng = rand::rngs::StdRng::seed_from_u64(42);
486            let mut tokens = vec![initial_token];
487            let mut log_probs: Vec<f32> = Vec::new();
488
489            for step_logits in token_probs.iter().take(self.config.max_length) {
490                let selected = sample_from_logits(step_logits, temp, &mut rng);
491                log_probs.push(log_softmax_at(step_logits, selected, temp));
492                tokens.push(selected);
493                if selected == eos_token {
494                    break;
495                }
496            }
497
498            let avg_lp = if log_probs.is_empty() {
499                0.0
500            } else {
501                log_probs.iter().sum::<f32>() / log_probs.len() as f32
502            };
503
504            let text = decode_text(&tokens);
505            let cr = compression_ratio(&text);
506
507            let quality_ok = avg_lp > self.config.log_prob_threshold
508                && cr < self.config.compression_ratio_threshold;
509
510            if best.is_none() {
511                best = Some(tokens.clone());
512            }
513
514            if quality_ok {
515                return Ok(tokens);
516            }
517        }
518
519        best.ok_or_else(|| anyhow!("decode_with_fallback: no temperatures configured"))
520    }
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526
527    #[test]
528    fn test_decoding_config_defaults() {
529        let config = DecodingConfig::default();
530        assert_eq!(config.beam_size, 5);
531        assert_eq!(config.temperatures, vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0]);
532        assert_eq!(config.language, "en");
533    }
534
535    #[test]
536    fn test_decoding_config_fast() {
537        let config = DecodingConfig::fast();
538        assert_eq!(config.beam_size, 1);
539        assert_eq!(config.temperatures, vec![0.0]);
540    }
541
542    #[test]
543    fn test_decoding_config_accurate() {
544        let config = DecodingConfig::accurate();
545        assert_eq!(config.beam_size, 10);
546        assert_eq!(config.temperatures, vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0]);
547    }
548
549    #[test]
550    fn test_with_temperature_overrides_sequence() {
551        let config = DecodingConfig::default().with_temperature(0.7);
552        assert_eq!(config.temperatures, vec![0.7]);
553    }
554
555    #[test]
556    fn test_beam_candidate_scoring() {
557        let mut c1 = BeamCandidate::new(1);
558        c1.log_prob = -4.0;
559        c1.token_count = 2;
560
561        let mut c2 = BeamCandidate::new(2);
562        c2.log_prob = -1.0;
563        c2.token_count = 1;
564
565        // c2 has better (higher) normalized score: -1.0 / 1^1 = -1.0 > -4.0 / 2^1 = -2.0
566        assert!(c2.normalized_score(1.0) > c1.normalized_score(1.0));
567    }
568
569    #[test]
570    fn test_greedy_decoder() -> Result<()> {
571        // Create simple token probabilities (logits)
572        let token_probs = vec![
573            vec![-10.0, -5.0, -0.5, -10.0], // Token 2 is most likely
574            vec![-5.0, -10.0, -0.1, -10.0], // Token 2 is most likely
575            vec![-0.5, -5.0, -10.0, -10.0], // Token 0 is most likely (EOS)
576        ];
577
578        let tokens = GreedyDecoder::decode(&token_probs, 50256, 4, 0, 50257)?;
579
580        assert_eq!(tokens.len(), 4); // Initial + 3 timesteps
581        assert_eq!(tokens[0], 50256); // Initial token
582        assert_eq!(tokens[1], 2); // Most likely at t=0
583        assert_eq!(tokens[2], 2); // Most likely at t=1
584        assert_eq!(tokens[3], 0); // EOS at t=2
585
586        Ok(())
587    }
588
589    #[test]
590    fn test_beam_search_decoder() -> Result<()> {
591        let config = DecodingConfig {
592            beam_size: 2,
593            ..Default::default()
594        };
595
596        let decoder = BeamSearchDecoder::new(config);
597
598        let token_probs = vec![
599            vec![-5.0, -0.5, -10.0], // Token 1 is most likely
600            vec![-0.1, -5.0, -10.0], // Token 0 is most likely
601        ];
602
603        let tokens = decoder.decode(&token_probs, 100, 3, 0, 99)?;
604
605        assert!(tokens.len() >= 2);
606        assert_eq!(tokens[0], 100); // Initial token
607
608        Ok(())
609    }
610
611    #[test]
612    fn test_hybrid_decoder_fallback() -> Result<()> {
613        let config = DecodingConfig::default();
614        let decoder = HybridDecoder::new(config);
615
616        let token_probs = vec![vec![-0.5, -10.0, -10.0]];
617
618        let tokens = decoder.decode(&token_probs, 100, 3, 0, 99)?;
619
620        assert!(!tokens.is_empty());
621        assert_eq!(tokens[0], 100);
622
623        Ok(())
624    }
625
626    #[test]
627    fn test_compression_ratio_normal_text() {
628        // Prose compresses to ratio < 2.4 (not a hallucination loop).
629        let text = "The quick brown fox jumps over the lazy dog.";
630        let cr = compression_ratio(text);
631        assert!(cr < 2.4, "normal text compression ratio was {cr}");
632    }
633
634    #[test]
635    fn test_compression_ratio_repetitive_text() {
636        // Hallucination loops repeat the same phrase hundreds of times.
637        // Gzip header overhead is ~20 bytes, so the input must be long enough
638        // for the repetition signal to dominate.
639        let phrase = "the quick brown fox ";
640        let text = phrase.repeat(100); // 2000 chars, highly compressible
641        let cr = compression_ratio(&text);
642        assert!(cr > 2.4, "repetitive text compression ratio was {cr}");
643    }
644
645    #[test]
646    fn test_compression_ratio_empty() {
647        assert_eq!(compression_ratio(""), 0.0);
648    }
649
650    #[test]
651    fn test_softmax_at_picks_max() {
652        let logits = vec![-10.0, -0.1, -5.0];
653        let p_max = softmax_at(&logits, 1);
654        let p_min = softmax_at(&logits, 0);
655        assert!(p_max > p_min, "softmax of max logit should be highest");
656        let total: f32 = (0..3).map(|i| softmax_at(&logits, i)).sum();
657        assert!((total - 1.0).abs() < 1e-4, "softmax probs must sum to 1");
658    }
659
660    #[test]
661    fn test_decode_with_fallback_passes_quality() -> Result<()> {
662        // Token 1 dominates every step → log prob near 0, repetitive text expected.
663        // Use a very lax threshold so it passes immediately.
664        let config = DecodingConfig {
665            temperatures: vec![0.0],
666            log_prob_threshold: -100.0,
667            compression_ratio_threshold: 100.0,
668            no_speech_threshold: 1.0,
669            max_length: 5,
670            ..Default::default()
671        };
672        let decoder = HybridDecoder::new(config);
673
674        let token_probs = vec![vec![-0.01, -10.0, -10.0], vec![-0.01, -10.0, -10.0]];
675
676        let tokens = decoder.decode_with_fallback(
677            &token_probs,
678            99,
679            3,
680            0, // eos
681            2, // no_speech_token (low prob → not triggered)
682            |ids| {
683                ids.iter()
684                    .map(|i| i.to_string())
685                    .collect::<Vec<_>>()
686                    .join(" ")
687            },
688        )?;
689
690        assert!(!tokens.is_empty());
691        assert_eq!(tokens[0], 99);
692        Ok(())
693    }
694
695    #[test]
696    fn test_decode_with_fallback_no_speech() -> Result<()> {
697        // Make no_speech_token dominant at step 0 → should return empty Vec.
698        let config = DecodingConfig {
699            temperatures: vec![0.0],
700            no_speech_threshold: 0.5,
701            log_prob_threshold: -100.0,
702            compression_ratio_threshold: 100.0,
703            max_length: 5,
704            ..Default::default()
705        };
706        let decoder = HybridDecoder::new(config);
707
708        // Token 1 has very high logit → softmax ≈ 1.0, well above threshold 0.5.
709        let token_probs = vec![vec![-10.0, 100.0, -10.0]];
710
711        let tokens = decoder.decode_with_fallback(
712            &token_probs,
713            99,
714            3,
715            0, // eos
716            1, // no_speech_token = token 1 (dominant)
717            |_| String::new(),
718        )?;
719
720        assert!(
721            tokens.is_empty(),
722            "should return empty when no-speech detected"
723        );
724        Ok(())
725    }
726}