whisper_cpp_plus/enhanced/
fallback.rs1use crate::{FullParams, Result, Segment, TranscriptionResult, WhisperError, WhisperState};
6use flate2::write::ZlibEncoder;
7use flate2::Compression;
8use std::io::Write;
9use whisper_cpp_plus_sys as ffi;
10
11#[derive(Debug, Clone)]
13pub struct QualityThresholds {
14 pub compression_ratio_threshold: Option<f32>,
16 pub log_prob_threshold: Option<f32>,
18 pub no_speech_threshold: Option<f32>,
20}
21
22impl Default for QualityThresholds {
23 fn default() -> Self {
24 Self {
25 compression_ratio_threshold: Some(2.4),
26 log_prob_threshold: Some(-1.0),
27 no_speech_threshold: Some(0.6),
28 }
29 }
30}
31
32#[derive(Clone)]
34pub struct EnhancedTranscriptionParams {
35 pub base: FullParams,
37 pub temperatures: Vec<f32>,
39 pub thresholds: QualityThresholds,
41 pub prompt_reset_on_temperature: f32,
43}
44
45impl EnhancedTranscriptionParams {
46 pub fn from_base(base: FullParams) -> Self {
48 Self {
49 base,
50 temperatures: vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
51 thresholds: QualityThresholds::default(),
52 prompt_reset_on_temperature: 0.5,
53 }
54 }
55
56 pub fn builder() -> EnhancedTranscriptionParamsBuilder {
57 EnhancedTranscriptionParamsBuilder::new()
58 }
59}
60
61pub struct EnhancedTranscriptionParamsBuilder {
62 params: EnhancedTranscriptionParams,
63}
64
65impl EnhancedTranscriptionParamsBuilder {
66 pub fn new() -> Self {
67 Self {
68 params: EnhancedTranscriptionParams::from_base(FullParams::default()),
69 }
70 }
71
72 pub fn base_params(mut self, params: FullParams) -> Self {
73 self.params.base = params;
74 self
75 }
76
77 pub fn language(mut self, lang: &str) -> Self {
78 self.params.base = self.params.base.language(lang);
79 self
80 }
81
82 pub fn temperatures(mut self, temps: Vec<f32>) -> Self {
83 self.params.temperatures = temps;
84 self
85 }
86
87 pub fn compression_ratio_threshold(mut self, threshold: Option<f32>) -> Self {
88 self.params.thresholds.compression_ratio_threshold = threshold;
89 self
90 }
91
92 pub fn log_prob_threshold(mut self, threshold: Option<f32>) -> Self {
93 self.params.thresholds.log_prob_threshold = threshold;
94 self
95 }
96
97 pub fn build(self) -> EnhancedTranscriptionParams {
98 self.params
99 }
100}
101
102impl Default for EnhancedTranscriptionParamsBuilder {
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108pub fn calculate_compression_ratio(text: &str) -> f32 {
110 let text_bytes = text.as_bytes();
111 let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
112 encoder.write_all(text_bytes).unwrap();
113 let compressed = encoder.finish().unwrap();
114
115 text_bytes.len() as f32 / compressed.len() as f32
116}
117
118#[derive(Debug)]
120pub struct TranscriptionAttempt {
121 pub text: String,
122 pub segments: Vec<Segment>,
123 pub temperature: f32,
124 pub compression_ratio: f32,
125 pub avg_logprob: f32,
126 pub no_speech_prob: f32,
127}
128
129impl TranscriptionAttempt {
130 pub fn meets_thresholds(&self, thresholds: &QualityThresholds) -> bool {
132 let mut meets = true;
133
134 if let Some(cr_threshold) = thresholds.compression_ratio_threshold {
135 if self.compression_ratio > cr_threshold {
136 meets = false;
137 }
138 }
139
140 if let Some(lp_threshold) = thresholds.log_prob_threshold {
141 if self.avg_logprob < lp_threshold {
142 if let Some(ns_threshold) = thresholds.no_speech_threshold {
144 if self.no_speech_prob <= ns_threshold {
145 meets = false;
146 }
147 } else {
148 meets = false;
149 }
150 }
151 }
152
153 meets
154 }
155}
156
157pub struct EnhancedWhisperState<'a> {
159 state: &'a mut WhisperState,
160}
161
162impl<'a> EnhancedWhisperState<'a> {
163 pub fn new(state: &'a mut WhisperState) -> Self {
164 Self { state }
165 }
166
167 fn get_no_speech_prob(&self, segment_idx: i32) -> f32 {
169 unsafe {
170 ffi::whisper_full_get_segment_no_speech_prob_from_state(self.state.ptr, segment_idx)
172 }
173 }
174
175 fn calculate_avg_logprob(&self, segment_idx: i32) -> f32 {
177 let n_tokens = self.state.full_n_tokens(segment_idx);
178 if n_tokens == 0 {
179 return 0.0;
180 }
181
182 let mut sum_logprob = 0.0;
183 for i in 0..n_tokens {
184 let prob = self.state.full_get_token_prob(segment_idx, i);
185 if prob > 0.0 {
186 sum_logprob += prob.ln();
187 }
188 }
189
190 sum_logprob / n_tokens as f32
191 }
192
193 pub fn transcribe_with_fallback(
195 &mut self,
196 params: EnhancedTranscriptionParams,
197 audio: &[f32],
198 ) -> Result<TranscriptionResult> {
199 let mut all_attempts = Vec::new();
200 let mut below_cr_attempts = Vec::new();
201
202 for temperature in ¶ms.temperatures {
203 let mut current_params = params.base.clone();
205 current_params = current_params.temperature(*temperature);
206
207 if *temperature > params.prompt_reset_on_temperature {
209 current_params = current_params.initial_prompt("");
210 }
211
212 self.state.full(current_params, audio)?;
214
215 let n_segments = self.state.full_n_segments();
217 let mut segments = Vec::new();
218 let mut text = String::new();
219 let mut total_logprob = 0.0;
220 let mut total_tokens = 0;
221
222 for i in 0..n_segments {
223 let segment_text = self.state.full_get_segment_text(i)?;
224 let (start_ms, end_ms) = self.state.full_get_segment_timestamps(i);
225 let speaker_turn_next = self.state.full_get_segment_speaker_turn_next(i);
226
227 if i > 0 {
228 text.push(' ');
229 }
230 text.push_str(&segment_text);
231
232 segments.push(Segment {
233 start_ms,
234 end_ms,
235 text: segment_text,
236 speaker_turn_next,
237 });
238
239 let avg_lp = self.calculate_avg_logprob(i);
241 let n_tokens = self.state.full_n_tokens(i);
242 total_logprob += avg_lp * n_tokens as f32;
243 total_tokens += n_tokens;
244 }
245
246 let avg_logprob = if total_tokens > 0 {
247 total_logprob / total_tokens as f32
248 } else {
249 0.0
250 };
251
252 let compression_ratio = calculate_compression_ratio(&text);
254 let no_speech_prob = if n_segments > 0 {
255 self.get_no_speech_prob(0)
256 } else {
257 0.0
258 };
259
260 let attempt = TranscriptionAttempt {
261 text: text.clone(),
262 segments: segments.clone(),
263 temperature: *temperature,
264 compression_ratio,
265 avg_logprob,
266 no_speech_prob,
267 };
268
269 if attempt.meets_thresholds(¶ms.thresholds) {
271 return Ok(TranscriptionResult {
272 text: attempt.text,
273 segments: attempt.segments,
274 });
275 }
276
277 if let Some(cr_threshold) = params.thresholds.compression_ratio_threshold {
279 if attempt.compression_ratio <= cr_threshold {
280 below_cr_attempts.push(attempt);
281 } else {
282 all_attempts.push(attempt);
283 }
284 } else {
285 all_attempts.push(attempt);
286 }
287 }
288
289 let best_attempt = if !below_cr_attempts.is_empty() {
291 below_cr_attempts
292 .into_iter()
293 .max_by(|a, b| a.avg_logprob.partial_cmp(&b.avg_logprob).unwrap())
294 } else {
295 all_attempts
296 .into_iter()
297 .max_by(|a, b| a.avg_logprob.partial_cmp(&b.avg_logprob).unwrap())
298 };
299
300 best_attempt
301 .map(|a| TranscriptionResult {
302 text: a.text,
303 segments: a.segments,
304 })
305 .ok_or_else(|| {
306 WhisperError::TranscriptionError(
307 "Failed to produce acceptable transcription with any temperature".into(),
308 )
309 })
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn test_compression_ratio_calculation() {
319 let text = "The quick brown fox jumps over the lazy dog";
321 let ratio = calculate_compression_ratio(text);
322 assert!(ratio > 0.0); let longer_text = "The quick brown fox jumps over the lazy dog. ".repeat(10);
326 let longer_ratio = calculate_compression_ratio(&longer_text);
327 assert!(longer_ratio > 1.0); let repetitive = "a".repeat(1000);
331 let repetitive_ratio = calculate_compression_ratio(&repetitive);
332 assert!(repetitive_ratio > 5.0); }
334
335 #[test]
336 fn test_quality_threshold_checking() {
337 let thresholds = QualityThresholds {
338 compression_ratio_threshold: Some(2.4),
339 log_prob_threshold: Some(-1.0),
340 no_speech_threshold: Some(0.6),
341 };
342
343 let good_attempt = TranscriptionAttempt {
344 text: "Hello world".to_string(),
345 segments: vec![],
346 temperature: 0.0,
347 compression_ratio: 1.5,
348 avg_logprob: -0.5,
349 no_speech_prob: 0.1,
350 };
351
352 assert!(good_attempt.meets_thresholds(&thresholds));
353
354 let bad_attempt = TranscriptionAttempt {
355 text: "a".repeat(100),
356 segments: vec![],
357 temperature: 0.0,
358 compression_ratio: 10.0, avg_logprob: -0.5,
360 no_speech_prob: 0.1,
361 };
362
363 assert!(!bad_attempt.meets_thresholds(&thresholds));
364 }
365
366 #[test]
367 fn test_enhanced_params_from_base() {
368 let base = FullParams::default().language("en");
369
370 let enhanced = EnhancedTranscriptionParams::from_base(base);
371
372 assert_eq!(enhanced.temperatures.len(), 6);
373 assert_eq!(enhanced.temperatures[0], 0.0);
374 assert_eq!(enhanced.prompt_reset_on_temperature, 0.5);
375 assert!(enhanced.thresholds.compression_ratio_threshold.is_some());
376 }
377
378 #[test]
379 fn test_enhanced_transcription_params_builder() {
380 let params = EnhancedTranscriptionParamsBuilder::new()
381 .language("en")
382 .temperatures(vec![0.0, 0.5, 1.0])
383 .compression_ratio_threshold(Some(3.0))
384 .build();
385
386 assert_eq!(params.temperatures.len(), 3);
387 assert_eq!(params.thresholds.compression_ratio_threshold, Some(3.0));
388 }
389}