Skip to main content

whisper_cpp_plus/enhanced/
fallback.rs

1//! Temperature fallback mechanism for improved transcription quality
2//!
3//! This module implements quality-based retry logic inspired by faster-whisper
4
5use crate::{WhisperState, FullParams, Result, WhisperError, Segment, TranscriptionResult};
6use std::io::Write;
7use flate2::Compression;
8use flate2::write::ZlibEncoder;
9use whisper_cpp_plus_sys as ffi;
10
11/// Quality thresholds for transcription validation
12#[derive(Debug, Clone)]
13pub struct QualityThresholds {
14    /// Maximum compression ratio (default: 2.4)
15    pub compression_ratio_threshold: Option<f32>,
16    /// Minimum average log probability (default: -1.0)
17    pub log_prob_threshold: Option<f32>,
18    /// Maximum no-speech probability (default: 0.6)
19    pub no_speech_threshold: Option<f32>,
20}
21
22impl Default for QualityThresholds {
23    fn default() -> Self {
24        Self {
25            compression_ratio_threshold: Some(2.4),
26            log_prob_threshold: Some(-1.0),
27            no_speech_threshold: Some(0.6),
28        }
29    }
30}
31
32/// Enhanced transcription parameters with fallback support
33#[derive(Clone)]
34pub struct EnhancedTranscriptionParams {
35    /// Base parameters
36    pub base: FullParams,
37    /// Temperature sequence for fallback
38    pub temperatures: Vec<f32>,
39    /// Quality thresholds
40    pub thresholds: QualityThresholds,
41    /// Whether to reset prompt on temperature increase
42    pub prompt_reset_on_temperature: f32,
43}
44
45impl EnhancedTranscriptionParams {
46    /// Create from base params with default enhancement settings
47    pub fn from_base(base: FullParams) -> Self {
48        Self {
49            base,
50            temperatures: vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
51            thresholds: QualityThresholds::default(),
52            prompt_reset_on_temperature: 0.5,
53        }
54    }
55
56    pub fn builder() -> EnhancedTranscriptionParamsBuilder {
57        EnhancedTranscriptionParamsBuilder::new()
58    }
59}
60
61pub struct EnhancedTranscriptionParamsBuilder {
62    params: EnhancedTranscriptionParams,
63}
64
65impl EnhancedTranscriptionParamsBuilder {
66    pub fn new() -> Self {
67        Self {
68            params: EnhancedTranscriptionParams::from_base(FullParams::default()),
69        }
70    }
71
72    pub fn base_params(mut self, params: FullParams) -> Self {
73        self.params.base = params;
74        self
75    }
76
77    pub fn language(mut self, lang: &str) -> Self {
78        self.params.base = self.params.base.language(lang);
79        self
80    }
81
82    pub fn temperatures(mut self, temps: Vec<f32>) -> Self {
83        self.params.temperatures = temps;
84        self
85    }
86
87    pub fn compression_ratio_threshold(mut self, threshold: Option<f32>) -> Self {
88        self.params.thresholds.compression_ratio_threshold = threshold;
89        self
90    }
91
92    pub fn log_prob_threshold(mut self, threshold: Option<f32>) -> Self {
93        self.params.thresholds.log_prob_threshold = threshold;
94        self
95    }
96
97    pub fn build(self) -> EnhancedTranscriptionParams {
98        self.params
99    }
100}
101
102impl Default for EnhancedTranscriptionParamsBuilder {
103    fn default() -> Self {
104        Self::new()
105    }
106}
107
108/// Calculate compression ratio for text using zlib
109pub fn calculate_compression_ratio(text: &str) -> f32 {
110    let text_bytes = text.as_bytes();
111    let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
112    encoder.write_all(text_bytes).unwrap();
113    let compressed = encoder.finish().unwrap();
114
115    text_bytes.len() as f32 / compressed.len() as f32
116}
117
118/// Result of a single transcription attempt
119#[derive(Debug)]
120pub struct TranscriptionAttempt {
121    pub text: String,
122    pub segments: Vec<Segment>,
123    pub temperature: f32,
124    pub compression_ratio: f32,
125    pub avg_logprob: f32,
126    pub no_speech_prob: f32,
127}
128
129impl TranscriptionAttempt {
130    /// Check if this attempt meets quality thresholds
131    pub fn meets_thresholds(&self, thresholds: &QualityThresholds) -> bool {
132        let mut meets = true;
133
134        if let Some(cr_threshold) = thresholds.compression_ratio_threshold {
135            if self.compression_ratio > cr_threshold {
136                meets = false;
137            }
138        }
139
140        if let Some(lp_threshold) = thresholds.log_prob_threshold {
141            if self.avg_logprob < lp_threshold {
142                // Check for silence exception
143                if let Some(ns_threshold) = thresholds.no_speech_threshold {
144                    if self.no_speech_prob <= ns_threshold {
145                        meets = false;
146                    }
147                } else {
148                    meets = false;
149                }
150            }
151        }
152
153        meets
154    }
155}
156
157/// Enhanced state with fallback support
158pub struct EnhancedWhisperState<'a> {
159    state: &'a mut WhisperState,
160}
161
162impl<'a> EnhancedWhisperState<'a> {
163    pub fn new(state: &'a mut WhisperState) -> Self {
164        Self { state }
165    }
166
167    /// Get no-speech probability for a segment (enhanced feature)
168    fn get_no_speech_prob(&self, segment_idx: i32) -> f32 {
169        unsafe {
170            // Direct FFI call using the exposed ptr
171            ffi::whisper_full_get_segment_no_speech_prob_from_state(
172                self.state.ptr,
173                segment_idx
174            )
175        }
176    }
177
178    /// Calculate average log probability from token probabilities
179    fn calculate_avg_logprob(&self, segment_idx: i32) -> f32 {
180        let n_tokens = self.state.full_n_tokens(segment_idx);
181        if n_tokens == 0 {
182            return 0.0;
183        }
184
185        let mut sum_logprob = 0.0;
186        for i in 0..n_tokens {
187            let prob = self.state.full_get_token_prob(segment_idx, i);
188            if prob > 0.0 {
189                sum_logprob += prob.ln();
190            }
191        }
192
193        sum_logprob / n_tokens as f32
194    }
195
196    /// Transcribe with temperature fallback
197    pub fn transcribe_with_fallback(
198        &mut self,
199        params: EnhancedTranscriptionParams,
200        audio: &[f32],
201    ) -> Result<TranscriptionResult> {
202        let mut all_attempts = Vec::new();
203        let mut below_cr_attempts = Vec::new();
204
205        for temperature in &params.temperatures {
206            // Update temperature in params
207            let mut current_params = params.base.clone();
208            current_params = current_params.temperature(*temperature);
209
210            // Reset prompt if temperature is high
211            if *temperature > params.prompt_reset_on_temperature {
212                current_params = current_params.initial_prompt("");
213            }
214
215            // Attempt transcription
216            self.state.full(current_params, audio)?;
217
218            // Extract results
219            let n_segments = self.state.full_n_segments();
220            let mut segments = Vec::new();
221            let mut text = String::new();
222            let mut total_logprob = 0.0;
223            let mut total_tokens = 0;
224
225            for i in 0..n_segments {
226                let segment_text = self.state.full_get_segment_text(i)?;
227                let (start_ms, end_ms) = self.state.full_get_segment_timestamps(i);
228                let speaker_turn_next = self.state.full_get_segment_speaker_turn_next(i);
229
230                if i > 0 {
231                    text.push(' ');
232                }
233                text.push_str(&segment_text);
234
235                segments.push(Segment {
236                    start_ms,
237                    end_ms,
238                    text: segment_text,
239                    speaker_turn_next,
240                });
241
242                // Calculate average log probability
243                let avg_lp = self.calculate_avg_logprob(i);
244                let n_tokens = self.state.full_n_tokens(i);
245                total_logprob += avg_lp * n_tokens as f32;
246                total_tokens += n_tokens;
247            }
248
249            let avg_logprob = if total_tokens > 0 {
250                total_logprob / total_tokens as f32
251            } else {
252                0.0
253            };
254
255            // Calculate quality metrics
256            let compression_ratio = calculate_compression_ratio(&text);
257            let no_speech_prob = if n_segments > 0 {
258                self.get_no_speech_prob(0)
259            } else {
260                0.0
261            };
262
263            let attempt = TranscriptionAttempt {
264                text: text.clone(),
265                segments: segments.clone(),
266                temperature: *temperature,
267                compression_ratio,
268                avg_logprob,
269                no_speech_prob,
270            };
271
272            // Check if attempt meets thresholds
273            if attempt.meets_thresholds(&params.thresholds) {
274                return Ok(TranscriptionResult {
275                    text: attempt.text,
276                    segments: attempt.segments,
277                });
278            }
279
280            // Store attempt for potential fallback selection
281            if let Some(cr_threshold) = params.thresholds.compression_ratio_threshold {
282                if attempt.compression_ratio <= cr_threshold {
283                    below_cr_attempts.push(attempt);
284                } else {
285                    all_attempts.push(attempt);
286                }
287            } else {
288                all_attempts.push(attempt);
289            }
290        }
291
292        // All temperatures failed, select best attempt
293        let best_attempt = if !below_cr_attempts.is_empty() {
294            below_cr_attempts.into_iter()
295                .max_by(|a, b| a.avg_logprob.partial_cmp(&b.avg_logprob).unwrap())
296        } else {
297            all_attempts.into_iter()
298                .max_by(|a, b| a.avg_logprob.partial_cmp(&b.avg_logprob).unwrap())
299        };
300
301        best_attempt
302            .map(|a| TranscriptionResult {
303                text: a.text,
304                segments: a.segments,
305            })
306            .ok_or_else(|| WhisperError::TranscriptionError(
307                "Failed to produce acceptable transcription with any temperature".into()
308            ))
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    #[test]
317    fn test_compression_ratio_calculation() {
318        // Short text might not compress well due to compression overhead
319        let text = "The quick brown fox jumps over the lazy dog";
320        let ratio = calculate_compression_ratio(text);
321        assert!(ratio > 0.0); // Ratio should be positive
322
323        // Longer text should compress better
324        let longer_text = "The quick brown fox jumps over the lazy dog. ".repeat(10);
325        let longer_ratio = calculate_compression_ratio(&longer_text);
326        assert!(longer_ratio > 1.0); // Should achieve compression
327
328        // Highly repetitive text should compress very well
329        let repetitive = "a".repeat(1000);
330        let repetitive_ratio = calculate_compression_ratio(&repetitive);
331        assert!(repetitive_ratio > 5.0); // Highly compressible
332    }
333
334    #[test]
335    fn test_quality_threshold_checking() {
336        let thresholds = QualityThresholds {
337            compression_ratio_threshold: Some(2.4),
338            log_prob_threshold: Some(-1.0),
339            no_speech_threshold: Some(0.6),
340        };
341
342        let good_attempt = TranscriptionAttempt {
343            text: "Hello world".to_string(),
344            segments: vec![],
345            temperature: 0.0,
346            compression_ratio: 1.5,
347            avg_logprob: -0.5,
348            no_speech_prob: 0.1,
349        };
350
351        assert!(good_attempt.meets_thresholds(&thresholds));
352
353        let bad_attempt = TranscriptionAttempt {
354            text: "a".repeat(100),
355            segments: vec![],
356            temperature: 0.0,
357            compression_ratio: 10.0, // Too repetitive
358            avg_logprob: -0.5,
359            no_speech_prob: 0.1,
360        };
361
362        assert!(!bad_attempt.meets_thresholds(&thresholds));
363    }
364
365    #[test]
366    fn test_enhanced_params_from_base() {
367        let base = FullParams::default()
368            .language("en");
369
370        let enhanced = EnhancedTranscriptionParams::from_base(base);
371
372        assert_eq!(enhanced.temperatures.len(), 6);
373        assert_eq!(enhanced.temperatures[0], 0.0);
374        assert_eq!(enhanced.prompt_reset_on_temperature, 0.5);
375        assert!(enhanced.thresholds.compression_ratio_threshold.is_some());
376    }
377
378    #[test]
379    fn test_enhanced_transcription_params_builder() {
380        let params = EnhancedTranscriptionParamsBuilder::new()
381            .language("en")
382            .temperatures(vec![0.0, 0.5, 1.0])
383            .compression_ratio_threshold(Some(3.0))
384            .build();
385
386        assert_eq!(params.temperatures.len(), 3);
387        assert_eq!(params.thresholds.compression_ratio_threshold, Some(3.0));
388    }
389}