1use anyhow::{Context, Result};
13use burn::tensor::{Tensor, backend::Backend};
14use tokenizers::Tokenizer;
15
16use crate::kv_cache::{KvCache, forward_decoder_cached};
17use crate::model::Whisper;
18
19pub const LANGUAGE_CODES: &[&str] = &[
22 "en", "zh", "de", "es", "ru", "ko", "fr", "ja", "pt", "tr", "pl", "ca", "nl", "ar", "sv", "it",
23 "id", "hi", "fi", "vi", "he", "uk", "el", "ms", "cs", "ro", "da", "hu", "ta", "no", "th", "ur",
24 "hr", "bg", "lt", "la", "mi", "ml", "cy", "sk", "te", "fa", "lv", "bn", "sr", "az", "sl", "kn",
25 "et", "mk", "br", "eu", "is", "hy", "ne", "mn", "bs", "kk", "sq", "sw", "gl", "mr", "pa", "si",
26 "km", "sn", "yo", "so", "af", "oc", "ka", "be", "tg", "sd", "gu", "am", "yi", "lo", "uz", "fo",
27 "ht", "ps", "tk", "nn", "mt", "sa", "lb", "my", "bo", "tl", "mg", "as", "tt", "haw", "ln",
28 "ha", "ba", "jw", "su", "yue",
29];
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
34pub enum Task {
35 #[default]
36 Transcribe,
37 Translate,
38}
39
40pub fn language_token_id(tokenizer: &Tokenizer, code: &str) -> Option<u32> {
43 tokenizer.token_to_id(&format!("<|{code}|>"))
44}
45
46pub fn task_token_id(tokenizer: &Tokenizer, task: Task) -> Option<u32> {
48 let s = match task {
49 Task::Transcribe => "<|transcribe|>",
50 Task::Translate => "<|translate|>",
51 };
52 tokenizer.token_to_id(s)
53}
54
55pub fn detect_language<B: Backend>(
63 model: &Whisper<B>,
64 encoder_out: Tensor<B, 3>,
65 tokenizer: &Tokenizer,
66 sot_token: u32,
67 device: &B::Device,
68) -> Result<(String, u32)> {
69 let mut cache = KvCache::new(model, encoder_out);
70 let logits = forward_decoder_cached(model, sot_token, &mut cache, device)
71 .context("language-detection forward pass")?;
72
73 let mut best: Option<(f32, u32, &str)> = None;
74 for &code in LANGUAGE_CODES {
75 let Some(id) = language_token_id(tokenizer, code) else {
76 continue;
77 };
78 let Some(&logit) = logits.get(id as usize) else {
79 continue;
80 };
81 if best.is_none_or(|(b, _, _)| logit > b) {
82 best = Some((logit, id, code));
83 }
84 }
85
86 let (_, id, code) = best.context(
87 "no language tokens found in tokenizer — language auto-detection requires a \
88 multilingual model (English-only .en models cannot detect language)",
89 )?;
90 Ok((code.to_string(), id))
91}