whisper_cpp_plus/enhanced/
fallback.rs1use crate::{WhisperState, FullParams, Result, WhisperError, Segment, TranscriptionResult};
6use std::io::Write;
7use flate2::Compression;
8use flate2::write::ZlibEncoder;
9use whisper_cpp_plus_sys as ffi;
10
11#[derive(Debug, Clone)]
13pub struct QualityThresholds {
14 pub compression_ratio_threshold: Option<f32>,
16 pub log_prob_threshold: Option<f32>,
18 pub no_speech_threshold: Option<f32>,
20}
21
22impl Default for QualityThresholds {
23 fn default() -> Self {
24 Self {
25 compression_ratio_threshold: Some(2.4),
26 log_prob_threshold: Some(-1.0),
27 no_speech_threshold: Some(0.6),
28 }
29 }
30}
31
32#[derive(Clone)]
34pub struct EnhancedTranscriptionParams {
35 pub base: FullParams,
37 pub temperatures: Vec<f32>,
39 pub thresholds: QualityThresholds,
41 pub prompt_reset_on_temperature: f32,
43}
44
45impl EnhancedTranscriptionParams {
46 pub fn from_base(base: FullParams) -> Self {
48 Self {
49 base,
50 temperatures: vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
51 thresholds: QualityThresholds::default(),
52 prompt_reset_on_temperature: 0.5,
53 }
54 }
55
56 pub fn builder() -> EnhancedTranscriptionParamsBuilder {
57 EnhancedTranscriptionParamsBuilder::new()
58 }
59}
60
61pub struct EnhancedTranscriptionParamsBuilder {
62 params: EnhancedTranscriptionParams,
63}
64
65impl EnhancedTranscriptionParamsBuilder {
66 pub fn new() -> Self {
67 Self {
68 params: EnhancedTranscriptionParams::from_base(FullParams::default()),
69 }
70 }
71
72 pub fn base_params(mut self, params: FullParams) -> Self {
73 self.params.base = params;
74 self
75 }
76
77 pub fn language(mut self, lang: &str) -> Self {
78 self.params.base = self.params.base.language(lang);
79 self
80 }
81
82 pub fn temperatures(mut self, temps: Vec<f32>) -> Self {
83 self.params.temperatures = temps;
84 self
85 }
86
87 pub fn compression_ratio_threshold(mut self, threshold: Option<f32>) -> Self {
88 self.params.thresholds.compression_ratio_threshold = threshold;
89 self
90 }
91
92 pub fn log_prob_threshold(mut self, threshold: Option<f32>) -> Self {
93 self.params.thresholds.log_prob_threshold = threshold;
94 self
95 }
96
97 pub fn build(self) -> EnhancedTranscriptionParams {
98 self.params
99 }
100}
101
102impl Default for EnhancedTranscriptionParamsBuilder {
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108pub fn calculate_compression_ratio(text: &str) -> f32 {
110 let text_bytes = text.as_bytes();
111 let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
112 encoder.write_all(text_bytes).unwrap();
113 let compressed = encoder.finish().unwrap();
114
115 text_bytes.len() as f32 / compressed.len() as f32
116}
117
118#[derive(Debug)]
120pub struct TranscriptionAttempt {
121 pub text: String,
122 pub segments: Vec<Segment>,
123 pub temperature: f32,
124 pub compression_ratio: f32,
125 pub avg_logprob: f32,
126 pub no_speech_prob: f32,
127}
128
129impl TranscriptionAttempt {
130 pub fn meets_thresholds(&self, thresholds: &QualityThresholds) -> bool {
132 let mut meets = true;
133
134 if let Some(cr_threshold) = thresholds.compression_ratio_threshold {
135 if self.compression_ratio > cr_threshold {
136 meets = false;
137 }
138 }
139
140 if let Some(lp_threshold) = thresholds.log_prob_threshold {
141 if self.avg_logprob < lp_threshold {
142 if let Some(ns_threshold) = thresholds.no_speech_threshold {
144 if self.no_speech_prob <= ns_threshold {
145 meets = false;
146 }
147 } else {
148 meets = false;
149 }
150 }
151 }
152
153 meets
154 }
155}
156
157pub struct EnhancedWhisperState<'a> {
159 state: &'a mut WhisperState,
160}
161
162impl<'a> EnhancedWhisperState<'a> {
163 pub fn new(state: &'a mut WhisperState) -> Self {
164 Self { state }
165 }
166
167 fn get_no_speech_prob(&self, segment_idx: i32) -> f32 {
169 unsafe {
170 ffi::whisper_full_get_segment_no_speech_prob_from_state(
172 self.state.ptr,
173 segment_idx
174 )
175 }
176 }
177
178 fn calculate_avg_logprob(&self, segment_idx: i32) -> f32 {
180 let n_tokens = self.state.full_n_tokens(segment_idx);
181 if n_tokens == 0 {
182 return 0.0;
183 }
184
185 let mut sum_logprob = 0.0;
186 for i in 0..n_tokens {
187 let prob = self.state.full_get_token_prob(segment_idx, i);
188 if prob > 0.0 {
189 sum_logprob += prob.ln();
190 }
191 }
192
193 sum_logprob / n_tokens as f32
194 }
195
196 pub fn transcribe_with_fallback(
198 &mut self,
199 params: EnhancedTranscriptionParams,
200 audio: &[f32],
201 ) -> Result<TranscriptionResult> {
202 let mut all_attempts = Vec::new();
203 let mut below_cr_attempts = Vec::new();
204
205 for temperature in ¶ms.temperatures {
206 let mut current_params = params.base.clone();
208 current_params = current_params.temperature(*temperature);
209
210 if *temperature > params.prompt_reset_on_temperature {
212 current_params = current_params.initial_prompt("");
213 }
214
215 self.state.full(current_params, audio)?;
217
218 let n_segments = self.state.full_n_segments();
220 let mut segments = Vec::new();
221 let mut text = String::new();
222 let mut total_logprob = 0.0;
223 let mut total_tokens = 0;
224
225 for i in 0..n_segments {
226 let segment_text = self.state.full_get_segment_text(i)?;
227 let (start_ms, end_ms) = self.state.full_get_segment_timestamps(i);
228 let speaker_turn_next = self.state.full_get_segment_speaker_turn_next(i);
229
230 if i > 0 {
231 text.push(' ');
232 }
233 text.push_str(&segment_text);
234
235 segments.push(Segment {
236 start_ms,
237 end_ms,
238 text: segment_text,
239 speaker_turn_next,
240 });
241
242 let avg_lp = self.calculate_avg_logprob(i);
244 let n_tokens = self.state.full_n_tokens(i);
245 total_logprob += avg_lp * n_tokens as f32;
246 total_tokens += n_tokens;
247 }
248
249 let avg_logprob = if total_tokens > 0 {
250 total_logprob / total_tokens as f32
251 } else {
252 0.0
253 };
254
255 let compression_ratio = calculate_compression_ratio(&text);
257 let no_speech_prob = if n_segments > 0 {
258 self.get_no_speech_prob(0)
259 } else {
260 0.0
261 };
262
263 let attempt = TranscriptionAttempt {
264 text: text.clone(),
265 segments: segments.clone(),
266 temperature: *temperature,
267 compression_ratio,
268 avg_logprob,
269 no_speech_prob,
270 };
271
272 if attempt.meets_thresholds(¶ms.thresholds) {
274 return Ok(TranscriptionResult {
275 text: attempt.text,
276 segments: attempt.segments,
277 });
278 }
279
280 if let Some(cr_threshold) = params.thresholds.compression_ratio_threshold {
282 if attempt.compression_ratio <= cr_threshold {
283 below_cr_attempts.push(attempt);
284 } else {
285 all_attempts.push(attempt);
286 }
287 } else {
288 all_attempts.push(attempt);
289 }
290 }
291
292 let best_attempt = if !below_cr_attempts.is_empty() {
294 below_cr_attempts.into_iter()
295 .max_by(|a, b| a.avg_logprob.partial_cmp(&b.avg_logprob).unwrap())
296 } else {
297 all_attempts.into_iter()
298 .max_by(|a, b| a.avg_logprob.partial_cmp(&b.avg_logprob).unwrap())
299 };
300
301 best_attempt
302 .map(|a| TranscriptionResult {
303 text: a.text,
304 segments: a.segments,
305 })
306 .ok_or_else(|| WhisperError::TranscriptionError(
307 "Failed to produce acceptable transcription with any temperature".into()
308 ))
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 #[test]
317 fn test_compression_ratio_calculation() {
318 let text = "The quick brown fox jumps over the lazy dog";
320 let ratio = calculate_compression_ratio(text);
321 assert!(ratio > 0.0); let longer_text = "The quick brown fox jumps over the lazy dog. ".repeat(10);
325 let longer_ratio = calculate_compression_ratio(&longer_text);
326 assert!(longer_ratio > 1.0); let repetitive = "a".repeat(1000);
330 let repetitive_ratio = calculate_compression_ratio(&repetitive);
331 assert!(repetitive_ratio > 5.0); }
333
334 #[test]
335 fn test_quality_threshold_checking() {
336 let thresholds = QualityThresholds {
337 compression_ratio_threshold: Some(2.4),
338 log_prob_threshold: Some(-1.0),
339 no_speech_threshold: Some(0.6),
340 };
341
342 let good_attempt = TranscriptionAttempt {
343 text: "Hello world".to_string(),
344 segments: vec![],
345 temperature: 0.0,
346 compression_ratio: 1.5,
347 avg_logprob: -0.5,
348 no_speech_prob: 0.1,
349 };
350
351 assert!(good_attempt.meets_thresholds(&thresholds));
352
353 let bad_attempt = TranscriptionAttempt {
354 text: "a".repeat(100),
355 segments: vec![],
356 temperature: 0.0,
357 compression_ratio: 10.0, avg_logprob: -0.5,
359 no_speech_prob: 0.1,
360 };
361
362 assert!(!bad_attempt.meets_thresholds(&thresholds));
363 }
364
365 #[test]
366 fn test_enhanced_params_from_base() {
367 let base = FullParams::default()
368 .language("en");
369
370 let enhanced = EnhancedTranscriptionParams::from_base(base);
371
372 assert_eq!(enhanced.temperatures.len(), 6);
373 assert_eq!(enhanced.temperatures[0], 0.0);
374 assert_eq!(enhanced.prompt_reset_on_temperature, 0.5);
375 assert!(enhanced.thresholds.compression_ratio_threshold.is_some());
376 }
377
378 #[test]
379 fn test_enhanced_transcription_params_builder() {
380 let params = EnhancedTranscriptionParamsBuilder::new()
381 .language("en")
382 .temperatures(vec![0.0, 0.5, 1.0])
383 .compression_ratio_threshold(Some(3.0))
384 .build();
385
386 assert_eq!(params.temperatures.len(), 3);
387 assert_eq!(params.thresholds.compression_ratio_threshold, Some(3.0));
388 }
389}