Skip to main content

whisper_cpp_plus/enhanced/
fallback.rs

1//! Temperature fallback mechanism for improved transcription quality
2//!
3//! This module implements quality-based retry logic inspired by faster-whisper
4
5use crate::{FullParams, Result, Segment, TranscriptionResult, WhisperError, WhisperState};
6use flate2::write::ZlibEncoder;
7use flate2::Compression;
8use std::io::Write;
9use whisper_cpp_plus_sys as ffi;
10
11/// Quality thresholds for transcription validation
12#[derive(Debug, Clone)]
13pub struct QualityThresholds {
14    /// Maximum compression ratio (default: 2.4)
15    pub compression_ratio_threshold: Option<f32>,
16    /// Minimum average log probability (default: -1.0)
17    pub log_prob_threshold: Option<f32>,
18    /// Maximum no-speech probability (default: 0.6)
19    pub no_speech_threshold: Option<f32>,
20}
21
22impl Default for QualityThresholds {
23    fn default() -> Self {
24        Self {
25            compression_ratio_threshold: Some(2.4),
26            log_prob_threshold: Some(-1.0),
27            no_speech_threshold: Some(0.6),
28        }
29    }
30}
31
32/// Enhanced transcription parameters with fallback support
33#[derive(Clone)]
34pub struct EnhancedTranscriptionParams {
35    /// Base parameters
36    pub base: FullParams,
37    /// Temperature sequence for fallback
38    pub temperatures: Vec<f32>,
39    /// Quality thresholds
40    pub thresholds: QualityThresholds,
41    /// Whether to reset prompt on temperature increase
42    pub prompt_reset_on_temperature: f32,
43}
44
45impl EnhancedTranscriptionParams {
46    /// Create from base params with default enhancement settings
47    pub fn from_base(base: FullParams) -> Self {
48        Self {
49            base,
50            temperatures: vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
51            thresholds: QualityThresholds::default(),
52            prompt_reset_on_temperature: 0.5,
53        }
54    }
55
56    pub fn builder() -> EnhancedTranscriptionParamsBuilder {
57        EnhancedTranscriptionParamsBuilder::new()
58    }
59}
60
61pub struct EnhancedTranscriptionParamsBuilder {
62    params: EnhancedTranscriptionParams,
63}
64
65impl EnhancedTranscriptionParamsBuilder {
66    pub fn new() -> Self {
67        Self {
68            params: EnhancedTranscriptionParams::from_base(FullParams::default()),
69        }
70    }
71
72    pub fn base_params(mut self, params: FullParams) -> Self {
73        self.params.base = params;
74        self
75    }
76
77    pub fn language(mut self, lang: &str) -> Self {
78        self.params.base = self.params.base.language(lang);
79        self
80    }
81
82    pub fn temperatures(mut self, temps: Vec<f32>) -> Self {
83        self.params.temperatures = temps;
84        self
85    }
86
87    pub fn compression_ratio_threshold(mut self, threshold: Option<f32>) -> Self {
88        self.params.thresholds.compression_ratio_threshold = threshold;
89        self
90    }
91
92    pub fn log_prob_threshold(mut self, threshold: Option<f32>) -> Self {
93        self.params.thresholds.log_prob_threshold = threshold;
94        self
95    }
96
97    pub fn build(self) -> EnhancedTranscriptionParams {
98        self.params
99    }
100}
101
102impl Default for EnhancedTranscriptionParamsBuilder {
103    fn default() -> Self {
104        Self::new()
105    }
106}
107
108/// Calculate compression ratio for text using zlib
109pub fn calculate_compression_ratio(text: &str) -> f32 {
110    let text_bytes = text.as_bytes();
111    let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
112    encoder.write_all(text_bytes).unwrap();
113    let compressed = encoder.finish().unwrap();
114
115    text_bytes.len() as f32 / compressed.len() as f32
116}
117
118/// Result of a single transcription attempt
119#[derive(Debug)]
120pub struct TranscriptionAttempt {
121    pub text: String,
122    pub segments: Vec<Segment>,
123    pub temperature: f32,
124    pub compression_ratio: f32,
125    pub avg_logprob: f32,
126    pub no_speech_prob: f32,
127}
128
129impl TranscriptionAttempt {
130    /// Check if this attempt meets quality thresholds
131    pub fn meets_thresholds(&self, thresholds: &QualityThresholds) -> bool {
132        let mut meets = true;
133
134        if let Some(cr_threshold) = thresholds.compression_ratio_threshold {
135            if self.compression_ratio > cr_threshold {
136                meets = false;
137            }
138        }
139
140        if let Some(lp_threshold) = thresholds.log_prob_threshold {
141            if self.avg_logprob < lp_threshold {
142                // Check for silence exception
143                if let Some(ns_threshold) = thresholds.no_speech_threshold {
144                    if self.no_speech_prob <= ns_threshold {
145                        meets = false;
146                    }
147                } else {
148                    meets = false;
149                }
150            }
151        }
152
153        meets
154    }
155}
156
157/// Enhanced state with fallback support
158pub struct EnhancedWhisperState<'a> {
159    state: &'a mut WhisperState,
160}
161
162impl<'a> EnhancedWhisperState<'a> {
163    pub fn new(state: &'a mut WhisperState) -> Self {
164        Self { state }
165    }
166
167    /// Get no-speech probability for a segment (enhanced feature)
168    fn get_no_speech_prob(&self, segment_idx: i32) -> f32 {
169        unsafe {
170            // Direct FFI call using the exposed ptr
171            ffi::whisper_full_get_segment_no_speech_prob_from_state(self.state.ptr, segment_idx)
172        }
173    }
174
175    /// Calculate average log probability from token probabilities
176    fn calculate_avg_logprob(&self, segment_idx: i32) -> f32 {
177        let n_tokens = self.state.full_n_tokens(segment_idx);
178        if n_tokens == 0 {
179            return 0.0;
180        }
181
182        let mut sum_logprob = 0.0;
183        for i in 0..n_tokens {
184            let prob = self.state.full_get_token_prob(segment_idx, i);
185            if prob > 0.0 {
186                sum_logprob += prob.ln();
187            }
188        }
189
190        sum_logprob / n_tokens as f32
191    }
192
193    /// Transcribe with temperature fallback
194    pub fn transcribe_with_fallback(
195        &mut self,
196        params: EnhancedTranscriptionParams,
197        audio: &[f32],
198    ) -> Result<TranscriptionResult> {
199        let mut all_attempts = Vec::new();
200        let mut below_cr_attempts = Vec::new();
201
202        for temperature in &params.temperatures {
203            // Update temperature in params
204            let mut current_params = params.base.clone();
205            current_params = current_params.temperature(*temperature);
206
207            // Reset prompt if temperature is high
208            if *temperature > params.prompt_reset_on_temperature {
209                current_params = current_params.initial_prompt("");
210            }
211
212            // Attempt transcription
213            self.state.full(current_params, audio)?;
214
215            // Extract results
216            let n_segments = self.state.full_n_segments();
217            let mut segments = Vec::new();
218            let mut text = String::new();
219            let mut total_logprob = 0.0;
220            let mut total_tokens = 0;
221
222            for i in 0..n_segments {
223                let segment_text = self.state.full_get_segment_text(i)?;
224                let (start_ms, end_ms) = self.state.full_get_segment_timestamps(i);
225                let speaker_turn_next = self.state.full_get_segment_speaker_turn_next(i);
226
227                if i > 0 {
228                    text.push(' ');
229                }
230                text.push_str(&segment_text);
231
232                segments.push(Segment {
233                    start_ms,
234                    end_ms,
235                    text: segment_text,
236                    speaker_turn_next,
237                });
238
239                // Calculate average log probability
240                let avg_lp = self.calculate_avg_logprob(i);
241                let n_tokens = self.state.full_n_tokens(i);
242                total_logprob += avg_lp * n_tokens as f32;
243                total_tokens += n_tokens;
244            }
245
246            let avg_logprob = if total_tokens > 0 {
247                total_logprob / total_tokens as f32
248            } else {
249                0.0
250            };
251
252            // Calculate quality metrics
253            let compression_ratio = calculate_compression_ratio(&text);
254            let no_speech_prob = if n_segments > 0 {
255                self.get_no_speech_prob(0)
256            } else {
257                0.0
258            };
259
260            let attempt = TranscriptionAttempt {
261                text: text.clone(),
262                segments: segments.clone(),
263                temperature: *temperature,
264                compression_ratio,
265                avg_logprob,
266                no_speech_prob,
267            };
268
269            // Check if attempt meets thresholds
270            if attempt.meets_thresholds(&params.thresholds) {
271                return Ok(TranscriptionResult {
272                    text: attempt.text,
273                    segments: attempt.segments,
274                });
275            }
276
277            // Store attempt for potential fallback selection
278            if let Some(cr_threshold) = params.thresholds.compression_ratio_threshold {
279                if attempt.compression_ratio <= cr_threshold {
280                    below_cr_attempts.push(attempt);
281                } else {
282                    all_attempts.push(attempt);
283                }
284            } else {
285                all_attempts.push(attempt);
286            }
287        }
288
289        // All temperatures failed, select best attempt
290        let best_attempt = if !below_cr_attempts.is_empty() {
291            below_cr_attempts
292                .into_iter()
293                .max_by(|a, b| a.avg_logprob.partial_cmp(&b.avg_logprob).unwrap())
294        } else {
295            all_attempts
296                .into_iter()
297                .max_by(|a, b| a.avg_logprob.partial_cmp(&b.avg_logprob).unwrap())
298        };
299
300        best_attempt
301            .map(|a| TranscriptionResult {
302                text: a.text,
303                segments: a.segments,
304            })
305            .ok_or_else(|| {
306                WhisperError::TranscriptionError(
307                    "Failed to produce acceptable transcription with any temperature".into(),
308                )
309            })
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn test_compression_ratio_calculation() {
319        // Short text might not compress well due to compression overhead
320        let text = "The quick brown fox jumps over the lazy dog";
321        let ratio = calculate_compression_ratio(text);
322        assert!(ratio > 0.0); // Ratio should be positive
323
324        // Longer text should compress better
325        let longer_text = "The quick brown fox jumps over the lazy dog. ".repeat(10);
326        let longer_ratio = calculate_compression_ratio(&longer_text);
327        assert!(longer_ratio > 1.0); // Should achieve compression
328
329        // Highly repetitive text should compress very well
330        let repetitive = "a".repeat(1000);
331        let repetitive_ratio = calculate_compression_ratio(&repetitive);
332        assert!(repetitive_ratio > 5.0); // Highly compressible
333    }
334
335    #[test]
336    fn test_quality_threshold_checking() {
337        let thresholds = QualityThresholds {
338            compression_ratio_threshold: Some(2.4),
339            log_prob_threshold: Some(-1.0),
340            no_speech_threshold: Some(0.6),
341        };
342
343        let good_attempt = TranscriptionAttempt {
344            text: "Hello world".to_string(),
345            segments: vec![],
346            temperature: 0.0,
347            compression_ratio: 1.5,
348            avg_logprob: -0.5,
349            no_speech_prob: 0.1,
350        };
351
352        assert!(good_attempt.meets_thresholds(&thresholds));
353
354        let bad_attempt = TranscriptionAttempt {
355            text: "a".repeat(100),
356            segments: vec![],
357            temperature: 0.0,
358            compression_ratio: 10.0, // Too repetitive
359            avg_logprob: -0.5,
360            no_speech_prob: 0.1,
361        };
362
363        assert!(!bad_attempt.meets_thresholds(&thresholds));
364    }
365
366    #[test]
367    fn test_enhanced_params_from_base() {
368        let base = FullParams::default().language("en");
369
370        let enhanced = EnhancedTranscriptionParams::from_base(base);
371
372        assert_eq!(enhanced.temperatures.len(), 6);
373        assert_eq!(enhanced.temperatures[0], 0.0);
374        assert_eq!(enhanced.prompt_reset_on_temperature, 0.5);
375        assert!(enhanced.thresholds.compression_ratio_threshold.is_some());
376    }
377
378    #[test]
379    fn test_enhanced_transcription_params_builder() {
380        let params = EnhancedTranscriptionParamsBuilder::new()
381            .language("en")
382            .temperatures(vec![0.0, 0.5, 1.0])
383            .compression_ratio_threshold(Some(3.0))
384            .build();
385
386        assert_eq!(params.temperatures.len(), 3);
387        assert_eq!(params.thresholds.compression_ratio_threshold, Some(3.0));
388    }
389}