1use anyhow::{Result, anyhow};
4use flate2::{Compression, write::GzEncoder};
5use rand::{RngExt, SeedableRng};
6use std::cmp::Ordering;
7use std::collections::BinaryHeap;
8use std::io::Write;
9
10use crate::language::Task;
11
12#[derive(Debug, Clone)]
14pub struct DecodingConfig {
15 pub beam_size: usize,
17 pub temperatures: Vec<f32>,
19 pub length_penalty: f32,
21 pub no_speech_threshold: f32,
23 pub max_length: usize,
25 pub language: String,
27 pub task: Task,
29 pub compression_ratio_threshold: f32,
31 pub log_prob_threshold: f32,
33}
34
35impl Default for DecodingConfig {
36 fn default() -> Self {
37 Self {
38 beam_size: 5,
39 temperatures: vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
40 length_penalty: 1.0,
41 no_speech_threshold: 0.6,
42 max_length: 448, language: "en".to_string(),
44 task: Task::Transcribe,
45 compression_ratio_threshold: 2.4,
46 log_prob_threshold: -1.0,
47 }
48 }
49}
50
51impl DecodingConfig {
52 pub fn fast() -> Self {
54 Self {
55 beam_size: 1,
56 temperatures: vec![0.0],
57 length_penalty: 0.0,
58 ..Default::default()
59 }
60 }
61
62 pub fn balanced() -> Self {
64 Self {
65 beam_size: 5,
66 temperatures: vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
67 length_penalty: 1.0,
68 ..Default::default()
69 }
70 }
71
72 pub fn accurate() -> Self {
74 Self {
75 beam_size: 10,
76 temperatures: vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
77 length_penalty: 1.0,
78 ..Default::default()
79 }
80 }
81
82 pub fn with_beam_size(mut self, beam_size: usize) -> Self {
84 self.beam_size = beam_size.max(1);
85 self
86 }
87
88 pub fn with_temperature(mut self, temperature: f32) -> Self {
90 self.temperatures = vec![temperature.max(0.0)];
91 self
92 }
93
94 pub fn with_length_penalty(mut self, penalty: f32) -> Self {
96 self.length_penalty = penalty.max(0.0);
97 self
98 }
99
100 pub fn with_no_speech_threshold(mut self, threshold: f32) -> Self {
102 self.no_speech_threshold = threshold.clamp(0.0, 1.0);
103 self
104 }
105
106 pub fn with_language(mut self, language: String) -> Self {
108 self.language = language;
109 self
110 }
111
112 pub fn with_task(mut self, task: Task) -> Self {
114 self.task = task;
115 self
116 }
117}
118
119pub fn compression_ratio(text: &str) -> f32 {
128 let bytes = text.as_bytes();
129 if bytes.is_empty() {
130 return 0.0;
131 }
132 let mut enc = GzEncoder::new(Vec::new(), Compression::default());
133 enc.write_all(bytes).ok();
134 let compressed_len = enc.finish().unwrap_or_default().len().max(1);
135 bytes.len() as f32 / compressed_len as f32
136}
137
138fn softmax_at(logits: &[f32], token: u32) -> f32 {
140 let idx = token as usize;
141 if idx >= logits.len() {
142 return 0.0;
143 }
144 let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
145 let exp_sum: f32 = logits.iter().map(|&l| (l - max).exp()).sum();
146 ((logits[idx] - max).exp()) / exp_sum.max(f32::EPSILON)
147}
148
149fn log_softmax_at(logits: &[f32], token: u32, temp: f32) -> f32 {
151 let idx = token as usize;
152 if idx >= logits.len() {
153 return f32::NEG_INFINITY;
154 }
155 let scaled: Vec<f32> = if temp > 0.0 {
156 logits.iter().map(|&l| l / temp).collect()
157 } else {
158 logits.to_vec()
159 };
160 let max = scaled.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
161 let log_sum = max + scaled.iter().map(|&l| (l - max).exp()).sum::<f32>().ln();
162 scaled[idx] - log_sum
163}
164
165fn sample_from_logits(logits: &[f32], temp: f32, rng: &mut impl rand::Rng) -> u32 {
169 if temp <= 0.0 || logits.is_empty() {
170 return argmax_logits(logits);
171 }
172 let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
173 let exps: Vec<f32> = logits.iter().map(|&l| ((l - max) / temp).exp()).collect();
174 let sum: f32 = exps.iter().sum::<f32>().max(f32::EPSILON);
175 let threshold: f32 = rng.random::<f32>() * sum;
176 let mut cumsum = 0.0;
177 for (i, &e) in exps.iter().enumerate() {
178 cumsum += e;
179 if cumsum >= threshold {
180 return i as u32;
181 }
182 }
183 (logits.len() - 1) as u32
184}
185
186fn argmax_logits(logits: &[f32]) -> u32 {
187 logits
188 .iter()
189 .enumerate()
190 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
191 .map(|(i, _)| i as u32)
192 .unwrap_or(0)
193}
194
195#[derive(Debug, Clone)]
197struct BeamCandidate {
198 tokens: Vec<u32>,
200 log_prob: f32,
202 finished: bool,
204 token_count: usize,
206}
207
208impl BeamCandidate {
209 fn new(token: u32) -> Self {
211 Self {
212 tokens: vec![token],
213 log_prob: 0.0,
214 finished: false,
215 token_count: 1,
216 }
217 }
218
219 fn normalized_score(&self, length_penalty: f32) -> f32 {
221 if self.token_count == 0 {
222 return self.log_prob;
223 }
224 self.log_prob / ((self.token_count as f32).powf(length_penalty))
226 }
227}
228
229impl PartialEq for BeamCandidate {
231 fn eq(&self, other: &Self) -> bool {
232 (self.normalized_score(1.0) - other.normalized_score(1.0)).abs() < 1e-6
233 }
234}
235
236impl Eq for BeamCandidate {}
237
238impl PartialOrd for BeamCandidate {
239 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
240 Some(self.cmp(other))
241 }
242}
243
244impl Ord for BeamCandidate {
245 fn cmp(&self, other: &Self) -> Ordering {
246 other
248 .normalized_score(1.0)
249 .partial_cmp(&self.normalized_score(1.0))
250 .unwrap_or(Ordering::Equal)
251 }
252}
253
254pub struct BeamSearchDecoder {
256 config: DecodingConfig,
257}
258
259impl BeamSearchDecoder {
260 pub fn new(config: DecodingConfig) -> Self {
262 Self { config }
263 }
264
265 pub fn decode(
277 &self,
278 token_probs: &[Vec<f32>],
279 initial_token: u32,
280 vocab_size: usize,
281 eos_token: u32,
282 _pad_token: u32,
283 ) -> Result<Vec<u32>> {
284 if token_probs.is_empty() {
285 return Ok(vec![initial_token]);
286 }
287
288 if token_probs.iter().any(|probs| probs.len() != vocab_size) {
290 return Err(anyhow!("Invalid token probabilities shape"));
291 }
292
293 let mut candidates = BinaryHeap::new();
295 candidates.push(BeamCandidate::new(initial_token));
296
297 for step in 0..token_probs.len().min(self.config.max_length) {
299 let probs = &token_probs[step];
300 let mut next_candidates = Vec::new();
301
302 for candidate in candidates.iter().take(self.config.beam_size) {
304 if candidate.finished {
305 next_candidates.push(candidate.clone());
306 continue;
307 }
308
309 let top_k = self.get_top_k_tokens(probs, self.config.beam_size);
311
312 for (token, log_prob) in top_k {
313 let mut new_candidate = candidate.clone();
314 new_candidate.tokens.push(token);
315 new_candidate.log_prob += log_prob;
316 new_candidate.token_count += 1;
317
318 if token == eos_token || step == token_probs.len() - 1 {
320 new_candidate.finished = true;
321 }
322
323 next_candidates.push(new_candidate);
324 }
325 }
326
327 next_candidates.sort_by(|a, b| {
329 b.normalized_score(self.config.length_penalty)
330 .partial_cmp(&a.normalized_score(self.config.length_penalty))
331 .unwrap_or(Ordering::Equal)
332 });
333
334 candidates = next_candidates
335 .into_iter()
336 .take(self.config.beam_size)
337 .collect::<BinaryHeap<_>>();
338
339 if candidates.iter().all(|c| c.finished) {
341 break;
342 }
343 }
344
345 candidates
347 .iter()
348 .max_by(|a, b| {
349 a.normalized_score(self.config.length_penalty)
350 .partial_cmp(&b.normalized_score(self.config.length_penalty))
351 .unwrap_or(Ordering::Equal)
352 })
353 .map(|c| c.tokens.clone())
354 .ok_or_else(|| anyhow!("No valid candidates found"))
355 }
356
357 fn get_top_k_tokens(&self, log_probs: &[f32], k: usize) -> Vec<(u32, f32)> {
359 let mut indexed: Vec<(u32, f32)> = log_probs
360 .iter()
361 .enumerate()
362 .map(|(i, &prob)| (i as u32, prob))
363 .collect();
364
365 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
366
367 indexed.into_iter().take(k).collect()
368 }
369}
370
371pub struct GreedyDecoder;
373
374impl GreedyDecoder {
375 pub fn decode(
377 token_probs: &[Vec<f32>],
378 initial_token: u32,
379 _vocab_size: usize,
380 eos_token: u32,
381 _pad_token: u32,
382 ) -> Result<Vec<u32>> {
383 let mut tokens = vec![initial_token];
384
385 for probs in token_probs {
386 if probs.is_empty() {
387 break;
388 }
389
390 let (token, _) = probs
391 .iter()
392 .enumerate()
393 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(Ordering::Equal))
394 .unwrap_or((eos_token as usize, &f32::NEG_INFINITY));
395
396 let token = token as u32;
397 tokens.push(token);
398
399 if token == eos_token {
400 break;
401 }
402 }
403
404 Ok(tokens)
405 }
406}
407
408pub struct HybridDecoder {
410 config: DecodingConfig,
411 beam_decoder: BeamSearchDecoder,
412}
413
414impl HybridDecoder {
415 pub fn new(config: DecodingConfig) -> Self {
417 Self {
418 beam_decoder: BeamSearchDecoder::new(config.clone()),
419 config,
420 }
421 }
422
423 pub fn decode(
425 &self,
426 token_probs: &[Vec<f32>],
427 initial_token: u32,
428 vocab_size: usize,
429 eos_token: u32,
430 pad_token: u32,
431 ) -> Result<Vec<u32>> {
432 match self
433 .beam_decoder
434 .decode(token_probs, initial_token, vocab_size, eos_token, pad_token)
435 {
436 Ok(tokens) if tokens.len() > 1 => Ok(tokens),
437 _ => {
438 GreedyDecoder::decode(token_probs, initial_token, vocab_size, eos_token, pad_token)
439 }
440 }
441 }
442
443 pub fn decode_with_fallback(
462 &self,
463 token_probs: &[Vec<f32>],
464 initial_token: u32,
465 vocab_size: usize,
466 eos_token: u32,
467 no_speech_token: u32,
468 decode_text: impl Fn(&[u32]) -> String,
469 ) -> Result<Vec<u32>> {
470 if token_probs.is_empty() {
471 return Ok(vec![initial_token]);
472 }
473
474 if (no_speech_token as usize) < vocab_size {
476 let ns_prob = softmax_at(&token_probs[0], no_speech_token);
477 if ns_prob > self.config.no_speech_threshold {
478 return Ok(vec![]);
479 }
480 }
481
482 let mut best: Option<Vec<u32>> = None;
483
484 for &temp in &self.config.temperatures {
485 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
486 let mut tokens = vec![initial_token];
487 let mut log_probs: Vec<f32> = Vec::new();
488
489 for step_logits in token_probs.iter().take(self.config.max_length) {
490 let selected = sample_from_logits(step_logits, temp, &mut rng);
491 log_probs.push(log_softmax_at(step_logits, selected, temp));
492 tokens.push(selected);
493 if selected == eos_token {
494 break;
495 }
496 }
497
498 let avg_lp = if log_probs.is_empty() {
499 0.0
500 } else {
501 log_probs.iter().sum::<f32>() / log_probs.len() as f32
502 };
503
504 let text = decode_text(&tokens);
505 let cr = compression_ratio(&text);
506
507 let quality_ok = avg_lp > self.config.log_prob_threshold
508 && cr < self.config.compression_ratio_threshold;
509
510 if best.is_none() {
511 best = Some(tokens.clone());
512 }
513
514 if quality_ok {
515 return Ok(tokens);
516 }
517 }
518
519 best.ok_or_else(|| anyhow!("decode_with_fallback: no temperatures configured"))
520 }
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526
527 #[test]
528 fn test_decoding_config_defaults() {
529 let config = DecodingConfig::default();
530 assert_eq!(config.beam_size, 5);
531 assert_eq!(config.temperatures, vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0]);
532 assert_eq!(config.language, "en");
533 }
534
535 #[test]
536 fn test_decoding_config_fast() {
537 let config = DecodingConfig::fast();
538 assert_eq!(config.beam_size, 1);
539 assert_eq!(config.temperatures, vec![0.0]);
540 }
541
542 #[test]
543 fn test_decoding_config_accurate() {
544 let config = DecodingConfig::accurate();
545 assert_eq!(config.beam_size, 10);
546 assert_eq!(config.temperatures, vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0]);
547 }
548
549 #[test]
550 fn test_with_temperature_overrides_sequence() {
551 let config = DecodingConfig::default().with_temperature(0.7);
552 assert_eq!(config.temperatures, vec![0.7]);
553 }
554
555 #[test]
556 fn test_beam_candidate_scoring() {
557 let mut c1 = BeamCandidate::new(1);
558 c1.log_prob = -4.0;
559 c1.token_count = 2;
560
561 let mut c2 = BeamCandidate::new(2);
562 c2.log_prob = -1.0;
563 c2.token_count = 1;
564
565 assert!(c2.normalized_score(1.0) > c1.normalized_score(1.0));
567 }
568
569 #[test]
570 fn test_greedy_decoder() -> Result<()> {
571 let token_probs = vec![
573 vec![-10.0, -5.0, -0.5, -10.0], vec![-5.0, -10.0, -0.1, -10.0], vec![-0.5, -5.0, -10.0, -10.0], ];
577
578 let tokens = GreedyDecoder::decode(&token_probs, 50256, 4, 0, 50257)?;
579
580 assert_eq!(tokens.len(), 4); assert_eq!(tokens[0], 50256); assert_eq!(tokens[1], 2); assert_eq!(tokens[2], 2); assert_eq!(tokens[3], 0); Ok(())
587 }
588
589 #[test]
590 fn test_beam_search_decoder() -> Result<()> {
591 let config = DecodingConfig {
592 beam_size: 2,
593 ..Default::default()
594 };
595
596 let decoder = BeamSearchDecoder::new(config);
597
598 let token_probs = vec![
599 vec![-5.0, -0.5, -10.0], vec![-0.1, -5.0, -10.0], ];
602
603 let tokens = decoder.decode(&token_probs, 100, 3, 0, 99)?;
604
605 assert!(tokens.len() >= 2);
606 assert_eq!(tokens[0], 100); Ok(())
609 }
610
611 #[test]
612 fn test_hybrid_decoder_fallback() -> Result<()> {
613 let config = DecodingConfig::default();
614 let decoder = HybridDecoder::new(config);
615
616 let token_probs = vec![vec![-0.5, -10.0, -10.0]];
617
618 let tokens = decoder.decode(&token_probs, 100, 3, 0, 99)?;
619
620 assert!(!tokens.is_empty());
621 assert_eq!(tokens[0], 100);
622
623 Ok(())
624 }
625
626 #[test]
627 fn test_compression_ratio_normal_text() {
628 let text = "The quick brown fox jumps over the lazy dog.";
630 let cr = compression_ratio(text);
631 assert!(cr < 2.4, "normal text compression ratio was {cr}");
632 }
633
634 #[test]
635 fn test_compression_ratio_repetitive_text() {
636 let phrase = "the quick brown fox ";
640 let text = phrase.repeat(100); let cr = compression_ratio(&text);
642 assert!(cr > 2.4, "repetitive text compression ratio was {cr}");
643 }
644
645 #[test]
646 fn test_compression_ratio_empty() {
647 assert_eq!(compression_ratio(""), 0.0);
648 }
649
650 #[test]
651 fn test_softmax_at_picks_max() {
652 let logits = vec![-10.0, -0.1, -5.0];
653 let p_max = softmax_at(&logits, 1);
654 let p_min = softmax_at(&logits, 0);
655 assert!(p_max > p_min, "softmax of max logit should be highest");
656 let total: f32 = (0..3).map(|i| softmax_at(&logits, i)).sum();
657 assert!((total - 1.0).abs() < 1e-4, "softmax probs must sum to 1");
658 }
659
660 #[test]
661 fn test_decode_with_fallback_passes_quality() -> Result<()> {
662 let config = DecodingConfig {
665 temperatures: vec![0.0],
666 log_prob_threshold: -100.0,
667 compression_ratio_threshold: 100.0,
668 no_speech_threshold: 1.0,
669 max_length: 5,
670 ..Default::default()
671 };
672 let decoder = HybridDecoder::new(config);
673
674 let token_probs = vec![vec![-0.01, -10.0, -10.0], vec![-0.01, -10.0, -10.0]];
675
676 let tokens = decoder.decode_with_fallback(
677 &token_probs,
678 99,
679 3,
680 0, 2, |ids| {
683 ids.iter()
684 .map(|i| i.to_string())
685 .collect::<Vec<_>>()
686 .join(" ")
687 },
688 )?;
689
690 assert!(!tokens.is_empty());
691 assert_eq!(tokens[0], 99);
692 Ok(())
693 }
694
695 #[test]
696 fn test_decode_with_fallback_no_speech() -> Result<()> {
697 let config = DecodingConfig {
699 temperatures: vec![0.0],
700 no_speech_threshold: 0.5,
701 log_prob_threshold: -100.0,
702 compression_ratio_threshold: 100.0,
703 max_length: 5,
704 ..Default::default()
705 };
706 let decoder = HybridDecoder::new(config);
707
708 let token_probs = vec![vec![-10.0, 100.0, -10.0]];
710
711 let tokens = decoder.decode_with_fallback(
712 &token_probs,
713 99,
714 3,
715 0, 1, |_| String::new(),
718 )?;
719
720 assert!(
721 tokens.is_empty(),
722 "should return empty when no-speech detected"
723 );
724 Ok(())
725 }
726}