Skip to main content

whisper_cpp_plus/
quantize.rs

1//! Model quantization for reducing model size and improving inference speed.
2//!
3//! Provides functionality to quantize Whisper models to various bit depths,
4//! reducing model size while maintaining reasonable accuracy. Quantization is
5//! particularly useful for deployment on resource-constrained devices.
6//!
7//! Enable with the `quantization` feature flag:
8//! ```toml
9//! whisper-cpp-plus = { version = "0.1.3", features = ["quantization"] }
10//! ```
11
12use std::ffi::CString;
13use std::path::Path;
14use std::sync::Arc;
15use std::sync::Mutex;
16use thiserror::Error;
17
18use whisper_cpp_plus_sys as ffi;
19
20/// Error type for quantization operations
21#[derive(Debug, Error)]
22pub enum QuantizeError {
23    #[error("Model file not found: {0}")]
24    FileNotFound(String),
25
26    #[error("Failed to open file: {0}")]
27    FileOpenError(String),
28
29    #[error("Failed to write file: {0}")]
30    FileWriteError(String),
31
32    #[error("Invalid model format")]
33    InvalidModel,
34
35    #[error("Invalid quantization type")]
36    InvalidQuantizationType,
37
38    #[error("Quantization failed: {0}")]
39    QuantizationFailed(String),
40}
41
42type Result<T> = std::result::Result<T, QuantizeError>;
43
44/// Quantization types supported by whisper.cpp
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46#[repr(i32)]
47#[allow(non_camel_case_types)]
48pub enum QuantizationType {
49    /// 4-bit quantization (method 0) - ~3.5 GB for base model
50    Q4_0 = ffi::GGML_FTYPE_MOSTLY_Q4_0,
51
52    /// 4-bit quantization (method 1) - ~3.9 GB for base model
53    Q4_1 = ffi::GGML_FTYPE_MOSTLY_Q4_1,
54
55    /// 5-bit quantization (method 0) - ~4.3 GB for base model
56    Q5_0 = ffi::GGML_FTYPE_MOSTLY_Q5_0,
57
58    /// 5-bit quantization (method 1) - ~4.7 GB for base model
59    Q5_1 = ffi::GGML_FTYPE_MOSTLY_Q5_1,
60
61    /// 8-bit quantization - ~7.7 GB for base model
62    Q8_0 = ffi::GGML_FTYPE_MOSTLY_Q8_0,
63
64    /// 2-bit k-quantization
65    Q2_K = ffi::GGML_FTYPE_MOSTLY_Q2_K,
66
67    /// 3-bit k-quantization
68    Q3_K = ffi::GGML_FTYPE_MOSTLY_Q3_K,
69
70    /// 4-bit k-quantization
71    Q4_K = ffi::GGML_FTYPE_MOSTLY_Q4_K,
72
73    /// 5-bit k-quantization
74    Q5_K = ffi::GGML_FTYPE_MOSTLY_Q5_K,
75
76    /// 6-bit k-quantization
77    Q6_K = ffi::GGML_FTYPE_MOSTLY_Q6_K,
78}
79
80impl QuantizationType {
81    /// Get a human-readable name for the quantization type
82    pub fn name(&self) -> &'static str {
83        match self {
84            Self::Q4_0 => "Q4_0",
85            Self::Q4_1 => "Q4_1",
86            Self::Q5_0 => "Q5_0",
87            Self::Q5_1 => "Q5_1",
88            Self::Q8_0 => "Q8_0",
89            Self::Q2_K => "Q2_K",
90            Self::Q3_K => "Q3_K",
91            Self::Q4_K => "Q4_K",
92            Self::Q5_K => "Q5_K",
93            Self::Q6_K => "Q6_K",
94        }
95    }
96
97    /// Estimate the size reduction factor for this quantization type.
98    /// Returns the approximate size as a fraction of the original F32 model.
99    pub fn size_factor(&self) -> f32 {
100        match self {
101            Self::Q2_K => 0.19,  // ~19% of original
102            Self::Q3_K => 0.26,  // ~26% of original
103            Self::Q4_0 => 0.31,  // ~31% of original
104            Self::Q4_1 => 0.35,  // ~35% of original
105            Self::Q4_K => 0.33,  // ~33% of original
106            Self::Q5_0 => 0.39,  // ~39% of original
107            Self::Q5_1 => 0.43,  // ~43% of original
108            Self::Q5_K => 0.41,  // ~41% of original
109            Self::Q6_K => 0.49,  // ~49% of original
110            Self::Q8_0 => 0.69,  // ~69% of original
111        }
112    }
113
114    /// Get all available quantization types
115    pub fn all() -> &'static [QuantizationType] {
116        &[
117            Self::Q4_0,
118            Self::Q4_1,
119            Self::Q5_0,
120            Self::Q5_1,
121            Self::Q8_0,
122            Self::Q2_K,
123            Self::Q3_K,
124            Self::Q4_K,
125            Self::Q5_K,
126            Self::Q6_K,
127        ]
128    }
129
130}
131
132impl std::str::FromStr for QuantizationType {
133    type Err = QuantizeError;
134
135    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
136        match s.to_uppercase().as_str() {
137            "Q4_0" | "Q40" => Ok(Self::Q4_0),
138            "Q4_1" | "Q41" => Ok(Self::Q4_1),
139            "Q5_0" | "Q50" => Ok(Self::Q5_0),
140            "Q5_1" | "Q51" => Ok(Self::Q5_1),
141            "Q8_0" | "Q80" => Ok(Self::Q8_0),
142            "Q2_K" | "Q2K" => Ok(Self::Q2_K),
143            "Q3_K" | "Q3K" => Ok(Self::Q3_K),
144            "Q4_K" | "Q4K" => Ok(Self::Q4_K),
145            "Q5_K" | "Q5K" => Ok(Self::Q5_K),
146            "Q6_K" | "Q6K" => Ok(Self::Q6_K),
147            _ => Err(QuantizeError::InvalidQuantizationType),
148        }
149    }
150}
151
152impl std::fmt::Display for QuantizationType {
153    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154        write!(f, "{}", self.name())
155    }
156}
157
158/// Progress callback for quantization operations
159pub type ProgressCallback = Box<dyn Fn(f32) + Send>;
160
161/// Model quantizer for converting Whisper models to different quantization formats
162pub struct WhisperQuantize;
163
164impl WhisperQuantize {
165    /// Quantize a model file to a specified quantization type
166    ///
167    /// # Arguments
168    /// * `input_path` - Path to the input model file (must be in GGML format)
169    /// * `output_path` - Path where the quantized model will be saved
170    /// * `qtype` - The quantization type to use
171    ///
172    /// # Example
173    /// ```no_run
174    /// use whisper_cpp_plus::{WhisperQuantize, QuantizationType};
175    ///
176    /// WhisperQuantize::quantize_model_file(
177    ///     "models/ggml-base.bin",
178    ///     "models/ggml-base-q5_0.bin",
179    ///     QuantizationType::Q5_0
180    /// ).expect("Failed to quantize model");
181    /// ```
182    pub fn quantize_model_file<P: AsRef<Path>>(
183        input_path: P,
184        output_path: P,
185        qtype: QuantizationType,
186    ) -> Result<()> {
187        Self::quantize_model_file_impl(input_path.as_ref(), output_path.as_ref(), qtype, None)
188    }
189
190    /// Quantize a model file with progress callback
191    ///
192    /// # Arguments
193    /// * `input_path` - Path to the input model file
194    /// * `output_path` - Path where the quantized model will be saved
195    /// * `qtype` - The quantization type to use
196    /// * `callback` - Progress callback function (receives values from 0.0 to 1.0)
197    ///
198    /// # Example
199    /// ```no_run
200    /// use whisper_cpp_plus::{WhisperQuantize, QuantizationType};
201    ///
202    /// WhisperQuantize::quantize_model_file_with_progress(
203    ///     "models/ggml-base.bin",
204    ///     "models/ggml-base-q4_0.bin",
205    ///     QuantizationType::Q4_0,
206    ///     |progress| {
207    ///         println!("Progress: {:.1}%", progress * 100.0);
208    ///     }
209    /// ).expect("Failed to quantize model");
210    /// ```
211    pub fn quantize_model_file_with_progress<P, F>(
212        input_path: P,
213        output_path: P,
214        qtype: QuantizationType,
215        callback: F,
216    ) -> Result<()>
217    where
218        P: AsRef<Path>,
219        F: Fn(f32) + Send + 'static,
220    {
221        Self::quantize_model_file_impl(
222            input_path.as_ref(),
223            output_path.as_ref(),
224            qtype,
225            Some(Box::new(callback)),
226        )
227    }
228
229    fn quantize_model_file_impl(
230        input_path: &Path,
231        output_path: &Path,
232        qtype: QuantizationType,
233        callback: Option<ProgressCallback>,
234    ) -> Result<()> {
235        // Validate input file exists
236        if !input_path.exists() {
237            return Err(QuantizeError::FileNotFound(format!(
238                "{}",
239                input_path.display()
240            )));
241        }
242
243        // Convert paths to C strings
244        let input_cstr = path_to_cstring(input_path)?;
245        let output_cstr = path_to_cstring(output_path)?;
246
247        // Set up progress callback if provided
248        let callback_data = callback.map(|cb| Arc::new(Mutex::new(cb)));
249        let callback_ptr = callback_data.as_ref().map(|data| {
250            Arc::clone(data) as Arc<Mutex<dyn Fn(f32) + Send>>
251        });
252
253        // Create the FFI callback function
254        let ffi_callback: ffi::whisper_quantize_progress_callback = if callback_ptr.is_some() {
255            Some(quantize_progress_callback)
256        } else {
257            None
258        };
259
260        // Store callback data in thread-local storage for the callback to access
261        if let Some(ptr) = callback_ptr {
262            CALLBACK_DATA.with(|data| {
263                *data.borrow_mut() = Some(ptr);
264            });
265        }
266
267        // Perform quantization
268        let result = unsafe {
269            ffi::whisper_model_quantize(
270                input_cstr.as_ptr(),
271                output_cstr.as_ptr(),
272                qtype as i32,
273                ffi_callback,
274            )
275        };
276
277        // Clear callback data
278        CALLBACK_DATA.with(|data| {
279            *data.borrow_mut() = None;
280        });
281
282        // Check result
283        match result {
284            ffi::WHISPER_QUANTIZE_OK => Ok(()),
285            ffi::WHISPER_QUANTIZE_ERROR_INVALID_MODEL => {
286                Err(QuantizeError::QuantizationFailed("Invalid model file".to_string()))
287            }
288            ffi::WHISPER_QUANTIZE_ERROR_FILE_OPEN => {
289                Err(QuantizeError::QuantizationFailed(format!(
290                    "Failed to open input file: {}",
291                    input_path.display()
292                )))
293            }
294            ffi::WHISPER_QUANTIZE_ERROR_FILE_WRITE => {
295                Err(QuantizeError::QuantizationFailed(format!(
296                    "Failed to write output file: {}",
297                    output_path.display()
298                )))
299            }
300            ffi::WHISPER_QUANTIZE_ERROR_INVALID_FTYPE => {
301                Err(QuantizeError::QuantizationFailed(format!(
302                    "Invalid quantization type: {}",
303                    qtype
304                )))
305            }
306            ffi::WHISPER_QUANTIZE_ERROR_QUANTIZATION_FAILED => {
307                Err(QuantizeError::QuantizationFailed("Quantization failed".to_string()))
308            }
309            _ => Err(QuantizeError::QuantizationFailed(format!(
310                "Unknown quantization error: {}",
311                result
312            ))),
313        }
314    }
315
316    /// Get the quantization type of an existing model file
317    ///
318    /// # Returns
319    /// * `Ok(Some(qtype))` - The quantization type if the model is quantized
320    /// * `Ok(None)` - If the model is in full precision (F32 or F16)
321    /// * `Err(_)` - If the file cannot be read or is not a valid model
322    ///
323    /// # Example
324    /// ```no_run
325    /// use whisper_cpp_plus::WhisperQuantize;
326    ///
327    /// match WhisperQuantize::get_model_quantization_type("models/ggml-base-q5_0.bin") {
328    ///     Ok(Some(qtype)) => println!("Model is quantized as: {}", qtype),
329    ///     Ok(None) => println!("Model is not quantized"),
330    ///     Err(e) => println!("Error reading model: {}", e),
331    /// }
332    /// ```
333    pub fn get_model_quantization_type<P: AsRef<Path>>(
334        model_path: P,
335    ) -> Result<Option<QuantizationType>> {
336        let path = model_path.as_ref();
337
338        if !path.exists() {
339            return Err(QuantizeError::FileNotFound(format!(
340                "{}",
341                path.display()
342            )));
343        }
344
345        let path_cstr = path_to_cstring(path)?;
346
347        let ftype = unsafe {
348            ffi::whisper_model_get_ftype(path_cstr.as_ptr())
349        };
350
351        if ftype < 0 {
352            return Err(QuantizeError::FileOpenError(format!(
353                "{}",
354                path.display()
355            )));
356        }
357
358        // Map the ftype to our QuantizationType enum
359        let qtype = match ftype {
360            x if x == ffi::GGML_FTYPE_ALL_F32 => None,
361            x if x == ffi::GGML_FTYPE_MOSTLY_F16 => None,
362            x if x == QuantizationType::Q4_0 as i32 => Some(QuantizationType::Q4_0),
363            x if x == QuantizationType::Q4_1 as i32 => Some(QuantizationType::Q4_1),
364            x if x == QuantizationType::Q5_0 as i32 => Some(QuantizationType::Q5_0),
365            x if x == QuantizationType::Q5_1 as i32 => Some(QuantizationType::Q5_1),
366            x if x == QuantizationType::Q8_0 as i32 => Some(QuantizationType::Q8_0),
367            x if x == QuantizationType::Q2_K as i32 => Some(QuantizationType::Q2_K),
368            x if x == QuantizationType::Q3_K as i32 => Some(QuantizationType::Q3_K),
369            x if x == QuantizationType::Q4_K as i32 => Some(QuantizationType::Q4_K),
370            x if x == QuantizationType::Q5_K as i32 => Some(QuantizationType::Q5_K),
371            x if x == QuantizationType::Q6_K as i32 => Some(QuantizationType::Q6_K),
372            _ => None,
373        };
374
375        Ok(qtype)
376    }
377
378    /// Estimate the size of a quantized model given the original model path and target quantization type
379    ///
380    /// # Returns
381    /// Estimated size in bytes of the quantized model
382    ///
383    /// # Example
384    /// ```no_run
385    /// use whisper_cpp_plus::{WhisperQuantize, QuantizationType};
386    ///
387    /// let estimated_size = WhisperQuantize::estimate_quantized_size(
388    ///     "models/ggml-base.bin",
389    ///     QuantizationType::Q5_0
390    /// ).unwrap_or(0);
391    ///
392    /// println!("Estimated after Q5_0: {} MB", estimated_size / 1024 / 1024);
393    /// ```
394    pub fn estimate_quantized_size<P: AsRef<Path>>(
395        model_path: P,
396        qtype: QuantizationType,
397    ) -> Result<u64> {
398        let path = model_path.as_ref();
399        let metadata = std::fs::metadata(path)
400            .map_err(|e| QuantizeError::QuantizationFailed(format!("Failed to read model file: {}", e)))?;
401
402        let original_size = metadata.len();
403        let estimated_size = (original_size as f64 * qtype.size_factor() as f64) as u64;
404
405        Ok(estimated_size)
406    }
407}
408
409// Thread-local storage for callback data
410thread_local! {
411    static CALLBACK_DATA: std::cell::RefCell<Option<Arc<Mutex<dyn Fn(f32) + Send>>>> =
412        std::cell::RefCell::new(None);
413}
414
415// FFI callback function that forwards to the Rust callback
416extern "C" fn quantize_progress_callback(progress: f32) {
417    CALLBACK_DATA.with(|data| {
418        if let Some(callback) = data.borrow().as_ref() {
419            if let Ok(cb) = callback.lock() {
420                cb(progress);
421            }
422        }
423    });
424}
425
426// Helper function to convert Path to CString
427fn path_to_cstring(path: &Path) -> Result<CString> {
428    let path_str = path.to_str()
429        .ok_or_else(|| QuantizeError::QuantizationFailed("Invalid UTF-8 in path".to_string()))?;
430
431    CString::new(path_str)
432        .map_err(|_| QuantizeError::QuantizationFailed("Path contains null byte".to_string()))
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438
439    #[test]
440    fn test_quantization_type_names() {
441        assert_eq!(QuantizationType::Q4_0.name(), "Q4_0");
442        assert_eq!(QuantizationType::Q5_1.name(), "Q5_1");
443        assert_eq!(QuantizationType::Q8_0.name(), "Q8_0");
444        assert_eq!(QuantizationType::Q3_K.name(), "Q3_K");
445    }
446
447    #[test]
448    fn test_quantization_type_from_str() {
449        assert_eq!("q4_0".parse::<QuantizationType>().unwrap(), QuantizationType::Q4_0);
450        assert_eq!("Q5_1".parse::<QuantizationType>().unwrap(), QuantizationType::Q5_1);
451        assert_eq!("q8_0".parse::<QuantizationType>().unwrap(), QuantizationType::Q8_0);
452        assert_eq!("Q3K".parse::<QuantizationType>().unwrap(), QuantizationType::Q3_K);
453        assert!("invalid".parse::<QuantizationType>().is_err());
454    }
455
456    #[test]
457    fn test_size_factors() {
458        for qtype in QuantizationType::all() {
459            let factor = qtype.size_factor();
460            assert!(factor > 0.0 && factor < 1.0,
461                "{} has invalid size factor: {}", qtype, factor);
462        }
463    }
464}