1use std::ffi::CString;
13use std::path::Path;
14use std::sync::Arc;
15use std::sync::Mutex;
16use thiserror::Error;
17
18use whisper_cpp_plus_sys as ffi;
19
20#[derive(Debug, Error)]
22pub enum QuantizeError {
23 #[error("Model file not found: {0}")]
24 FileNotFound(String),
25
26 #[error("Failed to open file: {0}")]
27 FileOpenError(String),
28
29 #[error("Failed to write file: {0}")]
30 FileWriteError(String),
31
32 #[error("Invalid model format")]
33 InvalidModel,
34
35 #[error("Invalid quantization type")]
36 InvalidQuantizationType,
37
38 #[error("Quantization failed: {0}")]
39 QuantizationFailed(String),
40}
41
42type Result<T> = std::result::Result<T, QuantizeError>;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46#[repr(i32)]
47#[allow(non_camel_case_types)]
48pub enum QuantizationType {
49 Q4_0 = ffi::GGML_FTYPE_MOSTLY_Q4_0,
51
52 Q4_1 = ffi::GGML_FTYPE_MOSTLY_Q4_1,
54
55 Q5_0 = ffi::GGML_FTYPE_MOSTLY_Q5_0,
57
58 Q5_1 = ffi::GGML_FTYPE_MOSTLY_Q5_1,
60
61 Q8_0 = ffi::GGML_FTYPE_MOSTLY_Q8_0,
63
64 Q2_K = ffi::GGML_FTYPE_MOSTLY_Q2_K,
66
67 Q3_K = ffi::GGML_FTYPE_MOSTLY_Q3_K,
69
70 Q4_K = ffi::GGML_FTYPE_MOSTLY_Q4_K,
72
73 Q5_K = ffi::GGML_FTYPE_MOSTLY_Q5_K,
75
76 Q6_K = ffi::GGML_FTYPE_MOSTLY_Q6_K,
78}
79
80impl QuantizationType {
81 pub fn name(&self) -> &'static str {
83 match self {
84 Self::Q4_0 => "Q4_0",
85 Self::Q4_1 => "Q4_1",
86 Self::Q5_0 => "Q5_0",
87 Self::Q5_1 => "Q5_1",
88 Self::Q8_0 => "Q8_0",
89 Self::Q2_K => "Q2_K",
90 Self::Q3_K => "Q3_K",
91 Self::Q4_K => "Q4_K",
92 Self::Q5_K => "Q5_K",
93 Self::Q6_K => "Q6_K",
94 }
95 }
96
97 pub fn size_factor(&self) -> f32 {
100 match self {
101 Self::Q2_K => 0.19, Self::Q3_K => 0.26, Self::Q4_0 => 0.31, Self::Q4_1 => 0.35, Self::Q4_K => 0.33, Self::Q5_0 => 0.39, Self::Q5_1 => 0.43, Self::Q5_K => 0.41, Self::Q6_K => 0.49, Self::Q8_0 => 0.69, }
112 }
113
114 pub fn all() -> &'static [QuantizationType] {
116 &[
117 Self::Q4_0,
118 Self::Q4_1,
119 Self::Q5_0,
120 Self::Q5_1,
121 Self::Q8_0,
122 Self::Q2_K,
123 Self::Q3_K,
124 Self::Q4_K,
125 Self::Q5_K,
126 Self::Q6_K,
127 ]
128 }
129
130}
131
132impl std::str::FromStr for QuantizationType {
133 type Err = QuantizeError;
134
135 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
136 match s.to_uppercase().as_str() {
137 "Q4_0" | "Q40" => Ok(Self::Q4_0),
138 "Q4_1" | "Q41" => Ok(Self::Q4_1),
139 "Q5_0" | "Q50" => Ok(Self::Q5_0),
140 "Q5_1" | "Q51" => Ok(Self::Q5_1),
141 "Q8_0" | "Q80" => Ok(Self::Q8_0),
142 "Q2_K" | "Q2K" => Ok(Self::Q2_K),
143 "Q3_K" | "Q3K" => Ok(Self::Q3_K),
144 "Q4_K" | "Q4K" => Ok(Self::Q4_K),
145 "Q5_K" | "Q5K" => Ok(Self::Q5_K),
146 "Q6_K" | "Q6K" => Ok(Self::Q6_K),
147 _ => Err(QuantizeError::InvalidQuantizationType),
148 }
149 }
150}
151
152impl std::fmt::Display for QuantizationType {
153 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154 write!(f, "{}", self.name())
155 }
156}
157
158pub type ProgressCallback = Box<dyn Fn(f32) + Send>;
160
161pub struct WhisperQuantize;
163
164impl WhisperQuantize {
165 pub fn quantize_model_file<P: AsRef<Path>>(
183 input_path: P,
184 output_path: P,
185 qtype: QuantizationType,
186 ) -> Result<()> {
187 Self::quantize_model_file_impl(input_path.as_ref(), output_path.as_ref(), qtype, None)
188 }
189
190 pub fn quantize_model_file_with_progress<P, F>(
212 input_path: P,
213 output_path: P,
214 qtype: QuantizationType,
215 callback: F,
216 ) -> Result<()>
217 where
218 P: AsRef<Path>,
219 F: Fn(f32) + Send + 'static,
220 {
221 Self::quantize_model_file_impl(
222 input_path.as_ref(),
223 output_path.as_ref(),
224 qtype,
225 Some(Box::new(callback)),
226 )
227 }
228
229 fn quantize_model_file_impl(
230 input_path: &Path,
231 output_path: &Path,
232 qtype: QuantizationType,
233 callback: Option<ProgressCallback>,
234 ) -> Result<()> {
235 if !input_path.exists() {
237 return Err(QuantizeError::FileNotFound(format!(
238 "{}",
239 input_path.display()
240 )));
241 }
242
243 let input_cstr = path_to_cstring(input_path)?;
245 let output_cstr = path_to_cstring(output_path)?;
246
247 let callback_data = callback.map(|cb| Arc::new(Mutex::new(cb)));
249 let callback_ptr = callback_data.as_ref().map(|data| {
250 Arc::clone(data) as Arc<Mutex<dyn Fn(f32) + Send>>
251 });
252
253 let ffi_callback: ffi::whisper_quantize_progress_callback = if callback_ptr.is_some() {
255 Some(quantize_progress_callback)
256 } else {
257 None
258 };
259
260 if let Some(ptr) = callback_ptr {
262 CALLBACK_DATA.with(|data| {
263 *data.borrow_mut() = Some(ptr);
264 });
265 }
266
267 let result = unsafe {
269 ffi::whisper_model_quantize(
270 input_cstr.as_ptr(),
271 output_cstr.as_ptr(),
272 qtype as i32,
273 ffi_callback,
274 )
275 };
276
277 CALLBACK_DATA.with(|data| {
279 *data.borrow_mut() = None;
280 });
281
282 match result {
284 ffi::WHISPER_QUANTIZE_OK => Ok(()),
285 ffi::WHISPER_QUANTIZE_ERROR_INVALID_MODEL => {
286 Err(QuantizeError::QuantizationFailed("Invalid model file".to_string()))
287 }
288 ffi::WHISPER_QUANTIZE_ERROR_FILE_OPEN => {
289 Err(QuantizeError::QuantizationFailed(format!(
290 "Failed to open input file: {}",
291 input_path.display()
292 )))
293 }
294 ffi::WHISPER_QUANTIZE_ERROR_FILE_WRITE => {
295 Err(QuantizeError::QuantizationFailed(format!(
296 "Failed to write output file: {}",
297 output_path.display()
298 )))
299 }
300 ffi::WHISPER_QUANTIZE_ERROR_INVALID_FTYPE => {
301 Err(QuantizeError::QuantizationFailed(format!(
302 "Invalid quantization type: {}",
303 qtype
304 )))
305 }
306 ffi::WHISPER_QUANTIZE_ERROR_QUANTIZATION_FAILED => {
307 Err(QuantizeError::QuantizationFailed("Quantization failed".to_string()))
308 }
309 _ => Err(QuantizeError::QuantizationFailed(format!(
310 "Unknown quantization error: {}",
311 result
312 ))),
313 }
314 }
315
316 pub fn get_model_quantization_type<P: AsRef<Path>>(
334 model_path: P,
335 ) -> Result<Option<QuantizationType>> {
336 let path = model_path.as_ref();
337
338 if !path.exists() {
339 return Err(QuantizeError::FileNotFound(format!(
340 "{}",
341 path.display()
342 )));
343 }
344
345 let path_cstr = path_to_cstring(path)?;
346
347 let ftype = unsafe {
348 ffi::whisper_model_get_ftype(path_cstr.as_ptr())
349 };
350
351 if ftype < 0 {
352 return Err(QuantizeError::FileOpenError(format!(
353 "{}",
354 path.display()
355 )));
356 }
357
358 let qtype = match ftype {
360 x if x == ffi::GGML_FTYPE_ALL_F32 => None,
361 x if x == ffi::GGML_FTYPE_MOSTLY_F16 => None,
362 x if x == QuantizationType::Q4_0 as i32 => Some(QuantizationType::Q4_0),
363 x if x == QuantizationType::Q4_1 as i32 => Some(QuantizationType::Q4_1),
364 x if x == QuantizationType::Q5_0 as i32 => Some(QuantizationType::Q5_0),
365 x if x == QuantizationType::Q5_1 as i32 => Some(QuantizationType::Q5_1),
366 x if x == QuantizationType::Q8_0 as i32 => Some(QuantizationType::Q8_0),
367 x if x == QuantizationType::Q2_K as i32 => Some(QuantizationType::Q2_K),
368 x if x == QuantizationType::Q3_K as i32 => Some(QuantizationType::Q3_K),
369 x if x == QuantizationType::Q4_K as i32 => Some(QuantizationType::Q4_K),
370 x if x == QuantizationType::Q5_K as i32 => Some(QuantizationType::Q5_K),
371 x if x == QuantizationType::Q6_K as i32 => Some(QuantizationType::Q6_K),
372 _ => None,
373 };
374
375 Ok(qtype)
376 }
377
378 pub fn estimate_quantized_size<P: AsRef<Path>>(
395 model_path: P,
396 qtype: QuantizationType,
397 ) -> Result<u64> {
398 let path = model_path.as_ref();
399 let metadata = std::fs::metadata(path)
400 .map_err(|e| QuantizeError::QuantizationFailed(format!("Failed to read model file: {}", e)))?;
401
402 let original_size = metadata.len();
403 let estimated_size = (original_size as f64 * qtype.size_factor() as f64) as u64;
404
405 Ok(estimated_size)
406 }
407}
408
409thread_local! {
411 static CALLBACK_DATA: std::cell::RefCell<Option<Arc<Mutex<dyn Fn(f32) + Send>>>> =
412 std::cell::RefCell::new(None);
413}
414
415extern "C" fn quantize_progress_callback(progress: f32) {
417 CALLBACK_DATA.with(|data| {
418 if let Some(callback) = data.borrow().as_ref() {
419 if let Ok(cb) = callback.lock() {
420 cb(progress);
421 }
422 }
423 });
424}
425
426fn path_to_cstring(path: &Path) -> Result<CString> {
428 let path_str = path.to_str()
429 .ok_or_else(|| QuantizeError::QuantizationFailed("Invalid UTF-8 in path".to_string()))?;
430
431 CString::new(path_str)
432 .map_err(|_| QuantizeError::QuantizationFailed("Path contains null byte".to_string()))
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438
439 #[test]
440 fn test_quantization_type_names() {
441 assert_eq!(QuantizationType::Q4_0.name(), "Q4_0");
442 assert_eq!(QuantizationType::Q5_1.name(), "Q5_1");
443 assert_eq!(QuantizationType::Q8_0.name(), "Q8_0");
444 assert_eq!(QuantizationType::Q3_K.name(), "Q3_K");
445 }
446
447 #[test]
448 fn test_quantization_type_from_str() {
449 assert_eq!("q4_0".parse::<QuantizationType>().unwrap(), QuantizationType::Q4_0);
450 assert_eq!("Q5_1".parse::<QuantizationType>().unwrap(), QuantizationType::Q5_1);
451 assert_eq!("q8_0".parse::<QuantizationType>().unwrap(), QuantizationType::Q8_0);
452 assert_eq!("Q3K".parse::<QuantizationType>().unwrap(), QuantizationType::Q3_K);
453 assert!("invalid".parse::<QuantizationType>().is_err());
454 }
455
456 #[test]
457 fn test_size_factors() {
458 for qtype in QuantizationType::all() {
459 let factor = qtype.size_factor();
460 assert!(factor > 0.0 && factor < 1.0,
461 "{} has invalid size factor: {}", qtype, factor);
462 }
463 }
464}