Skip to main content

whisperforge_core/
language.rs

1//! Whisper language + task token helpers.
2//!
3//! Whisper's decoder is seeded with `[<|sot|>, <|lang|>, <|task|>, <|notimestamps|>]`.
4//! The language token selects the spoken language (transcription) or source language
5//! (translation); the task token selects `transcribe` (output in the spoken language)
6//! or `translate` (output in **English only** — Whisper has no other-target path).
7//!
8//! This module turns CLI strings (`--language hi`, `--task translate`) into the
9//! corresponding token ids via the model tokenizer, and provides Whisper's
10//! first-token language auto-detection.
11
12use anyhow::{Context, Result};
13use burn::tensor::{Tensor, backend::Backend};
14use tokenizers::Tokenizer;
15
16use crate::kv_cache::{KvCache, forward_decoder_cached};
17use crate::model::Whisper;
18
19/// The ~99 Whisper language codes, in token-id order (mirrors the language span of
20/// [`crate::SPECIAL_TOKENS`], i.e. everything between `<|en|>` and `<|notimestamps|>`).
21pub const LANGUAGE_CODES: &[&str] = &[
22    "en", "zh", "de", "es", "ru", "ko", "fr", "ja", "pt", "tr", "pl", "ca", "nl", "ar", "sv", "it",
23    "id", "hi", "fi", "vi", "he", "uk", "el", "ms", "cs", "ro", "da", "hu", "ta", "no", "th", "ur",
24    "hr", "bg", "lt", "la", "mi", "ml", "cy", "sk", "te", "fa", "lv", "bn", "sr", "az", "sl", "kn",
25    "et", "mk", "br", "eu", "is", "hy", "ne", "mn", "bs", "kk", "sq", "sw", "gl", "mr", "pa", "si",
26    "km", "sn", "yo", "so", "af", "oc", "ka", "be", "tg", "sd", "gu", "am", "yi", "lo", "uz", "fo",
27    "ht", "ps", "tk", "nn", "mt", "sa", "lb", "my", "bo", "tl", "mg", "as", "tt", "haw", "ln",
28    "ha", "ba", "jw", "su", "yue",
29];
30
31/// Whisper decode task. `Translate` is **X → English only** — Whisper cannot emit any
32/// other target language (see module docs).
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
34pub enum Task {
35    #[default]
36    Transcribe,
37    Translate,
38}
39
40/// Token id for a language code (e.g. `"hi"` → `<|hi|>`), or `None` if the tokenizer
41/// lacks it (the signature of an English-only `.en` model).
42pub fn language_token_id(tokenizer: &Tokenizer, code: &str) -> Option<u32> {
43    tokenizer.token_to_id(&format!("<|{code}|>"))
44}
45
46/// Token id for the task token (`<|transcribe|>` / `<|translate|>`).
47pub fn task_token_id(tokenizer: &Tokenizer, task: Task) -> Option<u32> {
48    let s = match task {
49        Task::Transcribe => "<|transcribe|>",
50        Task::Translate => "<|translate|>",
51    };
52    tokenizer.token_to_id(s)
53}
54
55/// Whisper first-token language detection.
56///
57/// Runs a single decoder step seeded with `<|sot|>` only, then takes the argmax of the
58/// resulting logits **restricted to the language token ids** (an unrestricted argmax
59/// would return a content token, not a language). Returns `(code, token_id)`.
60///
61/// `encoder_out` is consumed by [`KvCache::new`]; clone before calling if you need it again.
62pub fn detect_language<B: Backend>(
63    model: &Whisper<B>,
64    encoder_out: Tensor<B, 3>,
65    tokenizer: &Tokenizer,
66    sot_token: u32,
67    device: &B::Device,
68) -> Result<(String, u32)> {
69    let mut cache = KvCache::new(model, encoder_out);
70    let logits = forward_decoder_cached(model, sot_token, &mut cache, device)
71        .context("language-detection forward pass")?;
72
73    let mut best: Option<(f32, u32, &str)> = None;
74    for &code in LANGUAGE_CODES {
75        let Some(id) = language_token_id(tokenizer, code) else {
76            continue;
77        };
78        let Some(&logit) = logits.get(id as usize) else {
79            continue;
80        };
81        if best.is_none_or(|(b, _, _)| logit > b) {
82            best = Some((logit, id, code));
83        }
84    }
85
86    let (_, id, code) = best.context(
87        "no language tokens found in tokenizer — language auto-detection requires a \
88         multilingual model (English-only .en models cannot detect language)",
89    )?;
90    Ok((code.to_string(), id))
91}