1use std::time::Instant;
2
3use anyhow::{Context, Result};
4use burn::tensor::{Tensor, backend::Backend};
5use tokenizers::Tokenizer;
6use tracing::{Level, event};
7
8use crate::kv_cache::{KvCache, forward_decoder_cached};
9use crate::model::Whisper;
10
11#[derive(Clone)]
12pub struct TokenEmit {
13 pub id: u32,
14 pub text: String,
18 pub logprob: f32,
19 pub window_ts_secs: Option<f32>,
21 pub is_special: bool,
23}
24
25pub struct DecodeContext<'a> {
26 pub prompt_tokens: &'a [u32],
29 pub language_token: u32,
30 pub task_token: u32,
32 pub sot_token: u32,
33 pub eot_token: u32,
34 pub no_speech_token: u32,
35 pub timestamp_begin_token: u32,
37 pub notimestamps_token: u32,
42 pub max_new_tokens: usize,
44 pub no_speech_threshold: f32,
46}
47
48pub fn decode_window<B: Backend>(
62 model: &Whisper<B>,
63 encoder_out: Tensor<B, 3>,
64 ctx: &DecodeContext,
65 tokenizer: &Tokenizer,
66 device: &B::Device,
67) -> Result<Vec<TokenEmit>> {
68 let t0 = Instant::now();
69
70 let mut cache = KvCache::new(model, encoder_out);
71
72 if !ctx.prompt_tokens.is_empty() {
74 event!(
75 Level::DEBUG,
76 prompt_len = ctx.prompt_tokens.len(),
77 prompt_first_token = ctx.prompt_tokens[0],
78 "feeding prompt prefix into KV cache",
79 );
80 }
81 for &tok in ctx.prompt_tokens {
82 forward_decoder_cached(model, tok, &mut cache, device)
83 .with_context(|| format!("feeding prompt token {tok}"))?;
84 }
85
86 let init = [
91 ctx.sot_token,
92 ctx.language_token,
93 ctx.task_token,
94 ctx.notimestamps_token,
95 ];
96 let mut logits: Vec<f32> = Vec::new();
97 for (i, &tok) in init.iter().enumerate() {
98 logits = forward_decoder_cached(model, tok, &mut cache, device)
99 .with_context(|| format!("feeding init token at index {i}"))?;
100 }
101
102 if softmax_at(&logits, ctx.no_speech_token) > ctx.no_speech_threshold {
104 event!(
105 Level::DEBUG,
106 decode_ms = t0.elapsed().as_millis(),
107 n_tokens = 0usize,
108 skipped = true
109 );
110 return Ok(Vec::new());
111 }
112
113 if (ctx.eot_token as usize) < logits.len() {
115 logits[ctx.eot_token as usize] = f32::NEG_INFINITY;
116 }
117
118 let mut emits: Vec<TokenEmit> = Vec::new();
119
120 for _ in 0..ctx.max_new_tokens {
121 let token_id = argmax(&logits);
122
123 if token_id == ctx.eot_token {
124 break;
125 }
126
127 let logprob = log_softmax_at(&logits, token_id);
128 let is_special = token_id >= ctx.eot_token;
129 let window_ts_secs = if token_id >= ctx.timestamp_begin_token {
130 Some((token_id - ctx.timestamp_begin_token) as f32 * 0.02)
131 } else {
132 None
133 };
134 let text = if is_special {
135 String::new()
136 } else {
137 tokenizer.decode(&[token_id], false).unwrap_or_default()
138 };
139
140 emits.push(TokenEmit {
141 id: token_id,
142 text,
143 logprob,
144 window_ts_secs,
145 is_special,
146 });
147
148 logits = forward_decoder_cached(model, token_id, &mut cache, device)
149 .with_context(|| format!("decode step {}", emits.len()))?;
150 }
151
152 event!(
153 Level::DEBUG,
154 decode_ms = t0.elapsed().as_millis(),
155 n_tokens = emits.len()
156 );
157
158 let regular_text: String = emits
165 .iter()
166 .filter(|t| !t.is_special)
167 .map(|t| t.text.as_str())
168 .collect();
169 let trimmed = regular_text.trim();
170 if !trimmed.is_empty()
171 && trimmed
172 .chars()
173 .all(|c| c.is_ascii_punctuation() || c.is_whitespace())
174 {
175 event!(
176 Level::DEBUG,
177 dropped_punctuation_only = true,
178 text = %trimmed
179 );
180 return Ok(Vec::new());
181 }
182
183 Ok(emits)
184}
185
186fn argmax(logits: &[f32]) -> u32 {
187 logits
188 .iter()
189 .enumerate()
190 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
191 .map(|(i, _)| i as u32)
192 .unwrap_or(0)
193}
194
195fn softmax_at(logits: &[f32], token: u32) -> f32 {
196 let idx = token as usize;
197 if idx >= logits.len() {
198 return 0.0;
199 }
200 let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
201 let exp_sum: f32 = logits.iter().map(|&l| (l - max).exp()).sum();
202 ((logits[idx] - max).exp()) / exp_sum.max(f32::EPSILON)
203}
204
205fn log_softmax_at(logits: &[f32], token: u32) -> f32 {
206 let idx = token as usize;
207 if idx >= logits.len() {
208 return f32::NEG_INFINITY;
209 }
210 let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
211 let log_sum = max + logits.iter().map(|&l| (l - max).exp()).sum::<f32>().ln();
212 logits[idx] - log_sum
213}
214
215pub fn avg_logprob(tokens: &[TokenEmit]) -> f32 {
216 let content: Vec<f32> = tokens
217 .iter()
218 .filter(|t| !t.is_special)
219 .map(|t| t.logprob)
220 .collect();
221 if content.is_empty() {
222 return 0.0;
223 }
224 content.iter().sum::<f32>() / content.len() as f32
225}
226
227#[derive(Clone, Copy, Debug)]
233pub struct QualityGate {
234 pub log_prob_threshold: f32,
236 pub compression_ratio_threshold: f32,
239}
240
241impl Default for QualityGate {
242 fn default() -> Self {
243 Self {
244 log_prob_threshold: -1.0,
245 compression_ratio_threshold: 2.4,
246 }
247 }
248}
249
250pub fn passes_quality_gate(emits: &[TokenEmit], gate: &QualityGate) -> bool {
256 let text: String = emits
257 .iter()
258 .filter(|t| !t.is_special)
259 .map(|t| t.text.as_str())
260 .collect();
261 if text.trim().is_empty() {
262 return true;
263 }
264 if avg_logprob(emits) < gate.log_prob_threshold {
265 return false;
266 }
267 if crate::decoding::compression_ratio(&text) > gate.compression_ratio_threshold {
268 return false;
269 }
270 true
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use anyhow::Result;
277 use burn_flex::{Flex, FlexDevice};
278
279 use crate::model::WhisperConfig;
280
281 fn tiny_en_random() -> (Whisper<Flex<f32>>, FlexDevice) {
282 let device = FlexDevice;
283 let config = WhisperConfig::tiny_en();
284 let model = config.init::<Flex<f32>>(&device);
285 (model, device)
286 }
287
288 fn dummy_tokenizer() -> Tokenizer {
289 Tokenizer::new(tokenizers::models::bpe::BPE::default())
292 }
293
294 fn ctx_no_gate<'a>() -> DecodeContext<'a> {
295 DecodeContext {
296 prompt_tokens: &[],
297 language_token: 50259,
298 task_token: 50359,
299 sot_token: 50258,
300 eot_token: 50257,
301 no_speech_token: 50362,
302 notimestamps_token: 50363,
303 timestamp_begin_token: 50364,
304 max_new_tokens: 8,
305 no_speech_threshold: 0.999,
307 }
308 }
309
310 fn content_emit(text: &str, logprob: f32) -> TokenEmit {
311 TokenEmit {
312 id: 1,
313 text: text.to_string(),
314 logprob,
315 window_ts_secs: None,
316 is_special: false,
317 }
318 }
319
320 #[test]
321 fn test_quality_gate_passes_normal() {
322 let gate = QualityGate::default();
323 let emits = vec![
324 content_emit(" the", -0.2),
325 content_emit(" quick", -0.4),
326 content_emit(" brown", -0.3),
327 content_emit(" fox", -0.5),
328 ];
329 assert!(
330 passes_quality_gate(&emits, &gate),
331 "varied, confident text should pass"
332 );
333 }
334
335 #[test]
336 fn test_quality_gate_rejects_low_logprob() {
337 let gate = QualityGate::default();
338 let emits = vec![content_emit(" maybe", -2.5), content_emit(" perhaps", -3.0)];
340 assert!(
341 !passes_quality_gate(&emits, &gate),
342 "low-confidence window should be rejected"
343 );
344 }
345
346 #[test]
347 fn test_quality_gate_rejects_repetition() {
348 let gate = QualityGate::default();
349 let mut emits = Vec::new();
351 for _ in 0..60 {
352 emits.push(content_emit(" sigh", -0.1));
353 }
354 assert!(
355 !passes_quality_gate(&emits, &gate),
356 "confident repetition loop should be rejected on compression ratio"
357 );
358 }
359
360 #[test]
361 fn test_quality_gate_empty_passes() {
362 let gate = QualityGate::default();
363 let emits: Vec<TokenEmit> = vec![TokenEmit {
365 id: 50364,
366 text: String::new(),
367 logprob: -5.0,
368 window_ts_secs: Some(0.0),
369 is_special: true,
370 }];
371 assert!(passes_quality_gate(&emits, &gate));
372 assert!(passes_quality_gate(&[], &gate));
373 }
374
375 #[test]
378 fn test_decode_window_random_model() -> Result<()> {
379 let (model, device) = tiny_en_random();
380 let encoder_out = burn::tensor::Tensor::<Flex<f32>, 3>::zeros([1, 1500, 384], &device);
381 let tokenizer = dummy_tokenizer();
382 let ctx = ctx_no_gate();
383
384 let emits = decode_window(&model, encoder_out, &ctx, &tokenizer, &device)?;
385 assert!(emits.len() <= 8, "emits exceeded max_new_tokens");
386 Ok(())
387 }
388
389 #[test]
391 fn test_decode_window_no_speech_gate() -> Result<()> {
392 let (model, device) = tiny_en_random();
393 let encoder_out = burn::tensor::Tensor::<Flex<f32>, 3>::zeros([1, 1500, 384], &device);
394 let tokenizer = dummy_tokenizer();
395 let ctx = DecodeContext {
396 no_speech_threshold: 0.0,
397 ..ctx_no_gate()
398 };
399
400 let emits = decode_window(&model, encoder_out, &ctx, &tokenizer, &device)?;
401 assert!(
402 emits.is_empty(),
403 "no-speech gate should have returned an empty vec"
404 );
405 Ok(())
406 }
407
408 #[test]
411 #[ignore = "requires tiny_en_converted in ./models/ AND test_data/LJ001-0001_16k.wav at repo root"]
412 fn test_decode_window_matches_transcribe_path() -> Result<()> {
413 use crate::{
414 WhisperInference, WhisperTranscriber, audio::compute_mel_from_samples,
415 decoding::DecodingConfig, load::load_whisper,
416 };
417 use burn_flex::{Flex, FlexDevice};
418 use std::path::PathBuf;
419
420 let device = FlexDevice;
421 let models_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
422 .parent()
423 .expect("workspace root")
424 .join("models");
425 let model_dir = models_dir.join("tiny_en_converted");
427 let model_path = model_dir.join("model");
428 let model_path_str = model_path.to_str().expect("valid UTF-8 model path");
429
430 let model = load_whisper::<Flex<f32>>(model_path_str, &device)?;
431 let tokenizer_path = model_dir.join("tokenizer.json");
432 let tokenizer = Tokenizer::from_file(&tokenizer_path)
433 .map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
434
435 let wav_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
437 .parent()
438 .expect("workspace root")
439 .join("test_data")
440 .join("LJ001-0001_16k.wav");
441 let raw =
442 std::fs::read(&wav_path).with_context(|| format!("read {}", wav_path.display()))?;
443 let samples_30s = {
446 let needed = 480_000usize;
447 let mut pos = 12usize;
448 let mut audio_format = 1u16;
449 let mut data_start = None;
450 let mut data_len = 0usize;
451 while pos + 8 <= raw.len() {
452 let chunk_id = &raw[pos..pos + 4];
453 let size = u32::from_le_bytes(raw[pos + 4..pos + 8].try_into().unwrap()) as usize;
454 if chunk_id == b"fmt " {
455 audio_format = u16::from_le_bytes(raw[pos + 8..pos + 10].try_into().unwrap());
456 } else if chunk_id == b"data" {
457 data_start = Some(pos + 8);
458 data_len = size;
459 break;
460 }
461 pos += 8 + size + (size & 1);
462 }
463 let start = data_start.context("no 'data' chunk")?;
464 let end = (start + data_len).min(raw.len());
465 let all: Vec<f32> = if audio_format == 3 {
466 (0..(end - start) / 4)
467 .map(|i| {
468 f32::from_le_bytes(
469 raw[start + i * 4..start + i * 4 + 4].try_into().unwrap(),
470 )
471 })
472 .collect()
473 } else {
474 (0..(end - start) / 2)
475 .map(|i| {
476 i16::from_le_bytes(
477 raw[start + i * 2..start + i * 2 + 2].try_into().unwrap(),
478 ) as f32
479 / 32768.0
480 })
481 .collect()
482 };
483 let mut padded = all;
484 padded.resize(needed, 0.0);
485 padded
486 };
487
488 let mel = compute_mel_from_samples::<Flex<f32>>(&samples_30s, 400, 160, 80, &device)?;
490 let transcriber =
491 WhisperTranscriber::new(model.clone(), tokenizer.clone(), DecodingConfig::fast());
492 let ref_result = transcriber.transcribe(mel.clone())?;
493 let ref_text = ref_result.text.trim().to_lowercase();
494
495 let encoder_out = model.forward_encoder(mel);
497 let tok = |s: &str, fb: u32| tokenizer.token_to_id(s).unwrap_or(fb);
498 let ctx = DecodeContext {
499 prompt_tokens: &[],
500 sot_token: tok("<|startoftranscript|>", 50258),
501 language_token: tok("<|en|>", 50259),
502 task_token: tok("<|transcribe|>", 50359),
503 eot_token: tok("<|endoftext|>", 50257),
504 no_speech_token: tok("<|nospeech|>", 50362),
505 notimestamps_token: tok("<|notimestamps|>", 50363),
506 timestamp_begin_token: 50364,
507 max_new_tokens: 128,
508 no_speech_threshold: 0.6,
509 };
510
511 let emits = decode_window(&model, encoder_out, &ctx, &tokenizer, &device)?;
512 assert!(
513 !emits.is_empty(),
514 "decode_window produced no tokens for a speech clip"
515 );
516
517 let text_ids: Vec<u32> = emits
519 .iter()
520 .filter(|e| !e.is_special)
521 .map(|e| e.id)
522 .collect();
523 assert!(
524 !text_ids.is_empty(),
525 "no regular text tokens in decode_window output"
526 );
527
528 let stream_text = tokenizer
529 .decode(&text_ids, true)
530 .map_err(|e| anyhow::anyhow!("{e}"))?
531 .trim()
532 .to_lowercase();
533
534 assert_eq!(
535 stream_text, ref_text,
536 "stream_decode text diverges from one-shot path\n stream: {stream_text:?}\n ref: {ref_text:?}"
537 );
538
539 Ok(())
540 }
541}